2323from .cvx import barycenter
2424from ..utils import dist
2525
26- __all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' ,
27- 'emd_1d' , 'emd2_1d' , 'wasserstein_1d' ]
26+ __all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' ,
27+ 'emd_1d' , 'emd2_1d' , 'wasserstein_1d' ]
2828
2929
30- def emd (a , b , M , numItermax = 100000 , log = False , dense = True ):
30+ def center_ot_dual (alpha0 , beta0 , a = None , b = None ):
31+ r"""Center dual OT potentials wrt theirs weights
32+
33+ The main idea of this function is to find unique dual potentials
34+ that ensure some kind of centering/fairness. It will help have
35+ stability when multiple calling of the OT solver with small changes.
36+
37+ Basically we add another constraint to the potential that will not
38+ change the objective value but will ensure unicity. The constraint
39+ is the following:
40+
41+ .. math::
42+ \alpha^T a= \beta^T b
43+
44+ in addition to the OT problem constraints.
45+
46+ since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing
47+ a constant from both :math:`\alpha_0` and :math:`\beta_0`.
48+
49+ .. math::
50+ c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta}
51+
52+ \alpha=\alpha_0+c
53+
54+ \beta=\beta0+c
55+
56+ Parameters
57+ ----------
58+ alpha0 : (ns,) numpy.ndarray, float64
59+ Source dual potential
60+ beta0 : (nt,) numpy.ndarray, float64
61+ Target dual potential
62+ a : (ns,) numpy.ndarray, float64
63+ Source histogram (uniform weight if empty list)
64+ b : (nt,) numpy.ndarray, float64
65+ Target histogram (uniform weight if empty list)
66+
67+ Returns
68+ -------
69+ alpha : (ns,) numpy.ndarray, float64
70+ Source centered dual potential
71+ beta : (nt,) numpy.ndarray, float64
72+ Target centered dual potential
73+
74+ """
75+ # if no weights are provided, use uniform
76+ if a is None :
77+ a = np .ones (alpha0 .shape [0 ]) / alpha0 .shape [0 ]
78+ if b is None :
79+ b = np .ones (beta0 .shape [0 ]) / beta0 .shape [0 ]
80+
81+ # compute constant that balances the weighted sums of the duals
82+ c = (b .dot (beta0 ) - a .dot (alpha0 )) / (a .sum () + b .sum ())
83+
84+ # update duals
85+ alpha = alpha0 + c
86+ beta = beta0 - c
87+
88+ return alpha , beta
89+
90+
91+ def estimate_dual_null_weights (alpha0 , beta0 , a , b , M ):
92+ r"""Estimate feasible values for 0-weighted dual potentials
93+
94+ The feasible values are computed efficiently bjt rather coarsely.
95+ First we compute the constraints violations:
96+
97+ .. math::
98+ V=\alpha+\beta^T-M
99+
100+ Next we compute the max amount of violation per row (alpha) and
101+ columns (beta)
102+
103+ .. math::
104+ v^a_i=\max_j V_{i,j}
105+
106+ v^b_j=\max_i V_{i,j}
107+
108+ Finally we update the dual potential with 0 weights if a
109+ constraint is violated
110+
111+ .. math::
112+ \alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0
113+
114+ \beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0
115+
116+ In the end the dual potential are centred using function
117+ :ref:`center_ot_dual`.
118+
119+ Note that all those updates do not change the objective value of the
120+ solution but provide dual potential that do not violate the constraints.
121+
122+ Parameters
123+ ----------
124+ alpha0 : (ns,) numpy.ndarray, float64
125+ Source dual potential
126+ beta0 : (nt,) numpy.ndarray, float64
127+ Target dual potential
128+ alpha0 : (ns,) numpy.ndarray, float64
129+ Source dual potential
130+ beta0 : (nt,) numpy.ndarray, float64
131+ Target dual potential
132+ a : (ns,) numpy.ndarray, float64
133+ Source histogram (uniform weight if empty list)
134+ b : (nt,) numpy.ndarray, float64
135+ Target histogram (uniform weight if empty list)
136+ M : (ns,nt) numpy.ndarray, float64
137+ Loss matrix (c-order array with type float64)
138+
139+ Returns
140+ -------
141+ alpha : (ns,) numpy.ndarray, float64
142+ Source corrected dual potential
143+ beta : (nt,) numpy.ndarray, float64
144+ Target corrected dual potential
145+
146+ """
147+
148+ # binary indexing of non-zeros weights
149+ asel = a != 0
150+ bsel = b != 0
151+
152+ # compute dual constraints violation
153+ Viol = alpha0 [:, None ] + beta0 [None , :] - M
154+
155+ # Compute worst violation per line and columns
156+ aviol = np .max (Viol , 1 )
157+ bviol = np .max (Viol , 0 )
158+
159+ # update corrects violation of
160+ alpha_up = - 1 * ~ asel * np .maximum (aviol , 0 )
161+ beta_up = - 1 * ~ bsel * np .maximum (bviol , 0 )
162+
163+ alpha = alpha0 + alpha_up
164+ beta = beta0 + beta_up
165+
166+ return center_ot_dual (alpha , beta , a , b )
167+
168+
169+ def emd (a , b , M , numItermax = 100000 , log = False , dense = True , center_dual = True ):
31170 r"""Solves the Earth Movers distance problem and returns the OT matrix
32171
33172
@@ -43,7 +182,7 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
43182 - a and b are the sample weights
44183
45184 .. warning::
46- Note that the M matrix needs to be a C-order numpy.array in float64
185+ Note that the M matrix needs to be a C-order numpy.array in float64
47186 format.
48187
49188 Uses the algorithm proposed in [1]_
@@ -66,6 +205,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
66205 If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
67206 Otherwise returns a sparse representation using scipy's `coo_matrix`
68207 format.
208+ center_dual: boolean, optional (default=True)
209+ If True, centers the dual potential using function
210+ :ref:`center_ot_dual`.
69211
70212 Returns
71213 -------
@@ -107,7 +249,6 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
107249 b = np .asarray (b , dtype = np .float64 )
108250 M = np .asarray (M , dtype = np .float64 )
109251
110-
111252 # if empty array given then use uniform distributions
112253 if len (a ) == 0 :
113254 a = np .ones ((M .shape [0 ],), dtype = np .float64 ) / M .shape [0 ]
@@ -117,11 +258,21 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
117258 assert (a .shape [0 ] == M .shape [0 ] and b .shape [0 ] == M .shape [1 ]), \
118259 "Dimension mismatch, check dimensions of M with a and b"
119260
261+ asel = a != 0
262+ bsel = b != 0
263+
120264 if dense :
121- G , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
265+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
266+
267+ if np .any (~ asel ) or np .any (~ bsel ):
268+ u , v = estimate_dual_null_weights (u , v , a , b , M )
269+
122270 else :
123- Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
124- G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
271+ Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
272+ G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
273+
274+ if np .any (~ asel ) or np .any (~ bsel ):
275+ u , v = estimate_dual_null_weights (u , v , a , b , M )
125276
126277 result_code_string = check_result (result_code )
127278 if log :
@@ -151,7 +302,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
151302 - a and b are the sample weights
152303
153304 .. warning::
154- Note that the M matrix needs to be a C-order numpy.array in float64
305+ Note that the M matrix needs to be a C-order numpy.array in float64
155306 format.
156307
157308 Uses the algorithm proposed in [1]_
@@ -177,7 +328,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
177328 dense: boolean, optional (default=True)
178329 If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
179330 Otherwise returns a sparse representation using scipy's `coo_matrix`
180- format.
331+ format.
181332
182333 Returns
183334 -------
@@ -221,7 +372,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
221372
222373 # problem with pikling Forks
223374 if sys .platform .endswith ('win32' ):
224- processes = 1
375+ processes = 1
225376
226377 # if empty array given then use uniform distributions
227378 if len (a ) == 0 :
@@ -235,10 +386,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
235386 if log or return_matrix :
236387 def f (b ):
237388 if dense :
238- G , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
389+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
239390 else :
240- Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
241- G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
391+ Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
392+ G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
242393
243394 result_code_string = check_result (result_code )
244395 log = {}
@@ -252,10 +403,10 @@ def f(b):
252403 else :
253404 def f (b ):
254405 if dense :
255- G , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
406+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
256407 else :
257- Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
258- G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
408+ Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax , dense )
409+ G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
259410
260411 result_code_string = check_result (result_code )
261412 check_result (result_code )
@@ -265,15 +416,14 @@ def f(b):
265416 return f (b )
266417 nb = b .shape [1 ]
267418
268- if processes > 1 :
419+ if processes > 1 :
269420 res = parmap (f , [b [:, i ] for i in range (nb )], processes )
270421 else :
271422 res = list (map (f , [b [:, i ].copy () for i in range (nb )]))
272423
273424 return res
274425
275426
276-
277427def free_support_barycenter (measures_locations , measures_weights , X_init , b = None , weights = None , numItermax = 100 , stopThr = 1e-7 , verbose = False , log = None ):
278428 """
279429 Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
@@ -326,7 +476,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
326476 k = X_init .shape [0 ]
327477 d = X_init .shape [1 ]
328478 if b is None :
329- b = np .ones ((k ,))/ k
479+ b = np .ones ((k ,)) / k
330480 if weights is None :
331481 weights = np .ones ((N ,)) / N
332482
@@ -337,7 +487,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
337487
338488 displacement_square_norm = stopThr + 1.
339489
340- while ( displacement_square_norm > stopThr and iter_count < numItermax ):
490+ while (displacement_square_norm > stopThr and iter_count < numItermax ):
341491
342492 T_sum = np .zeros ((k , d ))
343493
@@ -347,7 +497,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
347497 T_i = emd (b , measure_weights_i , M_i )
348498 T_sum = T_sum + weight_i * np .reshape (1. / b , (- 1 , 1 )) * np .matmul (T_i , measure_locations_i )
349499
350- displacement_square_norm = np .sum (np .square (T_sum - X ))
500+ displacement_square_norm = np .sum (np .square (T_sum - X ))
351501 if log :
352502 displacement_square_norms .append (displacement_square_norm )
353503
0 commit comments