Skip to content

Commit 1b58440

Browse files
authored
Merge pull request #116 from kilianFatras/emd_dimension
[MRG] Add assert for emd dimension mismatch
2 parents c5039bc + 4015474 commit 1b58440

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

ot/lp/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
114114
if len(b) == 0:
115115
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
116116

117+
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
118+
"Dimension mismatch, check dimensions of M with a and b"
119+
117120
if dense:
118121
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
119122
else:
@@ -226,6 +229,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
226229
if len(b) == 0:
227230
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
228231

232+
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
233+
"Dimension mismatch, check dimensions of M with a and b"
234+
229235
if log or return_matrix:
230236
def f(b):
231237
if dense:

test/test_ot.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,22 @@
1414
import pytest
1515

1616

17+
def test_emd_dimension_mismatch():
18+
# test emd and emd2 for dimension mismatch
19+
n_samples = 100
20+
n_features = 2
21+
rng = np.random.RandomState(0)
22+
23+
x = rng.randn(n_samples, n_features)
24+
a = ot.utils.unif(n_samples + 1)
25+
26+
M = ot.dist(x, x)
27+
28+
np.testing.assert_raises(AssertionError, ot.emd, a, a, M)
29+
30+
np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
31+
32+
1733
def test_emd_emd2():
1834
# test emd and emd2 for simple identity
1935
n = 100

0 commit comments

Comments
 (0)