77import warnings
88
99import numpy as np
10+ import pytest
1011from scipy .stats import wasserstein_distance
1112
1213import ot
1314from ot .datasets import make_1D_gauss as gauss
14- import pytest
1515
1616
1717def test_emd_dimension_mismatch ():
@@ -75,12 +75,12 @@ def test_emd_1d_emd2_1d():
7575 np .testing .assert_allclose (wass , wass1d_emd2 )
7676
7777 # check loss is similar to scipy's implementation for Euclidean metric
78- wass_sp = wasserstein_distance (u .reshape ((- 1 , )), v .reshape ((- 1 , )))
78+ wass_sp = wasserstein_distance (u .reshape ((- 1 ,)), v .reshape ((- 1 ,)))
7979 np .testing .assert_allclose (wass_sp , wass1d_euc )
8080
8181 # check constraints
82- np .testing .assert_allclose (np .ones ((n , )) / n , G .sum (1 ))
83- np .testing .assert_allclose (np .ones ((m , )) / m , G .sum (0 ))
82+ np .testing .assert_allclose (np .ones ((n ,)) / n , G .sum (1 ))
83+ np .testing .assert_allclose (np .ones ((m ,)) / m , G .sum (0 ))
8484
8585 # check G is similar
8686 np .testing .assert_allclose (G , G_1d )
@@ -92,6 +92,42 @@ def test_emd_1d_emd2_1d():
9292 ot .emd_1d (u , v , [], [])
9393
9494
95+ def test_emd_1d_emd2_1d_with_weights ():
96+ # test emd1d gives similar results as emd
97+ n = 20
98+ m = 30
99+ rng = np .random .RandomState (0 )
100+ u = rng .randn (n , 1 )
101+ v = rng .randn (m , 1 )
102+
103+ w_u = rng .uniform (0. , 1. , n )
104+ w_u = w_u / w_u .sum ()
105+
106+ w_v = rng .uniform (0. , 1. , m )
107+ w_v = w_v / w_v .sum ()
108+
109+ M = ot .dist (u , v , metric = 'sqeuclidean' )
110+
111+ G , log = ot .emd (w_u , w_v , M , log = True )
112+ wass = log ["cost" ]
113+ G_1d , log = ot .emd_1d (u , v , w_u , w_v , metric = 'sqeuclidean' , log = True )
114+ wass1d = log ["cost" ]
115+ wass1d_emd2 = ot .emd2_1d (u , v , w_u , w_v , metric = 'sqeuclidean' , log = False )
116+ wass1d_euc = ot .emd2_1d (u , v , w_u , w_v , metric = 'euclidean' , log = False )
117+
118+ # check loss is similar
119+ np .testing .assert_allclose (wass , wass1d )
120+ np .testing .assert_allclose (wass , wass1d_emd2 )
121+
122+ # check loss is similar to scipy's implementation for Euclidean metric
123+ wass_sp = wasserstein_distance (u .reshape ((- 1 ,)), v .reshape ((- 1 ,)), w_u , w_v )
124+ np .testing .assert_allclose (wass_sp , wass1d_euc )
125+
126+ # check constraints
127+ np .testing .assert_allclose (w_u , G .sum (1 ))
128+ np .testing .assert_allclose (w_v , G .sum (0 ))
129+
130+
95131def test_wass_1d ():
96132 # test emd1d gives similar results as emd
97133 n = 20
@@ -135,7 +171,6 @@ def test_emd_empty():
135171
136172
137173def test_emd_sparse ():
138-
139174 n = 100
140175 rng = np .random .RandomState (0 )
141176
@@ -211,7 +246,6 @@ def test_emd2_multi():
211246
212247
213248def test_lp_barycenter ():
214-
215249 a1 = np .array ([1.0 , 0 , 0 ])[:, None ]
216250 a2 = np .array ([0 , 0 , 1.0 ])[:, None ]
217251
@@ -228,7 +262,6 @@ def test_lp_barycenter():
228262
229263
230264def test_free_support_barycenter ():
231-
232265 measures_locations = [np .array ([- 1. ]).reshape ((1 , 1 )), np .array ([1. ]).reshape ((1 , 1 ))]
233266 measures_weights = [np .array ([1. ]), np .array ([1. ])]
234267
@@ -244,7 +277,6 @@ def test_free_support_barycenter():
244277
245278@pytest .mark .skipif (not ot .lp .cvx .cvxopt , reason = "No cvxopt available" )
246279def test_lp_barycenter_cvxopt ():
247-
248280 a1 = np .array ([1.0 , 0 , 0 ])[:, None ]
249281 a2 = np .array ([0 , 0 , 1.0 ])[:, None ]
250282
0 commit comments