Skip to content

Commit 6d7fd7e

Browse files
committed
more
1 parent 7ad4725 commit 6d7fd7e

File tree

2 files changed

+46
-42
lines changed

2 files changed

+46
-42
lines changed

examples/plot_WDA.py

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,81 +16,83 @@
1616

1717
#%% parameters
1818

19-
n=1000 # nb samples in source and target datasets
20-
nz=0.2
19+
n = 1000 # nb samples in source and target datasets
20+
nz = 0.2
2121

2222
# generate circle dataset
23-
t=np.random.rand(n)*2*np.pi
24-
ys=np.floor((np.arange(n)*1.0/n*3))+1
25-
xs=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1)
26-
xs=xs*ys.reshape(-1,1)+nz*np.random.randn(n,2)
23+
t = np.random.rand(n) * 2 * np.pi
24+
ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
25+
xs = np.concatenate(
26+
(np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
27+
xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)
2728

28-
t=np.random.rand(n)*2*np.pi
29-
yt=np.floor((np.arange(n)*1.0/n*3))+1
30-
xt=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1)
31-
xt=xt*yt.reshape(-1,1)+nz*np.random.randn(n,2)
29+
t = np.random.rand(n) * 2 * np.pi
30+
yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
31+
xt = np.concatenate(
32+
(np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
33+
xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)
3234

33-
nbnoise=8
35+
nbnoise = 8
3436

35-
xs=np.hstack((xs,np.random.randn(n,nbnoise)))
36-
xt=np.hstack((xt,np.random.randn(n,nbnoise)))
37+
xs = np.hstack((xs, np.random.randn(n, nbnoise)))
38+
xt = np.hstack((xt, np.random.randn(n, nbnoise)))
3739

3840
#%% plot samples
39-
pl.figure(1,(10,5))
41+
pl.figure(1, figsize=(6.4, 3.5))
4042

41-
pl.subplot(1,2,1)
42-
pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples')
43+
pl.subplot(1, 2, 1)
44+
pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')
4345
pl.legend(loc=0)
4446
pl.title('Discriminant dimensions')
4547

46-
pl.subplot(1,2,2)
47-
pl.scatter(xt[:,2],xt[:,3],c=ys,marker='+',label='Source samples')
48+
pl.subplot(1, 2, 2)
49+
pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')
4850
pl.legend(loc=0)
4951
pl.title('Other dimensions')
50-
pl.show()
52+
pl.tight_layout()
5153

5254
#%% Compute FDA
53-
p=2
55+
p = 2
5456

55-
Pfda,projfda = fda(xs,ys,p)
57+
Pfda, projfda = fda(xs, ys, p)
5658

5759
#%% Compute WDA
58-
p=2
59-
reg=1e0
60-
k=10
61-
maxiter=100
60+
p = 2
61+
reg = 1e0
62+
k = 10
63+
maxiter = 100
6264

63-
Pwda,projwda = wda(xs,ys,p,reg,k,maxiter=maxiter)
65+
Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)
6466

6567
#%% plot samples
6668

67-
xsp=projfda(xs)
68-
xtp=projfda(xt)
69+
xsp = projfda(xs)
70+
xtp = projfda(xt)
6971

70-
xspw=projwda(xs)
71-
xtpw=projwda(xt)
72+
xspw = projwda(xs)
73+
xtpw = projwda(xt)
7274

73-
pl.figure(1,(10,10))
75+
pl.figure(2)
7476

75-
pl.subplot(2,2,1)
76-
pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples')
77+
pl.subplot(2, 2, 1)
78+
pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')
7779
pl.legend(loc=0)
7880
pl.title('Projected training samples FDA')
7981

80-
81-
pl.subplot(2,2,2)
82-
pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples')
82+
pl.subplot(2, 2, 2)
83+
pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')
8384
pl.legend(loc=0)
8485
pl.title('Projected test samples FDA')
8586

86-
87-
pl.subplot(2,2,3)
88-
pl.scatter(xspw[:,0],xspw[:,1],c=ys,marker='+',label='Projected samples')
87+
pl.subplot(2, 2, 3)
88+
pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')
8989
pl.legend(loc=0)
9090
pl.title('Projected training samples WDA')
9191

92-
93-
pl.subplot(2,2,4)
94-
pl.scatter(xtpw[:,0],xtpw[:,1],c=ys,marker='+',label='Projected samples')
92+
pl.subplot(2, 2, 4)
93+
pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')
9594
pl.legend(loc=0)
9695
pl.title('Projected test samples WDA')
96+
pl.tight_layout()
97+
98+
pl.show()

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ scipy
33
cython
44
matplotlib
55
sphinx-gallery
6+
autograd
7+
pymanopt

0 commit comments

Comments
 (0)