1111import numpy as np
1212import matplotlib .pylab as pl
1313import ot
14- from mpl_toolkits .mplot3d import Axes3D #necessary for 3d plot even if not used
14+ # necessary for 3d plot even if not used
15+ from mpl_toolkits .mplot3d import Axes3D # noqa
1516from matplotlib .collections import PolyCollection
1617
1718
1819#%% parameters
1920
20- n = 100 # nb bins
21+ n = 100 # nb bins
2122
2223# bin positions
23- x = np .arange (n ,dtype = np .float64 )
24+ x = np .arange (n , dtype = np .float64 )
2425
2526# Gaussian distributions
26- a1 = ot .datasets .get_1D_gauss (n ,m = 20 ,s = 5 ) # m= mean, s= std
27- a2 = ot .datasets .get_1D_gauss (n ,m = 60 ,s = 8 )
27+ a1 = ot .datasets .get_1D_gauss (n , m = 20 , s = 5 ) # m= mean, s= std
28+ a2 = ot .datasets .get_1D_gauss (n , m = 60 , s = 8 )
2829
2930# creating matrix A containing all distributions
30- A = np .vstack ((a1 ,a2 )).T
31- nbd = A .shape [1 ]
31+ A = np .vstack ((a1 , a2 )).T
32+ n_distributions = A .shape [1 ]
3233
3334# loss matrix + normalization
34- M = ot .utils .dist0 (n )
35- M /= M .max ()
35+ M = ot .utils .dist0 (n )
36+ M /= M .max ()
3637
3738#%% plot the distributions
3839
39- pl .figure (1 )
40- for i in range (nbd ):
41- pl .plot (x ,A [:,i ])
40+ pl .figure (1 , figsize = ( 6.4 , 3 ) )
41+ for i in range (n_distributions ):
42+ pl .plot (x , A [:, i ])
4243pl .title ('Distributions' )
44+ pl .tight_layout ()
4345
4446#%% barycenter computation
4547
46- alpha = 0.2 # 0<=alpha<=1
47- weights = np .array ([1 - alpha ,alpha ])
48+ alpha = 0.2 # 0<=alpha<=1
49+ weights = np .array ([1 - alpha , alpha ])
4850
4951# l2bary
50- bary_l2 = A .dot (weights )
52+ bary_l2 = A .dot (weights )
5153
5254# wasserstein
53- reg = 1e-3
54- bary_wass = ot .bregman .barycenter (A ,M , reg ,weights )
55+ reg = 1e-3
56+ bary_wass = ot .bregman .barycenter (A , M , reg , weights )
5557
5658pl .figure (2 )
5759pl .clf ()
58- pl .subplot (2 ,1 , 1 )
59- for i in range (nbd ):
60- pl .plot (x ,A [:,i ])
60+ pl .subplot (2 , 1 , 1 )
61+ for i in range (n_distributions ):
62+ pl .plot (x , A [:, i ])
6163pl .title ('Distributions' )
6264
63- pl .subplot (2 ,1 , 2 )
64- pl .plot (x ,bary_l2 ,'r' ,label = 'l2' )
65- pl .plot (x ,bary_wass ,'g' ,label = 'Wasserstein' )
65+ pl .subplot (2 , 1 , 2 )
66+ pl .plot (x , bary_l2 , 'r' , label = 'l2' )
67+ pl .plot (x , bary_wass , 'g' , label = 'Wasserstein' )
6668pl .legend ()
6769pl .title ('Barycenters' )
68-
70+ pl . tight_layout ()
6971
7072#%% barycenter interpolation
7173
72- nbalpha = 11
73- alphalist = np .linspace (0 ,1 , nbalpha )
74+ n_alpha = 11
75+ alpha_list = np .linspace (0 , 1 , n_alpha )
7476
7577
76- B_l2 = np .zeros ((n ,nbalpha ))
78+ B_l2 = np .zeros ((n , n_alpha ))
7779
78- B_wass = np .copy (B_l2 )
80+ B_wass = np .copy (B_l2 )
7981
80- for i in range (0 ,nbalpha ):
81- alpha = alphalist [i ]
82- weights = np .array ([1 - alpha ,alpha ])
83- B_l2 [:,i ] = A .dot (weights )
84- B_wass [:,i ] = ot .bregman .barycenter (A ,M , reg ,weights )
82+ for i in range (0 , n_alpha ):
83+ alpha = alpha_list [i ]
84+ weights = np .array ([1 - alpha , alpha ])
85+ B_l2 [:, i ] = A .dot (weights )
86+ B_wass [:, i ] = ot .bregman .barycenter (A , M , reg , weights )
8587
8688#%% plot interpolation
8789
88- pl .figure (3 ,( 10 , 5 ) )
90+ pl .figure (3 )
8991
90- #pl.subplot(1,2,1)
91- cmap = pl .cm .get_cmap ('viridis' )
92+ cmap = pl .cm .get_cmap ('viridis' )
9293verts = []
93- zs = alphalist
94- for i ,z in enumerate (zs ):
95- ys = B_l2 [:,i ]
94+ zs = alpha_list
95+ for i , z in enumerate (zs ):
96+ ys = B_l2 [:, i ]
9697 verts .append (list (zip (x , ys )))
9798
9899ax = pl .gcf ().gca (projection = '3d' )
99100
100- poly = PolyCollection (verts ,facecolors = [cmap (a ) for a in alphalist ])
101+ poly = PolyCollection (verts , facecolors = [cmap (a ) for a in alpha_list ])
101102poly .set_alpha (0.7 )
102103ax .add_collection3d (poly , zs = zs , zdir = 'y' )
103-
104104ax .set_xlabel ('x' )
105105ax .set_xlim3d (0 , n )
106106ax .set_ylabel ('$\\ alpha$' )
107- ax .set_ylim3d (0 ,1 )
107+ ax .set_ylim3d (0 , 1 )
108108ax .set_zlabel ('' )
109- ax .set_zlim3d (0 , B_l2 .max ()* 1.01 )
109+ ax .set_zlim3d (0 , B_l2 .max () * 1.01 )
110110pl .title ('Barycenter interpolation with l2' )
111+ pl .tight_layout ()
111112
112- pl .show ()
113-
114- pl .figure (4 ,(10 ,5 ))
115-
116- #pl.subplot(1,2,1)
117- cmap = pl .cm .get_cmap ('viridis' )
113+ pl .figure (4 )
114+ cmap = pl .cm .get_cmap ('viridis' )
118115verts = []
119- zs = alphalist
120- for i ,z in enumerate (zs ):
121- ys = B_wass [:,i ]
116+ zs = alpha_list
117+ for i , z in enumerate (zs ):
118+ ys = B_wass [:, i ]
122119 verts .append (list (zip (x , ys )))
123120
124121ax = pl .gcf ().gca (projection = '3d' )
125122
126- poly = PolyCollection (verts ,facecolors = [cmap (a ) for a in alphalist ])
123+ poly = PolyCollection (verts , facecolors = [cmap (a ) for a in alpha_list ])
127124poly .set_alpha (0.7 )
128125ax .add_collection3d (poly , zs = zs , zdir = 'y' )
129-
130126ax .set_xlabel ('x' )
131127ax .set_xlim3d (0 , n )
132128ax .set_ylabel ('$\\ alpha$' )
133- ax .set_ylim3d (0 ,1 )
129+ ax .set_ylim3d (0 , 1 )
134130ax .set_zlabel ('' )
135- ax .set_zlim3d (0 , B_l2 .max ()* 1.01 )
131+ ax .set_zlim3d (0 , B_l2 .max () * 1.01 )
136132pl .title ('Barycenter interpolation with Wasserstein' )
133+ pl .tight_layout ()
137134
138- pl .show ()
135+ pl .show ()
0 commit comments