@@ -107,7 +107,7 @@ def emd(a, b, M, num_iter_max=100000, log=False):
107107 return G
108108
109109
110- def emd2 (a , b , M , processes = multiprocessing .cpu_count (), num_iter_max = 100000 , log = False ):
110+ def emd2 (a , b , M , processes = multiprocessing .cpu_count (), num_iter_max = 100000 , log = False , return_matrix = False ):
111111 """Solves the Earth Movers distance problem and returns the loss
112112
113113 .. math::
@@ -134,6 +134,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo
134134 num_iter_max : int, optional (default=100000)
135135 The maximum number of iterations before stopping the optimization
136136 algorithm if it has not converged.
137+ log: boolean, optional (default=False)
138+ If True, returns a dictionary containing the cost and dual
139+ variables. Otherwise returns only the optimal transportation cost.
140+ return_matrix: boolean, optional (default=False)
141+ If True, returns the optimal transportation matrix in the log.
137142
138143 Returns
139144 -------
@@ -181,12 +186,13 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo
181186 if len (b ) == 0 :
182187 b = np .ones ((M .shape [1 ],), dtype = np .float64 ) / M .shape [1 ]
183188
184- if log :
189+ if log or return_matrix :
185190 def f (b ):
186191 G , cost , u , v , resultCode = emd_c (a , b , M , num_iter_max )
187192 result_code_string = check_result (resultCode )
188193 log = {}
189- log ['G' ] = G
194+ if return_matrix :
195+ log ['G' ] = G
190196 log ['u' ] = u
191197 log ['v' ] = v
192198 log ['warning' ] = result_code_string
0 commit comments