Skip to content

Commit 8b45573

Browse files
Progress tracking for Steiner Tree
1 parent ce8eb14 commit 8b45573

File tree

6 files changed

+149
-22
lines changed

6 files changed

+149
-22
lines changed

algo/src/main/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithm.java

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ public class ShortestPathsSteinerAlgorithm extends Algorithm<SteinerTreeResult>
4949
private final boolean applyRerouting;
5050
private final double delta;
5151
private final ExecutorService executorService;
52-
5352
private int binSizeThreshold;
5453

5554
public ShortestPathsSteinerAlgorithm(
@@ -59,9 +58,10 @@ public ShortestPathsSteinerAlgorithm(
5958
double delta,
6059
int concurrency,
6160
boolean applyRerouting,
62-
ExecutorService executorService
61+
ExecutorService executorService,
62+
ProgressTracker progressTracker
6363
) {
64-
super(ProgressTracker.NULL_TRACKER);
64+
super(progressTracker);
6565
this.graph = graph;
6666
this.sourceId = sourceId;
6767
this.terminals = terminals;
@@ -82,9 +82,10 @@ public ShortestPathsSteinerAlgorithm(
8282
int concurrency,
8383
boolean applyRerouting,
8484
int binSizeThreshold,
85-
ExecutorService executorService
85+
ExecutorService executorService,
86+
ProgressTracker progressTracker
8687
) {
87-
super(ProgressTracker.NULL_TRACKER);
88+
super(progressTracker);
8889
this.graph = graph;
8990
this.sourceId = sourceId;
9091
this.terminals = terminals;
@@ -112,7 +113,10 @@ private BitSet createTerminals() {
112113

113114
@Override
114115
public SteinerTreeResult compute() {
115-
116+
progressTracker.beginSubTask("SteinerTree");
117+
if (applyRerouting) {
118+
progressTracker.beginSubTask("Main");
119+
}
116120
HugeLongArray parent = HugeLongArray.newArray(graph.nodeCount());
117121
HugeDoubleArray parentCost = HugeDoubleArray.newArray(graph.nodeCount());
118122
ParallelUtil.parallelForEachNode(graph.nodeCount(), concurrency, v -> {
@@ -134,8 +138,10 @@ public SteinerTreeResult compute() {
134138
});
135139

136140
if (applyRerouting) {
141+
progressTracker.endSubTask("Main");
137142
reroute(parent, parentCost, totalCost, effectiveNodeCount);
138143
}
144+
progressTracker.endSubTask("SteinerTree");
139145
return SteinerTreeResult.of(
140146
parent,
141147
parentCost,
@@ -187,7 +193,6 @@ private void processPath(
187193

188194
private DijkstraResult runShortestPaths() {
189195

190-
// var steinerBasedDijkstra = new SteinerBasedDijkstra(graph, sourceId, isTerminal);
191196
var steinerBasedDelta = new SteinerBasedDeltaStepping(
192197
graph,
193198
sourceId,
@@ -196,7 +201,7 @@ private DijkstraResult runShortestPaths() {
196201
concurrency,
197202
binSizeThreshold,
198203
executorService,
199-
ProgressTracker.NULL_TRACKER
204+
progressTracker
200205
);
201206

202207
return steinerBasedDelta.compute();
@@ -293,6 +298,7 @@ private void reroute(
293298
DoubleAdder totalCost,
294299
LongAdder effectiveNodeCount
295300
) {
301+
progressTracker.beginSubTask("Rerouting");
296302
//First, represent the tree as an LinkCutTree:
297303
// This is a dynamic tree (can answer connectivity like UnionFind)
298304
// but can also do some other cool stuff like answering path queries
@@ -318,13 +324,13 @@ private void reroute(
318324
return true;
319325
});
320326
}
327+
progressTracker.logProgress();
321328
return true;
322-
323329
});
324330
if (didReroutes.isTrue()) {
325331
cutNodesAfterRerouting(parent, parentCost, totalCost, effectiveNodeCount);
326332
}
327-
333+
progressTracker.endSubTask("Rerouting");
328334
}
329335

330336

algo/src/main/java/org/neo4j/gds/steiner/SteinerBasedDeltaStepping.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ private void syncPhase(List<SteinerBasedDeltaTask> tasks, int currentBin, Atomic
143143
task.setBinIndex(currentBin);
144144
}
145145
ParallelUtil.run(tasks, executorService);
146-
progressTracker.endSubTask();
147146
}
148147

149148
private long nextTerminal(HugeLongPriorityQueue terminalQueue) {
@@ -170,6 +169,8 @@ private boolean updateSteinerTree(
170169
metTerminals.increment();
171170
unvisitedTerminal.flip(terminalId);
172171

172+
progressTracker.logProgress();
173+
173174
if (metTerminals.longValue() == numOfTerminals) { //if we have found paths to all terminals, terminate early
174175
return true;
175176
}

algo/src/main/java/org/neo4j/gds/steiner/SteinerTreeAlgorithmFactory.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import org.neo4j.gds.api.Graph;
2424
import org.neo4j.gds.core.concurrency.Pools;
2525
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
26+
import org.neo4j.gds.core.utils.progress.tasks.Task;
27+
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
2628

2729
import java.util.List;
2830
import java.util.stream.Collectors;
@@ -50,7 +52,8 @@ public ShortestPathsSteinerAlgorithm build(
5052
configuration.delta(),
5153
configuration.concurrency(),
5254
configuration.applyRerouting(),
53-
Pools.DEFAULT
55+
Pools.DEFAULT,
56+
progressTracker
5457
);
5558

5659
}
@@ -59,4 +62,18 @@ public ShortestPathsSteinerAlgorithm build(
5962
public String taskName() {
6063
return "SteinerTree";
6164
}
65+
66+
@Override
67+
public Task progressTask(Graph graph, CONFIG config) {
68+
var targetNodesSize = config.targetNodes().size();
69+
if (config.applyRerouting()) {
70+
long nodeCount = graph.nodeCount();
71+
return Tasks.task(taskName(), List.of(
72+
Tasks.leaf("Main", targetNodesSize),
73+
Tasks.leaf("Rerouting", nodeCount)
74+
));
75+
} else {
76+
return Tasks.leaf(taskName(), targetNodesSize);
77+
}
78+
}
6279
}

algo/src/test/java/org/neo4j/gds/steiner/ShortestPathSteinerAlgorithmExtendedTest.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ void shouldWorkCorrectly(double delta, int binSizeThreshold) {
166166
false,
167167
binSizeThreshold,
168168
//setting custom threshold for such a small graph allows to not examine everything in a single iteration
169-
Pools.DEFAULT
169+
Pools.DEFAULT,
170+
ProgressTracker.NULL_TRACKER
170171
).compute();
171172

172173
long[] parentArray = new long[]{ShortestPathsSteinerAlgorithm.ROOTNODE, a[0], a[1], a[2], a[3], a[4]};
@@ -185,7 +186,8 @@ void shouldWorkCorrectlyWithLineGraph() {
185186
2.0,
186187
1,
187188
false,
188-
Pools.DEFAULT
189+
Pools.DEFAULT,
190+
ProgressTracker.NULL_TRACKER
189191
)
190192
.compute();
191193

@@ -236,7 +238,8 @@ void shouldWorkIfRevisitsVertices() {
236238
2.0,
237239
1,
238240
false,
239-
Pools.DEFAULT
241+
Pools.DEFAULT,
242+
ProgressTracker.NULL_TRACKER
240243
).compute();
241244

242245
long[] parentArray = new long[]{
@@ -264,7 +267,8 @@ void shouldWorkOnTriangle() {
264267
2.0,
265268
1,
266269
false,
267-
Pools.DEFAULT
270+
Pools.DEFAULT,
271+
ProgressTracker.NULL_TRACKER
268272
).compute();
269273

270274
long[] parentArray = new long[]{ShortestPathsSteinerAlgorithm.ROOTNODE, a[0], a[1], a[2]};

algo/src/test/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithmReroutingTest.java

Lines changed: 102 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222
import org.junit.jupiter.api.Assertions;
2323
import org.junit.jupiter.api.Test;
2424
import org.neo4j.gds.Orientation;
25+
import org.neo4j.gds.TestProgressTracker;
26+
import org.neo4j.gds.compat.Neo4jProxy;
27+
import org.neo4j.gds.compat.TestLog;
2528
import org.neo4j.gds.core.concurrency.Pools;
29+
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
30+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
31+
import org.neo4j.gds.core.utils.progress.tasks.Task;
2632
import org.neo4j.gds.extension.GdlExtension;
2733
import org.neo4j.gds.extension.GdlGraph;
2834
import org.neo4j.gds.extension.IdFunction;
@@ -32,6 +38,8 @@
3238
import java.util.List;
3339

3440
import static org.assertj.core.api.Assertions.assertThat;
41+
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
42+
import static org.neo4j.gds.assertj.Extractors.replaceTimings;
3543

3644
@GdlExtension
3745
class ShortestPathsSteinerAlgorithmReroutingTest {
@@ -87,7 +95,8 @@ void shouldPruneUnusedIfRerouting() {
8795
2.0,
8896
1,
8997
false,
90-
Pools.DEFAULT
98+
Pools.DEFAULT,
99+
ProgressTracker.NULL_TRACKER
91100
).compute();
92101
assertThat(steinerResult.totalCost()).isEqualTo(7.0);
93102
assertThat(steinerResult.effectiveNodeCount()).isEqualTo(5);
@@ -100,7 +109,8 @@ void shouldPruneUnusedIfRerouting() {
100109
2.0,
101110
1,
102111
true,
103-
Pools.DEFAULT
112+
Pools.DEFAULT,
113+
ProgressTracker.NULL_TRACKER
104114
).compute();
105115
assertThat(steinerResultWithReroute.totalCost()).isEqualTo(4.0);
106116
assertThat(steinerResultWithReroute.effectiveNodeCount()).isEqualTo(3);
@@ -119,7 +129,8 @@ void rerouteShouldNotCreateLoops() {
119129
2.0,
120130
1,
121131
true,
122-
Pools.DEFAULT
132+
Pools.DEFAULT,
133+
ProgressTracker.NULL_TRACKER
123134
).compute();
124135
var parent = steinerResult.parentArray().toArray();
125136

@@ -144,7 +155,8 @@ void shouldWorkForUnreachableAndReachableTerminals() {
144155
2.0,
145156
1,
146157
true,
147-
Pools.DEFAULT
158+
Pools.DEFAULT,
159+
ProgressTracker.NULL_TRACKER
148160
).compute();
149161
assertThat(steinerTreeResult.effectiveTargetNodesCount()).isEqualTo(2);
150162
});
@@ -163,12 +175,97 @@ void shouldWorkIfNoReachableTerminals() {
163175
2.0,
164176
1,
165177
true,
166-
Pools.DEFAULT
178+
Pools.DEFAULT,
179+
ProgressTracker.NULL_TRACKER
167180
).compute();
168181
assertThat(steinerTreeResult.effectiveTargetNodesCount()).isEqualTo(0);
169182
assertThat(steinerTreeResult.effectiveNodeCount()).isEqualTo(1);
170183

171184
});
185+
}
186+
187+
@Test
188+
void shouldLogProgress() {
189+
190+
var sourceId = graph.toOriginalNodeId(idFunction.of("a0"));
191+
var target1 = graph.toOriginalNodeId(idFunction.of("a3"));
192+
var target2 = graph.toOriginalNodeId(idFunction.of("a4"));
193+
194+
var config = SteinerTreeStatsConfigImpl
195+
.builder()
196+
.sourceNode(sourceId)
197+
.targetNodes(List.of(target1, target2))
198+
.build();
172199

200+
var steinerTreeAlgorithmFactory = new SteinerTreeAlgorithmFactory();
201+
var log = Neo4jProxy.testLog();
202+
Task baseTask = steinerTreeAlgorithmFactory.progressTask(graph, config);
203+
var progressTracker = new TestProgressTracker(
204+
baseTask,
205+
log,
206+
4,
207+
EmptyTaskRegistryFactory.INSTANCE
208+
);
209+
210+
211+
steinerTreeAlgorithmFactory.build(graph, config, progressTracker).compute();
212+
213+
assertThat(log.getMessages(TestLog.INFO))
214+
.extracting(removingThreadId())
215+
.extracting(replaceTimings())
216+
.containsExactly(
217+
"SteinerTree :: Start",
218+
"SteinerTree 50%",
219+
"SteinerTree 100%",
220+
"SteinerTree :: Finished"
221+
);
173222
}
223+
224+
@Test
225+
void shouldLogProgressWithRerouting() {
226+
227+
var sourceId = graph.toOriginalNodeId(idFunction.of("a0"));
228+
var target1 = graph.toOriginalNodeId(idFunction.of("a3"));
229+
var target2 = graph.toOriginalNodeId(idFunction.of("a4"));
230+
231+
var config = SteinerTreeStatsConfigImpl
232+
.builder()
233+
.sourceNode(sourceId)
234+
.applyRerouting(true)
235+
.targetNodes(List.of(target1, target2))
236+
.build();
237+
238+
var steinerTreeAlgorithmFactory = new SteinerTreeAlgorithmFactory();
239+
var log = Neo4jProxy.testLog();
240+
Task baseTask = steinerTreeAlgorithmFactory.progressTask(graph, config);
241+
var progressTracker = new TestProgressTracker(
242+
baseTask,
243+
log,
244+
4,
245+
EmptyTaskRegistryFactory.INSTANCE
246+
);
247+
248+
steinerTreeAlgorithmFactory.build(graph, config, progressTracker).compute();
249+
250+
assertThat(log.getMessages(TestLog.INFO))
251+
.extracting(removingThreadId())
252+
.extracting(replaceTimings())
253+
.containsExactly(
254+
"SteinerTree :: Start",
255+
"SteinerTree :: Main :: Start",
256+
"SteinerTree :: Main 50%",
257+
"SteinerTree :: Main 100%",
258+
"SteinerTree :: Main :: Finished",
259+
"SteinerTree :: Rerouting :: Start",
260+
"SteinerTree :: Rerouting 16%",
261+
"SteinerTree :: Rerouting 33%",
262+
"SteinerTree :: Rerouting 50%",
263+
"SteinerTree :: Rerouting 66%",
264+
"SteinerTree :: Rerouting 83%",
265+
"SteinerTree :: Rerouting 100%",
266+
"SteinerTree :: Rerouting :: Finished",
267+
"SteinerTree :: Finished"
268+
);
269+
}
270+
174271
}

algo/src/test/java/org/neo4j/gds/steiner/ShortestPathsSteinerAlgorithmTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.junit.jupiter.api.Test;
2323
import org.neo4j.gds.Orientation;
2424
import org.neo4j.gds.core.concurrency.Pools;
25+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2526
import org.neo4j.gds.extension.GdlExtension;
2627
import org.neo4j.gds.extension.GdlGraph;
2728
import org.neo4j.gds.extension.IdFunction;
@@ -78,7 +79,8 @@ void shouldWorkCorrectly() {
7879
2.0,
7980
1,
8081
false,
81-
Pools.DEFAULT
82+
Pools.DEFAULT,
83+
ProgressTracker.NULL_TRACKER
8284
).compute();
8385
var pruned = ShortestPathsSteinerAlgorithm.PRUNED;
8486
var rootnode = ShortestPathsSteinerAlgorithm.ROOTNODE;

0 commit comments

Comments
 (0)