66# import pytest
77
88
9- def test_OTDA ():
9+ def test_otda ():
1010
11- n = 150 # nb bins
11+ n = 150 # nb samples
12+ np .random .seed (0 )
1213
1314 xs , ys = ot .datasets .get_data_classif ('3gauss' , n )
1415 xt , yt = ot .datasets .get_data_classif ('3gauss2' , n )
@@ -21,8 +22,8 @@ def test_OTDA():
2122 da_emd .interp () # interpolation of source samples
2223 da_emd .predict (xs ) # interpolation of source samples
2324
24- assert np .allclose (a , np .sum (da_emd .G , 1 ))
25- assert np .allclose (b , np .sum (da_emd .G , 0 ))
25+ np .testing . assert_allclose (a , np .sum (da_emd .G , 1 ))
26+ np .testing . assert_allclose (b , np .sum (da_emd .G , 0 ))
2627
2728 # sinkhorn regularization
2829 lambd = 1e-1
@@ -31,8 +32,8 @@ def test_OTDA():
3132 da_entrop .interp ()
3233 da_entrop .predict (xs )
3334
34- assert np .allclose (a , np .sum (da_entrop .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
35- assert np .allclose (b , np .sum (da_entrop .G , 0 ), rtol = 1e-3 , atol = 1e-3 )
35+ np .testing . assert_allclose (a , np .sum (da_entrop .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
36+ np .testing . assert_allclose (b , np .sum (da_entrop .G , 0 ), rtol = 1e-3 , atol = 1e-3 )
3637
3738 # non-convex Group lasso regularization
3839 reg = 1e-1
@@ -42,8 +43,8 @@ def test_OTDA():
4243 da_lpl1 .interp ()
4344 da_lpl1 .predict (xs )
4445
45- assert np .allclose (a , np .sum (da_lpl1 .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
46- assert np .allclose (b , np .sum (da_lpl1 .G , 0 ), rtol = 1e-3 , atol = 1e-3 )
46+ np .testing . assert_allclose (a , np .sum (da_lpl1 .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
47+ np .testing . assert_allclose (b , np .sum (da_lpl1 .G , 0 ), rtol = 1e-3 , atol = 1e-3 )
4748
4849 # True Group lasso regularization
4950 reg = 1e-1
@@ -53,8 +54,8 @@ def test_OTDA():
5354 da_l1l2 .interp ()
5455 da_l1l2 .predict (xs )
5556
56- assert np .allclose (a , np .sum (da_l1l2 .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
57- assert np .allclose (b , np .sum (da_l1l2 .G , 0 ), rtol = 1e-3 , atol = 1e-3 )
57+ np .testing . assert_allclose (a , np .sum (da_l1l2 .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
58+ np .testing . assert_allclose (b , np .sum (da_l1l2 .G , 0 ), rtol = 1e-3 , atol = 1e-3 )
5859
5960 # linear mapping
6061 da_emd = ot .da .OTDA_mapping_linear () # init class
0 commit comments