Skip to content

Commit 4fba2c9

Browse files
committed
debug bregman stabilized
1 parent 11239a9 commit 4fba2c9

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

ot/bregman.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,16 @@ def get_Gamma(alpha,beta,u,v):
416416
err=1
417417
while loop:
418418

419+
420+
421+
uprev = u
422+
vprev = v
423+
424+
# sinkhorn update
425+
v = b/(np.dot(K.T,u)+1e-16)
426+
u = a/(np.dot(K,v)+1e-16)
427+
428+
419429
# remove numerical problems and store them in K
420430
if np.abs(u).max()>tau or np.abs(v).max()>tau:
421431
if nbb:
@@ -428,12 +438,6 @@ def get_Gamma(alpha,beta,u,v):
428438
u,v = np.ones(na)/na,np.ones(nb)/nb
429439
K=get_K(alpha,beta)
430440

431-
uprev = u
432-
vprev = v
433-
434-
# sinkhorn update
435-
v = b/np.dot(K.T,u)
436-
u = a/np.dot(K,v)
437441

438442
if cpt%print_period==0:
439443
# we can speed up the process by checking for the error only all the 10th iterations
@@ -458,9 +462,7 @@ def get_Gamma(alpha,beta,u,v):
458462
loop=False
459463

460464

461-
if (np.any(np.dot(K.T,u)==0) or
462-
np.any(np.isnan(u)) or np.any(np.isnan(v)) or
463-
np.any(np.isinf(u)) or np.any(np.isinf(v))):
465+
if np.any(np.isnan(u)) or np.any(np.isnan(v)):
464466
# we have reached the machine precision
465467
# come back to previous solution and quit loop
466468
print('Warning: numerical errors at iteration', cpt)

0 commit comments

Comments
 (0)