Skip to content

Commit faa4744

Browse files
committed
add init WDA
1 parent a1ae72e commit faa4744

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

ot/dr.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def proj(X):
100100

101101
return Popt, proj
102102

103-
def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
103+
def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0,P0=None):
104104
"""
105105
Wasserstein Discriminant Analysis [11]_
106106
@@ -127,7 +127,9 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
127127
Regularization term >0 (entropic regularization)
128128
solver : str, optional
129129
None for steepest decsent or 'TrustRegions' for trust regions algorithm
130-
else shoudl be a pymanopt.sovers
130+
else shoudl be a pymanopt.solvers
131+
P0 : numpy.ndarray (d,p)
132+
Initial starting point for projection
131133
verbose : int, optional
132134
Print information along iterations
133135
@@ -187,7 +189,7 @@ def cost(P):
187189
elif solver in ['tr','TrustRegions']:
188190
solver= TrustRegions(maxiter=maxiter,logverbosity=verbose)
189191

190-
Popt = solver.solve(problem)
192+
Popt = solver.solve(problem,x=P0)
191193

192194
def proj(X):
193195
return (X-mx.reshape((1,-1))).dot(Popt)

0 commit comments

Comments
 (0)