@@ -107,19 +107,18 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
107107 b = np .asarray (b , dtype = np .float64 )
108108 M = np .asarray (M , dtype = np .float64 )
109109
110- sparse = not dense
111110
112111 # if empty array given then use uniform distributions
113112 if len (a ) == 0 :
114113 a = np .ones ((M .shape [0 ],), dtype = np .float64 ) / M .shape [0 ]
115114 if len (b ) == 0 :
116115 b = np .ones ((M .shape [1 ],), dtype = np .float64 ) / M .shape [1 ]
117116
118- if sparse :
119- Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax ,sparse )
120- G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
117+ if dense :
118+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
121119 else :
122- G , cost , u , v , result_code = emd_c (a , b , M , numItermax ,sparse )
120+ Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
121+ G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
123122
124123 result_code_string = check_result (result_code )
125124 if log :
@@ -217,8 +216,6 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
217216 b = np .asarray (b , dtype = np .float64 )
218217 M = np .asarray (M , dtype = np .float64 )
219218
220- sparse = not dense
221-
222219 # problem with pikling Forks
223220 if sys .platform .endswith ('win32' ):
224221 processes = 1
@@ -231,12 +228,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
231228
232229 if log or return_matrix :
233230 def f (b ):
234-
235- if sparse :
236- Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax ,sparse )
237- G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
231+ if dense :
232+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
238233 else :
239- G , cost , u , v , result_code = emd_c (a , b , M , numItermax ,sparse )
234+ Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
235+ G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
240236
241237 result_code_string = check_result (result_code )
242238 log = {}
@@ -249,11 +245,13 @@ def f(b):
249245 return [cost , log ]
250246 else :
251247 def f (b ):
252- if sparse :
253- Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax ,sparse )
254- G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
248+ if dense :
249+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
255250 else :
256- G , cost , u , v , result_code = emd_c (a , b , M , numItermax ,sparse )
251+ Gv , iG , jG , cost , u , v , result_code = emd_c (a , b , M , numItermax ,dense )
252+ G = coo_matrix ((Gv , (iG , jG )), shape = (a .shape [0 ], b .shape [0 ]))
253+
254+ result_code_string = check_result (result_code )
257255 check_result (result_code )
258256 return cost
259257
0 commit comments