Skip to content

Commit 0b402fd

Browse files
author
ievred
committed
add label prop + inverse
1 parent bc51793 commit 0b402fd

File tree

2 files changed

+214
-4
lines changed

2 files changed

+214
-4
lines changed

ot/da.py

Lines changed: 167 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,46 @@ class label
943943

944944
return transp_Xs
945945

946+
def transform_labels(self, ys=None):
947+
"""Propagate source labels ys to obtain estimated target labels
948+
949+
Parameters
950+
----------
951+
ys : array-like, shape (n_source_samples,)
952+
The class labels
953+
954+
Returns
955+
-------
956+
transp_ys : array-like, shape (n_target_samples,)
957+
Estimated target labels.
958+
"""
959+
960+
# check the necessary inputs parameters are here
961+
if check_params(ys=ys):
962+
963+
classes = np.unique(ys)
964+
n = len(classes)
965+
D1 = np.zeros((n, len(ys)))
966+
967+
# perform label propagation
968+
transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
969+
970+
# set nans to 0
971+
transp[~ np.isfinite(transp)] = 0
972+
973+
if np.min(classes) != 0:
974+
ys = ys - np.min(classes)
975+
classes = np.unique(ys)
976+
977+
for c in classes:
978+
D1[int(c), ys == c] = 1
979+
980+
# compute transported samples
981+
transp_ys = np.dot(D1, transp)
982+
983+
return np.argmax(transp_ys,axis=0)
984+
985+
946986
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
947987
batch_size=128):
948988
"""Transports target samples Xt onto target samples Xs
@@ -1010,6 +1050,44 @@ class label
10101050

10111051
return transp_Xt
10121052

1053+
def inverse_transform_labels(self, yt=None):
1054+
"""Propagate target labels yt to obtain estimated source labels ys
1055+
1056+
Parameters
1057+
----------
1058+
yt : array-like, shape (n_target_samples,)
1059+
1060+
Returns
1061+
-------
1062+
transp_ys : array-like, shape (n_source_samples,)
1063+
Estimated source labels.
1064+
"""
1065+
1066+
# check the necessary inputs parameters are here
1067+
if check_params(yt=yt):
1068+
1069+
classes = np.unique(yt)
1070+
n = len(classes)
1071+
D1 = np.zeros((n, len(yt)))
1072+
1073+
# perform label propagation
1074+
transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
1075+
1076+
# set nans to 0
1077+
transp[~ np.isfinite(transp)] = 0
1078+
1079+
if np.min(classes) != 0:
1080+
yt = yt - np.min(classes)
1081+
classes = np.unique(yt)
1082+
1083+
for c in classes:
1084+
D1[int(c), yt == c] = 1
1085+
1086+
# compute transported samples
1087+
transp_ys = np.dot(D1, transp.T)
1088+
1089+
return np.argmax(transp_ys,axis=0)
1090+
10131091

10141092
class LinearTransport(BaseTransport):
10151093

@@ -2017,10 +2095,10 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
20172095
20182096
Parameters
20192097
----------
2020-
Xs : array-like, shape (n_source_samples, n_features)
2021-
The training input samples.
2022-
ys : array-like, shape (n_source_samples,)
2023-
The class labels
2098+
Xs : list of K array-like objects, shape K x (nk_source_samples, n_features)
2099+
A list of the training input samples.
2100+
ys : list of K array-like objects, shape K x (nk_source_samples,)
2101+
A list of the class labels
20242102
Xt : array-like, shape (n_target_samples, n_features)
20252103
The training input samples.
20262104
yt : array-like, shape (n_target_samples,)
@@ -2083,3 +2161,88 @@ class label
20832161
transp_Xs = np.concatenate(transp_Xs, axis=0)
20842162

20852163
return transp_Xs
2164+
2165+
def transform_labels(self, ys=None):
2166+
"""Propagate source labels ys to obtain target labels
2167+
2168+
Parameters
2169+
----------
2170+
ys : list of K array-like objects, shape K x (nk_source_samples,)
2171+
A list of the class labels
2172+
2173+
Returns
2174+
-------
2175+
yt : array-like, shape (n_target_samples,)
2176+
Estimated target labels.
2177+
"""
2178+
2179+
# check the necessary inputs parameters are here
2180+
if check_params(ys=ys):
2181+
yt = np.zeros((len(np.unique(np.concatenate(ys))),self.xt_.shape[0]))
2182+
for i in range(len(ys)):
2183+
classes = np.unique(ys[i])
2184+
n = len(classes)
2185+
ns = len(ys[i])
2186+
2187+
# perform label propagation
2188+
transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None]
2189+
2190+
# set nans to 0
2191+
transp[~ np.isfinite(transp)] = 0
2192+
2193+
if self.log:
2194+
D1 = self.log_['D1'][i]
2195+
else:
2196+
D1 = np.zeros((n, ns))
2197+
2198+
if np.min(classes) != 0:
2199+
ys = ys - np.min(classes)
2200+
classes = np.unique(ys)
2201+
2202+
for c in classes:
2203+
D1[int(c), ys == c] = 1
2204+
# compute transported samples
2205+
yt = yt + np.dot(D1, transp)/len(ys)
2206+
2207+
return np.argmax(yt,axis=0)
2208+
2209+
def inverse_transform_labels(self, yt=None):
2210+
"""Propagate source labels ys to obtain target labels
2211+
2212+
Parameters
2213+
----------
2214+
yt : array-like, shape (n_source_samples,)
2215+
The target class labels
2216+
2217+
Returns
2218+
-------
2219+
transp_ys : list of K array-like objects, shape K x (nk_source_samples,)
2220+
A list of estimated source labels
2221+
"""
2222+
2223+
# check the necessary inputs parameters are here
2224+
if check_params(yt=yt):
2225+
transp_ys = []
2226+
classes = np.unique(yt)
2227+
n = len(classes)
2228+
D1 = np.zeros((n, len(yt)))
2229+
2230+
if np.min(classes) != 0:
2231+
yt = yt - np.min(classes)
2232+
classes = np.unique(yt)
2233+
2234+
for c in classes:
2235+
D1[int(c), yt == c] = 1
2236+
2237+
for i in range(len(self.xs_)):
2238+
2239+
# perform label propagation
2240+
transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None]
2241+
2242+
# set nans to 0
2243+
transp[~ np.isfinite(transp)] = 0
2244+
2245+
# compute transported labels
2246+
transp_ys.append(np.argmax(np.dot(D1, transp.T),axis=0))
2247+
2248+
return transp_ys

test/test_da.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ def test_sinkhorn_lpl1_transport_class():
6565
transp_Xs = otda.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
6666
assert_equal(transp_Xs.shape, Xs.shape)
6767

68+
# check label propagation
69+
transp_yt = otda.transform_labels(ys)
70+
assert_equal(transp_yt.shape[0], yt.shape[0])
71+
72+
# check inverse label propagation
73+
transp_ys = otda.inverse_transform_labels(yt)
74+
assert_equal(transp_ys.shape[0], ys.shape[0])
75+
6876
# test unsupervised vs semi-supervised mode
6977
otda_unsup = ot.da.SinkhornLpl1Transport()
7078
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
@@ -129,6 +137,14 @@ def test_sinkhorn_l1l2_transport_class():
129137
transp_Xt = otda.inverse_transform(Xt=Xt)
130138
assert_equal(transp_Xt.shape, Xt.shape)
131139

140+
# check label propagation
141+
transp_yt = otda.transform_labels(ys)
142+
assert_equal(transp_yt.shape[0], yt.shape[0])
143+
144+
# check inverse label propagation
145+
transp_ys = otda.inverse_transform_labels(yt)
146+
assert_equal(transp_ys.shape[0], ys.shape[0])
147+
132148
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
133149
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
134150

@@ -210,6 +226,14 @@ def test_sinkhorn_transport_class():
210226
transp_Xt = otda.inverse_transform(Xt=Xt)
211227
assert_equal(transp_Xt.shape, Xt.shape)
212228

229+
# check label propagation
230+
transp_yt = otda.transform_labels(ys)
231+
assert_equal(transp_yt.shape[0], yt.shape[0])
232+
233+
# check inverse label propagation
234+
transp_ys = otda.inverse_transform_labels(yt)
235+
assert_equal(transp_ys.shape[0], ys.shape[0])
236+
213237
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
214238
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
215239

@@ -271,6 +295,14 @@ def test_unbalanced_sinkhorn_transport_class():
271295
transp_Xs = otda.transform(Xs=Xs)
272296
assert_equal(transp_Xs.shape, Xs.shape)
273297

298+
# check label propagation
299+
transp_yt = otda.transform_labels(ys)
300+
assert_equal(transp_yt.shape[0], yt.shape[0])
301+
302+
# check inverse label propagation
303+
transp_ys = otda.inverse_transform_labels(yt)
304+
assert_equal(transp_ys.shape[0], ys.shape[0])
305+
274306
Xs_new, _ = make_data_classif('3gauss', ns + 1)
275307
transp_Xs_new = otda.transform(Xs_new)
276308

@@ -353,6 +385,14 @@ def test_emd_transport_class():
353385
transp_Xt = otda.inverse_transform(Xt=Xt)
354386
assert_equal(transp_Xt.shape, Xt.shape)
355387

388+
# check label propagation
389+
transp_yt = otda.transform_labels(ys)
390+
assert_equal(transp_yt.shape[0], yt.shape[0])
391+
392+
# check inverse label propagation
393+
transp_ys = otda.inverse_transform_labels(yt)
394+
assert_equal(transp_ys.shape[0], ys.shape[0])
395+
356396
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
357397
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
358398

@@ -602,6 +642,13 @@ def test_jcpot_transport_class():
602642
# check that the oos method is working
603643
assert_equal(transp_Xs_new.shape, Xs_new.shape)
604644

645+
# check label propagation
646+
transp_yt = otda.transform_labels(ys)
647+
assert_equal(transp_yt.shape[0], yt.shape[0])
648+
649+
# check inverse label propagation
650+
transp_ys = otda.inverse_transform_labels(yt)
651+
[assert_equal(x.shape, y.shape) for x, y in zip(transp_ys, ys)]
605652

606653
def test_jcpot_barycenter():
607654
"""test_jcpot_barycenter

0 commit comments

Comments
 (0)