Skip to content

Commit fa19830

Browse files
committed
Change approach to label filtered triangle count
1 parent c16aa5d commit fa19830

File tree

17 files changed

+247
-265
lines changed

17 files changed

+247
-265
lines changed

algo-params/community-params/src/main/java/org/neo4j/gds/triangle/TriangleCountParameters.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
import org.neo4j.gds.annotation.Parameters;
2424
import org.neo4j.gds.core.concurrency.Concurrency;
2525

26+
import java.util.List;
2627
import java.util.Optional;
2728

2829
@Parameters
29-
public record TriangleCountParameters(Concurrency concurrency, long maxDegree, Optional<String> ALabel, Optional<String> BLabel, Optional<String> CLabel) implements AlgorithmParameters {
30+
public record TriangleCountParameters(Concurrency concurrency, long maxDegree, Optional<List<String>> labelFilter) implements AlgorithmParameters {
3031
}

algo/src/main/java/org/neo4j/gds/triangle/IntersectingTriangleCount.java

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.neo4j.gds.triangle.intersect.RelationshipIntersectFactoryLocator;
3535

3636
import java.util.Collection;
37+
import java.util.List;
3738
import java.util.Optional;
3839
import java.util.concurrent.ExecutorService;
3940
import java.util.concurrent.atomic.AtomicLong;
@@ -78,9 +79,7 @@ public static IntersectingTriangleCount create(
7879
Graph graph,
7980
Concurrency concurrency,
8081
long maxDegree,
81-
Optional<String> ALabel,
82-
Optional<String> BLabel,
83-
Optional<String> CLabel,
82+
Optional<List<String>> labelFilter,
8483
ExecutorService executorService,
8584
ProgressTracker progressTracker,
8685
TerminationFlag terminationFlag
@@ -95,9 +94,7 @@ public static IntersectingTriangleCount create(
9594
factory,
9695
concurrency,
9796
maxDegree,
98-
ALabel,
99-
BLabel,
100-
CLabel,
97+
labelFilter,
10198
executorService,
10299
progressTracker,
103100
terminationFlag
@@ -109,9 +106,7 @@ private IntersectingTriangleCount(
109106
RelationshipIntersectFactory intersectFactory,
110107
Concurrency concurrency,
111108
long maxDegree,
112-
Optional<String> ALabel,
113-
Optional<String> BLabel,
114-
Optional<String> CLabel,
109+
Optional<List<String>> labelFilter,
115110
ExecutorService executorService,
116111
ProgressTracker progressTracker,
117112
TerminationFlag terminationFlag
@@ -121,9 +116,6 @@ private IntersectingTriangleCount(
121116
this.intersectFactory = intersectFactory;
122117
this.concurrency = concurrency;
123118
this.maxDegree = maxDegree;
124-
this.ALabel = ALabel.map(NodeLabel::of);
125-
this.BLabel = BLabel.map(NodeLabel::of);
126-
this.CLabel = CLabel.map(NodeLabel::of);
127119
this.triangleCounts = HugeAtomicLongArray.of(
128120
graph.nodeCount(),
129121
ParalleLongPageCreator.passThrough(concurrency)
@@ -132,11 +124,20 @@ private IntersectingTriangleCount(
132124
this.globalTriangleCounter = new LongAdder();
133125
this.queue = new AtomicLong();
134126

135-
this.bTraversal = ALabel.isPresent() && ((BLabel.isPresent() && !ALabel.get()
136-
.equals(BLabel.get())) || BLabel.isEmpty());
137-
this.cTraversal = (ALabel.isPresent() && ((CLabel.isPresent() && !ALabel.get()
138-
.equals(CLabel.get())) || CLabel.isEmpty()) && (BLabel.isPresent() && ((CLabel.isPresent() && !BLabel.get()
139-
.equals(CLabel.get())) || CLabel.isEmpty())));
127+
if (labelFilter.isPresent()) {
128+
var labelFilterList = labelFilter.get();
129+
this.ALabel = Optional.of(NodeLabel.of(labelFilterList.get(0)));
130+
this.BLabel = Optional.of(NodeLabel.of(labelFilterList.get(1)));
131+
this.CLabel = Optional.of(NodeLabel.of(labelFilterList.get(2)));
132+
this.bTraversal = !ALabel.get().equals(BLabel.get());
133+
this.cTraversal = !ALabel.get().equals(CLabel.get()) || !BLabel.get().equals(CLabel.get());
134+
} else {
135+
this.ALabel = Optional.empty();
136+
this.BLabel = Optional.empty();
137+
this.CLabel = Optional.empty();
138+
this.bTraversal = false;
139+
this.cTraversal = false;
140+
}
140141

141142
this.terminationFlag = terminationFlag;
142143
}
@@ -180,10 +181,10 @@ public void run() {
180181
if (ALabel.isEmpty() || graph.hasLabel(node, ALabel.get())) {
181182
intersect.intersectAll(node, this, BLabel, CLabel);
182183
}
183-
if (bTraversal && (BLabel.isEmpty() || graph.hasLabel(node, BLabel.get()))) {
184+
if (bTraversal && graph.hasLabel(node, BLabel.get())) {
184185
intersect.intersectAll(node, this, CLabel, ALabel);
185186
}
186-
if (cTraversal && (CLabel.isEmpty() || graph.hasLabel(node, CLabel.get()))) {
187+
if (cTraversal && graph.hasLabel(node, CLabel.get())) {
187188
intersect.intersectAll(node, this, ALabel, BLabel);
188189
}
189190
} else {

algo/src/main/java/org/neo4j/gds/triangle/LocalClusteringCoefficient.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ private HugeAtomicLongArray computeTriangleCounts() {
137137
concurrency,
138138
maxDegree,
139139
Optional.empty(),
140-
Optional.empty(),
141-
Optional.empty(),
142140
DefaultPool.INSTANCE,
143141
progressTracker,
144142
TerminationFlag.RUNNING_TRUE

algo/src/main/java/org/neo4j/gds/triangle/TriangleStream.java

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
import java.util.Collection;
3636
import java.util.Iterator;
37+
import java.util.List;
3738
import java.util.Objects;
3839
import java.util.Optional;
3940
import java.util.Spliterators;
@@ -63,14 +64,14 @@ public final class TriangleStream extends Algorithm<Stream<TriangleResult>> {
6364
private final int nodeCount;
6465
private final AtomicInteger runningThreads;
6566
private final BlockingQueue<TriangleResult> resultQueue;
67+
private final boolean bTraversal;
68+
private final boolean cTraversal;
6669

6770
public static TriangleStream create(
6871
Graph graph,
6972
ExecutorService executorService,
7073
Concurrency concurrency,
71-
Optional<String> ALabel,
72-
Optional<String> BLabel,
73-
Optional<String> CLabel,
74+
Optional<List<String>> labelFilter,
7475
TerminationFlag terminationFlag
7576
) {
7677
var factory = RelationshipIntersectFactoryLocator
@@ -83,9 +84,7 @@ public static TriangleStream create(
8384
factory,
8485
executorService,
8586
concurrency,
86-
ALabel,
87-
BLabel,
88-
CLabel,
87+
labelFilter,
8988
terminationFlag
9089
);
9190
}
@@ -95,24 +94,34 @@ private TriangleStream(
9594
RelationshipIntersectFactory intersectFactory,
9695
ExecutorService executorService,
9796
Concurrency concurrency,
98-
Optional<String> ALabel,
99-
Optional<String> BLabel,
100-
Optional<String> CLabel,
97+
Optional<List<String>> labelFilter,
10198
TerminationFlag terminationFlag
10299
) {
103100
super(ProgressTracker.NULL_TRACKER);
104101
this.graph = graph;
105102
this.intersectFactory = intersectFactory;
106103
this.executorService = executorService;
107104
this.concurrency = concurrency;
108-
this.ALabel = ALabel.map(NodeLabel::of);
109-
this.BLabel = BLabel.map(NodeLabel::of);
110-
this.CLabel = CLabel.map(NodeLabel::of);
111105
this.nodeCount = Math.toIntExact(graph.nodeCount());
112106
this.resultQueue = new ArrayBlockingQueue<>(concurrency.value() << 10);
113107
this.runningThreads = new AtomicInteger();
114108
this.queue = new AtomicInteger();
115109

110+
if (labelFilter.isPresent()) {
111+
var labelFilterList = labelFilter.get();
112+
this.ALabel = Optional.of(NodeLabel.of(labelFilterList.get(0)));
113+
this.BLabel = Optional.of(NodeLabel.of(labelFilterList.get(1)));
114+
this.CLabel = Optional.of(NodeLabel.of(labelFilterList.get(2)));
115+
this.bTraversal = !ALabel.get().equals(BLabel.get());
116+
this.cTraversal = !ALabel.get().equals(CLabel.get()) || !BLabel.get().equals(CLabel.get());
117+
} else {
118+
this.ALabel = Optional.empty();
119+
this.BLabel = Optional.empty();
120+
this.CLabel = Optional.empty();
121+
this.bTraversal = false;
122+
this.cTraversal = false;
123+
}
124+
116125
this.terminationFlag = terminationFlag;
117126
}
118127

@@ -164,10 +173,10 @@ public final void run() {
164173
if (ALabel.isEmpty() || graph.hasLabel(node, ALabel.get())) {
165174
evaluateNode(node, BLabel, CLabel);
166175
}
167-
if (BLabel.isEmpty() || graph.hasLabel(node, BLabel.get())) {
176+
if (bTraversal && graph.hasLabel(node, BLabel.get())) {
168177
evaluateNode(node, CLabel, ALabel);
169178
}
170-
if (CLabel.isEmpty() || graph.hasLabel(node, CLabel.get())) {
179+
if (cTraversal && graph.hasLabel(node, CLabel.get())) {
171180
evaluateNode(node, ALabel, BLabel);
172181
}
173182
progressTracker.logProgress();

algo/src/main/java/org/neo4j/gds/triangle/intersect/GraphIntersect.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,7 @@ private void triangles(
8484
Optional<NodeLabel> cLabel
8585
) {
8686
long b = AdjacencyCursorUtils.next(neighborsOfa);
87-
boolean cTraversal = bLabel.isPresent() && ((cLabel.isPresent() && !bLabel.get()
88-
.equals(cLabel.get())) || cLabel.isEmpty());
87+
boolean cTraversal = bLabel.isPresent() && !bLabel.get().equals(cLabel.get());
8988

9089
while (b != NOT_FOUND && (b < a)) {
9190
if (bLabel.isEmpty() || hasLabel.apply(b, bLabel.get())) {

algo/src/test/java/org/neo4j/gds/triangle/IntersectingTriangleCountFilteredGraphTest.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ void testUnionGraphWithNodeFilter() {
7878
new Concurrency(4),
7979
Long.MAX_VALUE,
8080
Optional.empty(),
81-
Optional.empty(),
82-
Optional.empty(),
8381
DefaultPool.INSTANCE,
8482
ProgressTracker.NULL_TRACKER,
8583
TerminationFlag.RUNNING_TRUE

0 commit comments

Comments
 (0)