Skip to content

Commit 7ffd4fe

Browse files
committed
remove @ for python compatibility+ comments alexandre
1 parent f3433fd commit 7ffd4fe

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

ot/bregman.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
480480
>>> a=[.5,.5]
481481
>>> b=[.5,.5]
482482
>>> M=[[0.,1.],[1.,0.]]
483-
>>> ot.sinkhorn(a,b,M,1)
483+
>>> ot.bregman.greenkhorn(a,b,M,1)
484484
array([[ 0.36552929, 0.13447071],
485485
[ 0.13447071, 0.36552929]])
486486
@@ -505,18 +505,18 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
505505
m = b.shape[0]
506506

507507
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
508-
K = np.empty(M.shape, dtype=M.dtype)
508+
K = np.empty_like(M)
509509
np.divide(M, -reg, out=K)
510510
np.exp(K, out=K)
511511

512-
u = np.ones(n) / n
513-
v = np.ones(m) / m
514-
G = np.diag(u)@K@np.diag(v)
512+
u = np.full(n, 1. / n)
513+
v = np.full(m, 1. / m)
514+
G = u[:, np.newaxis] * K * v[np.newaxis, :]
515515

516516
one_n = np.ones(n)
517517
one_m = np.ones(m)
518-
viol = G@one_m - a
519-
viol_2 = G.T@one_n - b
518+
viol = G.sum(1) - a
519+
viol_2 = G.sum(0) - b
520520
stopThr_val = 1
521521
if log:
522522
log['u'] = u
@@ -532,26 +532,26 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
532532

533533
if m_viol_1 > m_viol_2:
534534
old_u = u[i_1]
535-
u[i_1] = a[i_1] / (K[i_1, :]@v)
535+
u[i_1] = a[i_1] / (K[i_1, :].dot(v))
536536
G[i_1, :] = u[i_1] * K[i_1, :] * v
537537

538-
viol[i_1] = u[i_1] * K[i_1, :]@v - a[i_1]
538+
viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1]
539539
viol_2 = viol_2 + (K[i_1, :].T * (u[i_1] - old_u) * v)
540540

541541
else:
542542
old_v = v[i_2]
543-
v[i_2] = b[i_2] / (K[:, i_2].T@u)
543+
v[i_2] = b[i_2] / (K[:, i_2].T.dot(u))
544544
G[:, i_2] = u * K[:, i_2] * v[i_2]
545545
#aviol = (G@one_m - a)
546546
#aviol_2 = (G.T@one_n - b)
547547
viol = viol + (-old_v + v[i_2]) * K[:, i_2] * u
548-
viol_2[i_2] = v[i_2] * K[:, i_2]@u - b[i_2]
548+
viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2]
549549

550550
#print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
551551

552-
if log:
553-
log['u'] = u
554-
log['v'] = v
552+
if log:
553+
log['u'] = u
554+
log['v'] = v
555555

556556
if log:
557557
return G, log

0 commit comments

Comments
 (0)