Skip to content

Commit bd15381

Browse files
committed
Implement NodeEmbeddingComputeFacade
1 parent 6efff34 commit bd15381

File tree

7 files changed

+561
-5
lines changed

7 files changed

+561
-5
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/fastrp/FastRPResult.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,8 @@
2121

2222
import org.neo4j.gds.collections.ha.HugeObjectArray;
2323

24-
public record FastRPResult(HugeObjectArray<float[]> embeddings){}
25-
24+
public record FastRPResult(HugeObjectArray<float[]> embeddings) {
25+
public static FastRPResult empty() {
26+
return new FastRPResult(HugeObjectArray.newArray(float[].class, 0L));
27+
}
28+
}

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNNResult.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919
*/
2020
package org.neo4j.gds.embeddings.hashgnn;
2121

22+
import org.neo4j.gds.api.properties.nodes.EmptyDoubleArrayNodePropertyValues;
2223
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
2324

24-
public record HashGNNResult(NodePropertyValues embeddings) {}
25+
public record HashGNNResult(NodePropertyValues embeddings) {
26+
27+
public static HashGNNResult empty() {
28+
return new HashGNNResult(EmptyDoubleArrayNodePropertyValues.INSTANCE);
29+
}
30+
}

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecResult.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,17 @@
2222
import org.neo4j.gds.collections.ha.HugeObjectArray;
2323
import org.neo4j.gds.ml.core.tensor.FloatVector;
2424

25+
import java.util.Collections;
2526
import java.util.List;
2627

27-
public record Node2VecResult(HugeObjectArray<FloatVector> embeddings,List<Double> lossPerIteration)
28-
{ }
28+
public record Node2VecResult(
29+
HugeObjectArray<FloatVector> embeddings,
30+
List<Double> lossPerIteration
31+
) {
32+
public static Node2VecResult empty() {
33+
return new Node2VecResult(
34+
HugeObjectArray.newArray(FloatVector.class, 0L),
35+
Collections.emptyList()
36+
);
37+
}
38+
}

