Skip to content

Commit b67e716

Browse files
Label clusters
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neotechnology.com>
1 parent ba0551d commit b67e716

File tree

2 files changed

+110
-14
lines changed

2 files changed

+110
-14
lines changed

algo/src/main/java/org/neo4j/gds/hdbscan/StabilityStep.java renamed to algo/src/main/java/org/neo4j/gds/hdbscan/LabellingStep.java

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,24 @@
2121

2222
import com.carrotsearch.hppc.BitSet;
2323
import org.neo4j.gds.collections.ha.HugeDoubleArray;
24+
import org.neo4j.gds.collections.ha.HugeLongArray;
2425

25-
class StabilityStep {
26-
HugeDoubleArray computeStabilities(CondensedTree condensedTree, long nodeCount) {
26+
class LabellingStep {
27+
28+
private final CondensedTree condensedTree;
29+
private final long nodeCount;
30+
31+
LabellingStep(CondensedTree condensedTree, long nodeCount) {
32+
this.condensedTree = condensedTree;
33+
this.nodeCount = nodeCount;
34+
}
35+
36+
HugeDoubleArray computeStabilities() {
2737
var result = HugeDoubleArray.newArray(nodeCount - 1);
2838

29-
var condensedTreeRoot = condensedTree.root();
39+
var condensedTreeRoot = this.condensedTree.root();
3040
// process the leaves of the tree
31-
for (var p = 0; p < nodeCount; p++) {
41+
for (var p = 0; p < this.nodeCount; p++) {
3242
var lambdaP = 1. / condensedTree.lambda(p);
3343
var birthPoint = condensedTree.fellOutOf(p);
3444
var lambdaBirth = birthPoint == condensedTreeRoot
@@ -51,7 +61,7 @@ HugeDoubleArray computeStabilities(CondensedTree condensedTree, long nodeCount)
5161
return result;
5262
}
5363

54-
BitSet selectedClusters(CondensedTree condensedTree, HugeDoubleArray stabilities, long nodeCount) {
64+
BitSet selectedClusters(HugeDoubleArray stabilities) {
5565

5666
var selectedClusters = new BitSet(nodeCount);
5767

@@ -81,4 +91,34 @@ BitSet selectedClusters(CondensedTree condensedTree, HugeDoubleArray stabilities
8191

8292
return selectedClusters;
8393
}
94+
95+
HugeLongArray computeLabels(BitSet selectedClusters) {
96+
var labels = HugeLongArray.newArray(nodeCount);
97+
labels.fill(-1L);
98+
var nodeCountLabels = HugeLongArray.newArray(nodeCount);
99+
var root = condensedTree.root();
100+
var maximumClusterId = condensedTree.maximumClusterId();
101+
for (var p = root; p <= maximumClusterId; p++) {
102+
var adaptedIndex = p - nodeCount;
103+
var parent = condensedTree.parent(p);
104+
long parentLabel = p == root ? -1L : labels.get(parent - nodeCount);
105+
if (parentLabel != -1L) {
106+
labels.set(adaptedIndex, parentLabel);
107+
} else if (selectedClusters.get(adaptedIndex)) {
108+
labels.set(adaptedIndex, adaptedIndex);
109+
}
110+
}
111+
112+
for (var n = 0; n < nodeCount; n++) {
113+
nodeCountLabels.set(n, labels.get(condensedTree.fellOutOf(n) - nodeCount));
114+
}
115+
116+
return nodeCountLabels;
117+
}
118+
119+
HugeLongArray label() {
120+
var stabilities = computeStabilities();
121+
var selectedClusters = selectedClusters(stabilities);
122+
return computeLabels(selectedClusters);
123+
}
84124
}

algo/src/test/java/org/neo4j/gds/hdbscan/StabilityStepTest.java renamed to algo/src/test/java/org/neo4j/gds/hdbscan/LabellingTest.java

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919
*/
2020
package org.neo4j.gds.hdbscan;
2121

22+
import com.carrotsearch.hppc.BitSet;
2223
import org.assertj.core.data.Offset;
2324
import org.junit.jupiter.api.Test;
2425
import org.neo4j.gds.collections.ha.HugeDoubleArray;
2526
import org.neo4j.gds.collections.ha.HugeLongArray;
2627

2728
import static org.assertj.core.api.Assertions.assertThat;
2829

29-
class StabilityStepTest {
30+
class LabellingTest {
3031

3132
@Test
3233
void clusterStability() {
@@ -39,9 +40,9 @@ void clusterStability() {
3940
var maximumClusterId = 6;
4041

4142
var condensedTree = new CondensedTree(root, parent, lambda, size, maximumClusterId, nodeCount);
42-
var stabilityStep = new StabilityStep();
43+
var stabilityStep = new LabellingStep(condensedTree, nodeCount);
4344

44-
var stabilities = stabilityStep.computeStabilities(condensedTree, nodeCount);
45+
var stabilities = stabilityStep.computeStabilities();
4546

4647

4748
assertThat(stabilities.toArray()).containsExactly(
@@ -65,9 +66,9 @@ void clusterStabilityBiggerTest() {
6566

6667
var condensedTree = new CondensedTree(root, parent, lambda, size, maximumClusterId, nodeCount);
6768

68-
var stabilityStep = new StabilityStep();
69+
var stabilityStep = new LabellingStep(condensedTree, nodeCount);
6970

70-
var stabilities = stabilityStep.computeStabilities(condensedTree, nodeCount);
71+
var stabilities = stabilityStep.computeStabilities();
7172

7273
assertThat(stabilities.toArray()).containsExactly(
7374
new double[] {
@@ -101,9 +102,9 @@ void clusterSelectionOfChildClusters() {
101102
var stabilities = HugeDoubleArray.of(3., 4., 5.);
102103

103104
var condensedTree = new CondensedTree(root, parent, lambda, size, maximumClusterId, nodeCount);
104-
var stabilityStep = new StabilityStep();
105+
var stabilityStep = new LabellingStep(condensedTree, nodeCount);
105106

106-
var selectedClusters = stabilityStep.selectedClusters(condensedTree, stabilities, nodeCount);
107+
var selectedClusters = stabilityStep.selectedClusters(stabilities);
107108

108109
assertThat(selectedClusters.get(0))
109110
.withFailMessage("Root should be unselected")
@@ -132,9 +133,9 @@ void clusterSelectionOfParentCluster() {
132133
var stabilities = HugeDoubleArray.of(10., 4., 5.);
133134

134135
var condensedTree = new CondensedTree(root, parent, lambda, size, maximumClusterId, nodeCount);
135-
var stabilityStep = new StabilityStep();
136+
var stabilityStep = new LabellingStep(condensedTree, nodeCount);
136137

137-
var selectedClusters = stabilityStep.selectedClusters(condensedTree, stabilities, nodeCount);
138+
var selectedClusters = stabilityStep.selectedClusters(stabilities);
138139

139140
assertThat(selectedClusters.get(0))
140141
.withFailMessage("Root should be selected")
@@ -146,4 +147,59 @@ void clusterSelectionOfParentCluster() {
146147
.withFailMessage("Second child should be selected")
147148
.isTrue();
148149
}
150+
151+
@Test
152+
void labelling() {
153+
var parent = HugeLongArray.of(8, 8, 10, 10, 11, 11, 11, 0, 7, 7, 9, 9, 0);
154+
var lambda = HugeDoubleArray.of(11.0, 11.0, 9.0, 9.0, 8.0, 7.0, 7.0, 0.0, 12.0, 12.0, 10.0, 10.0, 0.0);
155+
var size = HugeLongArray.of(7, 2, 5, 2, 3, 0, 0);
156+
var maximumClusterId = 11;
157+
var nodeCount = 7;
158+
var root = 7;
159+
160+
var condensedTree = new CondensedTree(root, parent, lambda, size, maximumClusterId, nodeCount);
161+
var selectedClusters = new BitSet(5);
162+
// selects cluster `8`
163+
selectedClusters.set(1);
164+
// selects cluster `11`
165+
selectedClusters.set(4);
166+
167+
var stabilityStep = new LabellingStep(condensedTree, nodeCount);
168+
169+
var labels = stabilityStep.computeLabels(selectedClusters);
170+
171+
assertThat(labels.size()).isEqualTo(nodeCount);
172+
173+
assertThat(labels.get(0)).isEqualTo(1L);
174+
assertThat(labels.get(1)).isEqualTo(1L);
175+
176+
assertThat(labels.get(2)).isEqualTo(-1L);
177+
assertThat(labels.get(3)).isEqualTo(-1L);
178+
179+
assertThat(labels.get(4)).isEqualTo(4L);
180+
assertThat(labels.get(5)).isEqualTo(4L);
181+
assertThat(labels.get(6)).isEqualTo(4L);
182+
183+
}
184+
185+
@Test
186+
void labellingWhenAllClustersAreSelected() {
187+
var parent = HugeLongArray.of(8, 8, 10, 10, 11, 11, 11, 0, 7, 7, 9, 9, 0);
188+
var lambda = HugeDoubleArray.of(11.0, 11.0, 9.0, 9.0, 8.0, 7.0, 7.0, 0.0, 12.0, 12.0, 10.0, 10.0, 0.0);
189+
var size = HugeLongArray.of(7, 2, 5, 2, 3, 0, 0);
190+
var maximumClusterId = 11;
191+
var nodeCount = 7;
192+
var root = 7;
193+
194+
var condensedTree = new CondensedTree(root, parent, lambda, size, maximumClusterId, nodeCount);
195+
var selectedClusters = new BitSet(5);
196+
selectedClusters.set(0, 5);
197+
198+
var stabilityStep = new LabellingStep(condensedTree, nodeCount);
199+
200+
var labels = stabilityStep.computeLabels(selectedClusters);
201+
202+
assertThat(labels.size()).isEqualTo(nodeCount);
203+
assertThat(labels.toArray()).containsOnly(0L);
204+
}
149205
}

0 commit comments

Comments
 (0)