Skip to content

Commit bd4dcee

Browse files
Merge pull request #6577 from IoannisPanagiotas/steiner-quicker-exit-on-same-bucket
[Steiner] Discover shortest path to a target node earlier when next bin is the same in the Delta-Stepping process
2 parents 706145d + baea033 commit bd4dcee

File tree

4 files changed

+165
-45
lines changed

4 files changed

+165
-45
lines changed

algo/src/main/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithm.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import com.carrotsearch.hppc.BitSet;
2323
import org.apache.commons.lang3.mutable.MutableBoolean;
24+
import org.jetbrains.annotations.TestOnly;
2425
import org.neo4j.gds.Algorithm;
2526
import org.neo4j.gds.api.Graph;
2627
import org.neo4j.gds.core.concurrency.ParallelUtil;
@@ -49,6 +50,7 @@ public class ShortestPathsSteinerAlgorithm extends Algorithm<SteinerTreeResult>
4950
private final double delta;
5051
private final ExecutorService executorService;
5152

53+
private int binSizeThreshold;
5254

5355
public ShortestPathsSteinerAlgorithm(
5456
Graph graph,
@@ -68,6 +70,30 @@ public ShortestPathsSteinerAlgorithm(
6870
this.isTerminal = createTerminals();
6971
this.applyRerouting = applyRerouting;
7072
this.executorService = executorService;
73+
this.binSizeThreshold = SteinerBasedDeltaStepping.BIN_SIZE_THRESHOLD;
74+
}
75+
76+
@TestOnly
77+
ShortestPathsSteinerAlgorithm(
78+
Graph graph,
79+
long sourceId,
80+
List<Long> terminals,
81+
double delta,
82+
int concurrency,
83+
boolean applyRerouting,
84+
int binSizeThreshold,
85+
ExecutorService executorService
86+
) {
87+
super(ProgressTracker.NULL_TRACKER);
88+
this.graph = graph;
89+
this.sourceId = sourceId;
90+
this.terminals = terminals;
91+
this.concurrency = concurrency;
92+
this.delta = delta;
93+
this.isTerminal = createTerminals();
94+
this.applyRerouting = applyRerouting;
95+
this.executorService = executorService;
96+
this.binSizeThreshold = binSizeThreshold;
7197
}
7298

7399
private BitSet createTerminals() {
@@ -168,6 +194,7 @@ private DijkstraResult runShortestPaths() {
168194
delta,
169195
isTerminal,
170196
concurrency,
197+
binSizeThreshold,
171198
executorService,
172199
ProgressTracker.NULL_TRACKER
173200
);

algo/src/main/java/org/neo4j/gds/steiner/SteinerBasedDeltaStepping.java

Lines changed: 105 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
public final class SteinerBasedDeltaStepping extends Algorithm<DijkstraResult> {
5959

6060
public static final int NO_BIN = Integer.MAX_VALUE;
61-
6261
private static final long NO_TERMINAL = -1;
6362
public static final int BIN_SIZE_THRESHOLD = 1000;
6463
private final Graph graph;
@@ -70,16 +69,18 @@ public final class SteinerBasedDeltaStepping extends Algorithm<DijkstraResult> {
7069
private final ExecutorService executorService;
7170
private long pathIndex;
7271
private final long numOfTerminals;
73-
private final BitSet unvisitedTerminal;
72+
private final BitSet unvisitedTerminal;
7473
private final BitSet mergedWithSource;
7574
private final LongAdder metTerminals;
75+
private int binSizeThreshold;
7676

7777
SteinerBasedDeltaStepping(
7878
Graph graph,
7979
long startNode,
8080
double delta,
8181
BitSet isTerminal,
8282
int concurrency,
83+
int binSizeThreshold,
8384
ExecutorService executorService,
8485
ProgressTracker progressTracker
8586
) {
@@ -94,16 +95,16 @@ public final class SteinerBasedDeltaStepping extends Algorithm<DijkstraResult> {
9495
graph.nodeCount(),
9596
concurrency
9697
);
97-
this.mergedWithSource =new BitSet(graph.nodeCount());
98-
this.unvisitedTerminal= new BitSet(isTerminal.size());
98+
this.mergedWithSource = new BitSet(graph.nodeCount());
99+
this.unvisitedTerminal = new BitSet(isTerminal.size());
99100
unvisitedTerminal.or(isTerminal);
100-
this.pathIndex=0;
101-
this.metTerminals=new LongAdder();
102-
this.numOfTerminals=isTerminal.cardinality();
103-
101+
this.pathIndex = 0;
102+
this.metTerminals = new LongAdder();
103+
this.numOfTerminals = isTerminal.cardinality();
104+
this.binSizeThreshold = binSizeThreshold;
104105
}
105106

106-
private void mergeNodesOnPathToSource(long nodeId,AtomicLong frontierIndex) {
107+
private void mergeNodesOnPathToSource(long nodeId, AtomicLong frontierIndex) {
107108
long currentId = nodeId;
108109
//while not meeting merged nodes, add the current path node to the merge set
109110
//if the parent i merged, then it's path to the source has already been zeroed,
@@ -123,7 +124,7 @@ private void mergeNodesOnPathToSource(long nodeId,AtomicLong frontierIndex) {
123124
}
124125
}
125126

126-
private void relaxPhase(List<SteinerBasedDeltaTask> tasks,int currentBin,AtomicLong frontierSize){
127+
private void relaxPhase(List<SteinerBasedDeltaTask> tasks, int currentBin, AtomicLong frontierSize) {
127128
// Phase 1
128129
for (var task : tasks) {
129130
task.setPhase(Phase.RELAX);
@@ -133,7 +134,7 @@ private void relaxPhase(List<SteinerBasedDeltaTask> tasks,int currentBin,AtomicL
133134
ParallelUtil.run(tasks, executorService);
134135
}
135136

136-
private void syncPhase(List<SteinerBasedDeltaTask> tasks,int currentBin, AtomicLong frontierIndex){
137+
private void syncPhase(List<SteinerBasedDeltaTask> tasks, int currentBin, AtomicLong frontierIndex) {
137138
frontierIndex.set(0);
138139
tasks.forEach(task -> task.setPhase(Phase.SYNC));
139140

@@ -149,7 +150,12 @@ private long nextTerminal(HugeLongPriorityQueue terminalQueue) {
149150
return (terminalQueue.isEmpty()) ? NO_TERMINAL : terminalQueue.top();
150151
}
151152

152-
private boolean updateSteinerTree(long terminalId,AtomicLong frontierIndex,List<PathResult> paths, ImmutablePathResult.Builder pathResultBuilder) {
153+
private boolean updateSteinerTree(
154+
long terminalId,
155+
AtomicLong frontierIndex,
156+
List<PathResult> paths,
157+
ImmutablePathResult.Builder pathResultBuilder
158+
) {
153159
//add the new path to the solution
154160

155161
paths.add(pathResult(
@@ -173,34 +179,87 @@ private boolean updateSteinerTree(long terminalId,AtomicLong frontierIndex,List<
173179

174180
}
175181

176-
private long tryToUpdateSteinerTree(long oldBin, long currentBin, HugeLongPriorityQueue terminalQueue) {
177-
boolean shouldComputeClosestTerminal = false;
178-
//delta-Stepping differs by Dijkstra in that it processes the nodes not one-by-one but in batches
179-
//whereas in dijkstra once we examine a node, we are certain we have found the shortest path to it,
180-
//in delta-stepping this is not the case
181-
//for example assume a huge delta and assume a bin contains two nodes with distance a (distance=101) and
182-
//b (distance=98) in the same bucket. Assume furthermore, the edge b->a with cost 1 exists.
183-
//Then a is examined, and because of b->a it is re-examined and hten we find a smaller distance from it (99).
184-
185-
//For the moment, we use a simple criteria to discover if there is a terminal for which with full certainty,
186-
//we have found a shortest to it: Whenever we change from one bin to another, we find the terminal of smallest distance
187-
//if it's distance is below the currentBin, the path to it is optimal.
188-
if (currentBin == NO_BIN || oldBin < currentBin) {
189-
shouldComputeClosestTerminal = true;
182+
private boolean ensureShortest(
183+
double distance,
184+
long oldBin,
185+
long currentBin,
186+
List<SteinerBasedDeltaTask> tasks
187+
) {
188+
if (currentBin == NO_BIN) { //there is no more nodes to relax, path to target node is certainly shortest
189+
return true;
190190
}
191-
if (shouldComputeClosestTerminal) {
192-
long terminalId = nextTerminal(terminalQueue);
193-
if (terminalId == NO_TERMINAL) return NO_TERMINAL;
194-
if (distances.distance(terminalId) < currentBin * delta) {
195-
return terminalId;
191+
if (oldBin == currentBin) {
192+
//if closest terminal is in another bucket, can't be sure it's the best path
193+
if (distance >= (currentBin + 1) * delta) {
194+
return false;
196195
}
196+
//find closest node to be processed afterwards
197+
double currentMinDistance = tasks
198+
.stream()
199+
.mapToDouble(SteinerBasedDeltaTask::getSmallestConsideredDistance)
200+
.min()
201+
.orElseThrow();
202+
//return true if the closet terminal is at least as close as the closest next node
203+
return distance <= currentMinDistance;
204+
} else {
205+
return (distance < currentBin * delta);
197206
}
198-
return -1;
207+
}
208+
209+
//delta-Stepping differs by Dijkstra in that it processes the nodes not one-by-one but in batches
210+
//whereas in dijkstra once we examine a node, we are certain we have found the shortest path to it,
211+
//in delta-stepping this is not the case
212+
//for example assume a huge delta and assume a bin contains two nodes with distance a (distance=101) and
213+
//b (distance=98) in the same bucket. Assume furthermore, the edge b->a with cost 1 exists.
214+
//Then a is examined, and because of b->a it is re-examined and then we find a smaller distance from it (99).
215+
private long tryToUpdateSteinerTree(
216+
long oldBin,
217+
long currentBin,
218+
HugeLongPriorityQueue terminalQueue,
219+
List<SteinerBasedDeltaTask> tasks
220+
) {
221+
222+
//We Use two simple criteria to discover if there is a terminal for which with full certainty,
223+
//we have found a shortest to it:
224+
225+
// The first criterion checks whenever when we move to a new bucket in the next iteration.
226+
//We first consider the terminal which currently has the shortest distance from the source node.
227+
// If it's distance is below the threshold of the currentBin (i.e., the one processed in the
228+
//next step), then it means its path from source cannot be improved further
229+
// (since all values in the next bucket are further from the source node).
230+
231+
//The next criterion is relevant when we continue with the same bucket in the next iteration.
232+
//Again, we consider the terminal t which currently has the shortest distance from the source node.
233+
//If this terminal is inside the current Bucket,
234+
//We consider from all nodes that will be relaxed later on, the one with smallest distance from source
235+
//(let's call this node r)
236+
//Since any nodes that will be updated in the next iteration, will end up having a distance >=d(r)
237+
//any future nodes examined for the same bucket during the current phase, will have a distance d >= d(r).
238+
//Hence, if d(t) <=d(r) we can be certain neither of those next examined vertices will be able to improve
239+
// t's path (or improve the path to any other terminal since they too would end up with a value >=d(r)).
240+
241+
//Note it is not required that t (the closet terminal) is in the current bucket B.
242+
//After a path has been merged, we return back to Bucket 0. If t is in the bucket B where B>0.
243+
//Then at some point, we must pass from a bucket B' <B to a bucket B<=B'' hence the first criterion
244+
//will be able to locate t via the change of bucket criterion.
245+
246+
long terminalId = nextTerminal(terminalQueue);
247+
if (terminalId == NO_TERMINAL) {
248+
return NO_TERMINAL;
249+
}
250+
251+
boolean shouldReturnTerminal = ensureShortest(
252+
distances.distance(terminalId),
253+
oldBin,
254+
currentBin,
255+
tasks
256+
);
257+
258+
return (shouldReturnTerminal) ? terminalId : NO_TERMINAL;
199259
}
200260

201261
@Override
202262
public DijkstraResult compute() {
203-
int iteration = 0;
204263
int currentBin = 0;
205264

206265
var pathResultBuilder = ImmutablePathResult.builder()
@@ -228,7 +287,8 @@ public DijkstraResult compute() {
228287
mergedWithSource,
229288
terminalQueue,
230289
terminalQueueLock,
231-
unvisitedTerminal
290+
unvisitedTerminal,
291+
binSizeThreshold
232292
))
233293
.collect(Collectors.toList());
234294

@@ -241,18 +301,25 @@ public DijkstraResult compute() {
241301
// Find smallest non-empty bin across all tasks
242302
currentBin = tasks.stream().mapToInt(SteinerBasedDeltaTask::minNonEmptyBin).min().orElseThrow();
243303

244-
long terminalId = tryToUpdateSteinerTree(oldCurrentBin, currentBin, terminalQueue);
304+
long terminalId = tryToUpdateSteinerTree(oldCurrentBin, currentBin, terminalQueue, tasks);
245305

246-
if (terminalId != -1) { //if we are certain that we have found a shortest path to one of the remaining terminals
306+
if (terminalId != NO_TERMINAL) { //if we are certain that we have found a shortest path to one of the remaining terminals
247307
//we update the solution and merge its path to the root
248308
terminalQueue.pop();
249309
shouldBreak = updateSteinerTree(terminalId, frontierIndex, paths, pathResultBuilder);
250310
currentBin = 0;
251-
} else { //otherwise proceed as normal, sync the contents of the bucket for each thread to the global queue.
311+
//Note if this scenario occurs:
312+
// The content in the local buckets which normally would have been synced, remains stored inside buckets
313+
//and can be processed next time the currentBin is set as the smallest bin.
314+
//There might be some invalid situations (for example node a is located in multiple local buckets of different threads)
315+
//But such situations are also possible for the original delta-stepping algorithm,
316+
//although admittedly, this might be more frequent due to the "revert path to zero" steiner heuristic!
317+
// I doubt this is a very big problem though
318+
} else {
319+
//otherwise proceed as normal, sync the contents of the bucket for each thread to the global queue.
252320
// Phase 2
253321
syncPhase(tasks, currentBin, frontierIndex);
254322
}
255-
iteration += 1;
256323
frontierSize.set(frontierIndex.longValue());
257324
frontierIndex.set(0);
258325
}

algo/src/main/java/org/neo4j/gds/steiner/SteinerBasedDeltaTask.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import java.util.concurrent.atomic.AtomicLong;
3333
import java.util.concurrent.locks.ReentrantLock;
3434

35-
import static org.neo4j.gds.steiner.SteinerBasedDeltaStepping.BIN_SIZE_THRESHOLD;
3635
import static org.neo4j.gds.steiner.SteinerBasedDeltaStepping.NO_BIN;
3736

3837
class SteinerBasedDeltaTask implements Runnable {
@@ -49,6 +48,8 @@ class SteinerBasedDeltaTask implements Runnable {
4948
private final BitSet uninvisitedTerminal;
5049
private final HugeLongPriorityQueue terminalQueue;
5150
private final ReentrantLock terminalQueueLock;
51+
private double smallestConsideredDistance;
52+
private int binSizeThreshold;
5253

5354
SteinerBasedDeltaTask(
5455
Graph graph,
@@ -59,7 +60,8 @@ class SteinerBasedDeltaTask implements Runnable {
5960
BitSet mergedToSource,
6061
HugeLongPriorityQueue terminalQueue,
6162
ReentrantLock terminalQueueLock,
62-
BitSet uninvisitedTerminal
63+
BitSet uninvisitedTerminal,
64+
int binSizeThreshold
6365
) {
6466

6567
this.graph = graph;
@@ -72,18 +74,22 @@ class SteinerBasedDeltaTask implements Runnable {
7274
this.terminalQueue = terminalQueue;
7375
this.terminalQueueLock = terminalQueueLock;
7476
this.uninvisitedTerminal = uninvisitedTerminal;
77+
this.binSizeThreshold = binSizeThreshold;
7578
}
7679

7780
@Override
7881
public void run() {
7982
if (phase == SteinerBasedDeltaStepping.Phase.RELAX) {
83+
smallestConsideredDistance = Double.MAX_VALUE;
8084
relaxGlobalBin();
8185
relaxLocalBin();
8286
} else if (phase == SteinerBasedDeltaStepping.Phase.SYNC) {
8387
updateFrontier();
8488
}
8589
}
8690

91+
double getSmallestConsideredDistance() {return smallestConsideredDistance;}
92+
8793
void setPhase(SteinerBasedDeltaStepping.Phase phase) {
8894
this.phase = phase;
8995
}
@@ -123,7 +129,7 @@ private void relaxLocalBin() {
123129
while (binIndex < localBins.length
124130
&& localBins[binIndex] != null
125131
&& !localBins[binIndex].isEmpty()
126-
&& localBins[binIndex].size() < BIN_SIZE_THRESHOLD) {
132+
&& localBins[binIndex].size() < binSizeThreshold) {
127133
var binCopy = localBins[binIndex].clone();
128134
localBins[binIndex].elementsCount = 0;
129135
binCopy.forEach((LongProcedure) this::relaxNode);
@@ -132,7 +138,7 @@ private void relaxLocalBin() {
132138
private void relaxNode(long nodeId) {
133139
graph.forEachRelationship(nodeId, 1.0, (sourceNodeId, targetNodeId, weight) -> {
134140
if (!mergedToSource.get(targetNodeId)) { //ignore merged vertices
135-
tryToUpdate(sourceNodeId, targetNodeId,weight);
141+
tryToUpdate(sourceNodeId, targetNodeId, weight);
136142
}
137143
return true;
138144
});
@@ -159,6 +165,8 @@ private void tryToUpdate(long sourceNodeId, long targetNodeId,double weight) {
159165
// CAX failed, retry
160166
oldDist = witness;
161167
}
168+
smallestConsideredDistance = Math.min(newDist, smallestConsideredDistance);
169+
162170
if (uninvisitedTerminal.get(targetNodeId)) {
163171

164172
terminalQueueLock.lock();

0 commit comments

Comments
 (0)