@@ -124,27 +124,26 @@ def test_warnings():
124124 # %%
125125
126126 print ('Computing {} EMD ' .format (1 ))
127- G , alpha , beta = ot .emd (a , b , M , dual_variables = True )
128127 with warnings .catch_warnings (record = True ) as w :
129128 # Cause all warnings to always be triggered.
130129 warnings .simplefilter ("always" )
131130 # Trigger a warning.
132131 print ('Computing {} EMD ' .format (1 ))
133- G , alpha , beta = ot .emd (a , b , M , dual_variables = True , numItermax = 1 )
132+ G = ot .emd (a , b , M , numItermax = 1 )
134133 # Verify some things
135134 assert "numItermax" in str (w [- 1 ].message )
136135 assert len (w ) == 1
137136 # Trigger a warning.
138137 a [0 ]= 100
139138 print ('Computing {} EMD ' .format (2 ))
140- G , alpha , beta = ot .emd (a , b , M , dual_variables = True )
139+ G = ot .emd (a , b , M )
141140 # Verify some things
142141 assert "infeasible" in str (w [- 1 ].message )
143142 assert len (w ) == 2
144143 # Trigger a warning.
145144 a [0 ]= - 1
146145 print ('Computing {} EMD ' .format (2 ))
147- G , alpha , beta = ot .emd (a , b , M , dual_variables = True )
146+ G = ot .emd (a , b , M )
148147 # Verify some things
149148 assert "infeasible" in str (w [- 1 ].message )
150149 assert len (w ) == 3
@@ -176,16 +175,11 @@ def test_dual_variables():
176175
177176 # emd loss 1 proc
178177 ot .tic ()
179- G , alpha , beta = ot .emd (a , b , M , dual_variables = True )
178+ G , log = ot .emd (a , b , M , log = True )
180179 ot .toc ('1 proc : {} s' )
181180
182181 cost1 = (G * M ).sum ()
183- cost_dual = np .vdot (a , alpha ) + np .vdot (b , beta )
184-
185- # emd loss 1 proc
186- ot .tic ()
187- cost_emd2 = ot .emd2 (a , b , M )
188- ot .toc ('1 proc : {} s' )
182+ cost_dual = np .vdot (a , log ['u' ]) + np .vdot (b , log ['v' ])
189183
190184 ot .tic ()
191185 G2 = ot .emd (b , a , np .ascontiguousarray (M .T ))
@@ -194,7 +188,7 @@ def test_dual_variables():
194188 cost2 = (G2 * M .T ).sum ()
195189
196190 # Check that both cost computations are equivalent
197- np .testing .assert_almost_equal (cost1 , cost_emd2 )
191+ np .testing .assert_almost_equal (cost1 , log [ 'cost' ] )
198192 # Check that dual and primal cost are equal
199193 np .testing .assert_almost_equal (cost1 , cost_dual )
200194 # Check symmetry
@@ -205,5 +199,5 @@ def test_dual_variables():
205199 [ind1 , ind2 ] = np .nonzero (G )
206200
207201 # Check that reduced cost is zero on transport arcs
208- np .testing .assert_array_almost_equal ((M - alpha .reshape (- 1 , 1 ) - beta .reshape (1 , - 1 ))[ind1 , ind2 ],
202+ np .testing .assert_array_almost_equal ((M - log [ 'u' ] .reshape (- 1 , 1 ) - log [ 'v' ] .reshape (1 , - 1 ))[ind1 , ind2 ],
209203 np .zeros (ind1 .size ))
0 commit comments