Skip to content

Commit f8e822c

Browse files
committed
test sinkhorn with empty marginals
1 parent a31d3c2 commit f8e822c

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

test/test_bregman.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,33 @@ def test_sinkhorn():
2323
assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
2424

2525

26+
def test_sinkhorn_empty():
27+
# test sinkhorn
28+
n = 100
29+
np.random.seed(0)
30+
31+
x = np.random.randn(n, 2)
32+
u = ot.utils.unif(n)
33+
34+
M = ot.dist(x, x)
35+
36+
G = ot.sinkhorn([], [], M, 1, stopThr=1e-10)
37+
# check constratints
38+
assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
39+
assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
40+
41+
G = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method='sinkhorn_stabilized')
42+
# check constratints
43+
assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
44+
assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
45+
46+
G = ot.sinkhorn(
47+
[], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling')
48+
# check constratints
49+
assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
50+
assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
51+
52+
2653
def test_sinkhorn_variants():
2754
# test sinkhorn
2855
n = 100
@@ -37,7 +64,9 @@ def test_sinkhorn_variants():
3764
Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
3865
Ges = ot.sinkhorn(
3966
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
67+
Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
4068

41-
# check constratints
69+
# check values
4270
assert np.allclose(G0, Gs, atol=1e-05)
4371
assert np.allclose(G0, Ges, atol=1e-05)
72+
assert np.allclose(G0, Gerr)

test/test_ot.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,29 @@ def test_emd_emd2():
4040
assert np.allclose(w, 0)
4141

4242

43+
def test_emd_empty():
44+
# test emd and emd2 for simple identity
45+
n = 100
46+
np.random.seed(0)
47+
48+
x = np.random.randn(n, 2)
49+
u = ot.utils.unif(n)
50+
51+
M = ot.dist(x, x)
52+
53+
G = ot.emd([], [], M)
54+
55+
# check G is identity
56+
assert np.allclose(G, np.eye(n) / n)
57+
# check constratints
58+
assert np.allclose(u, G.sum(1)) # cf convergence sinkhorn
59+
assert np.allclose(u, G.sum(0)) # cf convergence sinkhorn
60+
61+
w = ot.emd2([], [], M)
62+
# check loss=0
63+
assert np.allclose(w, 0)
64+
65+
4366
def test_emd2_multi():
4467

4568
from ot.datasets import get_1D_gauss as gauss

0 commit comments

Comments
 (0)