Skip to content

Commit 7cb526d

Browse files
Extend testing for the new pruning parameter.
1 parent 805f358 commit 7cb526d

File tree

4 files changed

+53
-13
lines changed

4 files changed

+53
-13
lines changed

algo/src/main/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithm.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ public class ShortestPathsSteinerAlgorithm extends Algorithm<SteinerTreeResult>
4949
private final double delta;
5050
private final ExecutorService executorService;
5151

52+
private int binSizeThreshold;
53+
5254

5355
public ShortestPathsSteinerAlgorithm(
5456
Graph graph,
@@ -70,6 +72,28 @@ public ShortestPathsSteinerAlgorithm(
7072
this.executorService = executorService;
7173
}
7274

75+
ShortestPathsSteinerAlgorithm(
76+
Graph graph,
77+
long sourceId,
78+
List<Long> terminals,
79+
double delta,
80+
int concurrency,
81+
boolean applyRerouting,
82+
int binSizeThreshold,
83+
ExecutorService executorService
84+
) {
85+
super(ProgressTracker.NULL_TRACKER);
86+
this.graph = graph;
87+
this.sourceId = sourceId;
88+
this.terminals = terminals;
89+
this.concurrency = concurrency;
90+
this.delta = delta;
91+
this.isTerminal = createTerminals();
92+
this.applyRerouting = applyRerouting;
93+
this.executorService = executorService;
94+
this.binSizeThreshold = binSizeThreshold;
95+
}
96+
7397
private BitSet createTerminals() {
7498
long maxTerminalId = -1;
7599
for (long terminalId : terminals) {
@@ -168,6 +192,7 @@ private DijkstraResult runShortestPaths() {
168192
delta,
169193
isTerminal,
170194
concurrency,
195+
binSizeThreshold,
171196
executorService,
172197
ProgressTracker.NULL_TRACKER
173198
);

algo/src/main/java/org/neo4j/gds/steiner/SteinerBasedDeltaStepping.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,15 @@ public final class SteinerBasedDeltaStepping extends Algorithm<DijkstraResult> {
7272
private final BitSet unvisitedTerminal;
7373
private final BitSet mergedWithSource;
7474
private final LongAdder metTerminals;
75+
private int binSizeThreshold;
7576

7677
SteinerBasedDeltaStepping(
7778
Graph graph,
7879
long startNode,
7980
double delta,
8081
BitSet isTerminal,
8182
int concurrency,
83+
int binSizeThreshold,
8284
ExecutorService executorService,
8385
ProgressTracker progressTracker
8486
) {
@@ -99,7 +101,7 @@ public final class SteinerBasedDeltaStepping extends Algorithm<DijkstraResult> {
99101
this.pathIndex = 0;
100102
this.metTerminals = new LongAdder();
101103
this.numOfTerminals = isTerminal.cardinality();
102-
104+
this.binSizeThreshold = binSizeThreshold;
103105
}
104106

105107
private void mergeNodesOnPathToSource(long nodeId, AtomicLong frontierIndex) {
@@ -198,7 +200,6 @@ private boolean ensureShortest(
198200
.min()
199201
.orElseThrow();
200202
//return true if the closet terminal is at least as close as the closest next node
201-
System.out.println("hi" + distance + " " + currentMinDistance + " " + (distance <= currentMinDistance) + " " + oldBin + " " + currentBin);
202203
return distance <= currentMinDistance;
203204
} else {
204205
return (distance < currentBin * delta);
@@ -286,7 +287,8 @@ public DijkstraResult compute() {
286287
mergedWithSource,
287288
terminalQueue,
288289
terminalQueueLock,
289-
unvisitedTerminal
290+
unvisitedTerminal,
291+
binSizeThreshold
290292
))
291293
.collect(Collectors.toList());
292294

algo/src/main/java/org/neo4j/gds/steiner/SteinerBasedDeltaTask.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import java.util.concurrent.atomic.AtomicLong;
3333
import java.util.concurrent.locks.ReentrantLock;
3434

35-
import static org.neo4j.gds.steiner.SteinerBasedDeltaStepping.BIN_SIZE_THRESHOLD;
3635
import static org.neo4j.gds.steiner.SteinerBasedDeltaStepping.NO_BIN;
3736

3837
class SteinerBasedDeltaTask implements Runnable {
@@ -49,8 +48,8 @@ class SteinerBasedDeltaTask implements Runnable {
4948
private final BitSet uninvisitedTerminal;
5049
private final HugeLongPriorityQueue terminalQueue;
5150
private final ReentrantLock terminalQueueLock;
52-
5351
private double smallestConsideredDistance;
52+
private int binSizeThreshold;
5453

5554
SteinerBasedDeltaTask(
5655
Graph graph,
@@ -61,7 +60,8 @@ class SteinerBasedDeltaTask implements Runnable {
6160
BitSet mergedToSource,
6261
HugeLongPriorityQueue terminalQueue,
6362
ReentrantLock terminalQueueLock,
64-
BitSet uninvisitedTerminal
63+
BitSet uninvisitedTerminal,
64+
int binSizeThreshold
6565
) {
6666

6767
this.graph = graph;
@@ -74,6 +74,7 @@ class SteinerBasedDeltaTask implements Runnable {
7474
this.terminalQueue = terminalQueue;
7575
this.terminalQueueLock = terminalQueueLock;
7676
this.uninvisitedTerminal = uninvisitedTerminal;
77+
this.binSizeThreshold = binSizeThreshold;
7778
}
7879

7980
@Override
@@ -128,7 +129,7 @@ private void relaxLocalBin() {
128129
while (binIndex < localBins.length
129130
&& localBins[binIndex] != null
130131
&& !localBins[binIndex].isEmpty()
131-
&& localBins[binIndex].size() < BIN_SIZE_THRESHOLD) {
132+
&& localBins[binIndex].size() < binSizeThreshold) {
132133
var binCopy = localBins[binIndex].clone();
133134
localBins[binIndex].elementsCount = 0;
134135
binCopy.forEach((LongProcedure) this::relaxNode);

algo/src/test/java/org/neo4j/gds/steiner/ShortestPathSteinerAlgorithmExtendedTest.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import com.carrotsearch.hppc.BitSet;
2323
import org.junit.jupiter.api.Test;
2424
import org.junit.jupiter.params.ParameterizedTest;
25-
import org.junit.jupiter.params.provider.ValueSource;
25+
import org.junit.jupiter.params.provider.Arguments;
26+
import org.junit.jupiter.params.provider.MethodSource;
2627
import org.neo4j.gds.Orientation;
2728
import org.neo4j.gds.core.concurrency.Pools;
2829
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
@@ -34,8 +35,10 @@
3435
import org.neo4j.gds.paths.PathResult;
3536

3637
import java.util.List;
38+
import java.util.stream.Stream;
3739

3840
import static org.assertj.core.api.Assertions.assertThat;
41+
import static org.junit.jupiter.params.provider.Arguments.arguments;
3942

4043
@GdlExtension
4144
class ShortestPathSteinerAlgorithmExtendedTest {
@@ -140,12 +143,18 @@ class ShortestPathSteinerAlgorithmExtendedTest {
140143
@Inject
141144
private IdFunction triangleIdFunction;
142145

146+
static Stream<Arguments> inputTuples() {
147+
return Stream.of(
148+
149+
arguments(2.0, SteinerBasedDeltaStepping.BIN_SIZE_THRESHOLD), //default settings
150+
arguments(5.0, 0), //values in two buckets, can deduce lowest shortest without changing bucket
151+
arguments(100, 0) //everything in a single bucket, can likewise
152+
);
153+
}
154+
143155
@ParameterizedTest
144-
@ValueSource(doubles = {2.0, 5.0, 100.0})
145-
//2.0 is the standard one we use everything
146-
//5.0 is one such that we can deduce a2 is best without changing bucket
147-
//100.0 is a big one where everything takes place inside a single bucket
148-
void shouldWorkCorrectly(double delta) {
156+
@MethodSource("inputTuples")
157+
void shouldWorkCorrectly(double delta, int binSizeThreshold) {
149158

150159
var a = SteinerTestUtils.getNodes(idFunction, 6);
151160
var steinerTreeResult = new ShortestPathsSteinerAlgorithm(
@@ -155,6 +164,8 @@ void shouldWorkCorrectly(double delta) {
155164
delta,
156165
1,
157166
false,
167+
binSizeThreshold,
168+
//setting custom threshold for such a small graph allows to not examine everything in a single iteration
158169
Pools.DEFAULT
159170
).compute();
160171

@@ -198,6 +209,7 @@ void deltaSteppingShouldWorkCorrectly() {
198209
2.0,
199210
isTerminal,
200211
1,
212+
SteinerBasedDeltaStepping.BIN_SIZE_THRESHOLD,
201213
Pools.DEFAULT,
202214
ProgressTracker.NULL_TRACKER
203215
);

0 commit comments

Comments
 (0)