Skip to content

Commit 95db977

Browse files
committed
pep8
1 parent 37ca314 commit 95db977

File tree

6 files changed

+61
-51
lines changed

6 files changed

+61
-51
lines changed

ot/gpu/bregman.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
8282
ot.lp.emd : Unregularized OT
8383
ot.optim.cg : General regularized OT
8484
85-
"""
85+
"""
8686
# init data
8787
Nini = len(a)
8888
Nfin = len(b)
@@ -92,11 +92,11 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
9292

9393
# we assume that no distances are null except those of the diagonal of
9494
# distances
95-
u = (np.ones(Nini)/Nini).reshape((Nini, 1))
95+
u = (np.ones(Nini) / Nini).reshape((Nini, 1))
9696
u_GPU = cudamat.CUDAMatrix(u)
9797
a_GPU = cudamat.CUDAMatrix(a.reshape((Nini, 1)))
9898
ones_GPU = cudamat.empty(u_GPU.shape).assign(1)
99-
v = (np.ones(Nfin)/Nfin).reshape((Nfin, 1))
99+
v = (np.ones(Nfin) / Nfin).reshape((Nfin, 1))
100100
v_GPU = cudamat.CUDAMatrix(v)
101101
b_GPU = cudamat.CUDAMatrix(b.reshape((Nfin, 1)))
102102

@@ -121,7 +121,7 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
121121
ones_GPU.divide(Kp_GPU.dot(v_GPU), target=u_GPU)
122122

123123
if (np.any(KtransposeU_GPU.asarray() == 0) or
124-
not u_GPU.allfinite() or not v_GPU.allfinite()):
124+
not u_GPU.allfinite() or not v_GPU.allfinite()):
125125
# we have reached the machine precision
126126
# come back to previous solution and quit loop
127127
print('Warning: numerical errors at iteration', cpt)
@@ -142,7 +142,8 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
142142

