Skip to content

Commit fdceacb

Browse files
committed
add classes for entropic and group lasso regularization
1 parent 104627b commit fdceacb

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

ot/da.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ def predict(self,x,direction=1):
210210

211211

212212
class OTDA_sinkhorn(OTDA):
213-
214-
def fit(self,xs,xt,ws=None,wt=None,reg=1,**kwargs):
213+
"""Class for domain adaptation with optimal transport with entropic regularization"""
214+
def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
215215
""" Fit domain adaptation between samples is xs and xt (with optional
216216
weights)"""
217217
self.xs=xs
@@ -230,5 +230,26 @@ def fit(self,xs,xt,ws=None,wt=None,reg=1,**kwargs):
230230
self.computed=True
231231

232232

233+
class OTDA_lpl1(OTDA):
234+
"""Class for domain adaptation with optimal transport with entropic an group regularization"""
235+
233236

237+
def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
238+
""" Fit domain adaptation between samples is xs and xt (with optional
239+
weights)"""
240+
self.xs=xs
241+
self.xt=xt
242+
243+
if wt is None:
244+
wt=unif(xt.shape[0])
245+
if ws is None:
246+
ws=unif(xs.shape[0])
247+
248+
self.ws=ws
249+
self.wt=wt
250+
251+
self.M=dist(xs,xt,metric=self.metric)
252+
self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
253+
self.computed=True
254+
234255

0 commit comments

Comments
 (0)