Skip to content

Commit d4071e0

Browse files
committed
Fix N2V progress tracking and float loss of precision
1 parent d065911 commit d4071e0

File tree

3 files changed

+20
-27
lines changed

3 files changed

+20
-27
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecModel.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,7 @@ Result train() {
162162
var positiveSampleProducer = new PositiveSampleProducer(
163163
walks.iterator(partition.startNode(), partition.nodeCount()),
164164
randomWalkProbabilities.positiveSamplingProbabilities(),
165-
windowSize,
166-
progressTracker
165+
windowSize
167166
);
168167

169168
return new TrainingTask(
@@ -173,7 +172,8 @@ Result train() {
173172
negativeSamples,
174173
learningRate,
175174
negativeSamplingRate,
176-
embeddingDimension
175+
embeddingDimension,
176+
progressTracker
177177
);
178178
}
179179
);
@@ -184,7 +184,7 @@ Result train() {
184184
.run();
185185

186186
double loss = tasks.stream().mapToDouble(TrainingTask::lossSum).sum();
187-
progressTracker.logInfo(formatWithLocale("Maximum likelihood objective is %.4f", loss));
187+
progressTracker.logInfo(formatWithLocale("Loss %.4f", loss));
188188
lossPerIteration.add(loss);
189189

190190
progressTracker.endSubTask();
@@ -235,6 +235,8 @@ private static final class TrainingTask implements Runnable {
235235
private final int negativeSamplingRate;
236236
private final float learningRate;
237237

238+
private final ProgressTracker progressTracker;
239+
238240
private double lossSum;
239241

240242
private TrainingTask(
@@ -244,7 +246,8 @@ private TrainingTask(
244246
NegativeSampleProducer negativeSampleProducer,
245247
float learningRate,
246248
int negativeSamplingRate,
247-
int embeddingDimensions
249+
int embeddingDimensions,
250+
ProgressTracker progressTracker
248251
) {
249252
this.centerEmbeddings = centerEmbeddings;
250253
this.contextEmbeddings = contextEmbeddings;
@@ -255,6 +258,7 @@ private TrainingTask(
255258

256259
this.centerGradientBuffer = new FloatVector(embeddingDimensions);
257260
this.contextGradientBuffer = new FloatVector(embeddingDimensions);
261+
this.progressTracker = progressTracker;
258262
}
259263

260264
@Override
@@ -268,6 +272,7 @@ public void run() {
268272
for (var i = 0; i < negativeSamplingRate; i++) {
269273
trainSample(buffer[0], negativeSampleProducer.next(), false);
270274
}
275+
progressTracker.logProgress();
271276
}
272277
}
273278

@@ -279,13 +284,13 @@ private void trainSample(long center, long context, boolean positive) {
279284
// L_neg = -log sigmoid(-center * context) ; gradient: sigmoid (center * context)
280285
float affinity = centerEmbedding.innerProduct(contextEmbedding);
281286

282-
float positiveSigmoid = (float) Sigmoid.sigmoid(affinity);
283-
float negativeSigmoid = 1 - positiveSigmoid;
287+
double positiveSigmoid = Sigmoid.sigmoid(affinity);
288+
double negativeSigmoid = 1 - positiveSigmoid;
284289

285290

286291
lossSum -= positive ? Math.log(positiveSigmoid) : Math.log(negativeSigmoid);
287292

288-
float gradient = positive ? -negativeSigmoid : positiveSigmoid;
293+
float gradient = positive ? (float) -negativeSigmoid : (float) positiveSigmoid;
289294
// we are doing gradient descent, so we go in the negative direction of the gradient here
290295
float scaledGradient = -gradient * learningRate;
291296

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/PositiveSampleProducer.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
package org.neo4j.gds.embeddings.node2vec;
2121

2222
import org.neo4j.gds.collections.ha.HugeDoubleArray;
23-
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2423

2524
import java.util.Iterator;
2625
import java.util.concurrent.ThreadLocalRandom;
@@ -35,7 +34,6 @@ public class PositiveSampleProducer {
3534
private final HugeDoubleArray samplingProbabilities;
3635
private final int prefixWindowSize;
3736
private final int postfixWindowSize;
38-
private final ProgressTracker progressTracker;
3937
private long[] currentWalk;
4038
private int centerWordIndex;
4139
private long currentCenterWord;
@@ -46,11 +44,9 @@ public class PositiveSampleProducer {
4644
PositiveSampleProducer(
4745
Iterator<long[]> walks,
4846
HugeDoubleArray samplingProbabilities,
49-
int windowSize,
50-
ProgressTracker progressTracker
47+
int windowSize
5148
) {
5249
this.walks = walks;
53-
this.progressTracker = progressTracker;
5450
this.samplingProbabilities = samplingProbabilities;
5551

5652
prefixWindowSize = ceilDiv(windowSize - 1, 2);
@@ -76,7 +72,6 @@ private boolean nextWalk() {
7672
return false;
7773
}
7874
long[] walk = walks.next();
79-
progressTracker.logProgress();
8075
int filteredWalkLength = filter(walk);
8176

8277
while (filteredWalkLength < 2 && walks.hasNext()) {

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/PositiveSampleProducerTest.java

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.junit.jupiter.params.provider.Arguments;
2626
import org.junit.jupiter.params.provider.MethodSource;
2727
import org.neo4j.gds.collections.ha.HugeDoubleArray;
28-
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2928

3029
import java.util.ArrayList;
3130
import java.util.Collection;
@@ -60,8 +59,7 @@ void doesNotCauseStackOverflow() {
6059
var sampleProducer = new PositiveSampleProducer(
6160
walks.iterator(0, nbrOfWalks),
6261
HugeDoubleArray.of(LongStream.range(0, nbrOfWalks).mapToDouble((l) -> 1.0).toArray()),
63-
10,
64-
ProgressTracker.NULL_TRACKER
62+
10
6563
);
6664

6765
var counter = 0L;
@@ -88,8 +86,7 @@ void doesNotCauseStackOverflowDueToBadLuck() {
8886
var sampleProducer = new PositiveSampleProducer(
8987
walks.iterator(0, nbrOfWalks),
9088
probabilities,
91-
10,
92-
ProgressTracker.NULL_TRACKER
89+
10
9390
);
9491
// does not overflow the stack = passes test
9592

@@ -112,8 +109,7 @@ void doesNotAttemptToFetchOutsideBatch() {
112109
var sampleProducer = new PositiveSampleProducer(
113110
walks.iterator(0, nbrOfWalks / 2),
114111
HugeDoubleArray.of(LongStream.range(0, nbrOfWalks).mapToDouble((l) -> 1.0).toArray()),
115-
10,
116-
ProgressTracker.NULL_TRACKER
112+
10
117113
);
118114

119115
var counter = 0L;
@@ -137,8 +133,7 @@ void shouldProducePairsWith(
137133
PositiveSampleProducer producer = new PositiveSampleProducer(
138134
walks.iterator(0, walks.size()),
139135
centerNodeProbabilities,
140-
windowSize,
141-
ProgressTracker.NULL_TRACKER
136+
windowSize
142137
);
143138
while (producer.next(buffer)) {
144139
actualPairs.add(Pair.of(buffer[0], buffer[1]));
@@ -160,8 +155,7 @@ void shouldProducePairsWithBounds() {
160155
PositiveSampleProducer producer = new PositiveSampleProducer(
161156
walks.iterator(0, 2),
162157
centerNodeProbabilities,
163-
3,
164-
ProgressTracker.NULL_TRACKER
158+
3
165159
);
166160
while (producer.next(buffer)) {
167161
actualPairs.add(Pair.of(buffer[0], buffer[1]));
@@ -206,8 +200,7 @@ void shouldRemoveDownsampledWordFromWalk() {
206200
PositiveSampleProducer producer = new PositiveSampleProducer(
207201
walks.iterator(0, walks.size()),
208202
centerNodeProbabilities,
209-
3,
210-
ProgressTracker.NULL_TRACKER
203+
3
211204
);
212205

213206
while (producer.next(buffer)) {

0 commit comments

Comments
 (0)