Skip to content

Commit 2b3fcc0

Browse files
authored
Merge pull request #10175 from IoannisPanagiotas/boruvka-paralol
Parallel boruvka implementation
2 parents dd02b59 + 5956b99 commit 2b3fcc0

File tree

8 files changed

+708
-20
lines changed

8 files changed

+708
-20
lines changed
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.hdbscan;
21+
22+
import com.carrotsearch.hppc.BitSet;
23+
import org.neo4j.gds.Algorithm;
24+
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
25+
import org.neo4j.gds.collections.ha.HugeDoubleArray;
26+
import org.neo4j.gds.collections.ha.HugeObjectArray;
27+
import org.neo4j.gds.core.concurrency.Concurrency;
28+
import org.neo4j.gds.core.concurrency.ParallelUtil;
29+
import org.neo4j.gds.core.utils.Intersections;
30+
import org.neo4j.gds.core.utils.paged.dss.DisjointSetStruct;
31+
import org.neo4j.gds.core.utils.paged.dss.HugeAtomicDisjointSetStruct;
32+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
33+
34+
public class BoruvkaMST extends Algorithm<GeometricMSTResult> {
35+
36+
private final NodePropertyValues nodePropertyValues;
37+
private final KdTree kdTree;
38+
private final BitSet kdNodeSingleComponent;
39+
private final ClosestDistanceInformationTracker closestDistanceTracker;
40+
private final HugeDoubleArray coreValues;
41+
42+
private final DisjointSetStruct unionFind;
43+
private final long nodeCount;
44+
private final Concurrency concurrency;
45+
46+
private final HugeObjectArray<Edge> edges;
47+
private long edgeCount = 0;
48+
private double totalEdgeSum = 0d;
49+
50+
51+
private BoruvkaMST(
52+
NodePropertyValues nodePropertyValues,
53+
KdTree kdTree,
54+
ClosestDistanceInformationTracker closestDistanceTracker,
55+
HugeDoubleArray coreValues,
56+
long nodeCount, Concurrency concurrency
57+
) {
58+
super(ProgressTracker.NULL_TRACKER);
59+
this.nodePropertyValues = nodePropertyValues;
60+
this.closestDistanceTracker = closestDistanceTracker;
61+
this.kdTree = kdTree;
62+
this.kdNodeSingleComponent = new BitSet(kdTree.treeNodeCount());
63+
64+
this.coreValues = coreValues;
65+
//for now use existing tool
66+
this.unionFind = new HugeAtomicDisjointSetStruct(nodeCount, new Concurrency(1));
67+
68+
this.edges = HugeObjectArray.newArray(Edge.class, nodeCount - 1);
69+
this.nodeCount = nodeCount;
70+
this.concurrency = concurrency;
71+
}
72+
73+
74+
public static BoruvkaMST createWithZeroCores(
75+
NodePropertyValues nodePropertyValues,
76+
KdTree kdTree,
77+
long nodeCount,
78+
Concurrency concurrency
79+
) {
80+
var zeroCores = HugeDoubleArray.newArray(nodeCount);
81+
82+
return new BoruvkaMST(
83+
nodePropertyValues,
84+
kdTree,
85+
ClosestDistanceInformationTracker.create(nodeCount),
86+
zeroCores,
87+
nodeCount,
88+
concurrency
89+
);
90+
}
91+
92+
public static BoruvkaMST create(
93+
NodePropertyValues nodePropertyValues,
94+
KdTree kdTree,
95+
CoreResult coreResult,
96+
long nodeCount,
97+
Concurrency concurrency
98+
) {
99+
var cores = coreResult.createCoreArray();
100+
var closestTracker = ClosestDistanceInformationTracker.create(nodeCount, cores, coreResult);
101+
102+
return new BoruvkaMST(nodePropertyValues, kdTree, closestTracker, cores, nodeCount,concurrency);
103+
}
104+
105+
106+
@Override
107+
public GeometricMSTResult compute() {
108+
var kdRoot = kdTree.root();
109+
var rootId = kdRoot.id();
110+
while (!kdNodeSingleComponent.get(rootId)) {
111+
performIteration();
112+
}
113+
return new GeometricMSTResult(edges, totalEdgeSum);
114+
}
115+
116+
private void performIteration() {
117+
if (closestDistanceTracker.isNotUpdated()) {
118+
119+
ParallelUtil.parallelForEachNode(
120+
nodeCount, concurrency, terminationFlag,
121+
(q) -> {
122+
var qArray = nodePropertyValues.doubleArrayValue(q);
123+
var qComp = unionFind.setIdOf(q);
124+
traversalStep(q,kdTree.root(),qComp,0,qArray);
125+
}
126+
);
127+
}
128+
129+
mergeComponents();
130+
//reset bounds
131+
// TODO: find the component id to reset up to?
132+
closestDistanceTracker.reset(unionFind.size());
133+
updateSingleComponent(kdTree.root());
134+
}
135+
136+
private boolean filterNodesOnCoreValue(long node, long component) {
137+
return coreValues.get(node) < closestDistanceTracker.componentClosestDistance(component);
138+
}
139+
140+
boolean prune(KdNode kdNode, long componentId, double lowerBoundOnDistance){
141+
var nodeComponent = singleComponentOr(kdNode,-1);
142+
if (nodeComponent == componentId) return true;
143+
var currentComponentBest = closestDistanceTracker.componentClosestDistance(componentId);
144+
return currentComponentBest < lowerBoundOnDistance;
145+
}
146+
147+
boolean tryUpdate(long qComp, long q,long r, double distance){
148+
return closestDistanceTracker.tryToAssign(qComp,q,r,distance);
149+
}
150+
151+
double baseCase(long q,long r, double[] qArray, long qComp){
152+
var rComp = unionFind.setIdOf(r);
153+
if (rComp != qComp && filterNodesOnCoreValue(r, qComp)) {
154+
var rArray = nodePropertyValues.doubleArrayValue(r);
155+
var rqDistance = Intersections.sumSquareDelta(qArray, rArray);
156+
var adaptedDistance = Math.max(Math.max(coreValues.get(r), coreValues.get(q)), rqDistance);
157+
if (tryUpdate(qComp,q,r,adaptedDistance)){
158+
return adaptedDistance;
159+
}
160+
}
161+
return -1;
162+
}
163+
164+
void traversalStep(long q, KdNode kdNode, long qComp, double lowerBoundOnDistance,double[] qArray){
165+
if (!prune(kdNode,qComp,lowerBoundOnDistance)) {
166+
if (kdNode.isLeaf()) {
167+
var start = kdNode.start();
168+
var end = kdNode.end();
169+
for (long index = start; index < end; ++index) {
170+
baseCase(q,kdTree.nodeAt(index),qArray,qComp);
171+
}
172+
}else{
173+
var left = kdTree.leftChild(kdNode);
174+
var right = kdTree.rightChild(kdNode);
175+
var lowerBoundLeft = left.aabb().lowerBoundFor(qArray);
176+
lowerBoundLeft*=lowerBoundLeft;
177+
var lowerBoundRight = right.aabb().lowerBoundFor(qArray);
178+
lowerBoundRight*=lowerBoundRight;
179+
180+
if (lowerBoundRight < lowerBoundLeft){
181+
traversalStep(q,right,qComp,lowerBoundRight,qArray);
182+
traversalStep(q,left,qComp,lowerBoundLeft,qArray);
183+
}else{
184+
traversalStep(q,left,qComp,lowerBoundLeft,qArray);
185+
traversalStep(q,right,qComp,lowerBoundRight,qArray);
186+
}
187+
}
188+
}
189+
190+
}
191+
192+
long singleComponentOr(KdNode node, long or) {
193+
long id = node.id();
194+
if (!kdNodeSingleComponent.get(id)) return or; //this is a trick to return distinct id
195+
return unionFind.setIdOf(kdTree.nodeAt(node.start()));
196+
197+
}
198+
199+
void mergeComponents() {
200+
for (var componentId = 0; componentId < nodeCount; componentId++) {
201+
var u = closestDistanceTracker.componentInsideBestNode(componentId);
202+
var v = closestDistanceTracker.componentOutsideBestNode(componentId);
203+
if (u == -1 || v == -1) {
204+
continue;
205+
}
206+
207+
var uComponent = unionFind.setIdOf(u);
208+
var vComponent = unionFind.setIdOf(v);
209+
210+
if (uComponent == vComponent) {
211+
closestDistanceTracker.resetComponent(componentId);
212+
continue;
213+
}
214+
215+
var distance = Math.sqrt(closestDistanceTracker.componentClosestDistance(componentId));
216+
this.edges.set(
217+
edgeCount,
218+
new Edge(u, v, distance)
219+
);
220+
this.edgeCount++;
221+
this.totalEdgeSum += distance;
222+
223+
unionFind.union(uComponent, vComponent);
224+
}
225+
226+
}
227+
228+
boolean updateSingleComponent(KdNode node) {
229+
long id = node.id();
230+
if (kdNodeSingleComponent.get(id)) {
231+
return true;
232+
}
233+
if (node.isLeaf()) {
234+
var start = node.start();
235+
var end = node.end();
236+
long expected = unionFind.setIdOf(kdTree.nodeAt(start));
237+
238+
boolean isSingle = true;
239+
for (var ptr = start + 1; ptr < end; ++ptr) {
240+
if (unionFind.setIdOf(kdTree.nodeAt(ptr)) != expected) {
241+
isSingle = false;
242+
break;
243+
}
244+
}
245+
if (isSingle) {
246+
kdNodeSingleComponent.set(id);
247+
}
248+
return isSingle;
249+
} else {
250+
var left = kdTree.leftChild(node);
251+
var right = kdTree.rightChild(node);
252+
if (updateSingleComponent(left) && updateSingleComponent(right)) {
253+
var singleLeft = singleComponentOr(left, -1);
254+
var singleRight = singleComponentOr(right, -2);
255+
boolean isSingle = singleRight == singleLeft;
256+
if (isSingle) {
257+
kdNodeSingleComponent.set(id);
258+
}
259+
return isSingle;
260+
}
261+
return false;
262+
263+
}
264+
}
265+
266+
void mergeComponents(long comp0, long comp1) {
267+
unionFind.union(comp0, comp1);
268+
}
269+
270+
}

