Skip to content

Commit 13c6e1b

Browse files
Boruvka progress tracking
Co-authored-by: Veselin Nikolov <veselin.nikolov@neotechnology.com>
1 parent 2b3fcc0 commit 13c6e1b

File tree

5 files changed

+149
-20
lines changed

5 files changed

+149
-20
lines changed

algo/src/main/java/org/neo4j/gds/hdbscan/BoruvkaMST.java

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,11 @@ private BoruvkaMST(
5353
KdTree kdTree,
5454
ClosestDistanceInformationTracker closestDistanceTracker,
5555
HugeDoubleArray coreValues,
56-
long nodeCount, Concurrency concurrency
56+
long nodeCount,
57+
Concurrency concurrency,
58+
ProgressTracker progressTracker
5759
) {
58-
super(ProgressTracker.NULL_TRACKER);
60+
super(progressTracker);
5961
this.nodePropertyValues = nodePropertyValues;
6062
this.closestDistanceTracker = closestDistanceTracker;
6163
this.kdTree = kdTree;
@@ -75,7 +77,8 @@ public static BoruvkaMST createWithZeroCores(
7577
NodePropertyValues nodePropertyValues,
7678
KdTree kdTree,
7779
long nodeCount,
78-
Concurrency concurrency
80+
Concurrency concurrency,
81+
ProgressTracker progressTracker
7982
) {
8083
var zeroCores = HugeDoubleArray.newArray(nodeCount);
8184

@@ -85,7 +88,8 @@ public static BoruvkaMST createWithZeroCores(
8588
ClosestDistanceInformationTracker.create(nodeCount),
8689
zeroCores,
8790
nodeCount,
88-
concurrency
91+
concurrency,
92+
progressTracker
8993
);
9094
}
9195

@@ -94,22 +98,33 @@ public static BoruvkaMST create(
9498
KdTree kdTree,
9599
CoreResult coreResult,
96100
long nodeCount,
97-
Concurrency concurrency
101+
Concurrency concurrency,
102+
ProgressTracker progressTracker
98103
) {
99104
var cores = coreResult.createCoreArray();
100105
var closestTracker = ClosestDistanceInformationTracker.create(nodeCount, cores, coreResult);
101106

102-
return new BoruvkaMST(nodePropertyValues, kdTree, closestTracker, cores, nodeCount,concurrency);
107+
return new BoruvkaMST(nodePropertyValues,
108+
kdTree,
109+
closestTracker,
110+
cores,
111+
nodeCount,
112+
concurrency,
113+
progressTracker
114+
);
103115
}
104116

105117

106118
@Override
107119
public GeometricMSTResult compute() {
120+
progressTracker.beginSubTask();
108121
var kdRoot = kdTree.root();
109122
var rootId = kdRoot.id();
110123
while (!kdNodeSingleComponent.get(rootId)) {
111124
performIteration();
112125
}
126+
progressTracker.endSubTask();
127+
113128
return new GeometricMSTResult(edges, totalEdgeSum);
114129
}
115130

@@ -220,7 +235,7 @@ void mergeComponents() {
220235
this.edgeCount++;
221236
this.totalEdgeSum += distance;
222237

223-
unionFind.union(uComponent, vComponent);
238+
mergeComponents(uComponent,vComponent);
224239
}
225240

226241
}
@@ -265,6 +280,7 @@ boolean updateSingleComponent(KdNode node) {
265280

266281
void mergeComponents(long comp0, long comp1) {
267282
unionFind.union(comp0, comp1);
283+
progressTracker.logProgress();
268284
}
269285

270286
}

algo/src/main/java/org/neo4j/gds/hdbscan/HDBScan.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ GeometricMSTResult boruvka(KdTree kdTree, CoreResult coreResult) {
111111
kdTree,
112112
coreResult,
113113
nodes.nodeCount(),
114-
concurrency
114+
concurrency,
115+
progressTracker
115116
);
116117
return boruvkaMST.compute();
117118
}

algo/src/main/java/org/neo4j/gds/hdbscan/HDBScanProgressTrackerCreator.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,8 @@ static Task labellingTask(String name, long nodeCount){
4949
);
5050
}
5151

52+
static Task boruvkaTask(String name, long nodeCount){
53+
return Tasks.leaf(name,nodeCount - 1);
54+
}
55+
5256
}

algo/src/test/java/org/neo4j/gds/hdbscan/BoruvkaAlgorithmFunctionsTest.java

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.neo4j.gds.collections.ha.HugeDoubleArray;
2525
import org.neo4j.gds.collections.ha.HugeLongArray;
2626
import org.neo4j.gds.core.concurrency.Concurrency;
27+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2728

2829
import static org.assertj.core.api.Assertions.assertThat;
2930
import static org.mockito.ArgumentMatchers.anyLong;
@@ -42,7 +43,13 @@ void singleComponentShouldWorkOnLeaf(){
4243
1
4344
);
4445

45-
var boruvkaMST = BoruvkaMST.createWithZeroCores(null,kdTree,4,new Concurrency(1));
46+
var boruvkaMST = BoruvkaMST.createWithZeroCores(
47+
null,
48+
kdTree,
49+
4,
50+
new Concurrency(1),
51+
ProgressTracker.NULL_TRACKER
52+
);
4653

