Skip to content

Commit a632c40

Browse files
committed
make sinkhorn more general with method selection
1 parent 05da582 commit a632c40

File tree

1 file changed

+107
-3
lines changed

1 file changed

+107
-3
lines changed

ot/bregman.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,112 @@
55

66
import numpy as np
77

8+
def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
9+
u"""
10+
Solve the entropic regularization optimal transport problem
11+
12+
The function solves the following optimization problem:
13+
14+
.. math::
15+
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
16+
17+
s.t. \gamma 1 = a
18+
19+
\gamma^T 1= b
20+
21+
\gamma\geq 0
22+
where :
23+
24+
- M is the (ns,nt) metric cost matrix
25+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
26+
- a and b are source and target weights (sum to 1)
27+
28+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
29+
30+
31+
Parameters
32+
----------
33+
a : np.ndarray (ns,)
34+
samples weights in the source domain
35+
b : np.ndarray (nt,)
36+
samples in the target domain
37+
M : np.ndarray (ns,nt)
38+
loss matrix
39+
reg : float
40+
Regularization term >0
41+
method : str
42+
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
43+
'sinkhorn_epsilon_scaling', see those function for specific parameters
44+
numItermax : int, optional
45+
Max number of iterations
46+
stopThr : float, optional
47+
Stop threshol on error (>0)
48+
verbose : bool, optional
49+
Print information along iterations
50+
log : bool, optional
51+
record log if True
52+
53+
54+
Returns
55+
-------
56+
gamma : (ns x nt) ndarray
57+
Optimal transportation matrix for the given parameters
58+
log : dict
59+
log dictionary return only if log==True in parameters
60+
61+
Examples
62+
--------
63+
64+
>>> import ot
65+
>>> a=[.5,.5]
66+
>>> b=[.5,.5]
67+
>>> M=[[0.,1.],[1.,0.]]
68+
>>> ot.sinkhorn(a,b,M,1)
69+
array([[ 0.36552929, 0.13447071],
70+
[ 0.13447071, 0.36552929]])
71+
72+
73+
References
74+
----------
75+
76+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
77+
78+
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
79+
80+
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
81+
82+
83+
84+
See Also
85+
--------
86+
ot.lp.emd : Unregularized OT
87+
ot.optim.cg : General regularized OT
88+
ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
89+
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
90+
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
91+
92+
"""
93+
94+
if method.lower()=='sinkhorn':
95+
sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax,
96+
stopThr=stopThr, verbose=verbose, log=log,**kwargs)
97+
elif method.lower()=='sinkhorn_stabilized':
98+
sink= lambda: sinkhorn_stabilized(a,b, M, reg,numItermax=numItermax,
99+
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
100+
elif method.lower()=='sinkhorn_epsilon_scaling':
101+
sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax,
102+
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
103+
else:
104+
print('Warning : unknown method using classic Sinkhorn Knopp')
105+
sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs)
106+
107+
return sink()
108+
109+
110+
111+
8112

9-
def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False):
113+
def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
10114
"""
11115
Solve the entropic regularization optimal transport problem
12116
@@ -147,7 +251,7 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
147251
else:
148252
return u.reshape((-1,1))*K*v.reshape((1,-1))
149253

150-
def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False):
254+
def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs):
151255
"""
152256
Solve the entropic regularization OT problem with log stabilization
153257
@@ -331,7 +435,7 @@ def get_Gamma(alpha,beta,u,v):
331435
else:
332436
return get_Gamma(alpha,beta,u,v)
333437

334-
def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False):
438+
def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False,**kwargs):
335439
"""
336440
Solve the entropic regularization optimal transport problem with log
337441
stabilization and epsilon scaling.

0 commit comments

Comments
 (0)