Skip to content

Commit 1a4c264

Browse files
author
ievred
committed
added label normalization to utils
1 parent 0b402fd commit 1a4c264

File tree

3 files changed

+60
-38
lines changed

3 files changed

+60
-38
lines changed

ot/da.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from .bregman import sinkhorn, jcpot_barycenter
1818
from .lp import emd
19-
from .utils import unif, dist, kernel, cost_normalization
19+
from .utils import unif, dist, kernel, cost_normalization, label_normalization
2020
from .utils import check_params, BaseEstimator
2121
from .unbalanced import sinkhorn_unbalanced
2222
from .optim import cg
@@ -786,6 +786,9 @@ class BaseTransport(BaseEstimator):
786786
787787
transform method should always get as input a Xs parameter
788788
inverse_transform method should always get as input a Xt parameter
789+
790+
transform_labels method should always get as input a ys parameter
791+
inverse_transform_labels method should always get as input a yt parameter
789792
"""
790793

791794
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
@@ -944,7 +947,7 @@ class label
944947
return transp_Xs
945948

946949
def transform_labels(self, ys=None):
947-
"""Propagate source labels ys to obtain estimated target labels
950+
"""Propagate source labels ys to obtain estimated target labels as in [27]
948951
949952
Parameters
950953
----------
@@ -955,33 +958,37 @@ def transform_labels(self, ys=None):
955958
-------
956959
transp_ys : array-like, shape (n_target_samples,)
957960
Estimated target labels.
961+
962+
References
963+
----------
964+
965+
.. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
966+
"Optimal transport for multi-source domain adaptation under target shift",
967+
International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
968+
958969
"""
959970

960971
# check the necessary inputs parameters are here
961972
if check_params(ys=ys):
962973

963-
classes = np.unique(ys)
974+
ysTemp = label_normalization(np.copy(ys))
975+
classes = np.unique(ysTemp)
964976
n = len(classes)
965-
D1 = np.zeros((n, len(ys)))
977+
D1 = np.zeros((n, len(ysTemp)))
966978

967979
# perform label propagation
968980
transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
969981

970982
# set nans to 0
971983
transp[~ np.isfinite(transp)] = 0
972984

973-
if np.min(classes) != 0:
974-
ys = ys - np.min(classes)
975-
classes = np.unique(ys)
976-
977985
for c in classes:
978-
D1[int(c), ys == c] = 1
986+
D1[int(c), ysTemp == c] = 1
979987

980988
# compute transported samples
981989
transp_ys = np.dot(D1, transp)
982990

983-
return np.argmax(transp_ys,axis=0)
984-
991+
return np.argmax(transp_ys, axis=0)
985992

986993
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
987994
batch_size=128):
@@ -1066,27 +1073,24 @@ def inverse_transform_labels(self, yt=None):
10661073
# check the necessary inputs parameters are here
10671074
if check_params(yt=yt):
10681075

1069-
classes = np.unique(yt)
1076+
ytTemp = label_normalization(np.copy(yt))
1077+
classes = np.unique(ytTemp)
10701078
n = len(classes)
1071-
D1 = np.zeros((n, len(yt)))
1079+
D1 = np.zeros((n, len(ytTemp)))
10721080

10731081
# perform label propagation
10741082
transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
10751083

10761084
# set nans to 0
10771085
transp[~ np.isfinite(transp)] = 0
10781086

1079-
if np.min(classes) != 0:
1080-
yt = yt - np.min(classes)
1081-
classes = np.unique(yt)
1082-
10831087
for c in classes:
1084-
D1[int(c), yt == c] = 1
1088+
D1[int(c), ytTemp == c] = 1
10851089

10861090
# compute transported samples
10871091
transp_ys = np.dot(D1, transp.T)
10881092

1089-
return np.argmax(transp_ys,axis=0)
1093+
return np.argmax(transp_ys, axis=0)
10901094

10911095

10921096
class LinearTransport(BaseTransport):
@@ -2163,7 +2167,7 @@ class label
21632167
return transp_Xs
21642168

