Skip to content

Commit cb54edd

Browse files
Wire up the HDBSCAN steps
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neotechnology.com>
1 parent b67e716 commit cb54edd

File tree

3 files changed

+130
-32
lines changed

3 files changed

+130
-32
lines changed

algo/src/main/java/org/neo4j/gds/hdbscan/HDBScan.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,46 +22,56 @@
2222
import org.neo4j.gds.Algorithm;
2323
import org.neo4j.gds.api.IdMap;
2424
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
25+
import org.neo4j.gds.collections.ha.HugeLongArray;
2526
import org.neo4j.gds.collections.ha.HugeObjectArray;
2627
import org.neo4j.gds.core.concurrency.Concurrency;
2728
import org.neo4j.gds.core.concurrency.ParallelUtil;
2829
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2930
import org.neo4j.gds.termination.TerminationFlag;
3031

31-
public class HDBScan extends Algorithm<Void> {
32+
public class HDBScan extends Algorithm<HugeLongArray> {
3233

3334
private final IdMap nodes;
3435
private final NodePropertyValues nodePropertyValues;
3536
private final Concurrency concurrency;
3637
private final long leafSize;
3738
private final TerminationFlag terminationFlag;
3839
private final int k;
40+
private final long minClusterSize;
3941

4042
protected HDBScan(
4143
IdMap nodes,
4244
NodePropertyValues nodePropertyValues,
4345
Concurrency concurrency,
4446
long leafSize,
4547
int k,
46-
TerminationFlag terminationFlag,
47-
ProgressTracker progressTracker
48+
long minClusterSize,
49+
ProgressTracker progressTracker,
50+
TerminationFlag terminationFlag
4851
) {
4952
super(progressTracker);
5053
this.nodes = nodes;
5154
this.nodePropertyValues = nodePropertyValues;
5255
this.concurrency = concurrency;
5356
this.leafSize = leafSize;
5457
this.k = k;
58+
this.minClusterSize = minClusterSize;
5559
this.terminationFlag = terminationFlag;
5660
}
5761

5862
@Override
59-
public Void compute() {
63+
public HugeLongArray compute() {
6064
var kdTree = buildKDTree();
6165

62-
var coreResult = computeCores(kdTree, nodes.nodeCount());
66+
var nodeCount = nodes.nodeCount();
67+
var coreResult = computeCores(kdTree, nodeCount);
68+
// var dualTreeMST = dualTreeMSTPhase();
6369
var dualTreeMST = dualTreeMSTPhase(kdTree, coreResult);
64-
return null;
70+
var clusterHierarchy = createClusterHierarchy(dualTreeMST);
71+
var condenseStep = new CondenseStep(nodeCount);
72+
var condensedTree = condenseStep.condense(clusterHierarchy, minClusterSize);
73+
var labellingStep = new LabellingStep(condensedTree, nodeCount);
74+
return labellingStep.label();
6575
}
6676

6777
CoreResult computeCores(KdTree kdTree, long nodeCount) {
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.hdbscan;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.core.concurrency.Concurrency;
24+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
25+
import org.neo4j.gds.extension.GdlExtension;
26+
import org.neo4j.gds.extension.GdlGraph;
27+
import org.neo4j.gds.extension.Inject;
28+
import org.neo4j.gds.extension.TestGraph;
29+
import org.neo4j.gds.termination.TerminationFlag;
30+
31+
import static org.assertj.core.api.Assertions.assertThat;
32+
33+
@GdlExtension
34+
class HDBScanE2ETest {
35+
36+
@GdlGraph
37+
private static final String DATA =
38+
"""
39+
CREATE
40+
(a:Node {point: [1.17755754, 2.02742572]}),
41+
(b:Node {point: [0.88489682, 1.97328227]}),
42+
(c:Node {point: [1.04192267, 4.34997048]}),
43+
(d:Node {point: [1.25764886, 1.94667762]}),
44+
(e:Node {point: [0.95464318, 1.55300632]}),
45+
(f:Node {point: [0.80617459, 1.60491802]}),
46+
(g:Node {point: [1.26227786, 3.96066446]}),
47+
(h:Node {point: [0.87569985, 4.51938412]}),
48+
(i:Node {point: [0.8028515 , 4.088106 ]}),
49+
(j:Node {point: [0.82954022, 4.63897487]})
50+
""";
51+
52+
@Inject
53+
private TestGraph graph;
54+
55+
@Test
56+
void hdbscan() {
57+
var hdbScan = new HDBScan(
58+
graph,
59+
graph.nodeProperties("point"),
60+
new Concurrency(1),
61+
1,
62+
2,
63+
2,
64+
ProgressTracker.NULL_TRACKER,
65+
TerminationFlag.RUNNING_TRUE
66+
);
67+
68+
var labelsWithOffset = hdbScan.compute();
69+
70+
var labels = new long[10];
71+
for (char letter='a'; letter<='j';++letter){
72+
var offsetPosition = graph.toMappedNodeId(String.valueOf(letter));
73+
labels[letter-'a'] = labelsWithOffset.get(offsetPosition);
74+
}
75+
76+
var expectedLabels = new long[] {2, 2, 1, 2, 2, 2, 1, 1, 1, 1};
77+
78+
assertThat(labels).containsExactly(expectedLabels);
79+
}
80+
81+
82+
}

algo/src/test/java/org/neo4j/gds/hdbscan/HDBScanTest.java

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -58,50 +58,54 @@ class HDBScanTest {
5858
private TestGraph graph;
5959

6060
@Test
61-
void shouldComputeCoresCorrectly(){
61+
void shouldComputeCoresCorrectly() {
6262

63-
var hdbscan =new HDBScan(graph,graph.nodeProperties("point"),
63+
var hdbscan = new HDBScan(
64+
graph, graph.nodeProperties("point"),
6465
new Concurrency(1),
6566
1,
6667
2,
67-
TerminationFlag.RUNNING_TRUE,
68-
ProgressTracker.NULL_TRACKER
68+
-1L,
69+
ProgressTracker.NULL_TRACKER,
70+
TerminationFlag.RUNNING_TRUE
6971
);
7072

7173
var kdtree = hdbscan.buildKDTree();
7274

73-
var cores = hdbscan.computeCores(kdtree,graph.nodeCount()).createCoreArray();
75+
var cores = hdbscan.computeCores(kdtree, graph.nodeCount()).createCoreArray();
7476

7577
assertThat(cores.toArray())
7678
.usingComparatorWithPrecision(1e-4)
7779
.containsExactlyInAnyOrder(
78-
16.0, //a - d,b (sqrt(16)
79-
10.0, //b - c,d (sqrt(1 + 9)=sqrt(10)
80-
17.0, //c - b,d (sqrt(1 + 16) = sqrt(17)
81-
10.0, //d - a,b
82-
5.0, //e - g,f (sqrt(1 + 4) = sqrt(5)
83-
5.0, //f - g,e
84-
4.0, //g - f,e (sqrt(4)
85-
8.0, //h - g,f (sqrt( 4 + 4) = sqrt(8) = 2sqrt(2)
86-
346.0 //i - h, c (sqrt(11^2 + 15^2) = sqrt(346)
87-
);
80+
16.0, //a - d,b (sqrt(16)
81+
10.0, //b - c,d (sqrt(1 + 9)=sqrt(10)
82+
17.0, //c - b,d (sqrt(1 + 16) = sqrt(17)
83+
10.0, //d - a,b
84+
5.0, //e - g,f (sqrt(1 + 4) = sqrt(5)
85+
5.0, //f - g,e
86+
4.0, //g - f,e (sqrt(4)
87+
8.0, //h - g,f (sqrt( 4 + 4) = sqrt(8) = 2sqrt(2)
88+
346.0 //i - h, c (sqrt(11^2 + 15^2) = sqrt(346)
89+
);
8890

8991
}
9092

9193
@Test
92-
void shouldComputeMSTWithCoreValuesCorrectly(){
94+
void shouldComputeMSTWithCoreValuesCorrectly() {
9395

94-
var hdbscan =new HDBScan(graph,graph.nodeProperties("point"),
96+
var hdbscan = new HDBScan(
97+
graph, graph.nodeProperties("point"),
9598
new Concurrency(1),
9699
1,
97100
2,
98-
TerminationFlag.RUNNING_TRUE,
99-
ProgressTracker.NULL_TRACKER
101+
-1,
102+
ProgressTracker.NULL_TRACKER,
103+
TerminationFlag.RUNNING_TRUE
100104
);
101105

102106
var kdtree = hdbscan.buildKDTree();
103107

104-
var result = hdbscan.dualTreeMSTPhase(kdtree,hdbscan.computeCores(kdtree,graph.nodeCount()));
108+
var result = hdbscan.dualTreeMSTPhase(kdtree, hdbscan.computeCores(kdtree, graph.nodeCount()));
105109

106110
var expected = List.of(
107111
new Edge(graph.toMappedNodeId("i"), graph.toMappedNodeId("h"), Math.sqrt(346)),
@@ -126,25 +130,27 @@ void shouldComputeMSTWithCoreValuesCorrectly(){
126130
}
127131

128132
@Test
129-
void shouldComputeClusterHierarchyCorrectly(){
130-
HugeObjectArray<Edge> edges =HugeObjectArray.of(
133+
void shouldComputeClusterHierarchyCorrectly() {
134+
HugeObjectArray<Edge> edges = HugeObjectArray.of(
131135
new Edge(0, 1, 5.),
132136
new Edge(1, 2, 3.)
133137
);
134138

135139
var graphMock = mock(Graph.class);
136140
when(graphMock.nodeCount()).thenReturn(3L);
137141

138-
var hdbscan =new HDBScan(graphMock,
142+
var hdbscan = new HDBScan(
143+
graphMock,
139144
graph.nodeProperties("point"),
140145
new Concurrency(1),
141146
1,
142147
2,
143-
TerminationFlag.RUNNING_TRUE,
144-
ProgressTracker.NULL_TRACKER
148+
-1,
149+
ProgressTracker.NULL_TRACKER,
150+
TerminationFlag.RUNNING_TRUE
145151
);
146152

147-
var dualTreeResult = new DualTreeMSTResult(edges,-1);
153+
var dualTreeResult = new DualTreeMSTResult(edges, -1);
148154

149155
var clusterHierarchy = hdbscan.createClusterHierarchy(dualTreeResult);
150156

0 commit comments

Comments
 (0)