Skip to content

Commit 871302f

Browse files
authored
Merge pull request #13 from agramfort/pimp_it_up
[MRG] flake8 + pimp example figures Fixes #2 Big thanks to @agramfort for the counselling and PEP8 cleanup
2 parents 6088823 + 6ada23e commit 871302f

29 files changed

+1477
-1379
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,6 @@ ENV/
9797

9898
# Rope project settings
9999
.ropeproject
100+
101+
# Mac stuff
102+
.DS_Store

.travis.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ before_install:
1616
# command to install dependencies
1717
install:
1818
- pip install -r requirements.txt
19+
- pip install flake8 pytest
1920
- python setup.py install
20-
# command to run tests
21-
script: python test/test_load_module.py -v
21+
# command to run tests + check syntax style
22+
script:
23+
- python test/test_load_module.py -v
24+
- flake8 examples/ ot/ test/
25+
# - py.test ot test

examples/plot_OTDA_2D.py

Lines changed: 58 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,110 +11,112 @@
1111
import ot
1212

1313

14-
1514
#%% parameters
1615

17-
n=150 # nb bins
16+
n = 150 # nb bins
1817

19-
xs,ys=ot.datasets.get_data_classif('3gauss',n)
20-
xt,yt=ot.datasets.get_data_classif('3gauss2',n)
18+
xs, ys = ot.datasets.get_data_classif('3gauss', n)
19+
xt, yt = ot.datasets.get_data_classif('3gauss2', n)
2120

22-
a,b = ot.unif(n),ot.unif(n)
21+
a, b = ot.unif(n), ot.unif(n)
2322
# loss matrix
24-
M=ot.dist(xs,xt)
25-
#M/=M.max()
23+
M = ot.dist(xs, xt)
24+
# M/=M.max()
2625

2726
#%% plot samples
2827

2928
pl.figure(1)
30-
31-
pl.subplot(2,2,1)
32-
pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
29+
pl.subplot(2, 2, 1)
30+
pl.scatter(xs[:, 0], xs[:, 1], c=ys, marker='+', label='Source samples')
3331
pl.legend(loc=0)
3432
pl.title('Source distributions')
3533

36-
pl.subplot(2,2,2)
37-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
34+
pl.subplot(2, 2, 2)
35+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', label='Target samples')
3836
pl.legend(loc=0)
3937
pl.title('target distributions')
4038

4139
pl.figure(2)
42-
pl.imshow(M,interpolation='nearest')
40+
pl.imshow(M, interpolation='nearest')
4341
pl.title('Cost matrix M')
4442

4543

4644
#%% OT estimation
4745

4846
# EMD
49-
G0=ot.emd(a,b,M)
47+
G0 = ot.emd(a, b, M)
5048

5149
# sinkhorn
52-
lambd=1e-1
53-
Gs=ot.sinkhorn(a,b,M,lambd)
50+
lambd = 1e-1
51+
Gs = ot.sinkhorn(a, b, M, lambd)
5452

5553

5654
# Group lasso regularization
57-
reg=1e-1
58-
eta=1e0
59-
Gg=ot.da.sinkhorn_lpl1_mm(a,ys.astype(np.int),b,M,reg,eta)
55+
reg = 1e-1
56+
eta = 1e0
57+
Gg = ot.da.sinkhorn_lpl1_mm(a, ys.astype(np.int), b, M, reg, eta)
6058

6159

6260
#%% visu matrices
6361

6462
pl.figure(3)
6563

66-
pl.subplot(2,3,1)
67-
pl.imshow(G0,interpolation='nearest')
64+
pl.subplot(2, 3, 1)
65+
pl.imshow(G0, interpolation='nearest')
6866
pl.title('OT matrix ')
6967

70-
pl.subplot(2,3,2)
71-
pl.imshow(Gs,interpolation='nearest')
68+
pl.subplot(2, 3, 2)
69+
pl.imshow(Gs, interpolation='nearest')
7270
pl.title('OT matrix Sinkhorn')
7371

74-
pl.subplot(2,3,3)
75-
pl.imshow(Gg,interpolation='nearest')
72+
pl.subplot(2, 3, 3)
73+
pl.imshow(Gg, interpolation='nearest')
7674
pl.title('OT matrix Group lasso')
7775

78-
pl.subplot(2,3,4)
79-
ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])
80-
pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
81-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
76+
pl.subplot(2, 3, 4)
77+
ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])
78+
pl.scatter(xs[:, 0], xs[:, 1], c=ys, marker='+', label='Source samples')
79+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', label='Target samples')
8280

8381

