Skip to content

Commit d19295b

Browse files
committed
stabThr and pep8
1 parent dd200d5 commit d19295b

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

ot/bregman.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,7 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
919919
return geometricBar(weights, UKv)
920920

921921

922-
def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
922+
def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
923923
"""Compute the entropic regularized wasserstein barycenter of distributions A
924924
where A is a collection of 2D images.
925925
@@ -948,6 +948,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
948948
Max number of iterations
949949
stopThr : float, optional
950950
Stop threshol on error (>0)
951+
stabThr : float, optional
952+
Stabilization threshold to avoid numerical precision issue
951953
verbose : bool, optional
952954
Print information along iterations
953955
log : bool, optional
@@ -983,7 +985,6 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
983985
b = np.zeros_like(A[0, :, :])
984986
U = np.ones_like(A)
985987
KV = np.ones_like(A)
986-
threshold = 1e-30 # in order to avoids numerical precision issues
987988

988989
cpt = 0
989990
err = 1
@@ -993,7 +994,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
993994
[Y, X] = np.meshgrid(t, t)
994995
xi1 = np.exp(-(X - Y)**2 / reg)
995996

996-
def K(x):
997+
def K(x):
997998
return np.dot(np.dot(xi1, x), xi1)
998999

9991000
while (err > stopThr and cpt < numItermax):
@@ -1003,11 +1004,11 @@ def K(x):
10031004

10041005
b = np.zeros_like(A[0, :, :])
10051006
for r in range(A.shape[0]):
1006-
KV[r, :, :] = K(A[r, :, :] / np.maximum(threshold, K(U[r, :, :])))
1007-
b += weights[r] * np.log(np.maximum(threshold, U[r, :, :] * KV[r, :, :]))
1007+
KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :])))
1008+
b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :]))
10081009
b = np.exp(b)
10091010
for r in range(A.shape[0]):
1010-
U[r, :, :] = b / np.maximum(threshold, KV[r, :, :])
1011+
U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :])
10111012

10121013
if cpt % 10 == 1:
10131014
err = np.sum(np.abs(bold - b))

0 commit comments

Comments
 (0)