Skip to content

Commit c1211ba

Browse files
KD-Tree progress tracking
Co-authored-by: Veselin Nikolov <veselin.nikolov@neotechnology.com>
1 parent d623614 commit c1211ba

File tree

8 files changed

+184
-19
lines changed

8 files changed

+184
-19
lines changed

algo/src/main/java/org/neo4j/gds/hdbscan/HDBScan.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,13 @@ ClusterHierarchy createClusterHierarchy(DualTreeMSTResult dualTreeMSTResult){
104104
}
105105

106106
KdTree buildKDTree() {
107-
var builder = new KdTreeBuilder(nodes, nodePropertyValues, concurrency.value(), leafSize);
107+
var builder = new KdTreeBuilder(
108+
nodes,
109+
nodePropertyValues,
110+
concurrency.value(),
111+
leafSize,
112+
progressTracker
113+
);
108114
return builder.build();
109115
}
110116
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.hdbscan;
21+
22+
import org.neo4j.gds.core.utils.progress.tasks.Task;
23+
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
24+
25+
public class HDBScanProgressTrackerCreator {
26+
27+
static Task kdBuildingTask(String name, long nodeCount){
28+
return Tasks.leaf(name,nodeCount);
29+
}
30+
31+
}

algo/src/main/java/org/neo4j/gds/hdbscan/KdTreeBuilder.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.neo4j.gds.api.IdMap;
2323
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
2424
import org.neo4j.gds.collections.ha.HugeLongArray;
25+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2526

2627
import java.util.concurrent.atomic.AtomicInteger;
2728

