Skip to content

Commit 4d4bd00

Browse files
authored
Merge pull request #9909 from lassewesth/diespecs18
migrate estimation cli hashgnn to application layer
2 parents adf3746 + 5228383 commit 4d4bd00

File tree

8 files changed

+37
-231
lines changed

8 files changed

+37
-231
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNNFactory.java renamed to algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNNTask.java

Lines changed: 4 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,48 +19,16 @@
1919
*/
2020
package org.neo4j.gds.embeddings.hashgnn;
2121

22-
import org.neo4j.gds.GraphAlgorithmFactory;
2322
import org.neo4j.gds.api.Graph;
24-
import org.neo4j.gds.mem.MemoryEstimation;
25-
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
23+
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel;
2624
import org.neo4j.gds.core.utils.progress.tasks.Task;
2725
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
28-
import org.neo4j.gds.termination.TerminationFlag;
2926

3027
import java.util.ArrayList;
3128
import java.util.List;
3229

33-
public class HashGNNFactory<CONFIG extends HashGNNConfig> extends GraphAlgorithmFactory<HashGNN, CONFIG> {
34-
35-
@Override
36-
public String taskName() {
37-
return "HashGNN";
38-
}
39-
40-
public HashGNN build(
41-
Graph graph,
42-
HashGNNParameters parameters,
43-
ProgressTracker progressTracker
44-
) {
45-
return new HashGNN(
46-
graph,
47-
parameters,
48-
progressTracker,
49-
TerminationFlag.RUNNING_TRUE
50-
);
51-
}
52-
53-
@Override
54-
public HashGNN build(
55-
Graph graph,
56-
CONFIG configuration,
57-
ProgressTracker progressTracker
58-
) {
59-
return build(graph, HashGNNConfigTransformer.toParameters(configuration), progressTracker);
60-
}
61-
62-
@Override
63-
public Task progressTask(Graph graph, CONFIG config) {
30+
public class HashGNNTask {
31+
public static Task create(Graph graph, HashGNNConfig config) {
6432
var tasks = new ArrayList<Task>();
6533

6634
if (config.generateFeatures().isPresent()) {
@@ -93,17 +61,8 @@ public Task progressTask(Graph graph, CONFIG config) {
9361
}
9462

9563
return Tasks.task(
96-
taskName(),
64+
AlgorithmLabel.HashGNN.asString(),
9765
tasks
9866
);
9967
}
100-
101-
public MemoryEstimation memoryEstimation(HashGNNParameters parameters) {
102-
return new HashGNNMemoryEstimateDefinition(parameters).memoryEstimation();
103-
}
104-
105-
@Override
106-
public MemoryEstimation memoryEstimation(CONFIG config) {
107-
return memoryEstimation(HashGNNConfigTransformer.toParameters(config));
108-
}
10968
}

algo/src/test/java/org/neo4j/gds/embeddings/hashgnn/HashGNNMemoryEstimateDefinitionTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ void estimationShouldUseGeneratedDimensionIfOutputIsMissing() {
111111
Optional.of(GenerateFeaturesConfigImpl.builder().dimension(inputDimension).densityLevel(1).build()),
112112
Optional.empty()
113113
);
114-
var bigEstimation = new HashGNNFactory<>()
115-
.memoryEstimation(bigParameters)
114+
var bigEstimation = new HashGNNMemoryEstimateDefinition(bigParameters)
115+
.memoryEstimation()
116116
.estimate(graphDims, concurrency)
117117
.memoryUsage();
118118

@@ -128,8 +128,8 @@ void estimationShouldUseGeneratedDimensionIfOutputIsMissing() {
128128
Optional.of(GenerateFeaturesConfigImpl.builder().dimension((int) (inputRatio * inputDimension)).densityLevel(1).build()),
129129
Optional.empty()
130130
);
131-
var smallEstimation = new HashGNNFactory<>()
132-
.memoryEstimation(smallParameters)
131+
var smallEstimation = new HashGNNMemoryEstimateDefinition(smallParameters)
132+
.memoryEstimation()
133133
.estimate(graphDims, concurrency)
134134
.memoryUsage();
135135

algo/src/test/java/org/neo4j/gds/embeddings/hashgnn/HashGNNTest.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
import org.neo4j.gds.RelationshipType;
3131
import org.neo4j.gds.ResourceUtil;
3232
import org.neo4j.gds.api.Graph;
33+
import org.neo4j.gds.applications.algorithms.embeddings.NodeEmbeddingAlgorithms;
34+
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
35+
import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies;
3336
import org.neo4j.gds.collections.ha.HugeLongArray;
3437
import org.neo4j.gds.collections.hsa.HugeSparseLongArray;
3538
import org.neo4j.gds.compat.TestLog;
@@ -272,6 +275,13 @@ void outputDimensionIsApplied() {
272275
@ParameterizedTest
273276
@CsvSource(value = {"true", "false"})
274277
void shouldLogProgress(boolean dense) {
278+
var log = new GdsTestLog();
279+
var requestScopedDependencies = RequestScopedDependencies.builder()
280+
.terminationFlag(TerminationFlag.RUNNING_TRUE)
281+
.build();
282+
var progressTrackerCreator = new ProgressTrackerCreator(log, requestScopedDependencies);
283+
var nodeEmbeddingAlgorithms = new NodeEmbeddingAlgorithms(null, progressTrackerCreator, requestScopedDependencies.terminationFlag());
284+
275285
var g = dense ? doubleGraph : binaryGraph;
276286

277287
int embeddingDensity = 200;
@@ -290,12 +300,10 @@ void shouldLogProgress(boolean dense) {
290300
}
291301
var config = configBuilder.build();
292302

293-
var factory = new HashGNNFactory<>();
294-
var progressTask = factory.progressTask(g, config);
295-
var log = new GdsTestLog();
303+
var progressTask = HashGNNTask.create(g, config);
296304
var progressTracker = new TaskProgressTracker(progressTask, log, new Concurrency(4), EmptyTaskRegistryFactory.INSTANCE);
297305

298-
factory.build(g, config, progressTracker).compute();
306+
nodeEmbeddingAlgorithms.hashGnn(g, config, progressTracker);
299307

300308
String logResource;
301309
if (dense) {

algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/hashgnn/Constants.java

Lines changed: 0 additions & 26 deletions
This file was deleted.

algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNNMutateSpec.java

Lines changed: 0 additions & 57 deletions
This file was deleted.

algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNNStreamSpec.java

Lines changed: 0 additions & 57 deletions
This file was deleted.

applications/algorithms/node-embeddings/src/main/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithms.java

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfig;
4949
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfigTransformer;
5050
import org.neo4j.gds.embeddings.hashgnn.HashGNNResult;
51+
import org.neo4j.gds.embeddings.hashgnn.HashGNNTask;
5152
import org.neo4j.gds.embeddings.node2vec.Node2Vec;
5253
import org.neo4j.gds.embeddings.node2vec.Node2VecBaseConfig;
5354
import org.neo4j.gds.embeddings.node2vec.Node2VecConfigTransformer;
@@ -196,10 +197,16 @@ private static GraphSageTrain constructGraphSageTrainAlgorithm(
196197
}
197198

198199
HashGNNResult hashGnn(Graph graph, HashGNNConfig configuration) {
199-
var task = createHashGnnTask(graph, configuration);
200+
var task = HashGNNTask.create(graph, configuration);
200201
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
201202

202-
var algorithm = new HashGNN(graph, HashGNNConfigTransformer.toParameters(configuration), progressTracker, terminationFlag);
203+
return hashGnn(graph, configuration, progressTracker);
204+
}
205+
206+
public HashGNNResult hashGnn(Graph graph, HashGNNConfig configuration, ProgressTracker progressTracker) {
207+
var parameters = HashGNNConfigTransformer.toParameters(configuration);
208+
209+
var algorithm = new HashGNN(graph, parameters, progressTracker, terminationFlag);
203210

204211
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
205212
algorithm,
@@ -246,41 +253,6 @@ private Task createFastRPTask(Graph graph, Number nodeSelfInfluence, int iterati
246253
return Tasks.task(AlgorithmLabel.FastRP.asString(), tasks);
247254
}
248255

249-
private static Task createHashGnnTask(Graph graph, HashGNNConfig configuration) {
250-
var tasks = new ArrayList<Task>();
251-
252-
if (configuration.generateFeatures().isPresent()) {
253-
tasks.add(Tasks.leaf("Generate base node property features", graph.nodeCount()));
254-
} else if (configuration.binarizeFeatures().isPresent()) {
255-
tasks.add(Tasks.leaf("Binarize node property features", graph.nodeCount()));
256-
} else {
257-
tasks.add(Tasks.leaf("Extract raw node property features", graph.nodeCount()));
258-
}
259-
260-
int numRelTypes = configuration.heterogeneous() ? configuration.relationshipTypes().size() : 1;
261-
262-
tasks.add(Tasks.iterativeFixed(
263-
"Propagate embeddings",
264-
() -> List.of(
265-
Tasks.leaf(
266-
"Precompute hashes",
267-
configuration.embeddingDensity() * (1L + 1 + numRelTypes)
268-
),
269-
Tasks.leaf(
270-
"Perform min-hashing",
271-
(2 * graph.nodeCount() + graph.relationshipCount()) * configuration.embeddingDensity()
272-
)
273-
),
274-
configuration.iterations()
275-
));
276-
277-
if (configuration.outputDimension().isPresent()) {
278-
tasks.add(Tasks.leaf("Densify output embeddings", graph.nodeCount()));
279-
}
280-
281-
return Tasks.task(AlgorithmLabel.HashGNN.asString(), tasks);
282-
}
283-
284256
private Task createNode2VecTask(Graph graph, Node2VecBaseConfig configuration) {
285257
var randomWalkTasks = new ArrayList<Task>();
286258
if (graph.hasRelationshipProperty()) {

applications/algorithms/node-embeddings/src/main/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithmsEstimationModeBusinessFacade.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfig;
3333
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfigTransformer;
3434
import org.neo4j.gds.embeddings.hashgnn.HashGNNMemoryEstimateDefinition;
35+
import org.neo4j.gds.embeddings.hashgnn.HashGNNParameters;
3536
import org.neo4j.gds.embeddings.node2vec.Node2VecBaseConfig;
3637
import org.neo4j.gds.embeddings.node2vec.Node2VecConfigTransformer;
3738
import org.neo4j.gds.embeddings.node2vec.Node2VecMemoryEstimateDefinition;
@@ -97,7 +98,13 @@ public MemoryEstimateResult graphSageTrain(GraphSageTrainConfig configuration, O
9798
}
9899

99100
public MemoryEstimation hashGnn(HashGNNConfig configuration) {
100-
return new HashGNNMemoryEstimateDefinition(HashGNNConfigTransformer.toParameters(configuration)).memoryEstimation();
101+
var parameters = HashGNNConfigTransformer.toParameters(configuration);
102+
103+
return hashGnn(parameters);
104+
}
105+
106+
private MemoryEstimation hashGnn(HashGNNParameters parameters) {
107+
return new HashGNNMemoryEstimateDefinition(parameters).memoryEstimation();
101108
}
102109

103110
public MemoryEstimateResult hashGnn(HashGNNConfig configuration, Object graphNameOrConfiguration) {

0 commit comments

Comments
 (0)