Skip to content

Commit e39f04a

Browse files
author
Vivien Seguy
committed
add free support barycenter algorithm
1 parent 98ce4cc commit e39f04a

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

examples/plot_free_support_barycenter.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# Generate data
2222
# -------------
2323
#%% parameters and data generation
24-
N = 4
24+
N = 6
2525
d = 2
2626
measures_locations = []
2727
measures_weights = []
@@ -30,11 +30,13 @@
3030

3131
n = np.random.randint(low=1, high=20) # nb samples
3232

33-
mu = np.random.normal(0., 1., (d,))
34-
cov = np.random.uniform(0., 1., (d,d))
33+
mu = np.random.normal(0., 4., (d,))
34+
35+
A = np.random.rand(d, d)
36+
cov = np.dot(A,A.transpose())
3537

3638
xs = ot.datasets.make_2D_samples_gauss(n, mu, cov)
37-
b = np.random.uniform(0., 1., n)
39+
b = np.random.uniform(0., 1., (n,))
3840
b = b/np.sum(b)
3941

4042
measures_locations.append(xs)
@@ -59,7 +61,9 @@
5961

6062
pl.figure(1)
6163
for (xs, b) in zip(measures_locations, measures_weights):
62-
pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.random.uniform(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures')
63-
pl.scatter(xs[:, 0], xs[:, 1], s=b, c='black' , label='2-Wasserstein barycenter')
64-
pl.legend(loc=0)
64+
color = np.random.randint(low=1, high=10*N)
65+
pl.scatter(xs[:, 0], xs[:, 1], s=b*1000, label='input measure')
66+
pl.scatter(X[:, 0], X[:, 1], s=b_init*1000, c='black' , marker='^', label='2-Wasserstein barycenter')
6567
pl.title('Data measures and their barycenter')
68+
pl.legend(loc=0)
69+
pl.show()

0 commit comments

Comments
 (0)