Skip to content

Commit 8ed3f6a

Browse files
improve run time for terminal selection
1 parent 803bbc4 commit 8ed3f6a

File tree

2 files changed

+39
-23
lines changed

2 files changed

+39
-23
lines changed

algo/src/main/java/org/neo4j/gds/steiner/SteinerBasedDeltaStepping.java

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.neo4j.gds.core.utils.paged.HugeAtomicLongArray;
3030
import org.neo4j.gds.core.utils.paged.HugeLongArray;
3131
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
32+
import org.neo4j.gds.core.utils.queue.HugeLongPriorityQueue;
3233
import org.neo4j.gds.paths.ImmutablePathResult;
3334
import org.neo4j.gds.paths.PathResult;
3435
import org.neo4j.gds.paths.delta.TentativeDistances;
@@ -39,6 +40,7 @@
3940
import java.util.concurrent.ExecutorService;
4041
import java.util.concurrent.atomic.AtomicLong;
4142
import java.util.concurrent.atomic.LongAdder;
43+
import java.util.concurrent.locks.ReentrantLock;
4244
import java.util.stream.Collectors;
4345
import java.util.stream.IntStream;
4446

@@ -65,11 +67,9 @@ public final class SteinerBasedDeltaStepping extends Algorithm<DijkstraResult> {
6567
private final TentativeDistances distances;
6668
private final ExecutorService executorService;
6769
private long pathIndex;
68-
6970
private final long numOfTerminals;
7071
private final BitSet unvisitedTerminal;
7172
private final BitSet mergedWithSource;
72-
7373
private final LongAdder metTerminals;
7474

7575
SteinerBasedDeltaStepping(
@@ -143,24 +143,13 @@ private void syncPhase(List<SteinerBasedDeltaTask> tasks,int currentBin, AtomicL
143143
progressTracker.endSubTask();
144144
}
145145

146-
private long nextTerminal(){
147-
long index=unvisitedTerminal.nextSetBit(0);
148-
long bestTerminal=index;
149-
double bestDistance=distances.distance(bestTerminal);
150-
index=unvisitedTerminal.nextSetBit(index+1);
151-
while (index!=-1){
152-
double currentDistance=distances.distance(index);
153-
if (currentDistance < bestDistance){
154-
bestTerminal=index;
155-
bestDistance=currentDistance;
156-
}
157-
index=unvisitedTerminal.nextSetBit(index+1);
158-
}
159-
return bestTerminal;
146+
private long nextTerminal(HugeLongPriorityQueue terminalQueue) {
147+
return (terminalQueue.isEmpty()) ? -1 : terminalQueue.top();
160148
}
161149

162150
private boolean updateSteinerTree(long terminalId,AtomicLong frontierIndex,List<PathResult> paths, ImmutablePathResult.Builder pathResultBuilder) {
163151
//add the new path to the solution
152+
164153
paths.add(pathResult(
165154
pathResultBuilder,
166155
pathIndex++,
@@ -182,7 +171,7 @@ private boolean updateSteinerTree(long terminalId,AtomicLong frontierIndex,List<
182171

183172
}
184173

185-
private long tryToUpdateSteinerTree(long oldBin, long currentBin) {
174+
private long tryToUpdateSteinerTree(long oldBin, long currentBin, HugeLongPriorityQueue terminalQueue) {
186175
boolean shouldComputeClosestTerminal = false;
187176
//delta-Stepping differs by Dijkstra in that it processes the nodes not one-by-one but in batches
188177
//whereas in dijkstra once we examine a node, we are certain we have found the shortest path to it,
@@ -198,7 +187,8 @@ private long tryToUpdateSteinerTree(long oldBin, long currentBin) {
198187
shouldComputeClosestTerminal = true;
199188
}
200189
if (shouldComputeClosestTerminal) {
201-
long terminalId = nextTerminal();
190+
long terminalId = nextTerminal(terminalQueue);
191+
if (terminalId == -1) return -1;
202192
if (distances.distance(terminalId) < currentBin * delta) {
203193
return terminalId;
204194
}
@@ -217,12 +207,14 @@ public DijkstraResult compute() {
217207
var frontierIndex = new AtomicLong(0);
218208
var frontierSize = new AtomicLong(1);
219209

220-
List<PathResult> paths=new ArrayList<>();
210+
List<PathResult> paths = new ArrayList<>();
221211

222212
this.frontier.set(currentBin, startNode);
223213
mergedWithSource.set(startNode);
224214
this.distances.set(startNode, -1, 0);
225215

216+
HugeLongPriorityQueue terminalQueue = HugeLongPriorityQueue.min(unvisitedTerminal.size());
217+
var terminalQueueLock = new ReentrantLock();
226218
var tasks = IntStream
227219
.range(0, concurrency)
228220
.mapToObj(i -> new SteinerBasedDeltaTask(
@@ -231,7 +223,10 @@ public DijkstraResult compute() {
231223
distances,
232224
delta,
233225
frontierIndex,
234-
mergedWithSource
226+
mergedWithSource,
227+
terminalQueue,
228+
terminalQueueLock,
229+
unvisitedTerminal
235230
))
236231
.collect(Collectors.toList());
237232

@@ -244,10 +239,11 @@ public DijkstraResult compute() {
244239
// Find smallest non-empty bin across all tasks
245240
currentBin = tasks.stream().mapToInt(SteinerBasedDeltaTask::minNonEmptyBin).min().orElseThrow();
246241

247-
long terminalId = tryToUpdateSteinerTree(oldCurrentBin, currentBin);
242+
long terminalId = tryToUpdateSteinerTree(oldCurrentBin, currentBin, terminalQueue);
248243

249244
if (terminalId != -1) { //if we are certain that we have found a shortest path to one of the remaining terminals
250245
//we update the solution and merge its path to the root
246+
terminalQueue.pop();
251247
shouldBreak = updateSteinerTree(terminalId, frontierIndex, paths, pathResultBuilder);
252248
currentBin = 0;
253249
} else { //otherwise proceed as normal, sync the contents of the bucket for each thread to the global queue.

algo/src/main/java/org/neo4j/gds/steiner/SteinerBasedDeltaTask.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
import com.carrotsearch.hppc.procedures.LongProcedure;
2626
import org.neo4j.gds.api.Graph;
2727
import org.neo4j.gds.core.utils.paged.HugeLongArray;
28+
import org.neo4j.gds.core.utils.queue.HugeLongPriorityQueue;
2829
import org.neo4j.gds.paths.delta.TentativeDistances;
2930

3031
import java.util.Arrays;
3132
import java.util.concurrent.atomic.AtomicLong;
33+
import java.util.concurrent.locks.ReentrantLock;
3234

3335
import static org.neo4j.gds.steiner.SteinerBasedDeltaStepping.BIN_SIZE_THRESHOLD;
3436
import static org.neo4j.gds.steiner.SteinerBasedDeltaStepping.NO_BIN;
@@ -44,23 +46,32 @@ class SteinerBasedDeltaTask implements Runnable {
4446
private LongArrayList[] localBins;
4547
private SteinerBasedDeltaStepping.Phase phase = SteinerBasedDeltaStepping.Phase.RELAX;
4648
private final BitSet mergedToSource;
49+
private final BitSet uninvisitedTerminal;
50+
private final HugeLongPriorityQueue terminalQueue;
51+
private final ReentrantLock terminalQueueLock;
4752

4853
SteinerBasedDeltaTask(
4954
Graph graph,
5055
HugeLongArray frontier,
5156
TentativeDistances distances,
5257
double delta,
5358
AtomicLong frontierIndex,
54-
BitSet mergedToSource
59+
BitSet mergedToSource,
60+
HugeLongPriorityQueue terminalQueue,
61+
ReentrantLock terminalQueueLock,
62+
BitSet uninvisitedTerminal
5563
) {
5664

5765
this.graph = graph;
5866
this.frontier = frontier;
5967
this.distances = distances;
6068
this.delta = delta;
6169
this.frontierIndex = frontierIndex;
62-
this.mergedToSource=mergedToSource;
70+
this.mergedToSource = mergedToSource;
6371
this.localBins = new LongArrayList[0];
72+
this.terminalQueue = terminalQueue;
73+
this.terminalQueueLock = terminalQueueLock;
74+
this.uninvisitedTerminal = uninvisitedTerminal;
6475
}
6576

6677
@Override
@@ -148,6 +159,15 @@ private void tryToUpdate(long sourceNodeId, long targetNodeId,double weight) {
148159
// CAX failed, retry
149160
oldDist = witness;
150161
}
162+
if (uninvisitedTerminal.get(targetNodeId)) {
163+
164+
terminalQueueLock.lock();
165+
if (!terminalQueue.containsElement(targetNodeId) || terminalQueue.cost(targetNodeId) > newDist) {
166+
terminalQueue.set(targetNodeId, newDist);
167+
}
168+
terminalQueueLock.unlock();
169+
170+
}
151171
}
152172

153173
private void updateFrontier() {

0 commit comments

Comments
 (0)