|
16 | 16 |
|
17 | 17 | #%% parameters |
18 | 18 |
|
19 | | -n=1000 # nb samples in source and target datasets |
20 | | -nz=0.2 |
| 19 | +n = 1000 # nb samples in source and target datasets |
| 20 | +nz = 0.2 |
21 | 21 |
|
22 | 22 | # generate circle dataset |
23 | | -t=np.random.rand(n)*2*np.pi |
24 | | -ys=np.floor((np.arange(n)*1.0/n*3))+1 |
25 | | -xs=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1) |
26 | | -xs=xs*ys.reshape(-1,1)+nz*np.random.randn(n,2) |
| 23 | +t = np.random.rand(n) * 2 * np.pi |
| 24 | +ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 |
| 25 | +xs = np.concatenate( |
| 26 | + (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) |
| 27 | +xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2) |
27 | 28 |
|
28 | | -t=np.random.rand(n)*2*np.pi |
29 | | -yt=np.floor((np.arange(n)*1.0/n*3))+1 |
30 | | -xt=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1) |
31 | | -xt=xt*yt.reshape(-1,1)+nz*np.random.randn(n,2) |
| 29 | +t = np.random.rand(n) * 2 * np.pi |
| 30 | +yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 |
| 31 | +xt = np.concatenate( |
| 32 | + (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) |
| 33 | +xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2) |
32 | 34 |
|
33 | | -nbnoise=8 |
| 35 | +nbnoise = 8 |
34 | 36 |
|
35 | | -xs=np.hstack((xs,np.random.randn(n,nbnoise))) |
36 | | -xt=np.hstack((xt,np.random.randn(n,nbnoise))) |
| 37 | +xs = np.hstack((xs, np.random.randn(n, nbnoise))) |
| 38 | +xt = np.hstack((xt, np.random.randn(n, nbnoise))) |
37 | 39 |
|
38 | 40 | #%% plot samples |
39 | | -pl.figure(1,(10,5)) |
| 41 | +pl.figure(1, figsize=(6.4, 3.5)) |
40 | 42 |
|
41 | | -pl.subplot(1,2,1) |
42 | | -pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples') |
| 43 | +pl.subplot(1, 2, 1) |
| 44 | +pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples') |
43 | 45 | pl.legend(loc=0) |
44 | 46 | pl.title('Discriminant dimensions') |
45 | 47 |
|
46 | | -pl.subplot(1,2,2) |
47 | | -pl.scatter(xt[:,2],xt[:,3],c=ys,marker='+',label='Source samples') |
| 48 | +pl.subplot(1, 2, 2) |
| 49 | +pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples') |
48 | 50 | pl.legend(loc=0) |
49 | 51 | pl.title('Other dimensions') |
50 | | -pl.show() |
| 52 | +pl.tight_layout() |
51 | 53 |
|
52 | 54 | #%% Compute FDA |
53 | | -p=2 |
| 55 | +p = 2 |
54 | 56 |
|
55 | | -Pfda,projfda = fda(xs,ys,p) |
| 57 | +Pfda, projfda = fda(xs, ys, p) |
56 | 58 |
|
57 | 59 | #%% Compute WDA |
58 | | -p=2 |
59 | | -reg=1e0 |
60 | | -k=10 |
61 | | -maxiter=100 |
| 60 | +p = 2 |
| 61 | +reg = 1e0 |
| 62 | +k = 10 |
| 63 | +maxiter = 100 |
62 | 64 |
|
63 | | -Pwda,projwda = wda(xs,ys,p,reg,k,maxiter=maxiter) |
| 65 | +Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter) |
64 | 66 |
|
65 | 67 | #%% plot samples |
66 | 68 |
|
67 | | -xsp=projfda(xs) |
68 | | -xtp=projfda(xt) |
| 69 | +xsp = projfda(xs) |
| 70 | +xtp = projfda(xt) |
69 | 71 |
|
70 | | -xspw=projwda(xs) |
71 | | -xtpw=projwda(xt) |
| 72 | +xspw = projwda(xs) |
| 73 | +xtpw = projwda(xt) |
72 | 74 |
|
73 | | -pl.figure(1,(10,10)) |
| 75 | +pl.figure(2) |
74 | 76 |
|
75 | | -pl.subplot(2,2,1) |
76 | | -pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples') |
| 77 | +pl.subplot(2, 2, 1) |
| 78 | +pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples') |
77 | 79 | pl.legend(loc=0) |
78 | 80 | pl.title('Projected training samples FDA') |
79 | 81 |
|
80 | | - |
81 | | -pl.subplot(2,2,2) |
82 | | -pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples') |
| 82 | +pl.subplot(2, 2, 2) |
| 83 | +pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples') |
83 | 84 | pl.legend(loc=0) |
84 | 85 | pl.title('Projected test samples FDA') |
85 | 86 |
|
86 | | - |
87 | | -pl.subplot(2,2,3) |
88 | | -pl.scatter(xspw[:,0],xspw[:,1],c=ys,marker='+',label='Projected samples') |
| 87 | +pl.subplot(2, 2, 3) |
| 88 | +pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples') |
89 | 89 | pl.legend(loc=0) |
90 | 90 | pl.title('Projected training samples WDA') |
91 | 91 |
|
92 | | - |
93 | | -pl.subplot(2,2,4) |
94 | | -pl.scatter(xtpw[:,0],xtpw[:,1],c=ys,marker='+',label='Projected samples') |
| 92 | +pl.subplot(2, 2, 4) |
| 93 | +pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples') |
95 | 94 | pl.legend(loc=0) |
96 | 95 | pl.title('Projected test samples WDA') |
| 96 | +pl.tight_layout() |
| 97 | + |
| 98 | +pl.show() |
0 commit comments