Skip to content

Commit 3f23fa1

Browse files
committed
free support barycenter
1 parent 6492e95 commit 3f23fa1

File tree

2 files changed

+83
-12
lines changed

2 files changed

+83
-12
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
====================================================
4+
2D Wasserstein barycenters between empirical distributions
5+
====================================================
6+
7+
Illustration of 2D Wasserstein barycenters between discributions that are weighted
8+
sum of diracs.
9+
10+
"""
11+
12+
# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
13+
#
14+
# License: MIT License
15+
16+
import numpy as np
17+
import matplotlib.pylab as pl
18+
import ot.plot
19+
20+
21+
##############################################################################
22+
# Generate data
23+
# -------------
24+
#%% parameters and data generation
25+
N = 4
26+
d = 2
27+
measures_locations = []
28+
measures_weights = []
29+
30+
for i in range(N):
31+
32+
n = np.rand.int(low=1, high=20) # nb samples
33+
34+
mu = np.random.normal(0., 1., (d,))
35+
cov = np.random.normal(0., 1., (d,d))
36+
37+
xs = ot.datasets.make_2D_samples_gauss(n, mu, cov)
38+
b = np.random.uniform(0., 1., n)
39+
b = b/np.sum(b)
40+
41+
measures_locations.append(xs)
42+
measures_weights.append(b)
43+
44+
k = 10
45+
X_init = np.random.normal(0., 1., (k,d))
46+
b_init = np.ones((k,)) / k
47+
48+
49+
##############################################################################
50+
# Compute free support barycenter
51+
# -------------
52+
X = ot.lp.barycenter(measures_locations, measures_weights, X_init, b_init)
53+
54+
55+
##############################################################################
56+
# Plot data
57+
# ---------
58+
59+
#%% plot samples
60+
61+
pl.figure(1)
62+
for (xs, b) in zip(measures_locations, measures_weights):
63+
pl.scatter(xs[:, 0], xs[:, 1], s=b, c=np.tile(np.rand(0. ,255., size=(3,)), (1,b.size(0))) , label='Data measures')
64+
pl.scatter(xs[:, 0], xs[:, 1], s=b, c='black' , label='2-Wasserstein barycenter')
65+
pl.legend(loc=0)
66+
pl.title('Data measures and their barycenter')

ot/lp/cvx.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def scipy_sparse_to_spmatrix(A):
2727

2828

2929
def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
30-
"""Compute the entropic regularized wasserstein barycenter of distributions A
30+
"""Compute the Wasserstein barycenter of distributions A
3131
3232
The function solves the following optimization problem [16]:
3333
@@ -149,7 +149,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
149149

150150

151151

152-
def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda, numItermax=100, stopThr=1e-5, verbose=False, log=False, **kwargs):
152+
def free_support_barycenter(measures_locations, measures_weights, X_init, b_init, weights=None, numItermax=100, stopThr=1e-6, verbose=False):
153153

154154
"""
155155
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
@@ -170,7 +170,7 @@ def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda,
170170
Initialization of the support locations (on k atoms) of the barycenter
171171
b_init : (k,) np.ndarray
172172
Initialization of the weights of the barycenter (non-negatives, sum to 1)
173-
lambda : (k,) np.ndarray
173+
weights : (k,) np.ndarray
174174
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
175175
176176
numItermax : int, optional
@@ -200,25 +200,30 @@ def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda,
200200

201201
d = X_init.shape[1]
202202
k = b_init.size
203-
N = len(data_positions)
203+
N = len(measures_locations)
204+
205+
if not weights:
206+
weights = np.ones((N,))/N
204207

205208
X = X_init
206209

207-
displacement_square_norm = 1e3
210+
displacement_square_norm = stopThr+1.
208211

209212
while ( displacement_square_norm > stopThr and iter_count < numItermax ):
210213

211214
T_sum = np.zeros((k, d))
212215

213-
for (data_positions_i, data_weights_i) in zip(data_positions, data_weights):
214-
M_i = ot.dist(X, data_positions_i)
215-
T_i = ot.emd(b_init, data_weights_i, M_i)
216-
T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, data_positions_i)
216+
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
217+
218+
M_i = ot.dist(X, measure_locations_i)
219+
T_i = ot.emd(b_init, measure_weights_i, M_i)
220+
T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, measure_locations_i)
217221

218-
X_previous = X
219-
X = T_sum / N
222+
displacement_square_norm = np.sum(np.square(X-T_sum))
223+
X = T_sum
220224

221-
displacement_square_norm = np.sum(np.square(X-X_previous))
225+
if verbose:
226+
print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
222227

223228
iter_count += 1
224229

0 commit comments

Comments
 (0)