Skip to content

Commit e5196fa

Browse files
committed
correct bug in emd emd2 still todo
1 parent 30fc233 commit e5196fa

File tree

2 files changed

+174
-22
lines changed

2 files changed

+174
-22
lines changed

ot/lp/__init__.py

Lines changed: 172 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,150 @@
2323
from .cvx import barycenter
2424
from ..utils import dist
2525

26-
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
27-
'emd_1d', 'emd2_1d', 'wasserstein_1d']
26+
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
27+
'emd_1d', 'emd2_1d', 'wasserstein_1d']
2828

2929

30-
def emd(a, b, M, numItermax=100000, log=False, dense=True):
30+
def center_ot_dual(alpha0, beta0, a=None, b=None):
31+
r"""Center dual OT potentials wrt theirs weights
32+
33+
The main idea of this function is to find unique dual potentials
34+
that ensure some kind of centering/fairness. It will help have
35+
stability when multiple calling of the OT solver with small changes.
36+
37+
Basically we add another constraint to the potential that will not
38+
change the objective value but will ensure unicity. The constraint
39+
is the following:
40+
41+
.. math::
42+
\alpha^T a= \beta^T b
43+
44+
in addition to the OT problem constraints.
45+
46+
since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing
47+
a constant from both :math:`\alpha_0` and :math:`\beta_0`.
48+
49+
.. math::
50+
c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta}
51+
52+
\alpha=\alpha_0+c
53+
54+
\beta=\beta0+c
55+
56+
Parameters
57+
----------
58+
alpha0 : (ns,) numpy.ndarray, float64
59+
Source dual potential
60+
beta0 : (nt,) numpy.ndarray, float64
61+
Target dual potential
62+
a : (ns,) numpy.ndarray, float64
63+
Source histogram (uniform weight if empty list)
64+
b : (nt,) numpy.ndarray, float64
65+
Target histogram (uniform weight if empty list)
66+
67+
Returns
68+
-------
69+
alpha : (ns,) numpy.ndarray, float64
70+
Source centered dual potential
71+
beta : (nt,) numpy.ndarray, float64
72+
Target centered dual potential
73+
74+
"""
75+
# if no weights are provided, use uniform
76+
if a is None:
77+
a = np.ones(alpha0.shape[0]) / alpha0.shape[0]
78+
if b is None:
79+
b = np.ones(beta0.shape[0]) / beta0.shape[0]
80+
81+
# compute constant that balances the weighted sums of the duals
82+
c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum())
83+
84+
# update duals
85+
alpha = alpha0 + c
86+
beta = beta0 - c
87+
88+
return alpha, beta
89+
90+
91+
def estimate_dual_null_weights(alpha0, beta0, a, b, M):
92+
r"""Estimate feasible values for 0-weighted dual potentials
93+
94+
The feasible values are computed efficiently bjt rather coarsely.
95+
First we compute the constraints violations:
96+
97+
.. math::
98+
V=\alpha+\beta^T-M
99+
100+
Next we compute the max amount of violation per row (alpha) and
101+
columns (beta)
102+
103+
.. math::
104+
v^a_i=\max_j V_{i,j}
105+
106+
v^b_j=\max_i V_{i,j}
107+
108+
Finally we update the dual potential with 0 weights if a
109+
constraint is violated
110+
111+
.. math::
112+
\alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0
113+
114+
\beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0
115+
116+
In the end the dual potential are centred using function
117+
:ref:`center_ot_dual`.
118+
119+
Note that all those updates do not change the objective value of the
120+
solution but provide dual potential that do not violate the constraints.
121+
122+
Parameters
123+
----------
124+
alpha0 : (ns,) numpy.ndarray, float64
125+
Source dual potential
126+
beta0 : (nt,) numpy.ndarray, float64
127+
Target dual potential
128+
alpha0 : (ns,) numpy.ndarray, float64
129+
Source dual potential
130+
beta0 : (nt,) numpy.ndarray, float64
131+
Target dual potential
132+
a : (ns,) numpy.ndarray, float64
133+
Source histogram (uniform weight if empty list)
134+
b : (nt,) numpy.ndarray, float64
135+
Target histogram (uniform weight if empty list)
136+
M : (ns,nt) numpy.ndarray, float64
137+
Loss matrix (c-order array with type float64)
138+
139+
Returns
140+
-------
141+
alpha : (ns,) numpy.ndarray, float64
142+
Source corrected dual potential
143+
beta : (nt,) numpy.ndarray, float64
144+
Target corrected dual potential
145+
146+
"""
147+
148+
# binary indexing of non-zeros weights
149+
asel = a != 0
150+
bsel = b != 0
151+
152+
# compute dual constraints violation
153+
Viol = alpha0[:, None] + beta0[None, :] - M
154+
155+
# Compute worst violation per line and columns
156+
aviol = np.max(Viol, 1)
157+
bviol = np.max(Viol, 0)
158+
159+
# update corrects violation of
160+
alpha_up = -1 * ~asel * np.maximum(aviol, 0)
161+
beta_up = -1 * ~bsel * np.maximum(bviol, 0)
162+
163+
alpha = alpha0 + alpha_up
164+
beta = beta0 + beta_up
165+
166+
return center_ot_dual(alpha, beta, a, b)
167+
168+
169+
def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
31170
r"""Solves the Earth Movers distance problem and returns the OT matrix
32171
33172
@@ -43,7 +182,7 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
43182
- a and b are the sample weights
44183
45184
.. warning::
46-
Note that the M matrix needs to be a C-order numpy.array in float64
185+
Note that the M matrix needs to be a C-order numpy.array in float64
47186
format.
48187
49188
Uses the algorithm proposed in [1]_
@@ -66,6 +205,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
66205
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
67206
Otherwise returns a sparse representation using scipy's `coo_matrix`
68207
format.
208+
center_dual: boolean, optional (default=True)
209+
If True, centers the dual potential using function
210+
:ref:`center_ot_dual`.
69211
70212
Returns
71213
-------
@@ -107,7 +249,6 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
107249
b = np.asarray(b, dtype=np.float64)
108250
M = np.asarray(M, dtype=np.float64)
109251

110-
111252
# if empty array given then use uniform distributions
112253
if len(a) == 0:
113254
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
@@ -117,11 +258,21 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
117258
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
118259
"Dimension mismatch, check dimensions of M with a and b"
119260

261+
asel = a != 0
262+
bsel = b != 0
263+
120264
if dense:
121-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
265+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
266+
267+
if np.any(~asel) or np.any(~bsel):
268+
u, v = estimate_dual_null_weights(u, v, a, b, M)
269+
122270
else:
123-
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
124-
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
271+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
272+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
273+
274+
if np.any(~asel) or np.any(~bsel):
275+
u, v = estimate_dual_null_weights(u, v, a, b, M)
125276

126277
result_code_string = check_result(result_code)
127278
if log:
@@ -151,7 +302,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
151302
- a and b are the sample weights
152303
153304
.. warning::
154-
Note that the M matrix needs to be a C-order numpy.array in float64
305+
Note that the M matrix needs to be a C-order numpy.array in float64
155306
format.
156307
157308
Uses the algorithm proposed in [1]_
@@ -177,7 +328,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
177328
dense: boolean, optional (default=True)
178329
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
179330
Otherwise returns a sparse representation using scipy's `coo_matrix`
180-
format.
331+
format.
181332
182333
Returns
183334
-------
@@ -221,7 +372,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
221372

222373
# problem with pikling Forks
223374
if sys.platform.endswith('win32'):
224-
processes=1
375+
processes = 1
225376

226377
# if empty array given then use uniform distributions
227378
if len(a) == 0:
@@ -235,10 +386,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
235386
if log or return_matrix:
236387
def f(b):
237388
if dense:
238-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
389+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
239390
else:
240-
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
241-
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
391+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
392+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
242393

243394
result_code_string = check_result(result_code)
244395
log = {}
@@ -252,10 +403,10 @@ def f(b):
252403
else:
253404
def f(b):
254405
if dense:
255-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
406+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
256407
else:
257-
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
258-
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
408+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
409+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
259410

260411
result_code_string = check_result(result_code)
261412
check_result(result_code)
@@ -265,15 +416,14 @@ def f(b):
265416
return f(b)
266417
nb = b.shape[1]
267418

268-
if processes>1:
419+
if processes > 1:
269420
res = parmap(f, [b[:, i] for i in range(nb)], processes)
270421
else:
271422
res = list(map(f, [b[:, i].copy() for i in range(nb)]))
272423

273424
return res
274425

275426

276-
277427
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
278428
"""
279429
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
@@ -326,7 +476,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
326476
k = X_init.shape[0]
327477
d = X_init.shape[1]
328478
if b is None:
329-
b = np.ones((k,))/k
479+
b = np.ones((k,)) / k
330480
if weights is None:
331481
weights = np.ones((N,)) / N
332482

@@ -337,7 +487,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
337487

338488
displacement_square_norm = stopThr + 1.
339489

340-
while ( displacement_square_norm > stopThr and iter_count < numItermax ):
490+
while (displacement_square_norm > stopThr and iter_count < numItermax):
341491

342492
T_sum = np.zeros((k, d))
343493

@@ -347,7 +497,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
347497
T_i = emd(b, measure_weights_i, M_i)
348498
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
349499

350-
displacement_square_norm = np.sum(np.square(T_sum-X))
500+
displacement_square_norm = np.sum(np.square(T_sum - X))
351501
if log:
352502
displacement_square_norms.append(displacement_square_norm)
353503

ot/lp/emd_wrap.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def check_result(result_code):
4040
return message
4141

4242

43+
44+
4345
@cython.boundscheck(False)
4446
@cython.wraparound(False)
4547
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint dense):

0 commit comments

Comments
 (0)