1616from ..utils import parmap
1717
1818
19- def emd (a , b , M , max_iter = 100000 , log = False ):
19+ def emd (a , b , M , num_iter_max = 100000 , log = False ):
2020 """Solves the Earth Movers distance problem and returns the OT matrix
2121
2222
@@ -41,7 +41,7 @@ def emd(a, b, M, max_iter=100000, log=False):
4141 Target histogram (uniform weigth if empty list)
4242 M : (ns,nt) ndarray, float64
4343 loss matrix
44- max_iter : int, optional (default=100000)
44+ num_iter_max : int, optional (default=100000)
4545 The maximum number of iterations before stopping the optimization
4646 algorithm if it has not converged.
4747 log: boolean, optional (default=False)
@@ -94,7 +94,7 @@ def emd(a, b, M, max_iter=100000, log=False):
9494 if len (b ) == 0 :
9595 b = np .ones ((M .shape [1 ],), dtype = np .float64 ) / M .shape [1 ]
9696
97- G , cost , u , v , result_code = emd_c (a , b , M , max_iter )
97+ G , cost , u , v , result_code = emd_c (a , b , M , num_iter_max )
9898 result_code_string = check_result (result_code )
9999 if log :
100100 log = {}
@@ -107,7 +107,7 @@ def emd(a, b, M, max_iter=100000, log=False):
107107 return G
108108
109109
110- def emd2 (a , b , M , processes = multiprocessing .cpu_count (), max_iter = 100000 , log = False ):
110+ def emd2 (a , b , M , processes = multiprocessing .cpu_count (), num_iter_max = 100000 , log = False ):
111111 """Solves the Earth Movers distance problem and returns the loss
112112
113113 .. math::
@@ -183,7 +183,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000, log=Fa
183183
184184 if log :
185185 def f (b ):
186- G , cost , u , v , resultCode = emd_c (a , b , M , max_iter )
186+ G , cost , u , v , resultCode = emd_c (a , b , M , num_iter_max )
187187 result_code_string = check_result (resultCode )
188188 log = {}
189189 log ['G' ] = G
@@ -194,7 +194,7 @@ def f(b):
194194 return [cost , log ]
195195 else :
196196 def f (b ):
197- G , cost , u , v , result_code = emd_c (a , b , M , max_iter )
197+ G , cost , u , v , result_code = emd_c (a , b , M , num_iter_max )
198198 check_result (result_code )
199199 return cost
200200
0 commit comments