Skip to content

Commit bc51793

Browse files
author
ievred
committed
added test barycenter + modif target
1 parent 08d0bf9 commit bc51793

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

ot/bregman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1528,7 +1528,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
15281528
The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain.
15291529
15301530
The algorithm used for solving the problem is the Iterative Bregman projections algorithm
1531-
with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution.
1531+
with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution.
15321532
15331533
Parameters
15341534
----------

test/test_da.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,3 +601,31 @@ def test_jcpot_transport_class():
601601

602602
# check that the oos method is working
603603
assert_equal(transp_Xs_new.shape, Xs_new.shape)
604+
605+
606+
def test_jcpot_barycenter():
607+
"""test_jcpot_barycenter
608+
"""
609+
610+
ns1 = 150
611+
ns2 = 150
612+
nt = 200
613+
614+
sigma = 0.1
615+
np.random.seed(1985)
616+
617+
ps1 = .2
618+
ps2 = .9
619+
pt = .4
620+
621+
Xs1, ys1 = make_data_classif('2gauss_prop', ns1, nz=sigma, p=ps1)
622+
Xs2, ys2 = make_data_classif('2gauss_prop', ns2, nz=sigma, p=ps2)
623+
Xt, yt = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt)
624+
625+
Xs = [Xs1, Xs2]
626+
ys = [ys1, ys2]
627+
628+
_, prop, = ot.bregman.jcpot_barycenter(Xs, ys, Xt, reg=.5, metric='sqeuclidean',
629+
numItermax=10000, stopThr=1e-9, verbose=False, log=False)
630+
631+
np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3)

0 commit comments

Comments
 (0)