Skip to content

Commit b569449

Browse files
Add support for generating node features in HashGNN
Co-Authored-By: Jacob Sznajdman <breakanalysis@gmail.com>
1 parent 04b8b31 commit b569449

File tree

6 files changed

+209
-20
lines changed

6 files changed

+209
-20
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/BinarizeTask.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@
4242

4343
class BinarizeTask implements Runnable {
4444
private final Partition partition;
45-
private final HashGNNConfig config;
4645
private final HugeObjectArray<HugeAtomicBitSet> truncatedFeatures;
4746
private final List<FeatureExtractor> featureExtractors;
4847
private final int[][] propertyEmbeddings;
4948
private final List<int[]> hashesList;
5049
private final HashGNN.MinAndArgmin minAndArgMin;
5150
private final FeatureBinarizationConfig binarizationConfig;
5251
private final ProgressTracker progressTracker;
52+
private final int sampledBits;
5353

5454
BinarizeTask(
5555
Partition partition,
@@ -61,14 +61,18 @@ class BinarizeTask implements Runnable {
6161
ProgressTracker progressTracker
6262
) {
6363
this.partition = partition;
64-
this.config = config;
6564
this.binarizationConfig = config.binarizeFeatures().orElseThrow();
6665
this.truncatedFeatures = truncatedFeatures;
6766
this.featureExtractors = featureExtractors;
6867
this.propertyEmbeddings = propertyEmbeddings;
6968
this.hashesList = hashesList;
7069
this.minAndArgMin = new HashGNN.MinAndArgmin();
7170
this.progressTracker = progressTracker;
71+
72+
var densityOffset = config.generateFeatures().isPresent()
73+
? config.generateFeatures().get().densityLevel()
74+
: 0;
75+
this.sampledBits = config.embeddingDensity() - densityOffset;
7276
}
7377

7478
static HugeObjectArray<HugeAtomicBitSet> compute(
@@ -193,7 +197,7 @@ private HugeAtomicBitSet roundAndSample(BitSet tempBitSet, float[] floatVector)
193197
}
194198
}
195199
var sampledBitset = HugeAtomicBitSet.create(binarizationConfig.dimension());
196-
for (int i = 0; i < config.embeddingDensity(); i++) {
200+
for (int i = 0; i < sampledBits; i++) {
197201
hashArgMin(tempBitSet, hashesList.get(i), minAndArgMin);
198202
if (minAndArgMin.argMin != -1) {
199203
sampledBitset.set(minAndArgMin.argMin);
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings.hashgnn;
21+
22+
import org.neo4j.gds.api.Graph;
23+
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
24+
import org.neo4j.gds.core.utils.TerminationFlag;
25+
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
26+
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
27+
import org.neo4j.gds.core.utils.partition.Partition;
28+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
29+
30+
import java.util.List;
31+
import java.util.SplittableRandom;
32+
import java.util.stream.Collectors;
33+
34+
class GenerateFeaturesTask implements Runnable {
35+
private final Partition partition;
36+
private final HugeObjectArray<HugeAtomicBitSet> output;
37+
private final SplittableRandom rng;
38+
private final FeatureBinarizationConfig generateFeaturesConfig;
39+
private final ProgressTracker progressTracker;
40+
41+
GenerateFeaturesTask(
42+
Partition partition,
43+
SplittableRandom rng,
44+
FeatureBinarizationConfig config,
45+
HugeObjectArray<HugeAtomicBitSet> output,
46+
ProgressTracker progressTracker
47+
) {
48+
this.partition = partition;
49+
this.rng = rng;
50+
this.generateFeaturesConfig = config;
51+
this.output = output;
52+
this.progressTracker = progressTracker;
53+
}
54+
55+
static HugeObjectArray<HugeAtomicBitSet> compute(
56+
Graph graph,
57+
List<Partition> partition,
58+
HashGNNConfig config,
59+
SplittableRandom rng,
60+
ProgressTracker progressTracker,
61+
TerminationFlag terminationFlag
62+
) {
63+
progressTracker.beginSubTask("Generate base node property features");
64+
65+
var output = HugeObjectArray.newArray(HugeAtomicBitSet.class, graph.nodeCount());
66+
67+
var tasks = partition.stream()
68+
.map(p -> new GenerateFeaturesTask(
69+
p,
70+
rng.split(),
71+
config.generateFeatures().get(),
72+
output,
73+
progressTracker
74+
))
75+
.collect(Collectors.toList());
76+
RunWithConcurrency.builder()
77+
.concurrency(config.concurrency())
78+
.tasks(tasks)
79+
.terminationFlag(terminationFlag)
80+
.run();
81+
82+
progressTracker.endSubTask("Generate base node property features");
83+
84+
return output;
85+
}
86+
87+
@Override
88+
public void run() {
89+
int dimension = generateFeaturesConfig.dimension();
90+
int densityLevel = generateFeaturesConfig.densityLevel();
91+
92+
partition.consume(nodeId -> {
93+
var generatedFeatures = HugeAtomicBitSet.create(dimension);
94+
95+
var randomInts = rng.ints(densityLevel, 0, dimension);
96+
randomInts.forEach(generatedFeatures::set);
97+
98+
output.set(nodeId, generatedFeatures);
99+
});
100+
101+
progressTracker.logProgress(partition.nodeCount());
102+
}
103+
104+
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNN.java

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,19 @@
1919
*/
2020
package org.neo4j.gds.embeddings.hashgnn;
2121

22+
import org.apache.commons.lang3.mutable.MutableInt;
2223
import org.neo4j.gds.Algorithm;
2324
import org.neo4j.gds.api.Graph;
2425
import org.neo4j.gds.api.schema.GraphSchema;
26+
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
2527
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
2628
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
29+
import org.neo4j.gds.core.utils.partition.Partition;
2730
import org.neo4j.gds.core.utils.partition.PartitionUtils;
2831
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2932
import org.neo4j.gds.ml.core.features.FeatureExtraction;
3033

34+
import java.util.ArrayList;
3135
import java.util.List;
3236
import java.util.Optional;
3337
import java.util.Set;
@@ -80,16 +84,8 @@ public HashGNNResult compute() {
8084
.collect(Collectors.toList())
8185
: List.of(graphCopy);
8286

83-
var embeddingsB = config.binarizeFeatures().isPresent()
84-
? BinarizeTask.compute(graph, rangePartition, config, rng, progressTracker, terminationFlag)
85-
: RawFeaturesTask.compute(config, rng, progressTracker, graph, rangePartition, terminationFlag);
86-
int embeddingDimension = config.binarizeFeatures().map(FeatureBinarizationConfig::dimension).orElseGet(() -> {
87-
var featureExtractors = FeatureExtraction.propertyExtractors(
88-
graph,
89-
config.featureProperties()
90-
);
91-
return FeatureExtraction.featureCount(featureExtractors);
92-
});
87+
var embeddingsB = constructInputEmbeddings(rangePartition);
88+
int embeddingDimension = (int) embeddingsB.get(0).size();
9389

9490
var embeddingsA = HugeObjectArray.newArray(HugeAtomicBitSet.class, graph.nodeCount());
9591
embeddingsA.setAll(unused -> HugeAtomicBitSet.create(embeddingDimension));
@@ -201,4 +197,73 @@ public HugeObjectArray<double[]> embeddings() {
201197
}
202198
}
203199

200+
private HugeObjectArray<HugeAtomicBitSet> constructInputEmbeddings(List<Partition> partition) {
201+
List<HugeObjectArray<HugeAtomicBitSet>> inputEmbeddingsList = new ArrayList<>();
202+
var embeddingDimension = new MutableInt();
203+
var bitOffsets = new ArrayList<Integer>();
204+
205+
if (!config.featureProperties().isEmpty()) {
206+
if (config.binarizeFeatures().isPresent()) {
207+
inputEmbeddingsList.add(BinarizeTask.compute(
208+
graph,
209+
partition,
210+
config,
211+
rng,
212+
progressTracker,
213+
terminationFlag
214+
));
215+
bitOffsets.add(embeddingDimension.getValue());
216+
embeddingDimension.add(config.binarizeFeatures().get().dimension());
217+
} else {
218+
inputEmbeddingsList.add(RawFeaturesTask.compute(
219+
config,
220+
rng,
221+
progressTracker,
222+
graph,
223+
partition,
224+
terminationFlag
225+
));
226+
var featureExtractors = FeatureExtraction.propertyExtractors(
227+
graph,
228+
config.featureProperties()
229+
);
230+
bitOffsets.add(embeddingDimension.getValue());
231+
embeddingDimension.add(FeatureExtraction.featureCount(featureExtractors));
232+
}
233+
}
234+
235+
if (!config.generateFeatures().isPresent()) {
236+
return inputEmbeddingsList.get(0);
237+
}
238+
239+
inputEmbeddingsList.add(GenerateFeaturesTask.compute(
240+
graph,
241+
partition,
242+
config,
243+
rng,
244+
progressTracker,
245+
terminationFlag
246+
));
247+
bitOffsets.add(embeddingDimension.getValue());
248+
embeddingDimension.add(config.generateFeatures().get().dimension());
249+
250+
var concatInputEmbeddings = HugeObjectArray.newArray(HugeAtomicBitSet.class, graph.nodeCount());
251+
252+
var concatTasks = partition.stream().map(p -> (Runnable) () -> p.consume(nodeId -> {
253+
var concatFeatures = HugeAtomicBitSet.create(embeddingDimension.getValue());
254+
for (int i = 0; i < inputEmbeddingsList.size(); i++) {
255+
var embedding = inputEmbeddingsList.get(i).get(nodeId);
256+
int bitOffset = bitOffsets.get(i);
257+
embedding.forEachSetBit(bit -> concatFeatures.set(bitOffset + bit));
258+
}
259+
concatInputEmbeddings.set(nodeId, concatFeatures);
260+
})).collect(Collectors.toList());
261+
262+
RunWithConcurrency.builder()
263+
.concurrency(config.concurrency())
264+
.tasks(concatTasks)
265+
.run();
266+
267+
return concatInputEmbeddings;
268+
}
204269
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNNConfig.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ default boolean heterogeneous() {
4949
return false;
5050
}
5151

52+
@Configuration.ToMapValue("org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#toMapBinarizationConfig")
53+
@Configuration.ConvertWith(method = "org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#parseBinarizationConfig", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
54+
default Optional<FeatureBinarizationConfig> generateFeatures() {
55+
return Optional.empty();
56+
}
57+
5258
@Configuration.ToMapValue("org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#toMapBinarizationConfig")
5359
@Configuration.ConvertWith(method = "org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#parseBinarizationConfig", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
5460
default Optional<FeatureBinarizationConfig> binarizeFeatures() {

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNNFactory.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,16 @@ public HashGNN build(
5858
public Task progressTask(Graph graph, CONFIG config) {
5959
var tasks = new ArrayList<Task>();
6060

61-
if (config.binarizeFeatures().isPresent()) {
62-
tasks.add(Tasks.leaf("Binarize node property features", graph.nodeCount()));
63-
} else {
64-
tasks.add(Tasks.leaf("Extract raw node property features", graph.nodeCount()));
61+
if (!config.featureProperties().isEmpty()) {
62+
if (config.binarizeFeatures().isPresent()) {
63+
tasks.add(Tasks.leaf("Binarize node property features", graph.nodeCount()));
64+
} else {
65+
tasks.add(Tasks.leaf("Extract raw node property features", graph.nodeCount()));
66+
}
67+
}
68+
69+
if (config.generateFeatures().isPresent()) {
70+
tasks.add(Tasks.leaf("Generate base node property features", graph.nodeCount()));
6571
}
6672

6773
int numRelTypes = config.heterogeneous() ? config.relationshipTypes().size() : 1;

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/RawFeaturesTask.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@
4141

4242
class RawFeaturesTask implements Runnable {
4343
private final Partition partition;
44-
private final HashGNNConfig config;
4544
private final List<FeatureExtractor> featureExtractors;
4645
private final int inputDimension;
4746
private final HugeObjectArray<HugeAtomicBitSet> features;
4847
private final List<int[]> hashesList;
4948
private final ProgressTracker progressTracker;
49+
private final int sampledBits;
5050

5151
RawFeaturesTask(
5252
Partition partition,
@@ -58,12 +58,16 @@ class RawFeaturesTask implements Runnable {
5858
ProgressTracker progressTracker
5959
) {
6060
this.partition = partition;
61-
this.config = config;
6261
this.featureExtractors = featureExtractors;
6362
this.inputDimension = inputDimension;
6463
this.features = features;
6564
this.hashesList = hashesList;
6665
this.progressTracker = progressTracker;
66+
67+
var densityOffset = config.generateFeatures().isPresent()
68+
? config.generateFeatures().get().densityLevel()
69+
: 0;
70+
this.sampledBits = config.embeddingDensity() - densityOffset;
6771
}
6872

6973
static HugeObjectArray<HugeAtomicBitSet> compute(
@@ -140,7 +144,7 @@ public void acceptArray(long nodeOffset, int offset, double[] values) {
140144
return;
141145
}
142146
var sampledBitset = HugeAtomicBitSet.create(inputDimension);
143-
for (int i = 0; i < config.embeddingDensity(); i++) {
147+
for (int i = 0; i < sampledBits; i++) {
144148
hashArgMin(nodeFeatures, hashesList.get(i), resMinAndArgMin, tempMinAndArgMin);
145149
if (resMinAndArgMin.argMin != -1) {
146150
sampledBitset.set(resMinAndArgMin.argMin);

0 commit comments

Comments
 (0)