@@ -36,6 +36,18 @@ def test_gromov():
3636 np .testing .assert_allclose (
3737 q , G .sum (0 ), atol = 1e-04 ) # cf convergence gromov
3838
39+ gw , log = ot .gromov .gromov_wasserstein2 (C1 , C2 , p , q , 'kl_loss' , log = True )
40+
41+ G = log ['T' ]
42+
43+ np .testing .assert_allclose (gw , 0 , atol = 1e-04 , rtol = 1e-4 )
44+
45+ # check constratints
46+ np .testing .assert_allclose (
47+ p , G .sum (1 ), atol = 1e-04 ) # cf convergence gromov
48+ np .testing .assert_allclose (
49+ q , G .sum (0 ), atol = 1e-04 ) # cf convergence gromov
50+
3951
4052def test_entropic_gromov ():
4153 n_samples = 50 # nb samples
@@ -64,3 +76,16 @@ def test_entropic_gromov():
6476 p , G .sum (1 ), atol = 1e-04 ) # cf convergence gromov
6577 np .testing .assert_allclose (
6678 q , G .sum (0 ), atol = 1e-04 ) # cf convergence gromov
79+
80+ gw , log = ot .gromov .entropic_gromov_wasserstein2 (
81+ C1 , C2 , p , q , 'kl_loss' , epsilon = 1e-2 , log = True )
82+
83+ G = log ['T' ]
84+
85+ np .testing .assert_allclose (gw , 0 , atol = 1e-1 , rtol = 1e1 )
86+
87+ # check constratints
88+ np .testing .assert_allclose (
89+ p , G .sum (1 ), atol = 1e-04 ) # cf convergence gromov
90+ np .testing .assert_allclose (
91+ q , G .sum (0 ), atol = 1e-04 ) # cf convergence gromov
0 commit comments