Skip to content

Commit 897f2ba

Browse files
Merge pull request #11148 from neo-technology/max-flow-stream-facade
Max-flow stream facade
2 parents fb0a8d1 + b1f4122 commit 897f2ba

File tree

35 files changed

+676
-49
lines changed

35 files changed

+676
-49
lines changed

algo-common/src/main/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmLabel.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ public enum AlgorithmLabel implements Label {
6464
Leiden("Leiden"),
6565
Louvain("Louvain"),
6666
LongestPath("LongestPath"),
67+
MaxFlow("MaxFlow"),
6768
Modularity("Modularity"),
6869
ModularityOptimization("ModularityOptimization"),
6970
NodeSimilarity("Node Similarity"),
@@ -135,6 +136,7 @@ public static Label from(Algorithm algorithm) {
135136
case Algorithm.Leiden -> Leiden;
136137
case Algorithm.Louvain -> Louvain;
137138
case Algorithm.LongestPath -> LongestPath;
139+
case Algorithm.MaxFlow -> MaxFlow;
138140
case Algorithm.Modularity -> Modularity;
139141
case Algorithm.ModularityOptimization -> ModularityOptimization;
140142
case Algorithm.NodeSimilarity -> NodeSimilarity;

algo-common/src/main/java/org/neo4j/gds/applications/algorithms/metadata/Algorithm.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ public enum Algorithm {
6868
Leiden,
6969
Louvain,
7070
LongestPath,
71+
MaxFlow,
7172
Modularity,
7273
ModularityOptimization,
7374
NodeSimilarity,

algo/src/main/java/org/neo4j/gds/maxflow/FlowGraph.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
*/
2020
package org.neo4j.gds.maxflow;
2121

22+
import org.apache.commons.lang3.mutable.MutableDouble;
2223
import org.apache.commons.lang3.mutable.MutableLong;
2324
import org.neo4j.gds.api.Graph;
2425
import org.neo4j.gds.api.properties.relationships.RelationshipWithPropertyConsumer;
2526
import org.neo4j.gds.collections.ha.HugeDoubleArray;
2627
import org.neo4j.gds.collections.ha.HugeLongArray;
28+
import org.neo4j.gds.collections.ha.HugeObjectArray;
2729

2830
public final class FlowGraph {
2931
private final Graph graph;
@@ -244,21 +246,20 @@ long superTarget() {
244246
}
245247

246248
FlowResult createFlowResult() {
247-
var flowResult = new FlowResult(originalEdgeCount());
249+
var flow = HugeObjectArray.newArray(FlowRelationship.class, originalEdgeCount());
250+
var totalFlow = new MutableDouble(0D);
251+
248252
var idx = new MutableLong(0L);
249253
for (long nodeId = 0; nodeId < originalNodeCount(); nodeId++) {
250254
var relIdx = new MutableLong(indPtr.get(nodeId));
251255
graph.forEachRelationship(
252256
nodeId,
253257
0D,
254258
(s, t, _capacity) -> {
255-
var flow = this.flow.get(relIdx.longValue());
256-
if (flow > 0.0) {
257-
var flowRelationship = new FlowRelationship(s, t, flow);
258-
flowResult.flow.set(idx.getAndIncrement(), flowRelationship);
259-
// if (t == target) {
260-
// flowResult.totalFlow += flow;
261-
// }
259+
var flow_ = this.flow.get(relIdx.longValue());
260+
if (flow_ > 0.0) {
261+
var flowRelationship = new FlowRelationship(s, t, flow_);
262+
flow.set(idx.getAndIncrement(), flowRelationship);
262263
}
263264
relIdx.increment();
264265
return true;
@@ -273,11 +274,10 @@ FlowResult createFlowResult() {
273274
var fakeFlowFromSuperTarget = this.flow.get(relIdx);
274275
var actualFlowFromSuperTarget = fakeFlowFromSuperTarget - originalCapacity.get(relIdx);
275276
var actualFlowToSuperTarget = -actualFlowFromSuperTarget;
276-
flowResult.totalFlow += actualFlowToSuperTarget;
277+
totalFlow.add(actualFlowToSuperTarget);
277278
return true;
278279
}
279280
);
280-
flowResult.chop(idx.longValue());
281-
return flowResult;
281+
return new FlowResult(flow.copyOf(idx.longValue()), totalFlow.doubleValue());
282282
}
283283
}

algo/src/main/java/org/neo4j/gds/maxflow/FlowResult.java

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,7 @@
2121

2222
import org.neo4j.gds.collections.ha.HugeObjectArray;
2323

24-
public class FlowResult {
25-
HugeObjectArray<FlowRelationship> flow;
26-
double totalFlow;
24+
public record FlowResult(HugeObjectArray<FlowRelationship> flow, double totalFlow) {
2725

28-
FlowResult(long edgeCount) {
29-
flow = HugeObjectArray.newArray(FlowRelationship.class, edgeCount);
30-
totalFlow = 0;
31-
}
32-
33-
void chop(long newLength) {
34-
flow = flow.copyOf(newLength);
35-
}
36-
37-
38-
public double totalFlow() {
39-
return totalFlow;
40-
}
26+
public static FlowResult EMPTY = new FlowResult(HugeObjectArray.newArray(FlowRelationship.class, 0), 0D);
4127
}

algo/src/main/java/org/neo4j/gds/maxflow/SupplyAndDemandFactory.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,36 +38,36 @@ public static Pair<NodeWithValue[], NodeWithValue[]> create(
3838
InputNodes targetNodes
3939
) {
4040
var supply = createSupply(sourceNodes, graph);
41-
var demand = createDemand(targetNodes, supply);
41+
var demand = createDemand(targetNodes, graph, supply);
4242
return Pair.of(supply, demand);
4343
}
4444

4545
private static NodeWithValue[] createSupply(InputNodes sourceNodes, Graph graph) {
4646
return
4747
switch (sourceNodes) {
4848
case ListInputNodes list -> list.inputNodes().stream().map(sourceNode -> new NodeWithValue(
49-
sourceNode, graph.streamRelationships(
50-
sourceNode,
49+
graph.toMappedNodeId(sourceNode), graph.streamRelationships(
50+
graph.toMappedNodeId(sourceNode),
5151
0D
5252
).map(RelationshipCursor::property).reduce(0D, Double::sum)
5353
)).toArray(NodeWithValue[]::new);
5454
case MapInputNodes map -> map.map().entrySet().stream().map(entry -> new NodeWithValue(
55-
entry.getKey(),
55+
graph.toMappedNodeId(entry.getKey()),
5656
entry.getValue()
5757
)).toArray(NodeWithValue[]::new);
5858
default -> throw new IllegalStateException("Unexpected value: " + sourceNodes); //fixme
5959
};
6060
}
6161

62-
private static NodeWithValue[] createDemand(InputNodes targetNodes, NodeWithValue[] supply) {
62+
private static NodeWithValue[] createDemand(InputNodes targetNodes, Graph graph, NodeWithValue[] supply) {
6363
return
6464
switch (targetNodes) {
6565
case ListInputNodes list -> {
6666
var totalOutgoing = Arrays.stream(supply).map(nodeWithValue -> nodeWithValue.value()).reduce(0D, Double::sum);
67-
yield list.inputNodes().stream().map(sourceNode -> new NodeWithValue(sourceNode, totalOutgoing)).toArray(NodeWithValue[]::new);
67+
yield list.inputNodes().stream().map(sourceNode -> new NodeWithValue(graph.toMappedNodeId(sourceNode), totalOutgoing)).toArray(NodeWithValue[]::new);
6868
}
6969
case MapInputNodes map -> map.map().entrySet().stream().map(entry -> new NodeWithValue(
70-
entry.getKey(),
70+
graph.toMappedNodeId(entry.getKey()),
7171
entry.getValue()
7272
)).toArray(NodeWithValue[]::new);
7373
default -> throw new IllegalStateException("Unexpected value: " + targetNodes); //fixme

algo/src/test/java/org/neo4j/gds/maxflow/FlowGraphTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ void testCreateFlowResult() {
145145

146146
var result = flowGraph.createFlowResult();
147147

148-
assertThat(result.totalFlow).isEqualTo(2D);
149-
assertThat(result.flow.toArray()).containsExactlyInAnyOrder(new FlowRelationship(
148+
assertThat(result.totalFlow()).isEqualTo(2D);
149+
assertThat(result.flow().toArray()).containsExactlyInAnyOrder(new FlowRelationship(
150150
graph.toMappedNodeId("a"),
151151
graph.toMappedNodeId("d"),
152152
2D

algo/src/test/java/org/neo4j/gds/maxflow/MaxFlowTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void testGraph(Graph graph, InputNodes sourceNodes, InputNodes targetNodes, doub
8484
var params = new MaxFlowParameters(sourceNodes, targetNodes, new Concurrency(concurrency), ALPHA, BETA, FREQ);
8585
var x = new MaxFlow(graph, params, null, null); //fixme
8686
var result = x.compute();
87-
assertThat(result.totalFlow).isCloseTo(expectedFlow, Offset.offset(TOLERANCE));
87+
assertThat(result.totalFlow()).isCloseTo(expectedFlow, Offset.offset(TOLERANCE));
8888
}
8989

9090
void testGraph(Graph graph, long sourceNode, long targetNode, double expectedFlow, int concurrency) {
@@ -95,7 +95,7 @@ void testGraph(Graph graph, long sourceNode, long targetNode, double expectedFlo
9595
}
9696

9797
void testGraph(TestGraph graph, String sourceNode, String targetNode, double expectedFlow) {
98-
testGraph(graph.graph(), graph.toMappedNodeId(sourceNode), graph.toMappedNodeId(targetNode), expectedFlow, 1);
98+
testGraph(graph.graph(), graph.toOriginalNodeId(sourceNode), graph.toOriginalNodeId(targetNode), expectedFlow, 1);
9999
}
100100

101101
@Test

algo/src/test/java/org/neo4j/gds/maxflow/SupplyAndDemandFactoryTest.java

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,13 @@ void testCreateWithListSourceAndListTargetNodes() {
6060
var b = graph.toMappedNodeId("b");
6161
var c = graph.toMappedNodeId("c");
6262
var d = graph.toMappedNodeId("d");
63+
var aOriginal = graph.toOriginalNodeId("a");
64+
var bOriginal = graph.toOriginalNodeId("b");
65+
var cOriginal = graph.toOriginalNodeId("c");
66+
var dOriginal = graph.toOriginalNodeId("d");
6367

64-
var sourceNodes = new ListInputNodes(List.of(a, b));
65-
var targetNodes = new ListInputNodes(List.of(c, d));
68+
var sourceNodes = new ListInputNodes(List.of(aOriginal, bOriginal));
69+
var targetNodes = new ListInputNodes(List.of(cOriginal, dOriginal));
6670

6771
var result = SupplyAndDemandFactory.create(graph, sourceNodes, targetNodes);
6872

@@ -77,9 +81,15 @@ void testCreateWithListSourceAndMapTargetNodes() {
7781
var c = graph.toMappedNodeId("c");
7882
var e = graph.toMappedNodeId("e");
7983

84+
85+
var aOriginal = graph.toOriginalNodeId("a");
86+
var bOriginal = graph.toOriginalNodeId("b");
87+
var cOriginal = graph.toOriginalNodeId("c");
88+
var eOriginal = graph.toOriginalNodeId("e");
89+
8090
// Arrange
81-
var sourceNodes = new ListInputNodes(List.of(a, b));
82-
var targetNodes = new MapInputNodes(Map.of(c, 5.0, e, 8.0));
91+
var sourceNodes = new ListInputNodes(List.of(aOriginal, bOriginal));
92+
var targetNodes = new MapInputNodes(Map.of(cOriginal, 5.0, eOriginal, 8.0));
8393

8494
// Act
8595
var result = SupplyAndDemandFactory.create(graph, sourceNodes, targetNodes);
@@ -96,8 +106,13 @@ void testCreateWithMapSourceAndListTargetNodes() {
96106
var d = graph.toMappedNodeId("d");
97107
var e = graph.toMappedNodeId("e");
98108

99-
var sourceNodes = new MapInputNodes(Map.of(a, 1.2, d, 10.0));
100-
var targetNodes = new ListInputNodes(List.of(c, e));
109+
var aOriginal = graph.toOriginalNodeId("a");
110+
var cOriginal = graph.toOriginalNodeId("c");
111+
var dOriginal = graph.toOriginalNodeId("d");
112+
var eOriginal = graph.toOriginalNodeId("e");
113+
114+
var sourceNodes = new MapInputNodes(Map.of(aOriginal, 1.2, dOriginal, 10.0));
115+
var targetNodes = new ListInputNodes(List.of(cOriginal, eOriginal));
101116

102117
var result = SupplyAndDemandFactory.create(graph, sourceNodes, targetNodes);
103118

@@ -112,8 +127,13 @@ void testCreateWithMapSourceAndMapTargetNodes() {
112127
var d = graph.toMappedNodeId("d");
113128
var e = graph.toMappedNodeId("e");
114129

115-
var sourceNodes = new MapInputNodes(Map.of(a, 1.0, c, 3.0));
116-
var targetNodes = new MapInputNodes(Map.of(d, 5.1, e, 9.0));
130+
var aOriginal = graph.toOriginalNodeId("a");
131+
var cOriginal = graph.toOriginalNodeId("c");
132+
var dOriginal = graph.toOriginalNodeId("d");
133+
var eOriginal = graph.toOriginalNodeId("e");
134+
135+
var sourceNodes = new MapInputNodes(Map.of(aOriginal, 1.0, cOriginal, 3.0));
136+
var targetNodes = new MapInputNodes(Map.of(dOriginal, 5.1, eOriginal, 9.0));
117137

118138
var result = SupplyAndDemandFactory.create(graph, sourceNodes, targetNodes);
119139

algorithms-compute-business-facade/src/main/java/org/neo4j/gds/pathfinding/PathFindingComputeBusinessFacade.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
import org.neo4j.gds.dag.topologicalsort.TopologicalSortParameters;
3939
import org.neo4j.gds.dag.topologicalsort.TopologicalSortResult;
4040
import org.neo4j.gds.kspanningtree.KSpanningTreeParameters;
41+
import org.neo4j.gds.maxflow.FlowResult;
42+
import org.neo4j.gds.maxflow.MaxFlowParameters;
4143
import org.neo4j.gds.pathfinding.validation.KSpanningTreeGraphStoreValidation;
4244
import org.neo4j.gds.pathfinding.validation.PCSTGraphStoreValidation;
4345
import org.neo4j.gds.pathfinding.validation.RandomWalkGraphValidation;
@@ -294,6 +296,34 @@ public <TR> CompletableFuture<TR> longestPath(
294296
).thenApply(resultTransformerBuilder.build(graphResources));
295297
}
296298

299+
public <TR> CompletableFuture<TR> maxFlow(
300+
GraphName graphName,
301+
GraphParameters graphParameters,
302+
Optional<String> relationshipProperty,
303+
MaxFlowParameters parameters,
304+
JobId jobId,
305+
boolean logProgress,
306+
ResultTransformerBuilder<TimedAlgorithmResult<FlowResult>, TR> resultTransformerBuilder
307+
) {
308+
var graphResources = graphStoreCatalogService.fetchGraphResources(
309+
graphName,
310+
graphParameters,
311+
relationshipProperty,
312+
new NoAlgorithmValidation(),
313+
Optional.empty(),
314+
user,
315+
databaseId
316+
);
317+
var graph = graphResources.graph();
318+
319+
return computeFacade.maxFlow(
320+
graph,
321+
parameters,
322+
jobId,
323+
logProgress
324+
).thenApply(resultTransformerBuilder.build(graphResources));
325+
}
326+
297327
public <TR> CompletableFuture<TR> randomWalk(
298328
GraphName graphName,
299329
GraphParameters graphParameters,

algorithms-compute-facade/src/main/java/org/neo4j/gds/pathfinding/PathFindingComputeFacade.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
3131
import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator;
3232
import org.neo4j.gds.core.utils.progress.JobId;
33+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3334
import org.neo4j.gds.dag.longestPath.DagLongestPath;
3435
import org.neo4j.gds.dag.longestPath.DagLongestPathParameters;
3536
import org.neo4j.gds.dag.longestPath.LongestPathTask;
@@ -40,6 +41,9 @@
4041
import org.neo4j.gds.kspanningtree.KSpanningTree;
4142
import org.neo4j.gds.kspanningtree.KSpanningTreeParameters;
4243
import org.neo4j.gds.kspanningtree.KSpanningTreeTask;
44+
import org.neo4j.gds.maxflow.FlowResult;
45+
import org.neo4j.gds.maxflow.MaxFlow;
46+
import org.neo4j.gds.maxflow.MaxFlowParameters;
4347
import org.neo4j.gds.paths.RelationshipCountProgressTaskFactory;
4448
import org.neo4j.gds.paths.astar.AStar;
4549
import org.neo4j.gds.paths.astar.AStarParameters;
@@ -361,6 +365,35 @@ public CompletableFuture<TimedAlgorithmResult<PathFindingResult>> longestPath(
361365
);
362366
}
363367

368+
public CompletableFuture<TimedAlgorithmResult<FlowResult>> maxFlow(
369+
Graph graph,
370+
MaxFlowParameters parameters,
371+
JobId jobId,
372+
boolean logProgress
373+
) {
374+
// If the input graph is empty return a completed future with empty result
375+
if (graph.isEmpty()) {
376+
return CompletableFuture.completedFuture(TimedAlgorithmResult.empty(FlowResult.EMPTY));
377+
}
378+
379+
// Create ProgressTracker
380+
var progressTracker = ProgressTracker.NULL_TRACKER;
381+
382+
// Create the algorithm
383+
var algo = new MaxFlow(
384+
graph,
385+
parameters,
386+
progressTracker,
387+
terminationFlag
388+
);
389+
390+
// Submit the algorithm for async computation
391+
return algorithmCaller.run(
392+
algo::compute,
393+
jobId
394+
);
395+
}
396+
364397
public CompletableFuture<TimedAlgorithmResult<Stream<long[]>>> randomWalk(
365398
Graph graph,
366399
RandomWalkParameters parameters,

0 commit comments

Comments
 (0)