Skip to content

Commit 81118f2

Browse files
committed
test_ot random state
1 parent 0e06129 commit 81118f2

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

test/test_ot.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,55 +18,54 @@ def test_doctest():
1818
def test_emd_emd2():
1919
# test emd and emd2 for simple identity
2020
n = 100
21-
np.random.seed(0)
21+
rng = np.random.RandomState(0)
2222

23-
x = np.random.randn(n, 2)
23+
x = rng.randn(n, 2)
2424
u = ot.utils.unif(n)
2525

2626
M = ot.dist(x, x)
2727

2828
G = ot.emd(u, u, M)
2929

3030
# check G is identity
31-
assert np.allclose(G, np.eye(n) / n)
31+
np.testing.assert_allclose(G, np.eye(n) / n)
3232
# check constratints
33-
assert np.allclose(u, G.sum(1)) # cf convergence sinkhorn
34-
assert np.allclose(u, G.sum(0)) # cf convergence sinkhorn
33+
np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
34+
np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
3535

3636
w = ot.emd2(u, u, M)
3737
# check loss=0
38-
assert np.allclose(w, 0)
38+
np.testing.assert_allclose(w, 0)
3939

4040

4141
def test_emd_empty():
4242
# test emd and emd2 for simple identity
4343
n = 100
44-
np.random.seed(0)
44+
rng = np.random.RandomState(0)
4545

46-
x = np.random.randn(n, 2)
46+
x = rng.randn(n, 2)
4747
u = ot.utils.unif(n)
4848

4949
M = ot.dist(x, x)
5050

5151
G = ot.emd([], [], M)
5252

5353
# check G is identity
54-
assert np.allclose(G, np.eye(n) / n)
54+
np.testing.assert_allclose(G, np.eye(n) / n)
5555
# check constratints
56-
assert np.allclose(u, G.sum(1)) # cf convergence sinkhorn
57-
assert np.allclose(u, G.sum(0)) # cf convergence sinkhorn
56+
np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
57+
np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
5858

5959
w = ot.emd2([], [], M)
6060
# check loss=0
61-
assert np.allclose(w, 0)
61+
np.testing.assert_allclose(w, 0)
6262

6363

6464
def test_emd2_multi():
6565

6666
from ot.datasets import get_1D_gauss as gauss
6767

6868
n = 1000 # nb bins
69-
np.random.seed(0)
7069

7170
# bin positions
7271
x = np.arange(n, dtype=np.float64)
@@ -96,4 +95,4 @@ def test_emd2_multi():
9695
emdn = ot.emd2(a, b, M)
9796
ot.toc('multi proc : {} s')
9897

99-
assert np.allclose(emd1, emdn)
98+
np.testing.assert_allclose(emd1, emdn)

0 commit comments

Comments
 (0)