@@ -606,7 +606,7 @@ def __init__(self,metric='sqeuclidean'):
606606 self .computed = False
607607
608608
609- def fit (self ,xs ,xt ,ws = None ,wt = None ):
609+ def fit (self ,xs ,xt ,ws = None ,wt = None , norm = None ):
610610 """ Fit domain adaptation between samples is xs and xt (with optional weights)"""
611611 self .xs = xs
612612 self .xt = xt
@@ -620,6 +620,7 @@ def fit(self,xs,xt,ws=None,wt=None):
620620 self .wt = wt
621621
622622 self .M = dist (xs ,xt ,metric = self .metric )
623+ self .normalize ()
623624 self .G = emd (ws ,wt ,self .M )
624625 self .computed = True
625626
@@ -684,11 +685,25 @@ def predict(self,x,direction=1):
684685 xf = self .interp (direction )# interp the source samples
685686 return xf [idx ,:]+ x - x0 [idx ,:] # aply the delta to the interpolation
686687
688+ def normalizeM (self , norm ):
689+ """
690+ It may help to normalize the cost matrix self.M if there are numerical
691+ errors during the sinkhorn based algorithms.
692+ """
693+ if norm == "median" :
694+ self .M /= float (np .median (self .M ))
695+ elif norm == "max" :
696+ self .M /= float (np .max (self .M ))
697+ elif norm == "log" :
698+ self .M = np .log (1 + self .M )
699+ elif norm == "loglog" :
700+ self .M = np .log (1 + np .log (1 + self .M ))
701+
687702
688703class OTDA_sinkhorn (OTDA ):
689704 """Class for domain adaptation with optimal transport with entropic regularization"""
690705
691- def fit (self ,xs ,xt ,reg = 1 ,ws = None ,wt = None ,** kwargs ):
706+ def fit (self ,xs ,xt ,reg = 1 ,ws = None ,wt = None ,norm = None , ** kwargs ):
692707 """ Fit regularized domain adaptation between samples is xs and xt (with optional weights)"""
693708 self .xs = xs
694709 self .xt = xt
@@ -702,6 +717,7 @@ def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
702717 self .wt = wt
703718
704719 self .M = dist (xs ,xt ,metric = self .metric )
720+ self .normalizeM (norm )
705721 self .G = sinkhorn (ws ,wt ,self .M ,reg ,** kwargs )
706722 self .computed = True
707723
@@ -710,7 +726,7 @@ class OTDA_lpl1(OTDA):
710726 """Class for domain adaptation with optimal transport with entropic and group regularization"""
711727
712728
713- def fit (self ,xs ,ys ,xt ,reg = 1 ,eta = 1 ,ws = None ,wt = None ,** kwargs ):
729+ def fit (self ,xs ,ys ,xt ,reg = 1 ,eta = 1 ,ws = None ,wt = None ,norm = None , ** kwargs ):
714730 """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit parameters"""
715731 self .xs = xs
716732 self .xt = xt
@@ -724,14 +740,15 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
724740 self .wt = wt
725741
726742 self .M = dist (xs ,xt ,metric = self .metric )
743+ self .normalizeM (norm )
727744 self .G = sinkhorn_lpl1_mm (ws ,ys ,wt ,self .M ,reg ,eta ,** kwargs )
728745 self .computed = True
729746
730747class OTDA_l1l2 (OTDA ):
731748 """Class for domain adaptation with optimal transport with entropic and group lasso regularization"""
732749
733750
734- def fit (self ,xs ,ys ,xt ,reg = 1 ,eta = 1 ,ws = None ,wt = None ,** kwargs ):
751+ def fit (self ,xs ,ys ,xt ,reg = 1 ,eta = 1 ,ws = None ,wt = None ,norm = None , ** kwargs ):
735752 """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters"""
736753 self .xs = xs
737754 self .xt = xt
@@ -745,6 +762,7 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
745762 self .wt = wt
746763
747764 self .M = dist (xs ,xt ,metric = self .metric )
765+ self .normalizeM (norm )
748766 self .G = sinkhorn_l1l2_gl (ws ,ys ,wt ,self .M ,reg ,eta ,** kwargs )
749767 self .computed = True
750768
0 commit comments