77from pymanopt .manifolds import Stiefel
88from pymanopt import Problem
99from pymanopt .solvers import SteepestDescent , TrustRegions
10+ import scipy .linalg as la
1011
1112def dist (x1 ,x2 ):
1213 """ Compute squared euclidean distance between samples (autograd)
@@ -32,9 +33,73 @@ def split_classes(X,y):
3233 """
3334 lstsclass = np .unique (y )
3435 return [X [y == i ,:].astype (np .float32 ) for i in lstsclass ]
36+
37+
38+ def fda (X ,y ,p = 2 ,reg = 1e-16 ):
39+ """
40+ Fisher Discriminant Analysis
41+
3542
43+ Parameters
44+ ----------
45+ X : numpy.ndarray (n,d)
46+ Training samples
47+ y : np.ndarray (n,)
48+ labels for training samples
49+ p : int, optional
50+ size of dimensionnality reduction
51+ reg : float, optional
52+ Regularization term >0 (ridge regularization)
3653
3754
55+ Returns
56+ -------
57+ P : (d x p) ndarray
58+ Optimal transportation matrix for the given parameters
59+ proj : fun
60+ projection function including mean centering
61+
62+
63+ """
64+
65+ mx = np .mean (X )
66+ X -= mx .reshape ((1 ,- 1 ))
67+
68+ # data split between classes
69+ d = X .shape [1 ]
70+ xc = split_classes (X ,y )
71+ nc = len (xc )
72+
73+ p = min (nc - 1 ,p )
74+
75+ Cw = 0
76+ for x in xc :
77+ Cw += np .cov (x ,rowvar = False )
78+ Cw /= nc
79+
80+ mxc = np .zeros ((d ,nc ))
81+
82+ for i in range (nc ):
83+ mxc [:,i ]= np .mean (xc [i ])
84+
85+ mx0 = np .mean (mxc ,1 )
86+ Cb = 0
87+ for i in range (nc ):
88+ Cb += (mxc [:,i ]- mx0 ).reshape ((- 1 ,1 ))* (mxc [:,i ]- mx0 ).reshape ((1 ,- 1 ))
89+
90+ w ,V = la .eig (Cb ,Cw + reg * np .eye (d ))
91+
92+ idx = np .argsort (w .real )
93+
94+ Popt = V [:,idx [- p :]]
95+
96+
97+
98+ def proj (X ):
99+ return (X - mx .reshape ((1 ,- 1 ))).dot (Popt )
100+
101+ return Popt , proj
102+
38103def wda (X ,y ,p = 2 ,reg = 1 ,k = 10 ,solver = None ,maxiter = 100 ,verbose = 0 ):
39104 """
40105 Wasserstein Discriminant Analysis [11]_
@@ -73,16 +138,13 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
73138 P : (d x p) ndarray
74139 Optimal transportation matrix for the given parameters
75140 proj : fun
76- projectiuon function including mean centering
141+ projection function including mean centering
77142
78143
79144 References
80145 ----------
81146
82147 .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
83-
84-
85-
86148
87149 """
88150
@@ -131,3 +193,6 @@ def proj(X):
131193 return (X - mx .reshape ((1 ,- 1 ))).dot (Popt )
132194
133195 return Popt , proj
196+
197+
198+
0 commit comments