1919 */
2020package org .neo4j .gds .hdbscan ;
2121
22+ import com .carrotsearch .hppc .BitSet ;
2223import org .assertj .core .data .Offset ;
2324import org .junit .jupiter .api .Test ;
2425import org .neo4j .gds .collections .ha .HugeDoubleArray ;
2526import org .neo4j .gds .collections .ha .HugeLongArray ;
2627
2728import static org .assertj .core .api .Assertions .assertThat ;
2829
29- class StabilityStepTest {
30+ class LabellingTest {
3031
3132 @ Test
3233 void clusterStability () {
@@ -39,9 +40,9 @@ void clusterStability() {
3940 var maximumClusterId = 6 ;
4041
4142 var condensedTree = new CondensedTree (root , parent , lambda , size , maximumClusterId , nodeCount );
42- var stabilityStep = new StabilityStep ( );
43+ var stabilityStep = new LabellingStep ( condensedTree , nodeCount );
4344
44- var stabilities = stabilityStep .computeStabilities (condensedTree , nodeCount );
45+ var stabilities = stabilityStep .computeStabilities ();
4546
4647
4748 assertThat (stabilities .toArray ()).containsExactly (
@@ -65,9 +66,9 @@ void clusterStabilityBiggerTest() {
6566
6667 var condensedTree = new CondensedTree (root , parent , lambda , size , maximumClusterId , nodeCount );
6768
68- var stabilityStep = new StabilityStep ( );
69+ var stabilityStep = new LabellingStep ( condensedTree , nodeCount );
6970
70- var stabilities = stabilityStep .computeStabilities (condensedTree , nodeCount );
71+ var stabilities = stabilityStep .computeStabilities ();
7172
7273 assertThat (stabilities .toArray ()).containsExactly (
7374 new double [] {
@@ -101,9 +102,9 @@ void clusterSelectionOfChildClusters() {
101102 var stabilities = HugeDoubleArray .of (3. , 4. , 5. );
102103
103104 var condensedTree = new CondensedTree (root , parent , lambda , size , maximumClusterId , nodeCount );
104- var stabilityStep = new StabilityStep ( );
105+ var stabilityStep = new LabellingStep ( condensedTree , nodeCount );
105106
106- var selectedClusters = stabilityStep .selectedClusters (condensedTree , stabilities , nodeCount );
107+ var selectedClusters = stabilityStep .selectedClusters (stabilities );
107108
108109 assertThat (selectedClusters .get (0 ))
109110 .withFailMessage ("Root should be unselected" )
@@ -132,9 +133,9 @@ void clusterSelectionOfParentCluster() {
132133 var stabilities = HugeDoubleArray .of (10. , 4. , 5. );
133134
134135 var condensedTree = new CondensedTree (root , parent , lambda , size , maximumClusterId , nodeCount );
135- var stabilityStep = new StabilityStep ( );
136+ var stabilityStep = new LabellingStep ( condensedTree , nodeCount );
136137
137- var selectedClusters = stabilityStep .selectedClusters (condensedTree , stabilities , nodeCount );
138+ var selectedClusters = stabilityStep .selectedClusters (stabilities );
138139
139140 assertThat (selectedClusters .get (0 ))
140141 .withFailMessage ("Root should be selected" )
@@ -146,4 +147,59 @@ void clusterSelectionOfParentCluster() {
146147 .withFailMessage ("Second child should be selected" )
147148 .isTrue ();
148149 }
150+
151+ @ Test
152+ void labelling () {
153+ var parent = HugeLongArray .of (8 , 8 , 10 , 10 , 11 , 11 , 11 , 0 , 7 , 7 , 9 , 9 , 0 );
154+ var lambda = HugeDoubleArray .of (11.0 , 11.0 , 9.0 , 9.0 , 8.0 , 7.0 , 7.0 , 0.0 , 12.0 , 12.0 , 10.0 , 10.0 , 0.0 );
155+ var size = HugeLongArray .of (7 , 2 , 5 , 2 , 3 , 0 , 0 );
156+ var maximumClusterId = 11 ;
157+ var nodeCount = 7 ;
158+ var root = 7 ;
159+
160+ var condensedTree = new CondensedTree (root , parent , lambda , size , maximumClusterId , nodeCount );
161+ var selectedClusters = new BitSet (5 );
162+ // selects cluster `8`
163+ selectedClusters .set (1 );
164+ // selects cluster `11`
165+ selectedClusters .set (4 );
166+
167+ var stabilityStep = new LabellingStep (condensedTree , nodeCount );
168+
169+ var labels = stabilityStep .computeLabels (selectedClusters );
170+
171+ assertThat (labels .size ()).isEqualTo (nodeCount );
172+
173+ assertThat (labels .get (0 )).isEqualTo (1L );
174+ assertThat (labels .get (1 )).isEqualTo (1L );
175+
176+ assertThat (labels .get (2 )).isEqualTo (-1L );
177+ assertThat (labels .get (3 )).isEqualTo (-1L );
178+
179+ assertThat (labels .get (4 )).isEqualTo (4L );
180+ assertThat (labels .get (5 )).isEqualTo (4L );
181+ assertThat (labels .get (6 )).isEqualTo (4L );
182+
183+ }
184+
185+ @ Test
186+ void labellingWhenAllClustersAreSelected () {
187+ var parent = HugeLongArray .of (8 , 8 , 10 , 10 , 11 , 11 , 11 , 0 , 7 , 7 , 9 , 9 , 0 );
188+ var lambda = HugeDoubleArray .of (11.0 , 11.0 , 9.0 , 9.0 , 8.0 , 7.0 , 7.0 , 0.0 , 12.0 , 12.0 , 10.0 , 10.0 , 0.0 );
189+ var size = HugeLongArray .of (7 , 2 , 5 , 2 , 3 , 0 , 0 );
190+ var maximumClusterId = 11 ;
191+ var nodeCount = 7 ;
192+ var root = 7 ;
193+
194+ var condensedTree = new CondensedTree (root , parent , lambda , size , maximumClusterId , nodeCount );
195+ var selectedClusters = new BitSet (5 );
196+ selectedClusters .set (0 , 5 );
197+
198+ var stabilityStep = new LabellingStep (condensedTree , nodeCount );
199+
200+ var labels = stabilityStep .computeLabels (selectedClusters );
201+
202+ assertThat (labels .size ()).isEqualTo (nodeCount );
203+ assertThat (labels .toArray ()).containsOnly (0L );
204+ }
149205}
0 commit comments