|
12 | 12 |
|
13 | 13 | import multiprocessing |
14 | 14 | import sys |
| 15 | + |
15 | 16 | import numpy as np |
16 | 17 | from scipy.sparse import coo_matrix |
17 | 18 |
|
18 | | -from .import cvx |
19 | | - |
| 19 | +from . import cvx |
| 20 | +from .cvx import barycenter |
20 | 21 | # import compiled emd |
21 | 22 | from .emd_wrap import emd_c, check_result, emd_1d_sorted |
22 | | -from ..utils import parmap |
23 | | -from .cvx import barycenter |
24 | 23 | from ..utils import dist |
| 24 | +from ..utils import parmap |
25 | 25 |
|
26 | 26 | __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', |
27 | 27 | 'emd_1d', 'emd2_1d', 'wasserstein_1d'] |
@@ -458,7 +458,8 @@ def f(b): |
458 | 458 | return res |
459 | 459 |
|
460 | 460 |
|
461 | | -def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None): |
| 461 | +def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, |
| 462 | + stopThr=1e-7, verbose=False, log=None): |
462 | 463 | """ |
463 | 464 | Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) |
464 | 465 |
|
@@ -525,8 +526,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None |
525 | 526 |
|
526 | 527 | T_sum = np.zeros((k, d)) |
527 | 528 |
|
528 | | - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): |
529 | | - |
| 529 | + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, |
| 530 | + weights.tolist()): |
530 | 531 | M_i = dist(X, measure_locations_i) |
531 | 532 | T_i = emd(b, measure_weights_i, M_i) |
532 | 533 | T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) |
@@ -651,8 +652,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, |
651 | 652 | if b.ndim == 0 or len(b) == 0: |
652 | 653 | b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] |
653 | 654 |
|
654 | | - x_a_1d = x_a.reshape((-1, )) |
655 | | - x_b_1d = x_b.reshape((-1, )) |
| 655 | + x_a_1d = x_a.reshape((-1,)) |
| 656 | + x_b_1d = x_b.reshape((-1,)) |
656 | 657 | perm_a = np.argsort(x_a_1d) |
657 | 658 | perm_b = np.argsort(x_b_1d) |
658 | 659 |
|
|
0 commit comments