|
17 | 17 | import matplotlib.pylab as pl |
18 | 18 | import ot.plot |
19 | 19 |
|
20 | | - |
21 | 20 | ############################################################################## |
22 | 21 | # Generate data |
23 | 22 | # ------------- |
|
29 | 28 |
|
30 | 29 | for i in range(N): |
31 | 30 |
|
32 | | - n = np.rand.int(low=1, high=20) # nb samples |
| 31 | + n = np.random.randint(low=1, high=20) # nb samples |
33 | 32 |
|
34 | 33 | mu = np.random.normal(0., 1., (d,)) |
35 | | - cov = np.random.normal(0., 1., (d,d)) |
| 34 | + cov = np.random.uniform(0., 1., (d,d)) |
36 | 35 |
|
37 | 36 | xs = ot.datasets.make_2D_samples_gauss(n, mu, cov) |
38 | 37 | b = np.random.uniform(0., 1., n) |
|
49 | 48 | ############################################################################## |
50 | 49 | # Compute free support barycenter |
51 | 50 | # ------------- |
52 | | -X = ot.lp.barycenter(measures_locations, measures_weights, X_init, b_init) |
| 51 | +X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b_init) |
53 | 52 |
|
54 | 53 |
|
55 | 54 | ############################################################################## |
|
60 | 59 |
|
61 | 60 | pl.figure(1) |
62 | 61 | for (xs, b) in zip(measures_locations, measures_weights): |
63 | | - pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.rand(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures') |
| 62 | + pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.random.uniform(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures') |
64 | 63 | pl.scatter(xs[:, 0], xs[:, 1], s=b, c='black' , label='2-Wasserstein barycenter') |
65 | 64 | pl.legend(loc=0) |
66 | 65 | pl.title('Data measures and their barycenter') |
0 commit comments