Skip to content

Commit 653fd00

Browse files
committed
adding greenkhorn
1 parent 4367a34 commit 653fd00

File tree

2 files changed

+153
-2
lines changed

2 files changed

+153
-2
lines changed

ot/bregman.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
4747
reg : float
4848
Regularization term >0
4949
method : str
50-
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
50+
method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
5151
'sinkhorn_epsilon_scaling', see those function for specific parameters
5252
numItermax : int, optional
5353
Max number of iterations
@@ -103,6 +103,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
103103
def sink():
104104
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
105105
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
106+
if method.lower() == 'greenkhorn':
107+
def sink():
108+
return greenkhorn(a, b, M, reg, numItermax=numItermax,
109+
stopThr=stopThr, verbose=verbose, log=log)
106110
elif method.lower() == 'sinkhorn_stabilized':
107111
def sink():
108112
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
@@ -197,13 +201,16 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
197201
198202
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
199203
204+
[21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
205+
200206
201207
202208
See Also
203209
--------
204210
ot.lp.emd : Unregularized OT
205211
ot.optim.cg : General regularized OT
206212
ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
213+
ot.bregman.greenkhorn : Greenkhorn [21]
207214
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
208215
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
209216
@@ -410,6 +417,148 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
410417
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
411418

412419

420+
421+
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log = False):
422+
"""
423+
Solve the entropic regularization optimal transport problem and return the OT matrix
424+
425+
The algorithm used is based on the paper
426+
427+
Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration
428+
by Jason Altschuler, Jonathan Weed, Philippe Rigollet
429+
appeared at NIPS 2017
430+
431+
which is a stochastic version of the Sinkhorn-Knopp algorithm [2].
432+
433+
The function solves the following optimization problem:
434+
435+
.. math::
436+
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
437+
438+
s.t. \gamma 1 = a
439+
440+
\gamma^T 1= b
441+
442+
\gamma\geq 0
443+
where :
444+
445+
- M is the (ns,nt) metric cost matrix
446+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
447+
- a and b are source and target weights (sum to 1)
448+
449+
450+
451+
Parameters
452+
----------
453+
a : np.ndarray (ns,)
454+
samples weights in the source domain
455+
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
456+
samples in the target domain, compute sinkhorn with multiple targets
457+
and fixed M if b is a matrix (return OT loss + dual variables in log)
458+
M : np.ndarray (ns,nt)
459+
loss matrix
460+
reg : float
461+
Regularization term >0
462+
numItermax : int, optional
463+
Max number of iterations
464+
stopThr : float, optional
465+
Stop threshol on error (>0)
466+
log : bool, optional
467+
record log if True
468+
469+
470+
Returns
471+
-------
472+
gamma : (ns x nt) ndarray
473+
Optimal transportation matrix for the given parameters
474+
log : dict
475+
log dictionary return only if log==True in parameters
476+
477+
Examples
478+
--------
479+
480+
>>> import ot
481+
>>> a=[.5,.5]
482+
>>> b=[.5,.5]
483+
>>> M=[[0.,1.],[1.,0.]]
484+
>>> ot.sinkhorn(a,b,M,1)
485+
array([[ 0.36552929, 0.13447071],
486+
[ 0.13447071, 0.36552929]])
487+
488+
489+
References
490+
----------
491+
492+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
493+
[21] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
494+
495+
496+
See Also
497+
--------
498+
ot.lp.emd : Unregularized OT
499+
ot.optim.cg : General regularized OT
500+
501+
"""
502+
503+
i = 0
504+
505+
n = a.shape[0]
506+
m = b.shape[0]
507+
508+
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
509+
K = np.empty(M.shape, dtype=M.dtype)
510+
np.divide(M, -reg, out=K)
511+
np.exp(K, out=K)
512+
513+
u = np.ones(n)/n
514+
v = np.ones(m)/m
515+
G = np.diag(u)@K@np.diag(v)
516+
517+
one_n = np.ones(n)
518+
one_m = np.ones(m)
519+
viol = G@one_m - a
520+
viol_2 = G.T@one_n - b
521+
stopThr_val = 1
522+
if log:
523+
log['u'] = u
524+
log['v'] = v
525+
526+
while i < numItermax and stopThr_val > stopThr:
527+
i +=1
528+
i_1 = np.argmax(np.abs(viol))
529+
i_2 = np.argmax(np.abs(viol_2))
530+
m_viol_1 = np.abs(viol[i_1])
531+
m_viol_2 = np.abs(viol_2[i_2])
532+
stopThr_val = np.maximum(m_viol_1,m_viol_2)
533+
534+
if m_viol_1 > m_viol_2:
535+
old_u = u[i_1]
536+
u[i_1] = a[i_1]/(K[i_1,:]@v)
537+
G[i_1,:] = u[i_1]*K[i_1,:]*v
538+
539+
viol[i_1] = u[i_1]*K[i_1,:]@v - a[i_1]
540+
viol_2 = viol_2 + ( K[i_1,:].T*(u[i_1] - old_u)*v)
541+
542+
else:
543+
old_v = v[i_2]
544+
v[i_2] = b[i_2]/(K[:,i_2].T@u)
545+
G[:,i_2] = u*K[:,i_2]*v[i_2]
546+
#aviol = (G@one_m - a)
547+
#aviol_2 = (G.T@one_n - b)
548+
viol = viol + ( -old_v + v[i_2])*K[:,i_2]*u
549+
viol_2[i_2] = v[i_2]*K[:,i_2]@u - b[i_2]
550+
551+
#print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
552+
553+
if log:
554+
log['u'] = u
555+
log['v'] = v
556+
557+
if log:
558+
return G,log
559+
else:
560+
return G
561+
413562
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
414563
warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
415564
"""

test/test_bregman.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,14 @@ def test_sinkhorn_variants():
7171
Ges = ot.sinkhorn(
7272
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
7373
Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
74+
G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)
7475

7576
# check values
7677
np.testing.assert_allclose(G0, Gs, atol=1e-05)
7778
np.testing.assert_allclose(G0, Ges, atol=1e-05)
7879
np.testing.assert_allclose(G0, Gerr)
79-
80+
np.testing.assert_allclose(G0, G_green, atol = 1e-32)
81+
print(G0,G_green)
8082

8183
def test_bary():
8284

0 commit comments

Comments
 (0)