Skip to content

Commit de59036

Browse files
committed
do plot_compute_emd
1 parent 75c988f commit de59036

File tree

1 file changed

+31
-28
lines changed

1 file changed

+31
-28
lines changed

examples/plot_compute_emd.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,60 +15,63 @@
1515

1616
#%% parameters
1717

18-
n=100 # nb bins
19-
n_target=50 # nb target distributions
18+
n = 100 # nb bins
19+
n_target = 50 # nb target distributions
2020

2121

2222
# bin positions
23-
x=np.arange(n,dtype=np.float64)
23+
x = np.arange(n, dtype=np.float64)
2424

25-
lst_m=np.linspace(20,90,n_target)
25+
lst_m = np.linspace(20, 90, n_target)
2626

2727
# Gaussian distributions
28-
a=gauss(n,m=20,s=5) # m= mean, s= std
28+
a = gauss(n, m=20, s=5) # m= mean, s= std
2929

30-
B=np.zeros((n,n_target))
30+
B = np.zeros((n, n_target))
3131

32-
for i,m in enumerate(lst_m):
33-
B[:,i]=gauss(n,m=m,s=5)
32+
for i, m in enumerate(lst_m):
33+
B[:, i] = gauss(n, m=m, s=5)
3434

3535
# loss matrix and normalization
36-
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean')
37-
M/=M.max()
38-
M2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean')
39-
M2/=M2.max()
36+
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
37+
M /= M.max()
38+
M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
39+
M2 /= M2.max()
4040
#%% plot the distributions
4141

4242
pl.figure(1)
43-
pl.subplot(2,1,1)
44-
pl.plot(x,a,'b',label='Source distribution')
43+
pl.subplot(2, 1, 1)
44+
pl.plot(x, a, 'b', label='Source distribution')
4545
pl.title('Source distribution')
46-
pl.subplot(2,1,2)
47-
pl.plot(x,B,label='Target distributions')
46+
pl.subplot(2, 1, 2)
47+
pl.plot(x, B, label='Target distributions')
4848
pl.title('Target distributions')
49+
pl.tight_layout()
4950

5051
#%% Compute and plot distributions and loss matrix
5152

52-
d_emd=ot.emd2(a,B,M) # direct computation of EMD
53-
d_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3
53+
d_emd = ot.emd2(a, B, M) # direct computation of EMD
54+
d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M3
5455

5556

5657
pl.figure(2)
57-
pl.plot(d_emd,label='Euclidean EMD')
58-
pl.plot(d_emd2,label='Squared Euclidean EMD')
58+
pl.plot(d_emd, label='Euclidean EMD')
59+
pl.plot(d_emd2, label='Squared Euclidean EMD')
5960
pl.title('EMD distances')
6061
pl.legend()
6162

6263
#%%
63-
reg=1e-2
64-
d_sinkhorn=ot.sinkhorn2(a,B,M,reg)
65-
d_sinkhorn2=ot.sinkhorn2(a,B,M2,reg)
64+
reg = 1e-2
65+
d_sinkhorn = ot.sinkhorn2(a, B, M, reg)
66+
d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)
6667

6768
pl.figure(2)
6869
pl.clf()
69-
pl.plot(d_emd,label='Euclidean EMD')
70-
pl.plot(d_emd2,label='Squared Euclidean EMD')
71-
pl.plot(d_sinkhorn,'+',label='Euclidean Sinkhorn')
72-
pl.plot(d_sinkhorn2,'+',label='Squared Euclidean Sinkhorn')
70+
pl.plot(d_emd, label='Euclidean EMD')
71+
pl.plot(d_emd2, label='Squared Euclidean EMD')
72+
pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')
73+
pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')
7374
pl.title('EMD distances')
74-
pl.legend()
75+
pl.legend()
76+
77+
pl.show()

0 commit comments

Comments
 (0)