Skip to content

Commit eb17e02

Browse files
committed
correct if error bug
1 parent 653fd00 commit eb17e02

File tree

1 file changed

+29
-29
lines changed

1 file changed

+29
-29
lines changed

ot/bregman.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
103103
def sink():
104104
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
105105
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
106-
if method.lower() == 'greenkhorn':
106+
elif method.lower() == 'greenkhorn':
107107
def sink():
108108
return greenkhorn(a, b, M, reg, numItermax=numItermax,
109-
stopThr=stopThr, verbose=verbose, log=log)
109+
stopThr=stopThr, verbose=verbose, log=log)
110110
elif method.lower() == 'sinkhorn_stabilized':
111111
def sink():
112112
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
@@ -417,17 +417,16 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
417417
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
418418

419419

420-
421-
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log = False):
420+
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
422421
"""
423422
Solve the entropic regularization optimal transport problem and return the OT matrix
424-
423+
425424
The algorithm used is based on the paper
426-
425+
427426
Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration
428427
by Jason Altschuler, Jonathan Weed, Philippe Rigollet
429428
appeared at NIPS 2017
430-
429+
431430
which is a stochastic version of the Sinkhorn-Knopp algorithm [2].
432431
433432
The function solves the following optimization problem:
@@ -499,21 +498,21 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log
499498
ot.optim.cg : General regularized OT
500499
501500
"""
502-
501+
503502
i = 0
504-
503+
505504
n = a.shape[0]
506505
m = b.shape[0]
507-
506+
508507
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
509508
K = np.empty(M.shape, dtype=M.dtype)
510509
np.divide(M, -reg, out=K)
511510
np.exp(K, out=K)
512-
513-
u = np.ones(n)/n
514-
v = np.ones(m)/m
511+
512+
u = np.ones(n) / n
513+
v = np.ones(m) / m
515514
G = np.diag(u)@K@np.diag(v)
516-
515+
517516
one_n = np.ones(n)
518517
one_m = np.ones(m)
519518
viol = G@one_m - a
@@ -524,41 +523,42 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log
524523
log['v'] = v
525524

526525
while i < numItermax and stopThr_val > stopThr:
527-
i +=1
526+
i += 1
528527
i_1 = np.argmax(np.abs(viol))
529528
i_2 = np.argmax(np.abs(viol_2))
530529
m_viol_1 = np.abs(viol[i_1])
531530
m_viol_2 = np.abs(viol_2[i_2])
532-
stopThr_val = np.maximum(m_viol_1,m_viol_2)
533-
531+
stopThr_val = np.maximum(m_viol_1, m_viol_2)
532+
534533
if m_viol_1 > m_viol_2:
535534
old_u = u[i_1]
536-
u[i_1] = a[i_1]/(K[i_1,:]@v)
537-
G[i_1,:] = u[i_1]*K[i_1,:]*v
535+
u[i_1] = a[i_1] / (K[i_1, :]@v)
536+
G[i_1, :] = u[i_1] * K[i_1, :] * v
538537

539-
viol[i_1] = u[i_1]*K[i_1,:]@v - a[i_1]
540-
viol_2 = viol_2 + ( K[i_1,:].T*(u[i_1] - old_u)*v)
538+
viol[i_1] = u[i_1] * K[i_1, :]@v - a[i_1]
539+
viol_2 = viol_2 + (K[i_1, :].T * (u[i_1] - old_u) * v)
541540

542541
else:
543542
old_v = v[i_2]
544-
v[i_2] = b[i_2]/(K[:,i_2].T@u)
545-
G[:,i_2] = u*K[:,i_2]*v[i_2]
543+
v[i_2] = b[i_2] / (K[:, i_2].T@u)
544+
G[:, i_2] = u * K[:, i_2] * v[i_2]
546545
#aviol = (G@one_m - a)
547546
#aviol_2 = (G.T@one_n - b)
548-
viol = viol + ( -old_v + v[i_2])*K[:,i_2]*u
549-
viol_2[i_2] = v[i_2]*K[:,i_2]@u - b[i_2]
550-
547+
viol = viol + (-old_v + v[i_2]) * K[:, i_2] * u
548+
viol_2[i_2] = v[i_2] * K[:, i_2]@u - b[i_2]
549+
551550
#print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
552-
551+
553552
if log:
554553
log['u'] = u
555554
log['v'] = v
556-
555+
557556
if log:
558-
return G,log
557+
return G, log
559558
else:
560559
return G
561560

561+
562562
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
563563
warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
564564
"""

0 commit comments

Comments
 (0)