Skip to content

Commit da8f611

Browse files
committed
add example
1 parent bd1af44 commit da8f611

File tree

1 file changed

+284
-0
lines changed

1 file changed

+284
-0
lines changed
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=================================================================================
4+
1D Wasserstein barycenter comparison between exact LP and entropic regularization
5+
=================================================================================
6+
7+
This example illustrates the computation of regularized Wassersyein Barycenter
8+
as proposed in [3].
9+
10+
11+
[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
12+
Iterative Bregman projections for regularized transportation problems
13+
SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
14+
15+
"""
16+
17+
# Author: Remi Flamary <remi.flamary@unice.fr>
18+
#
19+
# License: MIT License
20+
21+
import numpy as np
22+
import matplotlib.pylab as pl
23+
import ot
24+
# necessary for 3d plot even if not used
25+
from mpl_toolkits.mplot3d import Axes3D # noqa
26+
from matplotlib.collections import PolyCollection # noqa
27+
28+
#import ot.lp.cvx as cvx
29+
30+
#
31+
# Generate data
32+
# -------------
33+
34+
#%% parameters
35+
36+
problems = []
37+
38+
n = 100 # nb bins
39+
40+
# bin positions
41+
x = np.arange(n, dtype=np.float64)
42+
43+
# Gaussian distributions
44+
# Gaussian distributions
45+
a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std
46+
a2 = ot.datasets.get_1D_gauss(n, m=60, s=8)
47+
48+
# creating matrix A containing all distributions
49+
A = np.vstack((a1, a2)).T
50+
n_distributions = A.shape[1]
51+
52+
# loss matrix + normalization
53+
M = ot.utils.dist0(n)
54+
M /= M.max()
55+
56+
#
57+
# Plot data
58+
# ---------
59+
60+
#%% plot the distributions
61+
62+
pl.figure(1, figsize=(6.4, 3))
63+
for i in range(n_distributions):
64+
pl.plot(x, A[:, i])
65+
pl.title('Distributions')
66+
pl.tight_layout()
67+
68+
#
69+
# Barycenter computation
70+
# ----------------------
71+
72+
#%% barycenter computation
73+
74+
alpha = 0.5 # 0<=alpha<=1
75+
weights = np.array([1 - alpha, alpha])
76+
77+
# l2bary
78+
bary_l2 = A.dot(weights)
79+
80+
# wasserstein
81+
reg = 1e-3
82+
ot.tic()
83+
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
84+
ot.toc()
85+
86+
87+
ot.tic()
88+
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
89+
ot.toc()
90+
91+
pl.figure(2)
92+
pl.clf()
93+
pl.subplot(2, 1, 1)
94+
for i in range(n_distributions):
95+
pl.plot(x, A[:, i])
96+
pl.title('Distributions')
97+
98+
pl.subplot(2, 1, 2)
99+
pl.plot(x, bary_l2, 'r', label='l2')
100+
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
101+
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
102+
pl.legend()
103+
pl.title('Barycenters')
104+
pl.tight_layout()
105+
106+
problems.append([A, [bary_l2, bary_wass, bary_wass2]])
107+
108+
#%% parameters
109+
110+
a1 = 1.0 * (x > 10) * (x < 50)
111+
a2 = 1.0 * (x > 60) * (x < 80)
112+
113+
a1 /= a1.sum()
114+
a2 /= a2.sum()
115+
116+
# creating matrix A containing all distributions
117+
A = np.vstack((a1, a2)).T
118+
n_distributions = A.shape[1]
119+
120+
# loss matrix + normalization
121+
M = ot.utils.dist0(n)
122+
M /= M.max()
123+
124+
125+
#%% plot the distributions
126+
127+
pl.figure(1, figsize=(6.4, 3))
128+
for i in range(n_distributions):
129+
pl.plot(x, A[:, i])
130+
pl.title('Distributions')
131+
pl.tight_layout()
132+
133+
#
134+
# Barycenter computation
135+
# ----------------------
136+
137+
#%% barycenter computation
138+
139+
alpha = 0.5 # 0<=alpha<=1
140+
weights = np.array([1 - alpha, alpha])
141+
142+
# l2bary
143+
bary_l2 = A.dot(weights)
144+
145+
# wasserstein
146+
reg = 1e-3
147+
ot.tic()
148+
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
149+
ot.toc()
150+
151+
152+
ot.tic()
153+
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
154+
ot.toc()
155+
156+
157+
problems.append([A, [bary_l2, bary_wass, bary_wass2]])
158+
159+
pl.figure(2)
160+
pl.clf()
161+
pl.subplot(2, 1, 1)
162+
for i in range(n_distributions):
163+
pl.plot(x, A[:, i])
164+
pl.title('Distributions')
165+
166+
pl.subplot(2, 1, 2)
167+
pl.plot(x, bary_l2, 'r', label='l2')
168+
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
169+
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
170+
pl.legend()
171+
pl.title('Barycenters')
172+
pl.tight_layout()
173+
174+
#%% parameters
175+
176+
a1 = np.zeros(n)
177+
a2 = np.zeros(n)
178+
179+
a1[10] = .25
180+
a1[20] = .5
181+
a1[30] = .25
182+
a2[80] = 1
183+
184+
185+
a1 /= a1.sum()
186+
a2 /= a2.sum()
187+
188+
# creating matrix A containing all distributions
189+
A = np.vstack((a1, a2)).T
190+
n_distributions = A.shape[1]
191+
192+
# loss matrix + normalization
193+
M = ot.utils.dist0(n)
194+
M /= M.max()
195+
196+
197+
#%% plot the distributions
198+
199+
pl.figure(1, figsize=(6.4, 3))
200+
for i in range(n_distributions):
201+
pl.plot(x, A[:, i])
202+
pl.title('Distributions')
203+
pl.tight_layout()
204+
205+
#
206+
# Barycenter computation
207+
# ----------------------
208+
209+
#%% barycenter computation
210+
211+
alpha = 0.5 # 0<=alpha<=1
212+
weights = np.array([1 - alpha, alpha])
213+
214+
# l2bary
215+
bary_l2 = A.dot(weights)
216+
217+
# wasserstein
218+
reg = 1e-3
219+
ot.tic()
220+
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
221+
ot.toc()
222+
223+
224+
ot.tic()
225+
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
226+
ot.toc()
227+
228+
229+
problems.append([A, [bary_l2, bary_wass, bary_wass2]])
230+
231+
pl.figure(2)
232+
pl.clf()
233+
pl.subplot(2, 1, 1)
234+
for i in range(n_distributions):
235+
pl.plot(x, A[:, i])
236+
pl.title('Distributions')
237+
238+
pl.subplot(2, 1, 2)
239+
pl.plot(x, bary_l2, 'r', label='l2')
240+
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
241+
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
242+
pl.legend()
243+
pl.title('Barycenters')
244+
pl.tight_layout()
245+
246+
247+
#
248+
# Final figure
249+
# ------------
250+
#
251+
252+
#%% plot
253+
254+
nbm = len(problems)
255+
nbm2 = (nbm // 2)
256+
257+
258+
pl.figure(2, (20, 6))
259+
pl.clf()
260+
261+
for i in range(nbm):
262+
263+
A = problems[i][0]
264+
bary_l2 = problems[i][1][0]
265+
bary_wass = problems[i][1][1]
266+
bary_wass2 = problems[i][1][2]
267+
268+
pl.subplot(2, nbm, 1 + i)
269+
for j in range(n_distributions):
270+
pl.plot(x, A[:, j])
271+
if i == nbm2:
272+
pl.title('Distributions')
273+
pl.xticks(())
274+
pl.yticks(())
275+
276+
pl.subplot(2, nbm, 1 + i)
277+
278+
pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)')
279+
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
280+
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
281+
if i == nbm - 1:
282+
pl.legend()
283+
if i == nbm2:
284+
pl.title('Barycenters')

0 commit comments

Comments
 (0)