4754
assertThat(boruvkaMST.updateSingleComponent(kdNode)).isFalse();
4855
boruvkaMST.mergeComponents(0,1);
@@ -61,7 +68,15 @@ void singleComponentOrShouldWork(){
6168
null,
6269
1
6370
);
64-
var boruvkaMST = BoruvkaMST.createWithZeroCores(null,kdTree,4,new Concurrency(1));
71+
72+
var boruvkaMST = BoruvkaMST.createWithZeroCores(
73+
null,
74+
kdTree,
75+
4,
76+
new Concurrency(1),
77+
ProgressTracker.NULL_TRACKER
78+
);
79+
6580
KdNode kdNode1 = KdNode.createLeaf(1, 1, 2, null);
6681
KdNode kdNode2 = KdNode.createLeaf(2, 2, 4, null);
6782

@@ -94,7 +109,13 @@ void singleComponentShouldWorkOnSplitNode(){
94109
1
95110
);
96111

97-
var boruvkaMST = BoruvkaMST.createWithZeroCores(null,kdTree,8, new Concurrency(1));
112+
var boruvkaMST = BoruvkaMST.createWithZeroCores(
113+
null,
114+
kdTree,
115+
8,
116+
new Concurrency(1),
117+
ProgressTracker.NULL_TRACKER
118+
);
98119

