Skip to content

Commit d3d8689

Browse files
committed
first class for DA
1 parent 996c668 commit d3d8689

File tree

1 file changed

+44
-1
lines changed

1 file changed

+44
-1
lines changed

ot/da.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import numpy as np
77
from .bregman import sinkhorn
8+
from .lp import emd
9+
from .utils import unif,dist
810

911

1012

@@ -118,4 +120,45 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
118120
W[indices_labels[0],t]=np.min(all_maj)
119121

120122
return transp
121-
123+
124+
125+
126+
class OTDA():
127+
"""Class for optimal transport with domain adaptation"""
128+
129+
def __init__(self):
130+
self.xs=0
131+
self.xt=0
132+
self.G=0
133+
134+
135+
def fit(self,xs,xt,ws=None,wt=None):
136+
self.xs=xs
137+
self.xt=xt
138+
139+
if wt is None:
140+
wt=unif(xt.shape[0])
141+
if ws is None:
142+
ws=unif(xs.shape[0])
143+
144+
self.ws=ws
145+
self.wt=wt
146+
147+
self.M=dist(xs,xt)
148+
self.G=emd(ws,wt,self.M)
149+
150+
def interp(self,direction=1):
151+
"""Barycentric interpolation for the source (1) or target (-1)"""
152+
153+
if self.G and direction>0:
154+
return (self.G/self.ws).dot(self.xt)
155+
elif self.G and direction<0:
156+
return (self.G.T/self.wt).dot(self.xs)
157+
158+
159+
160+
161+
162+
163+
164+

0 commit comments

Comments
 (0)