Skip to content

Commit b546a0a

Browse files
Track cluster sizes of the condensed tree
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neotechnology.com>
1 parent ee3af06 commit b546a0a

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

algo/src/main/java/org/neo4j/gds/hdbscan/CondenseStep.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ CondensedTree condense(ClusterHierarchy clusterHierarchy, long minClusterSize) {
5050
var clusterHierarchyRoot = clusterHierarchy.root();
5151
var parent = HugeLongArray.newArray(clusterHierarchyRoot + 1);
5252
var lambda = HugeDoubleArray.newArray(clusterHierarchyRoot + 1);
53+
var size = HugeLongArray.newArray(nodeCount);
54+
5355

5456
var relabel = HugeLongArray.newArray(nodeCount);
5557
var currentCondensedRoot = nodeCount;
@@ -59,6 +61,8 @@ CondensedTree condense(ClusterHierarchy clusterHierarchy, long minClusterSize) {
5961
var bfsQueue = HugeLongArrayQueue.newQueue(nodeCount);
6062
var visited = HugeAtomicBitSet.create(clusterHierarchyRoot + 1);
6163

64+
size.set(currentCondensedRoot - nodeCount, nodeCount);
65+
6266
for (var i = clusterHierarchyRoot; i >= nodeCount; i--) {
6367
if (visited.get(i)) {
6468
continue;
@@ -88,15 +92,17 @@ CondensedTree condense(ClusterHierarchy clusterHierarchy, long minClusterSize) {
8892
relabel.set(left - nodeCount, leftClusterId);
8993
parent.set(leftClusterId, currentReLabel);
9094
lambda.set(leftClusterId, fallingOutLambda);
95+
size.set(leftClusterId - nodeCount, leftSize);
9196

9297
var rightClusterId = ++currentCondensedMaxClusterId;
9398
relabel.set(right - nodeCount, rightClusterId);
9499
parent.set(rightClusterId, currentReLabel);
95100
lambda.set(rightClusterId, fallingOutLambda);
101+
size.set(rightClusterId - nodeCount, rightSize);
96102
}
97103
}
98104

99-
return new CondensedTree(currentCondensedRoot, parent, lambda, currentCondensedMaxClusterId);
105+
return new CondensedTree(currentCondensedRoot, parent, lambda, size, currentCondensedMaxClusterId, nodeCount);
100106
}
101107

102108
private void fallOut(

algo/src/main/java/org/neo4j/gds/hdbscan/CondensedTree.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,24 @@ class CondensedTree {
2727
private final long root;
2828
private final HugeLongArray parent;
2929
private final HugeDoubleArray lambda;
30+
private final HugeLongArray size;
3031
private final long maximumClusterId;
32+
private final long nodeCount;
3133

32-
CondensedTree(long root, HugeLongArray parent, HugeDoubleArray lambda, long maximumClusterId) {
34+
CondensedTree(
35+
long root,
36+
HugeLongArray parent,
37+
HugeDoubleArray lambda,
38+
HugeLongArray size,
39+
long maximumClusterId,
40+
long nodeCount
41+
) {
3342
this.root = root;
3443
this.parent = parent;
3544
this.lambda = lambda;
45+
this.size = size;
3646
this.maximumClusterId = maximumClusterId;
47+
this.nodeCount = nodeCount;
3748
}
3849

3950
long root() {
@@ -55,4 +66,8 @@ long maximumClusterId() {
5566
double lambda(long node) {
5667
return lambda.get(node);
5768
}
69+
70+
long size(long node) {
71+
return size.get(node - nodeCount);
72+
}
5873
}

algo/src/test/java/org/neo4j/gds/hdbscan/CondenseStepTest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,22 @@ void minClusterSizeTwo() {
4343
assertThat(condensedTree.root()).isEqualTo(7L);
4444
assertThat(condensedTree.maximumClusterId()).isEqualTo(11L);
4545

46+
assertThat(condensedTree.size(7L)).isEqualTo(7L);
4647

4748
assertThat(condensedTree.parent(8L)).isEqualTo(7L);
4849
assertThat(condensedTree.lambda(8L)).isEqualTo(12d);
50+
assertThat(condensedTree.size(8L)).isEqualTo(2L);
4951
assertThat(condensedTree.parent(9L)).isEqualTo(7L);
5052
assertThat(condensedTree.lambda(9L)).isEqualTo(12d);
53+
assertThat(condensedTree.size(9L)).isEqualTo(5L);
5154

5255
assertThat(condensedTree.parent(10L)).isEqualTo(9L);
56+
assertThat(condensedTree.size(10L)).isEqualTo(2L);
5357
assertThat(condensedTree.lambda(10L)).isEqualTo(10d);
58+
5459
assertThat(condensedTree.parent(11L)).isEqualTo(9L);
5560
assertThat(condensedTree.lambda(11L)).isEqualTo(10d);
61+
assertThat(condensedTree.size(11L)).isEqualTo(3L);
5662

5763
assertThat(condensedTree.fellOutOf(0L)).isEqualTo(8L);
5864
assertThat(condensedTree.lambda(0L)).isEqualTo(11d);
@@ -88,6 +94,8 @@ void minClusterSizeThree() {
8894
assertThat(condensedTree.root()).isEqualTo(7L);
8995
assertThat(condensedTree.maximumClusterId()).isEqualTo(7L);
9096

97+
assertThat(condensedTree.size(7L)).isEqualTo(7L);
98+
9199
assertThat(condensedTree.fellOutOf(0L)).isEqualTo(7L);
92100
assertThat(condensedTree.lambda(0L)).isEqualTo(12d);
93101
assertThat(condensedTree.fellOutOf(1L)).isEqualTo(7L);

0 commit comments

Comments
 (0)