Skip to content

Commit 210a3f3

Browse files
Merge pull request #6870 from IoannisPanagiotas/spanning-tree-bugs-23
[2.3] K-Spanning tree bugs
2 parents 7309eb0 + c711bf7 commit 210a3f3

File tree

9 files changed

+659
-138
lines changed

9 files changed

+659
-138
lines changed

algo/src/main/java/org/neo4j/gds/spanningtree/SpanningTree.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package org.neo4j.gds.spanningtree;
2121

2222
import org.apache.commons.lang3.builder.EqualsBuilder;
23-
import org.neo4j.gds.api.RelationshipConsumer;
23+
import org.neo4j.gds.api.RelationshipWithPropertyConsumer;
2424
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
2525
import org.neo4j.gds.core.utils.paged.HugeLongArray;
2626

@@ -62,21 +62,24 @@ public double totalWeight() {
6262
return totalWeight;
6363
}
6464

65-
public HugeLongArray parentArray() {return parent;}
65+
public HugeLongArray parentArray() {
66+
return parent;
67+
}
6668

6769
public long parent(long nodeId) {return parent.get(nodeId);}
6870

6971
public double costToParent(long nodeId) {
7072
return costToParent.get(nodeId);
7173
}
7274

73-
public void forEach(RelationshipConsumer consumer) {
75+
public void forEach(RelationshipWithPropertyConsumer consumer) {
7476
for (int i = 0; i < nodeCount; i++) {
75-
final long parent = this.parent.get(i);
77+
long parent = this.parent.get(i);
78+
double cost = this.costToParent(i);
7679
if (parent == -1) {
7780
continue;
7881
}
79-
if (!consumer.accept(parent, i)) {
82+
if (!consumer.accept(parent, i, cost)) {
8083
return;
8184
}
8285
}

alpha/alpha-algo/src/main/java/org/neo4j/gds/impl/spanningtree/KSpanningTree.java

Lines changed: 237 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919
*/
2020
package org.neo4j.gds.impl.spanningtree;
2121

22+
import com.carrotsearch.hppc.BitSet;
23+
import org.jetbrains.annotations.NotNull;
2224
import org.neo4j.gds.Algorithm;
2325
import org.neo4j.gds.api.Graph;
26+
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
2427
import org.neo4j.gds.core.utils.paged.HugeLongArray;
2528
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2629
import org.neo4j.gds.core.utils.queue.HugeLongPriorityQueue;
@@ -45,8 +48,6 @@ public class KSpanningTree extends Algorithm<SpanningTree> {
4548
private final long startNodeId;
4649
private final long k;
4750

48-
private SpanningTree spanningTree;
49-
5051
public KSpanningTree(
5152
Graph graph,
5253
DoubleUnaryOperator minMax,
@@ -71,38 +72,248 @@ public SpanningTree compute() {
7172
startNodeId,
7273
progressTracker
7374
);
75+
7476
prim.setTerminationFlag(getTerminationFlag());
7577
SpanningTree spanningTree = prim.compute();
76-
HugeLongArray parent = spanningTree.parentArray();
77-
long parentSize = parent.size();
78-
HugeLongPriorityQueue priorityQueue = minMax == Prim.MAX_OPERATOR ? HugeLongPriorityQueue.min(parentSize) : HugeLongPriorityQueue.max(
79-
parentSize);
80-
progressTracker.beginSubTask(parentSize);
81-
for (long i = 0; i < parentSize && terminationFlag.running(); i++) {
82-
long p = parent.get(i);
83-
if (p == -1) {
84-
continue;
85-
}
86-
priorityQueue.add(i, graph.relationshipProperty(p, i, 0.0D));
87-
progressTracker.logProgress();
88-
}
78+
79+
var outputTree = growApproach(spanningTree);
8980
progressTracker.endSubTask();
90-
progressTracker.beginSubTask(k - 1);
91-
// remove k-1 relationships
92-
for (long i = 0; i < k - 1 && terminationFlag.running(); i++) {
93-
long cutNode = priorityQueue.pop();
94-
parent.set(cutNode, -1);
95-
progressTracker.logProgress();
81+
return outputTree;
82+
}
83+
84+
@NotNull
85+
private HugeLongPriorityQueue createPriorityQueue(long parentSize, boolean pruning) {
86+
boolean minQueue = minMax == Prim.MIN_OPERATOR;
87+
//if pruning, we remove the worst (max if it's a minimization problem)
88+
//therefore we flip the priority queue
89+
if (pruning) {
90+
minQueue = !minQueue;
9691
}
97-
progressTracker.endSubTask();
98-
this.spanningTree = prim.getSpanningTree();
99-
progressTracker.endSubTask();
100-
return this.spanningTree;
92+
HugeLongPriorityQueue priorityQueue = minQueue
93+
? HugeLongPriorityQueue.min(parentSize)
94+
: HugeLongPriorityQueue.max(parentSize);
95+
return priorityQueue;
10196
}
10297

10398
@Override
10499
public void release() {
105100
graph = null;
106-
spanningTree = null;
107101
}
102+
103+
private double init(HugeLongArray parent, HugeDoubleArray costToParent, SpanningTree spanningTree) {
104+
graph.forEachNode((nodeId) -> {
105+
parent.set(nodeId, spanningTree.parent(nodeId));
106+
costToParent.set(nodeId, spanningTree.costToParent(nodeId));
107+
return true;
108+
});
109+
return spanningTree.totalWeight();
110+
}
111+
112+
113+
private SpanningTree growApproach(SpanningTree spanningTree) {
114+
115+
//this approach grows gradually the MST found in the previous step
116+
//when it is about to get larger than K, we crop the current worst leaf if the new value to be added
117+
// is actually better
118+
if (spanningTree.effectiveNodeCount() < k)
119+
return spanningTree;
120+
121+
HugeLongArray outDegree = HugeLongArray.newArray(graph.nodeCount());
122+
123+
HugeLongArray parent = HugeLongArray.newArray(graph.nodeCount());
124+
HugeDoubleArray costToParent = HugeDoubleArray.newArray(graph.nodeCount());
125+
126+
init(parent, costToParent, spanningTree);
127+
double totalCost = 0;
128+
var priorityQueue = createPriorityQueue(graph.nodeCount(), false);
129+
var toTrim = createPriorityQueue(graph.nodeCount(), true);
130+
131+
//priority-queue does not have a remove method
132+
// so we need something to know if a node is still a leaf or not
133+
BitSet exterior = new BitSet(graph.nodeCount());
134+
//at any point, the tree has a root we mark its neighbors in this bitset to avoid looping to find them
135+
BitSet rootNodeAdjacent = new BitSet(graph.nodeCount());
136+
//we just save which nodes are in the final output and not (just to do clean-up; probably can be avoided)
137+
BitSet included = new BitSet(graph.nodeCount());
138+
139+
priorityQueue.add(startNodeId, 0);
140+
long root = startNodeId; //current root is startNodeId
141+
long nodesInTree = 0;
142+
progressTracker.beginSubTask(graph.nodeCount());
143+
while (!priorityQueue.isEmpty()) {
144+
long node = priorityQueue.top();
145+
progressTracker.logProgress();
146+
double associatedCost = priorityQueue.cost(node);
147+
priorityQueue.pop();
148+
long nodeParent = parent.get(node);
149+
150+
boolean nodeAdded = false;
151+
if (nodesInTree < k) { //if we are smaller, we can just add it no problemo
152+
nodesInTree++;
153+
nodeAdded = true;
154+
} else {
155+
var nodeToTrim = findNextValidLeaf(toTrim, exterior); //a leaf node with currently theworst cost
156+
if (parent.get(node) == nodeToTrim) {
157+
//we cannot add it, if we're supposed to remove its parent
158+
//TODO: should be totally feasible to consider the 2nd worst then.
159+
continue;
160+
}
161+
162+
boolean shouldMove = moveMakesSense(associatedCost, toTrim.cost(nodeToTrim), minMax);
163+
164+
if (shouldMove) {
165+
nodeAdded = true;
166+
167+
double value = toTrim.cost(nodeToTrim);
168+
toTrim.pop();
169+
170+
long parentOfTrimmed = parent.get(nodeToTrim);
171+
included.clear(nodeToTrim); //nodeToTrim is removed from the answer
172+
clearNode(nodeToTrim, parent, costToParent);
173+
totalCost -= value; //as well as its cost from the solution
174+
175+
if (root != nodeToTrim) { //we are not removing the actual root
176+
//reduce degree of parent
177+
outDegree.set(parentOfTrimmed, outDegree.get(parentOfTrimmed) - 1);
178+
long affectedNode = -1;
179+
double affectedCost = -1;
180+
long parentOutDegree = outDegree.get(parentOfTrimmed);
181+
if (parentOfTrimmed == root) { //if its parent is the root
182+
rootNodeAdjacent.clear(nodeToTrim); //remove the trimmed child
183+
if (parentOutDegree == 1) { //root becomes a leaf
184+
assert rootNodeAdjacent.cardinality() == 1;
185+
//get the single sole child of root
186+
var rootChild = rootNodeAdjacent.nextSetBit(0);
187+
affectedNode = root;
188+
affectedCost = costToParent.get(rootChild);
189+
}
190+
} else {
191+
if (parentOutDegree == 0) { //if parent becomes a leaf
192+
affectedNode = parentOfTrimmed;
193+
affectedCost = costToParent.get(parentOfTrimmed);
194+
}
195+
}
196+
if (affectedNode != -1) { //if a node has been converted to a leaf
197+
updateExterior(affectedNode, affectedCost, toTrim, exterior);
198+
}
199+
} else {
200+
//the root is removed, long live the new root!
201+
assert rootNodeAdjacent.cardinality() == 1;
202+
//the new root is the single sole child of old root
203+
var newRoot = rootNodeAdjacent.nextSetBit(0);
204+
rootNodeAdjacent.clear(); //empty everything
205+
//find the children of the new root (this can happen once per node)
206+
207+
fillChildren(newRoot, rootNodeAdjacent, parent, included);
208+
209+
root = newRoot;
210+
//set it as root
211+
clearNode(root, parent, costToParent);
212+
//check if root is a degree-1 to add to exterior
213+
if (outDegree.get(root) == 1) {
214+
//get single child
215+
var rootChild = rootNodeAdjacent.nextSetBit(0);
216+
priorityQueue.add(root, costToParent.get(rootChild));
217+
exterior.set(root);
218+
}
219+
}
220+
}
221+
}
222+
if (nodeAdded) {
223+
included.set(node); // include it in the solution (for now!)
224+
totalCost += associatedCost; //add its associated cost to the weight of tree
225+
if (nodeParent == root) { //if it's parent is the root, update the bitset
226+
rootNodeAdjacent.set(node);
227+
}
228+
if (node != root) { //this only happens for startNode to be fair
229+
//the node's parent gets an update in degree
230+
outDegree.set(nodeParent, outDegree.get(nodeParent) + 1);
231+
exterior.clear(nodeParent); //and remoed from exterior if included
232+
}
233+
//then the node (being a leaf) is added to the trimming priority queu
234+
toTrim.add(node, associatedCost);
235+
exterior.set(node); //and the exterior
236+
relaxNode(node, priorityQueue, parent, spanningTree);
237+
238+
} else {
239+
clearNode(node, parent, costToParent);
240+
}
241+
}
242+
//post-processing step: anything not touched is reset to -1
243+
pruneUntouchedNodes(parent, costToParent, included);
244+
progressTracker.endSubTask();
245+
return new SpanningTree(root, graph.nodeCount(), k, parent, costToParent, totalCost);
246+
247+
}
248+
249+
private void pruneUntouchedNodes(HugeLongArray parent, HugeDoubleArray costToParent, BitSet included) {
250+
graph.forEachNode(nodeId -> {
251+
if (!included.get(nodeId)) {
252+
clearNode(nodeId, parent, costToParent);
253+
}
254+
return true;
255+
});
256+
}
257+
258+
private void clearNode(long node, HugeLongArray parent, HugeDoubleArray costToParent) {
259+
parent.set(node, -1);
260+
costToParent.set(node, -1);
261+
}
262+
263+
private boolean moveMakesSense(double cost1, double cost2, DoubleUnaryOperator minMax) {
264+
if (minMax == Prim.MAX_OPERATOR) {
265+
return cost1 > cost2;
266+
} else {
267+
return cost1 < cost2;
268+
}
269+
}
270+
271+
private void updateExterior(long affectedNode, double affectedCost, HugeLongPriorityQueue toTrim, BitSet exterior) {
272+
if (!toTrim.containsElement(affectedNode)) {
273+
toTrim.add(affectedNode, affectedCost); //add it to pq
274+
} else {
275+
//it is still in the queue, but it is not a leaf anymore, so it's value is obsolete
276+
toTrim.set(affectedNode, affectedCost);
277+
}
278+
exterior.set(affectedNode); //and mark it in the exterior
279+
}
280+
281+
private long findNextValidLeaf(HugeLongPriorityQueue toTrim, BitSet exterior) {
282+
while (!exterior.get(toTrim.top())) { //not valid frontier nodes anymore, just ignore
283+
toTrim.pop(); //as we said, pq does not have a direct remove method
284+
}
285+
return toTrim.top();
286+
}
287+
288+
private void fillChildren(long newRoot, BitSet rootNodeAdjacent, HugeLongArray parent, BitSet included) {
289+
graph.forEachRelationship(newRoot, (s, t) -> {
290+
//relevant are only those nodes which are currently
291+
//in the k-tree
292+
if (parent.get(t) == s && included.get(t)) {
293+
rootNodeAdjacent.set(t);
294+
}
295+
return true;
296+
});
297+
}
298+
299+
private void relaxNode(
300+
long node,
301+
HugeLongPriorityQueue priorityQueue,
302+
HugeLongArray parent,
303+
SpanningTree spanningTree
304+
) {
305+
graph.forEachRelationship(node, (s, t) -> {
306+
if (parent.get(t) == s) {
307+
//TODO: work's only on mst edges for now (should be doable to re-find an k-MST from whole graph)
308+
if (!priorityQueue.containsElement(t)) {
309+
priorityQueue.add(t, spanningTree.costToParent(t));
310+
}
311+
312+
}
313+
return true;
314+
});
315+
}
316+
108317
}
318+
319+

alpha/alpha-algo/src/main/java/org/neo4j/gds/impl/spanningtree/KSpanningTreeAlgorithmFactory.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ public Task progressTask(
5252
) {
5353
return Tasks.task(
5454
taskName(),
55-
Tasks.leaf("SpanningTree", graph.nodeCount()),
56-
Tasks.leaf("Add relationship weights"),
55+
Tasks.leaf("SpanningTree", graph.relationshipCount()),
5756
Tasks.leaf("Remove relationships")
5857
);
5958
}

0 commit comments

Comments
 (0)