@@ -91,6 +91,44 @@ def test_emd_1d_emd2_1d():
9191 with pytest .raises (AssertionError ):
9292 ot .emd_1d (u , v , [], [])
9393
94+ def test_emd_1d_emd2_1d_with_weights ():
95+
96+ # test emd1d gives similar results as emd
97+ n = 20
98+ m = 30
99+ rng = np .random .RandomState (0 )
100+ u = rng .randn (n , 1 )
101+ v = rng .randn (m , 1 )
102+
103+ w_u = rng .uniform (0. , 1. , n )
104+ w_u = w_u / w_u .sum ()
105+
106+ w_v = rng .uniform (0. , 1. , m )
107+ w_v = w_v / w_v .sum ()
108+
109+ M = ot .dist (u , v , metric = 'sqeuclidean' )
110+
111+ G , log = ot .emd (w_u , w_v , M , log = True )
112+ wass = log ["cost" ]
113+ G_1d , log = ot .emd_1d (u , v , w_u , w_v , metric = 'sqeuclidean' , log = True )
114+ wass1d = log ["cost" ]
115+ wass1d_emd2 = ot .emd2_1d (u , v , w_u , w_v , metric = 'sqeuclidean' , log = False )
116+ wass1d_euc = ot .emd2_1d (u , v , w_u , w_v , metric = 'euclidean' , log = False )
117+
118+ # check loss is similar
119+ np .testing .assert_allclose (wass , wass1d )
120+ np .testing .assert_allclose (wass , wass1d_emd2 )
121+
122+ # check loss is similar to scipy's implementation for Euclidean metric
123+ wass_sp = wasserstein_distance (u .reshape ((- 1 ,)), v .reshape ((- 1 ,)))
124+ np .testing .assert_allclose (wass_sp , wass1d_euc )
125+
126+ # check constraints
127+ np .testing .assert_allclose (w_u , G .sum (1 ))
128+ np .testing .assert_allclose (w_v , G .sum (0 ))
129+
130+
131+
94132
95133def test_wass_1d ():
96134 # test emd1d gives similar results as emd
0 commit comments