@@ -109,7 +109,7 @@ def proj(X):
109109 return Popt , proj
110110
111111
112- def wda (X , y , p = 2 , reg = 1 , k = 10 , solver = None , maxiter = 100 , verbose = 0 , P0 = None ):
112+ def wda (X , y , p = 2 , reg = 1 , k = 10 , solver = None , maxiter = 100 , verbose = 0 , P0 = None , normalize = False ):
113113 r"""
114114 Wasserstein Discriminant Analysis [11]_
115115
@@ -139,6 +139,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
139139 else should be a pymanopt.solvers
140140 P0 : ndarray, shape (d, p)
141141 Initial starting point for projection.
142+ normalize : bool, optional
143+ Normalise the Wasserstaiun distane by the average distance on P0 (default : False)
142144 verbose : int, optional
143145 Print information along iterations.
144146
@@ -164,6 +166,18 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
164166 # compute uniform weighs
165167 wc = [np .ones ((x .shape [0 ]), dtype = np .float32 ) / x .shape [0 ] for x in xc ]
166168
169+ # pre-compute reg_c,c'
170+ if P0 is not None and normalize :
171+ regmean = np .zeros ((len (xc ), len (xc )))
172+ for i , xi in enumerate (xc ):
173+ xi = np .dot (xi , P0 )
174+ for j , xj in enumerate (xc [i :]):
175+ xj = np .dot (xj , P0 )
176+ M = dist (xi , xj )
177+ regmean [i , j ] = np .sum (M ) / (len (xi ) * len (xj ))
178+ else :
179+ regmean = np .ones ((len (xc ), len (xc )))
180+
167181 def cost (P ):
168182 # wda loss
169183 loss_b = 0
@@ -174,7 +188,7 @@ def cost(P):
174188 for j , xj in enumerate (xc [i :]):
175189 xj = np .dot (xj , P )
176190 M = dist (xi , xj )
177- G = sinkhorn (wc [i ], wc [j + i ], M , reg , k )
191+ G = sinkhorn (wc [i ], wc [j + i ], M , reg * regmean [ i , j ] , k )
178192 if j == 0 :
179193 loss_w += np .sum (G * M )
180194 else :
0 commit comments