@@ -31,25 +32,37 @@ public class KdTreeBuilder {
3132
private final NodePropertyValues nodePropertyValues;
3233
private final int concurrency;
3334
private final long leafSize;
35+
private final ProgressTracker progressTracker;
3436

35-
public KdTreeBuilder(IdMap nodes, NodePropertyValues nodePropertyValues,int concurrency,long leafSize) {
37+
public KdTreeBuilder(IdMap nodes, NodePropertyValues nodePropertyValues,int concurrency,long leafSize, ProgressTracker progressTracker) {
3638
this.nodes = nodes;
3739
this.nodePropertyValues = nodePropertyValues;
3840
this. concurrency = concurrency;
3941
this.leafSize = leafSize;
42+
this.progressTracker = progressTracker;
4043
}
4144

4245
public KdTree build(){
4346

4447
var ids = HugeLongArray.newArray(nodes.nodeCount());
4548
ids.setAll( v-> v);
4649
AtomicInteger nodeIndex = new AtomicInteger(0);
47-
var builderTask = new KdTreeNodeBuilderTask(ids,nodePropertyValues,0,nodePropertyValues.nodeCount(),leafSize,false,null,
48-
nodeIndex
50+
var builderTask = new KdTreeNodeBuilderTask(
51+
ids,
52+
nodePropertyValues,
53+
0,
54+
nodePropertyValues.nodeCount(),
55+
leafSize,
56+
false,
57+
null,
58+
nodeIndex,
59+
progressTracker
4960
);
61+
62+
progressTracker.beginSubTask();
5063
builderTask.run();
5164
var root = builderTask.kdNode();
52-
65+
progressTracker.endSubTask();
5366
return new KdTree(ids, nodePropertyValues, root, nodeIndex.get());
5467
}
5568

algo/src/main/java/org/neo4j/gds/hdbscan/KdTreeNodeBuilderTask.java

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
2323
import org.neo4j.gds.collections.ha.HugeLongArray;
24+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2425

2526
import java.util.concurrent.atomic.AtomicInteger;
2627
import java.util.function.LongToDoubleFunction;
@@ -37,14 +38,19 @@ class KdTreeNodeBuilderTask implements Runnable {
3738
private final boolean amLeftChild;
3839
private final KdNode parent;
3940
private final AtomicInteger nodeIndex;
41+
private final ProgressTracker progressTracker;
4042

4143

4244
KdTreeNodeBuilderTask(
4345
HugeLongArray ids,
4446
NodePropertyValues nodePropertyValues,
4547
long start,
4648
long end,
47-
long maxLeafSize, boolean amLeftChild, KdNode parent, AtomicInteger nodeIndex
49+
long maxLeafSize,
50+
boolean amLeftChild,
51+
KdNode parent,
52+
AtomicInteger nodeIndex,
53+
ProgressTracker progressTracker
4854
) {
4955
this.ids = ids;
5056
this.nodePropertyValues = nodePropertyValues;
@@ -55,6 +61,7 @@ class KdTreeNodeBuilderTask implements Runnable {
5561
this.amLeftChild = amLeftChild;
5662
this.parent = parent;
5763
this.nodeIndex = nodeIndex;
64+
this.progressTracker = progressTracker;
5865
}
5966

6067
@Override
@@ -64,7 +71,7 @@ public void run() {
6471
var treeNodeId = nodeIndex.getAndIncrement();
6572
if (nodeSize <= maxLeafSize) {
6673
kdNode = KdNode.createLeaf(treeNodeId, start, end, aabb);
67-
74+
progressTracker.logProgress(nodeSize);
6875
} else {
6976

7077
int indexToSplit = aabb.mostSpreadDimension(); //step. 1: find the index to dimension split
@@ -73,10 +80,30 @@ public void run() {
7380

7481
kdNode = KdNode.createSplitNode(treeNodeId, start,end,aabb,new SplitInformation(medianValue,indexToSplit));
7582
//TODO: step.4 add these builder tasks into a fork-join
76-
var leftChildBuilder = new KdTreeNodeBuilderTask(ids, nodePropertyValues, start, median, maxLeafSize,true,kdNode, nodeIndex);
83+
var leftChildBuilder = new KdTreeNodeBuilderTask(ids,
84+
nodePropertyValues,
85+
start,
86+
median,
87+
maxLeafSize,
88+
true,
89+
kdNode,
90+
nodeIndex,
91+
progressTracker
92+
);
7793
leftChildBuilder.run();
7894

79-
var rightChildBuilder = new KdTreeNodeBuilderTask(ids, nodePropertyValues, median, end, maxLeafSize,false,kdNode, nodeIndex);
95+
var rightChildBuilder = new KdTreeNodeBuilderTask(
96+
ids,
97+
nodePropertyValues,
98+
median,
99+
end,
100+
maxLeafSize,
101+
false,
102+
kdNode,
103+
nodeIndex,
104+
progressTracker
105+
);
106+
80107
rightChildBuilder.run();
81108

82109
var leftChild = leftChildBuilder.kdNode();

algo/src/test/java/org/neo4j/gds/hdbscan/DualTreeMSTAlgorithmTest.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.assertj.core.data.Offset;
2323
import org.junit.jupiter.api.Nested;
2424
import org.junit.jupiter.api.Test;
25+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2526
import org.neo4j.gds.extension.GdlExtension;
2627
import org.neo4j.gds.extension.GdlGraph;
2728
import org.neo4j.gds.extension.Inject;
@@ -58,7 +59,7 @@ class Case1 {
5859
@Test
5960
void shouldReturnEuclideanMSTWithZeroCoreValues() {
6061
var nodePropertyValues = graph.nodeProperties("point");
61-
var kdTree = new KdTreeBuilder(graph, nodePropertyValues, 1, 1).build();
62+
var kdTree = new KdTreeBuilder(graph, nodePropertyValues, 1, 1,ProgressTracker.NULL_TRACKER).build();
6263

6364
var dualTree = DualTreeMSTAlgorithm.createWithZeroCores(
6465
nodePropertyValues,
@@ -113,7 +114,7 @@ class Case2 {
113114
@Test
114115
void shouldReturnEuclideanMSTWithZeroCoreValues() {
115116
var nodePropertyValues = graph.nodeProperties("point");
116-
var kdTree = new KdTreeBuilder(graph, nodePropertyValues, 1, 1).build();
117+
var kdTree = new KdTreeBuilder(graph, nodePropertyValues, 1, 1, ProgressTracker.NULL_TRACKER).build();
117118

118119
var dualTree = DualTreeMSTAlgorithm.createWithZeroCores(
119120
nodePropertyValues,
@@ -168,7 +169,7 @@ class Case3 {
168169
@Test
169170
void shouldReturnEuclideanMSTWithZeroCoreValues() {
170171
var nodePropertyValues = graph.nodeProperties("point");
171-
var kdTree = new KdTreeBuilder(graph, nodePropertyValues, 1, 1).build();
172+
var kdTree = new KdTreeBuilder(graph, nodePropertyValues, 1, 1,ProgressTracker.NULL_TRACKER).build();
172173

173174
var dualTree = DualTreeMSTAlgorithm.createWithZeroCores(
174175
nodePropertyValues,

algo/src/test/java/org/neo4j/gds/hdbscan/KdTreeBuilderTest.java

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,23 @@
1919
*/
2020
package org.neo4j.gds.hdbscan;
2121

22+
import org.assertj.core.api.Assertions;
2223
import org.junit.jupiter.api.Test;
24+
import org.neo4j.gds.compat.TestLog;
25+
import org.neo4j.gds.core.concurrency.Concurrency;
26+
import org.neo4j.gds.core.utils.logging.LoggerForProgressTrackingAdapter;
27+
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
28+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
29+
import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker;
2330
import org.neo4j.gds.extension.GdlExtension;
2431
import org.neo4j.gds.extension.GdlGraph;
2532
import org.neo4j.gds.extension.Inject;
2633
import org.neo4j.gds.extension.TestGraph;
34+
import org.neo4j.gds.logging.GdsTestLog;
2735

2836
import static org.assertj.core.api.Assertions.assertThat;
37+
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
38+
import static org.neo4j.gds.assertj.Extractors.replaceTimings;
2939

3040
@GdlExtension
3141
class KdTreeBuilderTest {
@@ -52,7 +62,9 @@ void shouldCreateKdTree() {
5262
.isZero();
5363

5464
var points = graph.nodeProperties("point");
55-
var kdTree = new KdTreeBuilder(graph, points, 1, 1).build();
65+
var kdTree = new KdTreeBuilder(graph, points, 1, 1, ProgressTracker.NULL_TRACKER)
66+
.build();
67+
5668
assertThat(kdTree).isNotNull();
5769

5870
var root = kdTree.root();
@@ -183,4 +195,31 @@ void shouldCreateKdTree() {
183195
assertThat(kdTree.treeNodeCount()).isEqualTo(11);
184196
}
185197

198+
@Test
199+
void shouldLogProgress(){
200+
201+
var progressTask = HDBScanProgressTrackerCreator.kdBuildingTask("foo",graph.nodeCount());
202+
var log = new GdsTestLog();
203+
var progressTracker = new TaskProgressTracker(progressTask, new LoggerForProgressTrackingAdapter(log), new Concurrency(1), EmptyTaskRegistryFactory.INSTANCE);
204+
var points = graph.nodeProperties("point");
205+
206+
new KdTreeBuilder(graph, points, 1, 1, progressTracker)
207+
.build();
208+
209+
Assertions.assertThat(log.getMessages(TestLog.INFO))
210+
.extracting(removingThreadId())
211+
.extracting(replaceTimings())
212+
.containsExactly(
213+
"foo :: Start",
214+
"foo 16%",
215+
"foo 33%",
216+
"foo 50%",
217+
"foo 66%",
218+
"foo 83%",
219+
"foo 100%",
220+
"foo :: Finished"
221+
);
222+
223+
}
224+
186225
}

algo/src/test/java/org/neo4j/gds/hdbscan/KdTreeNodeBuilderTaskTest.java

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.junit.jupiter.api.Test;
2323
import org.neo4j.gds.api.properties.nodes.DoubleArrayNodePropertyValues;
2424
import org.neo4j.gds.collections.ha.HugeLongArray;
25+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2526

2627
import java.util.Random;
2728
import java.util.concurrent.atomic.AtomicInteger;
@@ -49,7 +50,15 @@ public long nodeCount() {
4950
return ids.size();
5051
}
5152
};
52-
var nodeBuilder =new KdTreeNodeBuilderTask(ids,nodePropertyValues,0,10,1,false,null, new AtomicInteger(0));
53+
var nodeBuilder =new KdTreeNodeBuilderTask(ids,
54+
nodePropertyValues,
55+
0,10,1,
56+
false,
57+
null,
58+
new AtomicInteger(0),
59+
ProgressTracker.NULL_TRACKER
60+
);
61+
5362
var median = nodeBuilder.findMedianAndSplit(0);
5463
assertThat(median).isEqualTo(5L);
5564
assertThat(ids.toArray()).containsExactlyInAnyOrder(0,1,2,3,4,5,6,7,8,9);
@@ -73,7 +82,19 @@ public long nodeCount() {
7382
return ids.size();
7483
}
7584
};
76-
var nodeBuilder =new KdTreeNodeBuilderTask(ids,nodePropertyValues,4,8,1,false,null, new AtomicInteger() );
85+
86+
var nodeBuilder =new KdTreeNodeBuilderTask(
87+
ids,
88+
nodePropertyValues,
89+
4,
90+
8,
91+
1,
92+
false,
93+
null,
94+
new AtomicInteger(),
95+
ProgressTracker.NULL_TRACKER
96+
);
97+
7798
var median = nodeBuilder.findMedianAndSplit(0);
7899
assertThat(median).isEqualTo(6L);
79100
for (int i=4;i<6;++i){
@@ -94,7 +115,18 @@ public long nodeCount() {
94115
return ids.size();
95116
}
96117
};
97-
var nodeBuilder =new KdTreeNodeBuilderTask(ids,nodePropertyValues,0,3,3,false,null, new AtomicInteger());
118+
var nodeBuilder =new KdTreeNodeBuilderTask(
119+
ids,
120+
nodePropertyValues,
121+
0,
122+
3,
123+
3,
124+
false,
125+
null,
126+
new AtomicInteger(),
127+
ProgressTracker.NULL_TRACKER
128+
);
129+
98130
nodeBuilder.run();
99131

100132
var node = nodeBuilder.kdNode();
@@ -120,7 +152,18 @@ public long nodeCount() {
120152
return ids.size();
121153
}
122154
};
123-
var nodeBuilder =new KdTreeNodeBuilderTask(ids,nodePropertyValues,0,3,2,false,null, new AtomicInteger(0));
155+
var nodeBuilder =new KdTreeNodeBuilderTask(
156+
ids,
157+
nodePropertyValues,
158+
0,
159+
3,
160+
2,
161+
false,
162+
null,
163+
new AtomicInteger(0),
164+
ProgressTracker.NULL_TRACKER
165+
);
166+
124167
nodeBuilder.run();
125168

126169
var node = nodeBuilder.kdNode();

algo/src/test/java/org/neo4j/gds/hdbscan/KdTreeTest.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.neo4j.gds.hdbscan;
2121

2222
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2324
import org.neo4j.gds.extension.GdlExtension;
2425
import org.neo4j.gds.extension.GdlGraph;
2526
import org.neo4j.gds.extension.Inject;
@@ -49,7 +50,9 @@ class KdTreeTest {
4950
void shouldFindNeighbours() {
5051

5152
var points = graph.nodeProperties("point");
52-
var kdTree = new KdTreeBuilder(graph, points, 1, 1).build();
53+
var kdTree = new KdTreeBuilder(graph, points, 1, 1, ProgressTracker.NULL_TRACKER)
54+
.build();
55+
5356
var queryPoint = new double[]{9d, 2d};
5457
var neighbours = kdTree.neighbours(queryPoint, 2);
5558
assertThat(neighbours)
@@ -71,7 +74,9 @@ void shouldFindNeighbours() {
7174
void shouldNotFindItself() {
7275

7376
var points = graph.nodeProperties("point");
74-
var kdTree = new KdTreeBuilder(graph, points, 1, 1).build();
77+
var kdTree = new KdTreeBuilder(graph, points, 1, 1,ProgressTracker.NULL_TRACKER)
78+
.build();
79+
7580
var neighbours = kdTree.neighbours(graph.toMappedNodeId("a"), 2).neighbours();
7681
assertThat(neighbours)
7782
.isNotNull()

0 commit comments

Comments
 (0)