77
88def sinkhorn (a ,b , M , reg ,method = 'sinkhorn' , numItermax = 1000 , stopThr = 1e-9 , verbose = False , log = False ,** kwargs ):
99 u"""
10- Solve the entropic regularization optimal transport problem
10+ Solve the entropic regularization optimal transport problem and return the OT matrix
1111
1212 The function solves the following optimization problem:
1313
@@ -107,12 +107,9 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
107107 return sink ()
108108
109109
110-
111-
112-
113110def sinkhorn_knopp (a ,b , M , reg , numItermax = 1000 , stopThr = 1e-9 , verbose = False , log = False ,** kwargs ):
114111 """
115- Solve the entropic regularization optimal transport problem
112+ Solve the entropic regularization optimal transport problem and return the OT matrix
116113
117114 The function solves the following optimization problem:
118115
@@ -188,22 +185,35 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
188185 a = np .asarray (a ,dtype = np .float64 )
189186 b = np .asarray (b ,dtype = np .float64 )
190187 M = np .asarray (M ,dtype = np .float64 )
188+
191189
192190 if len (a )== 0 :
193191 a = np .ones ((M .shape [0 ],),dtype = np .float64 )/ M .shape [0 ]
194192 if len (b )== 0 :
195193 b = np .ones ((M .shape [1 ],),dtype = np .float64 )/ M .shape [1 ]
194+
196195
197196 # init data
198197 Nini = len (a )
199198 Nfin = len (b )
199+
200+ if len (b .shape )> 1 :
201+ nbb = b .shape [1 ]
202+ else :
203+ nbb = 0
204+
200205
201206 if log :
202207 log = {'err' :[]}
203208
204209 # we assume that no distances are null except those of the diagonal of distances
205- u = np .ones (Nini )/ Nini
206- v = np .ones (Nfin )/ Nfin
210+ if nbb :
211+ u = np .ones ((Nini ,nbb ))/ Nini
212+ v = np .ones ((Nfin ,nbb ))/ Nfin
213+ else :
214+ u = np .ones (Nini )/ Nini
215+ v = np .ones (Nfin )/ Nfin
216+
207217
208218 #print(reg)
209219
@@ -231,8 +241,11 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
231241 break
232242 if cpt % 10 == 0 :
233243 # we can speed up the process by checking for the error only all the 10th iterations
234- transp = u .reshape (- 1 , 1 ) * (K * v )
235- err = np .linalg .norm ((np .sum (transp ,axis = 0 )- b ))** 2
244+ if nbb :
245+ err = np .sum ((u - uprev )** 2 )/ np .sum ((u )** 2 )+ np .sum ((v - vprev )** 2 )/ np .sum ((v )** 2 )
246+ else :
247+ transp = u .reshape (- 1 , 1 ) * (K * v )
248+ err = np .linalg .norm ((np .sum (transp ,axis = 0 )- b ))** 2
236249 if log :
237250 log ['err' ].append (err )
238251
@@ -244,12 +257,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
244257 if log :
245258 log ['u' ]= u
246259 log ['v' ]= v
247-
248- #print('err=',err,' cpt=',cpt)
249- if log :
250- return u .reshape ((- 1 ,1 ))* K * v .reshape ((1 ,- 1 )),log
251- else :
252- return u .reshape ((- 1 ,1 ))* K * v .reshape ((1 ,- 1 ))
260+
261+ if nbb : #return only loss
262+ res = np .zeros ((nbb ))
263+ for i in range (nbb ):
264+ res [i ]= np .sum (u [:,i ].reshape ((- 1 ,1 ))* K * v [:,i ].reshape ((1 ,- 1 ))* M )
265+ if log :
266+ return res ,log
267+ else :
268+ return res
269+
270+ else : # return OT matrix
271+
272+ if log :
273+ return u .reshape ((- 1 ,1 ))* K * v .reshape ((1 ,- 1 )),log
274+ else :
275+ return u .reshape ((- 1 ,1 ))* K * v .reshape ((1 ,- 1 ))
276+
253277
254278def sinkhorn_stabilized (a ,b , M , reg , numItermax = 1000 ,tau = 1e3 , stopThr = 1e-9 ,warmstart = None , verbose = False ,print_period = 20 , log = False ,** kwargs ):
255279 """
0 commit comments