1616
1717from .bregman import sinkhorn , jcpot_barycenter
1818from .lp import emd
19- from .utils import unif , dist , kernel , cost_normalization
19+ from .utils import unif , dist , kernel , cost_normalization , label_normalization
2020from .utils import check_params , BaseEstimator
2121from .unbalanced import sinkhorn_unbalanced
2222from .optim import cg
@@ -786,6 +786,9 @@ class BaseTransport(BaseEstimator):
786786
787787 transform method should always get as input a Xs parameter
788788 inverse_transform method should always get as input a Xt parameter
789+
790+ transform_labels method should always get as input a ys parameter
791+ inverse_transform_labels method should always get as input a yt parameter
789792 """
790793
791794 def fit (self , Xs = None , ys = None , Xt = None , yt = None ):
@@ -944,7 +947,7 @@ class label
944947 return transp_Xs
945948
946949 def transform_labels (self , ys = None ):
947- """Propagate source labels ys to obtain estimated target labels
950+ """Propagate source labels ys to obtain estimated target labels as in [27]
948951
949952 Parameters
950953 ----------
@@ -955,33 +958,37 @@ def transform_labels(self, ys=None):
955958 -------
956959 transp_ys : array-like, shape (n_target_samples,)
957960 Estimated target labels.
961+
962+ References
963+ ----------
964+
965+ .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
966+ "Optimal transport for multi-source domain adaptation under target shift",
967+ International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
968+
958969 """
959970
960971 # check the necessary inputs parameters are here
961972 if check_params (ys = ys ):
962973
963- classes = np .unique (ys )
974+ ysTemp = label_normalization (np .copy (ys ))
975+ classes = np .unique (ysTemp )
964976 n = len (classes )
965- D1 = np .zeros ((n , len (ys )))
977+ D1 = np .zeros ((n , len (ysTemp )))
966978
967979 # perform label propagation
968980 transp = self .coupling_ / np .sum (self .coupling_ , 1 )[:, None ]
969981
970982 # set nans to 0
971983 transp [~ np .isfinite (transp )] = 0
972984
973- if np .min (classes ) != 0 :
974- ys = ys - np .min (classes )
975- classes = np .unique (ys )
976-
977985 for c in classes :
978- D1 [int (c ), ys == c ] = 1
986+ D1 [int (c ), ysTemp == c ] = 1
979987
980988 # compute transported samples
981989 transp_ys = np .dot (D1 , transp )
982990
983- return np .argmax (transp_ys ,axis = 0 )
984-
991+ return np .argmax (transp_ys , axis = 0 )
985992
986993 def inverse_transform (self , Xs = None , ys = None , Xt = None , yt = None ,
987994 batch_size = 128 ):
@@ -1066,27 +1073,24 @@ def inverse_transform_labels(self, yt=None):
10661073 # check the necessary inputs parameters are here
10671074 if check_params (yt = yt ):
10681075
1069- classes = np .unique (yt )
1076+ ytTemp = label_normalization (np .copy (yt ))
1077+ classes = np .unique (ytTemp )
10701078 n = len (classes )
1071- D1 = np .zeros ((n , len (yt )))
1079+ D1 = np .zeros ((n , len (ytTemp )))
10721080
10731081 # perform label propagation
10741082 transp = self .coupling_ / np .sum (self .coupling_ , 1 )[:, None ]
10751083
10761084 # set nans to 0
10771085 transp [~ np .isfinite (transp )] = 0
10781086
1079- if np .min (classes ) != 0 :
1080- yt = yt - np .min (classes )
1081- classes = np .unique (yt )
1082-
10831087 for c in classes :
1084- D1 [int (c ), yt == c ] = 1
1088+ D1 [int (c ), ytTemp == c ] = 1
10851089
10861090 # compute transported samples
10871091 transp_ys = np .dot (D1 , transp .T )
10881092
1089- return np .argmax (transp_ys ,axis = 0 )
1093+ return np .argmax (transp_ys , axis = 0 )
10901094
10911095
10921096class LinearTransport (BaseTransport ):
@@ -2163,7 +2167,7 @@ class label
21632167 return transp_Xs
21642168
21652169 def transform_labels (self , ys = None ):
2166- """Propagate source labels ys to obtain target labels
2170+ """Propagate source labels ys to obtain target labels as in [27]
21672171
21682172 Parameters
21692173 ----------
@@ -2178,11 +2182,12 @@ def transform_labels(self, ys=None):
21782182
21792183 # check the necessary inputs parameters are here
21802184 if check_params (ys = ys ):
2181- yt = np .zeros ((len (np .unique (np .concatenate (ys ))),self .xt_ .shape [0 ]))
2185+ yt = np .zeros ((len (np .unique (np .concatenate (ys ))), self .xt_ .shape [0 ]))
21822186 for i in range (len (ys )):
2183- classes = np .unique (ys [i ])
2187+ ysTemp = label_normalization (np .copy (ys [i ]))
2188+ classes = np .unique (ysTemp )
21842189 n = len (classes )
2185- ns = len (ys [ i ] )
2190+ ns = len (ysTemp )
21862191
21872192 # perform label propagation
21882193 transp = self .coupling_ [i ] / np .sum (self .coupling_ [i ], 1 )[:, None ]
@@ -2195,16 +2200,13 @@ def transform_labels(self, ys=None):
21952200 else :
21962201 D1 = np .zeros ((n , ns ))
21972202
2198- if np .min (classes ) != 0 :
2199- ys = ys - np .min (classes )
2200- classes = np .unique (ys )
2201-
22022203 for c in classes :
2203- D1 [int (c ), ys == c ] = 1
2204+ D1 [int (c ), ysTemp == c ] = 1
2205+
22042206 # compute transported samples
2205- yt = yt + np .dot (D1 , transp )/ len (ys )
2207+ yt = yt + np .dot (D1 , transp ) / len (ys )
22062208
2207- return np .argmax (yt ,axis = 0 )
2209+ return np .argmax (yt , axis = 0 )
22082210
22092211 def inverse_transform_labels (self , yt = None ):
22102212 """Propagate source labels ys to obtain target labels
@@ -2223,16 +2225,13 @@ def inverse_transform_labels(self, yt=None):
22232225 # check the necessary inputs parameters are here
22242226 if check_params (yt = yt ):
22252227 transp_ys = []
2226- classes = np .unique (yt )
2228+ ytTemp = label_normalization (np .copy (yt ))
2229+ classes = np .unique (ytTemp )
22272230 n = len (classes )
2228- D1 = np .zeros ((n , len (yt )))
2229-
2230- if np .min (classes ) != 0 :
2231- yt = yt - np .min (classes )
2232- classes = np .unique (yt )
2231+ D1 = np .zeros ((n , len (ytTemp )))
22332232
22342233 for c in classes :
2235- D1 [int (c ), yt == c ] = 1
2234+ D1 [int (c ), ytTemp == c ] = 1
22362235
22372236 for i in range (len (self .xs_ )):
22382237
@@ -2243,6 +2242,6 @@ def inverse_transform_labels(self, yt=None):
22432242 transp [~ np .isfinite (transp )] = 0
22442243
22452244 # compute transported labels
2246- transp_ys .append (np .argmax (np .dot (D1 , transp .T ),axis = 0 ))
2245+ transp_ys .append (np .argmax (np .dot (D1 , transp .T ), axis = 0 ))
22472246
22482247 return transp_ys
0 commit comments