143143
if verbose:
144144
if cpt % 200 == 0:
145-
print('{:5s}|{:12s}'.format('It.', 'Err')+'\n'+'-'*19)
145+
print(
146+
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
146147
print('{:5d}|{:8e}|'.format(cpt, err))
147148
cpt += 1
148149
if log:

ot/gpu/da.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M_GPU, reg, eta=0.1, numItermax=10,
167167
tmpC_GPU = cudamat.empty((Nfin, nbRow)).assign(0)
168168
transp_GPU.transpose().select_columns(indices_labels[i], tmpC_GPU)
169169
majs_GPU = tmpC_GPU.sum(axis=1).add(epsilon)
170-
cudamat.pow(majs_GPU, (p-1))
170+
cudamat.pow(majs_GPU, (p - 1))
171171
majs_GPU.mult(p)
172172

173173
tmpC_GPU.assign(0)

ot/utils.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,35 @@
77
import multiprocessing
88

99
import time
10-
__time_tic_toc=time.time()
10+
__time_tic_toc = time.time()
11+
1112

1213
def tic():
1314
""" Python implementation of Matlab tic() function """
1415
global __time_tic_toc
15-
__time_tic_toc=time.time()
16+
__time_tic_toc = time.time()
17+
1618

1719
def toc(message='Elapsed time : {} s'):
1820
""" Python implementation of Matlab toc() function """
19-
t=time.time()
20-
print(message.format(t-__time_tic_toc))
21-
return t-__time_tic_toc
21+
t = time.time()
22+
print(message.format(t - __time_tic_toc))
23+
return t - __time_tic_toc
24+
2225

2326
def toq():
2427
""" Python implementation of Julia toc() function """
25-
t=time.time()
26-
return t-__time_tic_toc
28+
t = time.time()
29+
return t - __time_tic_toc
2730

2831

29-
def kernel(x1,x2,method='gaussian',sigma=1,**kwargs):
32+
def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
3033
"""Compute kernel matrix"""
31-
if method.lower() in ['gaussian','gauss','rbf']:
32-
K=np.exp(-dist(x1,x2)/(2*sigma**2))
34+
if method.lower() in ['gaussian', 'gauss', 'rbf']:
35+
K = np.exp(-dist(x1, x2) / (2 * sigma**2))
3336
return K
3437

38+
3539
def unif(n):
3640
""" return a uniform histogram of length n (simplex)
3741
@@ -48,17 +52,19 @@ def unif(n):
4852
4953
5054
"""
51-
return np.ones((n,))/n
55+
return np.ones((n,)) / n
5256

53-
def clean_zeros(a,b,M):
54-
""" Remove all components with zeros weights in a and b
57+
58+
def clean_zeros(a, b, M):
59+
""" Remove all components with zeros weights in a and b
5560
"""
56-
M2=M[a>0,:][:,b>0].copy() # copy force c style matrix (froemd)
57-
a2=a[a>0]
58-
b2=b[b>0]
59-
return a2,b2,M2
61+
M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd)
62+
a2 = a[a > 0]
63+
b2 = b[b > 0]
64+
return a2, b2, M2
65+
6066

61-
def dist(x1,x2=None,metric='sqeuclidean'):
67+
def dist(x1, x2=None, metric='sqeuclidean'):
6268
"""Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
6369
6470
Parameters
@@ -84,12 +90,12 @@ def dist(x1,x2=None,metric='sqeuclidean'):
8490
8591
"""
8692
if x2 is None:
87-
x2=x1
93+
x2 = x1
8894

89-
return cdist(x1,x2,metric=metric)
95+
return cdist(x1, x2, metric=metric)
9096

9197

92-
def dist0(n,method='lin_square'):
98+
def dist0(n, method='lin_square'):
9399
"""Compute standard cost matrices of size (n,n) for OT problems
94100
95101
Parameters
@@ -111,16 +117,17 @@ def dist0(n,method='lin_square'):
111117
112118
113119
"""
114-
res=0
115-
if method=='lin_square':
116-
x=np.arange(n,dtype=np.float64).reshape((n,1))
117-
res=dist(x,x)
120+
res = 0
121+
if method == 'lin_square':
122+
x = np.arange(n, dtype=np.float64).reshape((n, 1))
123+
res = dist(x, x)
118124
return res
119125

120126

121127
def dots(*args):
122128
""" dots function for multiple matrix multiply """
123-
return reduce(np.dot,args)
129+
return reduce(np.dot, args)
130+
124131

125132
def fun(f, q_in, q_out):
126133
""" Utility function for parmap with no serializing problems """
@@ -130,6 +137,7 @@ def fun(f, q_in, q_out):
130137
break
131138
q_out.put((i, f(x)))
132139

140+
133141
def parmap(f, X, nprocs=multiprocessing.cpu_count()):
134142
""" paralell map for multiprocessing """
135143
q_in = multiprocessing.Queue(1)
@@ -147,4 +155,4 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()):
147155

148156
[p.join() for p in proc]
149157

150-
return [x for i, x in sorted(res)]
158+
return [x for i, x in sorted(res)]

test/test_emd_multi.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,42 +7,41 @@
77
"""
88

99
import numpy as np
10-
import pylab as pl
11-
import ot
1210

11+
import ot
1312
from ot.datasets import get_1D_gauss as gauss
14-
reload(ot.lp)
13+
# reload(ot.lp)
1514

1615
#%% parameters
1716

18-
n=5000 # nb bins
17+
n = 5000 # nb bins
1918

2019
# bin positions
21-
x=np.arange(n,dtype=np.float64)
20+
x = np.arange(n, dtype=np.float64)
2221

2322
# Gaussian distributions
24-
a=gauss(n,m=20,s=5) # m= mean, s= std
23+
a = gauss(n, m=20, s=5) # m= mean, s= std
2524

26-
ls= range(20,1000,10)
27-
nb=len(ls)
28-
b=np.zeros((n,nb))
25+
ls = range(20, 1000, 10)
26+
nb = len(ls)
27+
b = np.zeros((n, nb))
2928
for i in range(nb):
30-
b[:,i]=gauss(n,m=ls[i],s=10)
29+
b[:, i] = gauss(n, m=ls[i], s=10)
3130

3231
# loss matrix
33-
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
34-
#M/=M.max()
32+
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
33+
# M/=M.max()
3534

3635
#%%
3736

3837
print('Computing {} EMD '.format(nb))
3938

4039
# emd loss 1 proc
4140
ot.tic()
42-
emd_loss4=ot.emd2(a,b,M,1)
41+
emd_loss4 = ot.emd2(a, b, M, 1)
4342
ot.toc('1 proc : {} s')
4443

4544
# emd loss multipro proc
4645
ot.tic()
47-
emd_loss4=ot.emd2(a,b,M)
46+
emd_loss4 = ot.emd2(a, b, M)
4847
ot.toc('multi proc : {} s')

test/test_gpu_sinkhorn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import time
44
import ot.gpu
55

6+
67
def describeRes(r):
7-
print("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}".format(np.min(r),np.max(r),np.mean(r),np.std(r)))
8+
print("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}".format(
9+
np.min(r), np.max(r), np.mean(r), np.std(r)))
810

911

1012
for n in [5000, 10000, 15000, 20000]:
@@ -23,4 +25,4 @@ def describeRes(r):
2325
print("Normal sinkhorn, time: {:6.2f} sec ".format(time2 - time1))
2426
describeRes(G1)
2527
print(" GPU sinkhorn, time: {:6.2f} sec ".format(time3 - time2))
26-
describeRes(G2)
28+
describeRes(G2)

test/test_load_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import doctest
55

66
# test lp solver
7-
doctest.testmod(ot.lp,verbose=True)
7+
doctest.testmod(ot.lp, verbose=True)
88

99
# test bregman solver
10-
doctest.testmod(ot.bregman,verbose=True)
10+
doctest.testmod(ot.bregman, verbose=True)

0 commit comments

Comments
 (0)