Skip to content

Commit c5039bc

Browse files
authored
Merge pull request #109 from rflamary/sparse_emd
[MRG] Sparse emd solution
2 parents bbd8f20 + e9954bb commit c5039bc

File tree

7 files changed

+300
-18
lines changed

7 files changed

+300
-18
lines changed

.gitignore

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,15 @@ ENV/
106106

107107
# coverage output folder
108108
cov_html/
109+
110+
docs/source/modules/generated/*
111+
docs/source/_build/*
112+
113+
# local debug folder
114+
debug
115+
116+
# vscode parameters
117+
.vscode
118+
119+
# pytest cahche
120+
.pytest_cache

ot/lp/EMD.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,9 @@ enum ProblemType {
3232

3333
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
3434

35+
int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
36+
long *iG, long *jG, double *G, long * nG,
37+
double* alpha, double* beta, double *cost, int maxIter);
38+
39+
3540
#endif

ot/lp/EMD_wrapper.cpp

Lines changed: 187 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717

1818
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
1919
double* alpha, double* beta, double *cost, int maxIter) {
20-
// beware M and C anre strored in row major C style!!!
21-
int n, m, i, cur;
20+
// beware M and C anre strored in row major C style!!!
21+
int n, m, i, cur;
2222

2323
typedef FullBipartiteDigraph Digraph;
24-
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
24+
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
2525

26-
// Get the number of non zero coordinates for r and c
26+
// Get the number of non zero coordinates for r and c
2727
n=0;
2828
for (int i=0; i<n1; i++) {
2929
double val=*(X+i);
@@ -105,3 +105,186 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
105105

106106
return ret;
107107
}
108+
109+
110+
int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
111+
long *iG, long *jG, double *G, long * nG,
112+
double* alpha, double* beta, double *cost, int maxIter) {
113+
// beware M and C anre strored in row major C style!!!
114+
115+
// Get the number of non zero coordinates for r and c and vectors
116+
int n, m, i, cur;
117+
118+
typedef FullBipartiteDigraph Digraph;
119+
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
120+
121+
// Get the number of non zero coordinates for r and c
122+
n=0;
123+
for (int i=0; i<n1; i++) {
124+
double val=*(X+i);
125+
if (val>0) {
126+
n++;
127+
}else if(val<0){
128+
return INFEASIBLE;
129+
}
130+
}
131+
m=0;
132+
for (int i=0; i<n2; i++) {
133+
double val=*(Y+i);
134+
if (val>0) {
135+
m++;
136+
}else if(val<0){
137+
return INFEASIBLE;
138+
}
139+
}
140+
141+
// Define the graph
142+
143+
std::vector<int> indI(n), indJ(m);
144+
std::vector<double> weights1(n), weights2(m);
145+
Digraph di(n, m);
146+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
147+
148+
// Set supply and demand, don't account for 0 values (faster)
149+
150+
cur=0;
151+
for (int i=0; i<n1; i++) {
152+
double val=*(X+i);
153+
if (val>0) {
154+
weights1[ cur ] = val;
155+
indI[cur++]=i;
156+
}
157+
}
158+
159+
// Demand is actually negative supply...
160+
161+
cur=0;
162+
for (int i=0; i<n2; i++) {
163+
double val=*(Y+i);
164+
if (val>0) {
165+
weights2[ cur ] = -val;
166+
indJ[cur++]=i;
167+
}
168+
}
169+
170+
// Define the graph
171+
net.supplyMap(&weights1[0], n, &weights2[0], m);
172+
173+
// Set the cost of each edge
174+
for (int i=0; i<n; i++) {
175+
for (int j=0; j<m; j++) {
176+
double val=*(D+indI[i]*n2+indJ[j]);
177+
net.setCost(di.arcFromId(i*m+j), val);
178+
}
179+
}
180+
181+
182+
// Solve the problem with the network simplex algorithm
183+
184+
int ret=net.run();
185+
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
186+
*cost = 0;
187+
Arc a; di.first(a);
188+
cur=0;
189+
for (; a != INVALID; di.next(a)) {
190+
int i = di.source(a);
191+
int j = di.target(a);
192+
double flow = net.flow(a);
193+
if (flow>0)
194+
{
195+
*cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
196+
197+
*(G+cur) = flow;
198+
*(iG+cur) = indI[i];
199+
*(jG+cur) = indJ[j-n];
200+
*(alpha + indI[i]) = -net.potential(i);
201+
*(beta + indJ[j-n]) = net.potential(j);
202+
cur++;
203+
}
204+
}
205+
*nG=cur; // nb of value +1 for numpy indexing
206+
207+
}
208+
209+
210+
return ret;
211+
}
212+
213+
int EMD_wrap_all_sparse(int n1, int n2, double *X, double *Y,
214+
long *iD, long *jD, double *D, long nD,
215+
long *iG, long *jG, double *G, long * nG,
216+
double* alpha, double* beta, double *cost, int maxIter) {
217+
// beware M and C anre strored in row major C style!!!
218+
219+
// Get the number of non zero coordinates for r and c and vectors
220+
int n, m, cur;
221+
222+
typedef FullBipartiteDigraph Digraph;
223+
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
224+
225+
n=n1;
226+
m=n2;
227+
228+
229+
// Define the graph
230+
231+
232+
std::vector<double> weights2(m);
233+
Digraph di(n, m);
234+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
235+
236+
// Set supply and demand, don't account for 0 values (faster)
237+
238+
239+
// Demand is actually negative supply...
240+
241+
cur=0;
242+
for (int i=0; i<n2; i++) {
243+
double val=*(Y+i);
244+
if (val>0) {
245+
weights2[ cur ] = -val;
246+
}
247+
}
248+
249+
// Define the graph
250+
net.supplyMap(X, n, &weights2[0], m);
251+
252+
// Set the cost of each edge
253+
for (int k=0; k<nD; k++) {
254+
int i = iD[k];
255+
int j = jD[k];
256+
net.setCost(di.arcFromId(i*m+j), D[k]);
257+
258+
}
259+
260+
261+
// Solve the problem with the network simplex algorithm
262+
263+
int ret=net.run();
264+
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
265+
*cost = net.totalCost();
266+
Arc a; di.first(a);
267+
cur=0;
268+
for (; a != INVALID; di.next(a)) {
269+
int i = di.source(a);
270+
int j = di.target(a);
271+
double flow = net.flow(a);
272+
if (flow>0)
273+
{
274+
275+
*(G+cur) = flow;
276+
*(iG+cur) = i;
277+
*(jG+cur) = j-n;
278+
*(alpha + i) = -net.potential(i);
279+
*(beta + j-n) = net.potential(j);
280+
cur++;
281+
}
282+
}
283+
*nG=cur; // nb of value +1 for numpy indexing
284+
285+
}
286+
287+
288+
return ret;
289+
}
290+

ot/lp/__init__.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
'emd_1d', 'emd2_1d', 'wasserstein_1d']
2828

2929

30-
def emd(a, b, M, numItermax=100000, log=False):
30+
def emd(a, b, M, numItermax=100000, log=False, dense=True):
3131
r"""Solves the Earth Movers distance problem and returns the OT matrix
3232
3333
@@ -62,6 +62,10 @@ def emd(a, b, M, numItermax=100000, log=False):
6262
log: bool, optional (default=False)
6363
If True, returns a dictionary containing the cost and dual
6464
variables. Otherwise returns only the optimal transportation matrix.
65+
dense: boolean, optional (default=True)
66+
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
67+
Otherwise returns a sparse representation using scipy's `coo_matrix`
68+
format.
6569
6670
Returns
6771
-------
@@ -103,13 +107,19 @@ def emd(a, b, M, numItermax=100000, log=False):
103107
b = np.asarray(b, dtype=np.float64)
104108
M = np.asarray(M, dtype=np.float64)
105109

110+
106111
# if empty array given then use uniform distributions
107112
if len(a) == 0:
108113
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
109114
if len(b) == 0:
110115
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
111116

112-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
117+
if dense:
118+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
119+
else:
120+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
121+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
122+
113123
result_code_string = check_result(result_code)
114124
if log:
115125
log = {}
@@ -123,7 +133,7 @@ def emd(a, b, M, numItermax=100000, log=False):
123133

124134

125135
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
126-
numItermax=100000, log=False, return_matrix=False):
136+
numItermax=100000, log=False, dense=True, return_matrix=False):
127137
r"""Solves the Earth Movers distance problem and returns the loss
128138
129139
.. math::
@@ -161,6 +171,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
161171
variables. Otherwise returns only the optimal transportation cost.
162172
return_matrix: boolean, optional (default=False)
163173
If True, returns the optimal transportation matrix in the log.
174+
dense: boolean, optional (default=True)
175+
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
176+
Otherwise returns a sparse representation using scipy's `coo_matrix`
177+
format.
164178
165179
Returns
166180
-------
@@ -214,19 +228,30 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
214228

215229
if log or return_matrix:
216230
def f(b):
217-
G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
218-
result_code_string = check_result(resultCode)
231+
if dense:
232+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
233+
else:
234+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
235+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
236+
237+
result_code_string = check_result(result_code)
219238
log = {}
220239
if return_matrix:
221240
log['G'] = G
222241
log['u'] = u
223242
log['v'] = v
224243
log['warning'] = result_code_string
225-
log['result_code'] = resultCode
244+
log['result_code'] = result_code
226245
return [cost, log]
227246
else:
228247
def f(b):
229-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
248+
if dense:
249+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
250+
else:
251+
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
252+
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
253+
254+
result_code_string = check_result(result_code)
230255
check_result(result_code)
231256
return cost
232257

0 commit comments

Comments
 (0)