Skip to content

Commit 9b292e3

Browse files
Max-flow supports termination flag
Co-authored-by: Alfred Clemedtson <alfred.clemedtson@neo4j.com>
1 parent 9b239d8 commit 9b292e3

File tree

6 files changed

+73
-19
lines changed

6 files changed

+73
-19
lines changed

algo/src/main/java/org/neo4j/gds/maxflow/Discharging.java

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
2727
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
2828
import org.neo4j.gds.core.utils.paged.HugeLongArrayQueue;
29+
import org.neo4j.gds.termination.TerminationFlag;
2930

3031
import java.util.ArrayList;
3132
import java.util.Collection;
@@ -36,6 +37,7 @@ final class Discharging {
3637
private final AtomicWorkingSet workingSet;
3738
private final Concurrency concurrency;
3839
private final Collection<Runnable> dischargeTasks;
40+
private final TerminationFlag terminationFlag;
3941

4042

4143
static Discharging createDischarging(
@@ -48,7 +50,8 @@ static Discharging createDischarging(
4850
long beta,
4951
AtomicLong workSinceLastGR,
5052
Concurrency concurrency,
51-
HugeLongArrayQueue[] threadQueues
53+
HugeLongArrayQueue[] threadQueues,
54+
TerminationFlag terminationFlag
5255
)
5356
{
5457

@@ -72,25 +75,38 @@ static Discharging createDischarging(
7275
));
7376
}
7477

75-
return new Discharging(workingSet, dischargeTasks, concurrency);
78+
return new Discharging(workingSet, dischargeTasks, concurrency, terminationFlag);
7679
}
7780

78-
private Discharging(AtomicWorkingSet workingSet, Collection<Runnable> dischargeTasks, Concurrency concurrency) {
81+
private Discharging(AtomicWorkingSet workingSet, Collection<Runnable> dischargeTasks, Concurrency concurrency,
82+
TerminationFlag terminationFlag
83+
) {
7984
this.workingSet = workingSet;
8085
this.dischargeTasks = dischargeTasks;
8186
this.concurrency = concurrency;
87+
this.terminationFlag = terminationFlag;
8288
}
8389

8490
void processWorkingSet() {
8591
//Discharge working set
86-
RunWithConcurrency.builder().concurrency(concurrency).tasks(dischargeTasks).build().run();
92+
runTasks();
8793
workingSet.resetIdx();
8894

8995
//Sync working set
90-
RunWithConcurrency.builder().concurrency(concurrency).tasks(dischargeTasks).build().run();
96+
runTasks();
9197
workingSet.reset();
9298

9399
//Update and sync new working set
94-
RunWithConcurrency.builder().concurrency(concurrency).tasks(dischargeTasks).build().run();
100+
runTasks();
101+
102+
}
103+
104+
private void runTasks(){
105+
RunWithConcurrency.builder()
106+
.concurrency(concurrency)
107+
.terminationFlag(terminationFlag)
108+
.tasks(dischargeTasks)
109+
.build()
110+
.run();
95111
}
96112
}

algo/src/main/java/org/neo4j/gds/maxflow/FlowGraph.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.neo4j.gds.collections.ha.HugeDoubleArray;
2727
import org.neo4j.gds.collections.ha.HugeLongArray;
2828
import org.neo4j.gds.collections.ha.HugeObjectArray;
29+
import org.neo4j.gds.termination.TerminationFlag;
2930

