@@ -103,10 +103,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
103103 def sink ():
104104 return sinkhorn_knopp (a , b , M , reg , numItermax = numItermax ,
105105 stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
106- if method .lower () == 'greenkhorn' :
106+ elif method .lower () == 'greenkhorn' :
107107 def sink ():
108108 return greenkhorn (a , b , M , reg , numItermax = numItermax ,
109- stopThr = stopThr , verbose = verbose , log = log )
109+ stopThr = stopThr , verbose = verbose , log = log )
110110 elif method .lower () == 'sinkhorn_stabilized' :
111111 def sink ():
112112 return sinkhorn_stabilized (a , b , M , reg , numItermax = numItermax ,
@@ -417,17 +417,16 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
417417 return u .reshape ((- 1 , 1 )) * K * v .reshape ((1 , - 1 ))
418418
419419
420-
421- def greenkhorn (a , b , M , reg , numItermax = 10000 , stopThr = 1e-9 , verbose = False , log = False ):
420+ def greenkhorn (a , b , M , reg , numItermax = 10000 , stopThr = 1e-9 , verbose = False , log = False ):
422421 """
423422 Solve the entropic regularization optimal transport problem and return the OT matrix
424-
423+
425424 The algorithm used is based on the paper
426-
425+
427426 Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration
428427 by Jason Altschuler, Jonathan Weed, Philippe Rigollet
429428 appeared at NIPS 2017
430-
429+
431430 which is a stochastic version of the Sinkhorn-Knopp algorithm [2].
432431
433432 The function solves the following optimization problem:
@@ -499,21 +498,21 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log
499498 ot.optim.cg : General regularized OT
500499
501500 """
502-
501+
503502 i = 0
504-
503+
505504 n = a .shape [0 ]
506505 m = b .shape [0 ]
507-
506+
508507 # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
509508 K = np .empty (M .shape , dtype = M .dtype )
510509 np .divide (M , - reg , out = K )
511510 np .exp (K , out = K )
512-
513- u = np .ones (n )/ n
514- v = np .ones (m )/ m
511+
512+ u = np .ones (n ) / n
513+ v = np .ones (m ) / m
515514 G = np .diag (u )@K @np .diag (v )
516-
515+
517516 one_n = np .ones (n )
518517 one_m = np .ones (m )
519518 viol = G @one_m - a
@@ -524,41 +523,42 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log
524523 log ['v' ] = v
525524
526525 while i < numItermax and stopThr_val > stopThr :
527- i += 1
526+ i += 1
528527 i_1 = np .argmax (np .abs (viol ))
529528 i_2 = np .argmax (np .abs (viol_2 ))
530529 m_viol_1 = np .abs (viol [i_1 ])
531530 m_viol_2 = np .abs (viol_2 [i_2 ])
532- stopThr_val = np .maximum (m_viol_1 ,m_viol_2 )
533-
531+ stopThr_val = np .maximum (m_viol_1 , m_viol_2 )
532+
534533 if m_viol_1 > m_viol_2 :
535534 old_u = u [i_1 ]
536- u [i_1 ] = a [i_1 ]/ (K [i_1 ,:]@v )
537- G [i_1 ,:] = u [i_1 ]* K [i_1 ,:] * v
535+ u [i_1 ] = a [i_1 ] / (K [i_1 , :]@v )
536+ G [i_1 , :] = u [i_1 ] * K [i_1 , :] * v
538537
539- viol [i_1 ] = u [i_1 ]* K [i_1 ,:]@v - a [i_1 ]
540- viol_2 = viol_2 + ( K [i_1 ,:].T * (u [i_1 ] - old_u )* v )
538+ viol [i_1 ] = u [i_1 ] * K [i_1 , :]@v - a [i_1 ]
539+ viol_2 = viol_2 + (K [i_1 , :].T * (u [i_1 ] - old_u ) * v )
541540
542541 else :
543542 old_v = v [i_2 ]
544- v [i_2 ] = b [i_2 ]/ (K [:,i_2 ].T @u )
545- G [:,i_2 ] = u * K [:,i_2 ]* v [i_2 ]
543+ v [i_2 ] = b [i_2 ] / (K [:, i_2 ].T @u )
544+ G [:, i_2 ] = u * K [:, i_2 ] * v [i_2 ]
546545 #aviol = (G@one_m - a)
547546 #aviol_2 = (G.T@one_n - b)
548- viol = viol + ( - old_v + v [i_2 ])* K [:,i_2 ]* u
549- viol_2 [i_2 ] = v [i_2 ]* K [:,i_2 ]@u - b [i_2 ]
550-
547+ viol = viol + (- old_v + v [i_2 ]) * K [:, i_2 ] * u
548+ viol_2 [i_2 ] = v [i_2 ] * K [:, i_2 ]@u - b [i_2 ]
549+
551550 #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
552-
551+
553552 if log :
554553 log ['u' ] = u
555554 log ['v' ] = v
556-
555+
557556 if log :
558- return G ,log
557+ return G , log
559558 else :
560559 return G
561560
561+
562562def sinkhorn_stabilized (a , b , M , reg , numItermax = 1000 , tau = 1e3 , stopThr = 1e-9 ,
563563 warmstart = None , verbose = False , print_period = 20 , log = False , ** kwargs ):
564564 """
0 commit comments