|
8 | 8 | """ |
9 | 9 |
|
10 | 10 | import numpy as np |
11 | | -import matplotlib.pylab as pl |
| 11 | +import matplotlib.pylab as plt |
12 | 12 | import ot |
13 | 13 |
|
14 | 14 | #%% parameters and data generation |
15 | 15 |
|
16 | | -n=50 # nb samples |
| 16 | +n = 50 # nb samples |
17 | 17 |
|
18 | | -mu_s=np.array([0,0]) |
19 | | -cov_s=np.array([[1,0],[0,1]]) |
| 18 | +mu_s = np.array([0, 0]) |
| 19 | +cov_s = np.array([[1, 0], [0, 1]]) |
20 | 20 |
|
21 | | -mu_t=np.array([4,4]) |
22 | | -cov_t=np.array([[1,-.8],[-.8,1]]) |
| 21 | +mu_t = np.array([4, 4]) |
| 22 | +cov_t = np.array([[1, -.8], [-.8, 1]]) |
23 | 23 |
|
24 | | -xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s) |
25 | | -xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t) |
| 24 | +xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s) |
| 25 | +xt = ot.datasets.get_2D_samples_gauss(n, mu_t, cov_t) |
26 | 26 |
|
27 | | -a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples |
| 27 | +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples |
28 | 28 |
|
29 | 29 | # loss matrix |
30 | | -M=ot.dist(xs,xt) |
31 | | -M/=M.max() |
| 30 | +M = ot.dist(xs, xt) |
| 31 | +M /= M.max() |
32 | 32 |
|
33 | 33 | #%% plot samples |
34 | 34 |
|
35 | | -pl.figure(1) |
36 | | -pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') |
37 | | -pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') |
38 | | -pl.legend(loc=0) |
39 | | -pl.title('Source and traget distributions') |
| 35 | +plt.figure(1) |
| 36 | +plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') |
| 37 | +plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') |
| 38 | +plt.legend(loc=0) |
| 39 | +plt.title('Source and target distributions') |
40 | 40 |
|
41 | | -pl.figure(2) |
42 | | -pl.imshow(M,interpolation='nearest') |
43 | | -pl.title('Cost matrix M') |
| 41 | +plt.figure(2) |
| 42 | +plt.imshow(M, interpolation='nearest') |
| 43 | +plt.title('Cost matrix M') |
44 | 44 |
|
45 | 45 |
|
46 | 46 | #%% EMD |
47 | 47 |
|
48 | | -G0=ot.emd(a,b,M) |
| 48 | +G0 = ot.emd(a, b, M) |
49 | 49 |
|
50 | | -pl.figure(3) |
51 | | -pl.imshow(G0,interpolation='nearest') |
52 | | -pl.title('OT matrix G0') |
| 50 | +plt.figure(3) |
| 51 | +plt.imshow(G0, interpolation='nearest') |
| 52 | +plt.title('OT matrix G0') |
53 | 53 |
|
54 | | -pl.figure(4) |
55 | | -ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) |
56 | | -pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') |
57 | | -pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') |
58 | | -pl.legend(loc=0) |
59 | | -pl.title('OT matrix with samples') |
| 54 | +plt.figure(4) |
| 55 | +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1]) |
| 56 | +plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') |
| 57 | +plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') |
| 58 | +plt.legend(loc=0) |
| 59 | +plt.title('OT matrix with samples') |
60 | 60 |
|
61 | 61 |
|
62 | 62 | #%% sinkhorn |
63 | 63 |
|
64 | 64 | # reg term |
65 | | -lambd=5e-4 |
| 65 | +lambd = 5e-4 |
66 | 66 |
|
67 | | -Gs=ot.sinkhorn(a,b,M,lambd) |
| 67 | +Gs = ot.sinkhorn(a, b, M, lambd) |
68 | 68 |
|
69 | | -pl.figure(5) |
70 | | -pl.imshow(Gs,interpolation='nearest') |
71 | | -pl.title('OT matrix sinkhorn') |
| 69 | +plt.figure(5) |
| 70 | +plt.imshow(Gs, interpolation='nearest') |
| 71 | +plt.title('OT matrix sinkhorn') |
72 | 72 |
|
73 | | -pl.figure(6) |
74 | | -ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1]) |
75 | | -pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') |
76 | | -pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') |
77 | | -pl.legend(loc=0) |
78 | | -pl.title('OT matrix Sinkhorn with samples') |
| 73 | +plt.figure(6) |
| 74 | +ot.plot.plot2D_samples_mat(xs, xt, Gs, color=[.5, .5, 1]) |
| 75 | +plt.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') |
| 76 | +plt.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') |
| 77 | +plt.legend(loc=0) |
| 78 | +plt.title('OT matrix Sinkhorn with samples') |
| 79 | + |
| 80 | +plt.show() |
0 commit comments