@@ -385,6 +385,20 @@ def test_gromov_barycenter(nx):
385385 np .testing .assert_allclose (Cb , Cbb , atol = 1e-06 )
386386 np .testing .assert_allclose (Cbb .shape , (n_samples , n_samples ))
387387
388+ # test of gromov_barycenters with `log` on
389+ Cb_ , err_ = ot .gromov .gromov_barycenters (
390+ n_samples , [C1 , C2 ], [p1 , p2 ], p , [.5 , .5 ],
391+ 'square_loss' , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
392+ )
393+ Cbb_ , errb_ = ot .gromov .gromov_barycenters (
394+ n_samples , [C1b , C2b ], [p1b , p2b ], pb , [.5 , .5 ],
395+ 'square_loss' , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
396+ )
397+ Cbb_ = nx .to_numpy (Cbb_ )
398+ np .testing .assert_allclose (Cb_ , Cbb_ , atol = 1e-06 )
399+ np .testing .assert_array_almost_equal (err_ ['err' ], errb_ ['err' ])
400+ np .testing .assert_allclose (Cbb_ .shape , (n_samples , n_samples ))
401+
388402 Cb2 = ot .gromov .gromov_barycenters (
389403 n_samples , [C1 , C2 ], [p1 , p2 ], p , [.5 , .5 ],
390404 'kl_loss' , max_iter = 100 , tol = 1e-3 , random_state = 42
@@ -396,6 +410,20 @@ def test_gromov_barycenter(nx):
396410 np .testing .assert_allclose (Cb2 , Cb2b , atol = 1e-06 )
397411 np .testing .assert_allclose (Cb2b .shape , (n_samples , n_samples ))
398412
413+ # test of gromov_barycenters with `log` on
414+ Cb2_ , err2_ = ot .gromov .gromov_barycenters (
415+ n_samples , [C1 , C2 ], [p1 , p2 ], p , [.5 , .5 ],
416+ 'kl_loss' , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
417+ )
418+ Cb2b_ , err2b_ = ot .gromov .gromov_barycenters (
419+ n_samples , [C1b , C2b ], [p1b , p2b ], pb , [.5 , .5 ],
420+ 'kl_loss' , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
421+ )
422+ Cb2b_ = nx .to_numpy (Cb2b_ )
423+ np .testing .assert_allclose (Cb2_ , Cb2b_ , atol = 1e-06 )
424+ np .testing .assert_array_almost_equal (err2_ ['err' ], err2_ ['err' ])
425+ np .testing .assert_allclose (Cb2b_ .shape , (n_samples , n_samples ))
426+
399427
400428@pytest .mark .filterwarnings ("ignore:divide" )
401429def test_gromov_entropic_barycenter (nx ):
@@ -429,6 +457,20 @@ def test_gromov_entropic_barycenter(nx):
429457 np .testing .assert_allclose (Cb , Cbb , atol = 1e-06 )
430458 np .testing .assert_allclose (Cbb .shape , (n_samples , n_samples ))
431459
460+ # test of entropic_gromov_barycenters with `log` on
461+ Cb_ , err_ = ot .gromov .entropic_gromov_barycenters (
462+ n_samples , [C1 , C2 ], [p1 , p2 ], p , [.5 , .5 ],
463+ 'square_loss' , 1e-3 , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
464+ )
465+ Cbb_ , errb_ = ot .gromov .entropic_gromov_barycenters (
466+ n_samples , [C1b , C2b ], [p1b , p2b ], pb , [.5 , .5 ],
467+ 'square_loss' , 1e-3 , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
468+ )
469+ Cbb_ = nx .to_numpy (Cbb_ )
470+ np .testing .assert_allclose (Cb_ , Cbb_ , atol = 1e-06 )
471+ np .testing .assert_array_almost_equal (err_ ['err' ], errb_ ['err' ])
472+ np .testing .assert_allclose (Cbb_ .shape , (n_samples , n_samples ))
473+
432474 Cb2 = ot .gromov .entropic_gromov_barycenters (
433475 n_samples , [C1 , C2 ], [p1 , p2 ], p , [.5 , .5 ],
434476 'kl_loss' , 1e-3 , max_iter = 100 , tol = 1e-3 , random_state = 42
@@ -440,6 +482,20 @@ def test_gromov_entropic_barycenter(nx):
440482 np .testing .assert_allclose (Cb2 , Cb2b , atol = 1e-06 )
441483 np .testing .assert_allclose (Cb2b .shape , (n_samples , n_samples ))
442484
485+ # test of entropic_gromov_barycenters with `log` on
486+ Cb2_ , err2_ = ot .gromov .entropic_gromov_barycenters (
487+ n_samples , [C1 , C2 ], [p1 , p2 ], p , [.5 , .5 ],
488+ 'kl_loss' , 1e-3 , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
489+ )
490+ Cb2b_ , err2b_ = ot .gromov .entropic_gromov_barycenters (
491+ n_samples , [C1b , C2b ], [p1b , p2b ], pb , [.5 , .5 ],
492+ 'kl_loss' , 1e-3 , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
493+ )
494+ Cb2b_ = nx .to_numpy (Cb2b_ )
495+ np .testing .assert_allclose (Cb2_ , Cb2b_ , atol = 1e-06 )
496+ np .testing .assert_array_almost_equal (err2_ ['err' ], err2_ ['err' ])
497+ np .testing .assert_allclose (Cb2b_ .shape , (n_samples , n_samples ))
498+
443499
444500def test_fgw (nx ):
445501 n_samples = 50 # nb samples
0 commit comments