Skip to content

Commit 749378a

Browse files
author
ievred
committed
fix soft labels, remove gammas from jcpot
1 parent 1a4c264 commit 749378a

File tree

3 files changed

+38
-25
lines changed

3 files changed

+38
-25
lines changed

ot/bregman.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,8 +1553,6 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
15531553
15541554
Returns
15551555
-------
1556-
gamma : List of K (nsk x nt) ndarrays
1557-
Optimal transportation matrices for the given parameters for each pair of source and target domains
15581556
h : (C,) ndarray
15591557
proportion estimation in the target domain
15601558
log : dict
@@ -1574,7 +1572,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
15741572

15751573
# log dictionary
15761574
if log:
1577-
log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': []}
1575+
log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': [], 'gamma': []}
15781576

15791577
K = []
15801578
M = []
@@ -1657,9 +1655,10 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
16571655
log['M'] = M
16581656
log['D1'] = D1
16591657
log['D2'] = D2
1660-
return K, bary, log
1658+
log['gamma'] = K
1659+
return bary, log
16611660
else:
1662-
return K, bary
1661+
return bary
16631662

16641663

16651664
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',

ot/da.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -956,8 +956,8 @@ def transform_labels(self, ys=None):
956956
957957
Returns
958958
-------
959-
transp_ys : array-like, shape (n_target_samples,)
960-
Estimated target labels.
959+
transp_ys : array-like, shape (n_target_samples, nb_classes)
960+
Estimated soft target labels.
961961
962962
References
963963
----------
@@ -985,10 +985,10 @@ def transform_labels(self, ys=None):
985985
for c in classes:
986986
D1[int(c), ysTemp == c] = 1
987987

988-
# compute transported samples
988+
# compute propagated labels
989989
transp_ys = np.dot(D1, transp)
990990

991-
return np.argmax(transp_ys, axis=0)
991+
return transp_ys.T
992992

993993
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
994994
batch_size=128):
@@ -1066,8 +1066,8 @@ def inverse_transform_labels(self, yt=None):
10661066
10671067
Returns
10681068
-------
1069-
transp_ys : array-like, shape (n_source_samples,)
1070-
Estimated source labels.
1069+
transp_ys : array-like, shape (n_source_samples, nb_classes)
1070+
Estimated soft source labels.
10711071
"""
10721072

10731073
# check the necessary inputs parameters are here
@@ -1087,10 +1087,10 @@ def inverse_transform_labels(self, yt=None):
10871087
for c in classes:
10881088
D1[int(c), ytTemp == c] = 1
10891089

1090-
# compute transported samples
1090+
# compute propagated samples
10911091
transp_ys = np.dot(D1, transp.T)
10921092

1093-
return np.argmax(transp_ys, axis=0)
1093+
return transp_ys.T
10941094

10951095

10961096
class LinearTransport(BaseTransport):
@@ -2083,13 +2083,15 @@ class label
20832083

20842084
returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg=self.reg_e,
20852085
metric=self.metric, distrinumItermax=self.max_iter, stopThr=self.tol,
2086-
verbose=self.verbose, log=self.log)
2086+
verbose=self.verbose, log=True)
2087+
2088+
self.coupling_ = returned_[1]['gamma']
20872089

20882090
# deal with the value of log
20892091
if self.log:
2090-
self.coupling_, self.proportions_, self.log_ = returned_
2092+
self.proportions_, self.log_ = returned_
20912093
else:
2092-
self.coupling_, self.proportions_ = returned_
2094+
self.proportions_ = returned_
20932095
self.log_ = dict()
20942096

20952097
return self
@@ -2176,8 +2178,8 @@ def transform_labels(self, ys=None):
21762178
21772179
Returns
21782180
-------
2179-
yt : array-like, shape (n_target_samples,)
2180-
Estimated target labels.
2181+
yt : array-like, shape (n_target_samples, nb_classes)
2182+
Estimated soft target labels.
21812183
"""
21822184

21832185
# check the necessary inputs parameters are here
@@ -2203,10 +2205,10 @@ def transform_labels(self, ys=None):
22032205
for c in classes:
22042206
D1[int(c), ysTemp == c] = 1
22052207

2206-
# compute transported samples
2208+
# compute propagated labels
22072209
yt = yt + np.dot(D1, transp) / len(ys)
22082210

2209-
return np.argmax(yt, axis=0)
2211+
return yt.T
22102212

22112213
def inverse_transform_labels(self, yt=None):
22122214
"""Propagate source labels ys to obtain target labels
@@ -2218,8 +2220,8 @@ def inverse_transform_labels(self, yt=None):
22182220
22192221
Returns
22202222
-------
2221-
transp_ys : list of K array-like objects, shape K x (nk_source_samples,)
2222-
A list of estimated source labels
2223+
transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes)
2224+
A list of estimated soft source labels
22232225
"""
22242226

