@@ -18,55 +18,54 @@ def test_doctest():
1818def test_emd_emd2 ():
1919 # test emd and emd2 for simple identity
2020 n = 100
21- np .random .seed (0 )
21+ rng = np .random .RandomState (0 )
2222
23- x = np . random .randn (n , 2 )
23+ x = rng .randn (n , 2 )
2424 u = ot .utils .unif (n )
2525
2626 M = ot .dist (x , x )
2727
2828 G = ot .emd (u , u , M )
2929
3030 # check G is identity
31- assert np .allclose (G , np .eye (n ) / n )
31+ np .testing . assert_allclose (G , np .eye (n ) / n )
3232 # check constratints
33- assert np .allclose (u , G .sum (1 )) # cf convergence sinkhorn
34- assert np .allclose (u , G .sum (0 )) # cf convergence sinkhorn
33+ np .testing . assert_allclose (u , G .sum (1 )) # cf convergence sinkhorn
34+ np .testing . assert_allclose (u , G .sum (0 )) # cf convergence sinkhorn
3535
3636 w = ot .emd2 (u , u , M )
3737 # check loss=0
38- assert np .allclose (w , 0 )
38+ np .testing . assert_allclose (w , 0 )
3939
4040
4141def test_emd_empty ():
4242 # test emd and emd2 for simple identity
4343 n = 100
44- np .random .seed (0 )
44+ rng = np .random .RandomState (0 )
4545
46- x = np . random .randn (n , 2 )
46+ x = rng .randn (n , 2 )
4747 u = ot .utils .unif (n )
4848
4949 M = ot .dist (x , x )
5050
5151 G = ot .emd ([], [], M )
5252
5353 # check G is identity
54- assert np .allclose (G , np .eye (n ) / n )
54+ np .testing . assert_allclose (G , np .eye (n ) / n )
5555 # check constratints
56- assert np .allclose (u , G .sum (1 )) # cf convergence sinkhorn
57- assert np .allclose (u , G .sum (0 )) # cf convergence sinkhorn
56+ np .testing . assert_allclose (u , G .sum (1 )) # cf convergence sinkhorn
57+ np .testing . assert_allclose (u , G .sum (0 )) # cf convergence sinkhorn
5858
5959 w = ot .emd2 ([], [], M )
6060 # check loss=0
61- assert np .allclose (w , 0 )
61+ np .testing . assert_allclose (w , 0 )
6262
6363
6464def test_emd2_multi ():
6565
6666 from ot .datasets import get_1D_gauss as gauss
6767
6868 n = 1000 # nb bins
69- np .random .seed (0 )
7069
7170 # bin positions
7271 x = np .arange (n , dtype = np .float64 )
@@ -96,4 +95,4 @@ def test_emd2_multi():
9695 emdn = ot .emd2 (a , b , M )
9796 ot .toc ('multi proc : {} s' )
9897
99- assert np .allclose (emd1 , emdn )
98+ np .testing . assert_allclose (emd1 , emdn )
0 commit comments