Skip to content

Commit 6f32ab0

Browse files
Config
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neo4j.com>
1 parent 5faedc4 commit 6f32ab0

File tree

19 files changed

+401
-29
lines changed

19 files changed

+401
-29
lines changed

config-api/src/main/java/org/neo4j/gds/config/InputNodes.java renamed to algo-params/common/src/main/java/org/neo4j/gds/InputNodes.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
* You should have received a copy of the GNU General Public License
1818
* along with this program. If not, see <http://www.gnu.org/licenses/>.
1919
*/
20-
package org.neo4j.gds.config;
20+
package org.neo4j.gds;
2121

2222
import java.util.Collection;
2323
import java.util.List;

config-api/src/main/java/org/neo4j/gds/config/ListInputNodes.java renamed to algo-params/common/src/main/java/org/neo4j/gds/ListInputNodes.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
* You should have received a copy of the GNU General Public License
1818
* along with this program. If not, see <http://www.gnu.org/licenses/>.
1919
*/
20-
package org.neo4j.gds.config;
20+
package org.neo4j.gds;
2121

2222
import java.util.Collection;
2323
import java.util.List;

config-api/src/main/java/org/neo4j/gds/config/MapInputNodes.java renamed to algo-params/common/src/main/java/org/neo4j/gds/MapInputNodes.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
* You should have received a copy of the GNU General Public License
1818
* along with this program. If not, see <http://www.gnu.org/licenses/>.
1919
*/
20-
package org.neo4j.gds.config;
20+
package org.neo4j.gds;
2121

2222
import java.util.Collection;
2323
import java.util.Map;

algo-params/path-finding-params/src/main/java/org/neo4j/gds/maxflow/MaxFlowParameters.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
*/
2121
package org.neo4j.gds.maxflow;
2222

23+
import org.neo4j.gds.InputNodes;
24+
import org.neo4j.gds.annotation.Parameters;
2325
import org.neo4j.gds.core.concurrency.Concurrency;
2426

