|
18 | 18 |
|
19 | 19 | n=1000 # nb samples in source and target datasets |
20 | 20 | nz=0.2 |
21 | | -xs,ys=ot.datasets.get_data_classif('3gauss',n,nz) |
22 | | -xt,yt=ot.datasets.get_data_classif('3gauss',n,nz) |
| 21 | + |
| 22 | +# generate circle dataset |
| 23 | +t=np.random.rand(n)*2*np.pi |
| 24 | +ys=np.floor((np.arange(n)*1.0/n*3))+1 |
| 25 | +xs=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1) |
| 26 | +xs=xs*ys.reshape(-1,1)+nz*np.random.randn(n,2) |
| 27 | + |
| 28 | +t=np.random.rand(n)*2*np.pi |
| 29 | +yt=np.floor((np.arange(n)*1.0/n*3))+1 |
| 30 | +xt=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1) |
| 31 | +xt=xt*yt.reshape(-1,1)+nz*np.random.randn(n,2) |
23 | 32 |
|
24 | 33 | nbnoise=8 |
25 | 34 |
|
26 | 35 | xs=np.hstack((xs,np.random.randn(n,nbnoise))) |
27 | 36 | xt=np.hstack((xt,np.random.randn(n,nbnoise))) |
28 | 37 |
|
29 | 38 | #%% plot samples |
| 39 | +pl.figure(1,(10,5)) |
30 | 40 |
|
31 | | -pl.figure(1) |
32 | | - |
33 | | - |
| 41 | +pl.subplot(1,2,1) |
34 | 42 | pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples') |
35 | 43 | pl.legend(loc=0) |
36 | 44 | pl.title('Discriminant dimensions') |
37 | 45 |
|
| 46 | +pl.subplot(1,2,2) |
| 47 | +pl.scatter(xt[:,2],xt[:,3],c=ys,marker='+',label='Source samples') |
| 48 | +pl.legend(loc=0) |
| 49 | +pl.title('Other dimensions') |
| 50 | +pl.show() |
38 | 51 |
|
39 | | -#%% Comlpute FDA |
| 52 | +#%% Compute FDA |
40 | 53 | p=2 |
41 | 54 |
|
42 | 55 | Pfda,projfda = fda(xs,ys,p) |
43 | 56 |
|
44 | 57 | #%% Compute WDA |
45 | 58 | p=2 |
46 | | -reg=1 |
| 59 | +reg=1e-1 |
47 | 60 | k=10 |
48 | 61 | maxiter=100 |
49 | 62 |
|
50 | | -P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter) |
| 63 | +Pwda,projwda = wda(xs,ys,p,reg,k,maxiter=maxiter) |
51 | 64 |
|
52 | 65 | #%% plot samples |
53 | 66 |
|
54 | 67 | xsp=projfda(xs) |
55 | 68 | xtp=projfda(xt) |
56 | 69 |
|
57 | | -pl.figure(1,(10,5)) |
| 70 | +xspw=projwda(xs) |
| 71 | +xtpw=projwda(xt) |
58 | 72 |
|
59 | | -pl.subplot(1,2,1) |
| 73 | +pl.figure(1,(10,10)) |
| 74 | + |
| 75 | +pl.subplot(2,2,1) |
60 | 76 | pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples') |
61 | 77 | pl.legend(loc=0) |
62 | | -pl.title('Projected training samples') |
| 78 | +pl.title('Projected training samples FDA') |
63 | 79 |
|
64 | 80 |
|
65 | | -pl.subplot(1,2,2) |
| 81 | +pl.subplot(2,2,2) |
66 | 82 | pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples') |
67 | 83 | pl.legend(loc=0) |
68 | | -pl.title('Projected test samples') |
| 84 | +pl.title('Projected test samples FDA') |
| 85 | + |
| 86 | + |
| 87 | +pl.subplot(2,2,3) |
| 88 | +pl.scatter(xspw[:,0],xspw[:,1],c=ys,marker='+',label='Projected samples') |
| 89 | +pl.legend(loc=0) |
| 90 | +pl.title('Projected training samples WDA') |
| 91 | + |
| 92 | + |
| 93 | +pl.subplot(2,2,4) |
| 94 | +pl.scatter(xtpw[:,0],xtpw[:,1],c=ys,marker='+',label='Projected samples') |
| 95 | +pl.legend(loc=0) |
| 96 | +pl.title('Projected test samples WDA') |
0 commit comments