21652169
def transform_labels(self, ys=None):
2166-
"""Propagate source labels ys to obtain target labels
2170+
"""Propagate source labels ys to obtain target labels as in [27]
21672171
21682172
Parameters
21692173
----------
@@ -2178,11 +2182,12 @@ def transform_labels(self, ys=None):
21782182

21792183
# check the necessary inputs parameters are here
21802184
if check_params(ys=ys):
2181-
yt = np.zeros((len(np.unique(np.concatenate(ys))),self.xt_.shape[0]))
2185+
yt = np.zeros((len(np.unique(np.concatenate(ys))), self.xt_.shape[0]))
21822186
for i in range(len(ys)):
2183-
classes = np.unique(ys[i])
2187+
ysTemp = label_normalization(np.copy(ys[i]))
2188+
classes = np.unique(ysTemp)
21842189
n = len(classes)
2185-
ns = len(ys[i])
2190+
ns = len(ysTemp)
21862191

21872192
# perform label propagation
21882193
transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None]
@@ -2195,16 +2200,13 @@ def transform_labels(self, ys=None):
21952200
else:
21962201
D1 = np.zeros((n, ns))
21972202

2198-
if np.min(classes) != 0:
2199-
ys = ys - np.min(classes)
2200-
classes = np.unique(ys)
2201-
22022203
for c in classes:
2203-
D1[int(c), ys == c] = 1
2204+
D1[int(c), ysTemp == c] = 1
2205+
22042206
# compute transported samples
2205-
yt = yt + np.dot(D1, transp)/len(ys)
2207+
yt = yt + np.dot(D1, transp) / len(ys)
22062208

2207-
return np.argmax(yt,axis=0)
2209+
return np.argmax(yt, axis=0)
22082210

22092211
def inverse_transform_labels(self, yt=None):
22102212
"""Propagate source labels ys to obtain target labels
@@ -2223,16 +2225,13 @@ def inverse_transform_labels(self, yt=None):
22232225
# check the necessary inputs parameters are here
22242226
if check_params(yt=yt):
22252227
transp_ys = []
2226-
classes = np.unique(yt)
2228+
ytTemp = label_normalization(np.copy(yt))
2229+
classes = np.unique(ytTemp)
22272230
n = len(classes)
2228-
D1 = np.zeros((n, len(yt)))
2229-
2230-
if np.min(classes) != 0:
2231-
yt = yt - np.min(classes)
2232-
classes = np.unique(yt)
2231+
D1 = np.zeros((n, len(ytTemp)))
22332232

22342233
for c in classes:
2235-
D1[int(c), yt == c] = 1
2234+
D1[int(c), ytTemp == c] = 1
22362235

22372236
for i in range(len(self.xs_)):
22382237

@@ -2243,6 +2242,6 @@ def inverse_transform_labels(self, yt=None):
22432242
transp[~ np.isfinite(transp)] = 0
22442243

22452244
# compute transported labels
2246-
transp_ys.append(np.argmax(np.dot(D1, transp.T),axis=0))
2245+
transp_ys.append(np.argmax(np.dot(D1, transp.T), axis=0))
22472246

22482247
return transp_ys

ot/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,28 @@ def dots(*args):
200200
return reduce(np.dot, args)
201201

202202

203+
def label_normalization(y, start=0):
204+
""" Transform labels to start at a given value
205+
206+
Parameters
207+
----------
208+
y : array-like, shape (n, )
209+
The vector of labels to be normalized.
210+
start : int
211+
Desired value for the smallest label in y (default=0)
212+
213+
Returns
214+
-------
215+
y : array-like, shape (n1, )
216+
The input vector of labels normalized according to given start value.
217+
"""
218+
219+
diff = np.min(np.unique(y)) - start
220+
if diff != 0:
221+
y -= diff
222+
return y
223+
224+
203225
def fun(f, q_in, q_out):
204226
""" Utility function for parmap with no serializing problems """
205227
while True:

test/test_da.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ def test_jcpot_transport_class():
650650
transp_ys = otda.inverse_transform_labels(yt)
651651
[assert_equal(x.shape, y.shape) for x, y in zip(transp_ys, ys)]
652652

653+
653654
def test_jcpot_barycenter():
654655
"""test_jcpot_barycenter
655656
"""

0 commit comments

Comments
 (0)