99120
assertThat(boruvkaMST.updateSingleComponent(kdNode)).isFalse();
100121
boruvkaMST.mergeComponents(0,1);
@@ -133,7 +154,14 @@ public long nodeCount() {
133154
when(coreResult.createCoreArray()).thenReturn(HugeDoubleArray.of(0,0,10,10,0,0,0,0,0,0));
134155
when(coreResult.neighboursOf(anyLong())).thenReturn(new Neighbour[0]);
135156

136-
var boruvkaMST = BoruvkaMST.create(nodeProps,kdTree, coreResult,8,new Concurrency(1));
157+
var boruvkaMST = BoruvkaMST.create(
158+
nodeProps,
159+
kdTree,
160+
coreResult,
161+
8,
162+
new Concurrency(1),
163+
ProgressTracker.NULL_TRACKER
164+
);
137165

138166
assertThat(boruvkaMST.baseCase(0,1,nodeProps.doubleArrayValue(0),0)).isEqualTo(1); //distance
139167
assertThat(boruvkaMST.baseCase(2,3,nodeProps.doubleArrayValue(2),2)).isEqualTo(10); //corevalue
@@ -151,7 +179,14 @@ void baseCaseShouldIgnoreSameComponents(){
151179
when(coreResult.createCoreArray()).thenReturn(HugeDoubleArray.of(0,0,10,10));
152180
when(coreResult.neighboursOf(anyLong())).thenReturn(new Neighbour[0]);
153181

154-
var boruvkaMST = BoruvkaMST.create(nodeProps,kdTree, coreResult,8,new Concurrency(1));
182+
var boruvkaMST = BoruvkaMST.create(
183+
nodeProps,
184+
kdTree,
185+
coreResult,
186+
8,
187+
new Concurrency(1),
188+
ProgressTracker.NULL_TRACKER
189+
);
155190

156191
assertThat(boruvkaMST.baseCase(0,1,nodeProps.doubleArrayValue(0),1)).isEqualTo(-1); //distance
157192

@@ -163,7 +198,13 @@ void shouldPruneProperly(){
163198
DoubleArrayNodePropertyValues nodeProps=mock(DoubleArrayNodePropertyValues.class);
164199
var kdRoot = KdNode.createLeaf(0,0,2,mock(AABB.class));
165200
var kdTree =new KdTree(HugeLongArray.of(0,1,2),nodeProps,kdRoot,1);
166-
var boruvkaMST = BoruvkaMST.createWithZeroCores(nodeProps,kdTree, 3,new Concurrency(1));
201+
var boruvkaMST = BoruvkaMST.createWithZeroCores(
202+
nodeProps,
203+
kdTree,
204+
3,
205+
new Concurrency(1),
206+
ProgressTracker.NULL_TRACKER
207+
);
167208

168209
//prune based on distance
169210
boruvkaMST.tryUpdate(2,2,3,5);

algo/src/test/java/org/neo4j/gds/hdbscan/BoruvkaMSTTest.java

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,29 @@
1919
*/
2020
package org.neo4j.gds.hdbscan;
2121

22+
import org.assertj.core.api.Assertions;
2223
import org.assertj.core.data.Offset;
2324
import org.junit.jupiter.api.Nested;
25+
import org.junit.jupiter.api.Test;
2426
import org.junit.jupiter.params.ParameterizedTest;
2527
import org.junit.jupiter.params.provider.ValueSource;
28+
import org.neo4j.gds.compat.TestLog;
2629
import org.neo4j.gds.core.concurrency.Concurrency;
30+
import org.neo4j.gds.core.utils.logging.LoggerForProgressTrackingAdapter;
31+
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
2732
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
33+
import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker;
2834
import org.neo4j.gds.extension.GdlExtension;
2935
import org.neo4j.gds.extension.GdlGraph;
3036
import org.neo4j.gds.extension.Inject;
3137
import org.neo4j.gds.extension.TestGraph;
38+
import org.neo4j.gds.logging.GdsTestLog;
3239

3340
import java.util.List;
3441

3542
import static org.assertj.core.api.Assertions.assertThat;
43+
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
44+
import static org.neo4j.gds.assertj.Extractors.replaceTimings;
3645

3746
@GdlExtension
3847
class BoruvkaMSTTest {
@@ -62,13 +71,19 @@ class Case1 {
6271
@ValueSource(ints={1,4})
6372
void shouldReturnEuclideanMSTWithZeroCoreValues(int concurrency) {
6473
var nodePropertyValues = graph.nodeProperties("point");
65-
var kdTree = new KdTreeBuilder(graph, nodePropertyValues, 1, 1, ProgressTracker.NULL_TRACKER).build();
74+
var kdTree = new KdTreeBuilder(graph,
75+
nodePropertyValues,
76+
1,
77+
1,
78+
ProgressTracker.NULL_TRACKER
79+
).build();
6680

6781
var dualTree = BoruvkaMST.createWithZeroCores(
6882
nodePropertyValues,
6983
kdTree,
7084
graph.nodeCount(),
71-
new Concurrency(concurrency)
85+
new Concurrency(concurrency),
86+
ProgressTracker.NULL_TRACKER
7287
);
7388

7489
var result = dualTree.compute();
@@ -119,13 +134,20 @@ class Case2 {
119134
@ValueSource(ints={1,4})
120135
void shouldReturnEuclideanMSTWithZeroCoreValues(int concurrency) {
121136
var nodePropertyValues = graph.nodeProperties("point");
122-
var kdTree = new KdTreeBuilder(graph, nodePropertyValues, 1, 1, ProgressTracker.NULL_TRACKER).build();
137+
var kdTree = new KdTreeBuilder(
138+
graph,
139+
nodePropertyValues,
140+
1,
141+
1,
142+
ProgressTracker.NULL_TRACKER
143+
).build();
123144

124145
var dualTree = BoruvkaMST.createWithZeroCores(
125146
nodePropertyValues,
126147
kdTree,
127148
graph.nodeCount(),
128-
new Concurrency(concurrency)
149+
new Concurrency(concurrency),
150+
ProgressTracker.NULL_TRACKER
129151
);
130152

131153
var result = dualTree.compute();
@@ -176,13 +198,19 @@ class Case3 {
176198
@ValueSource(ints={1,4})
177199
void shouldReturnEuclideanMSTWithZeroCoreValues(int concurrency) {
178200
var nodePropertyValues = graph.nodeProperties("point");
179-
var kdTree = new KdTreeBuilder(graph, nodePropertyValues, 1, 1, ProgressTracker.NULL_TRACKER).build();
201+
var kdTree = new KdTreeBuilder(graph,
202+
nodePropertyValues,
203+
1,
204+
1,
205+
ProgressTracker.NULL_TRACKER
206+
).build();
180207

181208
var dualTree = BoruvkaMST.createWithZeroCores(
182209
nodePropertyValues,
183210
kdTree,
184211
graph.nodeCount(),
185-
new Concurrency(concurrency)
212+
new Concurrency(concurrency),
213+
ProgressTracker.NULL_TRACKER
186214
);
187215

188216
var result = dualTree.compute();
@@ -208,6 +236,45 @@ void shouldReturnEuclideanMSTWithZeroCoreValues(int concurrency) {
208236
);
209237
}
210238

239+
@Test
240+
void shouldLogProgress(){
241+
242+
var progressTask = HDBScanProgressTrackerCreator.boruvkaTask("boruvka",graph.nodeCount());
243+
var log = new GdsTestLog();
244+
var progressTracker = new TaskProgressTracker(progressTask, new LoggerForProgressTrackingAdapter(log), new Concurrency(1), EmptyTaskRegistryFactory.INSTANCE);
245+
246+
var nodePropertyValues = graph.nodeProperties("point");
247+
var kdTree = new KdTreeBuilder(graph,
248+
nodePropertyValues,
249+
1,
250+
1,
251+
ProgressTracker.NULL_TRACKER
252+
).build();
253+
254+
BoruvkaMST.createWithZeroCores(
255+
nodePropertyValues,
256+
kdTree,
257+
graph.nodeCount(),
258+
new Concurrency(1),
259+
progressTracker
260+
).compute();
261+
262+
Assertions.assertThat(log.getMessages(TestLog.INFO))
263+
.extracting(removingThreadId())
264+
.extracting(replaceTimings())
265+
.containsExactly(
266+
"boruvka :: Start",
267+
"boruvka 20%",
268+
"boruvka 40%",
269+
"boruvka 60%",
270+
"boruvka 80%",
271+
"boruvka 100%",
272+
"boruvka :: Finished"
273+
);
274+
275+
}
276+
211277
}
212278

279+
213280
}

0 commit comments

Comments
 (0)