Skip to content

Commit a4afee8

Browse files
committed
first implemntation sparse loss
1 parent c439e3e commit a4afee8

File tree

4 files changed

+88
-1
lines changed

4 files changed

+88
-1
lines changed

ot/lp/EMD.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,9 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
3636
long *iG, long *jG, double *G, long * nG,
3737
double* alpha, double* beta, double *cost, int maxIter);
3838

39+
int EMD_wrap_all_sparse(int n1, int n2, double *X, double *Y,
40+
long *iD, long *jD, double *D, long nD,
41+
long *iG, long *jG, double *G, long * nG,
42+
double* alpha, double* beta, double *cost, int maxIter);
43+
3944
#endif

ot/lp/EMD_wrapper.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,81 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
210210
return ret;
211211
}
212212

213+
int EMD_wrap_all_sparse(int n1, int n2, double *X, double *Y,
214+
long *iD, long *jD, double *D, long nD,
215+
long *iG, long *jG, double *G, long * nG,
216+
double* alpha, double* beta, double *cost, int maxIter) {
217+
// beware M and C anre strored in row major C style!!!
218+
219+
// Get the number of non zero coordinates for r and c and vectors
220+
int n, m, cur;
221+
222+
typedef FullBipartiteDigraph Digraph;
223+
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
224+
225+
n=n1;
226+
m=n2;
227+
228+
229+
// Define the graph
230+
231+
232+
std::vector<double> weights2(m);
233+
Digraph di(n, m);
234+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
235+
236+
// Set supply and demand, don't account for 0 values (faster)
237+
238+
239+
// Demand is actually negative supply...
240+
241+
cur=0;
242+
for (int i=0; i<n2; i++) {
243+
double val=*(Y+i);
244+
if (val>0) {
245+
weights2[ cur ] = -val;
246+
}
247+
}
248+
249+
// Define the graph
250+
net.supplyMap(X, n, &weights2[0], m);
251+
252+
// Set the cost of each edge
253+
for (int k=0; k<nD; k++) {
254+
int i = iD[k];
255+
int j = jD[k];
256+
net.setCost(di.arcFromId(i*m+j), D[k]);
257+
258+
}
259+
260+
261+
// Solve the problem with the network simplex algorithm
262+
263+
int ret=net.run();
264+
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
265+
*cost = net.totalCost();
266+
Arc a; di.first(a);
267+
cur=0;
268+
for (; a != INVALID; di.next(a)) {
269+
int i = di.source(a);
270+
int j = di.target(a);
271+
double flow = net.flow(a);
272+
if (flow>0)
273+
{
274+
275+
*(G+cur) = flow;
276+
*(iG+cur) = i;
277+
*(jG+cur) = j-n;
278+
*(alpha + i) = -net.potential(i);
279+
*(beta + j-n) = net.potential(j);
280+
cur++;
281+
}
282+
}
283+
*nG=cur; // nb of value +1 for numpy indexing
284+
285+
}
286+
287+
288+
return ret;
289+
}
290+

ot/lp/emd_wrap.pyx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ cdef extern from "EMD.h":
2323
int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
2424
long *iG, long *jG, double *G, long * nG,
2525
double* alpha, double* beta, double *cost, int maxIter)
26+
int EMD_wrap_all_sparse(int n1, int n2, double *X, double *Y,
27+
long *iD, long *jD, double *D, long nD,
28+
long *iG, long *jG, double *G, long * nG,
29+
double* alpha, double* beta, double *cost, int maxIter)
2630
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
2731

2832

ot/lp/network_simplex_simple.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ namespace lemon {
686686
/// \see resetParams(), reset()
687687
ProblemType run() {
688688
#if DEBUG_LVL>0
689-
std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED\n";
689+
std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED << "\n" ;
690690
#endif
691691

692692
if (!init()) return INFEASIBLE;

0 commit comments

Comments
 (0)