|
5 | 5 | ======================== |
6 | 6 |
|
7 | 7 | This example introduces a domain adaptation in a 2D setting and OTDA |
8 | | -approaches with Laplacian regularization. |
| 8 | +approache with Laplacian regularization. |
9 | 9 |
|
10 | 10 | """ |
11 | 11 |
|
|
36 | 36 | ot_emd.fit(Xs=Xs, Xt=Xt) |
37 | 37 |
|
38 | 38 | # Sinkhorn Transport |
39 | | -ot_sinkhorn = ot.da.SinkhornTransport(reg_e=.5) |
| 39 | +ot_sinkhorn = ot.da.SinkhornTransport(reg_e=.01) |
40 | 40 | ot_sinkhorn.fit(Xs=Xs, Xt=Xt) |
41 | 41 |
|
42 | 42 | # EMD Transport with Laplacian regularization |
43 | 43 | ot_emd_laplace = ot.da.EMDLaplaceTransport(reg_lap=100, reg_src=1) |
44 | 44 | ot_emd_laplace.fit(Xs=Xs, Xt=Xt) |
45 | 45 |
|
46 | | -# Sinkhorn Transport with Laplacian regularization |
47 | | -ot_sinkhorn_laplace = ot.da.SinkhornLaplaceTransport(reg_e=.5, reg_lap=100, reg_src=1) |
48 | | -ot_sinkhorn_laplace.fit(Xs=Xs, Xt=Xt) |
49 | | - |
50 | 46 | # transport source samples onto target samples |
51 | 47 | transp_Xs_emd = ot_emd.transform(Xs=Xs) |
52 | 48 | transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs) |
53 | 49 | transp_Xs_emd_laplace = ot_emd_laplace.transform(Xs=Xs) |
54 | | -transp_Xs_sinkhorn_laplace = ot_sinkhorn_laplace.transform(Xs=Xs) |
55 | 50 |
|
56 | 51 | ############################################################################## |
57 | 52 | # Fig 1 : plots source and target samples |
|
80 | 75 |
|
81 | 76 | param_img = {'interpolation': 'nearest'} |
82 | 77 |
|
83 | | -n_plots = 2 |
84 | | - |
85 | 78 | pl.figure(2, figsize=(15, 8)) |
86 | | -pl.subplot(2, 2*n_plots, 1) |
| 79 | +pl.subplot(2, 3, 1) |
87 | 80 | pl.imshow(ot_emd.coupling_, **param_img) |
88 | 81 | pl.xticks([]) |
89 | 82 | pl.yticks([]) |
90 | 83 | pl.title('Optimal coupling\nEMDTransport') |
91 | 84 |
|
92 | 85 | pl.figure(2, figsize=(15, 8)) |
93 | | -pl.subplot(2, 2*n_plots, 2) |
| 86 | +pl.subplot(2, 3, 2) |
94 | 87 | pl.imshow(ot_sinkhorn.coupling_, **param_img) |
95 | 88 | pl.xticks([]) |
96 | 89 | pl.yticks([]) |
97 | 90 | pl.title('Optimal coupling\nSinkhornTransport') |
98 | 91 |
|
99 | | -pl.subplot(2, 2*n_plots, 3) |
| 92 | +pl.subplot(2, 3, 3) |
100 | 93 | pl.imshow(ot_emd_laplace.coupling_, **param_img) |
101 | 94 | pl.xticks([]) |
102 | 95 | pl.yticks([]) |
103 | 96 | pl.title('Optimal coupling\nEMDLaplaceTransport') |
104 | 97 |
|
105 | | -pl.subplot(2, 2*n_plots, 4) |
106 | | -pl.imshow(ot_emd_laplace.coupling_, **param_img) |
107 | | -pl.xticks([]) |
108 | | -pl.yticks([]) |
109 | | -pl.title('Optimal coupling\nSinkhornLaplaceTransport') |
110 | | - |
111 | | -pl.subplot(2, 2*n_plots, 5) |
| 98 | +pl.subplot(2, 3, 4) |
112 | 99 | pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', |
113 | 100 | label='Target samples', alpha=0.3) |
114 | 101 | pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys, |
|
118 | 105 | pl.title('Transported samples\nEmdTransport') |
119 | 106 | pl.legend(loc="lower left") |
120 | 107 |
|
121 | | -pl.subplot(2, 2*n_plots, 6) |
| 108 | +pl.subplot(2, 3, 5) |
122 | 109 | pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', |
123 | 110 | label='Target samples', alpha=0.3) |
124 | 111 | pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys, |
|
127 | 114 | pl.yticks([]) |
128 | 115 | pl.title('Transported samples\nSinkhornTransport') |
129 | 116 |
|
130 | | -pl.subplot(2, 2*n_plots, 7) |
| 117 | +pl.subplot(2, 3, 6) |
131 | 118 | pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', |
132 | 119 | label='Target samples', alpha=0.3) |
133 | 120 | pl.scatter(transp_Xs_emd_laplace[:, 0], transp_Xs_emd_laplace[:, 1], c=ys, |
134 | 121 | marker='+', label='Transp samples', s=30) |
135 | 122 | pl.xticks([]) |
136 | 123 | pl.yticks([]) |
137 | 124 | pl.title('Transported samples\nEMDLaplaceTransport') |
138 | | - |
139 | | -pl.subplot(2, 2*n_plots, 8) |
140 | | -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', |
141 | | - label='Target samples', alpha=0.3) |
142 | | -pl.scatter(transp_Xs_sinkhorn_laplace[:, 0], transp_Xs_sinkhorn_laplace[:, 1], c=ys, |
143 | | - marker='+', label='Transp samples', s=30) |
144 | | -pl.xticks([]) |
145 | | -pl.yticks([]) |
146 | | -pl.title('Transported samples\nSinkhornLaplaceTransport') |
147 | 125 | pl.tight_layout() |
148 | 126 |
|
149 | 127 | pl.show() |
0 commit comments