Skip to content

Commit 2d78e66

Browse files
committed
Fix bug in SubGraph
the neighbor adjacency is invoked by (mapped) nodeIds, but was stored using batch offsets due to possible duplicates of node ids in the batch, the offset is not neccessarily the same as the mapped id
1 parent f6320f4 commit 2d78e66

File tree

4 files changed

+68
-40
lines changed

4 files changed

+68
-40
lines changed

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,14 @@
2626
import org.junit.jupiter.api.Test;
2727
import org.junit.jupiter.params.ParameterizedTest;
2828
import org.junit.jupiter.params.provider.CsvSource;
29+
import org.junit.jupiter.params.provider.EnumSource;
2930
import org.junit.jupiter.params.provider.ValueSource;
3031
import org.neo4j.gds.Orientation;
3132
import org.neo4j.gds.api.Graph;
3233
import org.neo4j.gds.api.GraphStore;
34+
import org.neo4j.gds.beta.generator.PropertyProducer;
35+
import org.neo4j.gds.beta.generator.RandomGraphGenerator;
36+
import org.neo4j.gds.beta.generator.RelationshipDistribution;
3337
import org.neo4j.gds.core.concurrency.Pools;
3438
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
3539
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
@@ -77,6 +81,7 @@ class GraphSageModelTrainerTest {
7781
private Graph unweightedGraph;
7882
@Inject
7983
private Graph arrayGraph;
84+
8085
private HugeObjectArray<double[]> features;
8186
private GraphSageTrainConfigImpl.Builder configBuilder;
8287

@@ -97,6 +102,53 @@ void setUp() {
97102
.embeddingDimension(EMBEDDING_DIMENSION);
98103
}
99104

105+
// This reproduced bug in https://trello.com/c/BQ3e12K3/7826-250-graphsage-returns-nan-when-using-relationship-weights
106+
// https://github.com/neo4j/graph-data-science/issues/250
107+
@ParameterizedTest
108+
@EnumSource(AggregatorType.class)
109+
void trainsWithRelationshipWeight(AggregatorType aggregatorType) {
110+
111+
var config = GraphSageTrainConfigImpl.builder()
112+
.randomSeed(42L)
113+
.batchSize(100)
114+
.relationshipWeightProperty("p")
115+
.embeddingDimension(2)
116+
.aggregator(aggregatorType)
117+
.activationFunction(ActivationFunction.SIGMOID)
118+
.featureProperties(List.of("features"))
119+
.modelName("model")
120+
.modelUser("")
121+
.build();
122+
123+
var trainModel = new GraphSageModelTrainer(config, Pools.DEFAULT, ProgressTracker.NULL_TRACKER);
124+
125+
int nodeCount = 5_000;
126+
var bigGraph = RandomGraphGenerator
127+
.builder()
128+
.nodeCount(nodeCount)
129+
.averageDegree(1)
130+
.relationshipDistribution(RelationshipDistribution.UNIFORM)
131+
.nodePropertyProducer(PropertyProducer.randomEmbedding("features", 1, -100, 100))
132+
.relationshipPropertyProducer(PropertyProducer.fixedDouble("p", 0.5))
133+
.seed(42L)
134+
.build()
135+
.generate();
136+
137+
features = HugeObjectArray.newArray(double[].class, nodeCount);
138+
139+
LongStream.range(0, nodeCount).forEach(n -> features.set(n, bigGraph.nodeProperties().get("features").doubleArrayValue(n)));
140+
141+
GraphSageModelTrainer.ModelTrainResult result = trainModel.train(
142+
bigGraph,
143+
features
144+
);
145+
146+
assertThat(result.layers())
147+
.allSatisfy(layer -> assertThat(layer.weights())
148+
.noneMatch(weights -> TensorTestUtils.containsNaN(weights.data()))
149+
);
150+
}
151+
100152
@ParameterizedTest
101153
@ValueSource(booleans = {false, true})
102154
void trainsWithMeanAggregator(boolean useRelationshipWeight) {
@@ -236,18 +288,7 @@ void testLosses() {
236288
assertThat(metrics.ranIterationsPerEpoch()).containsExactly(100, 100, 100, 100, 100, 100, 100, 100, 100, 100);
237289

238290
assertThat(metrics.epochLosses().stream().mapToDouble(Double::doubleValue).toArray())
239-
.contains(new double[]{
240-
18.25,
241-
16.31,
242-
16.41,
243-
16.21,
244-
14.96,
245-
14.97,
246-
14.31,
247-
16.17,
248-
14.90,
249-
15.58
250-
}, Offset.offset(0.05)
291+
.contains(new double[]{19.55, 21.24, 19.90, 19.42, 17.87, 17.03, 17.04, 20.42, 15.86, 20.56}, Offset.offset(0.05)
251292
);
252293
}
253294

@@ -280,18 +321,7 @@ void testLossesWithPoolAggregator() {
280321
assertThat(metrics.ranIterationsPerEpoch()).containsOnly(10);
281322

282323
assertThat(metrics.epochLosses().stream().mapToDouble(Double::doubleValue).toArray())
283-
.contains(new double[]{
284-
23.41,
285-
19.94,
286-
19.70,
287-
21.62,
288-
19.06,
289-
24.11,
290-
19.72,
291-
16.47,
292-
19.74,
293-
20.97
294-
}, Offset.offset(0.05)
324+
.contains(new double[]{19.73, 21.25, 20.81, 23.13, 19.70, 25.34, 20.65, 17.10, 20.48, 21.51}, Offset.offset(0.05)
295325
);
296326
}
297327

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainAlgorithmFactoryTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -499,21 +499,21 @@ void testLogging() {
499499
"GraphSageTrain :: Train model :: Start",
500500
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Start",
501501
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: Start",
502-
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: Average loss per node: 26.49",
502+
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: Average loss per node: 25.58",
503503
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 100%",
504504
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 1 of 2 :: Finished",
505505
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 :: Start",
506-
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 :: Average loss per node: 25.58",
506+
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 :: Average loss per node: 26.69",
507507
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 100%",
508508
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Iteration 2 of 2 :: Finished",
509509
"GraphSageTrain :: Train model :: Epoch 1 of 2 :: Finished",
510510
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Start",
511511
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 :: Start",
512-
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 :: Average loss per node: 25.28",
512+
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 :: Average loss per node: 23.29",
513513
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 100%",
514514
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 1 of 2 :: Finished",
515515
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 :: Start",
516-
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 :: Average loss per node: 25.23",
516+
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 :: Average loss per node: 22.88",
517517
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 100%",
518518
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Iteration 2 of 2 :: Finished",
519519
"GraphSageTrain :: Train model :: Epoch 2 of 2 :: Finished",

ml/ml-core/src/main/java/org/neo4j/gds/ml/core/subgraph/SubGraph.java

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,35 +69,33 @@ public static List<SubGraph> buildSubGraphs(
6969
}
7070

7171
public static SubGraph buildSubGraph(long[] batchNodeIds, NeighborhoodFunction neighborhoodFunction, RelationshipWeights weightFunction) {
72-
int[][] adjacency = new int[batchNodeIds.length][];
73-
int[] batchedNodeIds = new int[batchNodeIds.length];
72+
int[] mappedBatchNodeIds = new int[batchNodeIds.length];
7473

7574
// mapping original long-based nodeIds into consecutive int-based ids
7675
LocalIdMap idmap = new LocalIdMap();
7776

7877
// map the input node ids
7978
// this assures they are in consecutive order
80-
for (long nodeId : batchNodeIds) {
81-
idmap.toMapped(nodeId);
82-
}
83-
8479
for (int nodeOffset = 0, nodeIdsLength = batchNodeIds.length; nodeOffset < nodeIdsLength; nodeOffset++) {
85-
long nodeId = batchNodeIds[nodeOffset];
80+
int mappedNodeId = idmap.toMapped(batchNodeIds[nodeOffset]);
81+
mappedBatchNodeIds[nodeOffset] = mappedNodeId;
82+
}
83+
int[][] adjacency = new int[idmap.size()][];
8684

87-
batchedNodeIds[nodeOffset] = idmap.toMapped(nodeId);
85+
for (int mappedNodeId = 0, mappedBatchIds = idmap.size(); mappedNodeId < mappedBatchIds; mappedNodeId++) {
8886

89-
var nodeNeighbors = neighborhoodFunction.sample(nodeId);
87+
var nodeNeighbors = neighborhoodFunction.sample(idmap.toOriginal(mappedNodeId));
9088

9189
// map sampled neighbors into local id space
9290
// this also expands the id mapping as the neighbours could be not in the nodeIds[]
9391
int[] neighborInternalIds = nodeNeighbors
9492
.mapToInt(idmap::toMapped)
9593
.toArray();
9694

97-
adjacency[nodeOffset] = neighborInternalIds;
95+
adjacency[mappedNodeId] = neighborInternalIds;
9896
}
9997

100-
return new SubGraph(adjacency, batchedNodeIds, idmap.originalIds(), weightFunction);
98+
return new SubGraph(adjacency, mappedBatchNodeIds, idmap.originalIds(), weightFunction);
10199
}
102100

103101
@Override

ml/ml-core/src/test/java/org/neo4j/gds/ml/core/subgraph/SubGraphBuilderTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ void shouldHandleDuplicatedNodes() {
205205
RelationshipWeights.UNWEIGHTED
206206
);
207207

208-
assertEquals(6, subGraph.neighbors.length);
208+
assertEquals(3, subGraph.neighbors.length);
209209
}
210210

211211
@Test

0 commit comments

Comments
 (0)