Skip to content

Commit e58cd78

Browse files
committed
Added convergence status to the log
1 parent a37e52e commit e58cd78

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

ot/lp/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import numpy as np
1313

1414
# import compiled emd
15-
from .emd_wrap import emd_c
15+
from .emd_wrap import emd_c, checkResult
1616
from ..utils import parmap
1717

1818

@@ -94,12 +94,15 @@ def emd(a, b, M, numItermax=100000, log=False):
9494
if len(b) == 0:
9595
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
9696

97-
G, cost, u, v = emd_c(a, b, M, numItermax)
97+
G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
98+
resultCodeString = checkResult(resultCode)
9899
if log:
99100
log = {}
100101
log['cost'] = cost
101102
log['u'] = u
102103
log['v'] = v
104+
log['warning'] = resultCodeString
105+
log['resultCode'] = resultCode
103106
return G, log
104107
return G
105108

@@ -177,15 +180,20 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=
177180

178181
if log:
179182
def f(b):
180-
G, cost, u, v = emd_c(a, b, M, numItermax)
183+
G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
184+
resultCodeString = checkResult(resultCode)
181185
log = {}
182186
log['G'] = G
183187
log['u'] = u
184188
log['v'] = v
189+
log['warning'] = resultCodeString
190+
log['resultCode'] = resultCode
185191
return [cost, log]
186192
else:
187193
def f(b):
188-
return emd_c(a, b, M, numItermax)[1]
194+
G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
195+
checkResult(resultCode)
196+
return cost
189197

190198
if len(b.shape) == 1:
191199
return f(b)

ot/lp/emd_wrap.pyx

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,32 @@ Cython linker with C solver
77
#
88
# License: MIT License
99

10-
import warnings
1110
import numpy as np
1211
cimport numpy as np
1312

1413
cimport cython
1514

15+
import warnings
1616

1717

1818
cdef extern from "EMD.h":
1919
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int numItermax)
2020
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
2121

2222

23+
def checkResult(resultCode):
24+
if resultCode == OPTIMAL:
25+
return None
26+
27+
if resultCode == INFEASIBLE:
28+
message = "Problem infeasible. Check that a and b are in the simplex"
29+
elif resultCode == UNBOUNDED:
30+
message = "Problem unbounded"
31+
elif resultCode == MAX_ITER_REACHED:
32+
message = "numItermax reached before optimality. Try to increase numItermax."
33+
warnings.warn(message)
34+
return message
35+
2336

2437
@cython.boundscheck(False)
2538
@cython.wraparound(False)
@@ -77,13 +90,6 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
7790
b=np.ones((n2,))/n2
7891

7992
# calling the function
80-
cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, numItermax)
81-
if resultSolver != OPTIMAL:
82-
if resultSolver == INFEASIBLE:
83-
warnings.warn("Problem infeasible. Check that a and b are in the simplex")
84-
elif resultSolver == UNBOUNDED:
85-
warnings.warn("Problem unbounded")
86-
elif resultSolver == MAX_ITER_REACHED:
87-
warnings.warn("numItermax reached before optimality. Try to increase numItermax.")
88-
89-
return G, cost, alpha, beta
93+
cdef int resultCode = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, numItermax)
94+
95+
return G, cost, alpha, beta, resultCode

0 commit comments

Comments
 (0)