Skip to content

Commit a7bed09

Browse files
committed
implement paralell sinkhorn
1 parent 3af9b06 commit a7bed09

File tree

3 files changed

+64
-23
lines changed

3 files changed

+64
-23
lines changed

examples/plot_OT_1D.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
#%% Sinkhorn
5151

5252
lambd=1e-3
53-
Gs=ot.sinkhorn(a,b,M,lambd)
53+
Gs=ot.sinkhorn(a,b,M,lambd,verbose=True)
5454

5555
pl.figure(4)
5656
ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')

examples/plot_compute_emd.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@
3232
for i,m in enumerate(lst_m):
3333
B[:,i]=gauss(n,m=m,s=5)
3434

35-
# loss matrix
35+
# loss matrix and normalization
3636
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean')
37+
M/=M.max()
3738
M2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean')
38-
39+
M2/=M2.max()
3940
#%% plot the distributions
4041

4142
pl.figure(1)
@@ -46,12 +47,28 @@
4647
pl.plot(x,B,label='Target distributions')
4748
pl.title('Target distributions')
4849

49-
#%% plot distributions and loss matrix
50+
#%% Compute and plot distributions and loss matrix
51+
52+
d_emd=ot.emd2(a,B,M) # direct computation of EMD
53+
d_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3
54+
5055

51-
emd=ot.emd2(a,B,M)
52-
emd2=ot.emd2(a,B,M2)
5356
pl.figure(2)
54-
pl.plot(emd,label='Euclidean loss')
55-
pl.plot(emd,label='Squared Euclidean loss')
57+
pl.plot(d_emd,label='Euclidean EMD')
58+
pl.plot(d_emd2,label='Squared Euclidean EMD')
59+
pl.title('EMD distances')
5660
pl.legend()
5761

62+
#%%
63+
reg=1e-2
64+
d_sinkhorn=ot.sinkhorn(a,B,M,reg)
65+
d_sinkhorn2=ot.sinkhorn(a,B,M2,reg)
66+
67+
pl.figure(2)
68+
pl.clf()
69+
pl.plot(d_emd,label='Euclidean EMD')
70+
pl.plot(d_emd2,label='Squared Euclidean EMD')
71+
pl.plot(d_sinkhorn,label='Euclidean Sinkhorn')
72+
pl.plot(d_emd2,label='Squared Euclidean Sinkhorn')
73+
pl.title('EMD distances')
74+
pl.legend()

ot/bregman.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
99
u"""
10-
Solve the entropic regularization optimal transport problem
10+
Solve the entropic regularization optimal transport problem and return the OT matrix
1111
1212
The function solves the following optimization problem:
1313
@@ -107,12 +107,9 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
107107
return sink()
108108

109109

110-
111-
112-
113110
def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
114111
"""
115-
Solve the entropic regularization optimal transport problem
112+
Solve the entropic regularization optimal transport problem and return the OT matrix
116113
117114
The function solves the following optimization problem:
118115
@@ -188,22 +185,35 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
188185
a=np.asarray(a,dtype=np.float64)
189186
b=np.asarray(b,dtype=np.float64)
190187
M=np.asarray(M,dtype=np.float64)
188+
191189

192190
if len(a)==0:
193191
a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
194192
if len(b)==0:
195193
b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
194+
196195

197196
# init data
198197
Nini = len(a)
199198
Nfin = len(b)
199+
200+
if len(b.shape)>1:
201+
nbb=b.shape[1]
202+
else:
203+
nbb=0
204+
200205

201206
if log:
202207
log={'err':[]}
203208

204209
# we assume that no distances are null except those of the diagonal of distances
205-
u = np.ones(Nini)/Nini
206-
v = np.ones(Nfin)/Nfin
210+
if nbb:
211+
u = np.ones((Nini,nbb))/Nini
212+
v = np.ones((Nfin,nbb))/Nfin
213+
else:
214+
u = np.ones(Nini)/Nini
215+
v = np.ones(Nfin)/Nfin
216+
207217

208218
#print(reg)
209219

@@ -231,8 +241,11 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
231241
break
232242
if cpt%10==0:
233243
# we can speed up the process by checking for the error only all the 10th iterations
234-
transp = u.reshape(-1, 1) * (K * v)
235-
err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
244+
if nbb:
245+
err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2)
246+
else:
247+
transp = u.reshape(-1, 1) * (K * v)
248+
err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
236249
if log:
237250
log['err'].append(err)
238251

@@ -244,12 +257,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
244257
if log:
245258
log['u']=u
246259
log['v']=v
247-
248-
#print('err=',err,' cpt=',cpt)
249-
if log:
250-
return u.reshape((-1,1))*K*v.reshape((1,-1)),log
251-
else:
252-
return u.reshape((-1,1))*K*v.reshape((1,-1))
260+
261+
if nbb: #return only loss
262+
res=np.zeros((nbb))
263+
for i in range(nbb):
264+
res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M)
265+
if log:
266+
return res,log
267+
else:
268+
return res
269+
270+
else: # return OT matrix
271+
272+
if log:
273+
return u.reshape((-1,1))*K*v.reshape((1,-1)),log
274+
else:
275+
return u.reshape((-1,1))*K*v.reshape((1,-1))
276+
253277

254278
def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs):
255279
"""

0 commit comments

Comments
 (0)