Skip to content

Commit 3cb0315

Browse files
committed
cleanup variable name dense
1 parent a4afee8 commit 3cb0315

File tree

2 files changed

+27
-29
lines changed

2 files changed

+27
-29
lines changed

ot/lp/__init__.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -107,19 +107,18 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
107107
b = np.asarray(b, dtype=np.float64)
108108
M = np.asarray(M, dtype=np.float64)
109109

110-
sparse= not dense
111110

112111
# if empty array given then use uniform distributions
113112
if len(a) == 0:
114113
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
115114
if len(b) == 0:
116115
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
117116

118-
if sparse:
119-
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse)
120-
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
117+
if dense:
118+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
121119
else:
122-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse)
120+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
121+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
123122

124123
result_code_string = check_result(result_code)
125124
if log:
@@ -217,8 +216,6 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
217216
b = np.asarray(b, dtype=np.float64)
218217
M = np.asarray(M, dtype=np.float64)
219218

220-
sparse=not dense
221-
222219
# problem with pikling Forks
223220
if sys.platform.endswith('win32'):
224221
processes=1
@@ -231,12 +228,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
231228

232229
if log or return_matrix:
233230
def f(b):
234-
235-
if sparse:
236-
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse)
237-
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
231+
if dense:
232+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
238233
else:
239-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse)
234+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
235+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
240236

241237
result_code_string = check_result(result_code)
242238
log = {}
@@ -249,11 +245,13 @@ def f(b):
249245
return [cost, log]
250246
else:
251247
def f(b):
252-
if sparse:
253-
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse)
254-
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
248+
if dense:
249+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
255250
else:
256-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,sparse)
251+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
252+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
253+
254+
result_code_string = check_result(result_code)
257255
check_result(result_code)
258256
return cost
259257

ot/lp/emd_wrap.pyx

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def check_result(result_code):
4646

4747
@cython.boundscheck(False)
4848
@cython.wraparound(False)
49-
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint sparse):
49+
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint dense):
5050
"""
5151
Solves the Earth Movers distance problem and returns the optimal transport matrix
5252
@@ -110,8 +110,19 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
110110
if not len(b):
111111
b=np.ones((n2,))/n2
112112

113-
if sparse:
113+
if dense:
114+
# init OT matrix
115+
G=np.zeros([n1, n2])
116+
117+
# calling the function
118+
result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
119+
120+
return G, cost, alpha, beta, result_code
121+
114122

123+
else:
124+
125+
# init sparse OT matrix
115126
Gv=np.zeros(nmax)
116127
iG=np.zeros(nmax,dtype=np.int)
117128
jG=np.zeros(nmax,dtype=np.int)
@@ -123,17 +134,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
123134
return Gv[:nG], iG[:nG], jG[:nG], cost, alpha, beta, result_code
124135

125136

126-
else:
127-
128-
129-
G=np.zeros([n1, n2])
130-
131-
132-
# calling the function
133-
result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
134-
135-
return G, cost, alpha, beta, result_code
136-
137137

138138
@cython.boundscheck(False)
139139
@cython.wraparound(False)

0 commit comments

Comments
 (0)