Skip to content

Commit 8239423

Browse files
committed
more
1 parent 25ef32f commit 8239423

File tree

1 file changed

+64
-54
lines changed

1 file changed

+64
-54
lines changed

examples/plot_OTDA_mapping.py

Lines changed: 64 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,107 +4,117 @@
44
OT mapping estimation for domain adaptation [8]
55
===============================================
66
7-
[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
8-
discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
7+
[8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
8+
"Mapping estimation for discrete optimal transport",
9+
Neural Information Processing Systems (NIPS), 2016.
910
"""
1011

1112
import numpy as np
1213
import matplotlib.pylab as pl
1314
import ot
1415

1516

16-
1717
#%% dataset generation
1818

19-
np.random.seed(0) # makes example reproducible
19+
np.random.seed(0) # makes example reproducible
2020

21-
n=100 # nb samples in source and target datasets
22-
theta=2*np.pi/20
23-
nz=0.1
24-
xs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz)
25-
xt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz)
21+
n = 100 # nb samples in source and target datasets
22+
theta = 2 * np.pi / 20
23+
nz = 0.1
24+
xs, ys = ot.datasets.get_data_classif('gaussrot', n, nz=nz)
25+
xt, yt = ot.datasets.get_data_classif('gaussrot', n, theta=theta, nz=nz)
2626

2727
# one of the target mode changes its variance (no linear mapping)
28-
xt[yt==2]*=3
29-
xt=xt+4
28+
xt[yt == 2] *= 3
29+
xt = xt + 4
3030

3131

3232
#%% plot samples
3333

34-
pl.figure(1,(8,5))
34+
pl.figure(1, (6.4, 3))
3535
pl.clf()
36-
37-
pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
38-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
39-
36+
pl.scatter(xs[:, 0], xs[:, 1], c=ys, marker='+', label='Source samples')
37+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', label='Target samples')
4038
pl.legend(loc=0)
4139
pl.title('Source and target distributions')
4240

4341

44-
4542
#%% OT linear mapping estimation
4643

47-
eta=1e-8 # quadratic regularization for regression
48-
mu=1e0 # weight of the OT linear term
49-
bias=True # estimate a bias
44+
eta = 1e-8 # quadratic regularization for regression
45+
mu = 1e0 # weight of the OT linear term
46+
bias = True # estimate a bias
5047

51-
ot_mapping=ot.da.OTDA_mapping_linear()
52-
ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True)
48+
ot_mapping = ot.da.OTDA_mapping_linear()
49+
ot_mapping.fit(xs, xt, mu=mu, eta=eta, bias=bias, numItermax=20, verbose=True)
5350

54-
xst=ot_mapping.predict(xs) # use the estimated mapping
55-
xst0=ot_mapping.interp() # use barycentric mapping
51+
xst = ot_mapping.predict(xs) # use the estimated mapping
52+
xst0 = ot_mapping.interp() # use barycentric mapping
5653

5754

58-
pl.figure(2,(10,7))
55+
pl.figure(2)
5956
pl.clf()
60-
pl.subplot(2,2,1)
61-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)
62-
pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping')
57+
pl.subplot(2, 2, 1)
58+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o',
59+
label='Target samples', alpha=.3)
60+
pl.scatter(xst0[:, 0], xst0[:, 1], c=ys,
61+
marker='+', label='barycentric mapping')
6362
pl.title("barycentric mapping")
6463

65-
pl.subplot(2,2,2)
66-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)
67-
pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')
64+
pl.subplot(2, 2, 2)
65+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o',
66+
label='Target samples', alpha=.3)
67+
pl.scatter(xst[:, 0], xst[:, 1], c=ys, marker='+', label='Learned mapping')
6868
pl.title("Learned mapping")
69-
70-
69+
pl.tight_layout()
7170

7271
#%% Kernel mapping estimation
7372

74-
eta=1e-5 # quadratic regularization for regression
75-
mu=1e-1 # weight of the OT linear term
76-
bias=True # estimate a bias
77-
sigma=1 # sigma bandwidth fot gaussian kernel
73+
eta = 1e-5 # quadratic regularization for regression
74+
mu = 1e-1 # weight of the OT linear term
75+
bias = True # estimate a bias
76+
sigma = 1 # sigma bandwidth fot gaussian kernel
7877

7978

80-
ot_mapping_kernel=ot.da.OTDA_mapping_kernel()
81-
ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True)
79+
ot_mapping_kernel = ot.da.OTDA_mapping_kernel()
80+
ot_mapping_kernel.fit(
81+
xs, xt, mu=mu, eta=eta, sigma=sigma, bias=bias, numItermax=10, verbose=True)
8282

83-
xst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping
84-
xst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping
83+
xst_kernel = ot_mapping_kernel.predict(xs) # use the estimated mapping
84+
xst0_kernel = ot_mapping_kernel.interp() # use barycentric mapping
8585

8686

8787
#%% Plotting the mapped samples
8888

89-
pl.figure(2,(10,7))
89+
pl.figure(2)
9090
pl.clf()
91-
pl.subplot(2,2,1)
92-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
93-
pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples')
91+
pl.subplot(2, 2, 1)
92+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o',
93+
label='Target samples', alpha=.2)
94+
pl.scatter(xst0[:, 0], xst0[:, 1], c=ys, marker='+',
95+
label='Mapped source samples')
9496
pl.title("Bary. mapping (linear)")
9597
pl.legend(loc=0)
9698

97-
pl.subplot(2,2,2)
98-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
99-
pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')
99+
pl.subplot(2, 2, 2)
100+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o',
101+
label='Target samples', alpha=.2)
102+
pl.scatter(xst[:, 0], xst[:, 1], c=ys, marker='+', label='Learned mapping')
100103
pl.title("Estim. mapping (linear)")
101104

102-
pl.subplot(2,2,3)
103-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
104-
pl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping')
105+
pl.subplot(2, 2, 3)
106+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o',
107+
label='Target samples', alpha=.2)
108+
pl.scatter(xst0_kernel[:, 0], xst0_kernel[:, 1], c=ys,
109+
marker='+', label='barycentric mapping')
105110
pl.title("Bary. mapping (kernel)")
106111

107-
pl.subplot(2,2,4)
108-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
109-
pl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping')
112+
pl.subplot(2, 2, 4)
113+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o',
114+
label='Target samples', alpha=.2)
115+
pl.scatter(xst_kernel[:, 0], xst_kernel[:, 1], c=ys,
116+
marker='+', label='Learned mapping')
110117
pl.title("Estim. mapping (kernel)")
118+
pl.tight_layout()
119+
120+
pl.show()

0 commit comments

Comments
 (0)