algorithms-compute-facade/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ dependencies {
1212
implementation project(':core')
1313
implementation project(':core-api')
1414
implementation project(':logging')
15+
implementation project(':ml-core')
16+
implementation project(':node-embeddings-params')
1517
implementation project(':path-finding-algorithms')
1618
implementation project(':path-finding-params')
1719
implementation project(':progress-tracking')
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings;
21+
22+
import org.neo4j.gds.NodeEmbeddingsAlgorithmTasks;
23+
import org.neo4j.gds.ProgressTrackerFactory;
24+
import org.neo4j.gds.api.Graph;
25+
import org.neo4j.gds.async.AsyncAlgorithmCaller;
26+
import org.neo4j.gds.core.utils.progress.JobId;
27+
import org.neo4j.gds.embeddings.fastrp.FastRP;
28+
import org.neo4j.gds.embeddings.fastrp.FastRPParameters;
29+
import org.neo4j.gds.embeddings.fastrp.FastRPResult;
30+
import org.neo4j.gds.embeddings.hashgnn.HashGNN;
31+
import org.neo4j.gds.embeddings.hashgnn.HashGNNParameters;
32+
import org.neo4j.gds.embeddings.hashgnn.HashGNNResult;
33+
import org.neo4j.gds.embeddings.node2vec.Node2Vec;
34+
import org.neo4j.gds.embeddings.node2vec.Node2VecParameters;
35+
import org.neo4j.gds.embeddings.node2vec.Node2VecResult;
36+
import org.neo4j.gds.ml.core.features.FeatureExtraction;
37+
import org.neo4j.gds.result.TimedAlgorithmResult;
38+
import org.neo4j.gds.termination.TerminationFlag;
39+
40+
import java.util.List;
41+
import java.util.concurrent.CompletableFuture;
42+
43+
public class NodeEmbeddingComputeFacade {
44+
45+
// Global dependencies
46+
// This is created with its own ExecutorService workerPool,
47+
// which determines how many algorithms can run in parallel.
48+
private final AsyncAlgorithmCaller algorithmCaller;
49+
private final ProgressTrackerFactory progressTrackerFactory;
50+
51+
// Request scope dependencies
52+
private final TerminationFlag terminationFlag;
53+
54+
// Local dependencies
55+
private final NodeEmbeddingsAlgorithmTasks tasks = new NodeEmbeddingsAlgorithmTasks();
56+
57+
58+
public NodeEmbeddingComputeFacade(
59+
AsyncAlgorithmCaller algorithmCaller,
60+
ProgressTrackerFactory progressTrackerFactory,
61+
TerminationFlag terminationFlag
62+
) {
63+
this.algorithmCaller = algorithmCaller;
64+
this.progressTrackerFactory = progressTrackerFactory;
65+
this.terminationFlag = terminationFlag;
66+
}
67+
68+
69+
public CompletableFuture<TimedAlgorithmResult<FastRPResult>> fastRP(
70+
Graph graph,
71+
FastRPParameters parameters,
72+
JobId jobId,
73+
boolean logProgress
74+
) {
75+
if (graph.isEmpty()) {
76+
return CompletableFuture.completedFuture(TimedAlgorithmResult.empty(FastRPResult.empty()));
77+
}
78+
79+
var progressTracker = progressTrackerFactory.create(
80+
tasks.fastRP(graph, parameters),
81+
jobId,
82+
parameters.concurrency(),
83+
logProgress
84+
);
85+
86+
var featureExtractors = FeatureExtraction.propertyExtractors(graph, parameters.featureProperties());
87+
88+
var fastRP = new FastRP(
89+
graph,
90+
parameters,
91+
10_000,
92+
featureExtractors,
93+
progressTracker,
94+
terminationFlag
95+
);
96+
97+
return algorithmCaller.run(
98+
fastRP::compute,
99+
jobId
100+
);
101+
}
102+
103+
public CompletableFuture<TimedAlgorithmResult<HashGNNResult>> hashGnn(
104+
Graph graph,
105+
HashGNNParameters parameters,
106+
List<String> relationshipTypes,
107+
JobId jobId,
108+
boolean logProgress
109+
) {
110+
if (graph.isEmpty()) {
111+
return CompletableFuture.completedFuture(TimedAlgorithmResult.empty(HashGNNResult.empty()));
112+
}
113+
114+
var progressTracker = progressTrackerFactory.create(
115+
tasks.hashGNN(graph, parameters, relationshipTypes),
116+
jobId,
117+
parameters.concurrency(),
118+
logProgress
119+
);
120+
121+
var hashGNN = new HashGNN(
122+
graph,
123+
parameters,
124+
progressTracker,
125+
terminationFlag
126+
);
127+
128+
return algorithmCaller.run(
129+
hashGNN::compute,
130+
jobId
131+
);
132+
}
133+
134+
public CompletableFuture<TimedAlgorithmResult<Node2VecResult>> node2Vec(
135+
Graph graph,
136+
Node2VecParameters parameters,
137+
JobId jobId,
138+
boolean logProgress
139+
) {
140+
if (graph.isEmpty()) {
141+
return CompletableFuture.completedFuture(TimedAlgorithmResult.empty(Node2VecResult.empty()));
142+
}
143+
144+
var progressTracker = progressTrackerFactory.create(
145+
tasks.node2Vec(graph, parameters),
146+
jobId,
147+
parameters.concurrency(),
148+
logProgress
149+
);
150+
151+
var node2Vec = Node2Vec.create(
152+
graph,
153+
parameters,
154+
progressTracker,
155+
terminationFlag
156+
);
157+
158+
return algorithmCaller.run(
159+
node2Vec::compute,
160+
jobId
161+
);
162+
}
163+
}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings;
21+
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.api.extension.ExtendWith;
25+
import org.mockito.Mock;
26+
import org.mockito.junit.jupiter.MockitoExtension;
27+
import org.neo4j.gds.ProgressTrackerFactory;
28+
import org.neo4j.gds.api.Graph;
29+
import org.neo4j.gds.async.AsyncAlgorithmCaller;
30+
import org.neo4j.gds.core.utils.progress.JobId;
31+
import org.neo4j.gds.embeddings.fastrp.FastRPParameters;
32+
import org.neo4j.gds.embeddings.hashgnn.HashGNNParameters;
33+
import org.neo4j.gds.embeddings.node2vec.Node2VecParameters;
34+
import org.neo4j.gds.termination.TerminationFlag;
35+
36+
import java.util.List;
37+
38+
import static org.assertj.core.api.Assertions.assertThat;
39+
import static org.mockito.Mockito.mock;
40+
import static org.mockito.Mockito.verifyNoInteractions;
41+
import static org.mockito.Mockito.when;
42+
43+
@ExtendWith(MockitoExtension.class)
44+
class NodeEmbeddingComputeFacadeEmptyGraphTest {
45+
@Mock
46+
private Graph graph;
47+
48+
@Mock
49+
private ProgressTrackerFactory progressTrackerFactoryMock;
50+
51+
@Mock
52+
private AsyncAlgorithmCaller algorithmCallerMock;
53+
54+
@Mock
55+
private JobId jobIdMock;
56+
57+
private NodeEmbeddingComputeFacade facade;
58+
59+
@BeforeEach
60+
void setUp() {
61+
when(graph.isEmpty()).thenReturn(true);
62+
facade = new NodeEmbeddingComputeFacade(
63+
algorithmCallerMock,
64+
progressTrackerFactoryMock,
65+
TerminationFlag.RUNNING_TRUE
66+
);
67+
}
68+
69+
@Test
70+
void fastRP() {
71+
var future = facade.fastRP(
72+
graph,
73+
mock(FastRPParameters.class),
74+
jobIdMock,
75+
true
76+
);
77+
var result = future.join();
78+
assertThat(result.result().embeddings().size()).isZero();
79+
80+
verifyNoInteractions(progressTrackerFactoryMock);
81+
verifyNoInteractions(algorithmCallerMock);
82+
}
83+
84+
@Test
85+
void hashGnn() {
86+
var future = facade.hashGnn(
87+
graph,
88+
mock(HashGNNParameters.class),
89+
List.of("R"),
90+
jobIdMock,
91+
true
92+
);
93+
var result = future.join();
94+
var algorithmResult = result.result();
95+
assertThat(algorithmResult.embeddings().nodeCount()).isZero();
96+
97+
verifyNoInteractions(progressTrackerFactoryMock);
98+
verifyNoInteractions(algorithmCallerMock);
99+
}
100+
101+
@Test
102+
void node2Vec() {
103+
var future = facade.node2Vec(
104+
graph,
105+
mock(Node2VecParameters.class),
106+
jobIdMock,
107+
true
108+
);
109+
var result = future.join();
110+
var algorithmResult = result.result();
111+
assertThat(algorithmResult.embeddings().size()).isZero();
112+
assertThat(algorithmResult.lossPerIteration()).isEmpty();
113+
114+
verifyNoInteractions(progressTrackerFactoryMock);
115+
verifyNoInteractions(algorithmCallerMock);
116+
}
117+
118+
}

0 commit comments

Comments
 (0)