Skip to content

Commit 0240426

Browse files
Make Node2Vec respect TerminationFlag during Training
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neotechnology.com>
1 parent 3a3a97e commit 0240426

File tree

3 files changed

+78
-9
lines changed

3 files changed

+78
-9
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ public Node2VecResult compute() {
110110
maybeRandomSeed,
111111
walks,
112112
probabilitiesBuilder.build(),
113-
progressTracker
113+
progressTracker,
114+
terminationFlag
114115
);
115116

116117
var result = node2VecModel.train();

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecModel.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2828
import org.neo4j.gds.mem.BitUtil;
2929
import org.neo4j.gds.ml.core.tensor.FloatVector;
30+
import org.neo4j.gds.termination.TerminationFlag;
3031

3132
import java.util.ArrayList;
3233
import java.util.List;
@@ -54,6 +55,7 @@ public class Node2VecModel {
5455
private final RandomWalkProbabilities randomWalkProbabilities;
5556
private final ProgressTracker progressTracker;
5657
private final long randomSeed;
58+
private final TerminationFlag terminationFlag;
5759

5860
static final double EPSILON = 1e-10;
5961

@@ -65,7 +67,8 @@ public class Node2VecModel {
6567
Optional<Long> maybeRandomSeed,
6668
CompressedRandomWalks walks,
6769
RandomWalkProbabilities randomWalkProbabilities,
68-
ProgressTracker progressTracker
70+
ProgressTracker progressTracker,
71+
TerminationFlag terminationFlag
6972
) {
7073
this(
7174
toOriginalId,
@@ -81,7 +84,8 @@ public class Node2VecModel {
8184
maybeRandomSeed,
8285
walks,
8386
randomWalkProbabilities,
84-
progressTracker
87+
progressTracker,
88+
terminationFlag
8589
);
8690
}
8791

@@ -99,7 +103,8 @@ public class Node2VecModel {
99103
Optional<Long> maybeRandomSeed,
100104
CompressedRandomWalks walks,
101105
RandomWalkProbabilities randomWalkProbabilities,
102-
ProgressTracker progressTracker
106+
ProgressTracker progressTracker,
107+
TerminationFlag terminationFlag
103108
) {
104109
this.initialLearningRate = initialLearningRate;
105110
this.minLearningRate = minLearningRate;
@@ -113,6 +118,7 @@ public class Node2VecModel {
113118
this.randomWalkProbabilities = randomWalkProbabilities;
114119
this.progressTracker = progressTracker;
115120
this.randomSeed = maybeRandomSeed.orElseGet(() -> new SplittableRandom().nextLong());
121+
this.terminationFlag = terminationFlag;
116122

117123
var random = new Random();
118124
centerEmbeddings = initializeEmbeddings(toOriginalId, nodeCount, embeddingDimension, random);
@@ -136,6 +142,7 @@ Node2VecResult train() {
136142

137143
RunWithConcurrency.builder()
138144
.concurrency(concurrency)
145+
.terminationFlag(terminationFlag)
139146
.tasks(tasks)
140147
.run();
141148

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecModelTest.java

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
import org.neo4j.gds.core.concurrency.Concurrency;
3232
import org.neo4j.gds.core.utils.Intersections;
3333
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
34+
import org.neo4j.gds.termination.TerminatedException;
35+
import org.neo4j.gds.termination.TerminationFlag;
3436

3537
import java.util.Optional;
3638
import java.util.Random;
@@ -39,6 +41,7 @@
3941
import java.util.stream.LongStream;
4042

4143
import static org.assertj.core.api.Assertions.assertThat;
44+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
4245
import static org.junit.jupiter.api.Assertions.assertEquals;
4346
import static org.mockito.ArgumentMatchers.any;
4447
import static org.mockito.ArgumentMatchers.anyLong;
@@ -92,7 +95,8 @@ void testModel() {
9295
Optional.empty(),
9396
walks,
9497
probabilitiesBuilder.build(),
95-
ProgressTracker.NULL_TRACKER
98+
ProgressTracker.NULL_TRACKER,
99+
TerminationFlag.RUNNING_TRUE
96100
);
97101

98102
var trainResult = node2VecModel.train();
@@ -194,7 +198,8 @@ void twoRunsSingleThreadedWithTheSameRandomSeed(int iterations) {
194198
Optional.of(1337L),
195199
walks,
196200
probabilitiesBuilder.build(),
197-
ProgressTracker.NULL_TRACKER
201+
ProgressTracker.NULL_TRACKER,
202+
TerminationFlag.RUNNING_TRUE
198203
).train().embeddings();
199204

200205
var secondRunEmbedding = new Node2VecModel(
@@ -205,7 +210,8 @@ void twoRunsSingleThreadedWithTheSameRandomSeed(int iterations) {
205210
Optional.of(1337L),
206211
walks,
207212
probabilitiesBuilder.build(),
208-
ProgressTracker.NULL_TRACKER
213+
ProgressTracker.NULL_TRACKER,
214+
TerminationFlag.RUNNING_TRUE
209215
).train().embeddings();
210216

211217
for (long node = 0; node < nodeCount; node++) {
@@ -239,7 +245,8 @@ void shouldCreateTrainingTasksWithCorrectRandomSeed() {
239245
Optional.of(1L), // Random Seed
240246
randomWalksMock,
241247
randomWalkProbabilitiesMock,
242-
ProgressTracker.NULL_TRACKER
248+
ProgressTracker.NULL_TRACKER,
249+
TerminationFlag.RUNNING_TRUE
243250
)
244251
);
245252

@@ -311,7 +318,8 @@ void shouldHaveCorrectLearningRate(){
311318
Optional.empty(),
312319
null,
313320
null,
314-
ProgressTracker.NULL_TRACKER
321+
ProgressTracker.NULL_TRACKER,
322+
TerminationFlag.RUNNING_TRUE
315323
);
316324

317325
assertThat(node2VecModel.learningRate(0)).isEqualTo(10f);
@@ -322,4 +330,57 @@ void shouldHaveCorrectLearningRate(){
322330
assertThat(node2VecModel.learningRate(10000)).isEqualTo(5f);
323331

324332
}
333+
334+
@Test
335+
void shouldRespectTerminationFlag() {
336+
var random = new Random(42);
337+
int numberOfClusters = 2;
338+
int clusterSize = 5;
339+
int numberOfWalks = 2;
340+
int walkLength = 5;
341+
342+
var probabilitiesBuilder = new RandomWalkProbabilitiesBuilder(
343+
numberOfClusters * clusterSize,
344+
new Concurrency(1),
345+
0.001,
346+
0.75
347+
);
348+
349+
var walks = generateRandomWalks(
350+
probabilitiesBuilder,
351+
numberOfClusters,
352+
clusterSize,
353+
numberOfWalks,
354+
walkLength,
355+
random
356+
);
357+
358+
var trainParameters = new TrainParameters(0.05, 0.0001, 10, 2, 1, 2, EmbeddingInitializer.NORMALIZED);
359+
360+
var terminationFlag = new TerminationFlag() {
361+
private int callCount = 0;
362+
@Override
363+
public boolean running() {
364+
++callCount;
365+
return callCount == 2;
366+
}
367+
};
368+
369+
var node2VecModel = new Node2VecModel(
370+
nodeId -> nodeId,
371+
1000,
372+
trainParameters,
373+
new Concurrency(4),
374+
Optional.of(19L),
375+
walks,
376+
probabilitiesBuilder.build(),
377+
ProgressTracker.NULL_TRACKER,
378+
terminationFlag
379+
);
380+
381+
assertThatExceptionOfType(TerminatedException.class)
382+
.isThrownBy(node2VecModel::train);
383+
384+
}
385+
325386
}

0 commit comments

Comments
 (0)