Skip to content

Commit f8c1c87

Browse files
committed
Added MAX_ITER_REACHED flag and warning
1 parent a3497b1 commit f8c1c87

File tree

5 files changed

+102
-49
lines changed

5 files changed

+102
-49
lines changed

ot/lp/EMD.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ typedef unsigned int node_id_type;
2626
enum ProblemType {
2727
INFEASIBLE,
2828
OPTIMAL,
29-
UNBOUNDED
29+
UNBOUNDED,
30+
MAX_ITER_REACHED
3031
};
3132

3233
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int max_iter);

ot/lp/EMD_wrapper.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,18 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
2929
double val=*(X+i);
3030
if (val>0) {
3131
n++;
32-
}
32+
}else if(val<0){
33+
return INFEASIBLE;
34+
}
3335
}
3436
m=0;
3537
for (int i=0; i<n2; i++) {
3638
double val=*(Y+i);
3739
if (val>0) {
3840
m++;
39-
}
41+
}else if(val<0){
42+
return INFEASIBLE;
43+
}
4044
}
4145

4246
// Define the graph
@@ -83,16 +87,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
8387
// Solve the problem with the network simplex algorithm
8488

8589
int ret=net.run();
86-
if (ret!=(int)net.OPTIMAL) {
87-
if (ret==(int)net.INFEASIBLE) {
88-
std::cout << "Infeasible problem";
89-
}
90-
if (ret==(int)net.UNBOUNDED)
91-
{
92-
std::cout << "Unbounded problem";
93-
}
94-
} else
95-
{
90+
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
9691
*cost = 0;
9792
Arc a; di.first(a);
9893
for (; a != INVALID; di.next(a)) {
@@ -105,7 +100,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
105100
*(beta + indJ[j-n]) = net.potential(j);
106101
}
107102

108-
};
103+
}
109104

110105

111106
return ret;

ot/lp/emd_wrap.pyx

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Cython linker with C solver
77
#
88
# License: MIT License
99

10+
import warnings
1011
import numpy as np
1112
cimport numpy as np
1213

@@ -15,14 +16,14 @@ cimport cython
1516

1617

1718
cdef extern from "EMD.h":
18-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int max_iter)
19-
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED
19+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int numItermax)
20+
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
2021

2122

2223

2324
@cython.boundscheck(False)
2425
@cython.wraparound(False)
25-
def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
26+
def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int numItermax):
2627
"""
2728
Solves the Earth Movers distance problem and returns the optimal transport matrix
2829
@@ -49,7 +50,7 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
4950
target histogram
5051
M : (ns,nt) ndarray, float64
5152
loss matrix
52-
max_iter : int
53+
numItermax : int
5354
The maximum number of iterations before stopping the optimization
5455
algorithm if it has not converged.
5556
@@ -76,18 +77,20 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
7677
b=np.ones((n2,))/n2
7778

7879
# calling the function
79-
cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
80+
cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, numItermax)
8081
if resultSolver != OPTIMAL:
8182
if resultSolver == INFEASIBLE:
82-
print("Problem infeasible. Try to increase numItermax.")
83+
warnings.warn("Problem infeasible. Check that a and b are in the simplex")
8384
elif resultSolver == UNBOUNDED:
84-
print("Problem unbounded")
85+
warnings.warn("Problem unbounded")
86+
elif resultSolver == MAX_ITER_REACHED:
87+
warnings.warn("numItermax reached before optimality. Try to increase numItermax.")
8588

8689
return G, alpha, beta
8790

8891
@cython.boundscheck(False)
8992
@cython.wraparound(False)
90-
def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
93+
def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int numItermax):
9194
"""
9295
Solves the Earth Movers distance problem and returns the optimal transport loss
9396
@@ -114,7 +117,7 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
114117
target histogram
115118
M : (ns,nt) ndarray, float64
116119
loss matrix
117-
max_iter : int
120+
numItermax : int
118121
The maximum number of iterations before stopping the optimization
119122
algorithm if it has not converged.
120123
@@ -140,12 +143,14 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
140143
if not len(b):
141144
b=np.ones((n2,))/n2
142145
# calling the function
143-
cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
146+
cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, numItermax)
144147
if resultSolver != OPTIMAL:
145148
if resultSolver == INFEASIBLE:
146-
print("Problem infeasible. Try to inscrease numItermax.")
149+
warnings.warn("Problem infeasible. Check that a and b are in the simplex")
147150
elif resultSolver == UNBOUNDED:
148-
print("Problem unbounded")
151+
warnings.warn("Problem unbounded")
152+
elif resultSolver == MAX_ITER_REACHED:
153+
warnings.warn("numItermax reached before optimality. Try to increase numItermax.")
149154

150155
return cost, alpha, beta
151156

ot/lp/network_simplex_simple.h

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
#endif
3535

3636

37-
#define EPSILON 10*2.2204460492503131e-016
37+
#define EPSILON 2.2204460492503131e-15
38+
#define _EPSILON 1e-8
3839
#define MAX_DEBUG_ITER 100000
3940

4041

@@ -260,7 +261,9 @@ namespace lemon {
260261
/// The objective function of the problem is unbounded, i.e.
261262
/// there is a directed cycle having negative total cost and
262263
/// infinite upper bound.
263-
UNBOUNDED
264+
UNBOUNDED,
265+
/// The maximum number of iteration has been reached
266+
MAX_ITER_REACHED
264267
};
265268

