Skip to content

Commit 19ef29d

Browse files
Fix graphsage memory estimation
1 parent 5718d23 commit 19ef29d

File tree

4 files changed

+41
-3
lines changed

4 files changed

+41
-3
lines changed

algo/src/main/java/org/neo4j/gds/algorithms/embeddings/NodeEmbeddingsAlgorithmsEstimateBusinessFacade.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,23 @@
3030
import org.neo4j.gds.embeddings.hashgnn.HashGNNFactory;
3131
import org.neo4j.gds.embeddings.node2vec.Node2VecAlgorithmFactory;
3232
import org.neo4j.gds.embeddings.node2vec.Node2VecBaseConfig;
33+
import org.neo4j.gds.modelcatalogservices.ModelCatalogService;
3334
import org.neo4j.gds.results.MemoryEstimateResult;
3435

3536
import java.util.Optional;
3637

3738
public class NodeEmbeddingsAlgorithmsEstimateBusinessFacade {
3839

3940
private final AlgorithmEstimator algorithmEstimator;
41+
private final ModelCatalogService modelCatalogService;
42+
4043

4144
public NodeEmbeddingsAlgorithmsEstimateBusinessFacade(
42-
AlgorithmEstimator algorithmEstimator
45+
AlgorithmEstimator algorithmEstimator,
46+
ModelCatalogService modelCatalogService
4347
) {
4448
this.algorithmEstimator = algorithmEstimator;
49+
this.modelCatalogService = modelCatalogService;
4550
}
4651

4752
public <C extends Node2VecBaseConfig> MemoryEstimateResult node2Vec(
@@ -64,7 +69,7 @@ public <C extends GraphSageBaseConfig> MemoryEstimateResult graphSage(
6469
graphNameOrConfiguration,
6570
configuration,
6671
Optional.empty(),
67-
new GraphSageAlgorithmFactory(null)
72+
new GraphSageAlgorithmFactory(modelCatalogService.get())
6873
);
6974
}
7075

proc/embeddings/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageMutateProcTest.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import java.util.stream.Stream;
4949

5050
import static org.assertj.core.api.Assertions.assertThat;
51+
import static org.assertj.core.api.Assertions.assertThatNoException;
5152
import static org.assertj.core.api.Assertions.assertThatThrownBy;
5253
import static org.junit.jupiter.api.Assertions.assertEquals;
5354
import static org.neo4j.gds.ElementProjection.PROJECT_ALL;
@@ -262,4 +263,32 @@ private static Stream<Arguments> missingNodeProperties() {
262263
);
263264
}
264265

266+
@Test
267+
void shouldEstimateMemory() {
268+
var trainQuery = GdsCypher.call(graphName)
269+
.algo("gds.beta.graphSage")
270+
.trainMode()
271+
.addParameter("sampleSizes", List.of(2, 4))
272+
.addParameter("featureProperties", List.of("age", "birth_year", "death_year"))
273+
.addParameter("embeddingDimension", 16)
274+
.addParameter("activationFunction", ActivationFunction.SIGMOID)
275+
.addParameter("aggregator", "mean")
276+
.addParameter("modelName", modelName)
277+
.yields();
278+
279+
runQuery(trainQuery);
280+
281+
var mutatePropertyKey = "embedding";
282+
var query = GdsCypher.call("embeddingsGraph")
283+
.algo("gds.beta.graphSage")
284+
.mutateEstimation()
285+
.addParameter("mutateProperty", mutatePropertyKey)
286+
.addParameter("modelName", modelName)
287+
.yields("requiredMemory");
288+
289+
assertThatNoException().isThrownBy(() -> runQuery(query));
290+
291+
292+
}
293+
265294
}

proc/embeddings/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageStreamProcTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ void weightedGraphSage() {
177177
.addParameter("concurrency", 1)
178178
.addParameter("modelName", modelName)
179179
.yields();
180+
180181
long[] nodeIds = idFunction.of("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o");
181182
var rowCount = runQueryWithRowConsumer(streamQuery, (resultRow) -> {
182183
assertThat(nodeIds).contains((long) resultRow.getNumber("nodeId"));

procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/AlgorithmProcedureFacadeProvider.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,10 @@ NodeEmbeddingsProcedureFacade createNodeEmbeddingsProcedureFacade() {
294294
writeNodePropertyService
295295
);
296296

297-
var estimateBusinessFacade = new NodeEmbeddingsAlgorithmsEstimateBusinessFacade(algorithmEstimator);
297+
var estimateBusinessFacade = new NodeEmbeddingsAlgorithmsEstimateBusinessFacade(
298+
algorithmEstimator,
299+
modelCatalogService
300+
);
298301

299302
// procedure facade
300303
return new NodeEmbeddingsProcedureFacade(

0 commit comments

Comments
 (0)