84-
pl.subplot(2,3,5)
85-
ot.plot.plot2D_samples_mat(xs,xt,Gs,c=[.5,.5,1])
86-
pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
87-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
82+
pl.subplot(2, 3, 5)
83+
ot.plot.plot2D_samples_mat(xs, xt, Gs, c=[.5, .5, 1])
84+
pl.scatter(xs[:, 0], xs[:, 1], c=ys, marker='+', label='Source samples')
85+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', label='Target samples')
8886

89-
pl.subplot(2,3,6)
90-
ot.plot.plot2D_samples_mat(xs,xt,Gg,c=[.5,.5,1])
91-
pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
92-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
87+
pl.subplot(2, 3, 6)
88+
ot.plot.plot2D_samples_mat(xs, xt, Gg, c=[.5, .5, 1])
89+
pl.scatter(xs[:, 0], xs[:, 1], c=ys, marker='+', label='Source samples')
90+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', label='Target samples')
91+
pl.tight_layout()
9392

9493
#%% sample interpolation
9594

96-
xst0=n*G0.dot(xt)
97-
xsts=n*Gs.dot(xt)
98-
xstg=n*Gg.dot(xt)
99-
100-
pl.figure(4)
101-
pl.subplot(2,3,1)
102-
95+
xst0 = n * G0.dot(xt)
96+
xsts = n * Gs.dot(xt)
97+
xstg = n * Gg.dot(xt)
10398

104-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)
105-
pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)
99+
pl.figure(4, figsize=(8, 3))
100+
pl.subplot(1, 3, 1)
101+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o',
102+
label='Target samples', alpha=0.5)
103+
pl.scatter(xst0[:, 0], xst0[:, 1], c=ys,
104+
marker='+', label='Transp samples', s=30)
106105
pl.title('Interp samples')
107106
pl.legend(loc=0)
108107

109-
pl.subplot(2,3,2)
110-
111-
112-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)
113-
pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)
108+
pl.subplot(1, 3, 2)
109+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o',
110+
label='Target samples', alpha=0.5)
111+
pl.scatter(xsts[:, 0], xsts[:, 1], c=ys,
112+
marker='+', label='Transp samples', s=30)
114113
pl.title('Interp samples Sinkhorn')
115114

116-
pl.subplot(2,3,3)
117-
118-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)
119-
pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)
120-
pl.title('Interp samples Grouplasso')
115+
pl.subplot(1, 3, 3)
116+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o',
117+
label='Target samples', alpha=0.5)
118+
pl.scatter(xstg[:, 0], xstg[:, 1], c=ys,
119+
marker='+', label='Transp samples', s=30)
120+
pl.title('Interp samples Grouplasso')
121+
pl.tight_layout()
122+
pl.show()

examples/plot_OTDA_classes.py

Lines changed: 65 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -10,103 +10,104 @@
1010
import ot
1111

1212

13-
14-
1513
#%% parameters
1614

17-
n=150 # nb samples in source and target datasets
18-
19-
xs,ys=ot.datasets.get_data_classif('3gauss',n)
20-
xt,yt=ot.datasets.get_data_classif('3gauss2',n)
21-
15+
n = 150 # nb samples in source and target datasets
2216

17+
xs, ys = ot.datasets.get_data_classif('3gauss', n)
18+
xt, yt = ot.datasets.get_data_classif('3gauss2', n)
2319

2420

2521
#%% plot samples
2622

27-
pl.figure(1)
23+
pl.figure(1, figsize=(6.4, 3))
2824

29-
pl.subplot(2,2,1)
30-
pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
25+
pl.subplot(1, 2, 1)
26+
pl.scatter(xs[:, 0], xs[:, 1], c=ys, marker='+', label='Source samples')
3127
pl.legend(loc=0)
3228
pl.title('Source distributions')
3329

34-
pl.subplot(2,2,2)
35-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
30+
pl.subplot(1, 2, 2)
31+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', label='Target samples')
3632
pl.legend(loc=0)
3733
pl.title('target distributions')
3834

3935

4036
#%% OT estimation
4137

4238
# LP problem
43-
da_emd=ot.da.OTDA() # init class
44-
da_emd.fit(xs,xt) # fit distributions
45-
xst0=da_emd.interp() # interpolation of source samples
46-
39+
da_emd = ot.da.OTDA() # init class
40+
da_emd.fit(xs, xt) # fit distributions
41+
xst0 = da_emd.interp() # interpolation of source samples
4742

4843
# sinkhorn regularization
49-
lambd=1e-1
50-
da_entrop=ot.da.OTDA_sinkhorn()
51-
da_entrop.fit(xs,xt,reg=lambd)
52-
xsts=da_entrop.interp()
44+
lambd = 1e-1
45+
da_entrop = ot.da.OTDA_sinkhorn()
46+
da_entrop.fit(xs, xt, reg=lambd)
47+
xsts = da_entrop.interp()
5348

