|
20 | 20 | package org.neo4j.gds.embeddings.hashgnn; |
21 | 21 |
|
22 | 22 | import org.assertj.core.api.Assertions; |
| 23 | +import org.assertj.core.data.Offset; |
23 | 24 | import org.junit.jupiter.api.Test; |
24 | 25 | import org.junit.jupiter.params.ParameterizedTest; |
25 | 26 | import org.junit.jupiter.params.provider.Arguments; |
26 | 27 | import org.junit.jupiter.params.provider.CsvSource; |
27 | 28 | import org.junit.jupiter.params.provider.MethodSource; |
| 29 | +import org.neo4j.gds.NodeLabel; |
| 30 | +import org.neo4j.gds.Orientation; |
28 | 31 | import org.neo4j.gds.ResourceUtil; |
29 | 32 | import org.neo4j.gds.TestSupport; |
30 | 33 | import org.neo4j.gds.api.Graph; |
| 34 | +import org.neo4j.gds.collections.HugeSparseLongArray; |
31 | 35 | import org.neo4j.gds.compat.Neo4jProxy; |
32 | 36 | import org.neo4j.gds.compat.TestLog; |
| 37 | +import org.neo4j.gds.core.concurrency.Pools; |
| 38 | +import org.neo4j.gds.core.loading.ArrayIdMap; |
| 39 | +import org.neo4j.gds.core.loading.LabelInformation; |
| 40 | +import org.neo4j.gds.core.loading.construction.GraphFactory; |
| 41 | +import org.neo4j.gds.core.loading.construction.RelationshipsBuilder; |
| 42 | +import org.neo4j.gds.core.utils.Intersections; |
33 | 43 | import org.neo4j.gds.core.utils.mem.MemoryRange; |
| 44 | +import org.neo4j.gds.core.utils.paged.HugeLongArray; |
34 | 45 | import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory; |
35 | 46 | import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; |
36 | 47 | import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker; |
37 | 48 | import org.neo4j.gds.extension.GdlExtension; |
38 | 49 | import org.neo4j.gds.extension.GdlGraph; |
39 | 50 | import org.neo4j.gds.extension.IdFunction; |
40 | 51 | import org.neo4j.gds.extension.Inject; |
| 52 | +import org.neo4j.gds.ml.util.ShuffleUtil; |
41 | 53 |
|
42 | 54 | import java.util.List; |
43 | 55 | import java.util.Map; |
| 56 | +import java.util.Optional; |
| 57 | +import java.util.SplittableRandom; |
44 | 58 | import java.util.stream.Stream; |
45 | 59 |
|
46 | 60 | import static org.assertj.core.api.AssertionsForClassTypes.assertThat; |
@@ -336,4 +350,88 @@ void shouldLogProgress(boolean dense) { |
336 | 350 | .extracting(removingThreadId()) |
337 | 351 | .containsExactlyElementsOf(ResourceUtil.lines(logResource)); |
338 | 352 | } |
| 353 | + |
| 354 | + @Test |
| 355 | + void shouldBeDeterministicGivenSameOriginalIds() { |
| 356 | + long nodeCount = 1000; |
| 357 | + int embeddingDimension = 32; |
| 358 | + long degree = 4; |
| 359 | + |
| 360 | + var firstMappedToOriginal = HugeLongArray.newArray(nodeCount); |
| 361 | + firstMappedToOriginal.setAll(nodeId -> nodeId); |
| 362 | + var firstOriginalToMappedBuilder = HugeSparseLongArray.builder(nodeCount); |
| 363 | + for (long nodeId = 0; nodeId < nodeCount; nodeId++) { |
| 364 | + firstOriginalToMappedBuilder.set(nodeId, nodeId); |
| 365 | + } |
| 366 | + var firstIdMap = new ArrayIdMap( |
| 367 | + firstMappedToOriginal, |
| 368 | + firstOriginalToMappedBuilder.build(), |
| 369 | + LabelInformation.single(new NodeLabel("hello")).build(nodeCount, firstMappedToOriginal::get), |
| 370 | + nodeCount, |
| 371 | + nodeCount - 1 |
| 372 | + ); |
| 373 | + RelationshipsBuilder firstRelationshipsBuilder = GraphFactory.initRelationshipsBuilder() |
| 374 | + .nodes(firstIdMap) |
| 375 | + .orientation(Orientation.UNDIRECTED) |
| 376 | + .executorService(Pools.DEFAULT) |
| 377 | + .build(); |
| 378 | + |
| 379 | + var secondMappedToOriginal = HugeLongArray.newArray(nodeCount); |
| 380 | + secondMappedToOriginal.setAll(nodeId -> nodeId); |
| 381 | + |
| 382 | + var gen = ShuffleUtil.createRandomDataGenerator(Optional.of(42L)); |
| 383 | + ShuffleUtil.shuffleArray(secondMappedToOriginal, gen); |
| 384 | + var secondOriginalToMappedBuilder = HugeSparseLongArray.builder(nodeCount); |
| 385 | + for (long nodeId = 0; nodeId < nodeCount; nodeId++) { |
| 386 | + secondOriginalToMappedBuilder.set(secondMappedToOriginal.get(nodeId), nodeId); |
| 387 | + } |
| 388 | + |
| 389 | + var secondIdMap = new ArrayIdMap( |
| 390 | + secondMappedToOriginal, |
| 391 | + secondOriginalToMappedBuilder.build(), |
| 392 | + LabelInformation.single(new NodeLabel("hello")).build(nodeCount, secondMappedToOriginal::get), |
| 393 | + nodeCount, |
| 394 | + nodeCount - 1 |
| 395 | + ); |
| 396 | + RelationshipsBuilder secondRelationshipsBuilder = GraphFactory.initRelationshipsBuilder() |
| 397 | + .nodes(secondIdMap) |
| 398 | + .orientation(Orientation.UNDIRECTED) |
| 399 | + .executorService(Pools.DEFAULT) |
| 400 | + .build(); |
| 401 | + |
| 402 | + var random = new SplittableRandom(42); |
| 403 | + for (long nodeId = 0; nodeId < nodeCount; nodeId++) { |
| 404 | + for (int j = 0; j < degree; j++) { |
| 405 | + long target = random.nextLong(nodeCount); |
| 406 | + firstRelationshipsBuilder.add(nodeId, target); |
| 407 | + secondRelationshipsBuilder.add(nodeId, target); |
| 408 | + } |
| 409 | + } |
| 410 | + |
| 411 | + var firstRelationships = firstRelationshipsBuilder.build(); |
| 412 | + var secondRelationships = secondRelationshipsBuilder.build(); |
| 413 | + |
| 414 | + var firstGraph = GraphFactory.create(firstIdMap, firstRelationships); |
| 415 | + var secondGraph = GraphFactory.create(secondIdMap, secondRelationships); |
| 416 | + |
| 417 | + var config = HashGNNConfigImpl |
| 418 | + .builder() |
| 419 | + .embeddingDensity(8) |
| 420 | + .generateFeatures(Map.of("dimension", embeddingDimension, "densityLevel", 2)) |
| 421 | + .iterations(2) |
| 422 | + .randomSeed(42L) |
| 423 | + .build(); |
| 424 | + |
| 425 | + var firstEmbeddings = new HashGNN(firstGraph, config, ProgressTracker.NULL_TRACKER).compute().embeddings(); |
| 426 | + var secondEmbeddings = new HashGNN(secondGraph, config, ProgressTracker.NULL_TRACKER).compute().embeddings(); |
| 427 | + |
| 428 | + double cosineSum = 0; |
| 429 | + for (long originalNodeId = 0; originalNodeId < nodeCount; originalNodeId++) { |
| 430 | + var firstVector = firstEmbeddings.get(firstGraph.toMappedNodeId(originalNodeId)); |
| 431 | + var secondVector = secondEmbeddings.get(secondGraph.toMappedNodeId(originalNodeId)); |
| 432 | + double cosine = Intersections.cosine(firstVector, secondVector, secondVector.length); |
| 433 | + cosineSum += cosine; |
| 434 | + } |
| 435 | + Assertions.assertThat(cosineSum / nodeCount).isCloseTo(1, Offset.offset(0.000001)); |
| 436 | + } |
339 | 437 | } |
0 commit comments