Skip to content

Commit 161d68a

Browse files
author
Hicham Janati
committed
fix loop counter in barycenter + precision of dual variables
1 parent a919f96 commit 161d68a

File tree

1 file changed

+50
-52
lines changed

1 file changed

+50
-52
lines changed

ot/unbalanced.py

Lines changed: 50 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -384,10 +384,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
384384

385385
fi = reg_m / (reg_m + reg)
386386

387-
cpt = 0
388387
err = 1.
389388

390-
while (err > stopThr and cpt < numItermax):
389+
for i in range(numItermax):
391390
uprev = u
392391
vprev = v
393392

@@ -401,28 +400,27 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
401400
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
402401
# we have reached the machine precision
403402
# come back to previous solution and quit loop
404-
warnings.warn('Numerical errors at iteration %s' % cpt)
403+
warnings.warn('Numerical errors at iteration %s' % i)
405404
u = uprev
406405
v = vprev
407406
break
408-
if cpt % 10 == 0:
409-
# we can speed up the process by checking for the error only all
410-
# the 10th iterations
411-
err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
412-
err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
413-
err = 0.5 * (err_u + err_v)
414-
if log:
415-
log['err'].append(err)
407+
408+
err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
409+
err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
410+
err = 0.5 * (err_u + err_v)
411+
if log:
412+
log['err'].append(err)
416413
if verbose:
417-
if cpt % 200 == 0:
414+
if i % 50 == 0:
418415
print(
419416
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
420-
print('{:5d}|{:8e}|'.format(cpt, err))
421-
cpt += 1
417+
print('{:5d}|{:8e}|'.format(i, err))
418+
if err < stopThr:
419+
break
422420

423421
if log:
424-
log['logu'] = np.log(u + 1e-16)
425-
log['logv'] = np.log(v + 1e-16)
422+
log['logu'] = np.log(u + 1e-300)
423+
log['logv'] = np.log(v + 1e-300)
426424

427425
if n_hists: # return only loss
428426
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
@@ -747,8 +745,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
747745
alpha = np.zeros(dim)
748746
beta = np.zeros(dim)
749747
q = np.ones(dim) / dim
750-
while (err > stopThr and cpt < numItermax):
751-
qprev = q
748+
for i in range(numItermax):
749+
qprev = q.copy()
752750
Kv = K.dot(v)
753751
f_alpha = np.exp(- alpha / (reg + reg_m))
754752
f_beta = np.exp(- beta / (reg + reg_m))
@@ -777,28 +775,29 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
777775
warnings.warn('Numerical errors at iteration %s' % cpt)
778776
q = qprev
779777
break
780-
if (cpt % 10 == 0 and not absorbing) or cpt == 0:
778+
if (i % 10 == 0 and not absorbing) or i == 0:
781779
# we can speed up the process by checking for the error only all
782780
# the 10th iterations
783781
err = abs(q - qprev).max() / max(abs(q).max(),
784782
abs(qprev).max(), 1.)
785783
if log:
786784
log['err'].append(err)
787785
if verbose:
788-
if cpt % 50 == 0:
786+
if i % 50 == 0:
789787
print(
790788
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
791-
print('{:5d}|{:8e}|'.format(cpt, err))
789+
print('{:5d}|{:8e}|'.format(i, err))
790+
if err < stopThr:
791+
break
792792

793-
cpt += 1
794793
if err > stopThr:
795794
warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
796795
"Try a larger entropy `reg` or a lower mass `reg_m`." +
797796
"Or a larger absorption threshold `tau`.")
798797
if log:
799-
log['niter'] = cpt
800-
log['logu'] = np.log(u + 1e-16)
801-
log['logv'] = np.log(v + 1e-16)
798+
log['niter'] = i
799+
log['logu'] = np.log(u + 1e-300)
800+
log['logv'] = np.log(v + 1e-300)
802801
return q, log
803802
else:
804803
return q
@@ -882,15 +881,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
882881

883882
fi = reg_m / (reg_m + reg)
884883

885-
v = np.ones((dim, n_hists)) / dim
886-
u = np.ones((dim, 1)) / dim
887-
888-
cpt = 0
884+
v = np.ones((dim, n_hists))
885+
u = np.ones((dim, 1))
886+
q = np.ones(dim)
889887
err = 1.
890888

891-
while (err > stopThr and cpt < numItermax):
892-
uprev = u
893-
vprev = v
889+
for i in range(numItermax):
890+
uprev = u.copy()
891+
vprev = v.copy()
892+
qprev = q.copy()
894893

895894
Kv = K.dot(v)
896895
u = (A / Kv) ** fi
@@ -905,31 +904,30 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
905904
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
906905
# we have reached the machine precision
907906
# come back to previous solution and quit loop
908-
warnings.warn('Numerical errors at iteration %s' % cpt)
907+
warnings.warn('Numerical errors at iteration %s' % i)
909908
u = uprev
910909
v = vprev
910+
q = qprev
911911
break
912-
if cpt % 10 == 0:
913-
# we can speed up the process by checking for the error only all
914-
# the 10th iterations
915-
err_u = abs(u - uprev).max()
916-
err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
917-
err_v = abs(v - vprev).max()
918-
err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
919-
err = 0.5 * (err_u + err_v)
920-
if log:
921-
log['err'].append(err)
922-
if verbose:
923-
if cpt % 50 == 0:
924-
print(
925-
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
926-
print('{:5d}|{:8e}|'.format(cpt, err))
912+
# compute change in barycenter
913+
err = abs(q - qprev).max()
914+
err /= max(abs(q).max(), abs(qprev).max(), 1.)
915+
if log:
916+
log['err'].append(err)
917+
# if barycenter did not change + at least 10 iterations - stop
918+
if err < stopThr and i > 10:
919+
break
920+
921+
if verbose:
922+
if i % 10 == 0:
923+
print(
924+
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
925+
print('{:5d}|{:8e}|'.format(i, err))
927926

928-
cpt += 1
929927
if log:
930-
log['niter'] = cpt
931-
log['logu'] = np.log(u + 1e-16)
932-
log['logv'] = np.log(v + 1e-16)
928+
log['niter'] = i
929+
log['logu'] = np.log(u + 1e-300)
930+
log['logv'] = np.log(v + 1e-300)
933931
return q, log
934932
else:
935933
return q

0 commit comments

Comments
 (0)