@@ -480,7 +480,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
480480 >>> a=[.5,.5]
481481 >>> b=[.5,.5]
482482 >>> M=[[0.,1.],[1.,0.]]
483- >>> ot.sinkhorn (a,b,M,1)
483+ >>> ot.bregman.greenkhorn (a,b,M,1)
484484 array([[ 0.36552929, 0.13447071],
485485 [ 0.13447071, 0.36552929]])
486486
@@ -505,18 +505,18 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
505505 m = b .shape [0 ]
506506
507507 # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
508- K = np .empty ( M . shape , dtype = M . dtype )
508+ K = np .empty_like ( M )
509509 np .divide (M , - reg , out = K )
510510 np .exp (K , out = K )
511511
512- u = np .ones ( n ) / n
513- v = np .ones ( m ) / m
514- G = np .diag ( u )@ K @ np .diag ( v )
512+ u = np .full ( n , 1. / n )
513+ v = np .full ( m , 1. / m )
514+ G = u [:, np .newaxis ] * K * v [ np .newaxis , :]
515515
516516 one_n = np .ones (n )
517517 one_m = np .ones (m )
518- viol = G @ one_m - a
519- viol_2 = G .T @ one_n - b
518+ viol = G . sum ( 1 ) - a
519+ viol_2 = G .sum ( 0 ) - b
520520 stopThr_val = 1
521521 if log :
522522 log ['u' ] = u
@@ -532,26 +532,26 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
532532
533533 if m_viol_1 > m_viol_2 :
534534 old_u = u [i_1 ]
535- u [i_1 ] = a [i_1 ] / (K [i_1 , :]@ v )
535+ u [i_1 ] = a [i_1 ] / (K [i_1 , :]. dot ( v ) )
536536 G [i_1 , :] = u [i_1 ] * K [i_1 , :] * v
537537
538- viol [i_1 ] = u [i_1 ] * K [i_1 , :]@ v - a [i_1 ]
538+ viol [i_1 ] = u [i_1 ] * K [i_1 , :]. dot ( v ) - a [i_1 ]
539539 viol_2 = viol_2 + (K [i_1 , :].T * (u [i_1 ] - old_u ) * v )
540540
541541 else :
542542 old_v = v [i_2 ]
543- v [i_2 ] = b [i_2 ] / (K [:, i_2 ].T @ u )
543+ v [i_2 ] = b [i_2 ] / (K [:, i_2 ].T . dot ( u ) )
544544 G [:, i_2 ] = u * K [:, i_2 ] * v [i_2 ]
545545 #aviol = (G@one_m - a)
546546 #aviol_2 = (G.T@one_n - b)
547547 viol = viol + (- old_v + v [i_2 ]) * K [:, i_2 ] * u
548- viol_2 [i_2 ] = v [i_2 ] * K [:, i_2 ]@ u - b [i_2 ]
548+ viol_2 [i_2 ] = v [i_2 ] * K [:, i_2 ]. dot ( u ) - b [i_2 ]
549549
550550 #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
551551
552- if log :
553- log ['u' ] = u
554- log ['v' ] = v
552+ if log :
553+ log ['u' ] = u
554+ log ['v' ] = v
555555
556556 if log :
557557 return G , log
0 commit comments