Skip to content

Commit 81e9d42

Browse files
authored
Merge pull request #123 from rflamary/bug_dual
[WIP] Bug on dual potential (constraint violation when some weights are 0)
2 parents da37513 + f65073f commit 81e9d42

File tree

3 files changed

+219
-23
lines changed

3 files changed

+219
-23
lines changed

ot/lp/__init__.py

Lines changed: 207 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,158 @@
2323
from .cvx import barycenter
2424
from ..utils import dist
2525

26-
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
27-
'emd_1d', 'emd2_1d', 'wasserstein_1d']
26+
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
27+
'emd_1d', 'emd2_1d', 'wasserstein_1d']
2828

2929

30-
def emd(a, b, M, numItermax=100000, log=False, dense=True):
30+
def center_ot_dual(alpha0, beta0, a=None, b=None):
31+
r"""Center dual OT potentials w.r.t. theirs weights
32+
33+
The main idea of this function is to find unique dual potentials
34+
that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having
35+
stability when multiple calling of the OT solver with small changes.
36+
37+
Basically we add another constraint to the potential that will not
38+
change the objective value but will ensure unicity. The constraint
39+
is the following:
40+
41+
.. math::
42+
\alpha^T a= \beta^T b
43+
44+
in addition to the OT problem constraints.
45+
46+
since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing
47+
a constant from both :math:`\alpha_0` and :math:`\beta_0`.
48+
49+
.. math::
50+
c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta}
51+
52+
\alpha=\alpha_0+c
53+
54+
\beta=\beta0+c
55+
56+
Parameters
57+
----------
58+
alpha0 : (ns,) numpy.ndarray, float64
59+
Source dual potential
60+
beta0 : (nt,) numpy.ndarray, float64
61+
Target dual potential
62+
a : (ns,) numpy.ndarray, float64
63+
Source histogram (uniform weight if empty list)
64+
b : (nt,) numpy.ndarray, float64
65+
Target histogram (uniform weight if empty list)
66+
67+
Returns
68+
-------
69+
alpha : (ns,) numpy.ndarray, float64
70+
Source centered dual potential
71+
beta : (nt,) numpy.ndarray, float64
72+
Target centered dual potential
73+
74+
"""
75+
# if no weights are provided, use uniform
76+
if a is None:
77+
a = np.ones(alpha0.shape[0]) / alpha0.shape[0]
78+
if b is None:
79+
b = np.ones(beta0.shape[0]) / beta0.shape[0]
80+
81+
# compute constant that balances the weighted sums of the duals
82+
c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum())
83+
84+
# update duals
85+
alpha = alpha0 + c
86+
beta = beta0 - c
87+
88+
return alpha, beta
89+
90+
91+
def estimate_dual_null_weights(alpha0, beta0, a, b, M):
92+
r"""Estimate feasible values for 0-weighted dual potentials
93+
94+
The feasible values are computed efficiently but rather coarsely.
95+
96+
.. warning::
97+
This function is necessary because the C++ solver in emd_c
98+
discards all samples in the distributions with
99+
zeros weights. This means that while the primal variable (transport
100+
matrix) is exact, the solver only returns feasible dual potentials
101+
on the samples with weights different from zero.
102+
103+
First we compute the constraints violations:
104+
105+
.. math::
106+
V=\alpha+\beta^T-M
107+
108+
Next we compute the max amount of violation per row (alpha) and
109+
columns (beta)
110+
111+
.. math::
112+
v^a_i=\max_j V_{i,j}
113+
114+
v^b_j=\max_i V_{i,j}
115+
116+
Finally we update the dual potential with 0 weights if a
117+
constraint is violated
118+
119+
.. math::
120+
\alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0
121+
122+
\beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0
123+
124+
In the end the dual potentials are centered using function
125+
:ref:`center_ot_dual`.
126+
127+
Note that all those updates do not change the objective value of the
128+
solution but provide dual potentials that do not violate the constraints.
129+
130+
Parameters
131+
----------
132+
alpha0 : (ns,) numpy.ndarray, float64
133+
Source dual potential
134+
beta0 : (nt,) numpy.ndarray, float64
135+
Target dual potential
136+
alpha0 : (ns,) numpy.ndarray, float64
137+
Source dual potential
138+
beta0 : (nt,) numpy.ndarray, float64
139+
Target dual potential
140+
a : (ns,) numpy.ndarray, float64
141+
Source distribution (uniform weights if empty list)
142+
b : (nt,) numpy.ndarray, float64
143+
Target distribution (uniform weights if empty list)
144+
M : (ns,nt) numpy.ndarray, float64
145+
Loss matrix (c-order array with type float64)
146+
147+
Returns
148+
-------
149+
alpha : (ns,) numpy.ndarray, float64
150+
Source corrected dual potential
151+
beta : (nt,) numpy.ndarray, float64
152+
Target corrected dual potential
153+
154+
"""
155+
156+
# binary indexing of non-zeros weights
157+
asel = a != 0
158+
bsel = b != 0
159+
160+
# compute dual constraints violation
161+
constraint_violation = alpha0[:, None] + beta0[None, :] - M
162+
163+
# Compute largest violation per line and columns
164+
aviol = np.max(constraint_violation, 1)
165+
bviol = np.max(constraint_violation, 0)
166+
167+
# update corrects violation of
168+
alpha_up = -1 * ~asel * np.maximum(aviol, 0)
169+
beta_up = -1 * ~bsel * np.maximum(bviol, 0)
170+
171+
alpha = alpha0 + alpha_up
172+
beta = beta0 + beta_up
173+
174+
return center_ot_dual(alpha, beta, a, b)
175+
176+
177+
def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
31178
r"""Solves the Earth Movers distance problem and returns the OT matrix
32179
33180
@@ -43,7 +190,7 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
43190
- a and b are the sample weights
44191
45192
.. warning::
46-
Note that the M matrix needs to be a C-order numpy.array in float64
193+
Note that the M matrix needs to be a C-order numpy.array in float64
47194
format.
48195
49196
Uses the algorithm proposed in [1]_
@@ -66,6 +213,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
66213
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
67214
Otherwise returns a sparse representation using scipy's `coo_matrix`
68215
format.
216+
center_dual: boolean, optional (default=True)
217+
If True, centers the dual potential using function
218+
:ref:`center_ot_dual`.
69219
70220
Returns
71221
-------
@@ -107,7 +257,6 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
107257
b = np.asarray(b, dtype=np.float64)
108258
M = np.asarray(M, dtype=np.float64)
109259

110-
111260
# if empty array given then use uniform distributions
112261
if len(a) == 0:
113262
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
@@ -117,11 +266,27 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
117266
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
118267
"Dimension mismatch, check dimensions of M with a and b"
119268

269+
asel = a != 0
270+
bsel = b != 0
271+
120272
if dense:
121-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
273+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
274+
275+
if center_dual:
276+
u, v = center_ot_dual(u, v, a, b)
277+
278+
if np.any(~asel) or np.any(~bsel):
279+
u, v = estimate_dual_null_weights(u, v, a, b, M)
280+
122281
else:
123-
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
124-
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
282+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
283+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
284+
285+
if center_dual:
286+
u, v = center_ot_dual(u, v, a, b)
287+
288+
if np.any(~asel) or np.any(~bsel):
289+
u, v = estimate_dual_null_weights(u, v, a, b, M)
125290

126291
result_code_string = check_result(result_code)
127292
if log:
@@ -136,7 +301,8 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
136301

137302

138303
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
139-
numItermax=100000, log=False, dense=True, return_matrix=False):
304+
numItermax=100000, log=False, dense=True, return_matrix=False,
305+
center_dual=True):
140306
r"""Solves the Earth Movers distance problem and returns the loss
141307
142308
.. math::
@@ -151,7 +317,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
151317
- a and b are the sample weights
152318
153319
.. warning::
154-
Note that the M matrix needs to be a C-order numpy.array in float64
320+
Note that the M matrix needs to be a C-order numpy.array in float64
155321
format.
156322
157323
Uses the algorithm proposed in [1]_
@@ -177,7 +343,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
177343
dense: boolean, optional (default=True)
178344
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
179345
Otherwise returns a sparse representation using scipy's `coo_matrix`
180-
format.
346+
format.
347+
center_dual: boolean, optional (default=True)
348+
If True, centers the dual potential using function
349+
:ref:`center_ot_dual`.
181350
182351
Returns
183352
-------
@@ -221,7 +390,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
221390

