Skip to content

Commit 89f11e6

Browse files
committed
Add label filters to triangle count
1 parent 28c2884 commit 89f11e6

File tree

20 files changed

+334
-82
lines changed

20 files changed

+334
-82
lines changed

algo-params/community-params/src/main/java/org/neo4j/gds/triangle/TriangleCountParameters.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import org.neo4j.gds.annotation.Parameters;
2424
import org.neo4j.gds.core.concurrency.Concurrency;
2525

26+
import java.util.Optional;
27+
2628
@Parameters
27-
public record TriangleCountParameters(Concurrency concurrency, long maxDegree) implements AlgorithmParameters {
29+
public record TriangleCountParameters(Concurrency concurrency, long maxDegree, Optional<String> ALabel, Optional<String> BLabel, Optional<String> CLabel) implements AlgorithmParameters {
2830
}

algo/src/main/java/org/neo4j/gds/triangle/IntersectingTriangleCount.java

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.neo4j.gds.triangle;
2121

2222
import org.neo4j.gds.Algorithm;
23+
import org.neo4j.gds.NodeLabel;
2324
import org.neo4j.gds.api.Graph;
2425
import org.neo4j.gds.api.IntersectionConsumer;
2526
import org.neo4j.gds.api.RelationshipIntersect;
@@ -33,6 +34,7 @@
3334
import org.neo4j.gds.triangle.intersect.RelationshipIntersectFactoryLocator;
3435

3536
import java.util.Collection;
37+
import java.util.Optional;
3638
import java.util.concurrent.ExecutorService;
3739
import java.util.concurrent.atomic.AtomicLong;
3840
import java.util.concurrent.atomic.LongAdder;
@@ -63,6 +65,9 @@ public final class IntersectingTriangleCount extends Algorithm<TriangleCountResu
6365
private final HugeAtomicLongArray triangleCounts;
6466
private final long maxDegree;
6567
private final Concurrency concurrency;
68+
private final Optional<NodeLabel> ALabel;
69+
private final Optional<NodeLabel> BLabel;
70+
private final Optional<NodeLabel> CLabel;
6671
private long globalTriangleCount;
6772

6873
private final LongAdder globalTriangleCounter;
@@ -71,6 +76,9 @@ public static IntersectingTriangleCount create(
7176
Graph graph,
7277
Concurrency concurrency,
7378
long maxDegree,
79+
Optional<String> ALabel,
80+
Optional<String> BLabel,
81+
Optional<String> CLabel,
7482
ExecutorService executorService,
7583
ProgressTracker progressTracker,
7684
TerminationFlag terminationFlag
@@ -80,14 +88,28 @@ public static IntersectingTriangleCount create(
8088
.orElseThrow(
8189
() -> new IllegalArgumentException("No relationship intersect factory registered for graph: " + graph.getClass())
8290
);
83-
return new IntersectingTriangleCount(graph, factory, concurrency, maxDegree, executorService, progressTracker, terminationFlag);
91+
return new IntersectingTriangleCount(
92+
graph,
93+
factory,
94+
concurrency,
95+
maxDegree,
96+
ALabel,
97+
BLabel,
98+
CLabel,
99+
executorService,
100+
progressTracker,
101+
terminationFlag
102+
);
84103
}
85104

86105
private IntersectingTriangleCount(
87106
Graph graph,
88107
RelationshipIntersectFactory intersectFactory,
89108
Concurrency concurrency,
90109
long maxDegree,
110+
Optional<String> ALabel,
111+
Optional<String> BLabel,
112+
Optional<String> CLabel,
91113
ExecutorService executorService,
92114
ProgressTracker progressTracker,
93115
TerminationFlag terminationFlag
@@ -97,7 +119,13 @@ private IntersectingTriangleCount(
97119
this.intersectFactory = intersectFactory;
98120
this.concurrency = concurrency;
99121
this.maxDegree = maxDegree;
100-
this.triangleCounts = HugeAtomicLongArray.of(graph.nodeCount(), ParalleLongPageCreator.passThrough(concurrency));
122+
this.ALabel = ALabel.map(NodeLabel::of);
123+
this.BLabel = BLabel.map(NodeLabel::of);
124+
this.CLabel = CLabel.map(NodeLabel::of);
125+
this.triangleCounts = HugeAtomicLongArray.of(
126+
graph.nodeCount(),
127+
ParalleLongPageCreator.passThrough(concurrency)
128+
);
101129
this.executorService = executorService;
102130
this.globalTriangleCounter = new LongAdder();
103131
this.queue = new AtomicLong();
@@ -110,10 +138,12 @@ public TriangleCountResult compute() {
110138
progressTracker.beginSubTask();
111139
queue.set(0);
112140
globalTriangleCounter.reset();
141+
142+
boolean filtered = ALabel.isPresent() || BLabel.isPresent() || CLabel.isPresent();
113143
// create tasks
114144
final Collection<? extends Runnable> tasks = ParallelUtil.tasks(
115145
concurrency,
116-
() -> new IntersectTask(intersectFactory.load(graph, maxDegree))
146+
() -> new IntersectTask(intersectFactory.load(graph, maxDegree, BLabel, CLabel, filtered))
117147
);
118148
// run
119149
ParallelUtil.run(tasks, executorService);
@@ -140,7 +170,9 @@ public void run() {
140170
long node;
141171
while ((node = queue.getAndIncrement()) < graph.nodeCount() && terminationFlag.running()) {
142172
if (graph.degree(node) <= maxDegree) {
143-
intersect.intersectAll(node, this);
173+
if (ALabel.isEmpty() || graph.hasLabel(node, ALabel.get())) {
174+
intersect.intersectAll(node, this);
175+
}
144176
} else {
145177
triangleCounts.set(node, EXCLUDED_NODE_TRIANGLE_COUNT);
146178
}

algo/src/main/java/org/neo4j/gds/triangle/LocalClusteringCoefficient.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ private HugeAtomicLongArray computeTriangleCounts() {
136136
graph,
137137
concurrency,
138138
maxDegree,
139+
Optional.empty(),
140+
Optional.empty(),
141+
Optional.empty(),
139142
DefaultPool.INSTANCE,
140143
progressTracker,
141144
TerminationFlag.RUNNING_TRUE

algo/src/main/java/org/neo4j/gds/triangle/TriangleStream.java

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import com.carrotsearch.hppc.AbstractIterator;
2323
import org.neo4j.gds.Algorithm;
24+
import org.neo4j.gds.NodeLabel;
2425
import org.neo4j.gds.api.Graph;
2526
import org.neo4j.gds.api.IntersectionConsumer;
2627
import org.neo4j.gds.api.RelationshipIntersect;
@@ -34,6 +35,7 @@
3435
import java.util.Collection;
3536
import java.util.Iterator;
3637
import java.util.Objects;
38+
import java.util.Optional;
3739
import java.util.Spliterators;
3840
import java.util.concurrent.ArrayBlockingQueue;
3941
import java.util.concurrent.BlockingQueue;
@@ -55,6 +57,9 @@ public final class TriangleStream extends Algorithm<Stream<TriangleResult>> {
5557
private final ExecutorService executorService;
5658
private final AtomicInteger queue;
5759
private final Concurrency concurrency;
60+
private final Optional<NodeLabel> ALabel;
61+
private final Optional<NodeLabel> BLabel;
62+
private final Optional<NodeLabel> CLabel;
5863
private final int nodeCount;
5964
private final AtomicInteger runningThreads;
6065
private final BlockingQueue<TriangleResult> resultQueue;
@@ -63,28 +68,46 @@ public static TriangleStream create(
6368
Graph graph,
6469
ExecutorService executorService,
6570
Concurrency concurrency,
71+
Optional<String> ALabel,
72+
Optional<String> BLabel,
73+
Optional<String> CLabel,
6674
TerminationFlag terminationFlag
6775
) {
6876
var factory = RelationshipIntersectFactoryLocator
6977
.lookup(graph)
7078
.orElseThrow(
7179
() -> new IllegalArgumentException("No relationship intersect factory registered for graph: " + graph.getClass())
7280
);
73-
return new TriangleStream(graph, factory, executorService, concurrency, terminationFlag);
81+
return new TriangleStream(
82+
graph,
83+
factory,
84+
executorService,
85+
concurrency,
86+
ALabel,
87+
BLabel,
88+
CLabel,
89+
terminationFlag
90+
);
7491
}
7592

7693
private TriangleStream(
7794
Graph graph,
7895
RelationshipIntersectFactory intersectFactory,
7996
ExecutorService executorService,
8097
Concurrency concurrency,
98+
Optional<String> ALabel,
99+
Optional<String> BLabel,
100+
Optional<String> CLabel,
81101
TerminationFlag terminationFlag
82102
) {
83103
super(ProgressTracker.NULL_TRACKER);
84104
this.graph = graph;
85105
this.intersectFactory = intersectFactory;
86106
this.executorService = executorService;
87107
this.concurrency = concurrency;
108+
this.ALabel = ALabel.map(NodeLabel::of);
109+
this.BLabel = BLabel.map(NodeLabel::of);
110+
this.CLabel = CLabel.map(NodeLabel::of);
88111
this.nodeCount = Math.toIntExact(graph.nodeCount());
89112
this.resultQueue = new ArrayBlockingQueue<>(concurrency.value() << 10);
90113
this.runningThreads = new AtomicInteger();
@@ -111,16 +134,20 @@ protected TriangleResult fetch() {
111134
};
112135

113136
return StreamSupport
114-
.stream(Spliterators.spliteratorUnknownSize(it, 0), false)
115-
.filter(Objects::nonNull)
116-
.onClose(progressTracker::endSubTask);
137+
.stream(Spliterators.spliteratorUnknownSize(it, 0), false)
138+
.filter(Objects::nonNull)
139+
.onClose(progressTracker::endSubTask);
117140
}
118141

119142
private void submitTasks() {
120143
queue.set(0);
121144
runningThreads.set(0);
122145
final Collection<Runnable> tasks;
123-
tasks = ParallelUtil.tasks(concurrency, () -> new IntersectTask(intersectFactory.load(graph, Long.MAX_VALUE)));
146+
boolean filtered = ALabel.isPresent() || BLabel.isPresent() || CLabel.isPresent();
147+
tasks = ParallelUtil.tasks(
148+
concurrency,
149+
() -> new IntersectTask(intersectFactory.load(graph, Long.MAX_VALUE, BLabel, CLabel, filtered))
150+
);
124151
ParallelUtil.run(tasks, false, executorService, null);
125152
}
126153

@@ -135,7 +162,9 @@ public final void run() {
135162
try {
136163
int node;
137164
while ((node = queue.getAndIncrement()) < nodeCount && terminationFlag.running()) {
138-
evaluateNode(node);
165+
if (ALabel.isEmpty() || graph.hasLabel(node, ALabel.get())) {
166+
evaluateNode(node);
167+
}
139168
progressTracker.logProgress();
140169
}
141170
} finally {
@@ -147,9 +176,10 @@ public final void run() {
147176

148177
void emit(long nodeA, long nodeB, long nodeC) {
149178
var result = new TriangleResult(
150-
graph.toOriginalNodeId(nodeA),
151-
graph.toOriginalNodeId(nodeB),
152-
graph.toOriginalNodeId(nodeC));
179+
graph.toOriginalNodeId(nodeA),
180+
graph.toOriginalNodeId(nodeB),
181+
graph.toOriginalNodeId(nodeC)
182+
);
153183
resultQueue.offer(result);
154184
}
155185
}

algo/src/main/java/org/neo4j/gds/triangle/intersect/GraphIntersect.java

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@
2020
package org.neo4j.gds.triangle.intersect;
2121

2222
import org.jetbrains.annotations.Nullable;
23+
import org.neo4j.gds.NodeLabel;
2324
import org.neo4j.gds.api.AdjacencyCursor;
2425
import org.neo4j.gds.api.AdjacencyCursorUtils;
2526
import org.neo4j.gds.api.IntersectionConsumer;
2627
import org.neo4j.gds.api.RelationshipIntersect;
2728

29+
import java.util.Optional;
30+
import java.util.function.BiFunction;
2831
import java.util.function.IntPredicate;
2932

3033
import static org.neo4j.gds.api.AdjacencyCursor.NOT_FOUND;
@@ -38,15 +41,29 @@
3841
public abstract class GraphIntersect<CURSOR extends AdjacencyCursor> implements RelationshipIntersect {
3942

4043
private final IntPredicate degreeFilter;
44+
private final Optional<NodeLabel> BLabel;
45+
private final Optional<NodeLabel> CLabel;
46+
private final BiFunction<Long, NodeLabel, Boolean> hasLabel;
47+
private final boolean filtered;
4148
private CURSOR origNeighborsOfa;
4249
private CURSOR helpingCursorOfa;
4350
private CURSOR helpingCursorOfb;
4451

4552

46-
protected GraphIntersect(long maxDegree) {
53+
protected GraphIntersect(
54+
long maxDegree,
55+
Optional<NodeLabel> BLabel,
56+
Optional<NodeLabel> CLabel,
57+
BiFunction<Long, NodeLabel, Boolean> hasLabel,
58+
boolean filtered
59+
) {
4760
this.degreeFilter = maxDegree < Long.MAX_VALUE
4861
? (degree) -> degree <= maxDegree
4962
: (ignore) -> true;
63+
this.BLabel = BLabel;
64+
this.CLabel = CLabel;
65+
this.hasLabel = hasLabel;
66+
this.filtered = filtered;
5067
}
5168

5269
@Override
@@ -69,24 +86,26 @@ private void triangles(
6986
IntersectionConsumer consumer
7087
) {
7188
long b = AdjacencyCursorUtils.next(neighborsOfa);
72-
while (b != NOT_FOUND && b < a) {
73-
var degreeOfb = degree(b);
74-
if (degreeFilter.test(degreeOfb)) {
75-
helpingCursorOfb = cursorForNode(
76-
helpingCursorOfb,
77-
b,
78-
degreeOfb
79-
);
80-
81-
helpingCursorOfa = cursorForNode(helpingCursorOfa, a, degreeOfa);
82-
83-
triangles(
84-
a,
85-
b,
86-
helpingCursorOfa,
87-
helpingCursorOfb,
88-
consumer
89-
); //find all triangles involving the edge (a-b)
89+
while (b != NOT_FOUND && (b < a || filtered)) {
90+
if (BLabel.isEmpty() || hasLabel.apply(b, BLabel.get())) {
91+
var degreeOfb = degree(b);
92+
if (degreeFilter.test(degreeOfb)) {
93+
helpingCursorOfb = cursorForNode(
94+
helpingCursorOfb,
95+
b,
96+
degreeOfb
97+
);
98+
99+
helpingCursorOfa = cursorForNode(helpingCursorOfa, a, degreeOfa);
100+
101+
triangles(
102+
a,
103+
b,
104+
helpingCursorOfa,
105+
helpingCursorOfb,
106+
consumer
107+
); //find all triangles involving the edge (a-b)
108+
}
90109
}
91110

92111
b = AdjacencyCursorUtils.next(neighborsOfa);
@@ -97,13 +116,14 @@ private void triangles(
97116
private void triangles(long a, long b, CURSOR neighborsOfa, CURSOR neighborsOfb, IntersectionConsumer consumer) {
98117
long c = AdjacencyCursorUtils.next(neighborsOfb);
99118
long currentOfa = AdjacencyCursorUtils.next(neighborsOfa);
100-
while (c != NOT_FOUND && currentOfa != NOT_FOUND && c < b) {
101-
var degreeOfc = degree(c);
102-
if (degreeFilter.test(degreeOfc)) {
103-
currentOfa = AdjacencyCursorUtils.advance(neighborsOfa, currentOfa, c);
104-
//now print all triangles a-b-c (taking into consideration the parallel edges of c)
105-
checkForAndEmitTriangle(consumer, a, b, currentOfa, c);
106-
119+
while (c != NOT_FOUND && currentOfa != NOT_FOUND && (c < b || filtered)) {
120+
if (CLabel.isEmpty() || hasLabel.apply(c, CLabel.get())) {
121+
var degreeOfc = degree(c);
122+
if (degreeFilter.test(degreeOfc)) {
123+
currentOfa = AdjacencyCursorUtils.advance(neighborsOfa, currentOfa, c);
124+
//now print all triangles a-b-c (taking into consideration the parallel edges of c)
125+
checkForAndEmitTriangle(consumer, a, b, currentOfa, c);
126+
}
107127
}
108128
c = AdjacencyCursorUtils.next(neighborsOfb);
109129
}

0 commit comments

Comments
 (0)