Skip to content

Commit 4a45135

Browse files
committed
dr +gpu numpy assert
1 parent 347e628 commit 4a45135

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

test/test_dr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_fda():
2929

3030
projfda(xs)
3131

32-
assert np.allclose(np.sum(Pfda**2, 0), np.ones(p))
32+
np.testing.assert_allclose(np.sum(Pfda**2, 0), np.ones(p))
3333

3434

3535
@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
@@ -51,4 +51,4 @@ def test_wda():
5151

5252
projwda(xs)
5353

54-
assert np.allclose(np.sum(Pwda**2, 0), np.ones(p))
54+
np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))

test/test_gpu.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ def test_gpu_sinkhorn():
1616

1717
np.random.seed(0)
1818

19-
def describeRes(r):
19+
def describe_res(r):
2020
print("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}".format(
2121
np.min(r), np.max(r), np.mean(r), np.std(r)))
2222

23-
for n in [50, 100, 500, 1000]:
24-
print(n)
25-
a = np.random.rand(n // 4, 100)
26-
b = np.random.rand(n, 100)
23+
for n_samples in [50, 100, 500, 1000]:
24+
print(n_samples)
25+
a = np.random.rand(n_samples // 4, 100)
26+
b = np.random.rand(n_samples, 100)
2727
time1 = time.time()
2828
transport = ot.da.OTDA_sinkhorn()
2929
transport.fit(a, b)
@@ -34,26 +34,26 @@ def describeRes(r):
3434
G2 = transport.G
3535
time3 = time.time()
3636
print("Normal sinkhorn, time: {:6.2f} sec ".format(time2 - time1))
37-
describeRes(G1)
37+
describe_res(G1)
3838
print(" GPU sinkhorn, time: {:6.2f} sec ".format(time3 - time2))
39-
describeRes(G2)
39+
describe_res(G2)
4040

41-
assert np.allclose(G1, G2, rtol=1e-5, atol=1e-5)
41+
np.testing.assert_allclose(G1, G2, rtol=1e-5, atol=1e-5)
4242

4343

4444
@pytest.mark.skipif(nogpu, reason="No GPU available")
4545
def test_gpu_sinkhorn_lpl1():
4646
np.random.seed(0)
4747

48-
def describeRes(r):
48+
def describe_res(r):
4949
print("min:{:.3E}, max:{:.3E}, mean:{:.3E}, std:{:.3E}"
5050
.format(np.min(r), np.max(r), np.mean(r), np.std(r)))
5151

52-
for n in [50, 100, 500]:
53-
print(n)
54-
a = np.random.rand(n // 4, 100)
55-
labels_a = np.random.randint(10, size=(n // 4))
56-
b = np.random.rand(n, 100)
52+
for n_samples in [50, 100, 500]:
53+
print(n_samples)
54+
a = np.random.rand(n_samples // 4, 100)
55+
labels_a = np.random.randint(10, size=(n_samples // 4))
56+
b = np.random.rand(n_samples, 100)
5757
time1 = time.time()
5858
transport = ot.da.OTDA_lpl1()
5959
transport.fit(a, labels_a, b)
@@ -65,9 +65,9 @@ def describeRes(r):
6565
time3 = time.time()
6666
print("Normal sinkhorn lpl1, time: {:6.2f} sec ".format(
6767
time2 - time1))
68-
describeRes(G1)
68+
describe_res(G1)
6969
print(" GPU sinkhorn lpl1, time: {:6.2f} sec ".format(
7070
time3 - time2))
71-
describeRes(G2)
71+
describe_res(G2)
7272

73-
assert np.allclose(G1, G2, rtol=1e-5, atol=1e-5)
73+
np.testing.assert_allclose(G1, G2, rtol=1e-5, atol=1e-5)

0 commit comments

Comments
 (0)