Skip to content

Commit 90efa5a

Browse files
authored
Merge pull request #47 from rflamary/bary
LP Wasserstein barycenter with scipy linear solver and/or cvxopt
2 parents ec79b79 + 54f0b47 commit 90efa5a

File tree

9 files changed

+488
-6
lines changed

9 files changed

+488
-6
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ notebook :
5858
ipython notebook --matplotlib=inline --notebook-dir=notebooks/
5959

6060
autopep8 :
61-
autopep8 -ir test ot examples
61+
autopep8 -ir test ot examples --jobs -1
6262

6363
aautopep8 :
64-
autopep8 -air test ot examples
64+
autopep8 -air test ot examples --jobs -1
6565

6666
FORCE :

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ This open source Python library provide several solvers for optimization problem
1414
It provides the following solvers:
1515

1616
* OT Network Flow solver for the linear program/ Earth Movers Distance [1].
17-
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (required cudamat).
17+
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat).
18+
* Non regularized Wasserstein barycenters [16] with LP solver.
1819
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
1920
* Optimal transport for domain adaptation with group lasso regularization [5]
2021
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
@@ -210,3 +211,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
210211
[14] Knott, M. and Smith, C. S. (1984).[On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43.
211212

212213
[15] Peyré, G., & Cuturi, M. (2018). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) .
214+
215+
[16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924.
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=================================================================================
4+
1D Wasserstein barycenter comparison between exact LP and entropic regularization
5+
=================================================================================
6+
7+
This example illustrates the computation of regularized Wasserstein Barycenter
8+
as proposed in [3] and exact LP barycenters using standard LP solver.
9+
10+
It reproduces approximately Figure 3.1 and 3.2 from the following paper:
11+
Cuturi, M., & Peyré, G. (2016). A smoothed dual approach for variational
12+
Wasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343.
13+
14+
[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
15+
Iterative Bregman projections for regularized transportation problems
16+
SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
17+
18+
19+
20+
"""
21+
22+
# Author: Remi Flamary <remi.flamary@unice.fr>
23+
#
24+
# License: MIT License
25+
26+
import numpy as np
27+
import matplotlib.pylab as pl
28+
import ot
29+
# necessary for 3d plot even if not used
30+
from mpl_toolkits.mplot3d import Axes3D # noqa
31+
from matplotlib.collections import PolyCollection # noqa
32+
33+
#import ot.lp.cvx as cvx
34+
35+
#
36+
# Generate data
37+
# -------------
38+
39+
#%% parameters
40+
41+
problems = []
42+
43+
n = 100 # nb bins
44+
45+
# bin positions
46+
x = np.arange(n, dtype=np.float64)
47+
48+
# Gaussian distributions
49+
# Gaussian distributions
50+
a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std
51+
a2 = ot.datasets.get_1D_gauss(n, m=60, s=8)
52+
53+
# creating matrix A containing all distributions
54+
A = np.vstack((a1, a2)).T
55+
n_distributions = A.shape[1]
56+
57+
# loss matrix + normalization
58+
M = ot.utils.dist0(n)
59+
M /= M.max()
60+
61+
#
62+
# Plot data
63+
# ---------
64+
65+
#%% plot the distributions
66+
67+
pl.figure(1, figsize=(6.4, 3))
68+
for i in range(n_distributions):
69+
pl.plot(x, A[:, i])
70+
pl.title('Distributions')
71+
pl.tight_layout()
72+
73+
#
74+
# Barycenter computation
75+
# ----------------------
76+
77+
#%% barycenter computation
78+
79+
alpha = 0.5 # 0<=alpha<=1
80+
weights = np.array([1 - alpha, alpha])
81+
82+
# l2bary
83+
bary_l2 = A.dot(weights)
84+
85+
# wasserstein
86+
reg = 1e-3
87+
ot.tic()
88+
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
89+
ot.toc()
90+
91+
92+
ot.tic()
93+
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
94+
ot.toc()
95+
96+
pl.figure(2)
97+
pl.clf()
98+
pl.subplot(2, 1, 1)
99+
for i in range(n_distributions):
100+
pl.plot(x, A[:, i])
101+
pl.title('Distributions')
102+
103+
pl.subplot(2, 1, 2)
104+
pl.plot(x, bary_l2, 'r', label='l2')
105+
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
106+
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
107+
pl.legend()
108+
pl.title('Barycenters')
109+
pl.tight_layout()
110+
111+
problems.append([A, [bary_l2, bary_wass, bary_wass2]])
112+
113+
#%% parameters
114+
115+
a1 = 1.0 * (x > 10) * (x < 50)
116+
a2 = 1.0 * (x > 60) * (x < 80)
117+
118+
a1 /= a1.sum()
119+
a2 /= a2.sum()
120+
121+
# creating matrix A containing all distributions
122+
A = np.vstack((a1, a2)).T
123+
n_distributions = A.shape[1]
124+
125+
# loss matrix + normalization
126+
M = ot.utils.dist0(n)
127+
M /= M.max()
128+
129+
130+
#%% plot the distributions
131+
132+
pl.figure(1, figsize=(6.4, 3))
133+
for i in range(n_distributions):
134+
pl.plot(x, A[:, i])
135+
pl.title('Distributions')
136+
pl.tight_layout()
137+
138+
#
139+
# Barycenter computation
140+
# ----------------------
141+
142+
#%% barycenter computation
143+
144+
alpha = 0.5 # 0<=alpha<=1
145+
weights = np.array([1 - alpha, alpha])
146+
147+
# l2bary
148+
bary_l2 = A.dot(weights)
149+
150+
# wasserstein
151+
reg = 1e-3
152+
ot.tic()
153+
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
154+
ot.toc()
155+
156+
157+
ot.tic()
158+
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
159+
ot.toc()
160+
161+
162+
problems.append([A, [bary_l2, bary_wass, bary_wass2]])
163+
164+
pl.figure(2)
165+
pl.clf()
166+
pl.subplot(2, 1, 1)
167+
for i in range(n_distributions):
168+
pl.plot(x, A[:, i])
169+
pl.title('Distributions')
170+
171+
pl.subplot(2, 1, 2)
172+
pl.plot(x, bary_l2, 'r', label='l2')
173+
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
174+
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
175+
pl.legend()
176+
pl.title('Barycenters')
177+
pl.tight_layout()
178+
179+
#%% parameters
180+
181+
a1 = np.zeros(n)
182+
a2 = np.zeros(n)
183+
184+
a1[10] = .25
185+
a1[20] = .5
186+
a1[30] = .25
187+
a2[80] = 1
188+
189+
190+
a1 /= a1.sum()
191+
a2 /= a2.sum()
192+
193+
# creating matrix A containing all distributions
194+
A = np.vstack((a1, a2)).T
195+
n_distributions = A.shape[1]
196+
197+
# loss matrix + normalization
198+
M = ot.utils.dist0(n)
199+
M /= M.max()
200+
201+
202+
#%% plot the distributions
203+
204+
pl.figure(1, figsize=(6.4, 3))
205+
for i in range(n_distributions):
206+
pl.plot(x, A[:, i])
207+
pl.title('Distributions')
208+
pl.tight_layout()
209+
210+
#
211+
# Barycenter computation
212+
# ----------------------
213+
214+
#%% barycenter computation
215+
216+
alpha = 0.5 # 0<=alpha<=1
217+
weights = np.array([1 - alpha, alpha])
218+
219+
# l2bary
220+
bary_l2 = A.dot(weights)
221+
222+
# wasserstein
223+
reg = 1e-3
224+
ot.tic()
225+
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
226+
ot.toc()
227+
228+
229+
ot.tic()
230+
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
231+
ot.toc()
232+
233+
234+
problems.append([A, [bary_l2, bary_wass, bary_wass2]])
235+
236+
pl.figure(2)
237+
pl.clf()
238+
pl.subplot(2, 1, 1)
239+
for i in range(n_distributions):
240+
pl.plot(x, A[:, i])
241+
pl.title('Distributions')
242+
243+
pl.subplot(2, 1, 2)
244+
pl.plot(x, bary_l2, 'r', label='l2')
245+
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
246+
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
247+
pl.legend()
248+
pl.title('Barycenters')
249+
pl.tight_layout()
250+
251+
252+
#
253+
# Final figure
254+
# ------------
255+
#
256+
257+
#%% plot
258+
259+
nbm = len(problems)
260+
nbm2 = (nbm // 2)
261+
262+
263+
pl.figure(2, (20, 6))
264+
pl.clf()
265+
266+
for i in range(nbm):
267+
268+
A = problems[i][0]
269+
bary_l2 = problems[i][1][0]
270+
bary_wass = problems[i][1][1]
271+
bary_wass2 = problems[i][1][2]
272+
273+
pl.subplot(2, nbm, 1 + i)
274+
for j in range(n_distributions):
275+
pl.plot(x, A[:, j])
276+
if i == nbm2:
277+
pl.title('Distributions')
278+
pl.xticks(())
279+
pl.yticks(())
280+
281+
pl.subplot(2, nbm, 1 + i + nbm)
282+
283+
pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)')
284+
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
285+
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
286+
if i == nbm - 1:
287+
pl.legend()
288+
if i == nbm2:
289+
pl.title('Barycenters')
290+
291+
pl.xticks(())
292+
pl.yticks(())

ot/bregman.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -839,11 +839,13 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
839839
Parameters
840840
----------
841841
A : np.ndarray (d,n)
842-
n training distributions of size d
842+
n training distributions a_i of size d
843843
M : np.ndarray (d,d)
844844
loss matrix for OT
845845
reg : float
846846
Regularization term >0
847+
weights : np.ndarray (n,)
848+
Weights of each histogram a_i on the simplex (barycentric coodinates)
847849
numItermax : int, optional
848850
Max number of iterations
849851
stopThr : float, optional

ot/lp/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111

1212
import numpy as np
1313

14+
from .import cvx
15+
1416
# import compiled emd
1517
from .emd_wrap import emd_c, check_result
1618
from ..utils import parmap
19+
from .cvx import barycenter
1720

1821

1922
def emd(a, b, M, numItermax=100000, log=False):

0 commit comments

Comments
 (0)