@@ -23,6 +23,33 @@ def test_sinkhorn():
2323 assert np .allclose (u , G .sum (0 ), atol = 1e-05 ) # cf convergence sinkhorn
2424
2525
26+ def test_sinkhorn_empty ():
27+ # test sinkhorn
28+ n = 100
29+ np .random .seed (0 )
30+
31+ x = np .random .randn (n , 2 )
32+ u = ot .utils .unif (n )
33+
34+ M = ot .dist (x , x )
35+
36+ G = ot .sinkhorn ([], [], M , 1 , stopThr = 1e-10 )
37+ # check constratints
38+ assert np .allclose (u , G .sum (1 ), atol = 1e-05 ) # cf convergence sinkhorn
39+ assert np .allclose (u , G .sum (0 ), atol = 1e-05 ) # cf convergence sinkhorn
40+
41+ G = ot .sinkhorn ([], [], M , 1 , stopThr = 1e-10 , method = 'sinkhorn_stabilized' )
42+ # check constratints
43+ assert np .allclose (u , G .sum (1 ), atol = 1e-05 ) # cf convergence sinkhorn
44+ assert np .allclose (u , G .sum (0 ), atol = 1e-05 ) # cf convergence sinkhorn
45+
46+ G = ot .sinkhorn (
47+ [], [], M , 1 , stopThr = 1e-10 , method = 'sinkhorn_epsilon_scaling' )
48+ # check constratints
49+ assert np .allclose (u , G .sum (1 ), atol = 1e-05 ) # cf convergence sinkhorn
50+ assert np .allclose (u , G .sum (0 ), atol = 1e-05 ) # cf convergence sinkhorn
51+
52+
2653def test_sinkhorn_variants ():
2754 # test sinkhorn
2855 n = 100
@@ -37,7 +64,9 @@ def test_sinkhorn_variants():
3764 Gs = ot .sinkhorn (u , u , M , 1 , method = 'sinkhorn_stabilized' , stopThr = 1e-10 )
3865 Ges = ot .sinkhorn (
3966 u , u , M , 1 , method = 'sinkhorn_epsilon_scaling' , stopThr = 1e-10 )
67+ Gerr = ot .sinkhorn (u , u , M , 1 , method = 'do_not_exists' , stopThr = 1e-10 )
4068
41- # check constratints
69+ # check values
4270 assert np .allclose (G0 , Gs , atol = 1e-05 )
4371 assert np .allclose (G0 , Ges , atol = 1e-05 )
72+ assert np .allclose (G0 , Gerr )
0 commit comments