@@ -7,6 +7,7 @@ Cython linker with C solver
77#
88# License: MIT License
99
10+ import warnings
1011import numpy as np
1112cimport numpy as np
1213
@@ -15,14 +16,14 @@ cimport cython
1516
1617
1718cdef extern from " EMD.h" :
18- int EMD_wrap(int n1,int n2, double * X, double * Y,double * D, double * G, double * alpha, double * beta, double * cost, int max_iter )
19- cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED
19+ int EMD_wrap(int n1,int n2, double * X, double * Y,double * D, double * G, double * alpha, double * beta, double * cost, int numItermax )
20+ cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
2021
2122
2223
2324@ cython.boundscheck (False )
2425@ cython.wraparound (False )
25- def emd_c ( np.ndarray[double , ndim = 1 , mode = " c" ] a,np.ndarray[double , ndim = 1 , mode = " c" ] b,np.ndarray[double , ndim = 2 , mode = " c" ] M, int max_iter ):
26+ def emd_c ( np.ndarray[double , ndim = 1 , mode = " c" ] a,np.ndarray[double , ndim = 1 , mode = " c" ] b,np.ndarray[double , ndim = 2 , mode = " c" ] M, int numItermax ):
2627 """
2728 Solves the Earth Movers distance problem and returns the optimal transport matrix
2829
@@ -49,7 +50,7 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
4950 target histogram
5051 M : (ns,nt) ndarray, float64
5152 loss matrix
52- max_iter : int
53+ numItermax : int
5354 The maximum number of iterations before stopping the optimization
5455 algorithm if it has not converged.
5556
@@ -76,18 +77,20 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
7677 b= np.ones((n2,))/ n2
7778
7879 # calling the function
79- cdef int resultSolver = EMD_wrap(n1,n2,< double * > a.data,< double * > b.data,< double * > M.data,< double * > G.data, < double * > alpha.data, < double * > beta.data, < double * > & cost, max_iter )
80+ cdef int resultSolver = EMD_wrap(n1,n2,< double * > a.data,< double * > b.data,< double * > M.data,< double * > G.data, < double * > alpha.data, < double * > beta.data, < double * > & cost, numItermax )
8081 if resultSolver != OPTIMAL:
8182 if resultSolver == INFEASIBLE:
82- print (" Problem infeasible. Try to increase numItermax. " )
83+ warnings.warn (" Problem infeasible. Check that a and b are in the simplex " )
8384 elif resultSolver == UNBOUNDED:
84- print (" Problem unbounded" )
85+ warnings.warn(" Problem unbounded" )
86+ elif resultSolver == MAX_ITER_REACHED:
87+ warnings.warn(" numItermax reached before optimality. Try to increase numItermax." )
8588
8689 return G, alpha, beta
8790
8891@ cython.boundscheck (False )
8992@ cython.wraparound (False )
90- def emd2_c ( np.ndarray[double , ndim = 1 , mode = " c" ] a,np.ndarray[double , ndim = 1 , mode = " c" ] b,np.ndarray[double , ndim = 2 , mode = " c" ] M, int max_iter ):
93+ def emd2_c ( np.ndarray[double , ndim = 1 , mode = " c" ] a,np.ndarray[double , ndim = 1 , mode = " c" ] b,np.ndarray[double , ndim = 2 , mode = " c" ] M, int numItermax ):
9194 """
9295 Solves the Earth Movers distance problem and returns the optimal transport loss
9396
@@ -114,7 +117,7 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
114117 target histogram
115118 M : (ns,nt) ndarray, float64
116119 loss matrix
117- max_iter : int
120+ numItermax : int
118121 The maximum number of iterations before stopping the optimization
119122 algorithm if it has not converged.
120123
@@ -140,12 +143,14 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
140143 if not len (b):
141144 b= np.ones((n2,))/ n2
142145 # calling the function
143- cdef int resultSolver = EMD_wrap(n1,n2,< double * > a.data,< double * > b.data,< double * > M.data,< double * > G.data, < double * > alpha.data, < double * > beta.data, < double * > & cost, max_iter )
146+ cdef int resultSolver = EMD_wrap(n1,n2,< double * > a.data,< double * > b.data,< double * > M.data,< double * > G.data, < double * > alpha.data, < double * > beta.data, < double * > & cost, numItermax )
144147 if resultSolver != OPTIMAL:
145148 if resultSolver == INFEASIBLE:
146- print (" Problem infeasible. Try to inscrease numItermax. " )
149+ warnings.warn (" Problem infeasible. Check that a and b are in the simplex " )
147150 elif resultSolver == UNBOUNDED:
148- print (" Problem unbounded" )
151+ warnings.warn(" Problem unbounded" )
152+ elif resultSolver == MAX_ITER_REACHED:
153+ warnings.warn(" numItermax reached before optimality. Try to increase numItermax." )
149154
150155 return cost, alpha, beta
151156
0 commit comments