Skip to content

Commit 8e028be

Browse files
committed
Add memory estimation for AllPairsShortestPaths stream
1 parent 42f99c6 commit 8e028be

File tree

9 files changed

+167
-4
lines changed

9 files changed

+167
-4
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.allshortestpaths;
21+
22+
import org.neo4j.gds.mem.MemoryEstimateDefinition;
23+
import org.neo4j.gds.mem.MemoryEstimation;
24+
import org.neo4j.gds.mem.MemoryEstimations;
25+
import org.neo4j.gds.mem.Estimate;
26+
import org.neo4j.gds.msbfs.MSBFSMemoryEstimation;
27+
28+
public final class AllShortestPathsMemoryEstimateDefinition implements MemoryEstimateDefinition {
29+
30+
private final boolean hasRelationshipWeightProperty;
31+
32+
public AllShortestPathsMemoryEstimateDefinition(boolean hasRelationshipWeightProperty) {
33+
this.hasRelationshipWeightProperty = hasRelationshipWeightProperty;
34+
}
35+
36+
@Override
37+
public MemoryEstimation memoryEstimation() {
38+
if (hasRelationshipWeightProperty) {
39+
return weightedMemoryEstimation();
40+
} else {
41+
return unweightedMemoryEstimation();
42+
}
43+
}
44+
45+
private MemoryEstimation weightedMemoryEstimation() {
46+
return MemoryEstimations.builder(WeightedAllShortestPaths.class)
47+
.perThread("ShortestPathTask", shortestPathTaskMemoryEstimation())
48+
.build();
49+
}
50+
51+
private MemoryEstimation unweightedMemoryEstimation() {
52+
return MemoryEstimations.builder(MSBFSAllShortestPaths.class)
53+
.add("MSBFS", MSBFSMemoryEstimation.MSBFSWithANPStrategy(0))
54+
.build();
55+
}
56+
57+
private MemoryEstimation shortestPathTaskMemoryEstimation() {
58+
return MemoryEstimations.builder(WeightedAllShortestPaths.ShortestPathTask.class)
59+
.perNode("distance array", Estimate::sizeOfDoubleArray)
60+
.add("priority queue", IntPriorityQueue.memoryEstimation())
61+
.build();
62+
}
63+
}

algo/src/main/java/org/neo4j/gds/allshortestpaths/IntPriorityQueue.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
import org.jetbrains.annotations.TestOnly;
2222
import org.neo4j.gds.collections.ArrayUtil;
2323
import org.neo4j.gds.collections.ha.HugeIntArray;
24+
import org.neo4j.gds.mem.MemoryEstimation;
25+
import org.neo4j.gds.mem.MemoryEstimations;
26+
import org.neo4j.gds.mem.Estimate;
2427

