Skip to content

Commit d3b3c47

Browse files
Merge pull request #6567 from breakanalysis/hashgnn-validation
Add HashGNN validation
2 parents 6a4152f + a55916e commit d3b3c47

File tree

9 files changed

+130
-6
lines changed

9 files changed

+130
-6
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/BinarizeFeaturesConfig.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import org.immutables.value.Value;
2323
import org.neo4j.gds.annotation.Configuration;
2424

25+
import java.util.Collection;
26+
import java.util.List;
2527
import java.util.Map;
2628

2729
@Configuration
@@ -39,4 +41,13 @@ default double threshold() {
3941
default Map<String, Object> toMap() {
4042
return Map.of(); // Will be overwritten
4143
}
44+
45+
@Configuration.CollectKeys
46+
@Value.Auxiliary
47+
@Value.Default
48+
@Value.Parameter(false)
49+
default Collection<String> configKeys() {
50+
return List.of();
51+
}
52+
4253
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/GenerateFeaturesConfig.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@
2222
import org.immutables.value.Value;
2323
import org.neo4j.gds.annotation.Configuration;
2424

25+
import java.util.Collection;
26+
import java.util.List;
2527
import java.util.Map;
2628

29+
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
30+
2731
@Configuration
2832
public interface GenerateFeaturesConfig {
2933
@Configuration.IntegerRange(min = 1)
@@ -37,4 +41,19 @@ public interface GenerateFeaturesConfig {
3741
default Map<String, Object> toMap() {
3842
return Map.of(); // Will be overwritten
3943
}
44+
45+
@Configuration.CollectKeys
46+
@Value.Auxiliary
47+
@Value.Default
48+
@Value.Parameter(false)
49+
default Collection<String> configKeys() {
50+
return List.of();
51+
}
52+
53+
@Value.Check
54+
default void validate() {
55+
if (densityLevel() > dimension()) {
56+
throw new IllegalArgumentException(formatWithLocale("Generate features requires `densityLevel` to be at most `dimension` but was %d > %d.", densityLevel(), dimension()));
57+
}
58+
}
4059
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNNConfig.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ static Optional<BinarizeFeaturesConfig> parseBinarizationConfig(Object o) {
7777
if (o instanceof Optional) {
7878
return (Optional<BinarizeFeaturesConfig>) o;
7979
}
80-
return Optional.of(new BinarizeFeaturesConfigImpl(CypherMapWrapper.create((Map<String, Object>) o)));
80+
var cypherMapWrapper = CypherMapWrapper.create((Map<String, Object>) o);
81+
var binarizeFeaturesConfig = new BinarizeFeaturesConfigImpl(cypherMapWrapper);
82+
cypherMapWrapper.requireOnlyKeysFrom(binarizeFeaturesConfig.configKeys());
83+
return Optional.of(binarizeFeaturesConfig);
8184
}
8285

8386
static Map<String, Object> toMapBinarizationConfig(BinarizeFeaturesConfig config) {
@@ -88,10 +91,14 @@ static Optional<GenerateFeaturesConfig> parseGenerateFeaturesConfig(Object o) {
8891
if (o instanceof Optional) {
8992
return (Optional<GenerateFeaturesConfig>) o;
9093
}
91-
return Optional.of(new GenerateFeaturesConfigImpl(CypherMapWrapper.create((Map<String, Object>) o)));
94+
var cypherMapWrapper = CypherMapWrapper.create((Map<String, Object>) o);
95+
var generateFeaturesConfig = new GenerateFeaturesConfigImpl(cypherMapWrapper);
96+
cypherMapWrapper.requireOnlyKeysFrom(generateFeaturesConfig.configKeys());
97+
return Optional.of(generateFeaturesConfig);
9298
}
9399

94100
static Map<String, Object> toMapGenerateFeaturesConfig(GenerateFeaturesConfig config) {
95101
return config.toMap();
96102
}
103+
97104
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNNMutateConfig.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ public interface HashGNNMutateConfig extends HashGNNConfig, MutatePropertyConfig
2929
static HashGNNMutateConfig of(CypherMapWrapper config) {
3030
return new HashGNNMutateConfigImpl(config);
3131
}
32+
3233
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/RawFeaturesTask.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,11 @@
3434
import java.util.List;
3535
import java.util.stream.Collectors;
3636

37+
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
38+
3739
class RawFeaturesTask implements Runnable {
3840
private final Partition partition;
41+
private final Graph graph;
3942
private final List<FeatureExtractor> featureExtractors;
4043
private final int inputDimension;
4144
private final HugeObjectArray<HugeAtomicBitSet> features;
@@ -44,12 +47,14 @@ class RawFeaturesTask implements Runnable {
4447

4548
RawFeaturesTask(
4649
Partition partition,
50+
Graph graph,
4751
List<FeatureExtractor> featureExtractors,
4852
int inputDimension,
4953
HugeObjectArray<HugeAtomicBitSet> features,
5054
ProgressTracker progressTracker
5155
) {
5256
this.partition = partition;
57+
this.graph = graph;
5358
this.featureExtractors = featureExtractors;
5459
this.inputDimension = inputDimension;
5560
this.features = features;
@@ -77,6 +82,7 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
7782
var tasks = partitions.stream()
7883
.map(p -> new RawFeaturesTask(
7984
p,
85+
graph,
8086
featureExtractors,
8187
inputDimension,
8288
features,
@@ -103,16 +109,21 @@ public void run() {
103109
FeatureExtraction.extract(nodeId, -1, featureExtractors, new FeatureConsumer() {
104110
@Override
105111
public void acceptScalar(long nodeOffset, int offset, double value) {
106-
if (value != 0.0) {
112+
if (value == 1.0) {
107113
nodeFeatures.set(offset);
114+
} else if (value != 0.0) {
115+
throw new IllegalArgumentException(formatWithLocale("Feature properties may only contain values 0 and 1 unless `binarizeFeatures` is used. Node %d and possibly other nodes have a feature property containing value %f", graph.toOriginalNodeId(nodeId), value));
108116
}
109117
}
110118

111119
@Override
112120
public void acceptArray(long nodeOffset, int offset, double[] values) {
113121
for (int inputFeatureOffset = 0; inputFeatureOffset < values.length; inputFeatureOffset++) {
114-
if (values[inputFeatureOffset] != 0.0) {
122+
var value = values[inputFeatureOffset];
123+
if (value == 1.0) {
115124
nodeFeatures.set(offset + inputFeatureOffset);
125+
} else if (value != 0.0) {
126+
throw new IllegalArgumentException(formatWithLocale("Feature properties may only contain values 0 and 1 unless `binarizeFeatures` is used. Node %d and possibly other nodes have a feature property containing value %.17f", graph.toOriginalNodeId(nodeId), value));
116127
}
117128
}
118129
}

algo/src/test/java/org/neo4j/gds/embeddings/hashgnn/HashGNNConfigTest.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.neo4j.gds.embeddings.hashgnn;
2121

2222
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.core.CypherMapWrapper;
2324

2425
import java.util.List;
2526
import java.util.Map;
@@ -63,4 +64,48 @@ void requiresFeaturePropertiesIfNoGeneratedFeatures() {
6364
.build();
6465
}).hasMessage("When `generateFeatures` is not given, `featureProperties` must be non-empty.");
6566
}
67+
68+
@Test
69+
void requiresDensityLevelAtMostDensity() {
70+
assertThatThrownBy(() -> {
71+
HashGNNConfigImpl
72+
.builder()
73+
.embeddingDensity(4)
74+
.generateFeatures(Map.of("dimension", 4, "densityLevel", 5))
75+
.iterations(100)
76+
.build();
77+
}).hasMessage("Generate features requires `densityLevel` to be at most `dimension` but was 5 > 4.");
78+
}
79+
80+
@Test
81+
void failsOnInvalidBinarizationKeys() {
82+
assertThatThrownBy(() -> {
83+
new HashGNNConfigImpl(CypherMapWrapper.create(
84+
Map.of(
85+
"mutateProperty", "foo",
86+
"featureProperties", List.of("x"),
87+
"binarizeFeatures", Map.of("dimension", 100, "treshold", 2.0),
88+
"embeddingDensity", 4,
89+
"iterations", 100
90+
)
91+
));
92+
93+
}).isInstanceOf(IllegalArgumentException.class)
94+
.hasMessage("Unexpected configuration key: treshold (Did you mean [threshold]?)");
95+
}
96+
97+
@Test
98+
void failsOnInvalidGenerateFeaturesKeys() {
99+
assertThatThrownBy(() -> {
100+
new HashGNNConfigImpl(CypherMapWrapper.create(
101+
Map.of(
102+
"generateFeatures", Map.of("dimension", 100, "densityElfen", 2),
103+
"embeddingDensity", 4,
104+
"iterations", 100
105+
)
106+
));
107+
108+
}).isInstanceOf(IllegalArgumentException.class)
109+
.hasMessage("No value specified for the mandatory configuration parameter `densityLevel` (a similar parameter exists: [densityElfen])");
110+
}
66111
}

algo/src/test/java/org/neo4j/gds/embeddings/hashgnn/RawFeaturesTaskTest.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import java.util.List;
3535

3636
import static org.assertj.core.api.Assertions.assertThat;
37+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
3738

3839
@GdlExtension
3940
class RawFeaturesTaskTest {
@@ -45,12 +46,41 @@ class RawFeaturesTaskTest {
4546
", (b:N {f1: 1, f2: [1.0, 0.0]})" +
4647
", (c:N {f1: 1, f2: [0.0, 1.0]})";
4748

49+
@GdlGraph(graphNamePrefix = "nonBinary")
50+
private static final String NON_BINARY =
51+
"CREATE" +
52+
" (a:N {f1: 1, f2: [1.0, 1.0]})" +
53+
", (b:N {f1: 1, f2: [1.0, 0.0000000001]})" +
54+
", (c:N {f1: 1, f2: [0.0, 1.0]})";
55+
4856
@Inject
4957
private Graph graph;
5058

59+
@Inject
60+
private Graph nonBinaryGraph;
61+
5162
@Inject
5263
private IdFunction idFunction;
5364

65+
@Test
66+
void shouldFailOnNonBinaryFeatures() {
67+
var partition = new Partition(0, nonBinaryGraph.nodeCount());
68+
var featureExtractors = FeatureExtraction.propertyExtractors(nonBinaryGraph, List.of("f1", "f2"));
69+
var features = HugeObjectArray.newArray(HugeAtomicBitSet.class, nonBinaryGraph.nodeCount());
70+
var inputDimension = FeatureExtraction.featureCount(featureExtractors);
71+
72+
assertThatThrownBy(() -> {
73+
new RawFeaturesTask(
74+
partition,
75+
nonBinaryGraph,
76+
featureExtractors,
77+
inputDimension,
78+
features,
79+
ProgressTracker.NULL_TRACKER
80+
).run();
81+
}).isInstanceOf(IllegalArgumentException.class)
82+
.hasMessage("Feature properties may only contain values 0 and 1 unless `binarizeFeatures` is used. Node 1 and possibly other nodes have a feature property containing value 0.00000000010000000");
83+
}
5484
@Test
5585
void shouldPickCorrectFeatures() {
5686
var partition = new Partition(0, graph.nodeCount());
@@ -60,6 +90,7 @@ void shouldPickCorrectFeatures() {
6090

6191
new RawFeaturesTask(
6292
partition,
93+
graph,
6394
featureExtractors,
6495
inputDimension,
6596
features,

doc/modules/ROOT/partials/machine-learning/node-embeddings/hashgnn/specific-configuration.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
| heterogeneous | Boolean | false | yes | Whether different relationship types should be treated differently.
55
| neighborInfluence | Float | 1.0 | yes | Controls how often neighbors' features are sampled in each iteration relative to sampling the node's own features. Must be non-negative.
66
| binarizeFeatures | Map | n/a | yes | A map with keys `dimension` and `threshold`. If given, features are transformed into `dimension` binary features via hyperplane rounding. Increasing `threshold` makes the output more sparse, and it defaults to `0`. The value of `dimension` must be at least 1.
7-
| generateFeatures | Map | n/a | yes | A map with keys `dimension` and `densityLevel`. Should be given if and only if `featureProperties` is empty. If given, `dimension` binary features are generated with approximately `densityLevel` active features per node. Both must be at least 1.
7+
| generateFeatures | Map | n/a | yes | A map with keys `dimension` and `densityLevel`. Should be given if and only if `featureProperties` is empty. If given, `dimension` binary features are generated with approximately `densityLevel` active features per node. Both must be at least 1 and `densityLevel` at most `dimension`.
88
| outputDimension | Integer | n/a | yes | If given, the embeddings are projected randomly into `outputDimension` dense features. Must be at least 1.
99
| randomSeed | Integer | n/a | yes | A random seed which is used for all randomness in computing the embeddings.

proc/embeddings/src/test/java/org/neo4j/gds/embeddings/hashgnn/HashGNNStreamProcTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,5 @@ void shouldComputeNonZeroEmbeddings() {
6868
.hasSize(3)
6969
.anyMatch(value -> value != 0.0);
7070
});
71-
7271
}
7372
}

0 commit comments

Comments
 (0)