Skip to content

Commit ead43bf

Browse files
Don't sample input features in HashGNN
Co-Authored-By: Jacob Sznajdman <breakanalysis@gmail.com>
1 parent b569449 commit ead43bf

File tree

9 files changed

+95
-148
lines changed

9 files changed

+95
-148
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/BinarizeTask.java

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
*/
2020
package org.neo4j.gds.embeddings.hashgnn;
2121

22-
import com.carrotsearch.hppc.BitSet;
22+
import org.apache.commons.lang3.mutable.MutableLong;
2323
import org.neo4j.gds.api.Graph;
2424
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
2525
import org.neo4j.gds.core.utils.TerminationFlag;
@@ -32,47 +32,34 @@
3232
import org.neo4j.gds.ml.core.features.FeatureExtractor;
3333
import org.neo4j.gds.ml.util.ShuffleUtil;
3434

35-
import java.util.ArrayList;
3635
import java.util.Arrays;
3736
import java.util.List;
3837
import java.util.SplittableRandom;
3938
import java.util.stream.Collectors;
4039

41-
import static org.neo4j.gds.embeddings.hashgnn.HashGNNCompanion.hashArgMin;
42-
4340
class BinarizeTask implements Runnable {
4441
private final Partition partition;
4542
private final HugeObjectArray<HugeAtomicBitSet> truncatedFeatures;
4643
private final List<FeatureExtractor> featureExtractors;
4744
private final int[][] propertyEmbeddings;
48-
private final List<int[]> hashesList;
49-
private final HashGNN.MinAndArgmin minAndArgMin;
5045
private final FeatureBinarizationConfig binarizationConfig;
5146
private final ProgressTracker progressTracker;
52-
private final int sampledBits;
47+
private long totalNumFeatures;
5348

5449
BinarizeTask(
5550
Partition partition,
5651
HashGNNConfig config,
5752
HugeObjectArray<HugeAtomicBitSet> truncatedFeatures,
5853
List<FeatureExtractor> featureExtractors,
5954
int[][] propertyEmbeddings,
60-
List<int[]> hashesList,
6155
ProgressTracker progressTracker
6256
) {
6357
this.partition = partition;
6458
this.binarizationConfig = config.binarizeFeatures().orElseThrow();
6559
this.truncatedFeatures = truncatedFeatures;
6660
this.featureExtractors = featureExtractors;
6761
this.propertyEmbeddings = propertyEmbeddings;
68-
this.hashesList = hashesList;
69-
this.minAndArgMin = new HashGNN.MinAndArgmin();
7062
this.progressTracker = progressTracker;
71-
72-
var densityOffset = config.generateFeatures().isPresent()
73-
? config.generateFeatures().get().densityLevel()
74-
: 0;
75-
this.sampledBits = config.embeddingDensity() - densityOffset;
7663
}
7764

7865
static HugeObjectArray<HugeAtomicBitSet> compute(
@@ -81,18 +68,11 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
8168
HashGNNConfig config,
8269
SplittableRandom rng,
8370
ProgressTracker progressTracker,
84-
TerminationFlag terminationFlag
71+
TerminationFlag terminationFlag,
72+
MutableLong totalNumFeaturesOutput
8573
) {
8674
progressTracker.beginSubTask("Binarize node property features");
8775

88-
var hashesList = new ArrayList<int[]>(config.embeddingDensity());
89-
for (int i = 0; i < config.embeddingDensity(); i++) {
90-
hashesList.add(HashGNNCompanion.HashTriple.computeHashesFromTriple(
91-
config.binarizeFeatures().get().dimension(),
92-
HashGNNCompanion.HashTriple.generate(rng)
93-
));
94-
}
95-
9676
var featureExtractors = FeatureExtraction.propertyExtractors(
9777
graph,
9878
config.featureProperties()
@@ -110,7 +90,6 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
11090
truncatedFeatures,
11191
featureExtractors,
11292
propertyEmbeddings,
113-
hashesList,
11493
progressTracker
11594
))
11695
.collect(Collectors.toList());
@@ -120,6 +99,8 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
12099
.terminationFlag(terminationFlag)
121100
.run();
122101

102+
totalNumFeaturesOutput.add(tasks.stream().mapToLong(BinarizeTask::totalNumFeatures).sum());
103+
123104
progressTracker.endSubTask("Binarize node property features");
124105

125106
return truncatedFeatures;
@@ -149,8 +130,6 @@ public static int[][] embedProperties(HashGNNConfig config, SplittableRandom rng
149130

150131
@Override
151132
public void run() {
152-
var tempFeatureContainer = new BitSet(binarizationConfig.dimension());
153-
154133
partition.consume(nodeId -> {
155134
var featureVector = new float[binarizationConfig.dimension()];
156135
FeatureExtraction.extract(nodeId, -1, featureExtractors, new FeatureConsumer() {
@@ -183,27 +162,21 @@ public void acceptArray(long nodeOffset, int offset, double[] values) {
183162
}
184163
});
185164

186-
truncatedFeatures.set(nodeId, roundAndSample(tempFeatureContainer, featureVector));
165+
var bitSet = HugeAtomicBitSet.create(binarizationConfig.dimension());
166+
for (int feature = 0; feature < featureVector.length; feature++) {
167+
if (featureVector[feature] > 0) {
168+
bitSet.set(feature);
169+
}
170+
}
171+
totalNumFeatures += bitSet.cardinality();
172+
truncatedFeatures.set(nodeId, bitSet);
187173
});
188174

189175
progressTracker.logProgress(partition.nodeCount());
190176
}
191177

192-
private HugeAtomicBitSet roundAndSample(BitSet tempBitSet, float[] floatVector) {
193-
tempBitSet.clear();
194-
for (int feature = 0; feature < floatVector.length; feature++) {
195-
if (floatVector[feature] > 0) {
196-
tempBitSet.set(feature);
197-
}
198-
}
199-
var sampledBitset = HugeAtomicBitSet.create(binarizationConfig.dimension());
200-
for (int i = 0; i < sampledBits; i++) {
201-
hashArgMin(tempBitSet, hashesList.get(i), minAndArgMin);
202-
if (minAndArgMin.argMin != -1) {
203-
sampledBitset.set(minAndArgMin.argMin);
204-
}
205-
}
206-
return sampledBitset;
178+
public long totalNumFeatures() {
179+
return totalNumFeatures;
207180
}
208181

209182
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/GenerateFeaturesTask.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
package org.neo4j.gds.embeddings.hashgnn;
2121

22+
import org.apache.commons.lang3.mutable.MutableLong;
2223
import org.neo4j.gds.api.Graph;
2324
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
2425
import org.neo4j.gds.core.utils.TerminationFlag;
@@ -37,6 +38,7 @@ class GenerateFeaturesTask implements Runnable {
3738
private final SplittableRandom rng;
3839
private final FeatureBinarizationConfig generateFeaturesConfig;
3940
private final ProgressTracker progressTracker;
41+
private long totalNumFeatures = 0;
4042

4143
GenerateFeaturesTask(
4244
Partition partition,
@@ -58,7 +60,8 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
5860
HashGNNConfig config,
5961
SplittableRandom rng,
6062
ProgressTracker progressTracker,
61-
TerminationFlag terminationFlag
63+
TerminationFlag terminationFlag,
64+
MutableLong totalNumFeaturesOutput
6265
) {
6366
progressTracker.beginSubTask("Generate base node property features");
6467

@@ -79,6 +82,8 @@ static HugeObjectArray<HugeAtomicBitSet> compute(
7982
.terminationFlag(terminationFlag)
8083
.run();
8184

85+
totalNumFeaturesOutput.add(tasks.stream().mapToLong(GenerateFeaturesTask::totalNumFeatures).sum());
86+
8287
progressTracker.endSubTask("Generate base node property features");
8388

8489
return output;
@@ -95,10 +100,15 @@ public void run() {
95100
var randomInts = rng.ints(densityLevel, 0, dimension);
96101
randomInts.forEach(generatedFeatures::set);
97102

103+
totalNumFeatures += generatedFeatures.cardinality();
104+
98105
output.set(nodeId, generatedFeatures);
99106
});
100107

101108
progressTracker.logProgress(partition.nodeCount());
102109
}
103110

111+
public long totalNumFeatures() {
112+
return totalNumFeatures;
113+
}
104114
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNN.java

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.neo4j.gds.embeddings.hashgnn;
2121

2222
import org.apache.commons.lang3.mutable.MutableInt;
23+
import org.apache.commons.lang3.mutable.MutableLong;
2324
import org.neo4j.gds.Algorithm;
2425
import org.neo4j.gds.api.Graph;
2526
import org.neo4j.gds.api.schema.GraphSchema;
@@ -48,6 +49,7 @@ public class HashGNN extends Algorithm<HashGNN.HashGNNResult> {
4849
private final Graph graph;
4950
private final SplittableRandom rng;
5051
private final HashGNNConfig config;
52+
private final MutableLong totalSetBits = new MutableLong();
5153

5254
public HashGNN(Graph graph, HashGNNConfig config, ProgressTracker progressTracker) {
5355
super(progressTracker);
@@ -91,26 +93,12 @@ public HashGNNResult compute() {
9193
embeddingsA.setAll(unused -> HugeAtomicBitSet.create(embeddingDimension));
9294

9395
double avgDegree = (graph.relationshipCount() / (double) graph.nodeCount());
94-
int upperBoundBits = Math.max(Math.min(embeddingDimension, config.embeddingDensity()), 1);
95-
int upperBoundSelfExpectedBits = (int) Math.round(upperBoundBits * (1 - Math.pow(
96-
1 - (1.0 / upperBoundBits),
97-
config.embeddingDensity()
98-
)));
99-
double upperBoundNeighborExpectedBits = upperBoundBits * (1 - Math.pow(
100-
1 - (1.0 / upperBoundBits),
101-
avgDegree
102-
));
103-
double scaledNeighborInfluence = graph.relationshipCount() == 0 ? 1.0 : upperBoundSelfExpectedBits * config.neighborInfluence() / upperBoundNeighborExpectedBits;
104-
105-
var hashes = HashTask.compute(
106-
embeddingDimension,
107-
scaledNeighborInfluence,
108-
graphs.size(),
109-
config,
110-
randomSeed,
111-
terminationFlag,
112-
progressTracker
113-
);
96+
double upperBoundNeighborExpectedBits = embeddingDimension == 0
97+
? 1
98+
: embeddingDimension * (1 - Math.pow(
99+
1 - (1.0 / embeddingDimension),
100+
avgDegree)
101+
);
114102

115103
progressTracker.beginSubTask("Propagate embeddings");
116104

@@ -123,17 +111,30 @@ public HashGNNResult compute() {
123111
currentEmbeddings.get(i).clear();
124112
}
125113

114+
double scaledNeighborInfluence = graph.relationshipCount() == 0 ? 1.0 : (totalSetBits.doubleValue() / graph.nodeCount()) * config.neighborInfluence() / upperBoundNeighborExpectedBits;
115+
totalSetBits.setValue(0);
116+
117+
var hashes = HashTask.compute(
118+
embeddingDimension,
119+
scaledNeighborInfluence,
120+
graphs.size(),
121+
config,
122+
randomSeed,
123+
terminationFlag,
124+
progressTracker
125+
);
126+
126127
MinHashTask.compute(
127128
degreePartition,
128129
graphs,
129130
config,
130131
embeddingDimension,
131132
currentEmbeddings,
132133
previousEmbeddings,
133-
iteration,
134134
hashes,
135135
progressTracker,
136-
terminationFlag
136+
terminationFlag,
137+
totalSetBits
137138
);
138139
}
139140

@@ -210,18 +211,19 @@ private HugeObjectArray<HugeAtomicBitSet> constructInputEmbeddings(List<Partitio
210211
config,
211212
rng,
212213
progressTracker,
213-
terminationFlag
214+
terminationFlag,
215+
totalSetBits
214216
));
215217
bitOffsets.add(embeddingDimension.getValue());
216218
embeddingDimension.add(config.binarizeFeatures().get().dimension());
217219
} else {
218220
inputEmbeddingsList.add(RawFeaturesTask.compute(
219221
config,
220-
rng,
221222
progressTracker,
222223
graph,
223224
partition,
224-
terminationFlag
225+
terminationFlag,
226+
totalSetBits
225227
));
226228
var featureExtractors = FeatureExtraction.propertyExtractors(
227229
graph,
@@ -242,7 +244,8 @@ private HugeObjectArray<HugeAtomicBitSet> constructInputEmbeddings(List<Partitio
242244
config,
243245
rng,
244246
progressTracker,
245-
terminationFlag
247+
terminationFlag,
248+
totalSetBits
246249
));
247250
bitOffsets.add(embeddingDimension.getValue());
248251
embeddingDimension.add(config.generateFeatures().get().dimension());

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNNFactory.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,19 @@ public Task progressTask(Graph graph, CONFIG config) {
7171
}
7272

7373
int numRelTypes = config.heterogeneous() ? config.relationshipTypes().size() : 1;
74-
tasks.add(Tasks.leaf(
75-
"Precompute hashes",
76-
config.iterations() * config.embeddingDensity() * (1 + 1 + numRelTypes)
77-
));
7874

7975
tasks.add(Tasks.iterativeFixed(
8076
"Propagate embeddings",
81-
() -> List.of(Tasks.leaf(
82-
"Propagate embeddings iteration",
83-
(2 * graph.nodeCount() + graph.relationshipCount()) * config.embeddingDensity()
84-
)),
77+
() -> List.of(
78+
Tasks.leaf(
79+
"Precompute hashes",
80+
config.embeddingDensity() * (1 + 1 + numRelTypes)
81+
),
82+
Tasks.leaf(
83+
"Propagate embeddings iteration",
84+
(2 * graph.nodeCount() + graph.relationshipCount()) * config.embeddingDensity()
85+
)
86+
),
8587
config.iterations()
8688
));
8789

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashTask.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ public static List<Hashes> compute(
6969
) {
7070
progressTracker.beginSubTask("Precompute hashes");
7171

72-
progressTracker.setSteps(config.iterations() * config.embeddingDensity());
72+
progressTracker.setSteps(config.embeddingDensity());
7373

74-
var hashTasks = IntStream.range(0, config.iterations() * config.embeddingDensity()).mapToObj(seedOffset ->
74+
var hashTasks = IntStream.range(0, config.embeddingDensity()).mapToObj(seedOffset ->
7575
new HashTask(
7676
embeddingDimension,
7777
scaledNeighborInfluence,

0 commit comments

Comments
 (0)