Skip to content

Commit 7d27120

Browse files
committed
Allow querying TaskStore over all users
This is intended for admins.
1 parent a912ba3 commit 7d27120

File tree

7 files changed

+117
-21
lines changed

7 files changed

+117
-21
lines changed

core/src/main/java/org/neo4j/gds/core/utils/progress/EmptyTaskStore.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,23 @@ public void store(String username, JobId jobId, Task task) {}
3535
@Override
3636
public void remove(String username, JobId jobId) {}
3737

38+
@Override
39+
public Stream<UserTask> query() {
40+
return Stream.empty();
41+
}
42+
43+
@Override
44+
public Stream<UserTask> query(JobId jobId) {
45+
return Stream.empty();
46+
}
47+
3848
@Override
3949
public @NotNull Map<JobId, Task> query(String username) {
4050
return Map.of();
4151
}
4252

4353
@Override
44-
public Optional<Task> query(String username, JobId jobId) {
54+
public Optional<UserTask> query(String username, JobId jobId) {
4555
return Optional.empty();
4656
}
4757

core/src/main/java/org/neo4j/gds/core/utils/progress/GlobalTaskStore.java

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,34 @@ public void remove(String username, JobId jobId) {
5454
}
5555
}
5656

57+
@Override
58+
public Stream<UserTask> query() {
59+
return registeredTasks
60+
.entrySet()
61+
.stream()
62+
.flatMap(tasksPerUsers -> tasksPerUsers
63+
.getValue()
64+
.entrySet()
65+
.stream()
66+
.map(jobTask -> ImmutableUserTask.of(tasksPerUsers.getKey(), jobTask.getKey(), jobTask.getValue())));
67+
}
68+
69+
@Override
70+
public Stream<UserTask> query(JobId jobId) {
71+
return query().filter(userTask -> userTask.jobId().equals(jobId));
72+
}
73+
5774
@Override
5875
public @NotNull Map<JobId, Task> query(String username) {
5976
return registeredTasks.getOrDefault(username, Map.of());
6077
}
6178

6279
@Override
63-
public Optional<Task> query(String username, JobId jobId) {
80+
public Optional<UserTask> query(String username, JobId jobId) {
6481
if (registeredTasks.containsKey(username)) {
65-
return Optional.ofNullable(registeredTasks.get(username).get(jobId));
82+
return Optional
83+
.ofNullable(registeredTasks.get(username).get(jobId))
84+
.map(task -> ImmutableUserTask.of(username, jobId, task));
6685
}
6786
return Optional.empty();
6887
}

core/src/main/java/org/neo4j/gds/core/utils/progress/TaskStore.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.neo4j.gds.core.utils.progress;
2121

2222
import org.jetbrains.annotations.NotNull;
23+
import org.neo4j.gds.annotation.ValueClass;
2324
import org.neo4j.gds.core.utils.progress.tasks.Task;
2425

2526
import java.util.Map;
@@ -32,14 +33,27 @@ public interface TaskStore {
3233

3334
void remove(String username, JobId jobId);
3435

36+
Stream<UserTask> query();
37+
38+
Stream<UserTask> query(JobId jobId);
39+
3540
@NotNull
3641
Map<JobId, Task> query(String username);
3742

38-
Optional<Task> query(String username, JobId jobId);
43+
Optional<UserTask> query(String username, JobId jobId);
3944

4045
Stream<Task> taskStream();
4146

4247
boolean isEmpty();
4348

4449
long taskCount();
50+
51+
@ValueClass
52+
interface UserTask {
53+
String username();
54+
55+
JobId jobId();
56+
57+
Task task();
58+
}
4559
}

core/src/test/java/org/neo4j/gds/core/utils/progress/GlobalTaskStoreTest.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,15 @@ void shouldCountAcrossUsers() {
5858

5959
assertThat(taskStore.taskCount()).isEqualTo(3);
6060
}
61+
62+
@Test
63+
void shouldQueryMultipleUsers() {
64+
var taskStore = new GlobalTaskStore();
65+
taskStore.store("alice", new JobId("42"), Tasks.leaf("leaf"));
66+
taskStore.store("bob", new JobId("1337"), Tasks.leaf("other"));
67+
68+
assertThat(taskStore.query()).hasSize(2);
69+
assertThat(taskStore.query(new JobId("42"))).hasSize(1);
70+
assertThat(taskStore.query(new JobId(""))).hasSize(0);
71+
}
6172
}

