Skip to content

Commit 4e31174

Browse files
Cluster hierarchy progress tracking
Co-authored-by: Veselin Nikolov <veselin.nikolov@neotechnology.com>
1 parent c1211ba commit 4e31174

File tree

4 files changed

+60
-6
lines changed

4 files changed

+60
-6
lines changed

algo/src/main/java/org/neo4j/gds/hdbscan/ClusterHierarchy.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.neo4j.gds.collections.ha.HugeDoubleArray;
2323
import org.neo4j.gds.collections.ha.HugeLongArray;
2424
import org.neo4j.gds.collections.ha.HugeObjectArray;
25+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2526

2627
import java.util.function.Function;
2728

@@ -49,7 +50,7 @@ final class ClusterHierarchy {
4950
this.nodeCount = nodeCount;
5051
}
5152

52-
static ClusterHierarchy create(long nodeCount, HugeObjectArray<Edge> edges) {
53+
static ClusterHierarchy create(long nodeCount, HugeObjectArray<Edge> edges, ProgressTracker progressTracker) {
5354
var left = HugeLongArray.newArray(nodeCount);
5455
var right = HugeLongArray.newArray(nodeCount);
5556
var lambda = HugeDoubleArray.newArray(nodeCount);
@@ -61,6 +62,7 @@ static ClusterHierarchy create(long nodeCount, HugeObjectArray<Edge> edges) {
6162

6263
var sizeFn = (Function<Long, Long>) n -> n < nodeCount ? 1L : size.get(n - nodeCount);
6364

65+
progressTracker.beginSubTask();
6466
for (var i = 0; i < edges.size(); i++) {
6567
var edge = edges.get(i);
6668
var l = unionFind.find(edge.source());
@@ -76,8 +78,10 @@ static ClusterHierarchy create(long nodeCount, HugeObjectArray<Edge> edges) {
7678
var rightSize = sizeFn.apply(r);
7779

7880
size.set(adaptedIndex, leftSize + rightSize);
79-
}
8081

82+
progressTracker.logProgress();
83+
}
84+
progressTracker.endSubTask();;
8185
return new ClusterHierarchy(currentRoot, left, right, lambda, size, nodeCount);
8286
}
8387

algo/src/main/java/org/neo4j/gds/hdbscan/HDBScan.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ DualTreeMSTResult dualTreeMSTPhase(KdTree kdTree, CoreResult coreResult) {
100100
ClusterHierarchy createClusterHierarchy(DualTreeMSTResult dualTreeMSTResult){
101101
var edges = dualTreeMSTResult.edges();
102102
HugeSerialObjectMergeSort.sort(Edge.class, edges, Edge::distance);
103-
return ClusterHierarchy.create(nodes.nodeCount(),edges);
103+
return ClusterHierarchy.create(nodes.nodeCount(),edges,progressTracker);
104104
}
105105

106106
KdTree buildKDTree() {

algo/src/main/java/org/neo4j/gds/hdbscan/HDBScanProgressTrackerCreator.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
public class HDBScanProgressTrackerCreator {
2626

2727
static Task kdBuildingTask(String name, long nodeCount){
28-
return Tasks.leaf(name,nodeCount);
28+
return Tasks.leaf(name, nodeCount);
29+
}
30+
31+
static Task hierarchyTask(String name, long nodeCount){
32+
return Tasks.leaf(name,nodeCount-1);
2933
}
3034

3135
}

algo/src/test/java/org/neo4j/gds/hdbscan/ClusterHierarchyTest.java

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,23 @@
1919
*/
2020
package org.neo4j.gds.hdbscan;
2121

22+
import org.assertj.core.api.Assertions;
2223
import org.assertj.core.api.SoftAssertions;
2324
import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension;
2425
import org.assertj.core.data.Offset;
2526
import org.junit.jupiter.api.Test;
2627
import org.junit.jupiter.api.extension.ExtendWith;
2728
import org.neo4j.gds.collections.ha.HugeObjectArray;
29+
import org.neo4j.gds.compat.TestLog;
30+
import org.neo4j.gds.core.concurrency.Concurrency;
31+
import org.neo4j.gds.core.utils.logging.LoggerForProgressTrackingAdapter;
32+
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
33+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
34+
import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker;
35+
import org.neo4j.gds.logging.GdsTestLog;
36+
37+
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
38+
import static org.neo4j.gds.assertj.Extractors.replaceTimings;
2839

2940
@ExtendWith(SoftAssertionsExtension.class)
3041
class ClusterHierarchyTest {
@@ -36,7 +47,11 @@ void shouldWorkWithLineGraph(SoftAssertions assertions) {
3647
new Edge(0, 1, 5.)
3748
);
3849

39-
var clusterHierarchy = ClusterHierarchy.create(3, edges);
50+
var clusterHierarchy = ClusterHierarchy.create(
51+
3,
52+
edges,
53+
ProgressTracker.NULL_TRACKER
54+
);
4055

4156
// 1. `1` and `2` are joined and create new id = 3 --> first set
4257
// 2. `0` and `3` are joined and create new id = 4 --> second set
@@ -71,7 +86,11 @@ void shouldWorkOnMoreComplexTree(SoftAssertions assertions) {
7186
new Edge(5, 7, 1.42823558)
7287
);
7388

74-
var clusterHierarchy = ClusterHierarchy.create(edges.size() + 1, edges);
89+
var clusterHierarchy = ClusterHierarchy.create(
90+
edges.size() + 1,
91+
edges,
92+
ProgressTracker.NULL_TRACKER
93+
);
7594

7695

7796
assertions.assertThat(clusterHierarchy.left(11)).isEqualTo(2);
@@ -124,4 +143,31 @@ void shouldWorkOnMoreComplexTree(SoftAssertions assertions) {
124143
assertions.assertThat(clusterHierarchy.lambda(20)).isCloseTo(1.42823558, Offset.offset(1e-9));
125144
assertions.assertThat(clusterHierarchy.size(20)).isEqualTo(11);
126145
}
146+
147+
@Test
148+
void shouldLogProgress(){
149+
var edges = HugeObjectArray.of(
150+
new Edge(1, 2, 3.),
151+
new Edge(0, 1, 5.)
152+
);
153+
154+
var progressTask = HDBScanProgressTrackerCreator.hierarchyTask("foo",3);
155+
var log = new GdsTestLog();
156+
var progressTracker = new TaskProgressTracker(progressTask, new LoggerForProgressTrackingAdapter(log), new Concurrency(1), EmptyTaskRegistryFactory.INSTANCE);
157+
var clusterHierarchy = ClusterHierarchy.create(
158+
3,
159+
edges,
160+
progressTracker
161+
);
162+
163+
Assertions.assertThat(log.getMessages(TestLog.INFO))
164+
.extracting(removingThreadId())
165+
.extracting(replaceTimings())
166+
.containsExactly(
167+
"foo :: Start",
168+
"foo 50%",
169+
"foo 100%",
170+
"foo :: Finished"
171+
);
172+
}
127173
}

0 commit comments

Comments
 (0)