@@ -40,7 +40,7 @@ def test_gromov():
4040
4141 G = log ['T' ]
4242
43- np .testing .assert_allclose (gw , 0 , atol = 1e-2 , rtol = 1e-2 )
43+ np .testing .assert_allclose (gw , 0 , atol = 1e-1 , rtol = 1e-1 )
4444
4545 # check constratints
4646 np .testing .assert_allclose (
@@ -82,10 +82,37 @@ def test_entropic_gromov():
8282
8383 G = log ['T' ]
8484
85- np .testing .assert_allclose (gw , 0 , atol = 1e-1 , rtol = 1e1 )
85+ np .testing .assert_allclose (gw , 0 , atol = 1e-1 , rtol = 1e-1 )
8686
8787 # check constratints
8888 np .testing .assert_allclose (
8989 p , G .sum (1 ), atol = 1e-04 ) # cf convergence gromov
9090 np .testing .assert_allclose (
9191 q , G .sum (0 ), atol = 1e-04 ) # cf convergence gromov
92+
93+
94+ def test_gromov_barycenter ():
95+
96+ ns = 50
97+ nt = 60
98+
99+ Xs , ys = ot .datasets .get_data_classif ('3gauss' , ns )
100+ Xt , yt = ot .datasets .get_data_classif ('3gauss2' , nt )
101+
102+ C1 = ot .dist (Xs )
103+ C2 = ot .dist (Xt )
104+
105+ n_samples = 3
106+ Cb = ot .gromov .gromov_barycenters (n_samples , [C1 , C2 ],
107+ [ot .unif (ns ), ot .unif (nt )
108+ ], ot .unif (n_samples ), [.5 , .5 ],
109+ 'square_loss' , # 5e-4,
110+ max_iter = 100 , tol = 1e-3 )
111+ np .testing .assert_allclose (Cb .shape , (n_samples , n_samples ))
112+
113+ Cb2 = ot .gromov .gromov_barycenters (n_samples , [C1 , C2 ],
114+ [ot .unif (ns ), ot .unif (nt )
115+ ], ot .unif (n_samples ), [.5 , .5 ],
116+ 'kl_loss' , # 5e-4,
117+ max_iter = 100 , tol = 1e-3 )
118+ np .testing .assert_allclose (Cb2 .shape , (n_samples , n_samples ))
0 commit comments