Skip to content

Commit 9569f89

Browse files
Kilian FatrasKilian Fatras
authored andcommitted
fix pep8
1 parent a2545b5 commit 9569f89

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

ot/bregman.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,17 +1375,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
13751375
'''
13761376

13771377
if a is None:
1378-
a = ot.unif(np.shape(X_s)[0])
1378+
a = utils.unif(np.shape(X_s)[0])
13791379
if b is None:
1380-
b = ot.unif(np.shape(X_t)[0])
1380+
b = utils.unif(np.shape(X_t)[0])
1381+
13811382
M = ot.dist(X_s, X_t, metric=metric)
1382-
if log == False:
1383-
pi = ot.sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
1384-
return pi
13851383

1386-
if log == True:
1387-
pi, log = ot.sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
1384+
if log:
1385+
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
13881386
return pi, log
1387+
else:
1388+
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
1389+
return pi
13891390

13901391

13911392
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
@@ -1464,18 +1465,18 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
14641465
'''
14651466

14661467
if a is None:
1467-
a = ot.unif(np.shape(X_s)[0])
1468+
a = utils.unif(np.shape(X_s)[0])
14681469
if b is None:
1469-
b = ot.unif(np.shape(X_t)[0])
1470+
b = utils.unif(np.shape(X_t)[0])
14701471

14711472
M = ot.dist(X_s, X_t, metric=metric)
1472-
if log == False:
1473-
sinkhorn_loss = ot.sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
1474-
return sinkhorn_loss
14751473

1476-
if log == True:
1477-
sinkhorn_loss, log = ot.sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
1474+
if log:
1475+
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
14781476
return sinkhorn_loss, log
1477+
else:
1478+
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
1479+
return sinkhorn_loss
14791480

14801481

14811482
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):

test/test_bregman.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,11 @@ def test_empirical_sinkhorn():
195195
n = 100
196196
a = ot.unif(n)
197197
b = ot.unif(n)
198-
M = ot.dist(X_s, X_t)
199-
M_e = ot.dist(X_s, X_t, metric='euclidean')
200-
201-
rng = np.random.RandomState(0)
202198

203199
X_s = np.reshape(np.arange(n), (n, 1))
204200
X_t = np.reshape(np.arange(0, n), (n, 1))
201+
M = ot.dist(X_s, X_t)
202+
M_e = ot.dist(X_s, X_t, metric='euclidean')
205203

206204
G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
207205
sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)

0 commit comments

Comments
 (0)