Skip to content

Commit 06f891c

Browse files
authored
Merge pull request #6520 from adamnsch/no-prop-hashgnn
Update handling of features in HashGNN
2 parents a540843 + 953a262 commit 06f891c

File tree

24 files changed

+1722
-1002
lines changed

24 files changed

+1722
-1002
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings.hashgnn;
21+
22+
import org.immutables.value.Value;
23+
import org.neo4j.gds.annotation.Configuration;
24+
25+
import java.util.Map;
26+
27+
@Configuration
28+
public interface BinarizeFeaturesConfig {
29+
@Configuration.IntegerRange(min = 1)
30+
int dimension();
31+
32+
default double threshold() {
33+
return 0.0;
34+
}
35+
36+
@Configuration.ToMap
37+
@Value.Auxiliary
38+
@Value.Derived
39+
default Map<String, Object> toMap() {
40+
return Map.of(); // Will be overwritten
41+
}
42+
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/BinarizeTask.java

Lines changed: 80 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
*/
2020
package org.neo4j.gds.embeddings.hashgnn;
2121

22-
import com.carrotsearch.hppc.BitSet;
22+
import org.apache.commons.lang3.mutable.MutableLong;
2323
import org.neo4j.gds.api.Graph;
2424
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
2525
import org.neo4j.gds.core.utils.TerminationFlag;
@@ -30,44 +30,42 @@
3030
import org.neo4j.gds.ml.core.features.FeatureConsumer;
3131
import org.neo4j.gds.ml.core.features.FeatureExtraction;
3232
import org.neo4j.gds.ml.core.features.FeatureExtractor;
33-
import org.neo4j.gds.ml.util.ShuffleUtil;
3433

35-
import java.util.ArrayList;
36-
import java.util.Arrays;
3734
import java.util.List;
3835
import java.util.SplittableRandom;
3936
import java.util.stream.Collectors;
4037

41-
import static org.neo4j.gds.embeddings.hashgnn.HashGNNCompanion.hashArgMin;
38+
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
4239

4340
class BinarizeTask implements Runnable {
4441
private final Partition partition;
45-
private final HashGNNConfig config;
4642
private final HugeObjectArray<HugeAtomicBitSet> truncatedFeatures;
4743
private final List<FeatureExtractor> featureExtractors;
48-
private final int[][] propertyEmbeddings;
49-
private final List<int[]> hashesList;
50-
private final HashGNN.MinAndArgmin minAndArgMin;
51-
private final FeatureBinarizationConfig binarizationConfig;
44+
private final double[][] propertyEmbeddings;
45+
46+
private final double threshold;
47+
private final int dimension;
5248
private final ProgressTracker progressTracker;
49+
private long totalFeatureCount;
50+
51+
private double scalarProductSum;
52+
53+
private double scalarProductSumOfSquares;
5354

5455
BinarizeTask(
5556
Partition partition,
56-
HashGNNConfig config,
57+
BinarizeFeaturesConfig config,
5758
HugeObjectArray<HugeAtomicBitSet> truncatedFeatures,
5859
List<FeatureExtractor> featureExtractors,
59-
int[][] propertyEmbeddings,
60-
List<int[]> hashesList,
60+
double[][] propertyEmbeddings,
6161
ProgressTracker progressTracker
6262
) {
6363
this.partition = partition;
64-
this.config = config;
65-
this.binarizationConfig = config.binarizeFeatures().orElseThrow();
64+
this.dimension = config.dimension();
65+
this.threshold = config.threshold();
6666
this.truncatedFeatures = truncatedFeatures;
6767
this.featureExtractors = featureExtractors;
6868
this.propertyEmbeddings = propertyEmbeddings;
69-
this.hashesList = hashesList;
70-
this.minAndArgMin = new HashGNN.MinAndArgmin();
7169
this.progressTracker = progressTracker;
7270
}
7371

@@ -77,36 +75,30 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
7775
HashGNNConfig config,
7876
SplittableRandom rng,
7977
ProgressTracker progressTracker,
80-
TerminationFlag terminationFlag
78+
TerminationFlag terminationFlag,
79+
MutableLong totalFeatureCountOutput
8180
) {
8281
progressTracker.beginSubTask("Binarize node property features");
8382

84-
var hashesList = new ArrayList<int[]>(config.embeddingDensity());
85-
for (int i = 0; i < config.embeddingDensity(); i++) {
86-
hashesList.add(HashGNNCompanion.HashTriple.computeHashesFromTriple(
87-
config.binarizeFeatures().get().dimension(),
88-
HashGNNCompanion.HashTriple.generate(rng)
89-
));
90-
}
83+
var binarizationConfig = config.binarizeFeatures().orElseThrow();
9184

9285
var featureExtractors = FeatureExtraction.propertyExtractors(
9386
graph,
9487
config.featureProperties()
9588
);
9689

9790
var inputDimension = FeatureExtraction.featureCount(featureExtractors);
98-
var propertyEmbeddings = embedProperties(config, rng, inputDimension);
91+
var propertyEmbeddings = embedProperties(binarizationConfig.dimension(), rng, inputDimension);
9992

10093
var truncatedFeatures = HugeObjectArray.newArray(HugeAtomicBitSet.class, graph.nodeCount());
10194

10295
var tasks = partition.stream()
10396
.map(p -> new BinarizeTask(
10497
p,
105-
config,
98+
binarizationConfig,
10699
truncatedFeatures,
107100
featureExtractors,
108101
propertyEmbeddings,
109-
hashesList,
110102
progressTracker
111103
))
112104
.collect(Collectors.toList());
@@ -116,90 +108,105 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
116108
.terminationFlag(terminationFlag)
117109
.run();
118110

111+
totalFeatureCountOutput.add(tasks.stream().mapToLong(BinarizeTask::totalFeatureCount).sum());
112+
113+
var squaredSum = tasks.stream().mapToDouble(BinarizeTask::scalarProductSumOfSquares).sum();
114+
var sum = tasks.stream().mapToDouble(BinarizeTask::scalarProductSum).sum();
115+
long exampleCount = graph.nodeCount() * binarizationConfig.dimension();
116+
var avg = sum / exampleCount;
117+
118+
var variance = (squaredSum - exampleCount * avg * avg) / exampleCount;
119+
var std = Math.sqrt(variance);
120+
121+
progressTracker.logInfo(formatWithLocale(
122+
"Hyperplane scalar products have mean %.4f and standard deviation %.4f. A threshold for binarization may be set to the mean plus a few standard deviations.",
123+
avg,
124+
std
125+
));
126+
119127
progressTracker.endSubTask("Binarize node property features");
120128

121129
return truncatedFeatures;
122130
}
123131

124-
// creates a sparse projection array with one row per input feature
132+
// creates a random projection vector for each feature
125133
// (input features vector for each node is the concatenation of the node's properties)
126-
// the first half of each row contains indices of positive output features in the projected space
127-
// the second half of each row contains indices of negative output features in the projected space
128134
// this array is used embed the properties themselves from inputDimension to embeddingDimension dimensions.
129-
public static int[][] embedProperties(HashGNNConfig config, SplittableRandom rng, int inputDimension) {
130-
var binarizationConfig = config.binarizeFeatures().orElseThrow();
131-
var permutation = new int[binarizationConfig.dimension()];
132-
Arrays.setAll(permutation, i -> i);
133-
134-
var propertyEmbeddings = new int[inputDimension][];
135+
public static double[][] embedProperties(int vectorDimension, SplittableRandom rng, int inputDimension) {
136+
var propertyEmbeddings = new double[inputDimension][];
135137

136138
for (int inputFeature = 0; inputFeature < inputDimension; inputFeature++) {
137-
ShuffleUtil.shuffleArray(permutation, rng);
138-
propertyEmbeddings[inputFeature] = new int[2 * binarizationConfig.densityLevel()];
139-
for (int feature = 0; feature < 2 * binarizationConfig.densityLevel(); feature++) {
140-
propertyEmbeddings[inputFeature][feature] = permutation[feature];
139+
propertyEmbeddings[inputFeature] = new double[vectorDimension];
140+
for (int feature = 0; feature < vectorDimension; feature++) {
141+
propertyEmbeddings[inputFeature][feature] = boxMullerGaussianRandom(rng);
141142
}
142143
}
143144
return propertyEmbeddings;
144145
}
145146

147+
private static double boxMullerGaussianRandom(SplittableRandom rng) {
148+
return Math.sqrt(-2 * Math.log(rng.nextDouble(
149+
0.0,
150+
1.0
151+
))) * Math.cos(2 * Math.PI * rng.nextDouble(0.0, 1.0));
152+
}
153+
146154
@Override
147155
public void run() {
148-
var tempFeatureContainer = new BitSet(binarizationConfig.dimension());
149-
150156
partition.consume(nodeId -> {
151-
var featureVector = new float[binarizationConfig.dimension()];
157+
var featureVector = new float[dimension];
152158
FeatureExtraction.extract(nodeId, -1, featureExtractors, new FeatureConsumer() {
153159
@Override
154160
public void acceptScalar(long nodeOffset, int offset, double value) {
155-
for (int feature = 0; feature < binarizationConfig.densityLevel(); feature++) {
156-
int positiveFeature = propertyEmbeddings[offset][feature];
157-
featureVector[positiveFeature] += value;
158-
}
159-
160-
for (int feature = binarizationConfig.densityLevel(); feature < 2 * binarizationConfig.densityLevel(); feature++) {
161-
int negativeFeature = propertyEmbeddings[offset][feature];
162-
featureVector[negativeFeature] -= value;
163-
161+
for (int feature = 0; feature < dimension; feature++) {
162+
double featureValue = propertyEmbeddings[offset][feature];
163+
featureVector[feature] += value * featureValue;
164164
}
165165
}
166166

167167
@Override
168168
public void acceptArray(long nodeOffset, int offset, double[] values) {
169169
for (int inputFeatureOffset = 0; inputFeatureOffset < values.length; inputFeatureOffset++) {
170-
for (int feature = 0; feature < binarizationConfig.densityLevel(); feature++) {
171-
int positiveFeature = propertyEmbeddings[offset + inputFeatureOffset][feature];
172-
featureVector[positiveFeature] += values[inputFeatureOffset];
173-
}
174-
for (int feature = binarizationConfig.densityLevel(); feature < 2 * binarizationConfig.densityLevel(); feature++) {
175-
int negativeFeature = propertyEmbeddings[offset + inputFeatureOffset][feature];
176-
featureVector[negativeFeature] -= values[inputFeatureOffset];
170+
double value = values[inputFeatureOffset];
171+
for (int feature = 0; feature < dimension; feature++) {
172+
double featureValue = propertyEmbeddings[offset + inputFeatureOffset][feature];
173+
featureVector[feature] += value * featureValue;
177174
}
178175
}
179176
}
180177
});
181178

182-
truncatedFeatures.set(nodeId, roundAndSample(tempFeatureContainer, featureVector));
179+
var featureSet = round(featureVector);
180+
totalFeatureCount += featureSet.cardinality();
181+
truncatedFeatures.set(nodeId, featureSet);
183182
});
184183

185184
progressTracker.logProgress(partition.nodeCount());
186185
}
187186

188-
private HugeAtomicBitSet roundAndSample(BitSet tempBitSet, float[] floatVector) {
189-
tempBitSet.clear();
187+
private HugeAtomicBitSet round(float[] floatVector) {
188+
var bitset = HugeAtomicBitSet.create(floatVector.length);
190189
for (int feature = 0; feature < floatVector.length; feature++) {
191-
if (floatVector[feature] > 0) {
192-
tempBitSet.set(feature);
193-
}
194-
}
195-
var sampledBitset = HugeAtomicBitSet.create(binarizationConfig.dimension());
196-
for (int i = 0; i < config.embeddingDensity(); i++) {
197-
hashArgMin(tempBitSet, hashesList.get(i), minAndArgMin);
198-
if (minAndArgMin.argMin != -1) {
199-
sampledBitset.set(minAndArgMin.argMin);
190+
var scalarProduct = floatVector[feature];
191+
scalarProductSum += scalarProduct;
192+
scalarProductSumOfSquares += scalarProduct * scalarProduct;
193+
if (scalarProduct > threshold) {
194+
bitset.set(feature);
200195
}
201196
}
202-
return sampledBitset;
197+
return bitset;
198+
}
199+
200+
public long totalFeatureCount() {
201+
return totalFeatureCount;
202+
}
203+
204+
public double scalarProductSum() {
205+
return scalarProductSum;
206+
}
207+
208+
public double scalarProductSumOfSquares() {
209+
return scalarProductSumOfSquares;
203210
}
204211

205212
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/FeatureBinarizationConfig.java renamed to algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/GenerateFeaturesConfig.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,13 @@
2424

2525
import java.util.Map;
2626

27-
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
28-
2927
@Configuration
30-
public interface FeatureBinarizationConfig {
28+
public interface GenerateFeaturesConfig {
3129
@Configuration.IntegerRange(min = 1)
3230
int dimension();
3331
@Configuration.IntegerRange(min = 1)
3432
int densityLevel();
3533

36-
@Value.Check
37-
default void validate() {
38-
if (2 * densityLevel() > dimension()) {
39-
throw new IllegalArgumentException(formatWithLocale("The value %d of `densityLevel` may not exceed half of the value %d of `dimension`.", densityLevel(), dimension()));
40-
}
41-
}
42-
4334
@Configuration.ToMap
4435
@Value.Auxiliary
4536
@Value.Derived

0 commit comments

Comments
 (0)