99import matplotlib .pylab as pl
1010import ot
1111
12-
12+ from mpl_toolkits .mplot3d import Axes3D
13+ from matplotlib .collections import PolyCollection
14+ from matplotlib .colors import colorConverter
1315
1416#%% parameters
1517
1921x = np .arange (n ,dtype = np .float64 )
2022
2123# Gaussian distributions
22- a1 = ot .datasets .get_1D_gauss (n ,m = 20 ,s = 20 ) # m= mean, s= std
23- a2 = ot .datasets .get_1D_gauss (n ,m = 60 ,s = 60 )
24+ a1 = ot .datasets .get_1D_gauss (n ,m = 20 ,s = 5 ) # m= mean, s= std
25+ a2 = ot .datasets .get_1D_gauss (n ,m = 60 ,s = 8 )
2426
2527# creating matrix A containing all distributions
2628A = np .vstack ((a1 ,a2 )).T
3941
4042#%% barycenter computation
4143
44+ alpha = 0.2 # 0<=alpha<=1
45+ weights = np .array ([1 - alpha ,alpha ])
46+
4247# l2bary
43- bary_l2 = A .mean ( 1 )
48+ bary_l2 = A .dot ( weights )
4449
4550# wasserstein
4651reg = 1e-3
47- bary_wass = ot .bregman .barycenter (A ,M ,reg )
52+ bary_wass = ot .bregman .barycenter (A ,M ,reg , weights )
4853
4954pl .figure (2 )
5055pl .clf ()
5863pl .plot (x ,bary_wass ,'g' ,label = 'Wasserstein' )
5964pl .legend ()
6065pl .title ('Barycenters' )
66+
67+
68+ #%% barycenter interpolation
69+
70+ nbalpha = 11
71+ alphalist = np .linspace (0 ,1 ,nbalpha )
72+
73+
74+ B_l2 = np .zeros ((n ,nbalpha ))
75+
76+ B_wass = np .copy (B_l2 )
77+
78+ for i in range (0 ,nbalpha ):
79+ alpha = alphalist [i ]
80+ weights = np .array ([1 - alpha ,alpha ])
81+ B_l2 [:,i ]= A .dot (weights )
82+ B_wass [:,i ]= ot .bregman .barycenter (A ,M ,reg ,weights )
83+
84+ #%% plot interpolation
85+
86+ pl .figure (3 ,(10 ,5 ))
87+
88+ #pl.subplot(1,2,1)
89+ cmap = pl .cm .get_cmap ('viridis' )
90+ verts = []
91+ zs = alphalist
92+ for i ,z in enumerate (zs ):
93+ ys = B_l2 [:,i ]
94+ verts .append (list (zip (x , ys )))
95+
96+ ax = pl .gcf ().gca (projection = '3d' )
97+
98+ poly = PolyCollection (verts ,facecolors = [cmap (a ) for a in alphalist ])
99+ poly .set_alpha (0.7 )
100+ ax .add_collection3d (poly , zs = zs , zdir = 'y' )
101+
102+ ax .set_xlabel ('x' )
103+ ax .set_xlim3d (0 , n )
104+ ax .set_ylabel ('$\\ alpha$' )
105+ ax .set_ylim3d (0 ,1 )
106+ ax .set_zlabel ('' )
107+ ax .set_zlim3d (0 , B_l2 .max ()* 1.01 )
108+ pl .title ('Barycenter interpolation with l2' )
109+
110+ pl .show ()
111+
112+ pl .figure (4 ,(10 ,5 ))
113+
114+ #pl.subplot(1,2,1)
115+ cmap = pl .cm .get_cmap ('viridis' )
116+ verts = []
117+ zs = alphalist
118+ for i ,z in enumerate (zs ):
119+ ys = B_wass [:,i ]
120+ verts .append (list (zip (x , ys )))
121+
122+ ax = pl .gcf ().gca (projection = '3d' )
123+
124+ poly = PolyCollection (verts ,facecolors = [cmap (a ) for a in alphalist ])
125+ poly .set_alpha (0.7 )
126+ ax .add_collection3d (poly , zs = zs , zdir = 'y' )
127+
128+ ax .set_xlabel ('x' )
129+ ax .set_xlim3d (0 , n )
130+ ax .set_ylabel ('$\\ alpha$' )
131+ ax .set_ylim3d (0 ,1 )
132+ ax .set_zlabel ('' )
133+ ax .set_zlim3d (0 , B_l2 .max ()* 1.01 )
134+ pl .title ('Barycenter interpolation with Wasserstein' )
135+
136+ pl .show ()
0 commit comments