99from .utils import unif ,dist
1010
1111
12-
1312def indices (a , func ):
1413 return [i for (i , val ) in enumerate (a ) if func (val )]
1514
@@ -124,15 +123,20 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
124123
125124
126125class OTDA ():
127- """Class for optimal transport with domain adaptation """
126+ """Class for domain adaptation with optimal transport """
128127
129- def __init__ (self ):
128+ def __init__ (self ,metric = 'sqeuclidean' ):
129+ """ Class initialization"""
130130 self .xs = 0
131131 self .xt = 0
132132 self .G = 0
133+ self .metric = metric
134+ self .computed = False
133135
134136
135137 def fit (self ,xs ,xt ,ws = None ,wt = None ):
138+ """ Fit domain adaptation between samples is xs and xt (with optional
139+ weights)"""
136140 self .xs = xs
137141 self .xt = xt
138142
@@ -144,17 +148,58 @@ def fit(self,xs,xt,ws=None,wt=None):
144148 self .ws = ws
145149 self .wt = wt
146150
147- self .M = dist (xs ,xt )
151+ self .M = dist (xs ,xt , metric = self . metric )
148152 self .G = emd (ws ,wt ,self .M )
153+ self .computed = True
149154
150155 def interp (self ,direction = 1 ):
151- """Barycentric interpolation for the source (1) or target (-1)"""
156+ """Barycentric interpolation for the source (1) or target (-1)
157+
158+ This Barycentric interpolation solves for each source (resp target)
159+ sample xs (resp xt) the following optimization problem:
160+
161+ .. math::
162+ arg\min_x \sum_i \gamma_{k,i} c(x,x_i^t)
163+
164+ where k is the index of the sample in xs
165+
166+ For the moment only squared euclidean distance is provided but more
167+ metric c can be used in teh future.
168+
169+ """
170+ if direction > 0 : # >0 then source to target
171+ G = self .G
172+ w = self .ws .reshape ((self .xs .shape [0 ],1 ))
173+ x = self .xt
174+ else :
175+ G = self .G .T
176+ w = self .wt .reshape ((self .xt .shape [0 ],1 ))
177+ x = self .xs
178+
179+ if self .computed :
180+ if self .metric == 'sqeuclidean' :
181+ return np .dot (G / w ,x ) # weighted mean
182+ else :
183+ print ("Warning, metric not handled yet, using weighted average" )
184+ return np .dot (G / w ,x ) # weighted mean
185+ return None
186+ else :
187+ print ("Warning, model not fitted yet, returning None" )
188+ return None
189+
152190
153- if self .G and direction > 0 :
154- return (self .G / self .ws ).dot (self .xt )
155- elif self .G and direction < 0 :
156- return (self .G .T / self .wt ).dot (self .xs )
191+ def predict (x ,direction = 1 ):
192+ """ Out of sample mapping using the formulation from Ferradans
157193
194+ """
195+ if direction > 0 : # >0 then source to target
196+ G = self .G
197+ w = self .ws .reshape ((self .xs .shape [0 ],1 ))
198+ x = self .xt
199+ else :
200+ G = self .G .T
201+ w = self .wt .reshape ((self .xt .shape [0 ],1 ))
202+ x = self .xs
158203
159204
160205
0 commit comments