Skip to content

Commit 875867f

Browse files
Implement discharge
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neo4j.com>
1 parent 146af9e commit 875867f

File tree

5 files changed

+440
-27
lines changed

5 files changed

+440
-27
lines changed

algo/src/main/java/org/neo4j/gds/maxflow/AtomicWorkingSet.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.neo4j.gds.core.utils.paged.HugeLongArrayQueue;
2424

2525
import java.util.concurrent.atomic.AtomicLong;
26+
import java.util.function.Consumer;
2627

2728
public class AtomicWorkingSet {
2829
private final HugeLongArray workingSet;
@@ -59,7 +60,7 @@ void push(long value) {
5960

6061
void batchPush(HugeLongArrayQueue queue) {
6162
long idx = size.getAndAdd(queue.size());
62-
while(!queue.isEmpty()) {
63+
while (!queue.isEmpty()) {
6364
var node = queue.remove();
6465
workingSet.set(idx++, node);
6566
}
@@ -82,4 +83,11 @@ long pop() {
8283
return -1L;
8384
}
8485
}
86+
87+
void consumeBatch(long from, long to, Consumer<Long> consumer) {
88+
for (long idx = from; idx < to; idx++) {
89+
long v = unsafePeek(idx);
90+
consumer.accept(v);
91+
}
92+
}
8593
}
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.maxflow;
21+
22+
import org.apache.commons.lang3.mutable.MutableBoolean;
23+
import org.apache.commons.lang3.mutable.MutableDouble;
24+
import org.apache.commons.lang3.mutable.MutableLong;
25+
import org.neo4j.gds.collections.ha.HugeDoubleArray;
26+
import org.neo4j.gds.collections.ha.HugeLongArray;
27+
import org.neo4j.gds.collections.haa.HugeAtomicDoubleArray;
28+
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
29+
import org.neo4j.gds.core.utils.paged.HugeLongArrayQueue;
30+
31+
import java.util.concurrent.atomic.AtomicLong;
32+
import java.util.function.Consumer;
33+
34+
public class DischargeTask implements Runnable {
35+
private final FlowGraph flowGraph;
36+
private final HugeDoubleArray excess;
37+
private final HugeLongArray label;
38+
private final HugeLongArray tempLabel;
39+
private final AtomicWorkingSet workingSet;
40+
private final HugeAtomicDoubleArray addedExcess;
41+
private final HugeAtomicBitSet isDiscovered;
42+
private final long targetNode;
43+
private final long beta;
44+
private final AtomicLong workSinceLastGR;
45+
46+
private final long batchSize;
47+
private final long nodeCount;
48+
private final HugeLongArrayQueue localDiscoveredVertices;
49+
private PHASE phase;
50+
private long localWork;
51+
52+
public DischargeTask(
53+
FlowGraph flowGraph,
54+
HugeDoubleArray excess,
55+
HugeLongArray label,
56+
HugeLongArray tempLabel,
57+
HugeAtomicDoubleArray addedExcess,
58+
HugeAtomicBitSet isDiscovered,
59+
AtomicWorkingSet workingSet,
60+
long targetNode,
61+
long beta,
62+
AtomicLong workSinceLastGR
63+
) {
64+
this.excess = excess;
65+
this.flowGraph = flowGraph;
66+
this.label = label;
67+
this.tempLabel = tempLabel;
68+
this.addedExcess = addedExcess;
69+
this.isDiscovered = isDiscovered;
70+
this.workingSet = workingSet;
71+
this.targetNode = targetNode;
72+
this.beta = beta;
73+
this.workSinceLastGR = workSinceLastGR;
74+
75+
this.batchSize = 8;
76+
this.phase = PHASE.DISCHARGE;
77+
78+
this.localWork = 0;
79+
this.nodeCount = flowGraph.nodeCount();
80+
this.localDiscoveredVertices = HugeLongArrayQueue.newQueue(flowGraph.nodeCount());
81+
}
82+
83+
public void run() {
84+
switch (phase) {
85+
case PHASE.DISCHARGE -> dischargeWorkingSet();
86+
case PHASE.SYNC_WORKING_SET -> syncWorkingSet();
87+
case PHASE.UPDATE_WORKING_SET -> updateWorkingSet();
88+
case PHASE.SYNC_NEW_WORKING_SET -> syncNewWorkingSet();
89+
}
90+
}
91+
92+
private void batchConsumeWorkingSet(Consumer<Long> consumer) {
93+
long oldIdx;
94+
while ((oldIdx = workingSet.getAndAdd(batchSize)) < workingSet.size()) {
95+
long toIdx = Math.min(oldIdx + batchSize, workingSet.size());
96+
workingSet.consumeBatch(oldIdx, toIdx, consumer);
97+
}
98+
}
99+
100+
private void dischargeWorkingSet() {
101+
batchConsumeWorkingSet(this::discharge);
102+
phase = PHASE.SYNC_WORKING_SET;
103+
}
104+
105+
void discharge(long v) {
106+
if (label.get(v) >= nodeCount || v == targetNode) {
107+
return;
108+
}
109+
110+
tempLabel.set(v, label.get(v));
111+
final var e = new MutableDouble(excess.get(v));
112+
113+
while (e.doubleValue() > 0) {
114+
final var newLabel = new MutableLong(nodeCount);
115+
final var breakOuter = new MutableBoolean(false);
116+
117+
//todo: Check to improve zero comparisons!
118+
119+
ResidualEdgeConsumer consumer = (long s, long t, long relIdx, double residualCapacity, boolean isReverse) -> {
120+
if (residualCapacity <= 0.0) {
121+
return true; //skip
122+
}
123+
var admissible = (tempLabel.get(s) == label.get(t) + 1);
124+
if (admissible) {
125+
if (excess.get(t) > 0.0) {
126+
boolean win = (label.get(s) == label.get(t) + 1 || label.get(s) + 1 < label.get(t) || (label.get(
127+
s) == label.get(t) && s < t));
128+
if (!win) {
129+
breakOuter.setTrue();
130+
return true;
131+
}
132+
}
133+
var delta = Math.min(e.doubleValue(), residualCapacity);
134+
flowGraph.push(relIdx, delta, isReverse);
135+
e.subtract(delta);
136+
addedExcess.getAndAdd(t, delta);
137+
138+
if (!isDiscovered.getAndSet(t)) {
139+
localDiscoveredVertices.add(t);
140+
}
141+
if (e.doubleValue() <= 0.0) {
142+
breakOuter.setTrue();
143+
return false;
144+
}
145+
} else if (label.get(t) >= tempLabel.get(s)) {
146+
newLabel.setValue(Math.min(newLabel.longValue(), label.get(t) + 1));
147+
//if ws are sorted by label ascendingly, then later values will be neither better(lower) than this, nor admissible -> return false; (break)
148+
}
149+
return true;
150+
};
151+
152+
flowGraph.forEachRelationship(v, consumer);
153+
154+
if (breakOuter.isTrue()) {
155+
break;
156+
}
157+
tempLabel.set(v, newLabel.longValue());
158+
localWork += flowGraph.outDegree(v) + beta;
159+
if (tempLabel.get(v) == nodeCount) {
160+
break;
161+
}
162+
}
163+
addedExcess.getAndAdd(v, (e.doubleValue() - excess.get(v)));
164+
if (e.doubleValue() > 0.0 && !isDiscovered.getAndSet(v)) {
165+
localDiscoveredVertices.add(v);
166+
}
167+
}
168+
169+
void syncWorkingSet() {
170+
batchConsumeWorkingSet((v) -> {
171+
label.set(v, tempLabel.get(v));
172+
excess.addTo(v, addedExcess.get(v));
173+
addedExcess.set(v, 0);
174+
isDiscovered.clear(v);
175+
});
176+
workSinceLastGR.addAndGet(localWork);
177+
localWork = 0;
178+
phase = PHASE.UPDATE_WORKING_SET;
179+
}
180+
181+
void updateWorkingSet() {
182+
workingSet.batchPush(localDiscoveredVertices);
183+
phase = PHASE.SYNC_NEW_WORKING_SET;
184+
}
185+
186+
void syncNewWorkingSet() {
187+
batchConsumeWorkingSet((v) -> {
188+
excess.addTo(v, addedExcess.get(v));
189+
addedExcess.set(v, 0);
190+
isDiscovered.clear(v);
191+
});
192+
phase = PHASE.UPDATE_WORKING_SET;
193+
}
194+
195+
enum PHASE {
196+
DISCHARGE,
197+
SYNC_WORKING_SET,
198+
UPDATE_WORKING_SET,
199+
SYNC_NEW_WORKING_SET
200+
}
201+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.maxflow;
21+
22+
import org.neo4j.gds.collections.ha.HugeDoubleArray;
23+
import org.neo4j.gds.collections.ha.HugeLongArray;
24+
import org.neo4j.gds.collections.haa.HugeAtomicDoubleArray;
25+
import org.neo4j.gds.core.concurrency.Concurrency;
26+
import org.neo4j.gds.core.concurrency.ParallelUtil;
27+
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
28+
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
29+
30+
import java.util.concurrent.atomic.AtomicLong;
31+
32+
public class Discharging {
33+
public static void processWorkingSet(
34+
FlowGraph flowGraph,
35+
HugeDoubleArray excess,
36+
HugeLongArray label,
37+
HugeLongArray tempLabel,
38+
HugeAtomicDoubleArray addedExcess,
39+
HugeAtomicBitSet isDiscovered,
40+
AtomicWorkingSet workingSet,
41+
long targetNode,
42+
long beta,
43+
AtomicLong workSinceLastGR,
44+
Concurrency concurrency
45+
) {
46+
var dischargeTasks = ParallelUtil.tasks(
47+
concurrency,
48+
() -> new DischargeTask(
49+
flowGraph.concurrentCopy(),
50+
excess,
51+
label,
52+
tempLabel,
53+
addedExcess,
54+
isDiscovered,
55+
workingSet,
56+
targetNode,
57+
beta,
58+
workSinceLastGR
59+
)
60+
);
61+
62+
//Discharge working set
63+
RunWithConcurrency.builder().concurrency(concurrency).tasks(dischargeTasks).build().run();
64+
workingSet.resetIdx();
65+
66+
//Sync working set
67+
RunWithConcurrency.builder().concurrency(concurrency).tasks(dischargeTasks).build().run();
68+
workingSet.reset();
69+
70+
//Update working set
71+
RunWithConcurrency.builder().concurrency(concurrency).tasks(dischargeTasks).build().run();
72+
73+
//Sync new working set
74+
RunWithConcurrency.builder().concurrency(concurrency).tasks(dischargeTasks).build().run();
75+
workingSet.resetIdx();
76+
}
77+
}

0 commit comments

Comments
 (0)