Skip to content

Commit f5de17d

Browse files
Add procedures
1 parent a0689e3 commit f5de17d

File tree

9 files changed

+286
-21
lines changed

9 files changed

+286
-21
lines changed

applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/MemoryGuard.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ public <CONFIGURATION extends AlgoBaseConfig> void assertAlgorithmCanRun(
4242
Label label,
4343
DimensionTransformer dimensionTransformer
4444
) {
45-
// do nothing
4645
}
4746
};
4847

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds;
21+
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import org.neo4j.gds.beta.generator.GraphGenerateProc;
25+
import org.neo4j.gds.compat.GraphDatabaseApiProxy;
26+
import org.neo4j.gds.memory.MemoryProc;
27+
28+
import static org.assertj.core.api.Assertions.assertThat;
29+
import static org.assertj.core.api.InstanceOfAssertFactories.LONG;
30+
31+
class MemoryProcTest extends BaseProgressTest {
32+
33+
@BeforeEach
34+
void setUp() throws Exception {
35+
GraphDatabaseApiProxy.registerProcedures(
36+
db,
37+
MemoryProc.class,
38+
BaseProgressTestProc.class,
39+
GraphGenerateProc.class
40+
);
41+
42+
}
43+
44+
@Test
45+
void shouldReturnEmptyIfEmpty() {
46+
var rowCount = runQueryWithRowConsumer("CALL gds.listMemory", result->{} );
47+
assertThat(rowCount).isEqualTo(0L);
48+
}
49+
50+
@Test
51+
void shouldReturnEmptySummary() {
52+
var rowCount = runQueryWithRowConsumer("alice",
53+
"CALL gds.listMemory.summary",
54+
resultRow -> {
55+
assertThat(resultRow.getString("user")).isEqualTo("alice");
56+
assertThat(resultRow.getNumber("totalGraphMemory")).asInstanceOf(LONG).isEqualTo(0);
57+
assertThat(resultRow.getNumber("totalTasksMemory")).asInstanceOf(LONG).isEqualTo(0);
58+
});
59+
assertThat(rowCount).isEqualTo(1L);
60+
}
61+
62+
@Test
63+
void shouldListGraphsAccordingly() {
64+
runQuery("alice", " CALL gds.graph.generate('random',10,1)");
65+
var rowCountAlice = runQueryWithRowConsumer("alice",
66+
"CALL gds.listMemory",
67+
resultRow -> {
68+
assertThat(resultRow.getString("user")).isEqualTo("alice");
69+
assertThat(resultRow.getString("name")).isEqualTo("random");
70+
assertThat(resultRow.getString("entity")).isEqualTo("graph");
71+
assertThat(resultRow.getNumber("memoryInBytes")).asInstanceOf(LONG).isGreaterThan(0);
72+
});
73+
assertThat(rowCountAlice).isEqualTo(1L);
74+
75+
var rowCountBob = runQueryWithRowConsumer("bob",
76+
"CALL gds.listMemory",resultRow ->{}
77+
);
78+
79+
assertThat(rowCountBob).isEqualTo(0L);
80+
}
81+
82+
@Test
83+
void shouldSummarizeAccordingly() {
84+
runQuery("alice", " CALL gds.graph.generate('random',10,1)");
85+
var rowCountAlice = runQueryWithRowConsumer("alice",
86+
"CALL gds.listMemory",
87+
resultRow -> {
88+
assertThat(resultRow.getString("user")).isEqualTo("alice");
89+
assertThat(resultRow.getString("name")).isEqualTo("random");
90+
assertThat(resultRow.getString("entity")).isEqualTo("graph");
91+
assertThat(resultRow.getNumber("memoryInBytes")).asInstanceOf(LONG).isGreaterThan(0);
92+
93+
});
94+
assertThat(rowCountAlice).isEqualTo(1L);
95+
96+
var rowCountBob = runQueryWithRowConsumer("bob",
97+
"CALL gds.listMemory.summary",
98+
resultRow -> {
99+
assertThat(resultRow.getString("user")).isEqualTo("bob");
100+
assertThat(resultRow.getNumber("totalGraphMemory")).asInstanceOf(LONG).isEqualTo(0);
101+
assertThat(resultRow.getNumber("totalTasksMemory")).asInstanceOf(LONG).isEqualTo(0);
102+
});
103+
104+
assertThat(rowCountBob).isEqualTo(1L);
105+
}
106+
107+
@Test
108+
void canListRunningTask() {
109+
runQuery("alice","CALL gds.test.pl('foo',true,false)");
110+
var rowCountAlice = runQueryWithRowConsumer("alice",
111+
"CALL gds.listMemory",
112+
resultRow -> {
113+
assertThat(resultRow.getString("user")).isEqualTo("alice");
114+
assertThat(resultRow.getString("entity")).isNotEqualTo("graph");
115+
assertThat(resultRow.getNumber("memoryInBytes")).asInstanceOf(LONG).isGreaterThan(0);
116+
});
117+
assertThat(rowCountAlice).isEqualTo(1L);
118+
119+
var rowCountBob = runQueryWithRowConsumer("bob",
120+
"CALL gds.listMemory",resultRow ->{}
121+
);
122+
123+
assertThat(rowCountBob).isEqualTo(0L);
124+
}
125+
126+
@Test
127+
void canSummarizeRunningTask() {
128+
runQuery("alice","CALL gds.test.pl('foo',true,false)");
129+
130+
var rowCountAlice = runQueryWithRowConsumer("alice",
131+
"CALL gds.listMemory.summary",
132+
resultRow -> {
133+
assertThat(resultRow.getString("user")).isEqualTo("alice");
134+
assertThat(resultRow.getNumber("totalGraphMemory")).asInstanceOf(LONG).isEqualTo(0);
135+
assertThat(resultRow.getNumber("totalTasksMemory")).asInstanceOf(LONG).isGreaterThan(0);
136+
});
137+
assertThat(rowCountAlice).isEqualTo(1L);
138+
139+
var rowCountBob = runQueryWithRowConsumer("bob",
140+
"CALL gds.listMemory.summary",
141+
resultRow -> {
142+
assertThat(resultRow.getString("user")).isEqualTo("bob");
143+
assertThat(resultRow.getNumber("totalGraphMemory")).asInstanceOf(LONG).isEqualTo(0);
144+
assertThat(resultRow.getNumber("totalTasksMemory")).asInstanceOf(LONG).isEqualTo(0);
145+
});
146+
147+
assertThat(rowCountBob).isEqualTo(1L);
148+
}
149+
150+
}

