|
4 | 4 | 2D Optimal transport for different metrics |
5 | 5 | ========================================== |
6 | 6 |
|
7 | | -Stole the figure idea from Fig. 1 and 2 in |
| 7 | +Stole the figure idea from Fig. 1 and 2 in |
8 | 8 | https://arxiv.org/pdf/1706.07650.pdf |
9 | 9 |
|
10 | 10 |
|
11 | 11 | @author: rflamary |
12 | 12 | """ |
13 | 13 |
|
14 | 14 | import numpy as np |
15 | | -import matplotlib.pylab as pl |
| 15 | +import matplotlib.pylab as plt |
16 | 16 | import ot |
17 | 17 |
|
18 | 18 | #%% parameters and data generation |
19 | 19 |
|
20 | 20 | for data in range(2): |
21 | 21 |
|
22 | 22 | if data: |
23 | | - n=20 # nb samples |
24 | | - xs=np.zeros((n,2)) |
25 | | - xs[:,0]=np.arange(n)+1 |
26 | | - xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex... |
27 | | - |
28 | | - xt=np.zeros((n,2)) |
29 | | - xt[:,1]=np.arange(n)+1 |
| 23 | + n = 20 # nb samples |
| 24 | + xs = np.zeros((n, 2)) |
| 25 | + xs[:, 0] = np.arange(n) + 1 |
| 26 | + xs[:, 1] = (np.arange(n) + 1) * -0.001 # to make it strictly convex... |
| 27 | + |
| 28 | + xt = np.zeros((n, 2)) |
| 29 | + xt[:, 1] = np.arange(n) + 1 |
30 | 30 | else: |
31 | | - |
32 | | - n=50 # nb samples |
33 | | - xtot=np.zeros((n+1,2)) |
34 | | - xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi) |
35 | | - xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi) |
36 | | - |
37 | | - xs=xtot[:n,:] |
38 | | - xt=xtot[1:,:] |
39 | | - |
40 | | - |
41 | | - |
42 | | - a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples |
43 | | - |
| 31 | + |
| 32 | + n = 50 # nb samples |
| 33 | + xtot = np.zeros((n + 1, 2)) |
| 34 | + xtot[:, 0] = np.cos( |
| 35 | + (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi) |
| 36 | + xtot[:, 1] = np.sin( |
| 37 | + (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi) |
| 38 | + |
| 39 | + xs = xtot[:n, :] |
| 40 | + xt = xtot[1:, :] |
| 41 | + |
| 42 | + a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples |
| 43 | + |
44 | 44 | # loss matrix |
45 | | - M1=ot.dist(xs,xt,metric='euclidean') |
46 | | - M1/=M1.max() |
47 | | - |
| 45 | + M1 = ot.dist(xs, xt, metric='euclidean') |
| 46 | + M1 /= M1.max() |
| 47 | + |
48 | 48 | # loss matrix |
49 | | - M2=ot.dist(xs,xt,metric='sqeuclidean') |
50 | | - M2/=M2.max() |
51 | | - |
| 49 | + M2 = ot.dist(xs, xt, metric='sqeuclidean') |
| 50 | + M2 /= M2.max() |
| 51 | + |
52 | 52 | # loss matrix |
53 | | - Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean')) |
54 | | - Mp/=Mp.max() |
55 | | - |
| 53 | + Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean')) |
| 54 | + Mp /= Mp.max() |
| 55 | + |
56 | 56 | #%% plot samples |
57 | | - |
58 | | - pl.figure(1+3*data) |
59 | | - pl.clf() |
60 | | - pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') |
61 | | - pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') |
62 | | - pl.axis('equal') |
63 | | - pl.title('Source and traget distributions') |
64 | | - |
65 | | - pl.figure(2+3*data,(15,5)) |
66 | | - pl.subplot(1,3,1) |
67 | | - pl.imshow(M1,interpolation='nearest') |
68 | | - pl.title('Eucidean cost') |
69 | | - pl.subplot(1,3,2) |
70 | | - pl.imshow(M2,interpolation='nearest') |
71 | | - pl.title('Squared Euclidean cost') |
72 | | - |
73 | | - pl.subplot(1,3,3) |
74 | | - pl.imshow(Mp,interpolation='nearest') |
75 | | - pl.title('Sqrt Euclidean cost') |
| 57 | + |
| 58 | + plt.figure(1 + 3 * data, figsize=(7, 3)) |
| 59 | + plt.clf() |
| 60 | + plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') |
| 61 | + plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') |
| 62 | + plt.axis('equal') |
| 63 | + plt.title('Source and traget distributions') |
| 64 | + |
| 65 | + plt.figure(2 + 3 * data, figsize=(7, 3)) |
| 66 | + |
| 67 | + plt.subplot(1, 3, 1) |
| 68 | + plt.imshow(M1, interpolation='nearest') |
| 69 | + plt.title('Euclidean cost') |
| 70 | + |
| 71 | + plt.subplot(1, 3, 2) |
| 72 | + plt.imshow(M2, interpolation='nearest') |
| 73 | + plt.title('Squared Euclidean cost') |
| 74 | + |
| 75 | + plt.subplot(1, 3, 3) |
| 76 | + plt.imshow(Mp, interpolation='nearest') |
| 77 | + plt.title('Sqrt Euclidean cost') |
| 78 | + plt.tight_layout() |
| 79 | + |
76 | 80 | #%% EMD |
77 | | - |
78 | | - G1=ot.emd(a,b,M1) |
79 | | - G2=ot.emd(a,b,M2) |
80 | | - Gp=ot.emd(a,b,Mp) |
81 | | - |
82 | | - pl.figure(3+3*data,(15,5)) |
83 | | - |
84 | | - pl.subplot(1,3,1) |
85 | | - ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1]) |
86 | | - pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') |
87 | | - pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') |
88 | | - pl.axis('equal') |
89 | | - #pl.legend(loc=0) |
90 | | - pl.title('OT Euclidean') |
91 | | - |
92 | | - pl.subplot(1,3,2) |
93 | | - |
94 | | - ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1]) |
95 | | - pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') |
96 | | - pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') |
97 | | - pl.axis('equal') |
98 | | - #pl.legend(loc=0) |
99 | | - pl.title('OT squared Euclidean') |
100 | | - |
101 | | - pl.subplot(1,3,3) |
102 | | - |
103 | | - ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1]) |
104 | | - pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') |
105 | | - pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') |
106 | | - pl.axis('equal') |
107 | | - #pl.legend(loc=0) |
108 | | - pl.title('OT sqrt Euclidean') |
| 81 | + G1 = ot.emd(a, b, M1) |
| 82 | + G2 = ot.emd(a, b, M2) |
| 83 | + Gp = ot.emd(a, b, Mp) |
| 84 | + |
| 85 | + plt.figure(3 + 3 * data, figsize=(7, 3)) |
| 86 | + |
| 87 | + plt.subplot(1, 3, 1) |
| 88 | + ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1]) |
| 89 | + plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') |
| 90 | + plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') |
| 91 | + plt.axis('equal') |
| 92 | + # plt.legend(loc=0) |
| 93 | + plt.title('OT Euclidean') |
| 94 | + |
| 95 | + plt.subplot(1, 3, 2) |
| 96 | + ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1]) |
| 97 | + plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') |
| 98 | + plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') |
| 99 | + plt.axis('equal') |
| 100 | + # plt.legend(loc=0) |
| 101 | + plt.title('OT squared Euclidean') |
| 102 | + |
| 103 | + plt.subplot(1, 3, 3) |
| 104 | + ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1]) |
| 105 | + plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') |
| 106 | + plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') |
| 107 | + plt.axis('equal') |
| 108 | + # plt.legend(loc=0) |
| 109 | + plt.title('OT sqrt Euclidean') |
| 110 | + plt.tight_layout() |
| 111 | + |
| 112 | +plt.show() |
0 commit comments