Skip to content

Commit 1f30759

Browse files
authored
[MRG] numItermax in 64 bits in EMD solver (#380)
* Correct test_mm_convergence for cupy * Fix bug where number of iterations is limited to 2^31 * Update RELEASES.md * Replace size_t with long long * Use uint64_t instead of long long
1 parent 951209a commit 1f30759

File tree

7 files changed

+41
-34
lines changed

7 files changed

+41
-34
lines changed

RELEASES.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010

1111
- Fixed an issue where we could not ask TorchBackend to place a random tensor on GPU
1212
(Issue #371, PR #373)
13-
- Fixed an issue where hitting iteration limits would be reported to stderr by std::cerr regardless of Python's stderr stream status.
13+
- Fixed an issue where Sinkhorn solver assumed a symmetric cost matrix (Issue #374, PR #375)
14+
- Fixed an issue where hitting iteration limits would be reported to stderr by std::cerr regardless of Python's stderr stream status (PR #377)
15+
- Fixed an issue where the metric argument in ot.dist did not allow a callable parameter (Issue #378, PR #379)
16+
- Fixed an issue where the max number of iterations in ot.emd was not allow to go beyond 2^31 (PR #380)
1417

1518

1619
## 0.8.2

ot/lp/EMD.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <iostream>
2020
#include <vector>
21+
#include <cstdint>
2122

2223
typedef unsigned int node_id_type;
2324

@@ -28,8 +29,8 @@ enum ProblemType {
2829
MAX_ITER_REACHED
2930
};
3031

31-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
32-
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads);
32+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
33+
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);
3334

3435

3536

ot/lp/EMD_wrapper.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121

2222
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
23-
double* alpha, double* beta, double *cost, int maxIter) {
23+
double* alpha, double* beta, double *cost, uint64_t maxIter) {
2424
// beware M and C are stored in row major C style!!!
2525

2626
using namespace lemon;
@@ -122,7 +122,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
122122

123123

124124
int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
125-
double* alpha, double* beta, double *cost, int maxIter, int numThreads) {
125+
double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) {
126126
// beware M and C are stored in row major C style!!!
127127

128128
using namespace lemon_omp;

ot/lp/emd_wrap.pyx

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ from ..utils import dist
1414

1515
cimport cython
1616
cimport libc.math as math
17+
from libc.stdint cimport uint64_t
1718

1819
import warnings
1920

2021

2122
cdef extern from "EMD.h":
22-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil
23-
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil
23+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
24+
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
2425
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
2526

2627

@@ -39,7 +40,7 @@ def check_result(result_code):
3940

4041
@cython.boundscheck(False)
4142
@cython.wraparound(False)
42-
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads):
43+
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads):
4344
"""
4445
Solves the Earth Movers distance problem and returns the optimal transport matrix
4546
@@ -75,7 +76,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
7576
target histogram
7677
M : (ns,nt) numpy.ndarray, float64
7778
loss matrix
78-
max_iter : int
79+
max_iter : uint64_t
7980
The maximum number of iterations before stopping the optimization
8081
algorithm if it has not converged.
8182

ot/lp/network_simplex_simple.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ namespace lemon {
233233
/// mixed order in the internal data structure.
234234
/// In special cases, it could lead to better overall performance,
235235
/// but it is usually slower. Therefore it is disabled by default.
236-
NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters) :
236+
NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters) :
237237
_graph(graph), //_arc_id(graph),
238238
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
239239
MAX(std::numeric_limits<Value>::max()),
@@ -242,7 +242,7 @@ namespace lemon {
242242
{
243243
// Reset data structures
244244
reset();
245-
max_iter=maxiters;
245+
max_iter = maxiters;
246246
}
247247

248248
/// The type of the flow amounts, capacity bounds and supply values
@@ -293,7 +293,7 @@ namespace lemon {
293293

294294
private:
295295

296-
size_t max_iter;
296+
uint64_t max_iter;
297297
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
298298

299299
typedef std::vector<int> IntVector;
@@ -1427,7 +1427,7 @@ namespace lemon {
14271427
// Perform heuristic initial pivots
14281428
if (!initialPivots()) return UNBOUNDED;
14291429

1430-
size_t iter_number=0;
1430+
uint64_t iter_number = 0;
14311431
//pivot.setDantzig(true);
14321432
// Execute the Network Simplex algorithm
14331433
while (pivot.findEnteringArc()) {

ot/lp/network_simplex_simple_omp.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ namespace lemon_omp {
244244
/// mixed order in the internal data structure.
245245
/// In special cases, it could lead to better overall performance,
246246
/// but it is usually slower. Therefore it is disabled by default.
247-
NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters = 0, int numThreads=-1) :
247+
NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters = 0, int numThreads=-1) :
248248
_graph(graph), //_arc_id(graph),
249249
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
250250
MAX(std::numeric_limits<Value>::max()),
@@ -317,7 +317,7 @@ namespace lemon_omp {
317317

318318

319319
private:
320-
size_t max_iter;
320+
uint64_t max_iter;
321321
int num_threads;
322322
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
323323

@@ -1563,7 +1563,7 @@ namespace lemon_omp {
15631563
// Perform heuristic initial pivots
15641564
if (!initialPivots()) return UNBOUNDED;
15651565

1566-
size_t iter_number = 0;
1566+
uint64_t iter_number = 0;
15671567
// Execute the Network Simplex algorithm
15681568
while (pivot.findEnteringArc()) {
15691569
if ((++iter_number <= max_iter&&max_iter > 0) || max_iter<=0) {

test/test_unbalanced.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -295,26 +295,27 @@ def test_mm_convergence(nx):
295295
x = rng.randn(n, 2)
296296
rng = np.random.RandomState(75)
297297
y = rng.randn(n, 2)
298-
a = ot.utils.unif(n)
299-
b = ot.utils.unif(n)
298+
a_np = ot.utils.unif(n)
299+
b_np = ot.utils.unif(n)
300300

301301
M = ot.dist(x, y)
302302
M = M / M.max()
303303
reg_m = 100
304-
a, b, M = nx.from_numpy(a, b, M)
304+
a, b, M = nx.from_numpy(a_np, b_np, M)
305305

306306
G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl',
307-
verbose=True, log=True)
308-
loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2(
309-
a, b, M, reg_m, div='kl', verbose=True))
307+
verbose=False, log=True)
308+
loss_kl = nx.to_numpy(
309+
ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div='kl', verbose=True)
310+
)
310311
G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2',
311312
verbose=False, log=True)
312313

313314
# check if the marginals come close to the true ones when large reg
314-
np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03)
315-
np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03)
316-
np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03)
317-
np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03)
315+
np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a_np, atol=1e-03)
316+
np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b_np, atol=1e-03)
317+
np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a_np, atol=1e-03)
318+
np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b_np, atol=1e-03)
318319

319320
# check if mm_unbalanced2 returns the correct loss
320321
np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl,
@@ -324,15 +325,16 @@ def test_mm_convergence(nx):
324325
a_np, b_np = np.array([]), np.array([])
325326
a, b = nx.from_numpy(a_np, b_np)
326327

327-
G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl')
328-
G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2')
329-
np.testing.assert_allclose(G_kl_null, G_kl)
330-
np.testing.assert_allclose(G_l2_null, G_l2)
328+
G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', verbose=False)
329+
G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False)
330+
np.testing.assert_allclose(nx.to_numpy(G_kl_null), nx.to_numpy(G_kl))
331+
np.testing.assert_allclose(nx.to_numpy(G_l2_null), nx.to_numpy(G_l2))
331332

332333
# test when G0 is given
333334
G0 = ot.emd(a, b, M)
335+
G0_np = nx.to_numpy(G0)
334336
reg_m = 10000
335-
G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0)
336-
G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0)
337-
np.testing.assert_allclose(G0, G_kl, atol=1e-05)
338-
np.testing.assert_allclose(G0, G_l2, atol=1e-05)
337+
G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0, verbose=False)
338+
G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0, verbose=False)
339+
np.testing.assert_allclose(G0_np, nx.to_numpy(G_kl), atol=1e-05)
340+
np.testing.assert_allclose(G0_np, nx.to_numpy(G_l2), atol=1e-05)

0 commit comments

Comments
 (0)