5449
# non-convex Group lasso regularization
55-
reg=1e-1
56-
eta=1e0
57-
da_lpl1=ot.da.OTDA_lpl1()
58-
da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta)
59-
xstg=da_lpl1.interp()
60-
50+
reg = 1e-1
51+
eta = 1e0
52+
da_lpl1 = ot.da.OTDA_lpl1()
53+
da_lpl1.fit(xs, ys, xt, reg=reg, eta=eta)
54+
xstg = da_lpl1.interp()
6155

6256
# True Group lasso regularization
63-
reg=1e-1
64-
eta=2e0
65-
da_l1l2=ot.da.OTDA_l1l2()
66-
da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True)
67-
xstgl=da_l1l2.interp()
68-
57+
reg = 1e-1
58+
eta = 2e0
59+
da_l1l2 = ot.da.OTDA_l1l2()
60+
da_l1l2.fit(xs, ys, xt, reg=reg, eta=eta, numItermax=20, verbose=True)
61+
xstgl = da_l1l2.interp()
6962

7063
#%% plot interpolated source samples
71-
pl.figure(4,(15,8))
7264

73-
param_img={'interpolation':'nearest','cmap':'jet'}
65+
param_img = {'interpolation': 'nearest', 'cmap': 'spectral'}
7466

75-
pl.subplot(2,4,1)
76-
pl.imshow(da_emd.G,**param_img)
67+
pl.figure(2, figsize=(8, 4.5))
68+
pl.subplot(2, 4, 1)
69+
pl.imshow(da_emd.G, **param_img)
7770
pl.title('OT matrix')
7871

72+
pl.subplot(2, 4, 2)
73+
pl.imshow(da_entrop.G, **param_img)
74+
pl.title('OT matrix\nsinkhorn')
7975

80-
pl.subplot(2,4,2)
81-
pl.imshow(da_entrop.G,**param_img)
82-
pl.title('OT matrix sinkhorn')
83-
84-
pl.subplot(2,4,3)
85-
pl.imshow(da_lpl1.G,**param_img)
86-
pl.title('OT matrix non-convex Group Lasso')
87-
88-
pl.subplot(2,4,4)
89-
pl.imshow(da_l1l2.G,**param_img)
90-
pl.title('OT matrix Group Lasso')
76+
pl.subplot(2, 4, 3)
77+
pl.imshow(da_lpl1.G, **param_img)
78+
pl.title('OT matrix\nnon-convex Group Lasso')
9179

80+
pl.subplot(2, 4, 4)
81+
pl.imshow(da_l1l2.G, **param_img)
82+
pl.title('OT matrix\nGroup Lasso')
9283

93-
pl.subplot(2,4,5)
94-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
95-
pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)
84+
pl.subplot(2, 4, 5)
85+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o',
86+
label='Target samples', alpha=0.3)
87+
pl.scatter(xst0[:, 0], xst0[:, 1], c=ys,
88+
marker='+', label='Transp samples', s=30)
9689
pl.title('Interp samples')
9790
pl.legend(loc=0)
9891

99-
pl.subplot(2,4,6)
100-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
101-
pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)
102-
pl.title('Interp samples Sinkhorn')
103-
104-
pl.subplot(2,4,7)
105-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
106-
pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)
107-
pl.title('Interp samples non-convex Group Lasso')
108-
109-
pl.subplot(2,4,8)
110-
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
111-
pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30)
112-
pl.title('Interp samples Group Lasso')
92+
pl.subplot(2, 4, 6)
93+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o',
94+
label='Target samples', alpha=0.3)
95+
pl.scatter(xsts[:, 0], xsts[:, 1], c=ys,
96+
marker='+', label='Transp samples', s=30)
97+
pl.title('Interp samples\nSinkhorn')
98+
99+
pl.subplot(2, 4, 7)
100+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o',
101+
label='Target samples', alpha=0.3)
102+
pl.scatter(xstg[:, 0], xstg[:, 1], c=ys,
103+
marker='+', label='Transp samples', s=30)
104+
pl.title('Interp samples\nnon-convex Group Lasso')
105+
106+
pl.subplot(2, 4, 8)
107+
pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o',
108+
label='Target samples', alpha=0.3)
109+
pl.scatter(xstgl[:, 0], xstgl[:, 1], c=ys,
110+
marker='+', label='Transp samples', s=30)
111+
pl.title('Interp samples\nGroup Lasso')
112+
pl.tight_layout()
113+
pl.show()

0 commit comments

Comments
 (0)