Skip to content

Commit 92233f7

Browse files
committed
add assert for emd dimension mismatch
1 parent 0280a34 commit 92233f7

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

ot/lp/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ def emd(a, b, M, numItermax=100000, log=False):
109109
if len(b) == 0:
110110
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
111111

112+
assert (a.shape[0] == M.shape[0] or b.shape[0] == M.shape[1]), \
113+
"Dimension mismatch, check dimensions of M with a and b"
114+
112115
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
113116
result_code_string = check_result(result_code)
114117
if log:
@@ -212,6 +215,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
212215
if len(b) == 0:
213216
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
214217

218+
assert (a.shape[0] == M.shape[0] or b.shape[0] == M.shape[1]), \
219+
"Dimension mismatch, check dimensions of M with a and b"
220+
215221
if log or return_matrix:
216222
def f(b):
217223
G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)

test/test_ot.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,22 @@
1414
import pytest
1515

1616

17+
def test_emd_dimension_mismatch():
18+
# test emd and emd2 for simple identity
19+
n_samples = 100
20+
n_features = 2
21+
rng = np.random.RandomState(0)
22+
23+
x = rng.randn(n_samples, n_features)
24+
a = ot.utils.unif(n_samples + 1)
25+
26+
M = ot.dist(x, x)
27+
28+
np.testing.assert_raises(AssertionError, emd, a, a, M)
29+
30+
np.testing.assert_raises(AssertionError, emd2, a, a, M)
31+
32+
1733
def test_emd_emd2():
1834
# test emd and emd2 for simple identity
1935
n = 100

0 commit comments

Comments
 (0)