Skip to content

Commit a3f24a0

Browse files
Merge pull request #9982 from IoannisPanagiotas/hashgnn-write-mode
HashGNN Write mode
2 parents 4ae2bd0 + 9667295 commit a3f24a0

File tree

12 files changed

+466
-3
lines changed

12 files changed

+466
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.applications.algorithms.embeddings;
21+
22+
import org.neo4j.gds.api.Graph;
23+
import org.neo4j.gds.api.GraphStore;
24+
import org.neo4j.gds.api.ResultStore;
25+
import org.neo4j.gds.applications.algorithms.machinery.WriteStep;
26+
import org.neo4j.gds.applications.algorithms.machinery.WriteToDatabase;
27+
import org.neo4j.gds.applications.algorithms.metadata.NodePropertiesWritten;
28+
import org.neo4j.gds.core.utils.progress.JobId;
29+
import org.neo4j.gds.embeddings.hashgnn.HashGNNResult;
30+
import org.neo4j.gds.embeddings.hashgnn.HashGNNWriteConfig;
31+
32+
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.GraphSage;
33+
34+
class HashGnnWriteStep implements WriteStep<HashGNNResult, NodePropertiesWritten> {
35+
private final WriteToDatabase writeToDatabase;
36+
private final HashGNNWriteConfig configuration;
37+
38+
HashGnnWriteStep(WriteToDatabase writeToDatabase, HashGNNWriteConfig configuration) {
39+
this.writeToDatabase = writeToDatabase;
40+
this.configuration = configuration;
41+
}
42+
43+
@Override
44+
public NodePropertiesWritten execute(
45+
Graph graph,
46+
GraphStore graphStore,
47+
ResultStore resultStore,
48+
HashGNNResult result,
49+
JobId jobId
50+
) {
51+
var nodePropertyValues = result.embeddings();
52+
53+
return writeToDatabase.perform(
54+
graph,
55+
graphStore,
56+
resultStore,
57+
configuration,
58+
configuration,
59+
GraphSage,
60+
jobId,
61+
nodePropertyValues
62+
);
63+
}
64+
}

applications/algorithms/node-embeddings/src/main/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithmsWriteModeBusinessFacade.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
import org.neo4j.gds.embeddings.fastrp.FastRPWriteConfig;
3131
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageResult;
3232
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageWriteConfig;
33+
import org.neo4j.gds.embeddings.hashgnn.HashGNNResult;
34+
import org.neo4j.gds.embeddings.hashgnn.HashGNNWriteConfig;
3335
import org.neo4j.gds.embeddings.node2vec.Node2VecResult;
3436
import org.neo4j.gds.embeddings.node2vec.Node2VecWriteConfig;
3537
import org.neo4j.gds.logging.Log;
@@ -39,6 +41,7 @@
3941

4042
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.FastRP;
4143
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.GraphSage;
44+
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.HashGNN;
4245
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.Node2Vec;
4346

4447
public final class NodeEmbeddingAlgorithmsWriteModeBusinessFacade {
@@ -120,6 +123,23 @@ public <RESULT> RESULT graphSage(
120123
resultBuilder
121124
);
122125
}
126+
public <RESULT> RESULT hashGnn(
127+
GraphName graphName,
128+
HashGNNWriteConfig configuration,
129+
ResultBuilder<HashGNNWriteConfig, HashGNNResult, RESULT, NodePropertiesWritten> resultBuilder
130+
) {
131+
var writeStep = new HashGnnWriteStep(writeToDatabase, configuration);
132+
133+
return algorithmProcessingTemplateConvenience.processRegularAlgorithmInWriteMode(
134+
graphName,
135+
configuration,
136+
HashGNN,
137+
() -> estimationFacade.hashGnn(configuration),
138+
(graph, __) -> algorithms.hashGnn(graph, configuration),
139+
writeStep,
140+
resultBuilder
141+
);
142+
}
123143