222391
# problem with pikling Forks
223392
if sys.platform.endswith('win32'):
224-
processes=1
393+
processes = 1
225394

226395
# if empty array given then use uniform distributions
227396
if len(a) == 0:
@@ -232,13 +401,22 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
232401
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
233402
"Dimension mismatch, check dimensions of M with a and b"
234403

404+
asel = a != 0
405+
235406
if log or return_matrix:
236407
def f(b):
408+
bsel = b != 0
237409
if dense:
238-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
410+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
239411
else:
240-
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
241-
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
412+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
413+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
414+
415+
if center_dual:
416+
u, v = center_ot_dual(u, v, a, b)
417+
418+
if np.any(~asel) or np.any(~bsel):
419+
u, v = estimate_dual_null_weights(u, v, a, b, M)
242420

243421
result_code_string = check_result(result_code)
244422
log = {}
@@ -251,11 +429,18 @@ def f(b):
251429
return [cost, log]
252430
else:
253431
def f(b):
432+
bsel = b != 0
254433
if dense:
255-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
434+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
256435
else:
257-
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
258-
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
436+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
437+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
438+
439+
if center_dual:
440+
u, v = center_ot_dual(u, v, a, b)
441+
442+
if np.any(~asel) or np.any(~bsel):
443+
u, v = estimate_dual_null_weights(u, v, a, b, M)
259444

