Skip to content

Commit 970b80a

Browse files
labelling progress tracking
Co-authored-by: Veselin Nikolov <veselin.nikolov@neotechnology.com>
1 parent bc1fd25 commit 970b80a

File tree

5 files changed

+101
-18
lines changed

5 files changed

+101
-18
lines changed

algo/src/main/java/org/neo4j/gds/hdbscan/HDBScan.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ public HugeLongArray compute() {
6868
var coreResult = computeCores(kdTree, nodeCount);
6969
var dualTreeMST = dualTreeMSTPhase(kdTree, coreResult);
7070
var clusterHierarchy = createClusterHierarchy(dualTreeMST);
71-
var condenseStep = new CondenseStep(nodeCount,progressTracker);
71+
var condenseStep = new CondenseStep(nodeCount, progressTracker);
7272
var condensedTree = condenseStep.condense(clusterHierarchy, minClusterSize);
73-
var labellingStep = new LabellingStep(condensedTree, nodeCount);
73+
var labellingStep = new LabellingStep(condensedTree, nodeCount, progressTracker);
7474
return labellingStep.labels();
7575
}
7676

algo/src/main/java/org/neo4j/gds/hdbscan/HDBScanProgressTrackerCreator.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import org.neo4j.gds.core.utils.progress.tasks.Task;
2323
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
2424