266269
/// \brief Constants for selecting the type of the supply constraints.
@@ -683,7 +686,7 @@ namespace lemon {
683686
/// \see resetParams(), reset()
684687
ProblemType run() {
685688
#if DEBUG_LVL>0
686-
std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "nUNBOUNDED = " << UNBOUNDED << "\n";
689+
std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED\n";
687690
#endif
688691
689692
if (!init()) return INFEASIBLE;
@@ -941,15 +944,15 @@ namespace lemon {
941944
// Initialize internal data structures
942945
bool init() {
943946
if (_node_num == 0) return false;
944-
/*
947+
945948
// Check the sum of supply values
946949
_sum_supply = 0;
947950
for (int i = 0; i != _node_num; ++i) {
948951
_sum_supply += _supply[i];
949952
}
950-
if ( !((_stype == GEQ && _sum_supply <= _epsilon ) ||
951-
(_stype == LEQ && _sum_supply >= -_epsilon )) ) return false;
952-
*/
953+
if ( fabs(_sum_supply) > _EPSILON ) return false;
954+
955+
_sum_supply = 0;
953956
954957
// Initialize artifical cost
955958
Cost ART_COST;
@@ -1416,13 +1419,11 @@ namespace lemon {
14161419
ProblemType start() {
14171420
PivotRuleImpl pivot(*this);
14181421
double prevCost=-1;
1422+
ProblemType retVal = OPTIMAL;
14191423
14201424
// Perform heuristic initial pivots
14211425
if (!initialPivots()) return UNBOUNDED;
14221426
1423-
#if DEBUG_LVL>0
1424-
int niter=0;
1425-
#endif
14261427
int iter_number=0;
14271428
//pivot.setDantzig(true);
14281429
// Execute the Network Simplex algorithm
@@ -1431,12 +1432,13 @@ namespace lemon {
14311432
char errMess[1000];
14321433
sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
14331434
std::cerr << errMess;
1435+
retVal = MAX_ITER_REACHED;
14341436
break;
14351437
}
14361438
#if DEBUG_LVL>0
1437-
if(niter>MAX_DEBUG_ITER)
1439+
if(iter_number>MAX_DEBUG_ITER)
14381440
break;
1439-
if(++niter%1000==0||niter%1000==1){
1441+
if(iter_number%1000==0||iter_number%1000==1){
14401442
double curCost=totalCost();
14411443
double sumFlow=0;
14421444
double a;
@@ -1445,7 +1447,7 @@ namespace lemon {
14451447
for (int i=0; i<_flow.size(); i++) {
14461448
sumFlow+=_state[i]*_flow[i];
14471449
}
1448-
std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
1450+
std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
14491451
std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n";
14501452
std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n";
14511453
std::cout << _cost[in_arc] << "\n";
@@ -1503,15 +1505,17 @@ namespace lemon {
15031505
std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n";
15041506
#endif
15051507
// Check feasibility
1506-
for (int e = _search_arc_num; e != _all_arc_num; ++e) {
1507-
if (_flow[e] != 0){
1508-
if (abs(_flow[e]) > EPSILON)
1509-
return INFEASIBLE;
1510-
else
1511-
_flow[e]=0;
1508+
if( retVal == OPTIMAL){
1509+
for (int e = _search_arc_num; e != _all_arc_num; ++e) {
1510+
if (_flow[e] != 0){
1511+
if (abs(_flow[e]) > EPSILON)
1512+
return INFEASIBLE;
1513+
else
1514+
_flow[e]=0;
15121515
1516+
}
15131517
}
1514-
}
1518+
}
15151519
15161520
// Shift potentials to meet the requirements of the GEQ/LEQ type
15171521
// optimality conditions
@@ -1537,7 +1541,7 @@ namespace lemon {
15371541
}
15381542
}
15391543
1540-
return OPTIMAL;
1544+
return retVal;
15411545
}
15421546
15431547
}; //class NetworkSimplexSimple

test/test_ot.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import ot
1010
from ot.datasets import get_1D_gauss as gauss
11+
import warnings
1112

1213

1314
def test_doctest():
@@ -100,9 +101,56 @@ def test_emd2_multi():
100101
np.testing.assert_allclose(emd1, emdn)
101102

102103

103-
def test_dual_variables():
104-
# %% parameters
104+
def test_warnings():
105+
n = 100 # nb bins
106+
m = 100 # nb bins
107+
108+
mean1 = 30
109+
mean2 = 50
110+
111+
# bin positions
112+
x = np.arange(n, dtype=np.float64)
113+
y = np.arange(m, dtype=np.float64)
114+
115+
# Gaussian distributions
116+
a = gauss(n, m=mean1, s=5) # m= mean, s= std
117+
118+
b = gauss(m, m=mean2, s=10)
105119

120+
# loss matrix
121+
M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2)
122+
# M/=M.max()
123+
124+
# %%
125+
126+
print('Computing {} EMD '.format(1))
127+
G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
128+
with warnings.catch_warnings(record=True) as w:
129+
# Cause all warnings to always be triggered.
130+
warnings.simplefilter("always")
131+
# Trigger a warning.
132+
print('Computing {} EMD '.format(1))
133+
G, alpha, beta = ot.emd(a, b, M, dual_variables=True, numItermax=1)
134+
# Verify some things
135+
assert "numItermax" in str(w[-1].message)
136+
assert len(w) == 1
137+
# Trigger a warning.
138+
a[0]=100
139+
print('Computing {} EMD '.format(2))
140+
G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
141+
# Verify some things
142+
assert "infeasible" in str(w[-1].message)
143+
assert len(w) == 2
144+
# Trigger a warning.
145+
a[0]=-1
146+
print('Computing {} EMD '.format(2))
147+
G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
148+
# Verify some things
149+
assert "infeasible" in str(w[-1].message)
150+
assert len(w) == 3
151+
152+
153+
def test_dual_variables():
106154
n = 5000 # nb bins
107155
m = 6000 # nb bins
108156

0 commit comments

Comments
 (0)