1212import numpy as np
1313
1414# import compiled emd
15- from .emd_wrap import emd_c , checkResult
15+ from .emd_wrap import emd_c , check_result
1616from ..utils import parmap
1717
1818
19- def emd (a , b , M , numItermax = 100000 , log = False ):
19+ def emd (a , b , M , num_iter_max = 100000 , log = False ):
2020 """Solves the Earth Movers distance problem and returns the OT matrix
2121
2222
@@ -41,7 +41,7 @@ def emd(a, b, M, numItermax=100000, log=False):
4141 Target histogram (uniform weigth if empty list)
4242 M : (ns,nt) ndarray, float64
4343 loss matrix
44- numItermax : int, optional (default=100000)
44+ num_iter_max : int, optional (default=100000)
4545 The maximum number of iterations before stopping the optimization
4646 algorithm if it has not converged.
4747 log: boolean, optional (default=False)
@@ -54,7 +54,7 @@ def emd(a, b, M, numItermax=100000, log=False):
5454 Optimal transportation matrix for the given parameters
5555 log: dict
5656 If input log is true, a dictionary containing the cost and dual
57- variables
57+ variables and exit status
5858
5959
6060 Examples
@@ -94,20 +94,20 @@ def emd(a, b, M, numItermax=100000, log=False):
9494 if len (b ) == 0 :
9595 b = np .ones ((M .shape [1 ],), dtype = np .float64 ) / M .shape [1 ]
9696
97- G , cost , u , v , resultCode = emd_c (a , b , M , numItermax )
98- resultCodeString = checkResult ( resultCode )
97+ G , cost , u , v , result_code = emd_c (a , b , M , num_iter_max )
98+ resultCodeString = check_result ( result_code )
9999 if log :
100100 log = {}
101101 log ['cost' ] = cost
102102 log ['u' ] = u
103103 log ['v' ] = v
104104 log ['warning' ] = resultCodeString
105- log ['resultCode ' ] = resultCode
105+ log ['result_code ' ] = result_code
106106 return G , log
107107 return G
108108
109109
110- def emd2 (a , b , M , processes = multiprocessing .cpu_count (), numItermax = 100000 , log = False ):
110+ def emd2 (a , b , M , processes = multiprocessing .cpu_count (), num_iter_max = 100000 , log = False ):
111111 """Solves the Earth Movers distance problem and returns the loss
112112
113113 .. math::
@@ -131,14 +131,17 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=
131131 Target histogram (uniform weigth if empty list)
132132 M : (ns,nt) ndarray, float64
133133 loss matrix
134- numItermax : int, optional (default=100000)
134+ num_iter_max : int, optional (default=100000)
135135 The maximum number of iterations before stopping the optimization
136136 algorithm if it has not converged.
137137
138138 Returns
139139 -------
140140 gamma: (ns x nt) ndarray
141141 Optimal transportation matrix for the given parameters
142+ log: dict
143+ If input log is true, a dictionary containing the cost and dual
144+ variables and exit status
142145
143146
144147 Examples
@@ -180,19 +183,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=
180183
181184 if log :
182185 def f (b ):
183- G , cost , u , v , resultCode = emd_c (a , b , M , numItermax )
184- resultCodeString = checkResult (resultCode )
186+ G , cost , u , v , resultCode = emd_c (a , b , M , num_iter_max )
187+ resultCodeString = check_result (resultCode )
185188 log = {}
186189 log ['G' ] = G
187190 log ['u' ] = u
188191 log ['v' ] = v
189192 log ['warning' ] = resultCodeString
190- log ['resultCode ' ] = resultCode
193+ log ['result_code ' ] = resultCode
191194 return [cost , log ]
192195 else :
193196 def f (b ):
194- G , cost , u , v , resultCode = emd_c (a , b , M , numItermax )
195- checkResult ( resultCode )
197+ G , cost , u , v , result_code = emd_c (a , b , M , num_iter_max )
198+ check_result ( result_code )
196199 return cost
197200
198201 if len (b .shape ) == 1 :
0 commit comments