@@ -193,30 +193,30 @@ def sinkhorn_l1l2_gl(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
193193
194194 """
195195 lstlab = np .unique (labels_a )
196-
196+
197197 def f (G ):
198198 res = 0
199199 for i in range (G .shape [1 ]):
200200 for lab in lstlab :
201201 temp = G [labels_a == lab ,i ]
202- res += np .linalg .norm (temp )
202+ res += np .linalg .norm (temp )
203203 return res
204-
204+
205205 def df (G ):
206- W = np .zeros (G .shape )
206+ W = np .zeros (G .shape )
207207 for i in range (G .shape [1 ]):
208208 for lab in lstlab :
209209 temp = G [labels_a == lab ,i ]
210210 n = np .linalg .norm (temp )
211211 if n :
212- W [labels_a == lab ,i ]= temp / n
213- return W
212+ W [labels_a == lab ,i ]= temp / n
213+ return W
214+
214215
215-
216216 return gcg (a ,b ,M ,reg ,eta ,f ,df ,G0 = None ,numItermax = numItermax ,numInnerItermax = numInnerItermax , stopThr = stopInnerThr ,verbose = verbose ,log = log )
217-
218-
219-
217+
218+
219+
220220def joint_OT_mapping_linear (xs ,xt ,mu = 1 ,eta = 0.001 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 10 ,stopInnerThr = 1e-6 ,stopThr = 1e-5 ,log = False ,** kwargs ):
221221 """Joint OT and linear mapping estimation as proposed in [8]
222222
@@ -606,7 +606,7 @@ def __init__(self,metric='sqeuclidean'):
606606 self .computed = False
607607
608608
609- def fit (self ,xs ,xt ,ws = None ,wt = None ):
609+ def fit (self ,xs ,xt ,ws = None ,wt = None , norm = None ):
610610 """ Fit domain adaptation between samples is xs and xt (with optional weights)"""
611611 self .xs = xs
612612 self .xt = xt
@@ -620,6 +620,7 @@ def fit(self,xs,xt,ws=None,wt=None):
620620 self .wt = wt
621621
622622 self .M = dist (xs ,xt ,metric = self .metric )
623+ self .normalize ()
623624 self .G = emd (ws ,wt ,self .M )
624625 self .computed = True
625626
@@ -684,12 +685,25 @@ def predict(self,x,direction=1):
684685 xf = self .interp (direction )# interp the source samples
685686 return xf [idx ,:]+ x - x0 [idx ,:] # aply the delta to the interpolation
686687
688+ def normalizeM (self , norm ):
689+ """
690+ It may help to normalize the cost matrix self.M if there are numerical
691+ errors during the sinkhorn based algorithms.
692+ """
693+ if norm == "median" :
694+ self .M /= float (np .median (self .M ))
695+ elif norm == "max" :
696+ self .M /= float (np .max (self .M ))
697+ elif norm == "log" :
698+ self .M = np .log (1 + self .M )
699+ elif norm == "loglog" :
700+ self .M = np .log (1 + np .log (1 + self .M ))
687701
688702
689703class OTDA_sinkhorn (OTDA ):
690704 """Class for domain adaptation with optimal transport with entropic regularization"""
691705
692- def fit (self ,xs ,xt ,reg = 1 ,ws = None ,wt = None ,** kwargs ):
706+ def fit (self ,xs ,xt ,reg = 1 ,ws = None ,wt = None ,norm = None , ** kwargs ):
693707 """ Fit regularized domain adaptation between samples is xs and xt (with optional weights)"""
694708 self .xs = xs
695709 self .xt = xt
@@ -703,6 +717,7 @@ def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
703717 self .wt = wt
704718
705719 self .M = dist (xs ,xt ,metric = self .metric )
720+ self .normalizeM (norm )
706721 self .G = sinkhorn (ws ,wt ,self .M ,reg ,** kwargs )
707722 self .computed = True
708723
@@ -711,7 +726,7 @@ class OTDA_lpl1(OTDA):
711726 """Class for domain adaptation with optimal transport with entropic and group regularization"""
712727
713728
714- def fit (self ,xs ,ys ,xt ,reg = 1 ,eta = 1 ,ws = None ,wt = None ,** kwargs ):
729+ def fit (self ,xs ,ys ,xt ,reg = 1 ,eta = 1 ,ws = None ,wt = None ,norm = None , ** kwargs ):
715730 """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit parameters"""
716731 self .xs = xs
717732 self .xt = xt
@@ -725,14 +740,15 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
725740 self .wt = wt
726741
727742 self .M = dist (xs ,xt ,metric = self .metric )
743+ self .normalizeM (norm )
728744 self .G = sinkhorn_lpl1_mm (ws ,ys ,wt ,self .M ,reg ,eta ,** kwargs )
729745 self .computed = True
730-
746+
731747class OTDA_l1l2 (OTDA ):
732748 """Class for domain adaptation with optimal transport with entropic and group lasso regularization"""
733749
734750
735- def fit (self ,xs ,ys ,xt ,reg = 1 ,eta = 1 ,ws = None ,wt = None ,** kwargs ):
751+ def fit (self ,xs ,ys ,xt ,reg = 1 ,eta = 1 ,ws = None ,wt = None ,norm = None , ** kwargs ):
736752 """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters"""
737753 self .xs = xs
738754 self .xt = xt
@@ -746,6 +762,7 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
746762 self .wt = wt
747763
748764 self .M = dist (xs ,xt ,metric = self .metric )
765+ self .normalizeM (norm )
749766 self .G = sinkhorn_l1l2_gl (ws ,ys ,wt ,self .M ,reg ,eta ,** kwargs )
750767 self .computed = True
751768
0 commit comments