Skip to content

Commit 55aaf78

Browse files
committed
add test gromov + debug sklearn Basestimator
1 parent 927395b commit 55aaf78

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

test/test_da.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,8 @@ def test_mapping_transport_class():
326326
"""test_mapping_transport
327327
"""
328328

329-
ns = 150
330-
nt = 200
329+
ns = 60
330+
nt = 120
331331

332332
Xs, ys = get_data_classif('3gauss', ns)
333333
Xt, yt = get_data_classif('3gauss2', nt)

test/test_gromov.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ def test_gromov():
3636
np.testing.assert_allclose(
3737
q, G.sum(0), atol=1e-04) # cf convergence gromov
3838

39+
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
40+
41+
G = log['T']
42+
43+
np.testing.assert_allclose(gw, 0, atol=1e-04, rtol=1e-4)
44+
45+
# check constratints
46+
np.testing.assert_allclose(
47+
p, G.sum(1), atol=1e-04) # cf convergence gromov
48+
np.testing.assert_allclose(
49+
q, G.sum(0), atol=1e-04) # cf convergence gromov
50+
3951

4052
def test_entropic_gromov():
4153
n_samples = 50 # nb samples
@@ -64,3 +76,16 @@ def test_entropic_gromov():
6476
p, G.sum(1), atol=1e-04) # cf convergence gromov
6577
np.testing.assert_allclose(
6678
q, G.sum(0), atol=1e-04) # cf convergence gromov
79+
80+
gw, log = ot.gromov.entropic_gromov_wasserstein2(
81+
C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
82+
83+
G = log['T']
84+
85+
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e1)
86+
87+
# check constratints
88+
np.testing.assert_allclose(
89+
p, G.sum(1), atol=1e-04) # cf convergence gromov
90+
np.testing.assert_allclose(
91+
q, G.sum(0), atol=1e-04) # cf convergence gromov

0 commit comments

Comments
 (0)