Skip to content

Commit ef12867

Browse files
ncourtyrflamary
andauthored
[WIP] Issue with sparse emd and adding tests on macos (#158)
* First commit-warning removal * remove dense feature * pep8 * pep8 * EMD.h * pep8 again * tic toc tolerance Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent bacb0b9 commit ef12867

File tree

8 files changed

+46
-305
lines changed

8 files changed

+46
-305
lines changed

.github/workflows/pythonpackage.yml

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -47,31 +47,31 @@ jobs:
4747
run: |
4848
codecov
4949
50-
# macos:
51-
# runs-on: macOS-latest
52-
# strategy:
53-
# max-parallel: 4
54-
# matrix:
55-
# python-version: [3.7]
50+
macos:
51+
runs-on: macOS-latest
52+
strategy:
53+
max-parallel: 4
54+
matrix:
55+
python-version: [3.7]
5656

57-
# steps:
58-
# - uses: actions/checkout@v1
59-
# - name: Set up Python ${{ matrix.python-version }}
60-
# uses: actions/setup-python@v1
61-
# with:
62-
# python-version: ${{ matrix.python-version }}
63-
# - name: Install dependencies
64-
# run: |
65-
# python -m pip install --upgrade pip
66-
# pip install -r requirements.txt
67-
# pip install pytest "pytest-cov<2.6"
68-
# pip install -U "sklearn"
69-
# - name: Install POT
70-
# run: |
71-
# pip install -e .
72-
# - name: Run tests
73-
# run: |
74-
# python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot
57+
steps:
58+
- uses: actions/checkout@v1
59+
- name: Set up Python ${{ matrix.python-version }}
60+
uses: actions/setup-python@v1
61+
with:
62+
python-version: ${{ matrix.python-version }}
63+
- name: Install dependencies
64+
run: |
65+
python -m pip install --upgrade pip
66+
pip install -r requirements.txt
67+
pip install pytest "pytest-cov<2.6"
68+
pip install -U "sklearn"
69+
- name: Install POT
70+
run: |
71+
pip install -e .
72+
- name: Run tests
73+
run: |
74+
python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot
7575
7676
7777
windows:

ot/lp/EMD.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@ enum ProblemType {
3232

3333
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
3434

35-
int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
36-
long *iG, long *jG, double *G, long * nG,
37-
double* alpha, double* beta, double *cost, int maxIter);
3835

3936

4037
#endif

ot/lp/EMD_wrapper.cpp

Lines changed: 0 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -106,185 +106,3 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
106106
return ret;
107107
}
108108

109-
110-
int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
111-
long *iG, long *jG, double *G, long * nG,
112-
double* alpha, double* beta, double *cost, int maxIter) {
113-
// beware M and C anre strored in row major C style!!!
114-
115-
// Get the number of non zero coordinates for r and c and vectors
116-
int n, m, i, cur;
117-
118-
typedef FullBipartiteDigraph Digraph;
119-
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
120-
121-
// Get the number of non zero coordinates for r and c
122-
n=0;
123-
for (int i=0; i<n1; i++) {
124-
double val=*(X+i);
125-
if (val>0) {
126-
n++;
127-
}else if(val<0){
128-
return INFEASIBLE;
129-
}
130-
}
131-
m=0;
132-
for (int i=0; i<n2; i++) {
133-
double val=*(Y+i);
134-
if (val>0) {
135-
m++;
136-
}else if(val<0){
137-
return INFEASIBLE;
138-
}
139-
}
140-
141-
// Define the graph
142-
143-
std::vector<int> indI(n), indJ(m);
144-
std::vector<double> weights1(n), weights2(m);
145-
Digraph di(n, m);
146-
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
147-
148-
// Set supply and demand, don't account for 0 values (faster)
149-
150-
cur=0;
151-
for (int i=0; i<n1; i++) {
152-
double val=*(X+i);
153-
if (val>0) {
154-
weights1[ cur ] = val;
155-
indI[cur++]=i;
156-
}
157-
}
158-
159-
// Demand is actually negative supply...
160-
161-
cur=0;
162-
for (int i=0; i<n2; i++) {
163-
double val=*(Y+i);
164-
if (val>0) {
165-
weights2[ cur ] = -val;
166-
indJ[cur++]=i;
167-
}
168-
}
169-
170-
// Define the graph
171-
net.supplyMap(&weights1[0], n, &weights2[0], m);
172-
173-
// Set the cost of each edge
174-
for (int i=0; i<n; i++) {
175-
for (int j=0; j<m; j++) {
176-
double val=*(D+indI[i]*n2+indJ[j]);
177-
net.setCost(di.arcFromId(i*m+j), val);
178-
}
179-
}
180-
181-
182-
// Solve the problem with the network simplex algorithm
183-
184-
int ret=net.run();
185-
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
186-
*cost = 0;
187-
Arc a; di.first(a);
188-
cur=0;
189-
for (; a != INVALID; di.next(a)) {
190-
int i = di.source(a);
191-
int j = di.target(a);
192-
double flow = net.flow(a);
193-
if (flow>0)
194-
{
195-
*cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
196-
197-
*(G+cur) = flow;
198-
*(iG+cur) = indI[i];
199-
*(jG+cur) = indJ[j-n];
200-
*(alpha + indI[i]) = -net.potential(i);
201-
*(beta + indJ[j-n]) = net.potential(j);
202-
cur++;
203-
}
204-
}
205-
*nG=cur; // nb of value +1 for numpy indexing
206-
207-
}
208-
209-
210-
return ret;
211-
}
212-
213-
int EMD_wrap_all_sparse(int n1, int n2, double *X, double *Y,
214-
long *iD, long *jD, double *D, long nD,
215-
long *iG, long *jG, double *G, long * nG,
216-
double* alpha, double* beta, double *cost, int maxIter) {
217-
// beware M and C anre strored in row major C style!!!
218-
219-
// Get the number of non zero coordinates for r and c and vectors
220-
int n, m, cur;
221-
222-
typedef FullBipartiteDigraph Digraph;
223-
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
224-
225-
n=n1;
226-
m=n2;
227-
228-
229-
// Define the graph
230-
231-
232-
std::vector<double> weights2(m);
233-
Digraph di(n, m);
234-
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
235-
236-
// Set supply and demand, don't account for 0 values (faster)
237-
238-
239-
// Demand is actually negative supply...
240-
241-
cur=0;
242-
for (int i=0; i<n2; i++) {
243-
double val=*(Y+i);
244-
if (val>0) {
245-
weights2[ cur ] = -val;
246-
}
247-
}
248-
249-
// Define the graph
250-
net.supplyMap(X, n, &weights2[0], m);
251-
252-
// Set the cost of each edge
253-
for (int k=0; k<nD; k++) {
254-
int i = iD[k];
255-
int j = jD[k];
256-
net.setCost(di.arcFromId(i*m+j), D[k]);
257-
258-
}
259-
260-
261-
// Solve the problem with the network simplex algorithm
262-
263-
int ret=net.run();
264-
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
265-
*cost = net.totalCost();
266-
Arc a; di.first(a);
267-
cur=0;
268-
for (; a != INVALID; di.next(a)) {
269-
int i = di.source(a);
270-
int j = di.target(a);
271-
double flow = net.flow(a);
272-
if (flow>0)
273-
{
274-
275-
*(G+cur) = flow;
276-
*(iG+cur) = i;
277-
*(jG+cur) = j-n;
278-
*(alpha + i) = -net.potential(i);
279-
*(beta + j-n) = net.potential(j);
280-
cur++;
281-
}
282-
}
283-
*nG=cur; // nb of value +1 for numpy indexing
284-
285-
}
286-
287-
288-
return ret;
289-
}
290-