124144
public <RESULT> RESULT node2Vec(
125145
GraphName graphName,

doc-test/src/test/java/org/neo4j/gds/doc/HashGNNDocTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import org.neo4j.gds.embeddings.hashgnn.HashGNNMutateProc;
2323
import org.neo4j.gds.embeddings.hashgnn.HashGNNStreamProc;
24+
import org.neo4j.gds.embeddings.hashgnn.HashGNNWriteProc;
2425
import org.neo4j.gds.functions.AsNodeFunc;
2526
import org.neo4j.gds.scaling.ScalePropertiesMutateProc;
2627

@@ -33,6 +34,7 @@ protected List<Class<?>> procedures() {
3334
return List.of(
3435
HashGNNStreamProc.class,
3536
HashGNNMutateProc.class,
37+
HashGNNWriteProc.class,
3638
ScalePropertiesMutateProc.class
3739
);
3840
}

doc/modules/ROOT/pages/machine-learning/node-embeddings/hashgnn.adoc

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,48 @@ include::partial$/machine-learning/node-embeddings/hashgnn/specific-configuratio
342342
| configuration | Map | Configuration used for running the algorithm.
343343
|===
344344
======
345+
[.include-with-write]
346+
======
345347

346-
====
348+
.Run HashGNN in write mode on a named graph.
349+
[source, cypher, role=noplay]
350+
----
351+
CALL gds.hashgnn.write(
352+
graphName: String,
353+
configuration: Map
354+
) YIELD
355+
nodeCount: Integer,
356+
nodePropertiesWritten: Integer,
357+
preProcessingMillis: Integer,
358+
computeMillis: Integer,
359+
writeMillis: Integer,
360+
configuration: Map
361+
----
347362

363+
include::partial$/algorithms/common-configuration/common-parameters.adoc[]
364+
365+
.Configuration
366+
[opts="header",cols="3,2,3m,2,8"]
367+
|===
368+
| Name | Type | Default | Optional | Description
369+
include::partial$/algorithms/common-configuration/common-write-configuration-entries.adoc[]
370+
include::partial$/machine-learning/node-embeddings/hashgnn/specific-configuration.adoc[]
371+
|===
372+
373+
.Results
374+
[opts="header"]
375+
|===
376+
| Name | Type | Description
377+
| nodeCount | Integer | Number of nodes processed.
378+
| nodePropertiesWritten | Integer | Number of node properties written.
379+
| preProcessingMillis | Integer | Milliseconds for preprocessing the graph.
380+
| computeMillis | Integer | Milliseconds for running the algorithm.
381+
| writeMillis | Integer | Milliseconds for writing back results.
382+
| configuration | Map | Configuration used for running the algorithm.
383+
|===
384+
======
385+
386+
====
348387

349388
[[algorithms-embeddings-hashgnn-examples]]
350389
== Examples
@@ -702,6 +741,38 @@ YIELD nodePropertiesWritten
702741
The graph 'persons' now has a node property `hashgnn-embedding` which stores the node embedding for each node.
703742
To find out how to inspect the new schema of the in-memory graph, see xref:management-ops/graph-list.adoc[Listing graphs].
704743

744+
[[algorithms-embeddings-hashgnn-examples-write]]
745+
=== Write
746+
747+
include::partial$/algorithms/shared/examples-write-intro.adoc[]
748+
749+
[role=query-example]
750+
--
751+
.The following will run the algorithm in `write` mode:
752+
[source, cypher, role=noplay]
753+
----
754+
CALL gds.hashgnn.write(
755+
'persons',
756+
{
757+
writeProperty: 'hashgnn-embedding',
758+
heterogeneous: true,
759+
iterations: 2,
760+
embeddingDensity: 4,
761+
binarizeFeatures: {dimension: 6, threshold: 0.2},
762+
featureProperties: ['experience_scaled', 'sourness', 'sweetness', 'tropical'],
763+
randomSeed: 42
764+
}
765+
)
766+
YIELD nodePropertiesWritten
767+
----
768+
769+
[opts=header]
770+
.Results
771+
|===
772+
| nodePropertiesWritten
773+
| 11
774+
|===
775+
--
705776

706777
[[algorithms-embeddings-hashgnn-virtual-example]]
707778
=== Virtual example

doc/modules/ROOT/pages/operations-reference/algorithm-references.adoc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,13 @@
366366
| `gds.beta.graphSage.write.estimate` label:procedure[Procedure]
367367
| `gds.beta.graphSage.train` label:procedure[Procedure]
368368
| `gds.beta.graphSage.train.estimate` label:procedure[Procedure]
369-
.4+<.^|xref:machine-learning/node-embeddings/hashgnn.adoc[HashGNN]
369+
.6+<.^|xref:machine-learning/node-embeddings/hashgnn.adoc[HashGNN]
370370
| `gds.hashgnn.mutate` label:procedure[Procedure]
371371
| `gds.hashgnn.mutate.estimate` label:procedure[Procedure]
372372
| `gds.hashgnn.stream` label:procedure[Procedure]
373373
| `gds.hashgnn.stream.estimate` label:procedure[Procedure]
374+
| `gds.hashgnn.write` label:procedure[Procedure]
375+
| `gds.hashgnn.write.estimate` label:procedure[Procedure]
374376
.6+<.^|xref:machine-learning/node-embeddings/node2vec.adoc[Node2Vec]
375377
| `gds.node2vec.mutate` label:procedure[Procedure]
376378
| `gds.node2vec.mutate.estimate` label:procedure[Procedure]

open-packaging/src/test/java/org/neo4j/gds/OpenGdsProcedureSmokeTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ class OpenGdsProcedureSmokeTest extends BaseProcTest {
151151
"gds.hashgnn.mutate.estimate",
152152
"gds.hashgnn.stream",
153153
"gds.hashgnn.stream.estimate",
154+
"gds.hashgnn.write",
155+
"gds.hashgnn.write.estimate",
154156

155157
"gds.beta.pipeline.linkPrediction.addFeature",
156158
"gds.beta.pipeline.linkPrediction.addNodeProperty",
@@ -600,7 +602,7 @@ void countShouldMatch() {
600602
);
601603

602604
// If you find yourself updating this count, please also update the count in SmokeTest.kt
603-
int expectedCount = 441;
605+
int expectedCount = 443;
604606
assertEquals(
605607
expectedCount,
606608
returnedRows,
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings.hashgnn;
21+
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import org.neo4j.gds.BaseProcTest;
25+
import org.neo4j.gds.GdsCypher;
26+
import org.neo4j.gds.catalog.GraphProjectProc;
27+
import org.neo4j.gds.extension.Neo4jGraph;
28+
29+
import java.util.List;
30+
import java.util.Map;
31+
32+
import static org.assertj.core.api.Assertions.assertThat;
33+
import static org.assertj.core.api.InstanceOfAssertFactories.LONG;
34+
35+
class HashGNNWriteProcTest extends BaseProcTest {
36+
37+
@Neo4jGraph
38+
private static final String DB_CYPHER =
39+
"CREATE" +
40+
" (a:N {f1: 1, f2: [0.0, 0.0]})" +
41+
", (b:N {f1: 0, f2: [1.0, 0.0]})" +
42+
", (c:N {f1: 0, f2: [0.0, 1.0]})" +
43+
", (b)-[:R1]->(a)" +
44+
", (b)-[:R2]->(c)";
45+
46+
static String expectedMutatedGraph = "CREATE" +
47+
" (a:N {f1: 1, f2: [0.0, 0.0], embedding: [1.0, 0.0, 0.0]})" +
48+
", (b:N {f1: 0, f2: [1.0, 0.0], embedding: [1.0, 0.0, 1.0]})" +
49+
", (c:N {f1: 0, f2: [0.0, 1.0], embedding: [0.0, 0.0, 1.0]})" +
50+
", (b)-[:R1]->(a)" +
51+
", (b)-[:R2]->(c)";
52+
53+
@BeforeEach
54+
void setupGraph() throws Exception {
55+
registerProcedures(
56+
HashGNNWriteProc.class,
57+
GraphProjectProc.class
58+
);
59+
60+
String graphCreateQuery = GdsCypher.call("graph")
61+
.graphProject()
62+
.withNodeLabel("N")
63+
.withNodeProperty("f1")
64+
.withNodeProperty("f2")
65+
.withRelationshipType("R1")
66+
.withRelationshipType("R2")
67+
.yields();
68+
runQuery(graphCreateQuery);
69+
}
70+
71+
@Test
72+
void shouldWrite() {
73+
var query = GdsCypher.call("graph").algo("hashgnn")
74+
.writeMode().addParameter("heterogeneous", true)
75+
.addParameter("iterations", 2)
76+
.addParameter("embeddingDensity", 2)
77+
.addParameter("randomSeed", 42L)
78+
.addParameter("featureProperties", List.of("f1", "f2"))
79+
.addParameter("writeProperty", "embedding")
80+
.yields();
81+
82+
var rowCount = runQueryWithRowConsumer(query, row -> {
83+
assertThat(row.getNumber("preProcessingMillis"))
84+
.asInstanceOf(LONG)
85+
.isGreaterThan(-1L);
86+
87+
assertThat(row.getNumber("computeMillis"))
88+
.asInstanceOf(LONG)
89+
.isGreaterThan(-1L);
90+
91+
assertThat(row.getNumber("writeMillis"))
92+
.asInstanceOf(LONG)
93+
.isGreaterThan(-1L);
94+
95+
assertThat(row.getNumber("nodePropertiesWritten"))
96+
.asInstanceOf(LONG)
97+
.isEqualTo(3L);
98+
99+
assertThat(row.getNumber("nodeCount"))
100+
.asInstanceOf(LONG)
101+
.isEqualTo(3L);
102+
103+
assertThat(row.get("configuration"))
104+
.isInstanceOf(Map.class);
105+
});
106+
107+
assertThat(rowCount).isEqualTo(1L);
108+
109+
var nodePropertiesWritten = runQueryWithRowConsumer("MATCH (n) RETURN size(n.embedding) AS size", row -> {
110+
assertThat(row.getNumber("size")).asInstanceOf(LONG).isEqualTo(3L);
111+
});
112+
113+
assertThat(nodePropertiesWritten).isEqualTo(3L);
114+
}
115+
116+
}

0 commit comments

Comments
 (0)