Skip to content

Commit 104627b

Browse files
committed
commit doc
1 parent 9e40820 commit 104627b

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

docs/source/conf.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,14 @@
1414

1515
import sys
1616
import os
17-
from unittest.mock import MagicMock
17+
try:
18+
from unittest.mock import MagicMock
19+
except ImportError:
20+
from mock import MagicMock
1821

1922
sys.path.insert(0, os.path.abspath("../.."))
23+
sys.setrecursionlimit(1500)
24+
2025

2126

2227
class Mock(MagicMock):

ot/da.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,21 +188,46 @@ def interp(self,direction=1):
188188
return None
189189

190190

191-
def predict(x,direction=1):
191+
def predict(self,x,direction=1):
192192
""" Out of sample mapping using the formulation from Ferradans
193193
194+
It basically find the source sample the nearset to the nex sample and
195+
apply the difference to the displaced source sample.
196+
194197
"""
195198
if direction>0: # >0 then source to target
196-
G=self.G
197-
w=self.ws.reshape((self.xs.shape[0],1))
198-
x=self.xt
199+
xf=self.xt
200+
x0=self.xs
199201
else:
200-
G=self.G.T
201-
w=self.wt.reshape((self.xt.shape[0],1))
202-
x=self.xs
202+
xf=self.xs
203+
x0=self.xt
204+
205+
D0=dist(x,x0) # dist netween new samples an source
206+
idx=np.argmin(D0,1) # closest one
207+
xf=self.interp(direction)# interp the source samples
208+
return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation
203209

204210

211+
212+
class OTDA_sinkhorn(OTDA):
213+
214+
def fit(self,xs,xt,ws=None,wt=None,reg=1,**kwargs):
215+
""" Fit domain adaptation between samples is xs and xt (with optional
216+
weights)"""
217+
self.xs=xs
218+
self.xt=xt
205219

220+
if wt is None:
221+
wt=unif(xt.shape[0])
222+
if ws is None:
223+
ws=unif(xs.shape[0])
224+
225+
self.ws=ws
226+
self.wt=wt
227+
228+
self.M=dist(xs,xt,metric=self.metric)
229+
self.G=sinkhorn(ws,wt,self.M,reg,**kwargs)
230+
self.computed=True
206231

207232

208233

0 commit comments

Comments
 (0)