@@ -27,7 +27,7 @@ def scipy_sparse_to_spmatrix(A):
2727
2828
2929def barycenter (A , M , weights = None , verbose = False , log = False , solver = 'interior-point' ):
30- """Compute the entropic regularized wasserstein barycenter of distributions A
30+ """Compute the Wasserstein barycenter of distributions A
3131
3232 The function solves the following optimization problem [16]:
3333
@@ -149,7 +149,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
149149
150150
151151
152- def free_support_barycenter (data_positions , data_weights , X_init , b_init , lamda , numItermax = 100 , stopThr = 1e-5 , verbose = False , log = False , ** kwargs ):
152+ def free_support_barycenter (measures_locations , measures_weights , X_init , b_init , weights = None , numItermax = 100 , stopThr = 1e-6 , verbose = False ):
153153
154154 """
155155 Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
@@ -170,7 +170,7 @@ def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda,
170170 Initialization of the support locations (on k atoms) of the barycenter
171171 b_init : (k,) np.ndarray
172172 Initialization of the weights of the barycenter (non-negatives, sum to 1)
173- lambda : (k,) np.ndarray
173+ weights : (k,) np.ndarray
174174 Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
175175
176176 numItermax : int, optional
@@ -200,25 +200,30 @@ def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda,
200200
201201 d = X_init .shape [1 ]
202202 k = b_init .size
203- N = len (data_positions )
203+ N = len (measures_locations )
204+
205+ if not weights :
206+ weights = np .ones ((N ,))/ N
204207
205208 X = X_init
206209
207- displacement_square_norm = 1e3
210+ displacement_square_norm = stopThr + 1.
208211
209212 while ( displacement_square_norm > stopThr and iter_count < numItermax ):
210213
211214 T_sum = np .zeros ((k , d ))
212215
213- for (data_positions_i , data_weights_i ) in zip (data_positions , data_weights ):
214- M_i = ot .dist (X , data_positions_i )
215- T_i = ot .emd (b_init , data_weights_i , M_i )
216- T_sum += np .reshape (1. / b_init , (- 1 , 1 )) * np .matmul (T_i , data_positions_i )
216+ for (measure_locations_i , measure_weights_i , weight_i ) in zip (measures_locations , measures_weights , weights .tolist ()):
217+
218+ M_i = ot .dist (X , measure_locations_i )
219+ T_i = ot .emd (b_init , measure_weights_i , M_i )
220+ T_sum += np .reshape (1. / b_init , (- 1 , 1 )) * np .matmul (T_i , measure_locations_i )
217221
218- X_previous = X
219- X = T_sum / N
222+ displacement_square_norm = np . sum ( np . square ( X - T_sum ))
223+ X = T_sum
220224
221- displacement_square_norm = np .sum (np .square (X - X_previous ))
225+ if verbose :
226+ print ('iteration %d, displacement_square_norm=%f\n ' , iter_count , displacement_square_norm )
222227
223228 iter_count += 1
224229
0 commit comments