25+
import java.util.List;
26+
2527
public class HDBScanProgressTrackerCreator {
2628

2729
static Task kdBuildingTask(String name, long nodeCount){
@@ -36,4 +38,15 @@ static Task condenseTask(String name, long nodeCount){
3638
return Tasks.leaf(name,nodeCount - 1);
3739
}
3840

41+
static Task labellingTask(String name, long nodeCount){
42+
return Tasks.task(
43+
name,
44+
List.of(
45+
Tasks.leaf("Stability calculation", nodeCount-1),
46+
Tasks.leaf("cluster selection", nodeCount-1),
47+
Tasks.leaf("labelling", nodeCount + nodeCount-1)
48+
)
49+
);
50+
}
51+
3952
}

algo/src/main/java/org/neo4j/gds/hdbscan/LabellingStep.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,23 @@
2222
import com.carrotsearch.hppc.BitSet;
2323
import org.neo4j.gds.collections.ha.HugeDoubleArray;
2424
import org.neo4j.gds.collections.ha.HugeLongArray;
25+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2526

2627
class LabellingStep {
2728

2829
private final CondensedTree condensedTree;
2930
private final long nodeCount;
31+
private final ProgressTracker progressTracker;
3032

31-
LabellingStep(CondensedTree condensedTree, long nodeCount) {
33+
LabellingStep(CondensedTree condensedTree, long nodeCount, ProgressTracker progressTracker) {
3234
this.condensedTree = condensedTree;
3335
this.nodeCount = nodeCount;
36+
this.progressTracker = progressTracker;
3437
}
3538

3639
HugeDoubleArray computeStabilities() {
3740
var result = HugeDoubleArray.newArray(nodeCount - 1);
38-
41+
progressTracker.beginSubTask();
3942
var condensedTreeRoot = this.condensedTree.root();
4043
// process the leaves of the tree
4144
for (var p = 0; p < this.nodeCount; p++) {
@@ -56,8 +59,9 @@ HugeDoubleArray computeStabilities() {
5659
: 1. / condensedTree.lambda(birthPoint);
5760
var sizeP = condensedTree.size(p);
5861
result.addTo(birthPoint - nodeCount, sizeP * (lambdaP - lambdaBirth));
62+
progressTracker.logProgress();
5963
}
60-
64+
progressTracker.endSubTask();
6165
return result;
6266
}
6367

@@ -67,7 +71,7 @@ BitSet selectedClusters(HugeDoubleArray stabilities) {
6771

6872
var condensedTreeRoot = condensedTree.root();
6973
var condensedTreeMaxClusterId = condensedTree.maximumClusterId();
70-
74+
progressTracker.beginSubTask();
7175
var stabilitySums = HugeDoubleArray.newArray(nodeCount);
7276
for (var p = condensedTreeMaxClusterId; p >= condensedTreeRoot; p--) {
7377
var adaptedPIndex = p - nodeCount;
@@ -82,17 +86,19 @@ BitSet selectedClusters(HugeDoubleArray stabilities) {
8286
selectedClusters.set(adaptedPIndex);
8387
// Selected clusters below `p` are implicitly unselected - they will be ignored during- `labeling`
8488
}
89+
progressTracker.logProgress();
8590
if (p == condensedTreeRoot) {
8691
continue;
8792
}
8893
var parent = condensedTree.parent(p);
8994
stabilitySums.addTo(parent - nodeCount, stabilityToAdd);
9095
}
91-
96+
progressTracker.endSubTask();
9297
return selectedClusters;
9398
}
9499

95100
HugeLongArray computeLabels(BitSet selectedClusters) {
101+
progressTracker.beginSubTask();
96102
var labels = HugeLongArray.newArray(nodeCount);
97103
labels.fill(-1L);
98104
var nodeCountLabels = HugeLongArray.newArray(nodeCount);
@@ -107,18 +113,25 @@ HugeLongArray computeLabels(BitSet selectedClusters) {
107113
} else if (selectedClusters.get(adaptedIndex)) {
108114
labels.set(adaptedIndex, adaptedIndex);
109115
}
116+
progressTracker.logProgress();
110117
}
111118

112119
for (var n = 0; n < nodeCount; n++) {
113120
nodeCountLabels.set(n, labels.get(condensedTree.fellOutOf(n) - nodeCount));
121+
progressTracker.logProgress();
114122
}
123+
progressTracker.endSubTask();
124+
115125

116126
return nodeCountLabels;
117127
}
118128

119129
HugeLongArray labels() {
130+
progressTracker.beginSubTask();
120131
var stabilities = computeStabilities();
121132
var selectedClusters = selectedClusters(stabilities);
122-
return computeLabels(selectedClusters);
133+
var labels= computeLabels(selectedClusters);
134+
progressTracker.endSubTask();
135+
return labels;
123136
}
124137
}

algo/src/test/java/org/neo4j/gds/hdbscan/CondenseStepTest.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,16 +129,17 @@ void minClusterSizeThree() {
129129
@Test
130130
void shouldLogProgress(){
131131
var nodeCount = 7L;
132-
133-
var progressTask = HDBScanProgressTrackerCreator.condenseTask("condense",7);
134-
var log = new GdsTestLog();
135-
var progressTracker = new TaskProgressTracker(progressTask, new LoggerForProgressTrackingAdapter(log), new Concurrency(1), EmptyTaskRegistryFactory.INSTANCE);
136132
var root = 12L;
137133
var left = HugeLongArray.of(5, 4, 2, 9, 0, 11);
138134
var right = HugeLongArray.of(6, 7, 3, 8, 1, 10);
139135
var lambda = HugeDoubleArray.of(7d, 8d, 9d, 10d, 11d, 12d);
140136
var size = HugeLongArray.of(2, 3, 2, 5, 2, 7);
141137

138+
var progressTask = HDBScanProgressTrackerCreator.condenseTask("condense",nodeCount);
139+
var log = new GdsTestLog();
140+
var progressTracker = new TaskProgressTracker(progressTask, new LoggerForProgressTrackingAdapter(log), new Concurrency(1), EmptyTaskRegistryFactory.INSTANCE);
141+
142+
142143
var clusterHierarchy = new ClusterHierarchy(root, left, right, lambda, size, nodeCount);
143144

144145
new CondenseStep(nodeCount,progressTracker).condense(clusterHierarchy, 3L);

algo/src/test/java/org/neo4j/gds/hdbscan/LabellingTest.java

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,22 @@
2020
package org.neo4j.gds.hdbscan;
2121

2222
import com.carrotsearch.hppc.BitSet;
23+
import org.assertj.core.api.Assertions;
2324
import org.assertj.core.data.Offset;
2425
import org.junit.jupiter.api.Test;
2526
import org.neo4j.gds.collections.ha.HugeDoubleArray;
2627
import org.neo4j.gds.collections.ha.HugeLongArray;
28+
import org.neo4j.gds.compat.TestLog;
29+
import org.neo4j.gds.core.concurrency.Concurrency;
30+
import org.neo4j.gds.core.utils.logging.LoggerForProgressTrackingAdapter;
31+
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
32+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
33+
import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker;
34+
import org.neo4j.gds.logging.GdsTestLog;
2735

2836
import static org.assertj.core.api.Assertions.assertThat;
37+
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
38+
import static org.neo4j.gds.assertj.Extractors.replaceTimings;
2939

3040
class LabellingTest {
3141

@@ -40,7 +50,7 @@ void clusterStability() {
4050
var maximumClusterId = 6;
4151

4252
var condensedTree = new CondensedTree(root, parent, lambda, size, maximumClusterId, nodeCount);
43-
var stabilityStep = new LabellingStep(condensedTree, nodeCount);
53+
var stabilityStep = new LabellingStep(condensedTree, nodeCount, ProgressTracker.NULL_TRACKER);
4454

4555
var stabilities = stabilityStep.computeStabilities();
4656

@@ -66,7 +76,7 @@ void clusterStabilityBiggerTest() {
6676

6777
var condensedTree = new CondensedTree(root, parent, lambda, size, maximumClusterId, nodeCount);
6878

69-
var stabilityStep = new LabellingStep(condensedTree, nodeCount);
79+
var stabilityStep = new LabellingStep(condensedTree, nodeCount,ProgressTracker.NULL_TRACKER);
7080

7181
var stabilities = stabilityStep.computeStabilities();
7282

@@ -102,7 +112,7 @@ void clusterSelectionOfChildClusters() {
102112
var stabilities = HugeDoubleArray.of(3., 4., 5.);
103113

104114
var condensedTree = new CondensedTree(root, parent, lambda, size, maximumClusterId, nodeCount);
105-
var stabilityStep = new LabellingStep(condensedTree, nodeCount);
115+
var stabilityStep = new LabellingStep(condensedTree, nodeCount,ProgressTracker.NULL_TRACKER);
106116

107117
var selectedClusters = stabilityStep.selectedClusters(stabilities);
108118

@@ -133,7 +143,7 @@ void clusterSelectionOfParentCluster() {
133143
var stabilities = HugeDoubleArray.of(10., 4., 5.);
134144

135145
var condensedTree = new CondensedTree(root, parent, lambda, size, maximumClusterId, nodeCount);
136-
var stabilityStep = new LabellingStep(condensedTree, nodeCount);
146+
var stabilityStep = new LabellingStep(condensedTree, nodeCount,ProgressTracker.NULL_TRACKER);
137147

138148
var selectedClusters = stabilityStep.selectedClusters(stabilities);
139149

@@ -164,7 +174,7 @@ void labelling() {
164174
// selects cluster `11`
165175
selectedClusters.set(4);
166176

167-
var stabilityStep = new LabellingStep(condensedTree, nodeCount);
177+
var stabilityStep = new LabellingStep(condensedTree, nodeCount,ProgressTracker.NULL_TRACKER);
168178

169179
var labels = stabilityStep.computeLabels(selectedClusters);
170180

@@ -195,11 +205,57 @@ void labellingWhenAllClustersAreSelected() {
195205
var selectedClusters = new BitSet(5);
196206
selectedClusters.set(0, 5);
197207

198-
var stabilityStep = new LabellingStep(condensedTree, nodeCount);
208+
var stabilityStep = new LabellingStep(condensedTree, nodeCount,ProgressTracker.NULL_TRACKER);
199209

200210
var labels = stabilityStep.computeLabels(selectedClusters);
201211

202212
assertThat(labels.size()).isEqualTo(nodeCount);
203213
assertThat(labels.toArray()).containsOnly(0L);
204214
}
215+
216+
@Test
217+
void shouldLogProgress(){
218+
var nodeCount = 4;
219+
var root = 4;
220+
221+
var parent = HugeLongArray.of(5, 5, 6, 6, 0, 4, 4);
222+
var lambda = HugeDoubleArray.of(10, 10, 11, 11, 0, 12, 12);
223+
var size = HugeLongArray.of(4, 2, 2);
224+
var maximumClusterId = 6;
225+
226+
var condensedTree = new CondensedTree(root, parent, lambda, size, maximumClusterId, nodeCount);
227+
228+
var progressTask = HDBScanProgressTrackerCreator.labellingTask("foo",nodeCount);
229+
var log = new GdsTestLog();
230+
var progressTracker = new TaskProgressTracker(progressTask, new LoggerForProgressTrackingAdapter(log), new Concurrency(1), EmptyTaskRegistryFactory.INSTANCE);
231+
232+
new LabellingStep(condensedTree, nodeCount, progressTracker).labels();
233+
234+
Assertions.assertThat(log.getMessages(TestLog.INFO))
235+
.extracting(removingThreadId())
236+
.extracting(replaceTimings())
237+
.containsExactly(
238+
"foo :: Start",
239+
"foo :: Stability calculation :: Start",
240+
"foo :: Stability calculation 33%",
241+
"foo :: Stability calculation 66%",
242+
"foo :: Stability calculation 100%",
243+
"foo :: Stability calculation :: Finished",
244+
"foo :: cluster selection :: Start",
245+
"foo :: cluster selection 33%",
246+
"foo :: cluster selection 66%",
247+
"foo :: cluster selection 100%",
248+
"foo :: cluster selection :: Finished",
249+
"foo :: labelling :: Start",
250+
"foo :: labelling 14%",
251+
"foo :: labelling 28%",
252+
"foo :: labelling 42%",
253+
"foo :: labelling 57%",
254+
"foo :: labelling 71%",
255+
"foo :: labelling 85%",
256+
"foo :: labelling 100%",
257+
"foo :: labelling :: Finished",
258+
"foo :: Finished"
259+
);
260+
}
205261
}

0 commit comments

Comments
 (0)