Skip to content

Commit 4015474

Browse files
authored
Merge branch 'master' into emd_dimension
2 parents a9bbc2c + c5039bc commit 4015474

File tree

8 files changed

+330
-18
lines changed

8 files changed

+330
-18
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
name: Test Package
2+
3+
on: [push]
4+
5+
jobs:
6+
build:
7+
8+
runs-on: ubuntu-latest
9+
strategy:
10+
max-parallel: 4
11+
matrix:
12+
python-version: [2.7, 3.5, 3.6, 3.7]
13+
14+
steps:
15+
- uses: actions/checkout@v1
16+
- name: Set up Python ${{ matrix.python-version }}
17+
uses: actions/setup-python@v1
18+
with:
19+
python-version: ${{ matrix.python-version }}
20+
- name: Install dependencies
21+
run: |
22+
python -m pip install --upgrade pip
23+
pip install -r requirements.txt
24+
- name: Lint with flake8
25+
run: |
26+
pip install flake8
27+
# stop the build if there are Python syntax errors or undefined names
28+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
29+
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
30+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics

.gitignore

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,15 @@ ENV/
106106

107107
# coverage output folder
108108
cov_html/
109+
110+
docs/source/modules/generated/*
111+
docs/source/_build/*
112+
113+
# local debug folder
114+
debug
115+
116+
# vscode parameters
117+
.vscode
118+
119+
# pytest cahche
120+
.pytest_cache

ot/lp/EMD.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,9 @@ enum ProblemType {
3232

3333
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
3434

35+
int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
36+
long *iG, long *jG, double *G, long * nG,
37+
double* alpha, double* beta, double *cost, int maxIter);
38+
39+
3540
#endif

ot/lp/EMD_wrapper.cpp

Lines changed: 187 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717

1818
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
1919
double* alpha, double* beta, double *cost, int maxIter) {
20-
// beware M and C anre strored in row major C style!!!
21-
int n, m, i, cur;
20+
// beware M and C anre strored in row major C style!!!
21+
int n, m, i, cur;
2222

2323
typedef FullBipartiteDigraph Digraph;
24-
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
24+
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
2525

26-
// Get the number of non zero coordinates for r and c
26+
// Get the number of non zero coordinates for r and c
2727
n=0;
2828
for (int i=0; i<n1; i++) {
2929
double val=*(X+i);
@@ -105,3 +105,186 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
105105

106106
return ret;
107107
}
108+
109+
110+
int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
111+
long *iG, long *jG, double *G, long * nG,
112+
double* alpha, double* beta, double *cost, int maxIter) {
113+
// beware M and C anre strored in row major C style!!!
114+
115+
// Get the number of non zero coordinates for r and c and vectors
116+
int n, m, i, cur;
117+
118+
typedef FullBipartiteDigraph Digraph;
119+
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
120+
121+
// Get the number of non zero coordinates for r and c
122+
n=0;
123+
for (int i=0; i<n1; i++) {
124+
double val=*(X+i);
125+
if (val>0) {
126+
n++;
127+
}else if(val<0){
128+
return INFEASIBLE;
129+
}
130+
}
131+
m=0;
132+
for (int i=0; i<n2; i++) {
133+
double val=*(Y+i);
134+
if (val>0) {
135+
m++;
136+
}else if(val<0){
137+
return INFEASIBLE;
138+
}
139+
}
140+
141+
// Define the graph
142+
143+
std::vector<int> indI(n), indJ(m);
144+
std::vector<double> weights1(n), weights2(m);
145+
Digraph di(n, m);
146+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
147+
148+
// Set supply and demand, don't account for 0 values (faster)
149+
150+
cur=0;
151+
for (int i=0; i<n1; i++) {
152+
double val=*(X+i);
153+
if (val>0) {
154+
weights1[ cur ] = val;
155+
indI[cur++]=i;
156+
}
157+
}
158+
159+
// Demand is actually negative supply...
160+
161+
cur=0;
162+
for (int i=0; i<n2; i++) {
163+
double val=*(Y+i);
164+
if (val>0) {
165+
weights2[ cur ] = -val;
166+
indJ[cur++]=i;
167+
}
168+
}
169+
170+
// Define the graph
171+
net.supplyMap(&weights1[0], n, &weights2[0], m);
172+
173+
// Set the cost of each edge
174+
for (int i=0; i<n; i++) {
175+
for (int j=0; j<m; j++) {
176+
double val=*(D+indI[i]*n2+indJ[j]);
177+
net.setCost(di.arcFromId(i*m+j), val);
178+
}
179+
}
180+
181+
182+
// Solve the problem with the network simplex algorithm
183+
184+
int ret=net.run();
185+
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
186+
*cost = 0;
187+
Arc a; di.first(a);
188+
cur=0;
189+
for (; a != INVALID; di.next(a)) {
190+
int i = di.source(a);
191+
int j = di.target(a);
192+
double flow = net.flow(a);
193+
if (flow>0)
194+
{
195+
*cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
196+
197+
*(G+cur) = flow;
198+
*(iG+cur) = indI[i];
199+
*(jG+cur) = indJ[j-n];
200+
*(alpha + indI[i]) = -net.potential(i);
201+
*(beta + indJ[j-n]) = net.potential(j);
202+
cur++;
203+
}
204+
}
205+
*nG=cur; // nb of value +1 for numpy indexing
206+
207+
}
208+
209+
210+
return ret;
211+
}
212+
213+
int EMD_wrap_all_sparse(int n1, int n2, double *X, double *Y,
214+
long *iD, long *jD, double *D, long nD,
215+
long *iG, long *jG, double *G, long * nG,
216+
double* alpha, double* beta, double *cost, int maxIter) {
217+
// beware M and C anre strored in row major C style!!!
218+
219+
// Get the number of non zero coordinates for r and c and vectors
220+
int n, m, cur;
221+
222+
typedef FullBipartiteDigraph Digraph;
223+
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
224+
225+
n=n1;
226+
m=n2;
227+
228+
229+
// Define the graph
230+
231+
232+
std::vector<double> weights2(m);
233+
Digraph di(n, m);
234+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
235+
236+
// Set supply and demand, don't account for 0 values (faster)
237+
238+
239+
// Demand is actually negative supply...
240+
241+
cur=0;
242+
for (int i=0; i<n2; i++) {
243+
double val=*(Y+i);
244+
if (val>0) {
245+
weights2[ cur ] = -val;
246+
}
247+
}
248+
249+
// Define the graph
250+
net.supplyMap(X, n, &weights2[0], m);
251+
252+
// Set the cost of each edge
253+
for (int k=0; k<nD; k++) {
254+
int i = iD[k];
255+
int j = jD[k];
256+
net.setCost(di.arcFromId(i*m+j), D[k]);
257+
258+
}
259+
260+
261+
// Solve the problem with the network simplex algorithm
262+
263+
int ret=net.run();
264+
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
265+
*cost = net.totalCost();
266+
Arc a; di.first(a);
267+
cur=0;
268+
for (; a != INVALID; di.next(a)) {
269+
int i = di.source(a);
270+
int j = di.target(a);
271+
double flow = net.flow(a);
272+
if (flow>0)
273+
{
274+
275+
*(G+cur) = flow;
276+
*(iG+cur) = i;
277+
*(jG+cur) = j-n;
278+
*(alpha + i) = -net.potential(i);
279+
*(beta + j-n) = net.potential(j);
280+
cur++;
281+
}
282+
}
283+
*nG=cur; // nb of value +1 for numpy indexing
284+
285+
}
286+
287+
288+
return ret;
289+
}
290+

ot/lp/__init__.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
'emd_1d', 'emd2_1d', 'wasserstein_1d']
2828

2929

30-
def emd(a, b, M, numItermax=100000, log=False):
30+
def emd(a, b, M, numItermax=100000, log=False, dense=True):
3131
r"""Solves the Earth Movers distance problem and returns the OT matrix
3232
3333
@@ -62,6 +62,10 @@ def emd(a, b, M, numItermax=100000, log=False):
6262
log: bool, optional (default=False)
6363
If True, returns a dictionary containing the cost and dual
6464
variables. Otherwise returns only the optimal transportation matrix.
65+
dense: boolean, optional (default=True)
66+
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
67+
Otherwise returns a sparse representation using scipy's `coo_matrix`
68+
format.
6569
6670
Returns
6771
-------
@@ -103,6 +107,7 @@ def emd(a, b, M, numItermax=100000, log=False):
103107
b = np.asarray(b, dtype=np.float64)
104108
M = np.asarray(M, dtype=np.float64)
105109

110+
106111
# if empty array given then use uniform distributions
107112
if len(a) == 0:
108113
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
@@ -112,7 +117,12 @@ def emd(a, b, M, numItermax=100000, log=False):
112117
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
113118
"Dimension mismatch, check dimensions of M with a and b"
114119

115-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
120+
if dense:
121+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
122+
else:
123+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
124+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
125+
116126
result_code_string = check_result(result_code)
117127
if log:
118128
log = {}
@@ -126,7 +136,7 @@ def emd(a, b, M, numItermax=100000, log=False):
126136

127137

128138
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
129-
numItermax=100000, log=False, return_matrix=False):
139+
numItermax=100000, log=False, dense=True, return_matrix=False):
130140
r"""Solves the Earth Movers distance problem and returns the loss
131141
132142
.. math::
@@ -164,6 +174,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
164174
variables. Otherwise returns only the optimal transportation cost.
165175
return_matrix: boolean, optional (default=False)
166176
If True, returns the optimal transportation matrix in the log.
177+
dense: boolean, optional (default=True)
178+
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
179+
Otherwise returns a sparse representation using scipy's `coo_matrix`
180+
format.
167181
168182
Returns
169183
-------
@@ -220,19 +234,30 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
220234

221235
if log or return_matrix:
222236
def f(b):
223-
G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
224-
result_code_string = check_result(resultCode)
237+
if dense:
238+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
239+
else:
240+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
241+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
242+
243+
result_code_string = check_result(result_code)
225244
log = {}
226245
if return_matrix:
227246
log['G'] = G
228247
log['u'] = u
229248
log['v'] = v
230249
log['warning'] = result_code_string
231-
log['result_code'] = resultCode
250+
log['result_code'] = result_code
232251
return [cost, log]
233252
else:
234253
def f(b):
235-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
254+
if dense:
255+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
256+
else:
257+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
258+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
259+
260+
result_code_string = check_result(result_code)
236261
check_result(result_code)
237262
return cost
238263

0 commit comments

Comments
 (0)