1515
1616#%% parameters
1717
18- n = 100 # nb bins
19- n_target = 50 # nb target distributions
18+ n = 100 # nb bins
19+ n_target = 50 # nb target distributions
2020
2121
2222# bin positions
23- x = np .arange (n ,dtype = np .float64 )
23+ x = np .arange (n , dtype = np .float64 )
2424
25- lst_m = np .linspace (20 ,90 ,n_target )
25+ lst_m = np .linspace (20 , 90 , n_target )
2626
2727# Gaussian distributions
28- a = gauss (n ,m = 20 ,s = 5 ) # m= mean, s= std
28+ a = gauss (n , m = 20 , s = 5 ) # m= mean, s= std
2929
30- B = np .zeros ((n ,n_target ))
30+ B = np .zeros ((n , n_target ))
3131
32- for i ,m in enumerate (lst_m ):
33- B [:,i ] = gauss (n ,m = m ,s = 5 )
32+ for i , m in enumerate (lst_m ):
33+ B [:, i ] = gauss (n , m = m , s = 5 )
3434
3535# loss matrix and normalization
36- M = ot .dist (x .reshape ((n ,1 )),x .reshape ((n ,1 )),'euclidean' )
37- M /= M .max ()
38- M2 = ot .dist (x .reshape ((n ,1 )),x .reshape ((n ,1 )),'sqeuclidean' )
39- M2 /= M2 .max ()
36+ M = ot .dist (x .reshape ((n , 1 )), x .reshape ((n , 1 )), 'euclidean' )
37+ M /= M .max ()
38+ M2 = ot .dist (x .reshape ((n , 1 )), x .reshape ((n , 1 )), 'sqeuclidean' )
39+ M2 /= M2 .max ()
4040#%% plot the distributions
4141
4242pl .figure (1 )
43- pl .subplot (2 ,1 , 1 )
44- pl .plot (x ,a , 'b' ,label = 'Source distribution' )
43+ pl .subplot (2 , 1 , 1 )
44+ pl .plot (x , a , 'b' , label = 'Source distribution' )
4545pl .title ('Source distribution' )
46- pl .subplot (2 ,1 , 2 )
47- pl .plot (x ,B , label = 'Target distributions' )
46+ pl .subplot (2 , 1 , 2 )
47+ pl .plot (x , B , label = 'Target distributions' )
4848pl .title ('Target distributions' )
49+ pl .tight_layout ()
4950
5051#%% Compute and plot distributions and loss matrix
5152
52- d_emd = ot .emd2 (a ,B , M ) # direct computation of EMD
53- d_emd2 = ot .emd2 (a ,B , M2 ) # direct computation of EMD with loss M3
53+ d_emd = ot .emd2 (a , B , M ) # direct computation of EMD
54+ d_emd2 = ot .emd2 (a , B , M2 ) # direct computation of EMD with loss M3
5455
5556
5657pl .figure (2 )
57- pl .plot (d_emd ,label = 'Euclidean EMD' )
58- pl .plot (d_emd2 ,label = 'Squared Euclidean EMD' )
58+ pl .plot (d_emd , label = 'Euclidean EMD' )
59+ pl .plot (d_emd2 , label = 'Squared Euclidean EMD' )
5960pl .title ('EMD distances' )
6061pl .legend ()
6162
6263#%%
63- reg = 1e-2
64- d_sinkhorn = ot .sinkhorn2 (a ,B , M , reg )
65- d_sinkhorn2 = ot .sinkhorn2 (a ,B , M2 ,reg )
64+ reg = 1e-2
65+ d_sinkhorn = ot .sinkhorn2 (a , B , M , reg )
66+ d_sinkhorn2 = ot .sinkhorn2 (a , B , M2 , reg )
6667
6768pl .figure (2 )
6869pl .clf ()
69- pl .plot (d_emd ,label = 'Euclidean EMD' )
70- pl .plot (d_emd2 ,label = 'Squared Euclidean EMD' )
71- pl .plot (d_sinkhorn ,'+' ,label = 'Euclidean Sinkhorn' )
72- pl .plot (d_sinkhorn2 ,'+' ,label = 'Squared Euclidean Sinkhorn' )
70+ pl .plot (d_emd , label = 'Euclidean EMD' )
71+ pl .plot (d_emd2 , label = 'Squared Euclidean EMD' )
72+ pl .plot (d_sinkhorn , '+' , label = 'Euclidean Sinkhorn' )
73+ pl .plot (d_sinkhorn2 , '+' , label = 'Squared Euclidean Sinkhorn' )
7374pl .title ('EMD distances' )
74- pl .legend ()
75+ pl .legend ()
76+
77+ pl .show ()
0 commit comments