|
4 | 4 | OT mapping estimation for domain adaptation [8] |
5 | 5 | =============================================== |
6 | 6 |
|
7 | | -[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for |
8 | | - discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. |
| 7 | +[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, |
| 8 | + "Mapping estimation for discrete optimal transport", |
| 9 | + Neural Information Processing Systems (NIPS), 2016. |
9 | 10 | """ |
10 | 11 |
|
11 | 12 | import numpy as np |
12 | 13 | import matplotlib.pylab as pl |
13 | 14 | import ot |
14 | 15 |
|
15 | 16 |
|
16 | | - |
17 | 17 | #%% dataset generation |
18 | 18 |
|
19 | | -np.random.seed(0) # makes example reproducible |
| 19 | +np.random.seed(0) # makes example reproducible |
20 | 20 |
|
21 | | -n=100 # nb samples in source and target datasets |
22 | | -theta=2*np.pi/20 |
23 | | -nz=0.1 |
24 | | -xs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz) |
25 | | -xt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz) |
| 21 | +n = 100 # nb samples in source and target datasets |
| 22 | +theta = 2 * np.pi / 20 |
| 23 | +nz = 0.1 |
| 24 | +xs, ys = ot.datasets.get_data_classif('gaussrot', n, nz=nz) |
| 25 | +xt, yt = ot.datasets.get_data_classif('gaussrot', n, theta=theta, nz=nz) |
26 | 26 |
|
27 | 27 | # one of the target mode changes its variance (no linear mapping) |
28 | | -xt[yt==2]*=3 |
29 | | -xt=xt+4 |
| 28 | +xt[yt == 2] *= 3 |
| 29 | +xt = xt + 4 |
30 | 30 |
|
31 | 31 |
|
32 | 32 | #%% plot samples |
33 | 33 |
|
34 | | -pl.figure(1,(8,5)) |
| 34 | +pl.figure(1, (6.4, 3)) |
35 | 35 | pl.clf() |
36 | | - |
37 | | -pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') |
38 | | -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') |
39 | | - |
| 36 | +pl.scatter(xs[:, 0], xs[:, 1], c=ys, marker='+', label='Source samples') |
| 37 | +pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', label='Target samples') |
40 | 38 | pl.legend(loc=0) |
41 | 39 | pl.title('Source and target distributions') |
42 | 40 |
|
43 | 41 |
|
44 | | - |
45 | 42 | #%% OT linear mapping estimation |
46 | 43 |
|
47 | | -eta=1e-8 # quadratic regularization for regression |
48 | | -mu=1e0 # weight of the OT linear term |
49 | | -bias=True # estimate a bias |
| 44 | +eta = 1e-8 # quadratic regularization for regression |
| 45 | +mu = 1e0 # weight of the OT linear term |
| 46 | +bias = True # estimate a bias |
50 | 47 |
|
51 | | -ot_mapping=ot.da.OTDA_mapping_linear() |
52 | | -ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True) |
| 48 | +ot_mapping = ot.da.OTDA_mapping_linear() |
| 49 | +ot_mapping.fit(xs, xt, mu=mu, eta=eta, bias=bias, numItermax=20, verbose=True) |
53 | 50 |
|
54 | | -xst=ot_mapping.predict(xs) # use the estimated mapping |
55 | | -xst0=ot_mapping.interp() # use barycentric mapping |
| 51 | +xst = ot_mapping.predict(xs) # use the estimated mapping |
| 52 | +xst0 = ot_mapping.interp() # use barycentric mapping |
56 | 53 |
|
57 | 54 |
|
58 | | -pl.figure(2,(10,7)) |
| 55 | +pl.figure(2) |
59 | 56 | pl.clf() |
60 | | -pl.subplot(2,2,1) |
61 | | -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) |
62 | | -pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping') |
| 57 | +pl.subplot(2, 2, 1) |
| 58 | +pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', |
| 59 | + label='Target samples', alpha=.3) |
| 60 | +pl.scatter(xst0[:, 0], xst0[:, 1], c=ys, |
| 61 | + marker='+', label='barycentric mapping') |
63 | 62 | pl.title("barycentric mapping") |
64 | 63 |
|
65 | | -pl.subplot(2,2,2) |
66 | | -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) |
67 | | -pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') |
| 64 | +pl.subplot(2, 2, 2) |
| 65 | +pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', |
| 66 | + label='Target samples', alpha=.3) |
| 67 | +pl.scatter(xst[:, 0], xst[:, 1], c=ys, marker='+', label='Learned mapping') |
68 | 68 | pl.title("Learned mapping") |
69 | | - |
70 | | - |
| 69 | +pl.tight_layout() |
71 | 70 |
|
72 | 71 | #%% Kernel mapping estimation |
73 | 72 |
|
74 | | -eta=1e-5 # quadratic regularization for regression |
75 | | -mu=1e-1 # weight of the OT linear term |
76 | | -bias=True # estimate a bias |
77 | | -sigma=1 # sigma bandwidth fot gaussian kernel |
| 73 | +eta = 1e-5 # quadratic regularization for regression |
| 74 | +mu = 1e-1 # weight of the OT linear term |
| 75 | +bias = True # estimate a bias |
| 76 | +sigma = 1 # sigma bandwidth fot gaussian kernel |
78 | 77 |
|
79 | 78 |
|
80 | | -ot_mapping_kernel=ot.da.OTDA_mapping_kernel() |
81 | | -ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True) |
| 79 | +ot_mapping_kernel = ot.da.OTDA_mapping_kernel() |
| 80 | +ot_mapping_kernel.fit( |
| 81 | + xs, xt, mu=mu, eta=eta, sigma=sigma, bias=bias, numItermax=10, verbose=True) |
82 | 82 |
|
83 | | -xst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping |
84 | | -xst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping |
| 83 | +xst_kernel = ot_mapping_kernel.predict(xs) # use the estimated mapping |
| 84 | +xst0_kernel = ot_mapping_kernel.interp() # use barycentric mapping |
85 | 85 |
|
86 | 86 |
|
87 | 87 | #%% Plotting the mapped samples |
88 | 88 |
|
89 | | -pl.figure(2,(10,7)) |
| 89 | +pl.figure(2) |
90 | 90 | pl.clf() |
91 | | -pl.subplot(2,2,1) |
92 | | -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) |
93 | | -pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples') |
| 91 | +pl.subplot(2, 2, 1) |
| 92 | +pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', |
| 93 | + label='Target samples', alpha=.2) |
| 94 | +pl.scatter(xst0[:, 0], xst0[:, 1], c=ys, marker='+', |
| 95 | + label='Mapped source samples') |
94 | 96 | pl.title("Bary. mapping (linear)") |
95 | 97 | pl.legend(loc=0) |
96 | 98 |
|
97 | | -pl.subplot(2,2,2) |
98 | | -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) |
99 | | -pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') |
| 99 | +pl.subplot(2, 2, 2) |
| 100 | +pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', |
| 101 | + label='Target samples', alpha=.2) |
| 102 | +pl.scatter(xst[:, 0], xst[:, 1], c=ys, marker='+', label='Learned mapping') |
100 | 103 | pl.title("Estim. mapping (linear)") |
101 | 104 |
|
102 | | -pl.subplot(2,2,3) |
103 | | -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) |
104 | | -pl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping') |
| 105 | +pl.subplot(2, 2, 3) |
| 106 | +pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', |
| 107 | + label='Target samples', alpha=.2) |
| 108 | +pl.scatter(xst0_kernel[:, 0], xst0_kernel[:, 1], c=ys, |
| 109 | + marker='+', label='barycentric mapping') |
105 | 110 | pl.title("Bary. mapping (kernel)") |
106 | 111 |
|
107 | | -pl.subplot(2,2,4) |
108 | | -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) |
109 | | -pl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping') |
| 112 | +pl.subplot(2, 2, 4) |
| 113 | +pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', |
| 114 | + label='Target samples', alpha=.2) |
| 115 | +pl.scatter(xst_kernel[:, 0], xst_kernel[:, 1], c=ys, |
| 116 | + marker='+', label='Learned mapping') |
110 | 117 | pl.title("Estim. mapping (kernel)") |
| 118 | +pl.tight_layout() |
| 119 | + |
| 120 | +pl.show() |
0 commit comments