Skip to content

Commit 9e40820

Browse files
committed
firt DA class
1 parent e3b1150 commit 9e40820

File tree

1 file changed

+54
-9
lines changed

1 file changed

+54
-9
lines changed

ot/da.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from .utils import unif,dist
1010

1111

12-
1312
def indices(a, func):
1413
return [i for (i, val) in enumerate(a) if func(val)]
1514

@@ -124,15 +123,20 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
124123

125124

126125
class OTDA():
127-
"""Class for optimal transport with domain adaptation"""
126+
"""Class for domain adaptation with optimal transport"""
128127

129-
def __init__(self):
128+
def __init__(self,metric='sqeuclidean'):
129+
""" Class initialization"""
130130
self.xs=0
131131
self.xt=0
132132
self.G=0
133+
self.metric=metric
134+
self.computed=False
133135

134136

135137
def fit(self,xs,xt,ws=None,wt=None):
138+
""" Fit domain adaptation between samples is xs and xt (with optional
139+
weights)"""
136140
self.xs=xs
137141
self.xt=xt
138142

@@ -144,17 +148,58 @@ def fit(self,xs,xt,ws=None,wt=None):
144148
self.ws=ws
145149
self.wt=wt
146150

147-
self.M=dist(xs,xt)
151+
self.M=dist(xs,xt,metric=self.metric)
148152
self.G=emd(ws,wt,self.M)
153+
self.computed=True
149154

150155
def interp(self,direction=1):
151-
"""Barycentric interpolation for the source (1) or target (-1)"""
156+
"""Barycentric interpolation for the source (1) or target (-1)
157+
158+
This Barycentric interpolation solves for each source (resp target)
159+
sample xs (resp xt) the following optimization problem:
160+
161+
.. math::
162+
arg\min_x \sum_i \gamma_{k,i} c(x,x_i^t)
163+
164+
where k is the index of the sample in xs
165+
166+
For the moment only squared euclidean distance is provided but more
167+
metric c can be used in teh future.
168+
169+
"""
170+
if direction>0: # >0 then source to target
171+
G=self.G
172+
w=self.ws.reshape((self.xs.shape[0],1))
173+
x=self.xt
174+
else:
175+
G=self.G.T
176+
w=self.wt.reshape((self.xt.shape[0],1))
177+
x=self.xs
178+
179+
if self.computed:
180+
if self.metric=='sqeuclidean':
181+
return np.dot(G/w,x) # weighted mean
182+
else:
183+
print("Warning, metric not handled yet, using weighted average")
184+
return np.dot(G/w,x) # weighted mean
185+
return None
186+
else:
187+
print("Warning, model not fitted yet, returning None")
188+
return None
189+
152190

153-
if self.G and direction>0:
154-
return (self.G/self.ws).dot(self.xt)
155-
elif self.G and direction<0:
156-
return (self.G.T/self.wt).dot(self.xs)
191+
def predict(x,direction=1):
192+
""" Out of sample mapping using the formulation from Ferradans
157193
194+
"""
195+
if direction>0: # >0 then source to target
196+
G=self.G
197+
w=self.ws.reshape((self.xs.shape[0],1))
198+
x=self.xt
199+
else:
200+
G=self.G.T
201+
w=self.wt.reshape((self.xt.shape[0],1))
202+
x=self.xs
158203

159204

160205

0 commit comments

Comments
 (0)