Skip to content

Commit 98ce4cc

Browse files
committed
free support barycenter
1 parent 3f23fa1 commit 98ce4cc

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

examples/plot_free_support_barycenter.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import matplotlib.pylab as pl
1818
import ot.plot
1919

20-
2120
##############################################################################
2221
# Generate data
2322
# -------------
@@ -29,10 +28,10 @@
2928

3029
for i in range(N):
3130

32-
n = np.rand.int(low=1, high=20) # nb samples
31+
n = np.random.randint(low=1, high=20) # nb samples
3332

3433
mu = np.random.normal(0., 1., (d,))
35-
cov = np.random.normal(0., 1., (d,d))
34+
cov = np.random.uniform(0., 1., (d,d))
3635

3736
xs = ot.datasets.make_2D_samples_gauss(n, mu, cov)
3837
b = np.random.uniform(0., 1., n)
@@ -49,7 +48,7 @@
4948
##############################################################################
5049
# Compute free support barycenter
5150
# -------------
52-
X = ot.lp.barycenter(measures_locations, measures_weights, X_init, b_init)
51+
X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b_init)
5352

5453

5554
##############################################################################
@@ -60,7 +59,7 @@
6059

6160
pl.figure(1)
6261
for (xs, b) in zip(measures_locations, measures_weights):
63-
pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.rand(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures')
62+
pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.random.uniform(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures')
6463
pl.scatter(xs[:, 0], xs[:, 1], s=b, c='black' , label='2-Wasserstein barycenter')
6564
pl.legend(loc=0)
6665
pl.title('Data measures and their barycenter')

ot/lp/cvx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b_init
217217

218218
M_i = ot.dist(X, measure_locations_i)
219219
T_i = ot.emd(b_init, measure_weights_i, M_i)
220-
T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, measure_locations_i)
220+
T_sum = T_sum + weight_i*np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, measure_locations_i)
221221

222222
displacement_square_norm = np.sum(np.square(X-T_sum))
223223
X = T_sum

0 commit comments

Comments
 (0)