Skip to content

Commit 4671279

Browse files
author
Vivien Seguy
committed
add test free support barycenter algorithm + cleaning
1 parent 67ddb92 commit 4671279

File tree

4 files changed

+121
-97
lines changed

4 files changed

+121
-97
lines changed

examples/plot_free_support_barycenter.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22
"""
33
====================================================
4-
2D Wasserstein barycenters of distributions
4+
2D free support Wasserstein barycenters of distributions
55
====================================================
66
77
Illustration of 2D Wasserstein barycenters if discributions that are weighted
@@ -15,7 +15,8 @@
1515

1616
import numpy as np
1717
import matplotlib.pylab as pl
18-
import ot.plot
18+
import ot
19+
1920

2021
##############################################################################
2122
# Generate data
@@ -28,16 +29,16 @@
2829

2930
for i in range(N):
3031

31-
n = np.random.randint(low=1, high=20) # nb samples
32+
n_i = np.random.randint(low=1, high=20) # nb samples
3233

33-
mu = np.random.normal(0., 4., (d,))
34+
mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean
3435

35-
A = np.random.rand(d, d)
36-
cov = np.dot(A, A.transpose())
36+
A_i = np.random.rand(d, d)
37+
cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix
3738

38-
x_i = ot.datasets.make_2D_samples_gauss(n, mu, cov)
39-
b_i = np.random.uniform(0., 1., (n,))
40-
b_i = b_i / np.sum(b_i)
39+
x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
40+
b_i = np.random.uniform(0., 1., (n_i,))
41+
b_i = b_i / np.sum(b_i) # Dirac weights
4142

4243
measures_locations.append(x_i)
4344
measures_weights.append(b_i)
@@ -47,19 +48,17 @@
4748
# Compute free support barycenter
4849
# -------------
4950

50-
k = 10
51-
X_init = np.random.normal(0., 1., (k, d))
52-
b = np.ones((k,)) / k
51+
k = 10 # number of Diracs of the barycenter
52+
X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
53+
b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
5354

54-
X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b)
55+
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
5556

5657

5758
##############################################################################
5859
# Plot data
5960
# ---------
6061

61-
#%% plot samples
62-
6362
pl.figure(1)
6463
for (x_i, b_i) in zip(measures_locations, measures_weights):
6564
color = np.random.randint(low=1, high=10 * N)

ot/lp/__init__.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
from .emd_wrap import emd_c, check_result
1818
from ..utils import parmap
1919
from .cvx import barycenter
20+
from ..utils import dist
2021

21-
__all__=['emd', 'emd2', 'barycenter', 'cvx']
22+
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx']
2223

2324

2425
def emd(a, b, M, numItermax=100000, log=False):
@@ -216,3 +217,92 @@ def f(b):
216217

217218
res = parmap(f, [b[:, i] for i in range(nb)], processes)
218219
return res
220+
221+
222+
223+
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
224+
"""
225+
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
226+
227+
The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
228+
This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
229+
- we do not optimize over the weights
230+
- we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
231+
232+
Parameters
233+
----------
234+
measures_locations : list of (k_i,d) np.ndarray
235+
The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
236+
measures_weights : list of (k_i,) np.ndarray
237+
Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
238+
239+
X_init : (k,d) np.ndarray
240+
Initialization of the support locations (on k atoms) of the barycenter
241+
b : (k,) np.ndarray
242+
Initialization of the weights of the barycenter (non-negatives, sum to 1)
243+
weights : (k,) np.ndarray
244+
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
245+
246+
numItermax : int, optional
247+
Max number of iterations
248+
stopThr : float, optional
249+
Stop threshol on error (>0)
250+
verbose : bool, optional
251+
Print information along iterations
252+
log : bool, optional
253+
record log if True
254+
255+
Returns
256+
-------
257+
X : (k,d) np.ndarray
258+
Support locations (on k atoms) of the barycenter
259+
260+
References
261+
----------
262+
263+
.. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
264+
265+
.. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
266+
267+
"""
268+
269+
iter_count = 0
270+
271+
N = len(measures_locations)
272+
k = X_init.shape[0]
273+
d = X_init.shape[1]
274+
if b is None:
275+
b = np.ones((k,))/k
276+
if weights is None:
277+
weights = np.ones((N,)) / N
278+
279+
X = X_init
280+
281+
displacement_square_norms = []
282+
displacement_square_norm = stopThr + 1.
283+
284+
while ( displacement_square_norm > stopThr and iter_count < numItermax ):
285+
286+
T_sum = np.zeros((k, d))
287+
288+
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
289+
290+
M_i = dist(X, measure_locations_i)
291+
T_i = emd(b, measure_weights_i, M_i)
292+
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
293+
294+
displacement_square_norm = np.sum(np.square(T_sum-X))
295+
if log:
296+
displacement_square_norms.append(displacement_square_norm)
297+
298+
X = T_sum
299+
300+
if verbose:
301+
print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
302+
303+
iter_count += 1
304+
305+
if log:
306+
return X, displacement_square_norms
307+
else:
308+
return X

ot/lp/cvx.py

Lines changed: 1 addition & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import numpy as np
1111
import scipy as sp
1212
import scipy.sparse as sps
13-
import ot
13+
1414

1515
try:
1616
import cvxopt
@@ -145,83 +145,3 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
145145
return b, sol
146146
else:
147147
return b
148-
149-
150-
def free_support_barycenter(measures_locations, measures_weights, X_init, b, weights=None, numItermax=100, stopThr=1e-6, verbose=False):
151-
"""
152-
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
153-
154-
The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
155-
This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
156-
- we do not optimize over the weights
157-
- we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
158-
159-
Parameters
160-
----------
161-
data_positions : list of (k_i,d) np.ndarray
162-
The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
163-
data_weights : list of (k_i,) np.ndarray
164-
Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
165-
166-
X_init : (k,d) np.ndarray
167-
Initialization of the support locations (on k atoms) of the barycenter
168-
b : (k,) np.ndarray
169-
Initialization of the weights of the barycenter (non-negatives, sum to 1)
170-
weights : (k,) np.ndarray
171-
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
172-
173-
numItermax : int, optional
174-
Max number of iterations
175-
stopThr : float, optional
176-
Stop threshol on error (>0)
177-
verbose : bool, optional
178-
Print information along iterations
179-
log : bool, optional
180-
record log if True
181-
182-
Returns
183-
-------
184-
X : (k,d) np.ndarray
185-
Support locations (on k atoms) of the barycenter
186-
187-
References
188-
----------
189-
190-
.. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
191-
192-
.. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
193-
194-
"""
195-
196-
iter_count = 0
197-
198-
d = X_init.shape[1]
199-
k = b.size
200-
N = len(measures_locations)
201-
202-
if not weights:
203-
weights = np.ones((N,)) / N
204-
205-
X = X_init
206-
207-
displacement_square_norm = stopThr + 1.
208-
209-
while (displacement_square_norm > stopThr and iter_count < numItermax):
210-
211-
T_sum = np.zeros((k, d))
212-
213-
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
214-
215-
M_i = ot.dist(X, measure_locations_i)
216-
T_i = ot.emd(b, measure_weights_i, M_i)
217-
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
218-
219-
displacement_square_norm = np.sum(np.square(X - T_sum))
220-
X = T_sum
221-
222-
if verbose:
223-
print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
224-
225-
iter_count += 1
226-
227-
return X

test/test_ot.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,21 @@ def test_lp_barycenter():
135135
np.testing.assert_allclose(bary.sum(), 1)
136136

137137

138+
def test_free_support_barycenter():
139+
140+
measures_locations = [np.array([-1.]).reshape((1,1)), np.array([1.]).reshape((1,1))]
141+
measures_weights = [np.array([1.]), np.array([1.])]
142+
143+
X_init = np.array([-12.]).reshape((1,1))
144+
145+
# obvious barycenter location between two diracs
146+
bar_locations = np.array([0.]).reshape((1,1))
147+
148+
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init)
149+
150+
np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
151+
152+
138153
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
139154
def test_lp_barycenter_cvxopt():
140155

0 commit comments

Comments
 (0)