260445
result_code_string = check_result(result_code)
261446
check_result(result_code)
@@ -265,15 +450,14 @@ def f(b):
265450
return f(b)
266451
nb = b.shape[1]
267452

268-
if processes>1:
453+
if processes > 1:
269454
res = parmap(f, [b[:, i] for i in range(nb)], processes)
270455
else:
271456
res = list(map(f, [b[:, i].copy() for i in range(nb)]))
272457

273458
return res
274459

275460

276-
277461
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
278462
"""
279463
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
@@ -326,7 +510,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
326510
k = X_init.shape[0]
327511
d = X_init.shape[1]
328512
if b is None:
329-
b = np.ones((k,))/k
513+
b = np.ones((k,)) / k
330514
if weights is None:
331515
weights = np.ones((N,)) / N
332516

@@ -337,7 +521,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
337521

338522
displacement_square_norm = stopThr + 1.
339523

340-
while ( displacement_square_norm > stopThr and iter_count < numItermax ):
524+
while (displacement_square_norm > stopThr and iter_count < numItermax):
341525

342526
T_sum = np.zeros((k, d))
343527

@@ -347,7 +531,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
347531
T_i = emd(b, measure_weights_i, M_i)
348532
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
349533

350-
displacement_square_norm = np.sum(np.square(T_sum-X))
534+
displacement_square_norm = np.sum(np.square(T_sum - X))
351535
if log:
352536
displacement_square_norms.append(displacement_square_norm)
353537

ot/lp/emd_wrap.pyx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def check_result(result_code):
4040
return message
4141

4242

43+
44+
4345
@cython.boundscheck(False)
4446
@cython.wraparound(False)
4547
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint dense):
@@ -64,6 +66,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
6466
.. warning::
6567
Note that the M matrix needs to be a C-order :py.cls:`numpy.array`
6668
69+
.. warning::
70+
The C++ solver discards all samples in the distributions with
71+
zeros weights. This means that while the primal variable (transport
72+
matrix) is exact, the solver only returns feasible dual potentials
73+
on the samples with weights different from zero.
74+
6775
Parameters
6876
----------
6977
a : (ns,) numpy.ndarray, float64

test/test_ot.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,10 @@ def test_dual_variables():
338338
np.testing.assert_almost_equal(cost1, log['cost'])
339339
check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost'])
340340

341+
constraint_violation = log['u'][:, None] + log['v'][None, :] - M
342+
343+
assert constraint_violation.max() < 1e-8
344+
341345

342346
def check_duality_gap(a, b, M, G, u, v, cost):
343347
cost_dual = np.vdot(a, u) + np.vdot(b, v)

0 commit comments

Comments
 (0)