Skip to content

Commit 23c866f

Browse files
Add ProgressTracker that doesn't log inner tasks progress
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neotechnology.com>
1 parent c204aed commit 23c866f

File tree

14 files changed

+198
-61
lines changed

14 files changed

+198
-61
lines changed

algo-common/src/main/java/org/neo4j/gds/AlgorithmFactory.java

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
*/
2020
package org.neo4j.gds;
2121

22+
import org.jetbrains.annotations.NotNull;
2223
import org.neo4j.gds.config.AlgoBaseConfig;
2324
import org.neo4j.gds.core.GraphDimensions;
2425
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
2526
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
2627
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2728
import org.neo4j.gds.core.utils.progress.tasks.Task;
2829
import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker;
30+
import org.neo4j.gds.core.utils.progress.tasks.TaskTreeProgressTracker;
2931
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
3032
import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory;
3133
import org.neo4j.gds.core.utils.warnings.UserLogRegistryFactory;
@@ -54,10 +56,27 @@ default ALGO build(
5456
Log log,
5557
TaskRegistryFactory taskRegistryFactory,
5658
UserLogRegistryFactory userLogRegistryFactory
59+
) {
60+
var progressTracker = createProgressTracker(
61+
configuration,
62+
log,
63+
taskRegistryFactory,
64+
userLogRegistryFactory,
65+
progressTask(graphOrGraphStore, configuration)
66+
);
67+
return build(graphOrGraphStore, configuration, progressTracker);
68+
}
69+
70+
@NotNull
71+
private ProgressTracker createProgressTracker(
72+
CONFIG configuration,
73+
Log log,
74+
TaskRegistryFactory taskRegistryFactory,
75+
UserLogRegistryFactory userLogRegistryFactory,
76+
Task progressTask
5777
) {
5878
ProgressTracker progressTracker;
5979
if (configuration.logProgress()) {
60-
var progressTask = progressTask(graphOrGraphStore, configuration);
6180
progressTracker = new TaskProgressTracker(
6281
progressTask,
6382
log,
@@ -67,9 +86,16 @@ default ALGO build(
6786
userLogRegistryFactory
6887
);
6988
} else {
70-
progressTracker = ProgressTracker.NULL_TRACKER;
89+
progressTracker = new TaskTreeProgressTracker(
90+
progressTask,
91+
log,
92+
configuration.concurrency(),
93+
configuration.jobId(),
94+
taskRegistryFactory,
95+
userLogRegistryFactory
96+
);
7197
}
72-
return build(graphOrGraphStore, configuration, progressTracker);
98+
return progressTracker;
7399
}
74100

75101
ALGO build(

algo-test/src/main/java/org/neo4j/gds/test/TestAlgorithm.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ public TestAlgorithm(
4141

4242
@Override
4343
public TestAlgorithm compute() {
44-
progressTracker.beginSubTask();
44+
progressTracker.beginSubTask(100);
4545

4646
if (throwInCompute) {
4747
throw new IllegalStateException("boo");
4848
}
49+
progressTracker.logProgress(50);
4950
relationshipCount = graph.relationshipCount();
5051

5152
progressTracker.endSubTask();

algo/src/test/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityTest.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,19 @@ void shouldNotLogMessagesWhenLoggingIsDisabled(int topK, int concurrency) {
815815

816816
nodeSimilarity.compute();
817817

818-
assertThat(progressLog.getMessages(INFO)).isEmpty();
818+
assertThat(progressLog.getMessages(INFO))
819+
.as("When progress logging is disabled we only log `start`, `100%` and `finished`.")
820+
.extracting(removingThreadId())
821+
.containsExactly(
822+
"NodeSimilarity :: Start",
823+
"NodeSimilarity :: prepare :: Start",
824+
"NodeSimilarity :: prepare 100%",
825+
"NodeSimilarity :: prepare :: Finished",
826+
"NodeSimilarity :: compare node pairs :: Start",
827+
"NodeSimilarity :: compare node pairs 100%",
828+
"NodeSimilarity :: compare node pairs :: Finished",
829+
"NodeSimilarity :: Finished"
830+
);
819831
}
820832

821833
@ParameterizedTest(name = "concurrency = {0}")

core/src/main/java/org/neo4j/gds/core/loading/CypherFactory.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,12 @@ public CSRGraphStore build() {
160160

161161
@Override
162162
protected ProgressTracker initProgressTracker() {
163+
var task = Tasks.task(
164+
"Loading",
165+
Tasks.leaf("Nodes"),
166+
Tasks.leaf("Relationships", dimensions.relCountUpperBound())
167+
);
163168
if (graphProjectConfig.logProgress()) {
164-
var task = Tasks.task(
165-
"Loading",
166-
Tasks.leaf("Nodes"),
167-
Tasks.leaf("Relationships", dimensions.relCountUpperBound())
168-
);
169169
return new TaskProgressTracker(
170170
task,
171171
loadingContext.log(),
@@ -176,7 +176,14 @@ protected ProgressTracker initProgressTracker() {
176176
);
177177
}
178178

179-
return ProgressTracker.NULL_TRACKER;
179+
return new TaskProgressTracker(
180+
task,
181+
loadingContext.log(),
182+
graphProjectConfig.readConcurrency(),
183+
graphProjectConfig.jobId(),
184+
loadingContext.taskRegistryFactory(),
185+
EmptyUserLogRegistryFactory.INSTANCE
186+
);
180187
}
181188

182189
private String nodeQuery() {

core/src/main/java/org/neo4j/gds/core/loading/NativeFactory.java

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
4444
import org.neo4j.gds.core.utils.progress.tasks.Task;
4545
import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker;
46+
import org.neo4j.gds.core.utils.progress.tasks.TaskTreeProgressTracker;
4647
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
4748
import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory;
4849
import org.neo4j.internal.id.IdGeneratorFactory;
@@ -251,40 +252,41 @@ private static void relationshipEstimationAfterLoading(
251252

252253
@Override
253254
protected ProgressTracker initProgressTracker() {
254-
if (graphProjectConfig.logProgress()) {
255-
long relationshipCount = graphProjectConfig
256-
.relationshipProjections()
257-
.projections()
258-
.entrySet()
259-
.stream()
260-
.map(entry -> {
261-
long relCount = entry.getKey().name.equals("*")
262-
? dimensions.relationshipCounts().values().stream().reduce(Long::sum).orElse(0L)
263-
: dimensions.relationshipCounts().getOrDefault(entry.getKey(), 0L);
264-
265-
return entry.getValue().orientation() == Orientation.UNDIRECTED
266-
? relCount * 2
267-
: relCount;
268-
}).mapToLong(Long::longValue).sum();
269-
270-
var properties = IndexPropertyMappings.prepareProperties(
271-
graphProjectConfig,
272-
dimensions,
273-
loadingContext.transactionContext()
274-
);
255+
long relationshipCount = graphProjectConfig
256+
.relationshipProjections()
257+
.projections()
258+
.entrySet()
259+
.stream()
260+
.map(entry -> {
261+
long relCount = entry.getKey().name.equals("*")
262+
? dimensions.relationshipCounts().values().stream().reduce(Long::sum).orElse(0L)
263+
: dimensions.relationshipCounts().getOrDefault(entry.getKey(), 0L);
264+
265+
return entry.getValue().orientation() == Orientation.UNDIRECTED
266+
? relCount * 2
267+
: relCount;
268+
}).mapToLong(Long::longValue).sum();
269+
270+
var properties = IndexPropertyMappings.prepareProperties(
271+
graphProjectConfig,
272+
dimensions,
273+
loadingContext.transactionContext()
274+
);
275275

276-
List<Task> nodeTasks = properties.indexedProperties().isEmpty()
277-
? List.of(Tasks.leaf("Store Scan", dimensions.nodeCount()))
278-
: List.of(
279-
Tasks.leaf("Store Scan", dimensions.nodeCount()),
280-
Tasks.leaf("Property Index Scan", properties.indexedProperties().size() * dimensions.nodeCount())
281-
);
282-
283-
var task = Tasks.task(
284-
"Loading",
285-
Tasks.task("Nodes", nodeTasks),
286-
Tasks.task("Relationships", Tasks.leaf("Store Scan", relationshipCount))
276+
List<Task> nodeTasks = properties.indexedProperties().isEmpty()
277+
? List.of(Tasks.leaf("Store Scan", dimensions.nodeCount()))
278+
: List.of(
279+
Tasks.leaf("Store Scan", dimensions.nodeCount()),
280+
Tasks.leaf("Property Index Scan", properties.indexedProperties().size() * dimensions.nodeCount())
287281
);
282+
283+
var task = Tasks.task(
284+
"Loading",
285+
Tasks.task("Nodes", nodeTasks),
286+
Tasks.task("Relationships", Tasks.leaf("Store Scan", relationshipCount))
287+
);
288+
289+
if (graphProjectConfig.logProgress()) {
288290
return new TaskProgressTracker(
289291
task,
290292
loadingContext.log(),
@@ -295,7 +297,14 @@ protected ProgressTracker initProgressTracker() {
295297
);
296298
}
297299

298-
return ProgressTracker.NULL_TRACKER;
300+
return new TaskTreeProgressTracker(
301+
task,
302+
loadingContext.log(),
303+
graphProjectConfig.readConcurrency(),
304+
graphProjectConfig.jobId(),
305+
loadingContext.taskRegistryFactory(),
306+
EmptyUserLogRegistryFactory.INSTANCE
307+
);
299308
}
300309

301310
@Override

core/src/main/java/org/neo4j/gds/core/utils/progress/tasks/TaskProgressTracker.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public class TaskProgressTracker implements ProgressTracker {
5151
private long currentTotalSteps;
5252
private double progressLeftOvers;
5353

54-
private Runnable onError;
54+
private final Runnable onError;
5555

5656
public TaskProgressTracker(Task baseTask, Log log, int concurrency, TaskRegistryFactory taskRegistryFactory) {
5757
this(baseTask, log, concurrency, new JobId(), taskRegistryFactory, EmptyUserLogRegistryFactory.INSTANCE);
@@ -253,9 +253,9 @@ public void endSubTaskWithFailure(String expectedTaskDescription) {
253253
}
254254

255255
@TestOnly
256-
public Task currentSubTask() {
256+
Task currentSubTask() {
257257
requireCurrentTask();
258-
return currentTask.get();
258+
return currentTask.orElseThrow();
259259
}
260260

261261
@Nullable
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.core.utils.progress.tasks;
21+
22+
import org.neo4j.gds.core.utils.progress.JobId;
23+
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
24+
import org.neo4j.gds.core.utils.warnings.UserLogRegistryFactory;
25+
import org.neo4j.logging.Log;
26+
27+
public final class TaskTreeProgressTracker extends TaskProgressTracker {
28+
29+
public TaskTreeProgressTracker(
30+
Task baseTask,
31+
Log log,
32+
int concurrency,
33+
JobId jobId,
34+
TaskRegistryFactory taskRegistryFactory,
35+
UserLogRegistryFactory userLogRegistryFactory
36+
) {
37+
super(baseTask, log, concurrency, jobId, taskRegistryFactory, userLogRegistryFactory);
38+
}
39+
40+
@Override
41+
public void logSteps(long steps) {
42+
// NOOP
43+
}
44+
45+
@Override
46+
public void logProgress(long value) {
47+
// NOOP
48+
}
49+
50+
@Override
51+
public void logProgress(long value, String messageTemplate) {
52+
// NOOP
53+
}
54+
}

proc/catalog/src/test/java/org/neo4j/gds/catalog/GraphDropProcTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ void dropGraphFromCatalog() {
9494
new Condition<>(config -> {
9595
assertThat(config)
9696
.asInstanceOf(stringObjectMapAssertFactory())
97-
.hasSize(9)
97+
.hasSize(10)
9898
.containsEntry(
9999
"nodeProjection", map(
100100
"A", map(
@@ -126,6 +126,7 @@ void dropGraphFromCatalog() {
126126
intAssertConsumer(readConcurrency -> readConcurrency.isEqualTo(4))
127127
)
128128
.hasEntrySatisfying("sudo", booleanAssertConsumer(AbstractBooleanAssert::isFalse))
129+
.hasEntrySatisfying("logProgress", booleanAssertConsumer(AbstractBooleanAssert::isTrue))
129130
.doesNotContainKeys(
130131
"username",
131132
GraphProjectConfig.NODE_COUNT_KEY,

proc/catalog/src/test/java/org/neo4j/gds/catalog/GraphListProcTest.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void listASingleLabelRelationshipTypeProjection() {
104104
new Condition<>(config -> {
105105
assertThat(config)
106106
.asInstanceOf(stringObjectMapAssertFactory())
107-
.hasSize(9)
107+
.hasSize(10)
108108
.containsEntry(
109109
"nodeProjection", map(
110110
"A", map(
@@ -136,6 +136,7 @@ void listASingleLabelRelationshipTypeProjection() {
136136
intAssertConsumer(readConcurrency -> readConcurrency.isEqualTo(4))
137137
)
138138
.hasEntrySatisfying("sudo", booleanAssertConsumer(AbstractBooleanAssert::isFalse))
139+
.hasEntrySatisfying("logProgress", booleanAssertConsumer(AbstractBooleanAssert::isTrue))
139140
.doesNotContainKeys(
140141
GraphProjectConfig.NODE_COUNT_KEY,
141142
GraphProjectConfig.RELATIONSHIP_COUNT_KEY,
@@ -217,7 +218,7 @@ void listGeneratedGraph() {
217218
"configuration", new Condition<>(config -> {
218219
assertThat(config)
219220
.asInstanceOf(stringObjectMapAssertFactory())
220-
.hasSize(11)
221+
.hasSize(12)
221222
.containsEntry("nodeProjections", map(
222223
"10_Nodes", map(
223224
"label", "10_Nodes",
@@ -250,6 +251,7 @@ void listGeneratedGraph() {
250251
)
251252
.hasEntrySatisfying("allowSelfLoops", booleanAssertConsumer(AbstractBooleanAssert::isFalse))
252253
.hasEntrySatisfying("sudo", booleanAssertConsumer(AbstractBooleanAssert::isFalse))
254+
.hasEntrySatisfying("logProgress", booleanAssertConsumer(AbstractBooleanAssert::isTrue))
253255
.hasEntrySatisfying(
254256
"relationshipDistribution",
255257
stringAssertConsumer(relationshipDistribution -> relationshipDistribution.isEqualTo(
@@ -344,7 +346,7 @@ void listCypherProjection() {
344346
"configuration", new Condition<>(config -> {
345347
assertThat(config)
346348
.asInstanceOf(stringObjectMapAssertFactory())
347-
.hasSize(8)
349+
.hasSize(9)
348350
.hasEntrySatisfying(
349351
"relationshipQuery",
350352
stringAssertConsumer(relationshipQuery -> relationshipQuery.isEqualTo(
@@ -360,6 +362,7 @@ void listCypherProjection() {
360362
stringAssertConsumer(nodeQuery -> nodeQuery.isEqualTo(ALL_NODES_QUERY))
361363
)
362364
.hasEntrySatisfying("sudo", booleanAssertConsumer(AbstractBooleanAssert::isTrue))
365+
.hasEntrySatisfying("logProgress", booleanAssertConsumer(AbstractBooleanAssert::isTrue))
363366
.hasEntrySatisfying(
364367
"readConcurrency",
365368
intAssertConsumer(readConcurrency -> readConcurrency.isEqualTo(4))
@@ -417,11 +420,12 @@ void listCypherAggregation() {
417420
"configuration", new Condition<>(config -> {
418421
assertThat(config)
419422
.asInstanceOf(stringObjectMapAssertFactory())
420-
.hasSize(4)
423+
.hasSize(5)
421424
.hasEntrySatisfying("creationTime", creationTimeAssertConsumer())
422425
.hasEntrySatisfying("jobId", jobId -> assertThat(jobId).isNotNull())
423426
.hasEntrySatisfying("undirectedRelationshipTypes", t -> assertThat(t).isEqualTo(List.of()))
424-
.hasEntrySatisfying("inverseIndexedRelationshipTypes", t -> assertThat(t).isEqualTo(List.of()));
427+
.hasEntrySatisfying("inverseIndexedRelationshipTypes", t -> assertThat(t).isEqualTo(List.of()))
428+
.hasEntrySatisfying("logProgress", booleanAssertConsumer(AbstractBooleanAssert::isTrue));
425429

426430
return true;
427431
}, "Assert Cypher Aggregation `configuration` map"),

proc/embeddings/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageIntegrationTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ private void dropModel() {
113113
),
114114
"creationTime", isA(ZonedDateTime.class),
115115
"trainConfig", allOf(
116-
aMapWithSize(19),
116+
aMapWithSize(20),
117117
hasEntry("modelName", modelName),
118118
hasEntry("aggregator", "MEAN"),
119119
hasEntry("activationFunction", "SIGMOID")

0 commit comments

Comments
 (0)