proc/misc/src/main/java/org/neo4j/gds/ListProgressProc.java

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import java.util.ArrayList;
3939
import java.util.List;
4040
import java.util.Map;
41+
import java.util.stream.Collectors;
4142
import java.util.stream.Stream;
4243

4344
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
@@ -60,21 +61,53 @@ public Stream<ProgressResult> listProgress(
6061
}
6162

6263
private Stream<ProgressResult> jobsSummaryView() {
63-
return taskStore.query(username()).entrySet().stream().map(ProgressResult::fromTaskStoreEntry);
64+
if (isGdsAdmin()) {
65+
return taskStore.query().map(ProgressResult::fromTaskStoreEntry);
66+
} else {
67+
return taskStore
68+
.query(username())
69+
.entrySet()
70+
.stream()
71+
.map(entry -> ProgressResult.fromTaskStoreEntry(username(), entry));
72+
}
6473
}
6574

6675
private Stream<ProgressResult> jobDetailView(String jobIdAsString) {
6776
var jobId = new JobId(jobIdAsString);
68-
var task = taskStore.query(username(), jobId).orElseThrow(
69-
() -> new IllegalArgumentException(formatWithLocale("No task with job id `%s` was found.", jobIdAsString))
70-
);
71-
var jobProgressVisitor = new JobProgressVisitor(jobId);
72-
TaskTraversal.visitPreOrderWithDepth(task, jobProgressVisitor);
77+
78+
if (isGdsAdmin()) {
79+
var progressResults = taskStore
80+
.query(jobId)
81+
.flatMap(ListProgressProc::jobProgress)
82+
.collect(Collectors.toList());
83+
84+
if (progressResults.isEmpty()) {
85+
throw new IllegalArgumentException(formatWithLocale(
86+
"No task with job id `%s` was found.",
87+
jobIdAsString
88+
));
89+
}
90+
91+
return progressResults.stream();
92+
} else {
93+
return taskStore.query(username(), jobId).map(ListProgressProc::jobProgress).orElseThrow(
94+
() -> new IllegalArgumentException(formatWithLocale(
95+
"No task with job id `%s` was found.",
96+
jobIdAsString
97+
))
98+
);
99+
}
100+
}
101+
102+
private static Stream<ProgressResult> jobProgress(TaskStore.UserTask userTask) {
103+
var jobProgressVisitor = new JobProgressVisitor(userTask.jobId(), userTask.username());
104+
TaskTraversal.visitPreOrderWithDepth(userTask.task(), jobProgressVisitor);
73105
return jobProgressVisitor.progressRowsStream();
74106
}
75107

