Skip to content

Commit f86dbde

Browse files
committed
pep8 normalization
1 parent d99abf0 commit f86dbde

File tree

2 files changed

+70
-68
lines changed

2 files changed

+70
-68
lines changed
Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
1+
22
#%%
33
# -*- coding: utf-8 -*-
44
"""
@@ -32,61 +32,61 @@
3232
f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]
3333

3434
A = []
35-
f1=f1/np.sum(f1)
36-
f2=f2/np.sum(f2)
37-
f3=f3/np.sum(f3)
38-
f4=f4/np.sum(f4)
35+
f1 = f1 / np.sum(f1)
36+
f2 = f2 / np.sum(f2)
37+
f3 = f3 / np.sum(f3)
38+
f4 = f4 / np.sum(f4)
3939
A.append(f1)
4040
A.append(f2)
4141
A.append(f3)
4242
A.append(f4)
43-
A=np.array(A)
43+
A = np.array(A)
4444

4545
nb_images = 5
4646

4747
# those are the four corners coordinates that will be interpolated by bilinear
4848
# interpolation
49-
v1=np.array((1,0,0,0))
50-
v2=np.array((0,1,0,0))
51-
v3=np.array((0,0,1,0))
52-
v4=np.array((0,0,0,1))
49+
v1 = np.array((1, 0, 0, 0))
50+
v2 = np.array((0, 1, 0, 0))
51+
v3 = np.array((0, 0, 1, 0))
52+
v4 = np.array((0, 0, 0, 1))
5353

5454

5555
##############################################################################
5656
# Barycenter computation and visualization
5757
# ----------------------------------------
5858
#
5959

60-
pl.figure(figsize=(10,10))
60+
pl.figure(figsize=(10, 10))
6161
pl.title('Convolutional Wasserstein Barycenters in POT')
62-
cm='Blues'
62+
cm = 'Blues'
6363
# regularization parameter
64-
reg=0.004
64+
reg = 0.004
6565
for i in range(nb_images):
6666
for j in range(nb_images):
67-
pl.subplot(nb_images,nb_images,i*nb_images+j+1)
68-
tx=float(i)/(nb_images-1)
69-
ty=float(j)/(nb_images-1)
70-
67+
pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
68+
tx = float(i) / (nb_images - 1)
69+
ty = float(j) / (nb_images - 1)
70+
7171
# weights are constructed by bilinear interpolation
72-
tmp1=(1-tx)*v1+tx*v2
73-
tmp2=(1-tx)*v3+tx*v4
74-
weights=(1-ty)*tmp1+ty*tmp2
75-
76-
if i==0 and j==0:
77-
pl.imshow(f1,cmap=cm)
78-
pl.axis('off')
79-
elif i==0 and j==(nb_images-1):
80-
pl.imshow(f3,cmap=cm)
81-
pl.axis('off')
82-
elif i==(nb_images-1) and j==0:
83-
pl.imshow(f2,cmap=cm)
84-
pl.axis('off')
85-
elif i==(nb_images-1) and j==(nb_images-1):
86-
pl.imshow(f4,cmap=cm)
87-
pl.axis('off')
72+
tmp1 = (1 - tx) * v1 + tx * v2
73+
tmp2 = (1 - tx) * v3 + tx * v4
74+
weights = (1 - ty) * tmp1 + ty * tmp2
75+
76+
if i == 0 and j == 0:
77+
pl.imshow(f1, cmap=cm)
78+
pl.axis('off')
79+
elif i == 0 and j == (nb_images - 1):
80+
pl.imshow(f3, cmap=cm)
81+
pl.axis('off')
82+
elif i == (nb_images - 1) and j == 0:
83+
pl.imshow(f2, cmap=cm)
84+
pl.axis('off')
85+
elif i == (nb_images - 1) and j == (nb_images - 1):
86+
pl.imshow(f4, cmap=cm)
87+
pl.axis('off')
8888
else:
8989
# call to barycenter computation
90-
pl.imshow(ot.convolutional_barycenter2d(A,reg,weights),cmap=cm)
90+
pl.imshow(ot.convolutional_barycenter2d(A, reg, weights), cmap=cm)
9191
pl.axis('off')
92-
pl.show()
92+
pl.show()

ot/bregman.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
918918
else:
919919
return geometricBar(weights, UKv)
920920

921-
def convolutional_barycenter2d(A,reg,weights=None,numItermax = 10000, stopThr=1e-9, verbose=False, log=False):
921+
922+
def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
922923
"""Compute the entropic regularized wasserstein barycenter of distributions A
923924
where A is a collection of 2D images.
924925
@@ -979,51 +980,52 @@ def convolutional_barycenter2d(A,reg,weights=None,numItermax = 10000, stopThr=1e
979980
if log:
980981
log = {'err': []}
981982

982-
b=np.zeros_like(A[0,:,:])
983-
U=np.ones_like(A)
984-
KV=np.ones_like(A)
985-
threshold = 1e-30 # in order to avoids numerical precision issues
983+
b = np.zeros_like(A[0, :, :])
984+
U = np.ones_like(A)
985+
KV = np.ones_like(A)
986+
threshold = 1e-30 # in order to avoids numerical precision issues
986987

987988
cpt = 0
988-
err=1
989-
990-
# build the convolution operator
991-
t = np.linspace(0,1,A.shape[1])
992-
[Y,X] = np.meshgrid(t,t)
993-
xi1 = np.exp(-(X-Y)**2/reg)
994-
K = lambda x: np.dot(np.dot(xi1,x),xi1)
995-
996-
while (err>stopThr and cpt<numItermax):
997-
998-
bold=b
999-
cpt = cpt +1
1000-
1001-
b=np.zeros_like(A[0,:,:])
989+
err = 1
990+
991+
# build the convolution operator
992+
t = np.linspace(0, 1, A.shape[1])
993+
[Y, X] = np.meshgrid(t, t)
994+
xi1 = np.exp(-(X - Y)**2 / reg)
995+
996+
def K(x): return np.dot(np.dot(xi1, x), xi1)
997+
998+
while (err > stopThr and cpt < numItermax):
999+
1000+
bold = b
1001+
cpt = cpt + 1
1002+
1003+
b = np.zeros_like(A[0, :, :])
10021004
for r in range(A.shape[0]):
1003-
KV[r,:,:]=K(A[r,:,:]/np.maximum(threshold,K(U[r,:,:])))
1004-
b += weights[r] * np.log(np.maximum(threshold, U[r,:,:]*KV[r,:,:]))
1005+
KV[r, :, :] = K(A[r, :, :] / np.maximum(threshold, K(U[r, :, :])))
1006+
b += weights[r] * np.log(np.maximum(threshold, U[r, :, :] * KV[r, :, :]))
10051007
b = np.exp(b)
10061008
for r in range(A.shape[0]):
1007-
U[r,:,:]=b/np.maximum(threshold,KV[r,:,:])
1008-
1009-
if cpt%10==1:
1010-
err=np.sum(np.abs(bold-b))
1009+
U[r, :, :] = b / np.maximum(threshold, KV[r, :, :])
1010+
1011+
if cpt % 10 == 1:
1012+
err = np.sum(np.abs(bold - b))
10111013
# log and verbose print
10121014
if log:
10131015
log['err'].append(err)
10141016

10151017
if verbose:
1016-
if cpt%200 ==0:
1017-
print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
1018-
print('{:5d}|{:8e}|'.format(cpt,err))
1018+
if cpt % 200 == 0:
1019+
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
1020+
print('{:5d}|{:8e}|'.format(cpt, err))
10191021

10201022
if log:
1021-
log['niter']=cpt
1022-
log['U']=U
1023-
return b,log
1023+
log['niter'] = cpt
1024+
log['U'] = U
1025+
return b, log
10241026
else:
1025-
return b
1026-
1027+
return b
1028+
10271029

10281030
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
10291031
stopThr=1e-3, verbose=False, log=False):

0 commit comments

Comments
 (0)