Skip to content

Commit 953a262

Browse files
adamnschbreakanalysisMats-SX
committed
Address review comments
Co-Authored-By: Jacob Sznajdman <breakanalysis@gmail.com> Co-Authored-By: Mats Rydberg <mats.rydberg@neotechnology.com>
1 parent a85b88b commit 953a262

File tree

10 files changed

+114
-146
lines changed

10 files changed

+114
-146
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/BinarizeTask.java

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,25 @@ class BinarizeTask implements Runnable {
4444
private final double[][] propertyEmbeddings;
4545

4646
private final double threshold;
47-
private final BinarizeFeaturesConfig binarizationConfig;
47+
private final int dimension;
4848
private final ProgressTracker progressTracker;
49-
private long totalNumFeatures;
49+
private long totalFeatureCount;
5050

5151
private double scalarProductSum;
5252

5353
private double scalarProductSumOfSquares;
5454

5555
BinarizeTask(
5656
Partition partition,
57-
HashGNNConfig config,
57+
BinarizeFeaturesConfig config,
5858
HugeObjectArray<HugeAtomicBitSet> truncatedFeatures,
5959
List<FeatureExtractor> featureExtractors,
6060
double[][] propertyEmbeddings,
6161
ProgressTracker progressTracker
6262
) {
6363
this.partition = partition;
64-
this.binarizationConfig = config.binarizeFeatures().orElseThrow();
65-
this.threshold = binarizationConfig.threshold();
64+
this.dimension = config.dimension();
65+
this.threshold = config.threshold();
6666
this.truncatedFeatures = truncatedFeatures;
6767
this.featureExtractors = featureExtractors;
6868
this.propertyEmbeddings = propertyEmbeddings;
@@ -76,7 +76,7 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
7676
SplittableRandom rng,
7777
ProgressTracker progressTracker,
7878
TerminationFlag terminationFlag,
79-
MutableLong totalNumFeaturesOutput
79+
MutableLong totalFeatureCountOutput
8080
) {
8181
progressTracker.beginSubTask("Binarize node property features");
8282

@@ -88,14 +88,14 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
8888
);
8989

9090
var inputDimension = FeatureExtraction.featureCount(featureExtractors);
91-
var propertyEmbeddings = embedProperties(config, rng, inputDimension);
91+
var propertyEmbeddings = embedProperties(binarizationConfig.dimension(), rng, inputDimension);
9292

9393
var truncatedFeatures = HugeObjectArray.newArray(HugeAtomicBitSet.class, graph.nodeCount());
9494

9595
var tasks = partition.stream()
9696
.map(p -> new BinarizeTask(
9797
p,
98-
config,
98+
binarizationConfig,
9999
truncatedFeatures,
100100
featureExtractors,
101101
propertyEmbeddings,
@@ -108,7 +108,7 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
108108
.terminationFlag(terminationFlag)
109109
.run();
110110

111-
totalNumFeaturesOutput.add(tasks.stream().mapToLong(BinarizeTask::totalNumFeatures).sum());
111+
totalFeatureCountOutput.add(tasks.stream().mapToLong(BinarizeTask::totalFeatureCount).sum());
112112

113113
var squaredSum = tasks.stream().mapToDouble(BinarizeTask::scalarProductSumOfSquares).sum();
114114
var sum = tasks.stream().mapToDouble(BinarizeTask::scalarProductSum).sum();
@@ -118,38 +118,47 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
118118
var variance = (squaredSum - exampleCount * avg * avg) / exampleCount;
119119
var std = Math.sqrt(variance);
120120

121-
progressTracker.logInfo(formatWithLocale("Hyperplane scalar products have mean %.4f and standard deviation %.4f. A threshold for binarization may be set to the average plus a few standard deviations.", avg, std));
121+
progressTracker.logInfo(formatWithLocale(
122+
"Hyperplane scalar products have mean %.4f and standard deviation %.4f. A threshold for binarization may be set to the mean plus a few standard deviations.",
123+
avg,
124+
std
125+
));
122126

123127
progressTracker.endSubTask("Binarize node property features");
124128

125129
return truncatedFeatures;
126130
}
131+
127132
// creates a random projection vector for each feature
128133
// (input features vector for each node is the concatenation of the node's properties)
129134
// this array is used embed the properties themselves from inputDimension to embeddingDimension dimensions.
130-
public static double[][] embedProperties(HashGNNConfig config, SplittableRandom rng, int inputDimension) {
131-
var binarizationConfig = config.binarizeFeatures().orElseThrow();
135+
public static double[][] embedProperties(int vectorDimension, SplittableRandom rng, int inputDimension) {
132136
var propertyEmbeddings = new double[inputDimension][];
133137

134138
for (int inputFeature = 0; inputFeature < inputDimension; inputFeature++) {
135-
propertyEmbeddings[inputFeature] = new double[binarizationConfig.dimension()];
136-
for (int feature = 0; feature < binarizationConfig.dimension(); feature++) {
137-
// Box-muller transformation to generate gaussian
138-
double matrixValue = Math.sqrt(-2*Math.log(rng.nextDouble(0.0, 1.0))) * Math.cos(2*Math.PI * rng.nextDouble(0.0, 1.0));
139-
propertyEmbeddings[inputFeature][feature] = matrixValue;
139+
propertyEmbeddings[inputFeature] = new double[vectorDimension];
140+
for (int feature = 0; feature < vectorDimension; feature++) {
141+
propertyEmbeddings[inputFeature][feature] = boxMullerGaussianRandom(rng);
140142
}
141143
}
142144
return propertyEmbeddings;
143145
}
144146

147+
private static double boxMullerGaussianRandom(SplittableRandom rng) {
148+
return Math.sqrt(-2 * Math.log(rng.nextDouble(
149+
0.0,
150+
1.0
151+
))) * Math.cos(2 * Math.PI * rng.nextDouble(0.0, 1.0));
152+
}
153+
145154
@Override
146155
public void run() {
147156
partition.consume(nodeId -> {
148-
var featureVector = new float[binarizationConfig.dimension()];
157+
var featureVector = new float[dimension];
149158
FeatureExtraction.extract(nodeId, -1, featureExtractors, new FeatureConsumer() {
150159
@Override
151160
public void acceptScalar(long nodeOffset, int offset, double value) {
152-
for (int feature = 0; feature < binarizationConfig.dimension(); feature++) {
161+
for (int feature = 0; feature < dimension; feature++) {
153162
double featureValue = propertyEmbeddings[offset][feature];
154163
featureVector[feature] += value * featureValue;
155164
}
@@ -159,7 +168,7 @@ public void acceptScalar(long nodeOffset, int offset, double value) {
159168
public void acceptArray(long nodeOffset, int offset, double[] values) {
160169
for (int inputFeatureOffset = 0; inputFeatureOffset < values.length; inputFeatureOffset++) {
161170
double value = values[inputFeatureOffset];
162-
for (int feature = 0; feature < binarizationConfig.dimension(); feature++) {
171+
for (int feature = 0; feature < dimension; feature++) {
163172
double featureValue = propertyEmbeddings[offset + inputFeatureOffset][feature];
164173
featureVector[feature] += value * featureValue;
165174
}
@@ -168,7 +177,7 @@ public void acceptArray(long nodeOffset, int offset, double[] values) {
168177
});
169178

170179
var featureSet = round(featureVector);
171-
totalNumFeatures += featureSet.cardinality();
180+
totalFeatureCount += featureSet.cardinality();
172181
truncatedFeatures.set(nodeId, featureSet);
173182
});
174183

@@ -188,13 +197,14 @@ private HugeAtomicBitSet round(float[] floatVector) {
188197
return bitset;
189198
}
190199

191-
public long totalNumFeatures() {
192-
return totalNumFeatures;
200+
public long totalFeatureCount() {
201+
return totalFeatureCount;
193202
}
194203

195204
public double scalarProductSum() {
196205
return scalarProductSum;
197206
}
207+
198208
public double scalarProductSumOfSquares() {
199209
return scalarProductSumOfSquares;
200210
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/GenerateFeaturesTask.java

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class GenerateFeaturesTask implements Runnable {
4040
private final GenerateFeaturesConfig generateFeaturesConfig;
4141
private final ProgressTracker progressTracker;
4242
private final long randomSeed;
43-
private long totalNumFeatures = 0;
43+
private long totalFeatureCount = 0;
4444

4545
GenerateFeaturesTask(
4646
Partition partition,
@@ -66,7 +66,7 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
6666
long randomSeed,
6767
ProgressTracker progressTracker,
6868
TerminationFlag terminationFlag,
69-
MutableLong totalNumFeaturesOutput
69+
MutableLong totalFeatureCountOutput
7070
) {
7171
progressTracker.beginSubTask("Generate base node property features");
7272

@@ -77,7 +77,7 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
7777
p,
7878
graph,
7979
randomSeed,
80-
config.generateFeatures().get(),
80+
config.generateFeatures().orElseThrow(),
8181
output,
8282
progressTracker
8383
))
@@ -88,7 +88,7 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
8888
.terminationFlag(terminationFlag)
8989
.run();
9090

91-
totalNumFeaturesOutput.add(tasks.stream().mapToLong(GenerateFeaturesTask::totalNumFeatures).sum());
91+
totalFeatureCountOutput.add(tasks.stream().mapToLong(GenerateFeaturesTask::totalFeatureCount).sum());
9292

9393
progressTracker.endSubTask("Generate base node property features");
9494

@@ -104,21 +104,19 @@ public void run() {
104104
var generatedFeatures = HugeAtomicBitSet.create(dimension);
105105

106106
rng.setSeed(this.randomSeed ^ graph.toOriginalNodeId(nodeId));
107-
// without this we get same results for different result, at least on example in doc test
108-
rng.setSeed(rng.nextLong());
109107

110108
var randomInts = rng.ints(densityLevel, 0, dimension);
111109
randomInts.forEach(generatedFeatures::set);
112110

113-
totalNumFeatures += generatedFeatures.cardinality();
111+
totalFeatureCount += generatedFeatures.cardinality();
114112

115113
output.set(nodeId, generatedFeatures);
116114
});
117115

118116
progressTracker.logProgress(partition.nodeCount());
119117
}
120118

121-
public long totalNumFeatures() {
122-
return totalNumFeatures;
119+
public long totalFeatureCount() {
120+
return totalFeatureCount;
123121
}
124122
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNN.java

Lines changed: 32 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,16 @@
1919
*/
2020
package org.neo4j.gds.embeddings.hashgnn;
2121

22-
import org.apache.commons.lang3.mutable.MutableInt;
2322
import org.apache.commons.lang3.mutable.MutableLong;
2423
import org.neo4j.gds.Algorithm;
2524
import org.neo4j.gds.api.Graph;
2625
import org.neo4j.gds.api.schema.GraphSchema;
27-
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
2826
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
2927
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
3028
import org.neo4j.gds.core.utils.partition.Partition;
3129
import org.neo4j.gds.core.utils.partition.PartitionUtils;
3230
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
33-
import org.neo4j.gds.ml.core.features.FeatureExtraction;
3431

35-
import java.util.ArrayList;
3632
import java.util.List;
3733
import java.util.Optional;
3834
import java.util.Set;
@@ -51,7 +47,7 @@ public class HashGNN extends Algorithm<HashGNN.HashGNNResult> {
5147
private final Graph graph;
5248
private final SplittableRandom rng;
5349
private final HashGNNConfig config;
54-
private final MutableLong totalSetBits = new MutableLong();
50+
private final MutableLong currentTotalFeatureCount = new MutableLong();
5551

5652
public HashGNN(Graph graph, HashGNNConfig config, ProgressTracker progressTracker) {
5753
super(progressTracker);
@@ -93,8 +89,11 @@ public HashGNNResult compute() {
9389
var embeddingsB = constructInputEmbeddings(rangePartition);
9490
int embeddingDimension = (int) embeddingsB.get(0).size();
9591

96-
double avgInputActiveFeatures = totalSetBits.doubleValue() / graph.nodeCount();
97-
progressTracker.logInfo(formatWithLocale("Density (number of active features) of binary input features is %.4f.", avgInputActiveFeatures));
92+
double avgInputActiveFeatures = currentTotalFeatureCount.doubleValue() / graph.nodeCount();
93+
progressTracker.logInfo(formatWithLocale(
94+
"Density (number of active features) of binary input features is %.4f.",
95+
avgInputActiveFeatures
96+
));
9897

9998
var embeddingsA = HugeObjectArray.newArray(HugeAtomicBitSet.class, graph.nodeCount());
10099
embeddingsA.setAll(unused -> HugeAtomicBitSet.create(embeddingDimension));
@@ -104,7 +103,8 @@ public HashGNNResult compute() {
104103
? 1
105104
: embeddingDimension * (1 - Math.pow(
106105
1 - (1.0 / embeddingDimension),
107-
avgDegree)
106+
avgDegree
107+
)
108108
);
109109

110110
progressTracker.beginSubTask("Propagate embeddings");
@@ -118,8 +118,8 @@ public HashGNNResult compute() {
118118
currentEmbeddings.get(i).clear();
119119
}
120120

121-
double scaledNeighborInfluence = graph.relationshipCount() == 0 ? 1.0 : (totalSetBits.doubleValue() / graph.nodeCount()) * config.neighborInfluence() / upperBoundNeighborExpectedBits;
122-
totalSetBits.setValue(0);
121+
double scaledNeighborInfluence = graph.relationshipCount() == 0 ? 1.0 : (currentTotalFeatureCount.doubleValue() / graph.nodeCount()) * config.neighborInfluence() / upperBoundNeighborExpectedBits;
122+
currentTotalFeatureCount.setValue(0);
123123

124124
var hashes = HashTask.compute(
125125
embeddingDimension,
@@ -141,11 +141,15 @@ public HashGNNResult compute() {
141141
hashes,
142142
progressTracker,
143143
terminationFlag,
144-
totalSetBits
144+
currentTotalFeatureCount
145145
);
146146

147-
double avgActiveFeatures = totalSetBits.doubleValue() / graph.nodeCount();
148-
progressTracker.logInfo(formatWithLocale("After iteration %d average node embedding density (number of active features) is %.4f.", iteration, avgActiveFeatures));
147+
double avgActiveFeatures = currentTotalFeatureCount.doubleValue() / graph.nodeCount();
148+
progressTracker.logInfo(formatWithLocale(
149+
"After iteration %d average node embedding density (number of active features) is %.4f.",
150+
iteration,
151+
avgActiveFeatures
152+
));
149153
}
150154

151155
progressTracker.endSubTask("Propagate embeddings");
@@ -209,74 +213,37 @@ public HugeObjectArray<double[]> embeddings() {
209213
}
210214

211215
private HugeObjectArray<HugeAtomicBitSet> constructInputEmbeddings(List<Partition> partition) {
212-
List<HugeObjectArray<HugeAtomicBitSet>> inputEmbeddingsList = new ArrayList<>();
213-
var embeddingDimension = new MutableInt();
214-
var bitOffsets = new ArrayList<Integer>();
215-
216216
if (!config.featureProperties().isEmpty()) {
217217
if (config.binarizeFeatures().isPresent()) {
218-
inputEmbeddingsList.add(BinarizeTask.compute(
218+
return BinarizeTask.compute(
219219
graph,
220220
partition,
221221
config,
222222
rng,
223223
progressTracker,
224224
terminationFlag,
225-
totalSetBits
226-
));
227-
bitOffsets.add(embeddingDimension.getValue());
228-
embeddingDimension.add(config.binarizeFeatures().get().dimension());
225+
currentTotalFeatureCount
226+
);
229227
} else {
230-
inputEmbeddingsList.add(RawFeaturesTask.compute(
228+
return RawFeaturesTask.compute(
231229
config,
232230
progressTracker,
233231
graph,
234232
partition,
235233
terminationFlag,
236-
totalSetBits
237-
));
238-
var featureExtractors = FeatureExtraction.propertyExtractors(
239-
graph,
240-
config.featureProperties()
234+
currentTotalFeatureCount
241235
);
242-
bitOffsets.add(embeddingDimension.getValue());
243-
embeddingDimension.add(FeatureExtraction.featureCount(featureExtractors));
244236
}
237+
} else {
238+
return GenerateFeaturesTask.compute(
239+
graph,
240+
partition,
241+
config,
242+
randomSeed,
243+
progressTracker,
244+
terminationFlag,
245+
currentTotalFeatureCount
246+
);
245247
}
246-
247-
if (!config.generateFeatures().isPresent()) {
248-
return inputEmbeddingsList.get(0);
249-
}
250-
251-
inputEmbeddingsList.add(GenerateFeaturesTask.compute(
252-
graph,
253-
partition,
254-
config,
255-
randomSeed,
256-
progressTracker,
257-
terminationFlag,
258-
totalSetBits
259-
));
260-
bitOffsets.add(embeddingDimension.getValue());
261-
embeddingDimension.add(config.generateFeatures().get().dimension());
262-
263-
var concatInputEmbeddings = HugeObjectArray.newArray(HugeAtomicBitSet.class, graph.nodeCount());
264-
265-
var concatTasks = partition.stream().map(p -> (Runnable) () -> p.consume(nodeId -> {
266-
var concatFeatures = HugeAtomicBitSet.create(embeddingDimension.getValue());
267-
for (int i = 0; i < inputEmbeddingsList.size(); i++) {
268-
var embedding = inputEmbeddingsList.get(i).get(nodeId);
269-
int bitOffset = bitOffsets.get(i);
270-
embedding.forEachSetBit(bit -> concatFeatures.set(bitOffset + bit));
271-
}
272-
concatInputEmbeddings.set(nodeId, concatFeatures);
273-
})).collect(Collectors.toList());
274-
275-
RunWithConcurrency.builder()
276-
.concurrency(config.concurrency())
277-
.tasks(concatTasks)
278-
.run();
279-
280-
return concatInputEmbeddings;
281248
}
282249
}

0 commit comments

Comments
 (0)