ot/lp/__init__.py

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
172172
return center_ot_dual(alpha, beta, a, b)
173173

174174

175-
def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
175+
def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
176176
r"""Solves the Earth Movers distance problem and returns the OT matrix
177177
178178
@@ -207,10 +207,6 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
207207
log: bool, optional (default=False)
208208
If True, returns a dictionary containing the cost and dual
209209
variables. Otherwise returns only the optimal transportation matrix.
210-
dense: boolean, optional (default=True)
211-
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
212-
Otherwise returns a sparse representation using scipy's `coo_matrix`
213-
format.
214210
center_dual: boolean, optional (default=True)
215211
If True, centers the dual potential using function
216212
:ref:`center_ot_dual`.
@@ -267,25 +263,14 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
267263
asel = a != 0
268264
bsel = b != 0
269265

270-
if dense:
271-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
272-
273-
if center_dual:
274-
u, v = center_ot_dual(u, v, a, b)
266+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
275267

276-
if np.any(~asel) or np.any(~bsel):
277-
u, v = estimate_dual_null_weights(u, v, a, b, M)
278-
279-
else:
280-
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
281-
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
282-
283-
if center_dual:
284-
u, v = center_ot_dual(u, v, a, b)
285-
286-
if np.any(~asel) or np.any(~bsel):
287-
u, v = estimate_dual_null_weights(u, v, a, b, M)
268+
if center_dual:
269+
u, v = center_ot_dual(u, v, a, b)
288270

271+
if np.any(~asel) or np.any(~bsel):
272+
u, v = estimate_dual_null_weights(u, v, a, b, M)
273+
289274
result_code_string = check_result(result_code)
290275
if log:
291276
log = {}
@@ -299,7 +284,7 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
299284

300285

301286
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
302-
numItermax=100000, log=False, dense=True, return_matrix=False,
287+
numItermax=100000, log=False, return_matrix=False,
303288
center_dual=True):
304289
r"""Solves the Earth Movers distance problem and returns the loss
305290
@@ -404,11 +389,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
404389
if log or return_matrix:
405390
def f(b):
406391
bsel = b != 0
407-
if dense:
408-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
409-
else:
410-
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
411-
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
392+
393+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
412394

413395
if center_dual:
414396
u, v = center_ot_dual(u, v, a, b)
@@ -428,19 +410,14 @@ def f(b):
428410
else:
429411
def f(b):
430412
bsel = b != 0
431-
if dense:
432-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
433-
else:
434-
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
435-
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
413+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
436414

437415
if center_dual:
438416
u, v = center_ot_dual(u, v, a, b)
439417

440418
if np.any(~asel) or np.any(~bsel):
441419
u, v = estimate_dual_null_weights(u, v, a, b, M)
442420

443-
result_code_string = check_result(result_code)
444421
check_result(result_code)
445422
return cost
446423

0 commit comments

Comments
 (0)