Skip to content

Commit 95b2a58

Browse files
committed
pep8 + pimp plot1D_mat rendering
1 parent d0258f1 commit 95b2a58

File tree

2 files changed

+75
-84
lines changed

2 files changed

+75
-84
lines changed

examples/plot_OT_1D.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,49 +8,50 @@
88
"""
99

1010
import numpy as np
11-
import matplotlib.pylab as pl
11+
import matplotlib.pylab as plt
1212
import ot
1313
from ot.datasets import get_1D_gauss as gauss
1414

15-
1615
#%% parameters
1716

18-
n=100 # nb bins
17+
n = 100 # nb bins
1918

2019
# bin positions
21-
x=np.arange(n,dtype=np.float64)
20+
x = np.arange(n, dtype=np.float64)
2221

2322
# Gaussian distributions
24-
a=gauss(n,m=20,s=5) # m= mean, s= std
25-
b=gauss(n,m=60,s=10)
23+
a = gauss(n, m=20, s=5) # m= mean, s= std
24+
b = gauss(n, m=60, s=10)
2625

2726
# loss matrix
28-
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
29-
M/=M.max()
27+
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
28+
M /= M.max()
3029

3130
#%% plot the distributions
3231

33-
pl.figure(1)
34-
pl.plot(x,a,'b',label='Source distribution')
35-
pl.plot(x,b,'r',label='Target distribution')
36-
pl.legend()
32+
plt.figure(1)
33+
plt.plot(x, a, 'b', label='Source distribution')
34+
plt.plot(x, b, 'r', label='Target distribution')
35+
plt.legend()
3736

3837
#%% plot distributions and loss matrix
3938

40-
pl.figure(2)
41-
ot.plot.plot1D_mat(a,b,M,'Cost matrix M')
39+
plt.figure(2, figsize=(5, 5))
40+
ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
4241

4342
#%% EMD
4443

45-
G0=ot.emd(a,b,M)
44+
G0 = ot.emd(a, b, M)
4645

47-
pl.figure(3)
48-
ot.plot.plot1D_mat(a,b,G0,'OT matrix G0')
46+
plt.figure(3, figsize=(5, 5))
47+
ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
4948

5049
#%% Sinkhorn
5150

52-
lambd=1e-3
53-
Gs=ot.sinkhorn(a,b,M,lambd,verbose=True)
51+
lambd = 1e-3
52+
Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
53+
54+
plt.figure(4, figsize=(5, 5))
55+
ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
5456

55-
pl.figure(4)
56-
ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')
57+
plt.show()

ot/plot.py

Lines changed: 53 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,89 +4,79 @@
44

55

66
import numpy as np
7-
import matplotlib.pylab as pl
7+
import matplotlib.pylab as plt
88
from matplotlib import gridspec
99

1010

11-
def plot1D_mat(a,b,M,title=''):
12-
""" Plot matrix M with the source and target 1D distribution
13-
14-
Creates a subplot with the source distribution a on the left and
11+
def plot1D_mat(a, b, M, title=''):
12+
""" Plot matrix M with the source and target 1D distribution
13+
14+
Creates a subplot with the source distribution a on the left and
1515
target distribution b on the tot. The matrix M is shown in between.
16-
17-
16+
17+
1818
Parameters
1919
----------
20-
21-
a : np.array (na,)
20+
a : np.array, shape (na,)
2221
Source distribution
23-
b : np.array (nb,)
24-
Target distribution
25-
M : np.array (na,nb)
22+
b : np.array, shape (nb,)
23+
Target distribution
24+
M : np.array, shape (na,nb)
2625
Matrix to plot
27-
28-
29-
3026
"""
31-
32-
na=M.shape[0]
33-
nb=M.shape[1]
34-
27+
na, nb = M.shape
28+
3529
gs = gridspec.GridSpec(3, 3)
36-
37-
38-
xa=np.arange(na)
39-
xb=np.arange(nb)
40-
41-
42-
ax1=pl.subplot(gs[0,1:])
43-
pl.plot(xb,b,'r',label='Target distribution')
44-
pl.yticks(())
45-
pl.title(title)
46-
47-
#pl.axis('off')
48-
49-
ax2=pl.subplot(gs[1:,0])
50-
pl.plot(a,xa,'b',label='Source distribution')
51-
pl.gca().invert_xaxis()
52-
pl.gca().invert_yaxis()
53-
pl.xticks(())
54-
#pl.ylim((0,n))
55-
#pl.axis('off')
56-
57-
pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
58-
pl.imshow(M,interpolation='nearest')
59-
60-
pl.xlim((0,nb))
61-
62-
63-
def plot2D_samples_mat(xs,xt,G,thr=1e-8,**kwargs):
30+
31+
xa = np.arange(na)
32+
xb = np.arange(nb)
33+
34+
ax1 = plt.subplot(gs[0, 1:])
35+
plt.plot(xb, b, 'r', label='Target distribution')
36+
plt.yticks(())
37+
plt.title(title)
38+
39+
ax2 = plt.subplot(gs[1:, 0])
40+
plt.plot(a, xa, 'b', label='Source distribution')
41+
plt.gca().invert_xaxis()
42+
plt.gca().invert_yaxis()
43+
plt.xticks(())
44+
45+
plt.subplot(gs[1:, 1:], sharex=ax1, sharey=ax2)
46+
plt.imshow(M, interpolation='nearest')
47+
plt.axis('off')
48+
49+
plt.xlim((0, nb))
50+
plt.tight_layout()
51+
plt.subplots_adjust(wspace=0., hspace=0.2)
52+
53+
54+
def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
6455
""" Plot matrix M in 2D with lines using alpha values
65-
66-
Plot lines between source and target 2D samples with a color
56+
57+
Plot lines between source and target 2D samples with a color
6758
proportional to the value of the matrix G between samples.
68-
69-
59+
60+
7061
Parameters
7162
----------
72-
73-
xs : np.array (ns,2)
63+
xs : ndarray, shape (ns,2)
7464
Source samples positions
75-
b : np.array (nt,2)
65+
b : ndarray, shape (nt,2)
7666
Target samples positions
77-
G : np.array (na,nb)
67+
G : ndarray, shape (na,nb)
7868
OT matrix
7969
thr : float, optional
8070
threshold above which the line is drawn
8171
**kwargs : dict
82-
paameters given to the plot functions (default color is black if nothing given)
83-
72+
paameters given to the plot functions (default color is black if
73+
nothing given)
8474
"""
85-
if ('color' not in kwargs) and ('c' not in kwargs):
86-
kwargs['color']='k'
87-
mx=G.max()
75+
if ('color' not in kwargs) and ('c' not in kwargs):
76+
kwargs['color'] = 'k'
77+
mx = G.max()
8878
for i in range(xs.shape[0]):
8979
for j in range(xt.shape[0]):
90-
if G[i,j]/mx>thr:
91-
pl.plot([xs[i,0],xt[j,0]],[xs[i,1],xt[j,1]],alpha=G[i,j]/mx,**kwargs)
92-
80+
if G[i, j] / mx > thr:
81+
plt.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]],
82+
alpha=G[i, j] / mx, **kwargs)

0 commit comments

Comments
 (0)