Skip to content

Commit c204aed

Browse files
committed
Allow to switch off progress tracking per procedure
1 parent 0f2ebed commit c204aed

File tree

6 files changed

+178
-60
lines changed

6 files changed

+178
-60
lines changed

algo-common/src/main/java/org/neo4j/gds/AlgorithmFactory.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,20 @@ default ALGO build(
5555
TaskRegistryFactory taskRegistryFactory,
5656
UserLogRegistryFactory userLogRegistryFactory
5757
) {
58-
var progressTask = progressTask(graphOrGraphStore, configuration);
59-
var progressTracker = new TaskProgressTracker(
60-
progressTask,
61-
log,
62-
configuration.concurrency(),
63-
configuration.jobId(),
64-
taskRegistryFactory,
65-
userLogRegistryFactory
66-
);
58+
ProgressTracker progressTracker;
59+
if (configuration.logProgress()) {
60+
var progressTask = progressTask(graphOrGraphStore, configuration);
61+
progressTracker = new TaskProgressTracker(
62+
progressTask,
63+
log,
64+
configuration.concurrency(),
65+
configuration.jobId(),
66+
taskRegistryFactory,
67+
userLogRegistryFactory
68+
);
69+
} else {
70+
progressTracker = ProgressTracker.NULL_TRACKER;
71+
}
6772
return build(graphOrGraphStore, configuration, progressTracker);
6873
}
6974

algo/src/test/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityTest.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,25 @@ void shouldLogMessages(int topK, int concurrency) {
799799
);
800800
}
801801