2528
/**
2629
* A PriorityQueue specialized for ints that maintains a partial ordering of
@@ -44,6 +47,14 @@ public abstract class IntPriorityQueue {
4447
private final IntLongScatterMap mapElementToIndex;
4548
private long size = 0;
4649

50+
public static MemoryEstimation memoryEstimation() {
51+
return MemoryEstimations.builder(IntPriorityQueue.class)
52+
.perNode("heap", HugeIntArray::memoryEstimation)
53+
.add("costs", MemoryEstimations.builder(IntDoubleScatterMap.class).build())
54+
.add("element to index map", MemoryEstimations.builder(IntLongScatterMap.class).build())
55+
.build();
56+
}
57+
4758
/**
4859
* Creates a new queue with the given capacity.
4960
* The queue dynamically grows to hold all elements.

algo/src/main/java/org/neo4j/gds/allshortestpaths/WeightedAllShortestPaths.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ public Stream<AllShortestPathsStreamResult> compute() {
109109
* Dijkstra Task. Takes one element of the counter at a time
110110
* and starts dijkstra on it.
111111
*/
112-
private final class ShortestPathTask implements Runnable {
112+
public final class ShortestPathTask implements Runnable {
113113

114114
private final IntPriorityQueue queue;
115115
private final double[] distance;
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.allshortestpaths;
21+
22+
import org.junit.jupiter.params.ParameterizedTest;
23+
import org.junit.jupiter.params.provider.CsvSource;
24+
import org.neo4j.gds.assertions.MemoryEstimationAssert;
25+
import org.neo4j.gds.core.concurrency.Concurrency;
26+
27+
import static org.neo4j.gds.assertions.MemoryEstimationAssert.assertThat;
28+
29+
class AllShortestPathsMemoryEstimateDefinitionTest {
30+
31+
@ParameterizedTest
32+
@CsvSource({
33+
"10_000, 1, false, 280",
34+
"10_000, 4, false, 544",
35+
"500_000, 4, false, 27_200",
36+
"10_000_000, 4, false, 544_000",
37+
"10_000, 1, true, 1_120_456",
38+
"10_000, 4, true, 1_120_720",
39+
"500_000, 4, true, 56_000_720",
40+
"10_000_000, 4, true, 1_120_000_720"
41+
})
42+
void testMemoryEstimation(long nodeCount, int concurrency, boolean weighted, long expectedMemory) {
43+
var memoryEstimation = new AllShortestPathsMemoryEstimateDefinition(weighted).memoryEstimation();
44+
assertThat(memoryEstimation)
45+
.memoryRange(nodeCount, new Concurrency(concurrency))
46+
.hasSameMinAndMaxEqualTo(expectedMemory);
47+
}
48+
}

applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithmsEstimationModeBusinessFacade.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
import org.neo4j.gds.traversal.RandomWalkCountingVisitsMemoryEstimateDefinition;
5050
import org.neo4j.gds.traversal.RandomWalkMemoryEstimateDefinition;
5151
import org.neo4j.gds.traversal.RandomWalkMutateConfig;
52+
import org.neo4j.gds.allshortestpaths.AllShortestPathsMemoryEstimateDefinition;
53+
import org.neo4j.gds.allshortestpaths.AllShortestPathsConfig;
5254

5355
/**
5456
* Here is the top level business facade for all your path finding memory estimation needs.
@@ -61,8 +63,14 @@ public PathFindingAlgorithmsEstimationModeBusinessFacade(AlgorithmEstimationTemp
6163
this.algorithmEstimationTemplate = algorithmEstimationTemplate;
6264
}
6365

64-
MemoryEstimation allShortestPaths() {
65-
throw new MemoryEstimationNotImplementedException();
66+
public MemoryEstimation allShortestPaths(AllShortestPathsConfig configuration) {
67+
return new AllShortestPathsMemoryEstimateDefinition(configuration.hasRelationshipWeightProperty()).memoryEstimation();
68+
}
69+
70+
public MemoryEstimateResult allShortestPaths(AllShortestPathsConfig configuration, Object graphNameOrConfiguration) {
71+
var memoryEstimation = allShortestPaths(configuration);
72+
73+
return runEstimation(configuration, graphNameOrConfiguration, memoryEstimation);
6674
}
6775

6876
public MemoryEstimateResult bellmanFord(AllShortestPathsBellmanFordBaseConfig configuration, Object graphNameOrConfiguration) {

open-packaging/src/test/java/org/neo4j/gds/OpenGdsProcedureSmokeTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class OpenGdsProcedureSmokeTest extends BaseProcTest {
5353
"gds.graph.sample.cnarw.estimate",
5454

5555
"gds.allShortestPaths.stream",
56+
"gds.allShortestPaths.stream.estimate",
5657

5758
"gds.articulationPoints.mutate",
5859
"gds.articulationPoints.mutate.estimate",
@@ -619,7 +620,7 @@ void countShouldMatch() {
619620
);
620621

621622
// If you find yourself updating this count, please also update the count in SmokeTest.kt
622-
int expectedCount = 459;
623+
int expectedCount = 460;
623624
assertEquals(
624625
expectedCount,
625626
returnedRows,

proc/path-finding/src/main/java/org/neo4j/gds/paths/all/AllShortestPathsStreamProc.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import org.neo4j.gds.allshortestpaths.AllShortestPathsStreamResult;
2323
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
24+
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
2425
import org.neo4j.procedure.Context;
2526
import org.neo4j.procedure.Description;
2627
import org.neo4j.procedure.Internal;
@@ -31,6 +32,7 @@
3132
import java.util.stream.Stream;
3233

3334
import static org.neo4j.gds.paths.all.Constants.ALL_PAIRS_SHORTEST_PATH_DESCRIPTION;
35+
import static org.neo4j.gds.procedures.ProcedureConstants.MEMORY_ESTIMATION_DESCRIPTION;
3436
import static org.neo4j.procedure.Mode.READ;
3537

3638
public class AllShortestPathsStreamProc {
@@ -46,6 +48,15 @@ public Stream<AllShortestPathsStreamResult> stream(
4648
return facade.algorithms().pathFinding().allShortestPathStream(graphName, configuration);
4749
}
4850

51+
@Procedure(name = "gds.allShortestPaths.stream.estimate", mode = READ)
52+
@Description(MEMORY_ESTIMATION_DESCRIPTION)
53+
public Stream<MemoryEstimateResult> estimate(
54+
@Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration,
55+
@Name(value = "algoConfiguration") Map<String, Object> algoConfiguration
56+
) {
57+
return facade.algorithms().pathFinding().allShortestPathStreamEstimate(graphNameOrConfiguration, algoConfiguration);
58+
}
59+
4960
@Procedure(name = "gds.alpha.allShortestPaths.stream", mode = READ, deprecatedBy = "gds.allShortestPaths.stream")
5061
@Description(ALL_PAIRS_SHORTEST_PATH_DESCRIPTION)
5162
@Internal

procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/LocalPathFindingProcedureFacade.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,22 @@ public Stream<AllShortestPathsStreamResult> allShortestPathStream(
265265
);
266266
}
267267

268+
@Override
269+
public Stream<MemoryEstimateResult> allShortestPathStreamEstimate(
270+
Object graphNameOrConfiguration,
271+
Map<String, Object> algorithmConfiguration
272+
) {
273+
var parsedConfiguration = configurationParser.parseConfiguration(
274+
algorithmConfiguration,
275+
AllShortestPathsConfig::of
276+
);
277+
278+
return Stream.of(estimationModeBusinessFacade.allShortestPaths(
279+
parsedConfiguration,
280+
graphNameOrConfiguration
281+
));
282+
}
283+
268284
@Override
269285
public Stream<BellmanFordStreamResult> bellmanFordStream(String graphName, Map<String, Object> configuration) {
270286
var routeRequested = procedureReturnColumns.contains("route");

procedures/facade-api/path-finding-facade-api/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/PathFindingProcedureFacade.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ Stream<AllShortestPathsStreamResult> allShortestPathStream(
3838
Map<String, Object> configuration
3939
);
4040

41+
Stream<MemoryEstimateResult> allShortestPathStreamEstimate(
42+
Object graphNameOrConfiguration,
43+
Map<String, Object> algorithmConfiguration
44+
);
45+
4146
Stream<BellmanFordStreamResult> bellmanFordStream(String graphName, Map<String, Object> configuration);
4247

4348
Stream<MemoryEstimateResult> bellmanFordStreamEstimate(

0 commit comments

Comments
 (0)