2323from .cvx import barycenter
2424from ..utils import dist
2525
26- __all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' ,
27- 'emd_1d' , 'emd2_1d' , 'wasserstein_1d' ]
26+ __all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' ,
27+ 'emd_1d' , 'emd2_1d' , 'wasserstein_1d' ]
2828
2929
30- def emd (a , b , M , numItermax = 100000 , log = False , dense = True ):
30+ def center_ot_dual (alpha0 , beta0 , a = None , b = None ):
31+ r"""Center dual OT potentials w.r.t. theirs weights
32+
33+ The main idea of this function is to find unique dual potentials
34+ that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having
35+ stability when multiple calling of the OT solver with small changes.
36+
37+ Basically we add another constraint to the potential that will not
38+ change the objective value but will ensure unicity. The constraint
39+ is the following:
40+
41+ .. math::
42+ \alpha^T a= \beta^T b
43+
44+ in addition to the OT problem constraints.
45+
46+ since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing
47+ a constant from both :math:`\alpha_0` and :math:`\beta_0`.
48+
49+ .. math::
50+ c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta}
51+
52+ \alpha=\alpha_0+c
53+
54+ \beta=\beta0+c
55+
56+ Parameters
57+ ----------
58+ alpha0 : (ns,) numpy.ndarray, float64
59+ Source dual potential
60+ beta0 : (nt,) numpy.ndarray, float64
61+ Target dual potential
62+ a : (ns,) numpy.ndarray, float64
63+ Source histogram (uniform weight if empty list)
64+ b : (nt,) numpy.ndarray, float64
65+ Target histogram (uniform weight if empty list)
66+
67+ Returns
68+ -------
69+ alpha : (ns,) numpy.ndarray, float64
70+ Source centered dual potential
71+ beta : (nt,) numpy.ndarray, float64
72+ Target centered dual potential
73+
74+ """
75+ # if no weights are provided, use uniform
76+ if a is None :
77+ a = np .ones (alpha0 .shape [0 ]) / alpha0 .shape [0 ]
78+ if b is None :
79+ b = np .ones (beta0 .shape [0 ]) / beta0 .shape [0 ]
80+
81+ # compute constant that balances the weighted sums of the duals
82+ c = (b .dot (beta0 ) - a .dot (alpha0 )) / (a .sum () + b .sum ())
83+
84+ # update duals
85+ alpha = alpha0 + c
86+ beta = beta0 - c
87+
88+ return alpha , beta
89+
90+
91+ def estimate_dual_null_weights (alpha0 , beta0 , a , b , M ):
92+ r"""Estimate feasible values for 0-weighted dual potentials
93+
94+ The feasible values are computed efficiently but rather coarsely.
95+
96+ .. warning::
97+ This function is necessary because the C++ solver in emd_c
98+ discards all samples in the distributions with
99+ zeros weights. This means that while the primal variable (transport
100+ matrix) is exact, the solver only returns feasible dual potentials
101+ on the samples with weights different from zero.
102+
103+ First we compute the constraints violations:
104+
105+ .. math::
106+ V=\alpha+\beta^T-M
107+
108+ Next we compute the max amount of violation per row (alpha) and
109+ columns (beta)
110+
111+ .. math::
112+ v^a_i=\max_j V_{i,j}
113+
114+ v^b_j=\max_i V_{i,j}
115+
116+ Finally we update the dual potential with 0 weights if a
117+ constraint is violated
118+
119+ .. math::
120+ \alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0
121+
122+ \beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0
123+
124+ In the end the dual potentials are centered using function
125+ :ref:`center_ot_dual`.
126+
127+ Note that all those updates do not change the objective value of the
128+ solution but provide dual potentials that do not violate the constraints.
129+
130+ Parameters
131+ ----------
132+ alpha0 : (ns,) numpy.ndarray, float64
133+ Source dual potential
134+ beta0 : (nt,) numpy.ndarray, float64
135+ Target dual potential
136+ alpha0 : (ns,) numpy.ndarray, float64
137+ Source dual potential
138+ beta0 : (nt,) numpy.ndarray, float64
139+ Target dual potential
140+ a : (ns,) numpy.ndarray, float64
141+ Source distribution (uniform weights if empty list)
142+ b : (nt,) numpy.ndarray, float64
143+ Target distribution (uniform weights if empty list)
144+ M : (ns,nt) numpy.ndarray, float64
145+ Loss matrix (c-order array with type float64)
146+
147+ Returns
148+ -------
149+ alpha : (ns,) numpy.ndarray, float64
150+ Source corrected dual potential
151+ beta : (nt,) numpy.ndarray, float64
152+ Target corrected dual potential
153+
154+ """
155+
156+ # binary indexing of non-zeros weights
157+ asel = a != 0
158+ bsel = b != 0
159+
160+ # compute dual constraints violation
161+ constraint_violation = alpha0 [:, None ] + beta0 [None , :] - M
162+
163+ # Compute largest violation per line and columns
164+ aviol = np .max (constraint_violation , 1 )
165+ bviol = np .max (constraint_violation , 0 )
166+
167+ # update corrects violation of
168+ alpha_up = - 1 * ~ asel * np .maximum (aviol , 0 )
169+ beta_up = - 1 * ~ bsel * np .maximum (bviol , 0 )
170+
171+ alpha = alpha0 + alpha_up
172+ beta = beta0 + beta_up
173+
174+ return center_ot_dual (alpha , beta , a , b )
175+
176+
177+ def emd (a , b , M , numItermax = 100000 , log = False , dense = True , center_dual = True ):
31178 r"""Solves the Earth Movers distance problem and returns the OT matrix
32179
33180
@@ -43,7 +190,7 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
43190 - a and b are the sample weights
44191
45192 .. warning::
46- Note that the M matrix needs to be a C-order numpy.array in float64
193+ Note that the M matrix needs to be a C-order numpy.array in float64
47194 format.
48195
49196 Uses the algorithm proposed in [1]_
@@ -66,6 +213,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
66213 If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
67214 Otherwise returns a sparse representation using scipy's `coo_matrix`
68215 format.
216+ center_dual: boolean, optional (default=True)
217+ If True, centers the dual potential using function
218+ :ref:`center_ot_dual`.
69219
70220 Returns
71221 -------
@@ -107,7 +257,6 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
107257 b = np .asarray (b , dtype = np .float64 )
108258 M = np .asarray (M , dtype = np .float64 )
109259
110-
111260 # if empty array given then use uniform distributions
112261 if len (a ) == 0 :
113262 a = np .ones ((M .shape [0 ],), dtype = np .float64 ) / M .shape [0 ]
@@ -117,11 +266,27 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
117266 assert (a .shape [0 ] == M .shape [0 ] and b .shape [0 ] == M .shape [1 ]), \
118267 "Dimension mismatch, check dimensions of M with a and b"
119268
269+ asel = a != 0
270+ bsel = b != 0
271+
120272 if dense :
121- G , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
273+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
274+
275+ if center_dual :
276+ u , v = center_ot_dual (u , v , a , b )
277+
278+ if np .any (~ asel ) or np .any (~ bsel ):
279+ u , v = estimate_dual_null_weights (u , v , a , b , M )
280+
122281 else :
123- Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
124- G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
282+ Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
283+ G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
284+
285+ if center_dual :
286+ u , v = center_ot_dual (u , v , a , b )
287+
288+ if np .any (~ asel ) or np .any (~ bsel ):
289+ u , v = estimate_dual_null_weights (u , v , a , b , M )
125290
126291 result_code_string = check_result (result_code )
127292 if log :
@@ -136,7 +301,8 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
136301
137302
138303def emd2 (a , b , M , processes = multiprocessing .cpu_count (),
139- numItermax = 100000 , log = False , dense = True , return_matrix = False ):
304+ numItermax = 100000 , log = False , dense = True , return_matrix = False ,
305+ center_dual = True ):
140306 r"""Solves the Earth Movers distance problem and returns the loss
141307
142308 .. math::
@@ -151,7 +317,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
151317 - a and b are the sample weights
152318
153319 .. warning::
154- Note that the M matrix needs to be a C-order numpy.array in float64
320+ Note that the M matrix needs to be a C-order numpy.array in float64
155321 format.
156322
157323 Uses the algorithm proposed in [1]_
@@ -177,7 +343,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
177343 dense: boolean, optional (default=True)
178344 If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
179345 Otherwise returns a sparse representation using scipy's `coo_matrix`
180- format.
346+ format.
347+ center_dual: boolean, optional (default=True)
348+ If True, centers the dual potential using function
349+ :ref:`center_ot_dual`.
181350
182351 Returns
183352 -------
@@ -221,7 +390,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
221390
222391 # problem with pikling Forks
223392 if sys .platform .endswith ('win32' ):
224- processes = 1
393+ processes = 1
225394
226395 # if empty array given then use uniform distributions
227396 if len (a ) == 0 :
@@ -232,13 +401,22 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
232401 assert (a .shape [0 ] == M .shape [0 ] and b .shape [0 ] == M .shape [1 ]), \
233402 "Dimension mismatch, check dimensions of M with a and b"
234403
404+ asel = a != 0
405+
235406 if log or return_matrix :
236407 def f (b ):
408+ bsel = b != 0
237409 if dense :
238- G , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
410+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
239411 else :
240- Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
241- G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
412+ Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
413+ G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
414+
415+ if center_dual :
416+ u , v = center_ot_dual (u , v , a , b )
417+
418+ if np .any (~ asel ) or np .any (~ bsel ):
419+ u , v = estimate_dual_null_weights (u , v , a , b , M )
242420
243421 result_code_string = check_result (result_code )
244422 log = {}
@@ -251,11 +429,18 @@ def f(b):
251429 return [cost , log ]
252430 else :
253431 def f (b ):
432+ bsel = b != 0
254433 if dense :
255- G , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
434+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
256435 else :
257- Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
258- G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
436+ Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
437+ G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
438+
439+ if center_dual :
440+ u , v = center_ot_dual (u , v , a , b )
441+
442+ if np .any (~ asel ) or np .any (~ bsel ):
443+ u , v = estimate_dual_null_weights (u , v , a , b , M )
259444
260445 result_code_string = check_result (result_code )
261446 check_result (result_code )
@@ -265,15 +450,14 @@ def f(b):
265450 return f (b )
266451 nb = b .shape [1 ]
267452
268- if processes > 1 :
453+ if processes > 1 :
269454 res = parmap (f , [b [:, i ] for i in range (nb )], processes )
270455 else :
271456 res = list (map (f , [b [:, i ].copy () for i in range (nb )]))
272457
273458 return res
274459
275460
276-
277461def free_support_barycenter (measures_locations , measures_weights , X_init , b = None , weights = None , numItermax = 100 , stopThr = 1e-7 , verbose = False , log = None ):
278462 """
279463 Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
@@ -326,7 +510,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
326510 k = X_init .shape [0 ]
327511 d = X_init .shape [1 ]
328512 if b is None :
329- b = np .ones ((k ,))/ k
513+ b = np .ones ((k ,)) / k
330514 if weights is None :
331515 weights = np .ones ((N ,)) / N
332516
@@ -337,7 +521,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
337521
338522 displacement_square_norm = stopThr + 1.
339523
340- while ( displacement_square_norm > stopThr and iter_count < numItermax ):
524+ while (displacement_square_norm > stopThr and iter_count < numItermax ):
341525
342526 T_sum = np .zeros ((k , d ))
343527
@@ -347,7 +531,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
347531 T_i = emd (b , measure_weights_i , M_i )
348532 T_sum = T_sum + weight_i * np .reshape (1. / b , (- 1 , 1 )) * np .matmul (T_i , measure_locations_i )
349533
350- displacement_square_norm = np .sum (np .square (T_sum - X ))
534+ displacement_square_norm = np .sum (np .square (T_sum - X ))
351535 if log :
352536 displacement_square_norms .append (displacement_square_norm )
353537
0 commit comments