Skip to content

Commit c6cb1cd

Browse files
committed
pimp + pep8 on plot_OT_L1_vs_L2
1 parent 35b25ad commit c6cb1cd

File tree

1 file changed

+86
-82
lines changed

1 file changed

+86
-82
lines changed

examples/plot_OT_L1_vs_L2.py

Lines changed: 86 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -4,105 +4,109 @@
44
2D Optimal transport for different metrics
55
==========================================
66
7-
Stole the figure idea from Fig. 1 and 2 in
7+
Stole the figure idea from Fig. 1 and 2 in
88
https://arxiv.org/pdf/1706.07650.pdf
99
1010
1111
@author: rflamary
1212
"""
1313

1414
import numpy as np
15-
import matplotlib.pylab as pl
15+
import matplotlib.pylab as plt
1616
import ot
1717

1818
#%% parameters and data generation
1919

2020
for data in range(2):
2121

2222
if data:
23-
n=20 # nb samples
24-
xs=np.zeros((n,2))
25-
xs[:,0]=np.arange(n)+1
26-
xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex...
27-
28-
xt=np.zeros((n,2))
29-
xt[:,1]=np.arange(n)+1
23+
n = 20 # nb samples
24+
xs = np.zeros((n, 2))
25+
xs[:, 0] = np.arange(n) + 1
26+
xs[:, 1] = (np.arange(n) + 1) * -0.001 # to make it strictly convex...
27+
28+
xt = np.zeros((n, 2))
29+
xt[:, 1] = np.arange(n) + 1
3030
else:
31-
32-
n=50 # nb samples
33-
xtot=np.zeros((n+1,2))
34-
xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
35-
xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
36-
37-
xs=xtot[:n,:]
38-
xt=xtot[1:,:]
39-
40-
41-
42-
a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
43-
31+
32+
n = 50 # nb samples
33+
xtot = np.zeros((n + 1, 2))
34+
xtot[:, 0] = np.cos(
35+
(np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
36+
xtot[:, 1] = np.sin(
37+
(np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
38+
39+
xs = xtot[:n, :]
40+
xt = xtot[1:, :]
41+
42+
a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
43+
4444
# loss matrix
45-
M1=ot.dist(xs,xt,metric='euclidean')
46-
M1/=M1.max()
47-
45+
M1 = ot.dist(xs, xt, metric='euclidean')
46+
M1 /= M1.max()
47+
4848
# loss matrix
49-
M2=ot.dist(xs,xt,metric='sqeuclidean')
50-
M2/=M2.max()
51-
49+
M2 = ot.dist(xs, xt, metric='sqeuclidean')
50+
M2 /= M2.max()
51+
5252
# loss matrix
53-
Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean'))
54-
Mp/=Mp.max()
55-
53+
Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
54+
Mp /= Mp.max()
55+
5656
#%% plot samples
57-
58-
pl.figure(1+3*data)
59-
pl.clf()
60-
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
61-
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
62-
pl.axis('equal')
63-
pl.title('Source and traget distributions')
64-
65-
pl.figure(2+3*data,(15,5))
66-
pl.subplot(1,3,1)
67-
pl.imshow(M1,interpolation='nearest')
68-
pl.title('Eucidean cost')
69-
pl.subplot(1,3,2)
70-
pl.imshow(M2,interpolation='nearest')
71-
pl.title('Squared Euclidean cost')
72-
73-
pl.subplot(1,3,3)
74-
pl.imshow(Mp,interpolation='nearest')
75-
pl.title('Sqrt Euclidean cost')
57+
58+
plt.figure(1 + 3 * data, figsize=(7, 3))
59+
plt.clf()
60+
plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
61+
plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
62+
plt.axis('equal')
63+
plt.title('Source and traget distributions')
64+
65+
plt.figure(2 + 3 * data, figsize=(7, 3))
66+
67+
plt.subplot(1, 3, 1)
68+
plt.imshow(M1, interpolation='nearest')
69+
plt.title('Euclidean cost')
70+
71+
plt.subplot(1, 3, 2)
72+
plt.imshow(M2, interpolation='nearest')
73+
plt.title('Squared Euclidean cost')
74+
75+
plt.subplot(1, 3, 3)
76+
plt.imshow(Mp, interpolation='nearest')
77+
plt.title('Sqrt Euclidean cost')
78+
plt.tight_layout()
79+
7680
#%% EMD
77-
78-
G1=ot.emd(a,b,M1)
79-
G2=ot.emd(a,b,M2)
80-
Gp=ot.emd(a,b,Mp)
81-
82-
pl.figure(3+3*data,(15,5))
83-
84-
pl.subplot(1,3,1)
85-
ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1])
86-
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
87-
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
88-
pl.axis('equal')
89-
#pl.legend(loc=0)
90-
pl.title('OT Euclidean')
91-
92-
pl.subplot(1,3,2)
93-
94-
ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1])
95-
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
96-
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
97-
pl.axis('equal')
98-
#pl.legend(loc=0)
99-
pl.title('OT squared Euclidean')
100-
101-
pl.subplot(1,3,3)
102-
103-
ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1])
104-
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
105-
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
106-
pl.axis('equal')
107-
#pl.legend(loc=0)
108-
pl.title('OT sqrt Euclidean')
81+
G1 = ot.emd(a, b, M1)
82+
G2 = ot.emd(a, b, M2)
83+
Gp = ot.emd(a, b, Mp)
84+
85+
plt.figure(3 + 3 * data, figsize=(7, 3))
86+
87+
plt.subplot(1, 3, 1)
88+
ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
89+
plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
90+
plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
91+
plt.axis('equal')
92+
# plt.legend(loc=0)
93+
plt.title('OT Euclidean')
94+
95+
plt.subplot(1, 3, 2)
96+
ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
97+
plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
98+
plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
99+
plt.axis('equal')
100+
# plt.legend(loc=0)
101+
plt.title('OT squared Euclidean')
102+
103+
plt.subplot(1, 3, 3)
104+
ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
105+
plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
106+
plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
107+
plt.axis('equal')
108+
# plt.legend(loc=0)
109+
plt.title('OT sqrt Euclidean')
110+
plt.tight_layout()
111+
112+
plt.show()

0 commit comments

Comments
 (0)