Skip to content

Commit 624bd68

Browse files
hindogIoannisPanagiotas
authored andcommitted
RandomWalk: ensure we read updated termination flag in algorithm
1 parent 8e075c4 commit 624bd68

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

algo/src/main/java/org/neo4j/gds/traversal/RandomWalk.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ public Stream<long[]> compute() {
9696
? new NextNodeSupplier.GraphNodeSupplier(graph.nodeCount())
9797
: NextNodeSupplier.ListNodeSupplier.of(config, graph);
9898

99-
var terminationFlag = new ExternalTerminationFlag(this.terminationFlag);
99+
var terminationFlag = new ExternalTerminationFlag(this);
100100

101101
BlockingQueue<long[]> walks = new ArrayBlockingQueue<>(config.walkBufferSize());
102102
long[] TOMB = new long[0];
@@ -198,15 +198,15 @@ private Stream<long[]> walksQueueConsumer(
198198

199199
private static final class ExternalTerminationFlag implements TerminationFlag {
200200
private volatile boolean running = true;
201-
private final TerminationFlag inner;
201+
private final Algorithm<?> algo;
202202

203-
ExternalTerminationFlag(TerminationFlag inner) {
204-
this.inner = inner;
203+
ExternalTerminationFlag(Algorithm<?> algo) {
204+
this.algo = algo;
205205
}
206206

207207
@Override
208208
public boolean running() {
209-
return this.running && this.inner.running();
209+
return this.running && this.algo.getTerminationFlag().running();
210210
}
211211

212212
void stop() {

algo/src/test/java/org/neo4j/gds/traversal/RandomWalkTest.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.neo4j.gds.compat.Neo4jProxy;
3737
import org.neo4j.gds.compat.TestLog;
3838
import org.neo4j.gds.core.concurrency.Pools;
39+
import org.neo4j.gds.core.utils.TerminationFlag;
3940
import org.neo4j.gds.core.utils.progress.GlobalTaskStore;
4041
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
4142
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
@@ -411,6 +412,38 @@ void testWithConfiguredOffsetStartNodes() {
411412
.anyMatch(walk -> walk[0] == bInternalId);
412413
}
413414

415+
/**
416+
* Ensure that when termination flag is set externally, we terminate the walk
417+
* @throws InterruptedException
418+
*/
419+
@Test
420+
void testPartialReadMultipleRuns() {
421+
for (int i = 0; i < 3; i++) {
422+
Node2VecStreamConfig config = ImmutableNode2VecStreamConfig.builder()
423+
.walkBufferSize(1)
424+
.build();
425+
426+
var randomWalk = RandomWalk.create(
427+
graph,
428+
config,
429+
ProgressTracker.NULL_TRACKER,
430+
Pools.DEFAULT
431+
);
432+
433+
var stream = randomWalk.compute();
434+
long count = stream.limit(10).count();
435+
436+
randomWalk.setTerminationFlag(new TerminationFlag() {
437+
@Override
438+
public boolean running() {
439+
return false;
440+
}
441+
});
442+
443+
assertEquals(10, count);
444+
}
445+
}
446+
414447
@Nested
415448
class ProgressTracking {
416449

0 commit comments

Comments
 (0)