@@ -100,7 +100,7 @@ def proj(X):
100100
101101 return Popt , proj
102102
103- def wda (X ,y ,p = 2 ,reg = 1 ,k = 10 ,solver = None ,maxiter = 100 ,verbose = 0 ):
103+ def wda (X ,y ,p = 2 ,reg = 1 ,k = 10 ,solver = None ,maxiter = 100 ,verbose = 0 , P0 = None ):
104104 """
105105 Wasserstein Discriminant Analysis [11]_
106106
@@ -127,7 +127,9 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
127127 Regularization term >0 (entropic regularization)
128128 solver : str, optional
129129 None for steepest decsent or 'TrustRegions' for trust regions algorithm
130- else shoudl be a pymanopt.sovers
130+ else shoudl be a pymanopt.solvers
131+ P0 : numpy.ndarray (d,p)
132+ Initial starting point for projection
131133 verbose : int, optional
132134 Print information along iterations
133135
@@ -187,7 +189,7 @@ def cost(P):
187189 elif solver in ['tr' ,'TrustRegions' ]:
188190 solver = TrustRegions (maxiter = maxiter ,logverbosity = verbose )
189191
190- Popt = solver .solve (problem )
192+ Popt = solver .solve (problem , x = P0 )
191193
192194 def proj (X ):
193195 return (X - mx .reshape ((1 ,- 1 ))).dot (Popt )
0 commit comments