Skip to content

Commit 6492e95

Browse files
committed
free support barycenter
1 parent 39cbcd3 commit 6492e95

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

ot/lp/cvx.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111
import scipy as sp
1212
import scipy.sparse as sps
13+
import ot
1314

1415
try:
1516
import cvxopt
@@ -144,3 +145,82 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
144145
return b, sol
145146
else:
146147
return b
148+
149+
150+
151+
152+
def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda, numItermax=100, stopThr=1e-5, verbose=False, log=False, **kwargs):
153+
154+
"""
155+
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
156+
157+
The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
158+
This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
159+
- we do not optimize over the weights
160+
- we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
161+
162+
Parameters
163+
----------
164+
data_positions : list of (k_i,d) np.ndarray
165+
The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
166+
data_weights : list of (k_i,) np.ndarray
167+
Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
168+
169+
X_init : (k,d) np.ndarray
170+
Initialization of the support locations (on k atoms) of the barycenter
171+
b_init : (k,) np.ndarray
172+
Initialization of the weights of the barycenter (non-negatives, sum to 1)
173+
lambda : (k,) np.ndarray
174+
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
175+
176+
numItermax : int, optional
177+
Max number of iterations
178+
stopThr : float, optional
179+
Stop threshol on error (>0)
180+
verbose : bool, optional
181+
Print information along iterations
182+
log : bool, optional
183+
record log if True
184+
185+
Returns
186+
-------
187+
X : (k,d) np.ndarray
188+
Support locations (on k atoms) of the barycenter
189+
190+
References
191+
----------
192+
193+
.. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
194+
195+
.. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
196+
197+
"""
198+
199+
iter_count = 0
200+
201+
d = X_init.shape[1]
202+
k = b_init.size
203+
N = len(data_positions)
204+
205+
X = X_init
206+
207+
displacement_square_norm = 1e3
208+
209+
while ( displacement_square_norm > stopThr and iter_count < numItermax ):
210+
211+
T_sum = np.zeros((k, d))
212+
213+
for (data_positions_i, data_weights_i) in zip(data_positions, data_weights):
214+
M_i = ot.dist(X, data_positions_i)
215+
T_i = ot.emd(b_init, data_weights_i, M_i)
216+
T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, data_positions_i)
217+
218+
X_previous = X
219+
X = T_sum / N
220+
221+
displacement_square_norm = np.sum(np.square(X-X_previous))
222+
223+
iter_count += 1
224+
225+
return X
226+

0 commit comments

Comments
 (0)