Skip to content

Commit 75c988f

Browse files
committed
do plot_barycenter_1D
1 parent c6cb1cd commit 75c988f

File tree

2 files changed

+58
-57
lines changed

2 files changed

+58
-57
lines changed

examples/plot_barycenter_1D.py

Lines changed: 54 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -11,128 +11,125 @@
1111
import numpy as np
1212
import matplotlib.pylab as pl
1313
import ot
14-
from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used
14+
# necessary for 3d plot even if not used
15+
from mpl_toolkits.mplot3d import Axes3D # noqa
1516
from matplotlib.collections import PolyCollection
1617

1718

1819
#%% parameters
1920

20-
n=100 # nb bins
21+
n = 100 # nb bins
2122

2223
# bin positions
23-
x=np.arange(n,dtype=np.float64)
24+
x = np.arange(n, dtype=np.float64)
2425

2526
# Gaussian distributions
26-
a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std
27-
a2=ot.datasets.get_1D_gauss(n,m=60,s=8)
27+
a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std
28+
a2 = ot.datasets.get_1D_gauss(n, m=60, s=8)
2829

2930
# creating matrix A containing all distributions
30-
A=np.vstack((a1,a2)).T
31-
nbd=A.shape[1]
31+
A = np.vstack((a1, a2)).T
32+
n_distributions = A.shape[1]
3233

3334
# loss matrix + normalization
34-
M=ot.utils.dist0(n)
35-
M/=M.max()
35+
M = ot.utils.dist0(n)
36+
M /= M.max()
3637

3738
#%% plot the distributions
3839

39-
pl.figure(1)
40-
for i in range(nbd):
41-
pl.plot(x,A[:,i])
40+
pl.figure(1, figsize=(6.4, 3))
41+
for i in range(n_distributions):
42+
pl.plot(x, A[:, i])
4243
pl.title('Distributions')
44+
pl.tight_layout()
4345

4446
#%% barycenter computation
4547

46-
alpha=0.2 # 0<=alpha<=1
47-
weights=np.array([1-alpha,alpha])
48+
alpha = 0.2 # 0<=alpha<=1
49+
weights = np.array([1 - alpha, alpha])
4850

4951
# l2bary
50-
bary_l2=A.dot(weights)
52+
bary_l2 = A.dot(weights)
5153

5254
# wasserstein
53-
reg=1e-3
54-
bary_wass=ot.bregman.barycenter(A,M,reg,weights)
55+
reg = 1e-3
56+
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
5557

5658
pl.figure(2)
5759
pl.clf()
58-
pl.subplot(2,1,1)
59-
for i in range(nbd):
60-
pl.plot(x,A[:,i])
60+
pl.subplot(2, 1, 1)
61+
for i in range(n_distributions):
62+
pl.plot(x, A[:, i])
6163
pl.title('Distributions')
6264

63-
pl.subplot(2,1,2)
64-
pl.plot(x,bary_l2,'r',label='l2')
65-
pl.plot(x,bary_wass,'g',label='Wasserstein')
65+
pl.subplot(2, 1, 2)
66+
pl.plot(x, bary_l2, 'r', label='l2')
67+
pl.plot(x, bary_wass, 'g', label='Wasserstein')
6668
pl.legend()
6769
pl.title('Barycenters')
68-
70+
pl.tight_layout()
6971

7072
#%% barycenter interpolation
7173

72-
nbalpha=11
73-
alphalist=np.linspace(0,1,nbalpha)
74+
n_alpha = 11
75+
alpha_list = np.linspace(0, 1, n_alpha)
7476

7577

76-
B_l2=np.zeros((n,nbalpha))
78+
B_l2 = np.zeros((n, n_alpha))
7779

78-
B_wass=np.copy(B_l2)
80+
B_wass = np.copy(B_l2)
7981

80-
for i in range(0,nbalpha):
81-
alpha=alphalist[i]
82-
weights=np.array([1-alpha,alpha])
83-
B_l2[:,i]=A.dot(weights)
84-
B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights)
82+
for i in range(0, n_alpha):
83+
alpha = alpha_list[i]
84+
weights = np.array([1 - alpha, alpha])
85+
B_l2[:, i] = A.dot(weights)
86+
B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)
8587

8688
#%% plot interpolation
8789

88-
pl.figure(3,(10,5))
90+
pl.figure(3)
8991

90-
#pl.subplot(1,2,1)
91-
cmap=pl.cm.get_cmap('viridis')
92+
cmap = pl.cm.get_cmap('viridis')
9293
verts = []
93-
zs = alphalist
94-
for i,z in enumerate(zs):
95-
ys = B_l2[:,i]
94+
zs = alpha_list
95+
for i, z in enumerate(zs):
96+
ys = B_l2[:, i]
9697
verts.append(list(zip(x, ys)))
9798

9899
ax = pl.gcf().gca(projection='3d')
99100

100-
poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
101+
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
101102
poly.set_alpha(0.7)
102103
ax.add_collection3d(poly, zs=zs, zdir='y')
103-
104104
ax.set_xlabel('x')
105105
ax.set_xlim3d(0, n)
106106
ax.set_ylabel('$\\alpha$')
107-
ax.set_ylim3d(0,1)
107+
ax.set_ylim3d(0, 1)
108108
ax.set_zlabel('')
109-
ax.set_zlim3d(0, B_l2.max()*1.01)
109+
ax.set_zlim3d(0, B_l2.max() * 1.01)
110110
pl.title('Barycenter interpolation with l2')
111+
pl.tight_layout()
111112

112-
pl.show()
113-
114-
pl.figure(4,(10,5))
115-
116-
#pl.subplot(1,2,1)
117-
cmap=pl.cm.get_cmap('viridis')
113+
pl.figure(4)
114+
cmap = pl.cm.get_cmap('viridis')
118115
verts = []
119-
zs = alphalist
120-
for i,z in enumerate(zs):
121-
ys = B_wass[:,i]
116+
zs = alpha_list
117+
for i, z in enumerate(zs):
118+
ys = B_wass[:, i]
122119
verts.append(list(zip(x, ys)))
123120

124121
ax = pl.gcf().gca(projection='3d')
125122

126-
poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
123+
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
127124
poly.set_alpha(0.7)
128125
ax.add_collection3d(poly, zs=zs, zdir='y')
129-
130126
ax.set_xlabel('x')
131127
ax.set_xlim3d(0, n)
132128
ax.set_ylabel('$\\alpha$')
133-
ax.set_ylim3d(0,1)
129+
ax.set_ylim3d(0, 1)
134130
ax.set_zlabel('')
135-
ax.set_zlim3d(0, B_l2.max()*1.01)
131+
ax.set_zlim3d(0, B_l2.max() * 1.01)
136132
pl.title('Barycenter interpolation with Wasserstein')
133+
pl.tight_layout()
137134

138-
pl.show()
135+
pl.show()

setup.cfg

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
[metadata]
22
description-file = README.md
3+
4+
[flake8]
5+
exclude = __init__.py
6+
ignore = E265

0 commit comments

Comments
 (0)