Skip to content

Commit 98b68f1

Browse files
author
ievred
committed
autopep+remove sinkhorn+add simtype
1 parent fa99199 commit 98b68f1

File tree

4 files changed

+28
-351
lines changed

4 files changed

+28
-351
lines changed

examples/plot_otda_laplacian.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
========================
66
77
This example introduces a domain adaptation in a 2D setting and OTDA
8-
approaches with Laplacian regularization.
8+
approache with Laplacian regularization.
99
1010
"""
1111

@@ -36,22 +36,17 @@
3636
ot_emd.fit(Xs=Xs, Xt=Xt)
3737

3838
# Sinkhorn Transport
39-
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=.5)
39+
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=.01)
4040
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
4141

4242
# EMD Transport with Laplacian regularization
4343
ot_emd_laplace = ot.da.EMDLaplaceTransport(reg_lap=100, reg_src=1)
4444
ot_emd_laplace.fit(Xs=Xs, Xt=Xt)
4545

46-
# Sinkhorn Transport with Laplacian regularization
47-
ot_sinkhorn_laplace = ot.da.SinkhornLaplaceTransport(reg_e=.5, reg_lap=100, reg_src=1)
48-
ot_sinkhorn_laplace.fit(Xs=Xs, Xt=Xt)
49-
5046
# transport source samples onto target samples
5147
transp_Xs_emd = ot_emd.transform(Xs=Xs)
5248
transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
5349
transp_Xs_emd_laplace = ot_emd_laplace.transform(Xs=Xs)
54-
transp_Xs_sinkhorn_laplace = ot_sinkhorn_laplace.transform(Xs=Xs)
5550

5651
##############################################################################
5752
# Fig 1 : plots source and target samples
@@ -80,35 +75,27 @@
8075

8176
param_img = {'interpolation': 'nearest'}
8277

83-
n_plots = 2
84-
8578
pl.figure(2, figsize=(15, 8))
86-
pl.subplot(2, 2*n_plots, 1)
79+
pl.subplot(2, 3, 1)
8780
pl.imshow(ot_emd.coupling_, **param_img)
8881
pl.xticks([])
8982
pl.yticks([])
9083
pl.title('Optimal coupling\nEMDTransport')
9184

9285
pl.figure(2, figsize=(15, 8))
93-
pl.subplot(2, 2*n_plots, 2)
86+
pl.subplot(2, 3, 2)
9487
pl.imshow(ot_sinkhorn.coupling_, **param_img)
9588
pl.xticks([])
9689
pl.yticks([])
9790
pl.title('Optimal coupling\nSinkhornTransport')
9891

99-
pl.subplot(2, 2*n_plots, 3)
92+
pl.subplot(2, 3, 3)
10093
pl.imshow(ot_emd_laplace.coupling_, **param_img)
10194
pl.xticks([])
10295
pl.yticks([])
10396
pl.title('Optimal coupling\nEMDLaplaceTransport')
10497

105-
pl.subplot(2, 2*n_plots, 4)
106-
pl.imshow(ot_emd_laplace.coupling_, **param_img)
107-
pl.xticks([])
108-
pl.yticks([])
109-
pl.title('Optimal coupling\nSinkhornLaplaceTransport')
110-
111-
pl.subplot(2, 2*n_plots, 5)
98+
pl.subplot(2, 3, 4)
11299
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
113100
label='Target samples', alpha=0.3)
114101
pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
@@ -118,7 +105,7 @@
118105
pl.title('Transported samples\nEmdTransport')
119106
pl.legend(loc="lower left")
120107

121-
pl.subplot(2, 2*n_plots, 6)
108+
pl.subplot(2, 3, 5)
122109
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
123110
label='Target samples', alpha=0.3)
124111
pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
@@ -127,23 +114,14 @@
127114
pl.yticks([])
128115
pl.title('Transported samples\nSinkhornTransport')
129116

130-
pl.subplot(2, 2*n_plots, 7)
117+
pl.subplot(2, 3, 6)
131118
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
132119
label='Target samples', alpha=0.3)
133120
pl.scatter(transp_Xs_emd_laplace[:, 0], transp_Xs_emd_laplace[:, 1], c=ys,
134121
marker='+', label='Transp samples', s=30)
135122
pl.xticks([])
136123
pl.yticks([])
137124
pl.title('Transported samples\nEMDLaplaceTransport')
138-
139-
pl.subplot(2, 2*n_plots, 8)
140-
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
141-
label='Target samples', alpha=0.3)
142-
pl.scatter(transp_Xs_sinkhorn_laplace[:, 0], transp_Xs_sinkhorn_laplace[:, 1], c=ys,
143-
marker='+', label='Transp samples', s=30)
144-
pl.xticks([])
145-
pl.yticks([])
146-
pl.title('Transported samples\nSinkhornLaplaceTransport')
147125
pl.tight_layout()
148126

149127
pl.show()

0 commit comments

Comments
 (0)