2121# Generate data
2222# -------------
2323#%% parameters and data generation
24- N = 4
24+ N = 6
2525d = 2
2626measures_locations = []
2727measures_weights = []
3030
3131 n = np .random .randint (low = 1 , high = 20 ) # nb samples
3232
33- mu = np .random .normal (0. , 1. , (d ,))
34- cov = np .random .uniform (0. , 1. , (d ,d ))
33+ mu = np .random .normal (0. , 4. , (d ,))
34+
35+ A = np .random .rand (d , d )
36+ cov = np .dot (A ,A .transpose ())
3537
3638 xs = ot .datasets .make_2D_samples_gauss (n , mu , cov )
37- b = np .random .uniform (0. , 1. , n )
39+ b = np .random .uniform (0. , 1. , ( n ,) )
3840 b = b / np .sum (b )
3941
4042 measures_locations .append (xs )
5961
6062pl .figure (1 )
6163for (xs , b ) in zip (measures_locations , measures_weights ):
62- pl . scatter ( xs [:, 0 ], xs [:, 1 ], s = b , c = np .tile ( np . random .uniform ( 0. , 255. , size = ( 3 ,)), ( 1 , b . size ( 0 ))) , label = 'Data measures' )
63- pl .scatter (xs [:, 0 ], xs [:, 1 ], s = b , c = 'black' , label = '2-Wasserstein barycenter ' )
64- pl .legend ( loc = 0 )
64+ color = np .random .randint ( low = 1 , high = 10 * N )
65+ pl .scatter (xs [:, 0 ], xs [:, 1 ], s = b * 1000 , label = 'input measure ' )
66+ pl .scatter ( X [:, 0 ], X [:, 1 ], s = b_init * 1000 , c = 'black' , marker = '^' , label = '2-Wasserstein barycenter' )
6567pl .title ('Data measures and their barycenter' )
68+ pl .legend (loc = 0 )
69+ pl .show ()
0 commit comments