Skip to content

Commit 67b011a

Browse files
committed
numpy assert test_da
1 parent 68d7490 commit 67b011a

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

test/test_da.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
# import pytest
77

88

9-
def test_OTDA():
9+
def test_otda():
1010

11-
n = 150 # nb bins
11+
n = 150 # nb samples
12+
np.random.seed(0)
1213

1314
xs, ys = ot.datasets.get_data_classif('3gauss', n)
1415
xt, yt = ot.datasets.get_data_classif('3gauss2', n)
@@ -21,8 +22,8 @@ def test_OTDA():
2122
da_emd.interp() # interpolation of source samples
2223
da_emd.predict(xs) # interpolation of source samples
2324

24-
assert np.allclose(a, np.sum(da_emd.G, 1))
25-
assert np.allclose(b, np.sum(da_emd.G, 0))
25+
np.testing.assert_allclose(a, np.sum(da_emd.G, 1))
26+
np.testing.assert_allclose(b, np.sum(da_emd.G, 0))
2627

2728
# sinkhorn regularization
2829
lambd = 1e-1
@@ -31,8 +32,8 @@ def test_OTDA():
3132
da_entrop.interp()
3233
da_entrop.predict(xs)
3334

34-
assert np.allclose(a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
35-
assert np.allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3)
35+
np.testing.assert_allclose(a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
36+
np.testing.assert_allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3)
3637

3738
# non-convex Group lasso regularization
3839
reg = 1e-1
@@ -42,8 +43,8 @@ def test_OTDA():
4243
da_lpl1.interp()
4344
da_lpl1.predict(xs)
4445

45-
assert np.allclose(a, np.sum(da_lpl1.G, 1), rtol=1e-3, atol=1e-3)
46-
assert np.allclose(b, np.sum(da_lpl1.G, 0), rtol=1e-3, atol=1e-3)
46+
np.testing.assert_allclose(a, np.sum(da_lpl1.G, 1), rtol=1e-3, atol=1e-3)
47+
np.testing.assert_allclose(b, np.sum(da_lpl1.G, 0), rtol=1e-3, atol=1e-3)
4748

4849
# True Group lasso regularization
4950
reg = 1e-1
@@ -53,8 +54,8 @@ def test_OTDA():
5354
da_l1l2.interp()
5455
da_l1l2.predict(xs)
5556

56-
assert np.allclose(a, np.sum(da_l1l2.G, 1), rtol=1e-3, atol=1e-3)
57-
assert np.allclose(b, np.sum(da_l1l2.G, 0), rtol=1e-3, atol=1e-3)
57+
np.testing.assert_allclose(a, np.sum(da_l1l2.G, 1), rtol=1e-3, atol=1e-3)
58+
np.testing.assert_allclose(b, np.sum(da_l1l2.G, 0), rtol=1e-3, atol=1e-3)
5859

5960
# linear mapping
6061
da_emd = ot.da.OTDA_mapping_linear() # init class

0 commit comments

Comments
 (0)