Skip to content

Commit f9e70bf

Browse files
Max-flow mutate facade
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neo4j.com>
1 parent 19ccd3a commit f9e70bf

File tree

20 files changed

+602
-20
lines changed

20 files changed

+602
-20
lines changed

applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithmsEstimationModeBusinessFacade.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@
1919
*/
2020
package org.neo4j.gds.applications.algorithms.pathfinding;
2121

22+
import org.neo4j.gds.allshortestpaths.AllShortestPathsConfig;
23+
import org.neo4j.gds.allshortestpaths.AllShortestPathsMemoryEstimateDefinition;
2224
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmEstimationTemplate;
2325
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
2426
import org.neo4j.gds.config.AlgoBaseConfig;
2527
import org.neo4j.gds.exceptions.MemoryEstimationNotImplementedException;
28+
import org.neo4j.gds.maxflow.MaxFlowBaseConfig;
2629
import org.neo4j.gds.mem.MemoryEstimation;
2730
import org.neo4j.gds.paths.astar.AStarMemoryEstimateDefinition;
2831
import org.neo4j.gds.paths.astar.config.ShortestPathAStarBaseConfig;
@@ -49,8 +52,6 @@
4952
import org.neo4j.gds.traversal.RandomWalkCountingVisitsMemoryEstimateDefinition;
5053
import org.neo4j.gds.traversal.RandomWalkMemoryEstimateDefinition;
5154
import org.neo4j.gds.traversal.RandomWalkMutateConfig;
52-
import org.neo4j.gds.allshortestpaths.AllShortestPathsMemoryEstimateDefinition;
53-
import org.neo4j.gds.allshortestpaths.AllShortestPathsConfig;
5455