802+
@ParameterizedTest(name = "topK = {0}, concurrency = {1}")
803+
@MethodSource("topKAndConcurrencies")
804+
void shouldNotLogMessagesWhenLoggingIsDisabled(int topK, int concurrency) {
805+
var graph = naturalGraph;
806+
var config = configBuilder().topN(100).topK(topK).concurrency(concurrency).logProgress(false).build();
807+
808+
var progressLog = Neo4jProxy.testLog();
809+
var nodeSimilarity = new NodeSimilarityFactory<>().build(
810+
graph,
811+
config,
812+
progressLog,
813+
EmptyTaskRegistryFactory.INSTANCE
814+
);
815+
816+
nodeSimilarity.compute();
817+
818+
assertThat(progressLog.getMessages(INFO)).isEmpty();
819+
}
820+
802821
@ParameterizedTest(name = "concurrency = {0}")
803822
@ValueSource(ints = {1,2})
804823
void shouldLogProgress(int concurrency) {

config-api/src/main/java/org/neo4j/gds/config/BaseConfig.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
public interface BaseConfig extends ToMapConvertible {
3333

3434
String SUDO_KEY = "sudo";
35+
String LOG_PROGRESS_KEY = "logProgress";
3536

3637
@Value.Parameter(false)
3738
@Configuration.Key("username")
@@ -45,6 +46,13 @@ default boolean sudo() {
4546
return false;
4647
}
4748

49+
@Value.Default
50+
@Value.Parameter(false)
51+
@Configuration.Key(LOG_PROGRESS_KEY)
52+
default boolean logProgress() {
53+
return true;
54+
}
55+
4856
@Configuration.CollectKeys
4957
@Value.Auxiliary
5058
@Value.Default

core/src/main/java/org/neo4j/gds/core/loading/CypherFactory.java

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -160,19 +160,23 @@ public CSRGraphStore build() {
160160

161161
@Override
162162
protected ProgressTracker initProgressTracker() {
163-
var task = Tasks.task(
164-
"Loading",
165-
Tasks.leaf("Nodes"),
166-
Tasks.leaf("Relationships", dimensions.relCountUpperBound())
167-
);
168-
return new TaskProgressTracker(
169-
task,
170-
loadingContext.log(),
171-
graphProjectConfig.readConcurrency(),
172-
graphProjectConfig.jobId(),
173-
loadingContext.taskRegistryFactory(),
174-
EmptyUserLogRegistryFactory.INSTANCE
175-
);
163+
if (graphProjectConfig.logProgress()) {
164+
var task = Tasks.task(
165+
"Loading",
166+
Tasks.leaf("Nodes"),
167+
Tasks.leaf("Relationships", dimensions.relCountUpperBound())
168+
);
169+
return new TaskProgressTracker(
170+
task,
171+
loadingContext.log(),
172+
graphProjectConfig.readConcurrency(),
173+
graphProjectConfig.jobId(),
174+
loadingContext.taskRegistryFactory(),
175+
EmptyUserLogRegistryFactory.INSTANCE
176+
);
177+
}
178+
179+
return ProgressTracker.NULL_TRACKER;
176180
}
177181

178182
private String nodeQuery() {

core/src/main/java/org/neo4j/gds/core/loading/NativeFactory.java

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -251,47 +251,51 @@ private static void relationshipEstimationAfterLoading(
251251

252252
@Override
253253
protected ProgressTracker initProgressTracker() {
254-
long relationshipCount = graphProjectConfig
255-
.relationshipProjections()
256-
.projections()
257-
.entrySet()
258-
.stream()
259-
.map(entry -> {
260-
long relCount = entry.getKey().name.equals("*")
261-
? dimensions.relationshipCounts().values().stream().reduce(Long::sum).orElse(0L)
262-
: dimensions.relationshipCounts().getOrDefault(entry.getKey(), 0L);
263-
264-
return entry.getValue().orientation() == Orientation.UNDIRECTED
265-
? relCount * 2
266-
: relCount;
267-
}).mapToLong(Long::longValue).sum();
268-
269-
var properties = IndexPropertyMappings.prepareProperties(
270-
graphProjectConfig,
271-
dimensions,
272-
loadingContext.transactionContext()
273-
);
254+
if (graphProjectConfig.logProgress()) {
255+
long relationshipCount = graphProjectConfig
256+
.relationshipProjections()
257+
.projections()
258+
.entrySet()
259+
.stream()
260+
.map(entry -> {
261+
long relCount = entry.getKey().name.equals("*")
262+
? dimensions.relationshipCounts().values().stream().reduce(Long::sum).orElse(0L)
263+
: dimensions.relationshipCounts().getOrDefault(entry.getKey(), 0L);
264+
265+
return entry.getValue().orientation() == Orientation.UNDIRECTED
266+
? relCount * 2
267+
: relCount;
268+
}).mapToLong(Long::longValue).sum();
269+
270+
var properties = IndexPropertyMappings.prepareProperties(
271+
graphProjectConfig,
272+
dimensions,
273+
loadingContext.transactionContext()
274+
);
274275

275-
List<Task> nodeTasks = properties.indexedProperties().isEmpty()
276-
? List.of(Tasks.leaf("Store Scan", dimensions.nodeCount()))
277-
: List.of(
278-
Tasks.leaf("Store Scan", dimensions.nodeCount()),
279-
Tasks.leaf("Property Index Scan", properties.indexedProperties().size() * dimensions.nodeCount())
276+
List<Task> nodeTasks = properties.indexedProperties().isEmpty()
277+
? List.of(Tasks.leaf("Store Scan", dimensions.nodeCount()))
278+
: List.of(
279+
Tasks.leaf("Store Scan", dimensions.nodeCount()),
280+
Tasks.leaf("Property Index Scan", properties.indexedProperties().size() * dimensions.nodeCount())
281+
);
282+
283+
var task = Tasks.task(
284+
"Loading",
285+
Tasks.task("Nodes", nodeTasks),
286+
Tasks.task("Relationships", Tasks.leaf("Store Scan", relationshipCount))
280287
);
288+
return new TaskProgressTracker(
289+
task,
290+
loadingContext.log(),
291+
graphProjectConfig.readConcurrency(),
292+
graphProjectConfig.jobId(),
293+
loadingContext.taskRegistryFactory(),
294+
EmptyUserLogRegistryFactory.INSTANCE
295+
);
296+
}
281297

282-
var task = Tasks.task(
283-
"Loading",
284-
Tasks.task("Nodes", nodeTasks),
285-
Tasks.task("Relationships", Tasks.leaf("Store Scan", relationshipCount))
286-
);
287-
return new TaskProgressTracker(
288-
task,
289-
loadingContext.log(),
290-
graphProjectConfig.readConcurrency(),
291-
graphProjectConfig.jobId(),
292-
loadingContext.taskRegistryFactory(),
293-
EmptyUserLogRegistryFactory.INSTANCE
294-
);
298+
return ProgressTracker.NULL_TRACKER;
295299
}
296300

297301
@Override
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.test;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.api.Graph;
24+
import org.neo4j.gds.compat.Neo4jProxy;
25+
import org.neo4j.gds.compat.TestLog;
26+
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
27+
import org.neo4j.gds.extension.GdlExtension;
28+
import org.neo4j.gds.extension.GdlGraph;
29+
import org.neo4j.gds.extension.Inject;
30+
31+
import static org.assertj.core.api.Assertions.assertThat;
32+
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
33+
import static org.neo4j.gds.assertj.Extractors.replaceTimings;
34+
35+
@GdlExtension
36+
class ProgressTrackingTest {
37+
38+
@GdlGraph
39+
static String GDL =
40+
"CREATE " +
41+
" ()-[:REL]->()," +
42+
" ()-[:REL2]->(),";
43+
44+
@Inject
45+
Graph graph;
46+
47+
@Test
48+
void shouldLogProgress() {
49+
var factory = new TestAlgorithmFactory<>();
50+
var testConfig = TestConfigImpl.builder().logProgress(true).build();
51+
var log = Neo4jProxy.testLog();
52+
53+
factory.build(graph, testConfig, log, TaskRegistryFactory.empty()).compute();
54+
55+
assertThat(log.getMessages(TestLog.INFO))
56+
.extracting(removingThreadId())
57+
.extracting(replaceTimings())
58+
.containsExactly(
59+
"TestAlgorithm :: Start",
60+
"TestAlgorithm 100%",
61+
"TestAlgorithm :: Finished"
62+
);
63+
64+
}
65+
66+
@Test
67+
void shouldNotLogProgress() {
68+
var factory = new TestAlgorithmFactory<>();
69+
var testConfig = TestConfigImpl.builder().logProgress(false).build();
70+
var log = Neo4jProxy.testLog();
71+
72+
factory.build(graph, testConfig, log, TaskRegistryFactory.empty()).compute();
73+
74+
assertThat(log.getMessages(TestLog.INFO))
75+
.as("When `logProgress` is set to `false` there should be no log messages")
76+
.isEmpty();
77+
}
78+
}

0 commit comments

Comments
 (0)