@@ -264,13 +264,19 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
264264 if dense :
265265 G , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
266266
267+ if center_dual :
268+ u , v = center_ot_dual (u , v , a , b )
269+
267270 if np .any (~ asel ) or np .any (~ bsel ):
268271 u , v = estimate_dual_null_weights (u , v , a , b , M )
269272
270273 else :
271274 Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
272275 G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
273276
277+ if center_dual :
278+ u , v = center_ot_dual (u , v , a , b )
279+
274280 if np .any (~ asel ) or np .any (~ bsel ):
275281 u , v = estimate_dual_null_weights (u , v , a , b , M )
276282
@@ -287,7 +293,8 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
287293
288294
289295def emd2 (a , b , M , processes = multiprocessing .cpu_count (),
290- numItermax = 100000 , log = False , dense = True , return_matrix = False ):
296+ numItermax = 100000 , log = False , dense = True , return_matrix = False ,
297+ center_dual = True ):
291298 r"""Solves the Earth Movers distance problem and returns the loss
292299
293300 .. math::
@@ -329,6 +336,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
329336 If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
330337 Otherwise returns a sparse representation using scipy's `coo_matrix`
331338 format.
339+ center_dual: boolean, optional (default=True)
340+ If True, centers the dual potential using function
341+ :ref:`center_ot_dual`.
332342
333343 Returns
334344 -------
@@ -383,14 +393,23 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
383393 assert (a .shape [0 ] == M .shape [0 ] and b .shape [0 ] == M .shape [1 ]), \
384394 "Dimension mismatch, check dimensions of M with a and b"
385395
396+ asel = a != 0
397+
386398 if log or return_matrix :
387399 def f (b ):
400+ bsel = b != 0
388401 if dense :
389402 G , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
390403 else :
391404 Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
392405 G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
393406
407+ if center_dual :
408+ u , v = center_ot_dual (u , v , a , b )
409+
410+ if np .any (~ asel ) or np .any (~ bsel ):
411+ u , v = estimate_dual_null_weights (u , v , a , b , M )
412+
394413 result_code_string = check_result (result_code )
395414 log = {}
396415 if return_matrix :
@@ -402,12 +421,19 @@ def f(b):
402421 return [cost , log ]
403422 else :
404423 def f (b ):
424+ bsel = b != 0
405425 if dense :
406426 G , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
407427 else :
408428 Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
409429 G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
410430
431+ if center_dual :
432+ u , v = center_ot_dual (u , v , a , b )
433+
434+ if np .any (~ asel ) or np .any (~ bsel ):
435+ u , v = estimate_dual_null_weights (u , v , a , b , M )
436+
411437 result_code_string = check_result (result_code )
412438 check_result (result_code )
413439 return cost
0 commit comments