76108
@SuppressWarnings("unused")
77109
public static class ProgressResult {
110+
public String username;
78111
public String jobId;
79112
public String taskName;
80113
public String progress;
@@ -83,22 +116,27 @@ public static class ProgressResult {
83116
public LocalTimeValue timeStarted;
84117
public String elapsedTime;
85118

86-
static ProgressResult fromTaskStoreEntry(Map.Entry<JobId, Task> taskStoreEntry) {
119+
static ProgressResult fromTaskStoreEntry(String username, Map.Entry<JobId, Task> taskStoreEntry) {
87120
var jobId = taskStoreEntry.getKey();
88121
var task = taskStoreEntry.getValue();
89-
return new ProgressResult(task, jobId, task.description());
122+
return new ProgressResult(username, task, jobId, task.description());
123+
}
124+
125+
static ProgressResult fromTaskStoreEntry(TaskStore.UserTask userTask) {
126+
return new ProgressResult(userTask.username(), userTask.task(), userTask.jobId(), userTask.task().description());
90127
}
91128

92-
static ProgressResult fromTaskWithDepth(Task task, JobId jobId, int depth) {
129+
static ProgressResult fromTaskWithDepth(String username, Task task, JobId jobId, int depth) {
93130
var treeViewTaskName = StructuredOutputHelper.treeViewDescription(task.description(), depth);
94-
return new ProgressResult(task, jobId, treeViewTaskName);
131+
return new ProgressResult(username, task, jobId, treeViewTaskName);
95132
}
96133

97-
public ProgressResult(Task task, JobId jobId, String taskName) {
134+
public ProgressResult(String username, Task task, JobId jobId, String taskName) {
98135
var progressContainer = task.getProgress();
99136

100137
this.jobId = jobId.asString();
101138
this.taskName = taskName;
139+
this.username = username;
102140
this.progress = StructuredOutputHelper.computeProgress(progressContainer);
103141
this.progressBar = StructuredOutputHelper.progressBar(progressContainer, PROGRESS_BAR_LENGTH);
104142
this.status = task.status().name();
@@ -132,10 +170,12 @@ private String prettyElapsedTime(Task task) {
132170
public static class JobProgressVisitor extends DepthAwareTaskVisitor {
133171

134172
private final JobId jobId;
173+
private final String username;
135174
private final List<ProgressResult> progressRows;
136175

137-
JobProgressVisitor(JobId jobId) {
176+
JobProgressVisitor(JobId jobId, String username) {
138177
this.jobId = jobId;
178+
this.username = username;
139179
this.progressRows = new ArrayList<>();
140180
}
141181

@@ -145,7 +185,7 @@ Stream<ProgressResult> progressRowsStream() {
145185

146186
@Override
147187
public void visit(Task task) {
148-
progressRows.add(ProgressResult.fromTaskWithDepth(task, jobId, depth()));
188+
progressRows.add(ProgressResult.fromTaskWithDepth(username, task, jobId, depth()));
149189
}
150190
}
151191
}

proc/misc/src/test/java/org/neo4j/gds/ListProgressProcTest.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.junit.jupiter.api.Test;
2424
import org.neo4j.gds.beta.generator.GraphGenerateProc;
2525
import org.neo4j.gds.compat.GraphDatabaseApiProxy;
26+
import org.neo4j.gds.core.Username;
2627
import org.neo4j.gds.core.utils.RenamesCurrentThread;
2728
import org.neo4j.gds.core.utils.progress.JobId;
2829
import org.neo4j.gds.embeddings.fastrp.FastRP;
@@ -72,11 +73,12 @@ void canListProgressEvent() {
7273
runQuery("CALL gds.test.pl('foo')");
7374
assertCypherResult(
7475
"CALL gds.beta.listProgress() " +
75-
"YIELD taskName, progress, progressBar, status, timeStarted, elapsedTime " +
76-
"RETURN taskName, progress, progressBar, status, timeStarted, elapsedTime ",
76+
"YIELD username, taskName, progress, progressBar, status, timeStarted, elapsedTime " +
77+
"RETURN username, taskName, progress, progressBar, status, timeStarted, elapsedTime ",
7778
List.of(
7879
Map.of(
7980
"taskName","foo",
81+
"username", Username.EMPTY_USERNAME.username(),
8082
"progress", "33.33%",
8183
"progressBar", "[###~~~~~~~]",
8284
"status", "RUNNING",

test-utils/src/main/java/org/neo4j/gds/InspectableTestProgressTracker.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,13 @@ public void logProgress(long progress) {
7777
@Override
7878
public void beginSubTask() {
7979
super.beginSubTask();
80-
progressHistory.add(taskStore.query(userName, jobId).map(Task::getProgress));
80+
progressHistory.add(taskStore.query(userName, jobId).map(userTask -> userTask.task().getProgress()));
8181
}
8282

8383
@Override
8484
public void endSubTask() {
8585
super.endSubTask();
86-
progressHistory.add(taskStore.query(userName, jobId).map(Task::getProgress));
86+
progressHistory.add(taskStore.query(userName, jobId).map(userTask -> userTask.task().getProgress()));
8787
}
8888

8989
@Override

0 commit comments

Comments
 (0)