Skip to content

Commit 90f5d5f

Browse files
author
ievred
committed
laplace emd+sinkhorn
1 parent 9200af5 commit 90f5d5f

File tree

3 files changed

+2806
-1
lines changed

3 files changed

+2806
-1
lines changed

examples/plot_otda_laplacian.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
========================
4+
OT for domain adaptation
5+
========================
6+
7+
This example introduces a domain adaptation in a 2D setting and OTDA
8+
approaches with Laplacian regularization.
9+
10+
"""
11+
12+
# Authors: Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
13+
14+
# License: MIT License
15+
16+
import matplotlib.pylab as pl
17+
import ot
18+
19+
##############################################################################
20+
# Generate data
21+
# -------------
22+
23+
n_source_samples = 150
24+
n_target_samples = 150
25+
26+
Xs, ys = ot.datasets.make_data_classif('3gauss', n_source_samples)
27+
Xt, yt = ot.datasets.make_data_classif('3gauss2', n_target_samples)
28+
29+
30+
##############################################################################
31+
# Instantiate the different transport algorithms and fit them
32+
# -----------------------------------------------------------
33+
34+
# EMD Transport
35+
ot_emd = ot.da.EMDTransport()
36+
ot_emd.fit(Xs=Xs, Xt=Xt)
37+
38+
# Sinkhorn Transport
39+
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=.5)
40+
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
41+
42+
# EMD Transport with Laplacian regularization
43+
ot_emd_laplace = ot.da.EMDLaplaceTransport(reg_lap=100, reg_src=1)
44+
ot_emd_laplace.fit(Xs=Xs, Xt=Xt)
45+
46+
# Sinkhorn Transport with Laplacian regularization
47+
ot_sinkhorn_laplace = ot.da.SinkhornLaplaceTransport(reg_e=.5, reg_lap=100, reg_src=1)
48+
ot_sinkhorn_laplace.fit(Xs=Xs, Xt=Xt)
49+
50+
# transport source samples onto target samples
51+
transp_Xs_emd = ot_emd.transform(Xs=Xs)
52+
transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
53+
transp_Xs_emd_laplace = ot_emd_laplace.transform(Xs=Xs)
54+
transp_Xs_sinkhorn_laplace = ot_sinkhorn_laplace.transform(Xs=Xs)
55+
56+
##############################################################################
57+
# Fig 1 : plots source and target samples
58+
# ---------------------------------------
59+
60+
pl.figure(1, figsize=(10, 5))
61+
pl.subplot(1, 2, 1)
62+
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
63+
pl.xticks([])
64+
pl.yticks([])
65+
pl.legend(loc=0)
66+
pl.title('Source samples')
67+
68+
pl.subplot(1, 2, 2)
69+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
70+
pl.xticks([])
71+
pl.yticks([])
72+
pl.legend(loc=0)
73+
pl.title('Target samples')
74+
pl.tight_layout()
75+
76+
77+
##############################################################################
78+
# Fig 2 : plot optimal couplings and transported samples
79+
# ------------------------------------------------------
80+
81+
param_img = {'interpolation': 'nearest'}
82+
83+
n_plots = 2
84+
85+
pl.figure(2, figsize=(15, 8))
86+
pl.subplot(2, 2*n_plots, 1)
87+
pl.imshow(ot_emd.coupling_, **param_img)
88+
pl.xticks([])
89+
pl.yticks([])
90+
pl.title('Optimal coupling\nEMDTransport')
91+
92+
pl.figure(2, figsize=(15, 8))
93+
pl.subplot(2, 2*n_plots, 2)
94+
pl.imshow(ot_sinkhorn.coupling_, **param_img)
95+
pl.xticks([])
96+
pl.yticks([])
97+
pl.title('Optimal coupling\nSinkhornTransport')
98+
99+
pl.subplot(2, 2*n_plots, 3)
100+
pl.imshow(ot_emd_laplace.coupling_, **param_img)
101+
pl.xticks([])
102+
pl.yticks([])
103+
pl.title('Optimal coupling\nEMDLaplaceTransport')
104+
105+
pl.subplot(2, 2*n_plots, 4)
106+
pl.imshow(ot_emd_laplace.coupling_, **param_img)
107+
pl.xticks([])
108+
pl.yticks([])
109+
pl.title('Optimal coupling\nSinkhornLaplaceTransport')
110+
111+
pl.subplot(2, 2*n_plots, 5)
112+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
113+
label='Target samples', alpha=0.3)
114+
pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
115+
marker='+', label='Transp samples', s=30)
116+
pl.xticks([])
117+
pl.yticks([])
118+
pl.title('Transported samples\nEmdTransport')
119+
pl.legend(loc="lower left")
120+
121+
pl.subplot(2, 2*n_plots, 6)
122+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
123+
label='Target samples', alpha=0.3)
124+
pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
125+
marker='+', label='Transp samples', s=30)
126+
pl.xticks([])
127+
pl.yticks([])
128+
pl.title('Transported samples\nSinkhornTransport')
129+
130+
pl.subplot(2, 2*n_plots, 7)
131+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
132+
label='Target samples', alpha=0.3)
133+
pl.scatter(transp_Xs_emd_laplace[:, 0], transp_Xs_emd_laplace[:, 1], c=ys,
134+
marker='+', label='Transp samples', s=30)
135+
pl.xticks([])
136+
pl.yticks([])
137+
pl.title('Transported samples\nEMDLaplaceTransport')
138+
139+
pl.subplot(2, 2*n_plots, 8)
140+
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
141+
label='Target samples', alpha=0.3)
142+
pl.scatter(transp_Xs_sinkhorn_laplace[:, 0], transp_Xs_sinkhorn_laplace[:, 1], c=ys,
143+
marker='+', label='Transp samples', s=30)
144+
pl.xticks([])
145+
pl.yticks([])
146+
pl.title('Transported samples\nSinkhornLaplaceTransport')
147+
pl.tight_layout()
148+
149+
pl.show()

0 commit comments

Comments
 (0)