Skip to content

Commit 1e2e118

Browse files
Fix test
1 parent 592f933 commit 1e2e118

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

test/test_ot.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import warnings
88

99
import numpy as np
10+
import pytest
1011
from scipy.stats import wasserstein_distance
1112

1213
import ot
1314
from ot.datasets import make_1D_gauss as gauss
14-
import pytest
1515

1616

1717
def test_emd_dimension_mismatch():
@@ -75,12 +75,12 @@ def test_emd_1d_emd2_1d():
7575
np.testing.assert_allclose(wass, wass1d_emd2)
7676

7777
# check loss is similar to scipy's implementation for Euclidean metric
78-
wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, )))
78+
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
7979
np.testing.assert_allclose(wass_sp, wass1d_euc)
8080

8181
# check constraints
82-
np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1))
83-
np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0))
82+
np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
83+
np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
8484

8585
# check G is similar
8686
np.testing.assert_allclose(G, G_1d)
@@ -91,8 +91,8 @@ def test_emd_1d_emd2_1d():
9191
with pytest.raises(AssertionError):
9292
ot.emd_1d(u, v, [], [])
9393

94-
def test_emd_1d_emd2_1d_with_weights():
9594

95+
def test_emd_1d_emd2_1d_with_weights():
9696
# test emd1d gives similar results as emd
9797
n = 20
9898
m = 30
@@ -120,16 +120,14 @@ def test_emd_1d_emd2_1d_with_weights():
120120
np.testing.assert_allclose(wass, wass1d_emd2)
121121

122122
# check loss is similar to scipy's implementation for Euclidean metric
123-
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
123+
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
124124
np.testing.assert_allclose(wass_sp, wass1d_euc)
125125

126126
# check constraints
127127
np.testing.assert_allclose(w_u, G.sum(1))
128128
np.testing.assert_allclose(w_v, G.sum(0))
129129

130130

131-
132-
133131
def test_wass_1d():
134132
# test emd1d gives similar results as emd
135133
n = 20
@@ -173,7 +171,6 @@ def test_emd_empty():
173171

174172

175173
def test_emd_sparse():
176-
177174
n = 100
178175
rng = np.random.RandomState(0)
179176

@@ -249,7 +246,6 @@ def test_emd2_multi():
249246

250247

251248
def test_lp_barycenter():
252-
253249
a1 = np.array([1.0, 0, 0])[:, None]
254250
a2 = np.array([0, 0, 1.0])[:, None]
255251

@@ -266,7 +262,6 @@ def test_lp_barycenter():
266262

267263

268264
def test_free_support_barycenter():
269-
270265
measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
271266
measures_weights = [np.array([1.]), np.array([1.])]
272267

@@ -282,7 +277,6 @@ def test_free_support_barycenter():
282277

283278
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
284279
def test_lp_barycenter_cvxopt():
285-
286280
a1 = np.array([1.0, 0, 0])[:, None]
287281
a2 = np.array([0, 0, 1.0])[:, None]
288282

0 commit comments

Comments
 (0)