Skip to content

Commit 12ac66d

Browse files
Seed HashGNN random feature generation by original node IDs
Co-Authored-By: Jacob Sznajdman <breakanalysis@gmail.com>
1 parent 00e7987 commit 12ac66d

File tree

5 files changed

+204
-10
lines changed

5 files changed

+204
-10
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/GenerateFeaturesTask.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,31 @@
2929
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3030

3131
import java.util.List;
32-
import java.util.SplittableRandom;
32+
import java.util.Random;
3333
import java.util.stream.Collectors;
3434

3535
class GenerateFeaturesTask implements Runnable {
3636
private final Partition partition;
3737
private final HugeObjectArray<HugeAtomicBitSet> output;
38-
private final SplittableRandom rng;
38+
private final Graph graph;
39+
private final Random rng;
3940
private final GenerateFeaturesConfig generateFeaturesConfig;
4041
private final ProgressTracker progressTracker;
42+
private final long randomSeed;
4143
private long totalNumFeatures = 0;
4244

4345
GenerateFeaturesTask(
4446
Partition partition,
45-
SplittableRandom rng,
47+
Graph graph,
48+
long randomSeed,
4649
GenerateFeaturesConfig config,
4750
HugeObjectArray<HugeAtomicBitSet> output,
4851
ProgressTracker progressTracker
4952
) {
5053
this.partition = partition;
51-
this.rng = rng;
54+
this.graph = graph;
55+
this.rng = new Random();
56+
this.randomSeed = randomSeed;
5257
this.generateFeaturesConfig = config;
5358
this.output = output;
5459
this.progressTracker = progressTracker;
@@ -58,7 +63,7 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
5863
Graph graph,
5964
List<Partition> partition,
6065
HashGNNConfig config,
61-
SplittableRandom rng,
66+
long randomSeed,
6267
ProgressTracker progressTracker,
6368
TerminationFlag terminationFlag,
6469
MutableLong totalNumFeaturesOutput
@@ -70,7 +75,8 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
7075
var tasks = partition.stream()
7176
.map(p -> new GenerateFeaturesTask(
7277
p,
73-
rng.split(),
78+
graph,
79+
randomSeed,
7480
config.generateFeatures().get(),
7581
output,
7682
progressTracker
@@ -97,6 +103,7 @@ public void run() {
97103
partition.consume(nodeId -> {
98104
var generatedFeatures = HugeAtomicBitSet.create(dimension);
99105

106+
rng.setSeed(this.randomSeed ^ graph.toOriginalNodeId(nodeId));
100107
var randomInts = rng.ints(densityLevel, 0, dimension);
101108
randomInts.forEach(generatedFeatures::set);
102109

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNN.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ public HashGNN(Graph graph, HashGNNConfig config, ProgressTracker progressTracke
5757
super(progressTracker);
5858
this.graph = graph;
5959
this.config = config;
60-
this.randomSeed = config.randomSeed().orElse((new SplittableRandom().nextLong()));
60+
61+
long tempRandomSeed = config.randomSeed().orElse((new SplittableRandom().nextLong()));
62+
this.randomSeed = new SplittableRandom(tempRandomSeed).nextLong();
6163
this.rng = new SplittableRandom(randomSeed);
6264
}
6365

@@ -141,6 +143,7 @@ public HashGNNResult compute() {
141143
terminationFlag,
142144
totalSetBits
143145
);
146+
144147
double avgActiveFeatures = totalSetBits.doubleValue() / graph.nodeCount();
145148
progressTracker.logInfo(formatWithLocale("After iteration %d average node embedding density (number of active features) is %.4f.", iteration, avgActiveFeatures));
146149
}
@@ -249,7 +252,7 @@ private HugeObjectArray<HugeAtomicBitSet> constructInputEmbeddings(List<Partitio
249252
graph,
250253
partition,
251254
config,
252-
rng,
255+
randomSeed,
253256
progressTracker,
254257
terminationFlag,
255258
totalSetBits
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings.hashgnn;
21+
22+
import org.apache.commons.lang3.mutable.MutableLong;
23+
import org.assertj.core.data.Percentage;
24+
import org.junit.jupiter.api.Test;
25+
import org.neo4j.gds.api.Graph;
26+
import org.neo4j.gds.beta.generator.RandomGraphGenerator;
27+
import org.neo4j.gds.beta.generator.RelationshipDistribution;
28+
import org.neo4j.gds.core.utils.TerminationFlag;
29+
import org.neo4j.gds.core.utils.partition.Partition;
30+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
31+
32+
import java.util.List;
33+
import java.util.Map;
34+
35+
import static org.assertj.core.api.Assertions.assertThat;
36+
37+
class GenerateFeaturesTaskTest {
38+
39+
@Test
40+
void shouldGenerateCorrectNumberOfFeatures() {
41+
int embeddingDimension = 100;
42+
int densityLevel = 8;
43+
44+
Graph graph = RandomGraphGenerator.builder()
45+
.nodeCount(1000)
46+
.averageDegree(1)
47+
.relationshipDistribution(RelationshipDistribution.UNIFORM)
48+
.seed(42L)
49+
.build()
50+
.generate();
51+
52+
var partition = new Partition(0, graph.nodeCount());
53+
var totalNumFeatures = new MutableLong(0);
54+
var config = HashGNNConfigImpl
55+
.builder()
56+
.generateFeatures(Map.of("dimension", embeddingDimension, "densityLevel", densityLevel))
57+
.iterations(1337)
58+
.embeddingDensity(1337)
59+
.randomSeed(42L)
60+
.build();
61+
62+
var output = GenerateFeaturesTask.compute(
63+
graph,
64+
List.of(partition),
65+
config,
66+
42L,
67+
ProgressTracker.NULL_TRACKER,
68+
TerminationFlag.RUNNING_TRUE,
69+
totalNumFeatures
70+
);
71+
72+
assertThat(output.size()).isEqualTo(graph.nodeCount());
73+
assertThat(totalNumFeatures.getValue()).isCloseTo(
74+
densityLevel * graph.nodeCount(),
75+
Percentage.withPercentage(10)
76+
);
77+
78+
for (int nodeId = 0; nodeId < graph.nodeCount(); nodeId++) {
79+
var features = output.get(nodeId);
80+
assertThat(features.size()).isEqualTo(embeddingDimension);
81+
assertThat(features.cardinality()).isGreaterThanOrEqualTo(1);
82+
assertThat(features.cardinality()).isLessThanOrEqualTo(densityLevel);
83+
}
84+
}
85+
86+
}

algo/src/test/java/org/neo4j/gds/embeddings/hashgnn/HashGNNConfigTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ void binarizationConfigCorrectType() {
4343
@Test
4444
void shouldNotAllowGeneratedAndFeatureProperties() {
4545
assertThatThrownBy(() -> {
46-
var config = HashGNNConfigImpl
46+
HashGNNConfigImpl
4747
.builder()
4848
.featureProperties(List.of("x"))
4949
.generateFeatures(Map.of("dimension", 100, "densityLevel", 2))
@@ -56,7 +56,7 @@ void shouldNotAllowGeneratedAndFeatureProperties() {
5656
@Test
5757
void requiresFeaturePropertiesIfNoGeneratedFeatures() {
5858
assertThatThrownBy(() -> {
59-
var config = HashGNNConfigImpl
59+
HashGNNConfigImpl
6060
.builder()
6161
.embeddingDensity(4)
6262
.iterations(100)

algo/src/test/java/org/neo4j/gds/embeddings/hashgnn/HashGNNTest.java

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,41 @@
2020
package org.neo4j.gds.embeddings.hashgnn;
2121

2222
import org.assertj.core.api.Assertions;
23+
import org.assertj.core.data.Offset;
2324
import org.junit.jupiter.api.Test;
2425
import org.junit.jupiter.params.ParameterizedTest;
2526
import org.junit.jupiter.params.provider.Arguments;
2627
import org.junit.jupiter.params.provider.CsvSource;
2728
import org.junit.jupiter.params.provider.MethodSource;
29+
import org.neo4j.gds.NodeLabel;
30+
import org.neo4j.gds.Orientation;
2831
import org.neo4j.gds.ResourceUtil;
2932
import org.neo4j.gds.TestSupport;
3033
import org.neo4j.gds.api.Graph;
34+
import org.neo4j.gds.collections.HugeSparseLongArray;
3135
import org.neo4j.gds.compat.Neo4jProxy;
3236
import org.neo4j.gds.compat.TestLog;
37+
import org.neo4j.gds.core.concurrency.Pools;
38+
import org.neo4j.gds.core.loading.ArrayIdMap;
39+
import org.neo4j.gds.core.loading.LabelInformation;
40+
import org.neo4j.gds.core.loading.construction.GraphFactory;
41+
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;
42+
import org.neo4j.gds.core.utils.Intersections;
3343
import org.neo4j.gds.core.utils.mem.MemoryRange;
44+
import org.neo4j.gds.core.utils.paged.HugeLongArray;
3445
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
3546
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3647
import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker;
3748
import org.neo4j.gds.extension.GdlExtension;
3849
import org.neo4j.gds.extension.GdlGraph;
3950
import org.neo4j.gds.extension.IdFunction;
4051
import org.neo4j.gds.extension.Inject;
52+
import org.neo4j.gds.ml.util.ShuffleUtil;
4153

4254
import java.util.List;
4355
import java.util.Map;
56+
import java.util.Optional;
57+
import java.util.SplittableRandom;
4458
import java.util.stream.Stream;
4559

4660
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
@@ -336,4 +350,88 @@ void shouldLogProgress(boolean dense) {
336350
.extracting(removingThreadId())
337351
.containsExactlyElementsOf(ResourceUtil.lines(logResource));
338352
}
353+
354+
@Test
355+
void shouldBeDeterministicGivenSameOriginalIds() {
356+
long nodeCount = 1000;
357+
int embeddingDimension = 32;
358+
long degree = 4;
359+
360+
var firstMappedToOriginal = HugeLongArray.newArray(nodeCount);
361+
firstMappedToOriginal.setAll(nodeId -> nodeId);
362+
var firstOriginalToMappedBuilder = HugeSparseLongArray.builder(nodeCount);
363+
for (long nodeId = 0; nodeId < nodeCount; nodeId++) {
364+
firstOriginalToMappedBuilder.set(nodeId, nodeId);
365+
}
366+
var firstIdMap = new ArrayIdMap(
367+
firstMappedToOriginal,
368+
firstOriginalToMappedBuilder.build(),
369+
LabelInformation.single(new NodeLabel("hello")).build(nodeCount, firstMappedToOriginal::get),
370+
nodeCount,
371+
nodeCount - 1
372+
);
373+
RelationshipsBuilder firstRelationshipsBuilder = GraphFactory.initRelationshipsBuilder()
374+
.nodes(firstIdMap)
375+
.orientation(Orientation.UNDIRECTED)
376+
.executorService(Pools.DEFAULT)
377+
.build();
378+
379+
var secondMappedToOriginal = HugeLongArray.newArray(nodeCount);
380+
secondMappedToOriginal.setAll(nodeId -> nodeId);
381+
382+
var gen = ShuffleUtil.createRandomDataGenerator(Optional.of(42L));
383+
ShuffleUtil.shuffleArray(secondMappedToOriginal, gen);
384+
var secondOriginalToMappedBuilder = HugeSparseLongArray.builder(nodeCount);
385+
for (long nodeId = 0; nodeId < nodeCount; nodeId++) {
386+
secondOriginalToMappedBuilder.set(secondMappedToOriginal.get(nodeId), nodeId);
387+
}
388+
389+
var secondIdMap = new ArrayIdMap(
390+
secondMappedToOriginal,
391+
secondOriginalToMappedBuilder.build(),
392+
LabelInformation.single(new NodeLabel("hello")).build(nodeCount, secondMappedToOriginal::get),
393+
nodeCount,
394+
nodeCount - 1
395+
);
396+
RelationshipsBuilder secondRelationshipsBuilder = GraphFactory.initRelationshipsBuilder()
397+
.nodes(secondIdMap)
398+
.orientation(Orientation.UNDIRECTED)
399+
.executorService(Pools.DEFAULT)
400+
.build();
401+
402+
var random = new SplittableRandom(42);
403+
for (long nodeId = 0; nodeId < nodeCount; nodeId++) {
404+
for (int j = 0; j < degree; j++) {
405+
long target = random.nextLong(nodeCount);
406+
firstRelationshipsBuilder.add(nodeId, target);
407+
secondRelationshipsBuilder.add(nodeId, target);
408+
}
409+
}
410+
411+
var firstRelationships = firstRelationshipsBuilder.build();
412+
var secondRelationships = secondRelationshipsBuilder.build();
413+
414+
var firstGraph = GraphFactory.create(firstIdMap, firstRelationships);
415+
var secondGraph = GraphFactory.create(secondIdMap, secondRelationships);
416+
417+
var config = HashGNNConfigImpl
418+
.builder()
419+
.embeddingDensity(8)
420+
.generateFeatures(Map.of("dimension", embeddingDimension, "densityLevel", 2))
421+
.iterations(2)
422+
.randomSeed(42L)
423+
.build();
424+
425+
var firstEmbeddings = new HashGNN(firstGraph, config, ProgressTracker.NULL_TRACKER).compute().embeddings();
426+
var secondEmbeddings = new HashGNN(secondGraph, config, ProgressTracker.NULL_TRACKER).compute().embeddings();
427+
428+
double cosineSum = 0;
429+
for (long originalNodeId = 0; originalNodeId < nodeCount; originalNodeId++) {
430+
var firstVector = firstEmbeddings.get(firstGraph.toMappedNodeId(originalNodeId));
431+
var secondVector = secondEmbeddings.get(secondGraph.toMappedNodeId(originalNodeId));
432+
double cosine = Intersections.cosine(firstVector, secondVector, secondVector.length);
433+
cosineSum += cosine;
434+
}
435+
Assertions.assertThat(cosineSum / nodeCount).isCloseTo(1, Offset.offset(0.000001));
436+
}
339437
}

0 commit comments

Comments
 (0)