@@ -188,21 +188,46 @@ def interp(self,direction=1):
188188 return None
189189
190190
191- def predict (x ,direction = 1 ):
191+ def predict (self , x ,direction = 1 ):
192192 """ Out of sample mapping using the formulation from Ferradans
193193
194+ It basically find the source sample the nearset to the nex sample and
195+ apply the difference to the displaced source sample.
196+
194197 """
195198 if direction > 0 : # >0 then source to target
196- G = self .G
197- w = self .ws .reshape ((self .xs .shape [0 ],1 ))
198- x = self .xt
199+ xf = self .xt
200+ x0 = self .xs
199201 else :
200- G = self .G .T
201- w = self .wt .reshape ((self .xt .shape [0 ],1 ))
202- x = self .xs
202+ xf = self .xs
203+ x0 = self .xt
204+
205+ D0 = dist (x ,x0 ) # dist netween new samples an source
206+ idx = np .argmin (D0 ,1 ) # closest one
207+ xf = self .interp (direction )# interp the source samples
208+ return xf [idx ,:]+ x - x0 [idx ,:] # aply the delta to the interpolation
203209
204210
211+
212+ class OTDA_sinkhorn (OTDA ):
213+
214+ def fit (self ,xs ,xt ,ws = None ,wt = None ,reg = 1 ,** kwargs ):
215+ """ Fit domain adaptation between samples is xs and xt (with optional
216+ weights)"""
217+ self .xs = xs
218+ self .xt = xt
205219
220+ if wt is None :
221+ wt = unif (xt .shape [0 ])
222+ if ws is None :
223+ ws = unif (xs .shape [0 ])
224+
225+ self .ws = ws
226+ self .wt = wt
227+
228+ self .M = dist (xs ,xt ,metric = self .metric )
229+ self .G = sinkhorn (ws ,wt ,self .M ,reg ,** kwargs )
230+ self .computed = True
206231
207232
208233
0 commit comments