algo/src/main/java/org/neo4j/gds/hdbscan/ClosestDistanceInformationTracker.java

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ final class ClosestDistanceInformationTracker {
2929
private final HugeLongArray componentOutsideBestNode;
3030
private boolean updated = false;
3131

32+
3233
private ClosestDistanceInformationTracker(
3334
HugeDoubleArray componentClosestDistance,
3435
HugeLongArray componentInsideBestNode,
@@ -64,25 +65,29 @@ static ClosestDistanceInformationTracker create(long size, HugeDoubleArray cores
6465
tracker.tryToAssign(u, u, neighbor, adaptedDistance);
6566
}
6667
}
67-
tracker.setUpdated(true);
68+
tracker.updated();
6869

6970
return tracker;
7071

7172
}
7273

73-
private void setUpdated(boolean updated) {
74-
this.updated = updated;
74+
private void updated() {
75+
this.updated = true;
76+
}
77+
78+
private void notUpdated() {
79+
this.updated = false;
7580
}
7681

77-
boolean isUpdated() {
78-
return updated;
82+
boolean isNotUpdated() {
83+
return !updated;
7984
}
8085

8186
void reset(long upTo) {
8287
for (long u = 0; u < upTo; ++u) {
8388
resetComponent(u);
8489
}
85-
setUpdated(false);
90+
notUpdated();
8691
}
8792

8893
void resetComponent(long u) {
@@ -96,7 +101,7 @@ void consider(long comp1, long comp2, long p1, long p2, double distance) {
96101
tryToAssign(comp2, p2, p1, distance);
97102
}
98103

99-
boolean tryToAssign(long comp, long pInside, long pOutside, double distance) {
104+
synchronized boolean tryToAssign(long comp, long pInside, long pOutside, double distance) {
100105
var best = componentClosestDistance.get(comp);
101106
if (best > distance) {
102107
componentClosestDistance.set(comp, distance);

algo/src/main/java/org/neo4j/gds/hdbscan/DualTreeMSTAlgorithm.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import org.neo4j.gds.core.utils.paged.dss.HugeAtomicDisjointSetStruct;
3131
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3232

33-
public final class DualTreeMSTAlgorithm extends Algorithm<DualTreeMSTResult> {
33+
public final class DualTreeMSTAlgorithm extends Algorithm<GeometricMSTResult> {
3434

3535
private final NodePropertyValues nodePropertyValues;
3636
private final KdTree kdTree;
@@ -97,14 +97,14 @@ public static DualTreeMSTAlgorithm create(
9797

9898

9999
@Override
100-
public DualTreeMSTResult compute() {
100+
public GeometricMSTResult compute() {
101101

102102
var kdRoot = kdTree.root();
103103
var rootId = kdRoot.id();
104104
while (!kdNodeSingleComponent.get(rootId)) {
105105
performIteration();
106106
}
107-
return new DualTreeMSTResult(edges, totalEdgeSum);
107+
return new GeometricMSTResult(edges, totalEdgeSum);
108108
}
109109

110110
void resetNodeBounds() {
@@ -136,7 +136,7 @@ void resetNodeBounds() {
136136
}
137137

138138
private void performIteration() {
139-
if (!closestDistanceTracker.isUpdated()) {
139+
if (closestDistanceTracker.isNotUpdated()) {
140140
resetNodeBounds();
141141
traversalStep(kdTree.root(), kdTree.root());
142142
}

algo/src/main/java/org/neo4j/gds/hdbscan/DualTreeMSTResult.java renamed to algo/src/main/java/org/neo4j/gds/hdbscan/GeometricMSTResult.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@
2121

2222
import org.neo4j.gds.collections.ha.HugeObjectArray;
2323

24-
public record DualTreeMSTResult(HugeObjectArray<Edge> edges, double totalDistance) {
24+
public record GeometricMSTResult(HugeObjectArray<Edge> edges, double totalDistance) {
2525
}

0 commit comments

Comments
 (0)