25-
public record MaxFlowParameters(NodeWithValue[] supply, NodeWithValue[] demand, Concurrency concurrency, long alpha, long beta, double freq) {
27+
@Parameters
28+
public record MaxFlowParameters(InputNodes sourceNodes, InputNodes targetNodes, Concurrency concurrency, long alpha, long beta, double freq) {
2629
}

algo/src/main/java/org/neo4j/gds/maxflow/FlowGraph.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ public static FlowGraph create(Graph graph, NodeWithValue[] supply, NodeWithValu
6969
for (long nodeId = 0; nodeId < graph.nodeCount(); nodeId++) {
7070
graph.forEachRelationship(
7171
nodeId, 0D, (s, t, capacity) -> {
72+
if(capacity < 0D){
73+
throw new IllegalArgumentException("Negative capacity not allowed");
74+
}
7275
reverseDegree.addTo(t, 1);
7376
return true;
7477
}

algo/src/main/java/org/neo4j/gds/maxflow/MaxFlow.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3131
import org.neo4j.gds.termination.TerminationFlag;
3232

33-
import java.util.concurrent.ExecutorService;
3433
import java.util.concurrent.atomic.AtomicLong;
3534

3635
public final class MaxFlow extends Algorithm<FlowResult> {
@@ -39,19 +38,16 @@ public final class MaxFlow extends Algorithm<FlowResult> {
3938
static final int BETA = 12;
4039
private final Graph graph;
4140
private final MaxFlowParameters parameters;
42-
private final ExecutorService executorService;
4341

4442
public MaxFlow(
4543
Graph graph,
4644
MaxFlowParameters parameters,
47-
ExecutorService executorService,
4845
ProgressTracker progressTracker,
4946
TerminationFlag terminationFlag
5047
) {
5148
super(progressTracker);
5249
this.graph = graph;
5350
this.parameters = parameters;
54-
this.executorService = executorService;
5551
this.terminationFlag = terminationFlag;
5652
}
5753

@@ -65,7 +61,8 @@ public FlowResult compute() {
6561
}
6662

6763
private Preflow initPreflow() {
68-
var flowGraph = FlowGraph.create(graph, parameters.supply(), parameters.demand());
64+
var supplyAndDemand = SupplyAndDemandFactory.create(graph, parameters.sourceNodes(), parameters.targetNodes());
65+
var flowGraph = FlowGraph.create(graph, supplyAndDemand.getLeft(), supplyAndDemand.getRight());
6966
var excess = HugeDoubleArray.newArray(flowGraph.nodeCount());
7067
excess.setAll(x -> 0D);
7168
flowGraph.forEachRelationship(
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.maxflow;
21+
22+
import org.apache.commons.lang3.tuple.Pair;
23+
import org.neo4j.gds.api.Graph;
24+
import org.neo4j.gds.api.properties.relationships.RelationshipCursor;
25+
import org.neo4j.gds.InputNodes;
26+
import org.neo4j.gds.ListInputNodes;
27+
import org.neo4j.gds.MapInputNodes;
28+
29+
import java.util.Arrays;
30+
31+
final class SupplyAndDemandFactory {
32+
33+
private SupplyAndDemandFactory() {}
34+
35+
public static Pair<NodeWithValue[], NodeWithValue[]> create(
36+
Graph graph,
37+
InputNodes sourceNodes,
38+
InputNodes targetNodes
39+
) {
40+
var supply = createSupply(sourceNodes, graph);
41+
var demand = createDemand(targetNodes, supply);
42+
return Pair.of(supply, demand);
43+
}
44+
45+
private static NodeWithValue[] createSupply(InputNodes sourceNodes, Graph graph) {
46+
return
47+
switch (sourceNodes) {
48+
case ListInputNodes list -> list.inputNodes().stream().map(sourceNode -> new NodeWithValue(
49+
sourceNode, graph.streamRelationships(
50+
sourceNode,
51+
0D
52+
).map(RelationshipCursor::property).reduce(0D, Double::sum)
53+
)).toArray(NodeWithValue[]::new);
54+
case MapInputNodes map -> map.map().entrySet().stream().map(entry -> new NodeWithValue(
55+
entry.getKey(),
56+
entry.getValue()
57+
)).toArray(NodeWithValue[]::new);
58+
default -> throw new IllegalStateException("Unexpected value: " + sourceNodes); //fixme
59+
};
60+
}
61+
62+
private static NodeWithValue[] createDemand(InputNodes targetNodes, NodeWithValue[] supply) {
63+
return
64+
switch (targetNodes) {
65+
case ListInputNodes list -> {
66+
var totalOutgoing = Arrays.stream(supply).map(nodeWithValue -> nodeWithValue.value()).reduce(0D, Double::sum);
67+
yield list.inputNodes().stream().map(sourceNode -> new NodeWithValue(sourceNode, totalOutgoing)).toArray(NodeWithValue[]::new);
68+
}
69+
case MapInputNodes map -> map.map().entrySet().stream().map(entry -> new NodeWithValue(
70+
entry.getKey(),
71+
entry.getValue()
72+
)).toArray(NodeWithValue[]::new);
73+
default -> throw new IllegalStateException("Unexpected value: " + targetNodes); //fixme
74+
};
75+
}
76+
}

algo/src/main/java/org/neo4j/gds/pagerank/InitialProbabilityFactory.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
*/
2020
package org.neo4j.gds.pagerank;
2121

22-
import org.neo4j.gds.config.InputNodes;
23-
import org.neo4j.gds.config.ListInputNodes;
24-
import org.neo4j.gds.config.MapInputNodes;
22+
import org.neo4j.gds.InputNodes;
23+
import org.neo4j.gds.ListInputNodes;
24+
import org.neo4j.gds.MapInputNodes;
2525

2626
import java.util.HashMap;
2727
import java.util.function.LongUnaryOperator;

algo/src/test/java/org/neo4j/gds/maxflow/MaxFlowTest.java

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222

2323
import org.assertj.core.data.Offset;
2424
import org.junit.jupiter.api.Test;
25+
import org.neo4j.gds.InputNodes;
26+
import org.neo4j.gds.ListInputNodes;
27+
import org.neo4j.gds.MapInputNodes;
2528
import org.neo4j.gds.Orientation;
2629
import org.neo4j.gds.TestSupport;
2730
import org.neo4j.gds.api.Graph;
28-
import org.neo4j.gds.api.properties.relationships.RelationshipCursor;
2931
import org.neo4j.gds.beta.generator.PropertyProducer;
3032
import org.neo4j.gds.beta.generator.RandomGraphGenerator;
3133
import org.neo4j.gds.core.concurrency.Concurrency;
@@ -34,6 +36,9 @@
3436
import org.neo4j.gds.extension.Inject;
3537
import org.neo4j.gds.extension.TestGraph;
3638

39+
import java.util.List;
40+
import java.util.Map;
41+
3742
import static org.assertj.core.api.Assertions.assertThat;
3843
import static org.neo4j.gds.beta.generator.RelationshipDistribution.UNIFORM;
3944
import static org.neo4j.gds.maxflow.MaxFlow.ALPHA;
@@ -75,21 +80,18 @@ Graph generateUniform(long nodeCount, int avgDegree) {
7580
return graph;
7681
}
7782

78-
void testGraph(Graph graph, NodeWithValue[] supply, NodeWithValue[] demand, double expectedFlow, int concurrency) {
79-
var params = new MaxFlowParameters(supply, demand, new Concurrency(concurrency), ALPHA, BETA, FREQ);
80-
var x = new MaxFlow(graph, params, null, null, null); //fixme
83+
void testGraph(Graph graph, InputNodes sourceNodes, InputNodes targetNodes, double expectedFlow, int concurrency) {
84+
var params = new MaxFlowParameters(sourceNodes, targetNodes, new Concurrency(concurrency), ALPHA, BETA, FREQ);
85+
var x = new MaxFlow(graph, params, null, null); //fixme
8186
var result = x.compute();
8287
assertThat(result.totalFlow).isCloseTo(expectedFlow, Offset.offset(TOLERANCE));
8388
}
8489

8590
void testGraph(Graph graph, long sourceNode, long targetNode, double expectedFlow, int concurrency) {
86-
double outgoingCapacityFromSource = graph.streamRelationships(sourceNode, 0D)
87-
.map(RelationshipCursor::property)
88-
.reduce(0D, Double::sum);
89-
NodeWithValue[] supply = {new NodeWithValue(sourceNode, outgoingCapacityFromSource)};
90-
NodeWithValue[] demand = {new NodeWithValue(targetNode, outgoingCapacityFromSource)}; //more is useless since this is max in network
91+
var sourceNodes = new ListInputNodes(List.of(sourceNode));
92+
var targetNodes = new ListInputNodes(List.of(targetNode));
9193

92-
testGraph(graph, supply, demand, expectedFlow, concurrency);
94+
testGraph(graph, sourceNodes, targetNodes, expectedFlow, concurrency);
9395
}
9496

9597
void testGraph(TestGraph graph, String sourceNode, String targetNode, double expectedFlow) {
@@ -322,8 +324,8 @@ void test4() {
322324
testGraph(graph, 50, 100, 434.3606561583014, 4);
323325

324326
testGraph(graph,
325-
new NodeWithValue[]{new NodeWithValue(1, 103.1), new NodeWithValue(23, 129.5), new NodeWithValue(101, 242.2)},
326-
new NodeWithValue[]{new NodeWithValue(5, 117.7), new NodeWithValue(199, 199.0), new NodeWithValue(150, 204.5)},
327+
new MapInputNodes(Map.of(1L, 103.1, 23L, 129.5, 101L, 242.2)),
328+
new MapInputNodes(Map.of(5L, 117.7, 199L, 199.0, 150L, 204.5)),
327329
474.8,
328330
4);
329331
}
@@ -335,8 +337,8 @@ void test5() {
335337

336338

337339
testGraph(graph,
338-
new NodeWithValue[]{new NodeWithValue(1, 100.1), new NodeWithValue(23, 120.5), new NodeWithValue(501, 142.2)},
339-
new NodeWithValue[]{new NodeWithValue(5, 157.7), new NodeWithValue(299, 109.0), new NodeWithValue(450, 204.5)},
340+
new MapInputNodes(Map.of(1L, 100.1, 23L, 120.5, 501L, 142.2)),
341+
new MapInputNodes(Map.of(5L, 157.7, 299L, 109.0, 450L, 204.5)),
340342
362.79999999999995,
341343
4);
342344
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.maxflow;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.ListInputNodes;
24+
import org.neo4j.gds.MapInputNodes;
25+
import org.neo4j.gds.extension.GdlExtension;
26+
import org.neo4j.gds.extension.GdlGraph;
27+
import org.neo4j.gds.extension.Inject;
28+
import org.neo4j.gds.extension.TestGraph;
29+
30+
import java.util.List;
31+
import java.util.Map;
32+
33+
import static org.assertj.core.api.Assertions.assertThat;
34+
35+
@GdlExtension
36+
class SupplyAndDemandFactoryTest {
37+
38+
@GdlGraph
39+
private static final String GRAPH =
40+
"""
41+
CREATE
42+
(a:Node {id: 0}),
43+
(b:Node {id: 1}),
44+
(c:Node {id: 2}),
45+
(d:Node {id: 3}),
46+
(e:Node {id: 4}),
47+
(a)-[:R {w: 4.0}]->(d),
48+
(b)-[:R {w: 3.0}]->(a),
49+
(c)-[:R {w: 2.0}]->(a),
50+
(c)-[:R {w: 0.0}]->(b),
51+
(d)-[:R {w: 5.0}]->(e)
52+
""";
53+
54+
@Inject
55+
private TestGraph graph;
56+
57+
@Test
58+
void testCreateWithListSourceAndListTargetNodes() {
59+
var a = graph.toMappedNodeId("a");
60+
var b = graph.toMappedNodeId("b");
61+
var c = graph.toMappedNodeId("c");
62+
var d = graph.toMappedNodeId("d");
63+
64+
var sourceNodes = new ListInputNodes(List.of(a, b));
65+
var targetNodes = new ListInputNodes(List.of(c, d));
66+
67+
var result = SupplyAndDemandFactory.create(graph, sourceNodes, targetNodes);
68+
69+
assertThat(result.getLeft()).containsExactlyInAnyOrder(new NodeWithValue(a, 4.0), new NodeWithValue(b, 3.0));
70+
assertThat(result.getRight()).containsExactlyInAnyOrder(new NodeWithValue(c, 7.0), new NodeWithValue(d, 7.0));
71+
}
72+
73+
@Test
74+
void testCreateWithListSourceAndMapTargetNodes() {
75+
var a = graph.toMappedNodeId("a");
76+
var b = graph.toMappedNodeId("b");
77+
var c = graph.toMappedNodeId("c");
78+
var e = graph.toMappedNodeId("e");
79+
80+
// Arrange
81+
var sourceNodes = new ListInputNodes(List.of(a, b));
82+
var targetNodes = new MapInputNodes(Map.of(c, 5.0, e, 8.0));
83+
84+
// Act
85+
var result = SupplyAndDemandFactory.create(graph, sourceNodes, targetNodes);
86+
87+
// Assert
88+
assertThat(result.getLeft()).containsExactlyInAnyOrder(new NodeWithValue(a, 4.0), new NodeWithValue(b, 3.0));
89+
assertThat(result.getRight()).containsExactlyInAnyOrder(new NodeWithValue(c, 5.0), new NodeWithValue(e, 8.0));
90+
}
91+
92+
@Test
93+
void testCreateWithMapSourceAndListTargetNodes() {
94+
var a = graph.toMappedNodeId("a");
95+
var c = graph.toMappedNodeId("c");
96+
var d = graph.toMappedNodeId("d");
97+
var e = graph.toMappedNodeId("e");
98+
99+
var sourceNodes = new MapInputNodes(Map.of(a, 1.2, d, 10.0));
100+
var targetNodes = new ListInputNodes(List.of(c, e));
101+
102+
var result = SupplyAndDemandFactory.create(graph, sourceNodes, targetNodes);
103+
104+
assertThat(result.getLeft()).containsExactlyInAnyOrder(new NodeWithValue(a, 1.2), new NodeWithValue(d, 10.0));
105+
assertThat(result.getRight()).containsExactlyInAnyOrder(new NodeWithValue(c, 11.2), new NodeWithValue(e, 11.2));
106+
}
107+
108+
@Test
109+
void testCreateWithMapSourceAndMapTargetNodes() {
110+
var a = graph.toMappedNodeId("a");
111+
var c = graph.toMappedNodeId("c");
112+
var d = graph.toMappedNodeId("d");
113+
var e = graph.toMappedNodeId("e");
114+
115+
var sourceNodes = new MapInputNodes(Map.of(a, 1.0, c, 3.0));
116+
var targetNodes = new MapInputNodes(Map.of(d, 5.1, e, 9.0));
117+
118+
var result = SupplyAndDemandFactory.create(graph, sourceNodes, targetNodes);
119+
120+
assertThat(result.getLeft()).containsExactlyInAnyOrder(new NodeWithValue(a, 1.0), new NodeWithValue(c, 3.0));
121+
assertThat(result.getRight()).containsExactlyInAnyOrder(new NodeWithValue(d, 5.1), new NodeWithValue(e, 9.0));
122+
}
123+
}

0 commit comments

Comments
 (0)