22252227
# check the necessary inputs parameters are here
@@ -2241,7 +2243,7 @@ def inverse_transform_labels(self, yt=None):
22412243
# set nans to 0
22422244
transp[~ np.isfinite(transp)] = 0
22432245

2244-
# compute transported labels
2245-
transp_ys.append(np.argmax(np.dot(D1, transp.T), axis=0))
2246+
# compute propagated labels
2247+
transp_ys.append(np.dot(D1, transp.T).T)
22462248

22472249
return transp_ys

test/test_da.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,12 @@ def test_sinkhorn_lpl1_transport_class():
6868
# check label propagation
6969
transp_yt = otda.transform_labels(ys)
7070
assert_equal(transp_yt.shape[0], yt.shape[0])
71+
assert_equal(transp_yt.shape[1], len(np.unique(ys)))
7172

7273
# check inverse label propagation
7374
transp_ys = otda.inverse_transform_labels(yt)
7475
assert_equal(transp_ys.shape[0], ys.shape[0])
76+
assert_equal(transp_ys.shape[1], len(np.unique(yt)))
7577

7678
# test unsupervised vs semi-supervised mode
7779
otda_unsup = ot.da.SinkhornLpl1Transport()
@@ -140,10 +142,12 @@ def test_sinkhorn_l1l2_transport_class():
140142
# check label propagation
141143
transp_yt = otda.transform_labels(ys)
142144
assert_equal(transp_yt.shape[0], yt.shape[0])
145+
assert_equal(transp_yt.shape[1], len(np.unique(ys)))
143146

144147
# check inverse label propagation
145148
transp_ys = otda.inverse_transform_labels(yt)
146149
assert_equal(transp_ys.shape[0], ys.shape[0])
150+
assert_equal(transp_ys.shape[1], len(np.unique(yt)))
147151

148152
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
149153
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -229,10 +233,12 @@ def test_sinkhorn_transport_class():
229233
# check label propagation
230234
transp_yt = otda.transform_labels(ys)
231235
assert_equal(transp_yt.shape[0], yt.shape[0])
236+
assert_equal(transp_yt.shape[1], len(np.unique(ys)))
232237

233238
# check inverse label propagation
234239
transp_ys = otda.inverse_transform_labels(yt)
235240
assert_equal(transp_ys.shape[0], ys.shape[0])
241+
assert_equal(transp_ys.shape[1], len(np.unique(yt)))
236242

237243
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
238244
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -298,10 +304,12 @@ def test_unbalanced_sinkhorn_transport_class():
298304
# check label propagation
299305
transp_yt = otda.transform_labels(ys)
300306
assert_equal(transp_yt.shape[0], yt.shape[0])
307+
assert_equal(transp_yt.shape[1], len(np.unique(ys)))
301308

302309
# check inverse label propagation
303310
transp_ys = otda.inverse_transform_labels(yt)
304311
assert_equal(transp_ys.shape[0], ys.shape[0])
312+
assert_equal(transp_ys.shape[1], len(np.unique(yt)))
305313

306314
Xs_new, _ = make_data_classif('3gauss', ns + 1)
307315
transp_Xs_new = otda.transform(Xs_new)
@@ -388,10 +396,12 @@ def test_emd_transport_class():
388396
# check label propagation
389397
transp_yt = otda.transform_labels(ys)
390398
assert_equal(transp_yt.shape[0], yt.shape[0])
399+
assert_equal(transp_yt.shape[1], len(np.unique(ys)))
391400

392401
# check inverse label propagation
393402
transp_ys = otda.inverse_transform_labels(yt)
394403
assert_equal(transp_ys.shape[0], ys.shape[0])
404+
assert_equal(transp_ys.shape[1], len(np.unique(yt)))
395405

396406
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
397407
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -645,10 +655,12 @@ def test_jcpot_transport_class():
645655
# check label propagation
646656
transp_yt = otda.transform_labels(ys)
647657
assert_equal(transp_yt.shape[0], yt.shape[0])
658+
assert_equal(transp_yt.shape[1], len(np.unique(ys)))
648659

649660
# check inverse label propagation
650661
transp_ys = otda.inverse_transform_labels(yt)
651-
[assert_equal(x.shape, y.shape) for x, y in zip(transp_ys, ys)]
662+
[assert_equal(x.shape[0], y.shape[0]) for x, y in zip(transp_ys, ys)]
663+
[assert_equal(x.shape[1], len(np.unique(y))) for x, y in zip(transp_ys, ys)]
652664

653665

654666
def test_jcpot_barycenter():

0 commit comments

Comments
 (0)