Skip to content

Commit c3b1d74

Browse files
Add some semblance of progress tracking
1 parent f39bc49 commit c3b1d74

File tree

3 files changed

+58
-7
lines changed

3 files changed

+58
-7
lines changed

alpha/alpha-algo/src/main/java/org/neo4j/gds/impl/spanningtree/KSpanningTree.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ public SpanningTree compute() {
7777

7878
prim.setTerminationFlag(getTerminationFlag());
7979
SpanningTree spanningTree = prim.compute();
80-
return combineApproach(spanningTree);
80+
81+
var outputTree = combineApproach(spanningTree);
82+
progressTracker.endSubTask();
83+
return outputTree;
8184
}
8285

8386
@NotNull
@@ -121,7 +124,9 @@ private SpanningTree cutLeafApproach(SpanningTree spanningTree) {
121124
HugeDoubleArray costToParent = HugeDoubleArray.newArray(graph.nodeCount());
122125

123126
double totalCost = init(parent, costToParent, spanningTree);
127+
long numberOfDeletions = spanningTree.effectiveNodeCount() - k;
124128

129+
progressTracker.beginSubTask(numberOfDeletions);
125130
//calculate degree of each node in MST
126131
for (long nodeId = 0; nodeId < graph.nodeCount(); ++nodeId) {
127132
var nodeParent = parent.get(nodeId);
@@ -145,7 +150,6 @@ private SpanningTree cutLeafApproach(SpanningTree spanningTree) {
145150
}
146151
}
147152

148-
long numberOfDeletions = spanningTree.effectiveNodeCount() - k;
149153
for (long i = 0; i < numberOfDeletions; ++i) {
150154
var nextNode = priorityQueue.pop();
151155
long affectedNode;
@@ -184,9 +188,10 @@ private SpanningTree cutLeafApproach(SpanningTree spanningTree) {
184188
associatedCost = costToParent.get(affectedNode);
185189
}
186190
priorityQueue.add(affectedNode, associatedCost);
187-
188191
}
192+
progressTracker.logProgress();
189193
}
194+
progressTracker.endSubTask();
190195
return new SpanningTree(-1, graph.nodeCount(), k, parent, costToParent, totalCost);
191196
}
192197

@@ -218,8 +223,10 @@ private SpanningTree growApproach(SpanningTree spanningTree) {
218223
priorityQueue.add(startNodeId, 0);
219224
long root = startNodeId; //current root is startNodeId
220225
long nodesInTree = 0;
226+
progressTracker.beginSubTask(graph.nodeCount());
221227
while (!priorityQueue.isEmpty()) {
222228
long node = priorityQueue.top();
229+
progressTracker.logProgress();
223230
double associatedCost = priorityQueue.cost(node);
224231
priorityQueue.pop();
225232
long nodeParent = parent.get(node);
@@ -333,7 +340,7 @@ private SpanningTree growApproach(SpanningTree spanningTree) {
333340
}
334341
return true;
335342
});
336-
343+
progressTracker.endSubTask();
337344
return new SpanningTree(-1, graph.nodeCount(), k, parent, costToParent, totalCost);
338345

339346
}

alpha/alpha-algo/src/main/java/org/neo4j/gds/impl/spanningtree/KSpanningTreeAlgorithmFactory.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ public Task progressTask(
5252
) {
5353
return Tasks.task(
5454
taskName(),
55-
Tasks.leaf("SpanningTree", graph.nodeCount()),
56-
Tasks.leaf("Add relationship weights"),
57-
Tasks.leaf("Remove relationships")
55+
Tasks.leaf("SpanningTree", graph.relationshipCount()),
56+
Tasks.leaf("Remove relationships 1"),
57+
Tasks.leaf("Remove relationships 2")
58+
5859
);
5960
}
6061

alpha/alpha-algo/src/test/java/org/neo4j/gds/impl/spanningtree/KSpanningTreeTest.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
import org.junit.jupiter.params.ParameterizedTest;
2626
import org.junit.jupiter.params.provider.CsvSource;
2727
import org.neo4j.gds.Orientation;
28+
import org.neo4j.gds.TestProgressTracker;
2829
import org.neo4j.gds.api.Graph;
30+
import org.neo4j.gds.compat.Neo4jProxy;
31+
import org.neo4j.gds.compat.TestLog;
32+
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
2933
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3034
import org.neo4j.gds.extension.GdlExtension;
3135
import org.neo4j.gds.extension.GdlGraph;
@@ -37,6 +41,8 @@
3741
import java.util.HashSet;
3842

3943
import static org.assertj.core.api.Assertions.assertThat;
44+
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
45+
import static org.neo4j.gds.assertj.Extractors.replaceTimings;
4046

4147
/**
4248
* 1
@@ -286,4 +292,41 @@ void worstCaseForPruningLeaves() {
286292

287293
}
288294

295+
@Test
296+
void shouldLogProgress() {
297+
var config = KSpanningTreeBaseConfigImpl.builder().sourceNode(idFunction.of("a")).k(2).build();
298+
var factory = new KSpanningTreeAlgorithmFactory<>();
299+
var log = Neo4jProxy.testLog();
300+
var progressTracker = new TestProgressTracker(
301+
factory.progressTask(graph, config),
302+
log,
303+
1,
304+
EmptyTaskRegistryFactory.INSTANCE
305+
);
306+
factory.build(graph, config, progressTracker).compute();
307+
assertThat(log.getMessages(TestLog.INFO))
308+
.extracting(removingThreadId())
309+
.extracting(replaceTimings())
310+
.containsExactly(
311+
"KSpanningTree :: Start",
312+
"KSpanningTree :: SpanningTree :: Start",
313+
"KSpanningTree :: SpanningTree 30%",
314+
"KSpanningTree :: SpanningTree 50%",
315+
"KSpanningTree :: SpanningTree 80%",
316+
"KSpanningTree :: SpanningTree 100%",
317+
"KSpanningTree :: SpanningTree :: Finished",
318+
"KSpanningTree :: Remove relationships 1 :: Start",
319+
"KSpanningTree :: Remove relationships 1 50%",
320+
"KSpanningTree :: Remove relationships 1 100%",
321+
"KSpanningTree :: Remove relationships 1 :: Finished",
322+
"KSpanningTree :: Remove relationships 2 :: Start",
323+
"KSpanningTree :: Remove relationships 2 20%",
324+
"KSpanningTree :: Remove relationships 2 40%",
325+
"KSpanningTree :: Remove relationships 2 60%",
326+
"KSpanningTree :: Remove relationships 2 100%",
327+
"KSpanningTree :: Remove relationships 2 :: Finished",
328+
"KSpanningTree :: Finished"
329+
);
330+
}
331+
289332
}

0 commit comments

Comments
 (0)