Skip to content

Commit 60943d0

Browse files
Auto PEP8
1 parent 1e2e118 commit 60943d0

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

ot/lp/__init__.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212

1313
import multiprocessing
1414
import sys
15+
1516
import numpy as np
1617
from scipy.sparse import coo_matrix
1718

18-
from .import cvx
19-
19+
from . import cvx
20+
from .cvx import barycenter
2021
# import compiled emd
2122
from .emd_wrap import emd_c, check_result, emd_1d_sorted
22-
from ..utils import parmap
23-
from .cvx import barycenter
2423
from ..utils import dist
24+
from ..utils import parmap
2525

2626
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
2727
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -458,7 +458,8 @@ def f(b):
458458
return res
459459

460460

461-
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
461+
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
462+
stopThr=1e-7, verbose=False, log=None):
462463
"""
463464
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
464465
@@ -525,8 +526,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
525526

526527
T_sum = np.zeros((k, d))
527528

528-
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
529-
529+
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
530+
weights.tolist()):
530531
M_i = dist(X, measure_locations_i)
531532
T_i = emd(b, measure_weights_i, M_i)
532533
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
@@ -651,8 +652,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
651652
if b.ndim == 0 or len(b) == 0:
652653
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
653654

654-
x_a_1d = x_a.reshape((-1, ))
655-
x_b_1d = x_b.reshape((-1, ))
655+
x_a_1d = x_a.reshape((-1,))
656+
x_b_1d = x_b.reshape((-1,))
656657
perm_a = np.argsort(x_a_1d)
657658
perm_b = np.argsort(x_b_1d)
658659

0 commit comments

Comments
 (0)