44#
55# License: MIT License
66
7+ import warnings
8+
79import numpy as np
810
911import ot
1012from ot .datasets import get_1D_gauss as gauss
11- import warnings
1213
1314
1415def test_doctest ():
@@ -100,6 +101,21 @@ def test_emd2_multi():
100101
101102 np .testing .assert_allclose (emd1 , emdn )
102103
104+ # emd loss multipro proc with log
105+ ot .tic ()
106+ emdn = ot .emd2 (a , b , M , log = True )
107+ ot .toc ('multi proc : {} s' )
108+
109+ for i in range (len (emdn )):
110+ emd = emdn [i ]
111+ log = emd [1 ]
112+ cost = emd [0 ]
113+ check_duality_gap (a , b [:, i ], M , log ['G' ], log ['u' ], log ['v' ], cost )
114+ emdn [i ] = cost
115+
116+ emdn = np .array (emdn )
117+ np .testing .assert_allclose (emd1 , emdn )
118+
103119
104120def test_warnings ():
105121 n = 100 # nb bins
@@ -119,32 +135,22 @@ def test_warnings():
119135
120136 # loss matrix
121137 M = ot .dist (x .reshape ((- 1 , 1 )), y .reshape ((- 1 , 1 ))) ** (1. / 2 )
122- # M/=M.max()
123-
124- # %%
125138
126139 print ('Computing {} EMD ' .format (1 ))
127140 with warnings .catch_warnings (record = True ) as w :
128- # Cause all warnings to always be triggered.
129141 warnings .simplefilter ("always" )
130- # Trigger a warning.
131142 print ('Computing {} EMD ' .format (1 ))
132143 G = ot .emd (a , b , M , numItermax = 1 )
133- # Verify some things
134144 assert "numItermax" in str (w [- 1 ].message )
135145 assert len (w ) == 1
136- # Trigger a warning.
137- a [0 ]= 100
146+ a [0 ] = 100
138147 print ('Computing {} EMD ' .format (2 ))
139148 G = ot .emd (a , b , M )
140- # Verify some things
141149 assert "infeasible" in str (w [- 1 ].message )
142150 assert len (w ) == 2
143- # Trigger a warning.
144- a [0 ]= - 1
151+ a [0 ] = - 1
145152 print ('Computing {} EMD ' .format (2 ))
146153 G = ot .emd (a , b , M )
147- # Verify some things
148154 assert "infeasible" in str (w [- 1 ].message )
149155 assert len (w ) == 3
150156
@@ -167,9 +173,6 @@ def test_dual_variables():
167173
168174 # loss matrix
169175 M = ot .dist (x .reshape ((- 1 , 1 )), y .reshape ((- 1 , 1 ))) ** (1. / 2 )
170- # M/=M.max()
171-
172- # %%
173176
174177 print ('Computing {} EMD ' .format (1 ))
175178
@@ -178,26 +181,28 @@ def test_dual_variables():
178181 G , log = ot .emd (a , b , M , log = True )
179182 ot .toc ('1 proc : {} s' )
180183
181- cost1 = (G * M ).sum ()
182- cost_dual = np .vdot (a , log ['u' ]) + np .vdot (b , log ['v' ])
183-
184184 ot .tic ()
185185 G2 = ot .emd (b , a , np .ascontiguousarray (M .T ))
186186 ot .toc ('1 proc : {} s' )
187187
188- cost2 = (G2 * M .T ).sum ()
188+ cost1 = (G * M ).sum ()
189+ # Check symmetry
190+ np .testing .assert_array_almost_equal (cost1 , (M * G2 .T ).sum ())
191+ # Check with closed-form solution for gaussians
192+ np .testing .assert_almost_equal (cost1 , np .abs (mean1 - mean2 ))
189193
190194 # Check that both cost computations are equivalent
191195 np .testing .assert_almost_equal (cost1 , log ['cost' ])
196+ check_duality_gap (a , b , M , G , log ['u' ], log ['v' ], log ['cost' ])
197+
198+
199+ def check_duality_gap (a , b , M , G , u , v , cost ):
200+ cost_dual = np .vdot (a , u ) + np .vdot (b , v )
192201 # Check that dual and primal cost are equal
193- np .testing .assert_almost_equal (cost1 , cost_dual )
194- # Check symmetry
195- np .testing .assert_almost_equal (cost1 , cost2 )
196- # Check with closed-form solution for gaussians
197- np .testing .assert_almost_equal (cost1 , np .abs (mean1 - mean2 ))
202+ np .testing .assert_almost_equal (cost_dual , cost )
198203
199204 [ind1 , ind2 ] = np .nonzero (G )
200205
201206 # Check that reduced cost is zero on transport arcs
202- np .testing .assert_array_almost_equal ((M - log [ 'u' ] .reshape (- 1 , 1 ) - log [ 'v' ] .reshape (1 , - 1 ))[ind1 , ind2 ],
207+ np .testing .assert_array_almost_equal ((M - u .reshape (- 1 , 1 ) - v .reshape (1 , - 1 ))[ind1 , ind2 ],
203208 np .zeros (ind1 .size ))
0 commit comments