proc/misc/src/main/java/org/neo4j/gds/memory/MemoryProc.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
*/
2020
package org.neo4j.gds.memory;
2121

22-
import org.neo4j.gds.procedures.operations.ProgressResult;
22+
import org.neo4j.gds.mem.UserEntityMemory;
23+
import org.neo4j.gds.mem.UserMemorySummary;
24+
import org.neo4j.gds.procedures.memory.MemoryFacade;
2325
import org.neo4j.procedure.Context;
2426
import org.neo4j.procedure.Description;
2527
import org.neo4j.procedure.Procedure;
@@ -34,7 +36,13 @@ public class MemoryProc {
3436

3537
@Procedure("gds.listMemory")
3638
@Description(DESCRIPTION)
37-
public Stream<ProgressResult> listMemory() {
38-
return facade.operations().listProgress(jobId);
39+
public Stream<UserEntityMemory> listMemory() {
40+
return facade.list();
41+
}
42+
43+
@Procedure("gds.listMemory.summary")
44+
@Description(DESCRIPTION)
45+
public Stream<UserMemorySummary> summary() {
46+
return facade.memorySummary();
3947
}
4048
}

proc/misc/src/test/java/org/neo4j/gds/BaseProgressTest.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020
package org.neo4j.gds;
2121

2222
import org.neo4j.gds.core.concurrency.Concurrency;
23+
import org.neo4j.gds.core.utils.progress.JobId;
2324
import org.neo4j.gds.core.utils.progress.ProgressFeatureSettings;
2425
import org.neo4j.gds.core.utils.progress.TaskRegistryExtension;
2526
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
2627
import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker;
2728
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
2829
import org.neo4j.gds.logging.Log;
2930
import org.neo4j.gds.mem.MemoryRange;
31+
import org.neo4j.gds.procedures.memory.MemoryFacade;
3032
import org.neo4j.procedure.Context;
3133
import org.neo4j.procedure.Name;
3234
import org.neo4j.procedure.Procedure;
@@ -55,6 +57,9 @@ public static class BaseProgressTestProc {
5557
@Context
5658
public TaskRegistryFactory taskRegistryFactory;
5759

60+
@Context
61+
public MemoryFacade memoryFacade;
62+
5863
@Procedure("gds.test.pl")
5964
public Stream<Bar> foo(
6065
@Name(value = "taskName") String taskName,
@@ -64,18 +69,21 @@ public Stream<Bar> foo(
6469
var task = Tasks.task(taskName, Tasks.leaf("leaf", 3));
6570
if (withMemoryEstimation) {
6671
task.setEstimatedMemoryRangeInBytes(MEMORY_ESTIMATION_RANGE);
72+
73+
memoryFacade.track(task.description(),new JobId(),task.estimatedMemoryRangeInBytes().max);
6774
}
75+
6876
if (withConcurrency) {
6977
task.setMaxConcurrency(new Concurrency(REQUESTED_CPU_CORES));
7078
}
79+
7180
var taskProgressTracker = new TaskProgressTracker(task, Log.noOpLog(), new Concurrency(1), taskRegistryFactory);
7281
taskProgressTracker.beginSubTask();
7382
taskProgressTracker.beginSubTask();
7483
taskProgressTracker.logProgress(1);
7584
return Stream.empty();
7685
}
7786

78-
7987
}
8088

8189
public static class Bar {

procedures/integration/src/main/java/org/neo4j/gds/procedures/integration/OpenGraphDataScienceExtensionBuilder.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,13 @@
3636
import org.neo4j.gds.core.utils.warnings.UserLogRegistryFactory;
3737
import org.neo4j.gds.logging.Log;
3838
import org.neo4j.gds.mem.MemoryTracker;
39-
import org.neo4j.gds.memory.MemoryFacade;
4039
import org.neo4j.gds.metrics.Metrics;
4140
import org.neo4j.gds.procedures.ExporterBuildersProviderService;
4241
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
4342
import org.neo4j.gds.procedures.TaskRegistryFactoryService;
43+
import org.neo4j.gds.procedures.UserAccessor;
4444
import org.neo4j.gds.procedures.UserLogServices;
45+
import org.neo4j.gds.procedures.memory.MemoryFacade;
4546
import org.neo4j.gds.settings.GdsSettings;
4647
import org.neo4j.graphdb.config.Configuration;
4748
import org.neo4j.kernel.api.procedure.GlobalProcedures;
@@ -156,9 +157,11 @@ public static Triple<OpenGraphDataScienceExtensionBuilder, TaskRegistryFactorySe
156157

157158
var componentRegistration = new ComponentRegistration(log, globalProcedures);
158159

159-
var memoryFacade = new MemoryFacade(memoryTracker);
160-
161-
componentRegistration.registerComponent("GDS Memory Facade", MemoryFacade.class, __ -> memoryFacade);
160+
componentRegistration.registerComponent("GDS Memory Facade", MemoryFacade.class, context -> {
161+
var userAccessor = new UserAccessor();
162+
var user = userAccessor.getUser(context.securityContext());
163+
return new MemoryFacade(user,memoryTracker);
164+
});
162165

163166
var graphDataScienceProviderFactory = new GraphDataScienceProceduresProviderFactory(
164167
log,

procedures/memory-facade/build.gradle

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,7 @@ dependencies {
88

99
implementation project(':progress-tracking')
1010

11+
testImplementation project(':test-utils')
12+
13+
1114
}

procedures/memory-facade/src/main/java/org/neo4j/gds/memory/MemoryFacade.java renamed to procedures/memory-facade/src/main/java/org/neo4j/gds/procedures/memory/MemoryFacade.java

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
* You should have received a copy of the GNU General Public License
1818
* along with this program. If not, see <http://www.gnu.org/licenses/>.
1919
*/
20-
package org.neo4j.gds.memory;
20+
package org.neo4j.gds.procedures.memory;
2121

22+
import org.neo4j.gds.api.User;
23+
import org.neo4j.gds.core.utils.progress.JobId;
2224
import org.neo4j.gds.mem.MemoryTracker;
2325
import org.neo4j.gds.mem.UserEntityMemory;
2426
import org.neo4j.gds.mem.UserMemorySummary;
@@ -28,26 +30,32 @@
2830
public class MemoryFacade {
2931

3032
private final MemoryTracker memoryTracker;
33+
private final User user;
3134

32-
public MemoryFacade(MemoryTracker memoryTracker){
35+
public MemoryFacade(User user,MemoryTracker memoryTracker){
3336
this.memoryTracker = memoryTracker;
37+
this.user = user;
3438
}
3539

36-
public Stream<UserEntityMemory> listUser(String user){
37-
return memoryTracker.listUser(user);
40+
public void track(String taskName, JobId jobId, long memoryEstimate) {
41+
memoryTracker.track(user.getUsername(), taskName,jobId,memoryEstimate);
3842
}
3943

40-
public Stream<UserEntityMemory> listAll(){
41-
return memoryTracker.listAll();
44+
public Stream<UserEntityMemory> list() {
45+
if (user.isAdmin()){
46+
return memoryTracker.listAll();
47+
}else{
48+
return memoryTracker.listUser(user.getUsername());
49+
}
4250
}
4351

44-
public UserMemorySummary memorySummary(String user){
45-
return memoryTracker.memorySummary(user);
46-
47-
}
4852

4953
public Stream<UserMemorySummary> memorySummary() {
50-
return memoryTracker.memorySummary();
54+
if (user.isAdmin()){
55+
return memoryTracker.memorySummary();
56+
}else{
57+
return Stream.of(memoryTracker.memorySummary(user.getUsername()));
58+
}
5159
}
5260

5361
}

0 commit comments

Comments
 (0)