Skip to content

Commit 9a9b354

Browse files
committed
correct emd2 and add centering for dual potentials
1 parent e5196fa commit 9a9b354

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

ot/lp/__init__.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,19 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
264264
if dense:
265265
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
266266

267+
if center_dual:
268+
u, v = center_ot_dual(u, v, a, b)
269+
267270
if np.any(~asel) or np.any(~bsel):
268271
u, v = estimate_dual_null_weights(u, v, a, b, M)
269272

270273
else:
271274
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
272275
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
273276

277+
if center_dual:
278+
u, v = center_ot_dual(u, v, a, b)
279+
274280
if np.any(~asel) or np.any(~bsel):
275281
u, v = estimate_dual_null_weights(u, v, a, b, M)
276282

@@ -287,7 +293,8 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
287293

288294

289295
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
290-
numItermax=100000, log=False, dense=True, return_matrix=False):
296+
numItermax=100000, log=False, dense=True, return_matrix=False,
297+
center_dual=True):
291298
r"""Solves the Earth Movers distance problem and returns the loss
292299
293300
.. math::
@@ -329,6 +336,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
329336
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
330337
Otherwise returns a sparse representation using scipy's `coo_matrix`
331338
format.
339+
center_dual: boolean, optional (default=True)
340+
If True, centers the dual potential using function
341+
:ref:`center_ot_dual`.
332342
333343
Returns
334344
-------
@@ -383,14 +393,23 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
383393
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
384394
"Dimension mismatch, check dimensions of M with a and b"
385395

396+
asel = a != 0
397+
386398
if log or return_matrix:
387399
def f(b):
400+
bsel = b != 0
388401
if dense:
389402
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
390403
else:
391404
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
392405
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
393406

407+
if center_dual:
408+
u, v = center_ot_dual(u, v, a, b)
409+
410+
if np.any(~asel) or np.any(~bsel):
411+
u, v = estimate_dual_null_weights(u, v, a, b, M)
412+
394413
result_code_string = check_result(result_code)
395414
log = {}
396415
if return_matrix:
@@ -402,12 +421,19 @@ def f(b):
402421
return [cost, log]
403422
else:
404423
def f(b):
424+
bsel = b != 0
405425
if dense:
406426
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
407427
else:
408428
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
409429
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
410430

431+
if center_dual:
432+
u, v = center_ot_dual(u, v, a, b)
433+
434+
if np.any(~asel) or np.any(~bsel):
435+
u, v = estimate_dual_null_weights(u, v, a, b, M)
436+
411437
result_code_string = check_result(result_code)
412438
check_result(result_code)
413439
return cost

0 commit comments

Comments
 (0)