Skip to content

Commit 00e7987

Browse files
Use gaussian hyperplane rounding
Also contains other changes and test fixes Co-Authored-By: Adam Schill Collberg <adam.schill.collberg@protonmail.com>
1 parent 80b456e commit 00e7987

File tree

17 files changed

+685
-373
lines changed

17 files changed

+685
-373
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings.hashgnn;
21+
22+
import org.immutables.value.Value;
23+
import org.neo4j.gds.annotation.Configuration;
24+
25+
import java.util.Map;
26+
27+
@Configuration
28+
public interface BinarizeFeaturesConfig {
29+
@Configuration.IntegerRange(min = 1)
30+
int dimension();
31+
32+
default double threshold() {
33+
return 0.0;
34+
}
35+
36+
@Configuration.ToMap
37+
@Value.Auxiliary
38+
@Value.Derived
39+
default Map<String, Object> toMap() {
40+
return Map.of(); // Will be overwritten
41+
}
42+
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/BinarizeTask.java

Lines changed: 62 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -30,32 +30,39 @@
3030
import org.neo4j.gds.ml.core.features.FeatureConsumer;
3131
import org.neo4j.gds.ml.core.features.FeatureExtraction;
3232
import org.neo4j.gds.ml.core.features.FeatureExtractor;
33-
import org.neo4j.gds.ml.util.ShuffleUtil;
3433

35-
import java.util.Arrays;
3634
import java.util.List;
3735
import java.util.SplittableRandom;
3836
import java.util.stream.Collectors;
3937

38+
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
39+
4040
class BinarizeTask implements Runnable {
4141
private final Partition partition;
4242
private final HugeObjectArray<HugeAtomicBitSet> truncatedFeatures;
4343
private final List<FeatureExtractor> featureExtractors;
44-
private final int[][] propertyEmbeddings;
45-
private final FeatureBinarizationConfig binarizationConfig;
44+
private final double[][] propertyEmbeddings;
45+
46+
private final double threshold;
47+
private final BinarizeFeaturesConfig binarizationConfig;
4648
private final ProgressTracker progressTracker;
4749
private long totalNumFeatures;
4850

51+
private double scalarProductSum;
52+
53+
private double scalarProductSumOfSquares;
54+
4955
BinarizeTask(
5056
Partition partition,
5157
HashGNNConfig config,
5258
HugeObjectArray<HugeAtomicBitSet> truncatedFeatures,
5359
List<FeatureExtractor> featureExtractors,
54-
int[][] propertyEmbeddings,
60+
double[][] propertyEmbeddings,
5561
ProgressTracker progressTracker
5662
) {
5763
this.partition = partition;
5864
this.binarizationConfig = config.binarizeFeatures().orElseThrow();
65+
this.threshold = binarizationConfig.threshold();
5966
this.truncatedFeatures = truncatedFeatures;
6067
this.featureExtractors = featureExtractors;
6168
this.propertyEmbeddings = propertyEmbeddings;
@@ -73,6 +80,8 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
7380
) {
7481
progressTracker.beginSubTask("Binarize node property features");
7582

83+
var binarizationConfig = config.binarizeFeatures().orElseThrow();
84+
7685
var featureExtractors = FeatureExtraction.propertyExtractors(
7786
graph,
7887
config.featureProperties()
@@ -101,28 +110,33 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
101110

102111
totalNumFeaturesOutput.add(tasks.stream().mapToLong(BinarizeTask::totalNumFeatures).sum());
103112

113+
var squaredSum = tasks.stream().mapToDouble(BinarizeTask::scalarProductSumOfSquares).sum();
114+
var sum = tasks.stream().mapToDouble(BinarizeTask::scalarProductSum).sum();
115+
long exampleCount = graph.nodeCount() * binarizationConfig.dimension();
116+
var avg = sum / exampleCount;
117+
118+
var variance = (squaredSum - exampleCount * avg * avg) / exampleCount;
119+
var std = Math.sqrt(variance);
120+
121+
progressTracker.logInfo(formatWithLocale("Hyperplane scalar products have mean %.4f and standard deviation %.4f. A threshold for binarization may be set to the average plus a few standard deviations.", avg, std));
122+
104123
progressTracker.endSubTask("Binarize node property features");
105124

106125
return truncatedFeatures;
107126
}
108-
109-
// creates a sparse projection array with one row per input feature
127+
// creates a random projection vector for each feature
110128
// (input features vector for each node is the concatenation of the node's properties)
111-
// the first half of each row contains indices of positive output features in the projected space
112-
// the second half of each row contains indices of negative output features in the projected space
113129
// this array is used embed the properties themselves from inputDimension to embeddingDimension dimensions.
114-
public static int[][] embedProperties(HashGNNConfig config, SplittableRandom rng, int inputDimension) {
130+
public static double[][] embedProperties(HashGNNConfig config, SplittableRandom rng, int inputDimension) {
115131
var binarizationConfig = config.binarizeFeatures().orElseThrow();
116-
var permutation = new int[binarizationConfig.dimension()];
117-
Arrays.setAll(permutation, i -> i);
118-
119-
var propertyEmbeddings = new int[inputDimension][];
132+
var propertyEmbeddings = new double[inputDimension][];
120133

121134
for (int inputFeature = 0; inputFeature < inputDimension; inputFeature++) {
122-
ShuffleUtil.shuffleArray(permutation, rng);
123-
propertyEmbeddings[inputFeature] = new int[2 * binarizationConfig.densityLevel()];
124-
for (int feature = 0; feature < 2 * binarizationConfig.densityLevel(); feature++) {
125-
propertyEmbeddings[inputFeature][feature] = permutation[feature];
135+
propertyEmbeddings[inputFeature] = new double[binarizationConfig.dimension()];
136+
for (int feature = 0; feature < binarizationConfig.dimension(); feature++) {
137+
// Box-muller transformation to generate gaussian
138+
double matrixValue = Math.sqrt(-2*Math.log(rng.nextDouble(0.0, 1.0))) * Math.cos(2*Math.PI * rng.nextDouble(0.0, 1.0));
139+
propertyEmbeddings[inputFeature][feature] = matrixValue;
126140
}
127141
}
128142
return propertyEmbeddings;
@@ -135,48 +149,54 @@ public void run() {
135149
FeatureExtraction.extract(nodeId, -1, featureExtractors, new FeatureConsumer() {
136150
@Override
137151
public void acceptScalar(long nodeOffset, int offset, double value) {
138-
for (int feature = 0; feature < binarizationConfig.densityLevel(); feature++) {
139-
int positiveFeature = propertyEmbeddings[offset][feature];
140-
featureVector[positiveFeature] += value;
141-
}
142-
143-
for (int feature = binarizationConfig.densityLevel(); feature < 2 * binarizationConfig.densityLevel(); feature++) {
144-
int negativeFeature = propertyEmbeddings[offset][feature];
145-
featureVector[negativeFeature] -= value;
146-
152+
for (int feature = 0; feature < binarizationConfig.dimension(); feature++) {
153+
double featureValue = propertyEmbeddings[offset][feature];
154+
featureVector[feature] += value * featureValue;
147155
}
148156
}
149157

150158
@Override
151159
public void acceptArray(long nodeOffset, int offset, double[] values) {
152160
for (int inputFeatureOffset = 0; inputFeatureOffset < values.length; inputFeatureOffset++) {
153-
for (int feature = 0; feature < binarizationConfig.densityLevel(); feature++) {
154-
int positiveFeature = propertyEmbeddings[offset + inputFeatureOffset][feature];
155-
featureVector[positiveFeature] += values[inputFeatureOffset];
156-
}
157-
for (int feature = binarizationConfig.densityLevel(); feature < 2 * binarizationConfig.densityLevel(); feature++) {
158-
int negativeFeature = propertyEmbeddings[offset + inputFeatureOffset][feature];
159-
featureVector[negativeFeature] -= values[inputFeatureOffset];
161+
double value = values[inputFeatureOffset];
162+
for (int feature = 0; feature < binarizationConfig.dimension(); feature++) {
163+
double featureValue = propertyEmbeddings[offset + inputFeatureOffset][feature];
164+
featureVector[feature] += value * featureValue;
160165
}
161166
}
162167
}
163168
});
164169

165-
var bitSet = HugeAtomicBitSet.create(binarizationConfig.dimension());
166-
for (int feature = 0; feature < featureVector.length; feature++) {
167-
if (featureVector[feature] > 0) {
168-
bitSet.set(feature);
169-
}
170-
}
171-
totalNumFeatures += bitSet.cardinality();
172-
truncatedFeatures.set(nodeId, bitSet);
170+
var featureSet = round(featureVector);
171+
totalNumFeatures += featureSet.cardinality();
172+
truncatedFeatures.set(nodeId, featureSet);
173173
});
174174

175175
progressTracker.logProgress(partition.nodeCount());
176176
}
177177

178+
private HugeAtomicBitSet round(float[] floatVector) {
179+
var bitset = HugeAtomicBitSet.create(floatVector.length);
180+
for (int feature = 0; feature < floatVector.length; feature++) {
181+
var scalarProduct = floatVector[feature];
182+
scalarProductSum += scalarProduct;
183+
scalarProductSumOfSquares += scalarProduct * scalarProduct;
184+
if (scalarProduct > threshold) {
185+
bitset.set(feature);
186+
}
187+
}
188+
return bitset;
189+
}
190+
178191
public long totalNumFeatures() {
179192
return totalNumFeatures;
180193
}
181194

195+
public double scalarProductSum() {
196+
return scalarProductSum;
197+
}
198+
public double scalarProductSumOfSquares() {
199+
return scalarProductSumOfSquares;
200+
}
201+
182202
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/FeatureBinarizationConfig.java renamed to algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/GenerateFeaturesConfig.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,13 @@
2424

2525
import java.util.Map;
2626

27-
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
28-
2927
@Configuration
30-
public interface FeatureBinarizationConfig {
28+
public interface GenerateFeaturesConfig {
3129
@Configuration.IntegerRange(min = 1)
3230
int dimension();
3331
@Configuration.IntegerRange(min = 1)
3432
int densityLevel();
3533

36-
@Value.Check
37-
default void validate() {
38-
if (2 * densityLevel() > dimension()) {
39-
throw new IllegalArgumentException(formatWithLocale("The value %d of `densityLevel` may not exceed half of the value %d of `dimension`.", densityLevel(), dimension()));
40-
}
41-
}
42-
4334
@Configuration.ToMap
4435
@Value.Auxiliary
4536
@Value.Derived

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/GenerateFeaturesTask.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ class GenerateFeaturesTask implements Runnable {
3636
private final Partition partition;
3737
private final HugeObjectArray<HugeAtomicBitSet> output;
3838
private final SplittableRandom rng;
39-
private final FeatureBinarizationConfig generateFeaturesConfig;
39+
private final GenerateFeaturesConfig generateFeaturesConfig;
4040
private final ProgressTracker progressTracker;
4141
private long totalNumFeatures = 0;
4242

4343
GenerateFeaturesTask(
4444
Partition partition,
4545
SplittableRandom rng,
46-
FeatureBinarizationConfig config,
46+
GenerateFeaturesConfig config,
4747
HugeObjectArray<HugeAtomicBitSet> output,
4848
ProgressTracker progressTracker
4949
) {

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNN.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
import java.util.function.Function;
4141
import java.util.stream.Collectors;
4242

43+
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
44+
4345
/**
4446
* Based on the paper "Hashing-Accelerated Graph Neural Networks for Link Prediction"
4547
*/
@@ -89,6 +91,9 @@ public HashGNNResult compute() {
8991
var embeddingsB = constructInputEmbeddings(rangePartition);
9092
int embeddingDimension = (int) embeddingsB.get(0).size();
9193

94+
double avgInputActiveFeatures = totalSetBits.doubleValue() / graph.nodeCount();
95+
progressTracker.logInfo(formatWithLocale("Density (number of active features) of binary input features is %.4f.", avgInputActiveFeatures));
96+
9297
var embeddingsA = HugeObjectArray.newArray(HugeAtomicBitSet.class, graph.nodeCount());
9398
embeddingsA.setAll(unused -> HugeAtomicBitSet.create(embeddingDimension));
9499

@@ -136,6 +141,8 @@ public HashGNNResult compute() {
136141
terminationFlag,
137142
totalSetBits
138143
);
144+
double avgActiveFeatures = totalSetBits.doubleValue() / graph.nodeCount();
145+
progressTracker.logInfo(formatWithLocale("After iteration %d average node embedding density (number of active features) is %.4f.", iteration, avgActiveFeatures));
139146
}
140147

141148
progressTracker.endSubTask("Propagate embeddings");

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNNConfig.java

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
package org.neo4j.gds.embeddings.hashgnn;
2121

22+
import org.immutables.value.Value;
2223
import org.neo4j.gds.annotation.Configuration;
2324
import org.neo4j.gds.config.AlgoBaseConfig;
2425
import org.neo4j.gds.config.FeaturePropertiesConfig;
@@ -49,26 +50,48 @@ default boolean heterogeneous() {
4950
return false;
5051
}
5152

52-
@Configuration.ToMapValue("org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#toMapBinarizationConfig")
53-
@Configuration.ConvertWith(method = "org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#parseBinarizationConfig", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
54-
default Optional<FeatureBinarizationConfig> generateFeatures() {
53+
@Configuration.ToMapValue("org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#toMapGenerateFeaturesConfig")
54+
@Configuration.ConvertWith(method = "org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#parseGenerateFeaturesConfig", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
55+
default Optional<GenerateFeaturesConfig> generateFeatures() {
5556
return Optional.empty();
5657
}
5758

5859
@Configuration.ToMapValue("org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#toMapBinarizationConfig")
5960
@Configuration.ConvertWith(method = "org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#parseBinarizationConfig", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
60-
default Optional<FeatureBinarizationConfig> binarizeFeatures() {
61+
default Optional<BinarizeFeaturesConfig> binarizeFeatures() {
6162
return Optional.empty();
6263
}
6364

64-
static Optional<FeatureBinarizationConfig> parseBinarizationConfig(Object o) {
65+
@Value.Check
66+
default void validate() {
67+
if (!featureProperties().isEmpty() && generateFeatures().isPresent()) {
68+
throw new IllegalArgumentException("It is not allowed to use `generateFeatures` and have non-empty `featureProperties`.");
69+
}
70+
if (generateFeatures().isPresent()) return;
71+
if (featureProperties().isEmpty()) {
72+
throw new IllegalArgumentException("When `generateFeatures` is not given, `featureProperties` must be non-empty.");
73+
}
74+
}
75+
76+
static Optional<BinarizeFeaturesConfig> parseBinarizationConfig(Object o) {
77+
if (o instanceof Optional) {
78+
return (Optional<BinarizeFeaturesConfig>) o;
79+
}
80+
return Optional.of(new BinarizeFeaturesConfigImpl(CypherMapWrapper.create((Map<String, Object>) o)));
81+
}
82+
83+
static Map<String, Object> toMapBinarizationConfig(BinarizeFeaturesConfig config) {
84+
return config.toMap();
85+
}
86+
87+
static Optional<GenerateFeaturesConfig> parseGenerateFeaturesConfig(Object o) {
6588
if (o instanceof Optional) {
66-
return (Optional<FeatureBinarizationConfig>) o;
89+
return (Optional<GenerateFeaturesConfig>) o;
6790
}
68-
return Optional.of(new FeatureBinarizationConfigImpl(CypherMapWrapper.create((Map<String, Object>) o)));
91+
return Optional.of(new GenerateFeaturesConfigImpl(CypherMapWrapper.create((Map<String, Object>) o)));
6992
}
7093

71-
static Map<String, Object> toMapBinarizationConfig(FeatureBinarizationConfig config) {
94+
static Map<String, Object> toMapGenerateFeaturesConfig(GenerateFeaturesConfig config) {
7295
return config.toMap();
7396
}
7497
}

0 commit comments

Comments
 (0)