5556
/**
5657
* Here is the top level business facade for all your path finding memory estimation needs.
@@ -130,10 +131,17 @@ MemoryEstimation longestPath() {
130131
throw new MemoryEstimationNotImplementedException();
131132
}
132133

133-
MemoryEstimation maxFlow() {
134+
public MemoryEstimation maxFlow() {
134135
throw new MemoryEstimationNotImplementedException();
135136
}
136137

138+
MemoryEstimation maxFlow(
139+
MaxFlowBaseConfig configuration,
140+
Object graphNameOrConfiguration
141+
) {
142+
return maxFlow();
143+
}
144+
137145
public MemoryEstimateResult pcst(
138146
PCSTBaseConfig configuration,
139147
Object graphNameOrConfiguration

applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithmsMutateModeBusinessFacade.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232
import org.neo4j.gds.applications.algorithms.metadata.RelationshipsWritten;
3333
import org.neo4j.gds.collections.ha.HugeLongArray;
3434
import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
35+
import org.neo4j.gds.maxflow.FlowResult;
36+
import org.neo4j.gds.maxflow.MaxFlowMutateConfig;
3537
import org.neo4j.gds.pathfinding.BellmanFordMutateStep;
38+
import org.neo4j.gds.pathfinding.MaxFlowMutateStep;
3639
import org.neo4j.gds.pathfinding.PrizeCollectingSteinerTreeMutateStep;
3740
import org.neo4j.gds.pathfinding.RandomWalkCountingNodeVisitsMutateStep;
3841
import org.neo4j.gds.pathfinding.SearchMutateStep;
@@ -67,6 +70,7 @@
6770
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.DFS;
6871
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.DeltaStepping;
6972
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.Dijkstra;
73+
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.MaxFlow;
7074
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.RandomWalk;
7175
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.SingleSourceDijkstra;
7276
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.SteinerTree;
@@ -182,6 +186,28 @@ public <RESULT> RESULT depthFirstSearch(
182186
);
183187
}
184188

189+
public <RESULT> RESULT maxFlow(
190+
GraphName graphName,
191+
MaxFlowMutateConfig configuration,
192+
ResultBuilder<MaxFlowMutateConfig, FlowResult, RESULT, RelationshipsWritten> resultBuilder
193+
) {
194+
var mutateStep = new MaxFlowMutateStep(
195+
configuration.mutateRelationshipType(),
196+
configuration.mutateProperty(),
197+
mutateRelationshipService
198+
);
199+
200+
return algorithmProcessingTemplateConvenience.processRegularAlgorithmInMutateMode(
201+
graphName,
202+
configuration,
203+
MaxFlow,
204+
estimationFacade::maxFlow,
205+
(graph, __) -> pathFindingAlgorithms.maxFlow(graph, configuration),
206+
mutateStep,
207+
resultBuilder
208+
);
209+
}
210+
185211
public <RESULT> RESULT pcst(
186212
GraphName graphName,
187213
PCSTMutateConfig configuration,

doc/modules/ROOT/pages/operations-reference/algorithm-references.adoc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@
227227
| `gds.louvain.stream.estimate` label:procedure[Procedure]
228228
| `gds.louvain.stats` label:procedure[Procedure]
229229
| `gds.louvain.stats.estimate` label:procedure[Procedure]
230-
.2+<.^|Max flow
230+
.3+<.^|Max flow
231+
| `gds.maxFlow.mutate` label:procedure[Procedure]
231232
| `gds.maxFlow.stats` label:procedure[Procedure]
232233
| `gds.maxFlow.stream` label:procedure[Procedure]
233234
.4+<.^|xref:algorithms/approx-max-k-cut.adoc[Approximate Maximum k-cut]

open-packaging/src/test/java/org/neo4j/gds/OpenGdsProcedureSmokeTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@ class OpenGdsProcedureSmokeTest extends BaseProcTest {
469469
"gds.louvain.write",
470470
"gds.louvain.write.estimate",
471471

472+
"gds.maxFlow.mutate",
472473
"gds.maxFlow.stats",
473474
"gds.maxFlow.stream",
474475

@@ -632,7 +633,7 @@ void countShouldMatch() {
632633
);
633634

634635
// If you find yourself updating this count, please also update the count in SmokeTest.kt
635-
int expectedCount = 470;
636+
int expectedCount = 471;
636637
assertEquals(
637638
expectedCount,
638639
returnedRows,

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/MutateModeAlgorithmLibrary.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ static CanonicalProcedureName algorithmToName(Algorithm algorithm) {
9898
case Leiden -> CanonicalProcedureName.parse("gds.leiden");
9999
case Louvain -> CanonicalProcedureName.parse("gds.louvain");
100100
case LongestPath -> null;
101-
case MaxFlow -> null;
101+
case MaxFlow -> CanonicalProcedureName.parse("gds.maxflow");
102102
case Modularity -> null;
103103
case ModularityOptimization -> CanonicalProcedureName.parse("gds.modularityOptimization");
104104
case NodeSimilarity -> CanonicalProcedureName.parse("gds.nodeSimilarity");

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/StubbyHolder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import org.neo4j.gds.ml.pipeline.stubs.LccStub;
5252
import org.neo4j.gds.ml.pipeline.stubs.LeidenStub;
5353
import org.neo4j.gds.ml.pipeline.stubs.LouvainStub;
54+
import org.neo4j.gds.ml.pipeline.stubs.MaxFlowStub;
5455
import org.neo4j.gds.ml.pipeline.stubs.ModularityOptimizationStub;
5556
import org.neo4j.gds.ml.pipeline.stubs.Node2VecStub;
5657
import org.neo4j.gds.ml.pipeline.stubs.NodeSimilarityStub;
@@ -124,7 +125,7 @@ Stub get(Algorithm algorithm) {
124125
case Leiden -> new LeidenStub();
125126
case Louvain -> new LouvainStub();
126127
case LongestPath -> null;
127-
case MaxFlow -> null;
128+
case MaxFlow -> new MaxFlowStub();
128129
case Modularity -> null;
129130
case ModularityOptimization -> new ModularityOptimizationStub();
130131
case NodeSimilarity -> new NodeSimilarityStub();
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.ml.pipeline.stubs;
21+
22+
import org.neo4j.gds.maxflow.MaxFlowMutateConfig;
23+
import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade;
24+
import org.neo4j.gds.procedures.algorithms.pathfinding.MaxFlowMutateResult;
25+
import org.neo4j.gds.procedures.algorithms.stubs.MutateStub;
26+
27+
public class MaxFlowStub extends AbstractStub<MaxFlowMutateConfig, MaxFlowMutateResult> {
28+
protected MutateStub<MaxFlowMutateConfig, MaxFlowMutateResult> stub(AlgorithmsProcedureFacade facade) {
29+
return facade.pathFinding().stubs().maxFlow();
30+
}
31+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.paths.maxflow;
21+
22+
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
23+
import org.neo4j.gds.procedures.algorithms.pathfinding.MaxFlowMutateResult;
24+
import org.neo4j.procedure.Context;
25+
import org.neo4j.procedure.Description;
26+
import org.neo4j.procedure.Name;
27+
import org.neo4j.procedure.Procedure;
28+
29+
import java.util.Map;
30+
import java.util.stream.Stream;
31+
32+
import static org.neo4j.gds.paths.maxflow.MaxFlowConstants.MAXFLOW_DESCRIPTION;
33+
import static org.neo4j.procedure.Mode.READ;
34+
35+
public class MaxFlowMutateProc {
36+
@Context
37+
public GraphDataScienceProcedures facade;
38+
39+
@Procedure(value = "gds.maxFlow.mutate", mode = READ)
40+
@Description(MAXFLOW_DESCRIPTION)
41+
public Stream<MaxFlowMutateResult> maxFlow(
42+
@Name(value = "graphName") String graphName,
43+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
44+
) {
45+
return facade.algorithms().pathFinding().maxFlowMutate(graphName, configuration);
46+
}
47+
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.paths.maxflow;
21+
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import org.neo4j.gds.BaseProcTest;
25+
import org.neo4j.gds.GdsCypher;
26+
import org.neo4j.gds.RelationshipType;
27+
import org.neo4j.gds.catalog.GraphProjectProc;
28+
import org.neo4j.gds.core.Username;
29+
import org.neo4j.gds.core.loading.GraphStoreCatalog;
30+
import org.neo4j.gds.extension.IdFunction;
31+
import org.neo4j.gds.extension.Inject;
32+
import org.neo4j.gds.extension.Neo4jGraph;
33+
34+
import java.util.Optional;
35+
import java.util.concurrent.atomic.LongAdder;
36+
37+
import static org.assertj.core.api.Assertions.assertThat;
38+
39+
class MaxFlowMutateProcTest extends BaseProcTest {
40+
41+
@Neo4jGraph
42+
static final String DB_CYPHER = """
43+
CREATE
44+
(a:Node {id: 0}),
45+
(b:Node {id: 1}),
46+
(c:Node {id: 2}),
47+
(d:Node {id: 3}),
48+
(e:Node {id: 4}),
49+
(a)-[:R {w: 4.0}]->(d),
50+
(b)-[:R {w: 3.0}]->(a),
51+
(c)-[:R {w: 2.0}]->(a),
52+
(c)-[:R {w: 0.0}]->(b),
53+
(d)-[:R {w: 5.0}]->(e)
54+
""";
55+
56+
@Inject
57+
private IdFunction idFunction;
58+
59+
@BeforeEach
60+
void setup() throws Exception {
61+
registerProcedures(MaxFlowMutateProc.class, GraphProjectProc.class);
62+
var createQuery = GdsCypher.call(DEFAULT_GRAPH_NAME)
63+
.graphProject()
64+
.withAnyLabel()
65+
.withRelationshipProperty("w")
66+
.yields();
67+
runQuery(createQuery);
68+
}
69+
70+
@Test
71+
void testMutate() {
72+
String query = GdsCypher.call(DEFAULT_GRAPH_NAME)
73+
.algo("gds.maxFlow")
74+
.mutateMode()
75+
.addParameter("sourceNodes", idFunction.of("a"))
76+
.addParameter("capacityProperty", "w")
77+
.addParameter("targetNodes", idFunction.of("e"))
78+
.addParameter("mutateRelationshipType", "MAX_FLOW")
79+
.addParameter("mutateProperty", "flow")
80+
.yields("totalFlow", "relationshipsWritten");
81+
82+
var rowCount = runQueryWithRowConsumer(query,
83+
resultRow -> {
84+
85+
assertThat(resultRow.get("totalFlow")).isInstanceOf(Double.class);
86+
assertThat(resultRow.get("relationshipsWritten")).isInstanceOf(Long.class);
87+
88+
assertThat((double) resultRow.get("totalFlow")).isEqualTo(4D);
89+
assertThat((long) resultRow.get("relationshipsWritten")).isEqualTo(2L);
90+
});
91+
assertThat(rowCount).isEqualTo(1L);
92+
93+
var mutatedGraph = GraphStoreCatalog
94+
.get(Username.EMPTY_USERNAME.username(), db.databaseName(), DEFAULT_GRAPH_NAME)
95+
.graphStore()
96+
.getGraph(RelationshipType.of("MAX_FLOW"), Optional.of("flow"));
97+
98+
assertThat(mutatedGraph.relationshipCount()).isEqualTo(2L);
99+
100+
var relationshipCounter = new LongAdder();
101+
mutatedGraph.forEachRelationship(mutatedGraph.toMappedNodeId(idFunction.of("a")), 0, (s, t, w) -> {
102+
assertThat(t).isEqualTo(mutatedGraph.toMappedNodeId(idFunction.of("d")));
103+
assertThat(w).isEqualTo(4D);
104+
relationshipCounter.increment();
105+
return true;
106+
});
107+
mutatedGraph.forEachRelationship(mutatedGraph.toMappedNodeId(idFunction.of("d")), 0, (s, t, w) -> {
108+
assertThat(t).isEqualTo(mutatedGraph.toMappedNodeId(idFunction.of("e")));
109+
assertThat(w).isEqualTo(4D);
110+
relationshipCounter.increment();
111+
return true;
112+
});
113+
assertThat(relationshipCounter.longValue()).isEqualTo(2L);
114+
115+
}
116+
}

proc/path-finding/src/test/java/org/neo4j/gds/paths/maxflow/MaxFlowStreamProcTest.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,6 @@
3333

3434
import static org.assertj.core.api.Assertions.assertThat;
3535

36-
/**
37-
* a a
38-
* 1 / \ 2 / \
39-
* / \ / \
40-
* b --3-- c b c
41-
* | | => | |
42-
* 4 5 | |
43-
* | | | |
44-
* d --6-- e d e
45-
*/
4636
class MaxFlowStreamProcTest extends BaseProcTest {
4737

4838
@Neo4jGraph(offsetIds = true)

0 commit comments

Comments
 (0)