3031
public final class FlowGraph {
3132
private final Graph graph;
@@ -60,7 +61,7 @@ private FlowGraph(
6061
this.demand = demand;
6162
}
6263

63-
public static FlowGraph create(Graph graph, NodeWithValue[] supply, NodeWithValue[] demand) {
64+
public static FlowGraph create(Graph graph, NodeWithValue[] supply, NodeWithValue[] demand, TerminationFlag terminationFlag) {
6465
var superSource = graph.nodeCount();
6566
var superTarget = graph.nodeCount() + 1;
6667
var newNodeCount = graph.nodeCount() + 2;
@@ -69,6 +70,7 @@ public static FlowGraph create(Graph graph, NodeWithValue[] supply, NodeWithValu
6970
reverseDegree.setAll(x -> 0L);
7071

7172
for (long nodeId = 0; nodeId < graph.nodeCount(); nodeId++) {
73+
terminationFlag.assertRunning();
7274
graph.forEachRelationship(
7375
nodeId, 0D, (s, t, capacity) -> {
7476
if(capacity < 0D){
@@ -120,9 +122,11 @@ public static FlowGraph create(Graph graph, NodeWithValue[] supply, NodeWithValu
120122
return true;
121123
};
122124
for (long nodeId = 0; nodeId < graph.nodeCount(); nodeId++) {
125+
terminationFlag.assertRunning();
123126
graph.forEachRelationship(nodeId, 0D, consumer);
124127
}
125128
for (var source : supply) {
129+
terminationFlag.assertRunning();
126130
consumer.accept(superSource, source.node(), source.value());
127131
}
128132
for (var target : demand) {

algo/src/main/java/org/neo4j/gds/maxflow/GlobalRelabeling.java

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
2525
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
2626
import org.neo4j.gds.core.utils.paged.HugeLongArrayQueue;
27+
import org.neo4j.gds.termination.TerminationFlag;
2728

2829
import java.util.ArrayList;
2930
import java.util.Collection;
@@ -38,14 +39,16 @@ public final class GlobalRelabeling {
3839
private final long target;
3940
private final Concurrency concurrency;
4041
private final Collection<Runnable> globalRelabelingTasks;
42+
private final TerminationFlag terminationFlag;
4143

4244
static GlobalRelabeling createRelabeling(
4345
FlowGraph flowGraph,
4446
HugeLongArray label,
4547
long source,
4648
long target,
4749
Concurrency concurrency,
48-
HugeLongArrayQueue[] threadQueues
50+
HugeLongArrayQueue[] threadQueues,
51+
TerminationFlag terminationFlag
4952
) {
5053
var vertexIsDiscovered = HugeAtomicBitSet.create(flowGraph.nodeCount());
5154
var frontier = new AtomicWorkingSet(flowGraph.nodeCount());
@@ -55,10 +58,22 @@ static GlobalRelabeling createRelabeling(
5558
globalRelabelingTasks.add(new GlobalRelabellingBFSTask(flowGraph.concurrentCopy(), frontier, vertexIsDiscovered, label, threadQueues[i]));
5659
}
5760

58-
return new GlobalRelabeling(flowGraph.nodeCount(), label, frontier, vertexIsDiscovered, source, target, concurrency, globalRelabelingTasks);
61+
return new GlobalRelabeling(
62+
flowGraph.nodeCount(),
63+
label,
64+
frontier,
65+
vertexIsDiscovered,
66+
source,
67+
target,
68+
concurrency,
69+
globalRelabelingTasks,
70+
terminationFlag
71+
);
5972
}
6073

61-
private GlobalRelabeling(long nodeCount, HugeLongArray label, AtomicWorkingSet frontier, HugeAtomicBitSet vertexIsDiscovered, long source, long target, Concurrency concurrency, Collection<Runnable> globalRelabelingTasks) {
74+
private GlobalRelabeling(long nodeCount, HugeLongArray label, AtomicWorkingSet frontier, HugeAtomicBitSet vertexIsDiscovered, long source, long target, Concurrency concurrency, Collection<Runnable> globalRelabelingTasks,
75+
TerminationFlag terminationFlag
76+
) {
6277
this.nodeCount = nodeCount;
6378
this.label = label;
6479
this.frontier = frontier;
@@ -67,6 +82,7 @@ private GlobalRelabeling(long nodeCount, HugeLongArray label, AtomicWorkingSet f
6782
this.target = target;
6883
this.concurrency = concurrency;
6984
this.globalRelabelingTasks = globalRelabelingTasks;
85+
this.terminationFlag = terminationFlag;
7086
}
7187

7288
public void globalRelabeling() {
@@ -78,11 +94,22 @@ public void globalRelabeling() {
7894
vertexIsDiscovered.set(source);
7995
vertexIsDiscovered.set(target);
8096
while (!frontier.isEmpty()) {
81-
RunWithConcurrency.builder().concurrency(concurrency).tasks(globalRelabelingTasks).build().run();
97+
//relax nodes in the frontier
98+
runTasks();
8299
frontier.reset();
83-
RunWithConcurrency.builder().concurrency(concurrency).tasks(globalRelabelingTasks).build().run();
100+
//update the frontier
101+
runTasks();
84102
}
85103
label.set(source, nodeCount);
86104

87105
}
106+
107+
private void runTasks(){
108+
RunWithConcurrency.builder()
109+
.concurrency(concurrency)
110+
.terminationFlag(terminationFlag)
111+
.tasks(globalRelabelingTasks)
112+
.build()
113+
.run();
114+
}
88115
}

algo/src/main/java/org/neo4j/gds/maxflow/MaxFlow.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public FlowResult compute() {
6161

6262
private Preflow initPreflow() {
6363
var supplyAndDemand = SupplyAndDemandFactory.create(graph, parameters.sourceNodes(), parameters.targetNodes());
64-
var flowGraph = FlowGraph.create(graph, supplyAndDemand.getLeft(), supplyAndDemand.getRight());
64+
var flowGraph = FlowGraph.create(graph, supplyAndDemand.getLeft(), supplyAndDemand.getRight(), terminationFlag);
6565
var excess = HugeDoubleArray.newArray(flowGraph.nodeCount());
6666
excess.setAll(x -> 0D);
6767
flowGraph.forEachRelationship(
@@ -107,7 +107,8 @@ private void maximizeFlow(Preflow preflow, long sourceNode, long targetNode) { /
107107
sourceNode,
108108
targetNode,
109109
parameters.concurrency(),
110-
threadQueues
110+
threadQueues,
111+
terminationFlag
111112
);
112113

113114
var discharging = Discharging.createDischarging(
@@ -120,7 +121,8 @@ private void maximizeFlow(Preflow preflow, long sourceNode, long targetNode) { /
120121
parameters.beta(),
121122
workSinceLastGR,
122123
parameters.concurrency(),
123-
threadQueues
124+
threadQueues,
125+
terminationFlag
124126
);
125127

126128
while (!workingSet.isEmpty()) {

algo/src/test/java/org/neo4j/gds/maxflow/FlowGraphTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.neo4j.gds.extension.GdlGraph;
2727
import org.neo4j.gds.extension.Inject;
2828
import org.neo4j.gds.extension.TestGraph;
29+
import org.neo4j.gds.termination.TerminationFlag;
2930

3031
import java.util.HashMap;
3132
import java.util.HashSet;
@@ -61,7 +62,7 @@ static FlowGraph createFlowGraph(Graph graph, long source, long target) {
6162
.reduce(0D, Double::sum);
6263
NodeWithValue[] supply = {new NodeWithValue(source, outgoingCapacityFromSource)};
6364
NodeWithValue[] demand = {new NodeWithValue(target, outgoingCapacityFromSource)}; //more is useless since this is max in network
64-
return FlowGraph.create(graph, supply, demand);
65+
return FlowGraph.create(graph, supply, demand, TerminationFlag.RUNNING_TRUE);
6566
}
6667

6768
@Test

algo/src/test/java/org/neo4j/gds/maxflow/GlobalRelabelingTest.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.neo4j.gds.extension.GdlGraph;
2828
import org.neo4j.gds.extension.Inject;
2929
import org.neo4j.gds.extension.TestGraph;
30+
import org.neo4j.gds.termination.TerminationFlag;
3031

3132
import static org.assertj.core.api.Assertions.assertThat;
3233
import static org.neo4j.gds.maxflow.FlowGraphTest.createFlowGraph;
@@ -76,7 +77,8 @@ void test() {
7677
graph.toMappedNodeId("c"),
7778
graph.toMappedNodeId("d"),
7879
new Concurrency(1),
79-
threadQueues
80+
threadQueues,
81+
TerminationFlag.RUNNING_TRUE
8082
);
8183

8284
globalRelabeling.globalRelabeling();
@@ -106,7 +108,8 @@ void test2() {
106108
graph.toMappedNodeId("a"),
107109
graph.toMappedNodeId("e"),
108110
new Concurrency(1),
109-
threadQueues
111+
threadQueues,
112+
TerminationFlag.RUNNING_TRUE
110113
);
111114

112115
globalRelabeling.globalRelabeling();
@@ -136,7 +139,8 @@ void test3() {
136139
flowGraph.superSource(),
137140
flowGraph.superTarget(),
138141
new Concurrency(1),
139-
threadQueues
142+
threadQueues,
143+
TerminationFlag.RUNNING_TRUE
140144
);
141145

142146
globalRelabeling.globalRelabeling();

0 commit comments

Comments
 (0)