Skip to content

Commit 7429c08

Browse files
Reuse tasks
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neo4j.com>
1 parent 4f277a9 commit 7429c08

File tree

6 files changed

+158
-54
lines changed

6 files changed

+158
-54
lines changed

algo/src/main/java/org/neo4j/gds/maxflow/DischargeTask.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ public class DischargeTask implements Runnable {
4949
private PHASE phase;
5050
private long localWork;
5151

52+
53+
5254
public DischargeTask(
5355
FlowGraph flowGraph,
5456
HugeDoubleArray excess,
@@ -59,7 +61,8 @@ public DischargeTask(
5961
AtomicWorkingSet workingSet,
6062
long targetNode,
6163
long beta,
62-
AtomicLong workSinceLastGR
64+
AtomicLong workSinceLastGR,
65+
HugeLongArrayQueue localDiscoveredVertices
6366
) {
6467
this.excess = excess;
6568
this.flowGraph = flowGraph;
@@ -77,7 +80,7 @@ public DischargeTask(
7780

7881
this.localWork = 0;
7982
this.nodeCount = flowGraph.nodeCount();
80-
this.localDiscoveredVertices = HugeLongArrayQueue.newQueue(flowGraph.nodeCount());
83+
this.localDiscoveredVertices = localDiscoveredVertices;
8184
}
8285

8386
public void run() {

algo/src/main/java/org/neo4j/gds/maxflow/Discharging.java

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,21 @@
2323
import org.neo4j.gds.collections.ha.HugeLongArray;
2424
import org.neo4j.gds.collections.haa.HugeAtomicDoubleArray;
2525
import org.neo4j.gds.core.concurrency.Concurrency;
26-
import org.neo4j.gds.core.concurrency.ParallelUtil;
2726
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
2827
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
28+
import org.neo4j.gds.core.utils.paged.HugeLongArrayQueue;
2929

30+
import java.util.ArrayList;
31+
import java.util.Collection;
32+
import java.util.List;
3033
import java.util.concurrent.atomic.AtomicLong;
3134

3235
public class Discharging {
33-
public static void processWorkingSet(
36+
private final AtomicWorkingSet workingSet;
37+
private final Concurrency concurrency;
38+
private final Collection<Runnable> dischargeTasks;
39+
40+
public static Discharging createDischarging(
3441
FlowGraph flowGraph,
3542
HugeDoubleArray excess,
3643
HugeLongArray label,
@@ -41,14 +48,13 @@ public static void processWorkingSet(
4148
long targetNode,
4249
long beta,
4350
AtomicLong workSinceLastGR,
44-
Concurrency concurrency
45-
) {
46-
47-
//Todo: Refactor so that dischargeTasks can be reused between iterations.
48-
49-
var dischargeTasks = ParallelUtil.tasks(
50-
concurrency,
51-
() -> new DischargeTask(
51+
Concurrency concurrency,
52+
HugeLongArrayQueue[] threadQueues
53+
)
54+
{
55+
List<Runnable> dischargeTasks = new ArrayList<>();
56+
for (int i = 0; i < concurrency.value(); i++) {
57+
dischargeTasks.add(new DischargeTask(
5258
flowGraph.concurrentCopy(),
5359
excess,
5460
label,
@@ -58,10 +64,21 @@ public static void processWorkingSet(
5864
workingSet,
5965
targetNode,
6066
beta,
61-
workSinceLastGR
62-
)
63-
);
67+
workSinceLastGR,
68+
threadQueues[i]
69+
));
70+
}
71+
72+
return new Discharging(workingSet, dischargeTasks, concurrency);
73+
}
74+
75+
private Discharging(AtomicWorkingSet workingSet, Collection<Runnable> dischargeTasks, Concurrency concurrency) {
76+
this.workingSet = workingSet;
77+
this.dischargeTasks = dischargeTasks;
78+
this.concurrency = concurrency;
79+
}
6480

81+
public void processWorkingSet() {
6582
//Discharge working set
6683
RunWithConcurrency.builder().concurrency(concurrency).tasks(dischargeTasks).build().run();
6784
workingSet.resetIdx();

algo/src/main/java/org/neo4j/gds/maxflow/GlobalRelabeling.java

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,45 +21,74 @@
2121

2222
import org.neo4j.gds.collections.ha.HugeLongArray;
2323
import org.neo4j.gds.core.concurrency.Concurrency;
24-
import org.neo4j.gds.core.concurrency.ParallelUtil;
2524
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
2625
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
2726
import org.neo4j.gds.core.utils.paged.HugeLongArrayQueue;
2827

28+
import java.util.ArrayList;
29+
import java.util.Collection;
30+
import java.util.List;
31+
2932
enum Phase {
3033
TRAVERSE,
3134
SYNC
3235
}
3336

3437
public class GlobalRelabeling {
35-
public static void globalRelabeling(
38+
private final long nodeCount;
39+
private final HugeLongArray label;
40+
private final AtomicWorkingSet frontier;
41+
private final HugeAtomicBitSet vertexIsDiscovered;
42+
private final long source;
43+
private final long target;
44+
private final Concurrency concurrency;
45+
private final Collection<Runnable> globalRelabelingTasks;
46+
47+
static GlobalRelabeling createRelabeling(
3648
FlowGraph flowGraph,
3749
HugeLongArray label,
3850
long source,
3951
long target,
40-
Concurrency concurrency
52+
Concurrency concurrency,
53+
HugeLongArrayQueue[] threadQueues
4154
) {
42-
label.setAll((i) -> flowGraph.nodeCount());
43-
label.set(target, 0L);
4455
var vertexIsDiscovered = HugeAtomicBitSet.create(flowGraph.nodeCount());
45-
4656
var frontier = new AtomicWorkingSet(flowGraph.nodeCount());
47-
frontier.push(target);
48-
vertexIsDiscovered.set(target);
49-
vertexIsDiscovered.set(source);
5057

58+
List<Runnable> globalRelabelingTasks = new ArrayList<>();
59+
for (int i = 0; i < concurrency.value(); i++) {
60+
globalRelabelingTasks.add(new GlobalRelabellingBFSTask(flowGraph.concurrentCopy(), frontier, vertexIsDiscovered, label, threadQueues[i]));
61+
}
5162

52-
var tasks = ParallelUtil.tasks(
53-
concurrency,
54-
() -> new GlobalRelabellingBFSTask(flowGraph.concurrentCopy(), frontier, vertexIsDiscovered, label)
55-
);
63+
return new GlobalRelabeling(flowGraph.nodeCount(), label, frontier, vertexIsDiscovered, source, target, concurrency, globalRelabelingTasks);
64+
}
65+
66+
private GlobalRelabeling(long nodeCount, HugeLongArray label, AtomicWorkingSet frontier, HugeAtomicBitSet vertexIsDiscovered, long source, long target, Concurrency concurrency, Collection<Runnable> globalRelabelingTasks) {
67+
this.nodeCount = nodeCount;
68+
this.label = label;
69+
this.frontier = frontier;
70+
this.vertexIsDiscovered = vertexIsDiscovered;
71+
this.source = source;
72+
this.target = target;
73+
this.concurrency = concurrency;
74+
this.globalRelabelingTasks = globalRelabelingTasks;
75+
}
5676

77+
public void globalRelabeling() {
78+
label.setAll((i) -> nodeCount);
79+
label.set(target, 0L);
80+
frontier.reset();
81+
frontier.push(target);
82+
vertexIsDiscovered.clear();
83+
vertexIsDiscovered.set(source);
84+
vertexIsDiscovered.set(target);
5785
while (!frontier.isEmpty()) {
58-
RunWithConcurrency.builder().concurrency(concurrency).tasks(tasks).build().run();
86+
RunWithConcurrency.builder().concurrency(concurrency).tasks(globalRelabelingTasks).build().run();
5987
frontier.reset();
60-
RunWithConcurrency.builder().concurrency(concurrency).tasks(tasks).build().run();
88+
RunWithConcurrency.builder().concurrency(concurrency).tasks(globalRelabelingTasks).build().run();
6189
}
62-
label.set(source, flowGraph.nodeCount());
90+
label.set(source, nodeCount);
91+
6392
}
6493
}
6594

@@ -77,11 +106,13 @@ class GlobalRelabellingBFSTask implements Runnable {
77106
FlowGraph flowGraph,
78107
AtomicWorkingSet frontier,
79108
HugeAtomicBitSet vertexIsDiscovered,
80-
HugeLongArray label
109+
HugeLongArray label,
110+
HugeLongArrayQueue localDiscoveredVertices
81111
) {
82112
this.flowGraph = flowGraph;
83113
this.frontier = frontier;
84-
this.localDiscoveredVertices = HugeLongArrayQueue.newQueue(flowGraph.nodeCount()); //think
114+
// this.localDiscoveredVertices = HugeLongArrayQueue.newQueue(flowGraph.nodeCount()); //think //fixme: Don't allocate every time
115+
this.localDiscoveredVertices = localDiscoveredVertices;
85116
this.verticesIsDiscovered = vertexIsDiscovered;
86117
this.label = label;
87118
this.phase = Phase.TRAVERSE;

algo/src/main/java/org/neo4j/gds/maxflow/MaxFlow.java

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.neo4j.gds.collections.ha.HugeLongArray;
2626
import org.neo4j.gds.collections.haa.HugeAtomicDoubleArray;
2727
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
28+
import org.neo4j.gds.core.utils.paged.HugeLongArrayQueue;
2829
import org.neo4j.gds.core.utils.paged.ParallelDoublePageCreator;
2930
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3031
import org.neo4j.gds.termination.TerminationFlag;
@@ -127,24 +128,41 @@ private void maximizeFlow(Preflow preflow, long sourceNode, long targetNode) { /
127128

128129
var workSinceLastGR = new AtomicLong(Long.MAX_VALUE);
129130

131+
HugeLongArrayQueue[] threadQueues = new HugeLongArrayQueue[parameters.concurrency().value()];
132+
for (int i = 0; i < threadQueues.length; i++) {
133+
threadQueues[i] = HugeLongArrayQueue.newQueue(nodeCount);
134+
}
135+
136+
var globalRelabeling = GlobalRelabeling.createRelabeling(
137+
flowGraph,
138+
label,
139+
sourceNode,
140+
targetNode,
141+
parameters.concurrency(),
142+
threadQueues
143+
);
144+
145+
var discharging = Discharging.createDischarging(
146+
flowGraph,
147+
excess,
148+
label,
149+
tempLabel,
150+
addedExcess,
151+
isDiscovered,
152+
workingSet,
153+
targetNode,
154+
parameters.beta(),
155+
workSinceLastGR,
156+
parameters.concurrency(),
157+
threadQueues
158+
);
159+
130160
while (!workingSet.isEmpty()) {
131161
if (parameters.freq() * workSinceLastGR.doubleValue() > parameters.alpha() * nodeCount + edgeCount) {
132-
GlobalRelabeling.globalRelabeling(flowGraph, label, sourceNode, targetNode, parameters.concurrency());
162+
globalRelabeling.globalRelabeling();
133163
workSinceLastGR.set(0L);
134164
}
135-
Discharging.processWorkingSet(
136-
flowGraph,
137-
excess,
138-
label,
139-
tempLabel,
140-
addedExcess,
141-
isDiscovered,
142-
workingSet,
143-
targetNode,
144-
parameters.beta(),
145-
workSinceLastGR,
146-
parameters.concurrency()
147-
);
165+
discharging.processWorkingSet();
148166
}
149167
}
150168
}

algo/src/test/java/org/neo4j/gds/maxflow/DischargeTaskTest.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.neo4j.gds.collections.haa.HugeAtomicDoubleArray;
2626
import org.neo4j.gds.core.concurrency.Concurrency;
2727
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
28+
import org.neo4j.gds.core.utils.paged.HugeLongArrayQueue;
2829
import org.neo4j.gds.core.utils.paged.ParallelDoublePageCreator;
2930
import org.neo4j.gds.extension.GdlExtension;
3031
import org.neo4j.gds.extension.GdlGraph;
@@ -86,7 +87,13 @@ void discharge() {
8687

8788

8889
workingSet.push(graph.toMappedNodeId("c"));
89-
var task = new DischargeTask(flowGraph.concurrentCopy(), excess, label, tempLabel, addedExcess, isDiscovered, workingSet, targetNode, beta, workSinceLastGR);
90+
91+
HugeLongArrayQueue[] threadQueues = new HugeLongArrayQueue[1];
92+
for (int i = 0; i < threadQueues.length; i++) {
93+
threadQueues[i] = HugeLongArrayQueue.newQueue(flowGraph.nodeCount());
94+
}
95+
96+
var task = new DischargeTask(flowGraph.concurrentCopy(), excess, label, tempLabel, addedExcess, isDiscovered, workingSet, targetNode, beta, workSinceLastGR, threadQueues[0]);
9097

9198
task.discharge(graph.toMappedNodeId("c"));
9299
workingSet.resetIdx();

algo/src/test/java/org/neo4j/gds/maxflow/GlobalRelabelingTest.java

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.junit.jupiter.api.Test;
2323
import org.neo4j.gds.collections.ha.HugeLongArray;
2424
import org.neo4j.gds.core.concurrency.Concurrency;
25+
import org.neo4j.gds.core.utils.paged.HugeLongArrayQueue;
2526
import org.neo4j.gds.extension.GdlExtension;
2627
import org.neo4j.gds.extension.GdlGraph;
2728
import org.neo4j.gds.extension.Inject;
@@ -63,14 +64,23 @@ void test() {
6364

6465
var label = HugeLongArray.newArray(flowGraph.nodeCount());
6566
label.setAll((i) -> flowGraph.nodeCount());
66-
GlobalRelabeling.globalRelabeling(
67+
68+
HugeLongArrayQueue[] threadQueues = new HugeLongArrayQueue[1];
69+
for (int i = 0; i < threadQueues.length; i++) {
70+
threadQueues[i] = HugeLongArrayQueue.newQueue(flowGraph.nodeCount());
71+
}
72+
73+
var globalRelabeling = GlobalRelabeling.createRelabeling(
6774
flowGraph,
6875
label,
6976
graph.toMappedNodeId("c"),
7077
graph.toMappedNodeId("d"),
71-
new Concurrency(1)
78+
new Concurrency(1),
79+
threadQueues
7280
);
7381

82+
globalRelabeling.globalRelabeling();
83+
7484
assertThat(label.get(graph.toMappedNodeId("a"))).isEqualTo(1L);
7585
assertThat(label.get(graph.toMappedNodeId("b"))).isEqualTo(2L);
7686
assertThat(label.get(graph.toMappedNodeId("c"))).isEqualTo(7L);
@@ -84,14 +94,23 @@ void test2() {
8494

8595
var label = HugeLongArray.newArray(flowGraph.nodeCount());
8696
label.setAll((i) -> flowGraph.nodeCount());
87-
GlobalRelabeling.globalRelabeling(
97+
98+
HugeLongArrayQueue[] threadQueues = new HugeLongArrayQueue[1];
99+
for (int i = 0; i < threadQueues.length; i++) {
100+
threadQueues[i] = HugeLongArrayQueue.newQueue(flowGraph.nodeCount());
101+
}
102+
103+
var globalRelabeling = GlobalRelabeling.createRelabeling(
88104
flowGraph,
89105
label,
90106
graph.toMappedNodeId("a"),
91107
graph.toMappedNodeId("e"),
92-
new Concurrency(1)
108+
new Concurrency(1),
109+
threadQueues
93110
);
94111

112+
globalRelabeling.globalRelabeling();
113+
95114
assertThat(label.get(graph.toMappedNodeId("a"))).isEqualTo(7L);
96115
assertThat(label.get(graph.toMappedNodeId("b"))).isEqualTo(7L);
97116
assertThat(label.get(graph.toMappedNodeId("c"))).isEqualTo(7L);
@@ -105,14 +124,23 @@ void test3() {
105124

106125
var label = HugeLongArray.newArray(flowGraph.nodeCount());
107126
label.setAll((i) -> flowGraph.nodeCount());
108-
GlobalRelabeling.globalRelabeling(
127+
128+
HugeLongArrayQueue[] threadQueues = new HugeLongArrayQueue[1];
129+
for (int i = 0; i < threadQueues.length; i++) {
130+
threadQueues[i] = HugeLongArrayQueue.newQueue(flowGraph.nodeCount());
131+
}
132+
133+
var globalRelabeling = GlobalRelabeling.createRelabeling(
109134
flowGraph,
110135
label,
111136
flowGraph.superSource(),
112137
flowGraph.superTarget(),
113-
new Concurrency(1)
138+
new Concurrency(1),
139+
threadQueues
114140
);
115141

142+
globalRelabeling.globalRelabeling();
143+
116144
assertThat(label.get(flowGraph.superTarget())).isEqualTo(0L);
117145
assertThat(label.get(graph.toMappedNodeId("e"))).isEqualTo(1L);
118146
assertThat(label.get(graph.toMappedNodeId("d"))).isEqualTo(2L);

0 commit comments

Comments
 (0)