77import warnings
88
99import numpy as np
10+ import pytest
1011from scipy .stats import wasserstein_distance
1112
1213import ot
1314from ot .datasets import make_1D_gauss as gauss
14- import pytest
1515
1616
1717def test_emd_dimension_mismatch ():
@@ -75,12 +75,12 @@ def test_emd_1d_emd2_1d():
7575 np .testing .assert_allclose (wass , wass1d_emd2 )
7676
7777 # check loss is similar to scipy's implementation for Euclidean metric
78- wass_sp = wasserstein_distance (u .reshape ((- 1 , )), v .reshape ((- 1 , )))
78+ wass_sp = wasserstein_distance (u .reshape ((- 1 ,)), v .reshape ((- 1 ,)))
7979 np .testing .assert_allclose (wass_sp , wass1d_euc )
8080
8181 # check constraints
82- np .testing .assert_allclose (np .ones ((n , )) / n , G .sum (1 ))
83- np .testing .assert_allclose (np .ones ((m , )) / m , G .sum (0 ))
82+ np .testing .assert_allclose (np .ones ((n ,)) / n , G .sum (1 ))
83+ np .testing .assert_allclose (np .ones ((m ,)) / m , G .sum (0 ))
8484
8585 # check G is similar
8686 np .testing .assert_allclose (G , G_1d )
@@ -91,8 +91,8 @@ def test_emd_1d_emd2_1d():
9191 with pytest .raises (AssertionError ):
9292 ot .emd_1d (u , v , [], [])
9393
94- def test_emd_1d_emd2_1d_with_weights ():
9594
95+ def test_emd_1d_emd2_1d_with_weights ():
9696 # test emd1d gives similar results as emd
9797 n = 20
9898 m = 30
@@ -120,16 +120,14 @@ def test_emd_1d_emd2_1d_with_weights():
120120 np .testing .assert_allclose (wass , wass1d_emd2 )
121121
122122 # check loss is similar to scipy's implementation for Euclidean metric
123- wass_sp = wasserstein_distance (u .reshape ((- 1 ,)), v .reshape ((- 1 ,)))
123+ wass_sp = wasserstein_distance (u .reshape ((- 1 ,)), v .reshape ((- 1 ,)), w_u , w_v )
124124 np .testing .assert_allclose (wass_sp , wass1d_euc )
125125
126126 # check constraints
127127 np .testing .assert_allclose (w_u , G .sum (1 ))
128128 np .testing .assert_allclose (w_v , G .sum (0 ))
129129
130130
131-
132-
133131def test_wass_1d ():
134132 # test emd1d gives similar results as emd
135133 n = 20
@@ -173,7 +171,6 @@ def test_emd_empty():
173171
174172
175173def test_emd_sparse ():
176-
177174 n = 100
178175 rng = np .random .RandomState (0 )
179176
@@ -249,7 +246,6 @@ def test_emd2_multi():
249246
250247
251248def test_lp_barycenter ():
252-
253249 a1 = np .array ([1.0 , 0 , 0 ])[:, None ]
254250 a2 = np .array ([0 , 0 , 1.0 ])[:, None ]
255251
@@ -266,7 +262,6 @@ def test_lp_barycenter():
266262
267263
268264def test_free_support_barycenter ():
269-
270265 measures_locations = [np .array ([- 1. ]).reshape ((1 , 1 )), np .array ([1. ]).reshape ((1 , 1 ))]
271266 measures_weights = [np .array ([1. ]), np .array ([1. ])]
272267
@@ -282,7 +277,6 @@ def test_free_support_barycenter():
282277
283278@pytest .mark .skipif (not ot .lp .cvx .cvxopt , reason = "No cvxopt available" )
284279def test_lp_barycenter_cvxopt ():
285-
286280 a1 = np .array ([1.0 , 0 , 0 ])[:, None ]
287281 a2 = np .array ([0 , 0 , 1.0 ])[:, None ]
288282
0 commit comments