@@ -16,14 +16,14 @@ def test_gpu_sinkhorn():
1616
1717 np .random .seed (0 )
1818
19- def describeRes (r ):
19+ def describe_res (r ):
2020 print ("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}" .format (
2121 np .min (r ), np .max (r ), np .mean (r ), np .std (r )))
2222
23- for n in [50 , 100 , 500 , 1000 ]:
24- print (n )
25- a = np .random .rand (n // 4 , 100 )
26- b = np .random .rand (n , 100 )
23+ for n_samples in [50 , 100 , 500 , 1000 ]:
24+ print (n_samples )
25+ a = np .random .rand (n_samples // 4 , 100 )
26+ b = np .random .rand (n_samples , 100 )
2727 time1 = time .time ()
2828 transport = ot .da .OTDA_sinkhorn ()
2929 transport .fit (a , b )
@@ -34,26 +34,26 @@ def describeRes(r):
3434 G2 = transport .G
3535 time3 = time .time ()
3636 print ("Normal sinkhorn, time: {:6.2f} sec " .format (time2 - time1 ))
37- describeRes (G1 )
37+ describe_res (G1 )
3838 print (" GPU sinkhorn, time: {:6.2f} sec " .format (time3 - time2 ))
39- describeRes (G2 )
39+ describe_res (G2 )
4040
41- assert np .allclose (G1 , G2 , rtol = 1e-5 , atol = 1e-5 )
41+ np .testing . assert_allclose (G1 , G2 , rtol = 1e-5 , atol = 1e-5 )
4242
4343
4444@pytest .mark .skipif (nogpu , reason = "No GPU available" )
4545def test_gpu_sinkhorn_lpl1 ():
4646 np .random .seed (0 )
4747
48- def describeRes (r ):
48+ def describe_res (r ):
4949 print ("min:{:.3E}, max:{:.3E}, mean:{:.3E}, std:{:.3E}"
5050 .format (np .min (r ), np .max (r ), np .mean (r ), np .std (r )))
5151
52- for n in [50 , 100 , 500 ]:
53- print (n )
54- a = np .random .rand (n // 4 , 100 )
55- labels_a = np .random .randint (10 , size = (n // 4 ))
56- b = np .random .rand (n , 100 )
52+ for n_samples in [50 , 100 , 500 ]:
53+ print (n_samples )
54+ a = np .random .rand (n_samples // 4 , 100 )
55+ labels_a = np .random .randint (10 , size = (n_samples // 4 ))
56+ b = np .random .rand (n_samples , 100 )
5757 time1 = time .time ()
5858 transport = ot .da .OTDA_lpl1 ()
5959 transport .fit (a , labels_a , b )
@@ -65,9 +65,9 @@ def describeRes(r):
6565 time3 = time .time ()
6666 print ("Normal sinkhorn lpl1, time: {:6.2f} sec " .format (
6767 time2 - time1 ))
68- describeRes (G1 )
68+ describe_res (G1 )
6969 print (" GPU sinkhorn lpl1, time: {:6.2f} sec " .format (
7070 time3 - time2 ))
71- describeRes (G2 )
71+ describe_res (G2 )
7272
73- assert np .allclose (G1 , G2 , rtol = 1e-5 , atol = 1e-5 )
73+ np .testing . assert_allclose (G1 , G2 , rtol = 1e-5 , atol = 1e-5 )
0 commit comments