Skip to content

Commit 2c9f992

Browse files
author
ievred
committed
upd
1 parent 34e13d4 commit 2c9f992

File tree

4 files changed

+10
-11
lines changed

4 files changed

+10
-11
lines changed

examples/plot_otda_classes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import matplotlib.pylab as pl
1818
import ot
1919

20-
2120
##############################################################################
2221
# Generate data
2322
# -------------

examples/plot_otda_jcpot.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,16 @@ def print_G(G, xs, ys, xt):
118118
otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)
119119
otda.fit(all_Xr, all_Yr, xt)
120120

121-
ws1 = otda.proportions_.dot(otda.log_['all_domains'][0]['D2'])
122-
ws2 = otda.proportions_.dot(otda.log_['all_domains'][1]['D2'])
121+
ws1 = otda.proportions_.dot(otda.log_['D2'][0])
122+
ws2 = otda.proportions_.dot(otda.log_['D2'][1])
123123

124124
pl.figure(3)
125125
pl.clf()
126126
plot_ax(dec1, 'Source 1')
127127
plot_ax(dec2, 'Source 2')
128128
plot_ax(dect, 'Target')
129-
print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt)
130-
print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt)
129+
print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt)
130+
print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt)
131131
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
132132
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
133133
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
@@ -146,16 +146,16 @@ def print_G(G, xs, ys, xt):
146146
# ----------------------------------------------------------------------------
147147
h_res = np.array([1 - pt, pt])
148148

149-
ws1 = h_res.dot(otda.log_['all_domains'][0]['D2'])
150-
ws2 = h_res.dot(otda.log_['all_domains'][1]['D2'])
149+
ws1 = h_res.dot(otda.log_['D2'][0])
150+
ws2 = h_res.dot(otda.log_['D2'][1])
151151

152152
pl.figure(4)
153153
pl.clf()
154154
plot_ax(dec1, 'Source 1')
155155
plot_ax(dec2, 'Source 2')
156156
plot_ax(dect, 'Target')
157-
print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt)
158-
print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt)
157+
print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt)
158+
print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt)
159159
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
160160
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
161161
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)

ot/bregman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1608,7 +1608,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
16081608
# build the cost matrix and the Gibbs kernel
16091609
Mtmp = dist(Xs[d], Xt, metric=metric)
16101610
Mtmp = Mtmp / np.median(Mtmp)
1611-
M.append(M)
1611+
M.append(Mtmp)
16121612

16131613
Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype)
16141614
np.divide(Mtmp, -reg, out=Ktmp)

test/test_da.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def test_jcpot_transport_class():
589589
# test margin constraints w.r.t. modified source weights for each source domain
590590

591591
assert_allclose(
592-
np.dot(otda.log_['all_domains'][i]['D1'], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3,
592+
np.dot(otda.log_['D1'][i], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3,
593593
atol=1e-3)
594594

595595
# test transform

0 commit comments

Comments
 (0)