Skip to content

Commit ac146f5

Browse files
committed
Refactor and update API of node filtered triangle count
Co-Authored-By: Ioannis Panagiotas <ioannis.panagiotas@neotechnology.com>
1 parent cb39687 commit ac146f5

File tree

24 files changed

+236
-178
lines changed

24 files changed

+236
-178
lines changed

algo-params/community-params/src/main/java/org/neo4j/gds/triangle/TriangleCountParameters.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@
2727
import java.util.Optional;
2828

2929
@Parameters
30-
public record TriangleCountParameters(Concurrency concurrency, long maxDegree, Optional<List<String>> labelFilter) implements AlgorithmParameters {
30+
public record TriangleCountParameters(Concurrency concurrency, long maxDegree, List<String> labelFilter) implements AlgorithmParameters {
3131
}

algo/src/main/java/org/neo4j/gds/triangle/IntersectingTriangleCount.java

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,9 @@ public final class IntersectingTriangleCount extends Algorithm<TriangleCountResu
6666
private final HugeAtomicLongArray triangleCounts;
6767
private final long maxDegree;
6868
private final Concurrency concurrency;
69-
private final Optional<NodeLabel> ALabel;
70-
private final Optional<NodeLabel> BLabel;
71-
private final Optional<NodeLabel> CLabel;
72-
private final boolean bTraversal;
73-
private final boolean cTraversal;
69+
private final Optional<NodeLabel> aLabel;
70+
private final Optional<NodeLabel> bLabel;
71+
private final Optional<NodeLabel> cLabel;
7472
private long globalTriangleCount;
7573

7674
private final LongAdder globalTriangleCounter;
@@ -79,7 +77,7 @@ public static IntersectingTriangleCount create(
7977
Graph graph,
8078
Concurrency concurrency,
8179
long maxDegree,
82-
Optional<List<String>> labelFilter,
80+
List<String> labelFilter,
8381
ExecutorService executorService,
8482
ProgressTracker progressTracker,
8583
TerminationFlag terminationFlag
@@ -106,7 +104,7 @@ private IntersectingTriangleCount(
106104
RelationshipIntersectFactory intersectFactory,
107105
Concurrency concurrency,
108106
long maxDegree,
109-
Optional<List<String>> labelFilter,
107+
List<String> labelFilter,
110108
ExecutorService executorService,
111109
ProgressTracker progressTracker,
112110
TerminationFlag terminationFlag
@@ -124,19 +122,20 @@ private IntersectingTriangleCount(
124122
this.globalTriangleCounter = new LongAdder();
125123
this.queue = new AtomicLong();
126124

127-
if (labelFilter.isPresent()) {
128-
var labelFilterList = labelFilter.get();
129-
this.ALabel = Optional.of(NodeLabel.of(labelFilterList.get(0)));
130-
this.BLabel = Optional.of(NodeLabel.of(labelFilterList.get(1)));
131-
this.CLabel = Optional.of(NodeLabel.of(labelFilterList.get(2)));
132-
this.bTraversal = !ALabel.get().equals(BLabel.get());
133-
this.cTraversal = !ALabel.get().equals(CLabel.get()) || !BLabel.get().equals(CLabel.get());
125+
if (!labelFilter.isEmpty()) {
126+
this.aLabel = Optional.of(NodeLabel.of(labelFilter.getFirst()));
134127
} else {
135-
this.ALabel = Optional.empty();
136-
this.BLabel = Optional.empty();
137-
this.CLabel = Optional.empty();
138-
this.bTraversal = false;
139-
this.cTraversal = false;
128+
this.aLabel = Optional.empty();
129+
}
130+
if (labelFilter.size() > 1) {
131+
this.bLabel = Optional.of(NodeLabel.of(labelFilter.get(1)));
132+
} else {
133+
this.bLabel = Optional.empty();
134+
}
135+
if (labelFilter.size() > 2) {
136+
this.cLabel = Optional.of(NodeLabel.of(labelFilter.get(2)));
137+
} else {
138+
this.cLabel = Optional.empty();
140139
}
141140

142141
this.terminationFlag = terminationFlag;
@@ -151,7 +150,7 @@ public TriangleCountResult compute() {
151150
// create tasks
152151
final Collection<? extends Runnable> tasks = ParallelUtil.tasks(
153152
concurrency,
154-
() -> new IntersectTask(intersectFactory.load(graph, maxDegree))
153+
() -> new IntersectTask(intersectFactory.load(graph, maxDegree, this.aLabel, this.bLabel, this.cLabel))
155154
);
156155
// run
157156
ParallelUtil.run(tasks, executorService);
@@ -178,14 +177,11 @@ public void run() {
178177
long node;
179178
while ((node = queue.getAndIncrement()) < graph.nodeCount() && terminationFlag.running()) {
180179
if (graph.degree(node) <= maxDegree) {
181-
if (ALabel.isEmpty() || graph.hasLabel(node, ALabel.get())) {
182-
intersect.intersectAll(node, this, BLabel, CLabel);
183-
}
184-
if (bTraversal && graph.hasLabel(node, BLabel.get())) {
185-
intersect.intersectAll(node, this, CLabel, ALabel);
186-
}
187-
if (cTraversal && graph.hasLabel(node, CLabel.get())) {
188-
intersect.intersectAll(node, this, ALabel, BLabel);
180+
if (cLabel.isEmpty() || bLabel.isEmpty() || aLabel.isEmpty()
181+
|| graph.hasLabel(node, aLabel.get())
182+
|| graph.hasLabel(node, bLabel.get())
183+
|| graph.hasLabel(node, cLabel.get())) {
184+
intersect.intersectAll(node, this);
189185
}
190186
} else {
191187
triangleCounts.set(node, EXCLUDED_NODE_TRIANGLE_COUNT);

algo/src/main/java/org/neo4j/gds/triangle/LocalClusteringCoefficient.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.neo4j.gds.termination.TerminationFlag;
3333
import org.neo4j.gds.utils.CloseableThreadLocal;
3434

35+
import java.util.Collections;
3536
import java.util.Optional;
3637
import java.util.concurrent.atomic.DoubleAdder;
3738
import java.util.function.LongToDoubleFunction;
@@ -136,7 +137,7 @@ private HugeAtomicLongArray computeTriangleCounts() {
136137
graph,
137138
concurrency,
138139
maxDegree,
139-
Optional.empty(),
140+
Collections.emptyList(),
140141
DefaultPool.INSTANCE,
141142
progressTracker,
142143
TerminationFlag.RUNNING_TRUE

algo/src/main/java/org/neo4j/gds/triangle/TriangleStream.java

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,18 @@ public final class TriangleStream extends Algorithm<Stream<TriangleResult>> {
5858
private final ExecutorService executorService;
5959
private final AtomicInteger queue;
6060
private final Concurrency concurrency;
61-
private final Optional<NodeLabel> ALabel;
62-
private final Optional<NodeLabel> BLabel;
63-
private final Optional<NodeLabel> CLabel;
61+
private final Optional<NodeLabel> aLabel;
62+
private final Optional<NodeLabel> bLabel;
63+
private final Optional<NodeLabel> cLabel;
6464
private final int nodeCount;
6565
private final AtomicInteger runningThreads;
6666
private final BlockingQueue<TriangleResult> resultQueue;
67-
private final boolean bTraversal;
68-
private final boolean cTraversal;
6967

7068
public static TriangleStream create(
7169
Graph graph,
7270
ExecutorService executorService,
7371
Concurrency concurrency,
74-
Optional<List<String>> labelFilter,
72+
List<String> labelFilter,
7573
TerminationFlag terminationFlag
7674
) {
7775
var factory = RelationshipIntersectFactoryLocator
@@ -94,7 +92,7 @@ private TriangleStream(
9492
RelationshipIntersectFactory intersectFactory,
9593
ExecutorService executorService,
9694
Concurrency concurrency,
97-
Optional<List<String>> labelFilter,
95+
List<String> labelFilter,
9896
TerminationFlag terminationFlag
9997
) {
10098
super(ProgressTracker.NULL_TRACKER);
@@ -107,19 +105,20 @@ private TriangleStream(
107105
this.runningThreads = new AtomicInteger();
108106
this.queue = new AtomicInteger();
109107

110-
if (labelFilter.isPresent()) {
111-
var labelFilterList = labelFilter.get();
112-
this.ALabel = Optional.of(NodeLabel.of(labelFilterList.get(0)));
113-
this.BLabel = Optional.of(NodeLabel.of(labelFilterList.get(1)));
114-
this.CLabel = Optional.of(NodeLabel.of(labelFilterList.get(2)));
115-
this.bTraversal = !ALabel.get().equals(BLabel.get());
116-
this.cTraversal = !ALabel.get().equals(CLabel.get()) || !BLabel.get().equals(CLabel.get());
108+
if (!labelFilter.isEmpty()) {
109+
this.aLabel = Optional.of(NodeLabel.of(labelFilter.getFirst()));
117110
} else {
118-
this.ALabel = Optional.empty();
119-
this.BLabel = Optional.empty();
120-
this.CLabel = Optional.empty();
121-
this.bTraversal = false;
122-
this.cTraversal = false;
111+
this.aLabel = Optional.empty();
112+
}
113+
if (labelFilter.size() > 1) {
114+
this.bLabel = Optional.of(NodeLabel.of(labelFilter.get(1)));
115+
} else {
116+
this.bLabel = Optional.empty();
117+
}
118+
if (labelFilter.size() > 2) {
119+
this.cLabel = Optional.of(NodeLabel.of(labelFilter.get(2)));
120+
} else {
121+
this.cLabel = Optional.empty();
123122
}
124123

125124
this.terminationFlag = terminationFlag;
@@ -154,7 +153,7 @@ private void submitTasks() {
154153
final Collection<Runnable> tasks;
155154
tasks = ParallelUtil.tasks(
156155
concurrency,
157-
() -> new IntersectTask(intersectFactory.load(graph, Long.MAX_VALUE))
156+
() -> new IntersectTask(intersectFactory.load(graph, Long.MAX_VALUE, this.aLabel, this.bLabel, this.cLabel))
158157
);
159158
ParallelUtil.run(tasks, false, executorService, null);
160159
}
@@ -170,14 +169,11 @@ public final void run() {
170169
try {
171170
int node;
172171
while ((node = queue.getAndIncrement()) < nodeCount && terminationFlag.running()) {
173-
if (ALabel.isEmpty() || graph.hasLabel(node, ALabel.get())) {
174-
evaluateNode(node, BLabel, CLabel);
175-
}
176-
if (bTraversal && graph.hasLabel(node, BLabel.get())) {
177-
evaluateNode(node, CLabel, ALabel);
178-
}
179-
if (cTraversal && graph.hasLabel(node, CLabel.get())) {
180-
evaluateNode(node, ALabel, BLabel);
172+
if (cLabel.isEmpty() || bLabel.isEmpty() || aLabel.isEmpty()
173+
|| graph.hasLabel(node, aLabel.get())
174+
|| graph.hasLabel(node, bLabel.get())
175+
|| graph.hasLabel(node, cLabel.get())) {
176+
evaluateNode(node);
181177
}
182178
progressTracker.logProgress();
183179
}
@@ -186,7 +182,7 @@ public final void run() {
186182
}
187183
}
188184

189-
abstract void evaluateNode(int nodeId, Optional<NodeLabel> blabel, Optional<NodeLabel> cLabel);
185+
abstract void evaluateNode(int nodeId);
190186

191187
void emit(long nodeA, long nodeB, long nodeC) {
192188
var result = new TriangleResult(
@@ -207,8 +203,8 @@ private final class IntersectTask extends BaseTask implements IntersectionConsum
207203
}
208204

209205
@Override
210-
void evaluateNode(final int nodeId, Optional<NodeLabel> bLabel, Optional<NodeLabel> cLabel) {
211-
intersect.intersectAll(nodeId, this, bLabel, cLabel);
206+
void evaluateNode(final int nodeId) {
207+
intersect.intersectAll(nodeId, this);
212208
}
213209

214210
@Override

algo/src/main/java/org/neo4j/gds/triangle/intersect/GraphIntersect.java

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -42,27 +42,34 @@ public abstract class GraphIntersect<CURSOR extends AdjacencyCursor> implements
4242

4343
private final IntPredicate degreeFilter;
4444
private final BiFunction<Long, NodeLabel, Boolean> hasLabel;
45+
private final Optional<NodeLabel> aLabel;
46+
private final Optional<NodeLabel> bLabel;
47+
private final Optional<NodeLabel> cLabel;
4548
private CURSOR origNeighborsOfa;
4649
private CURSOR helpingCursorOfa;
4750
private CURSOR helpingCursorOfb;
4851

4952

5053
protected GraphIntersect(
5154
long maxDegree,
52-
BiFunction<Long, NodeLabel, Boolean> hasLabel
55+
BiFunction<Long, NodeLabel, Boolean> hasLabel,
56+
Optional<NodeLabel> aLabel,
57+
Optional<NodeLabel> bLabel,
58+
Optional<NodeLabel> cLabel
5359
) {
5460
this.degreeFilter = maxDegree < Long.MAX_VALUE
5561
? (degree) -> degree <= maxDegree
5662
: (ignore) -> true;
5763
this.hasLabel = hasLabel;
64+
this.aLabel = aLabel;
65+
this.bLabel = bLabel;
66+
this.cLabel = cLabel;
5867
}
5968

6069
@Override
6170
public void intersectAll(
6271
long a,
63-
IntersectionConsumer consumer,
64-
Optional<NodeLabel> bLabel,
65-
Optional<NodeLabel> cLabel
72+
IntersectionConsumer consumer
6673
) {
6774
// check the first node's degree
6875
int degreeOfa = degree(a);
@@ -72,22 +79,32 @@ public void intersectAll(
7279

7380
origNeighborsOfa = cursorForNode(origNeighborsOfa, a, degreeOfa);
7481

75-
triangles(a, degreeOfa, origNeighborsOfa, consumer, bLabel, cLabel);
82+
triangles(a, degreeOfa, origNeighborsOfa, consumer);
7683
}
7784

7885
private void triangles(
7986
long a,
8087
int degreeOfa,
8188
CURSOR neighborsOfa,
82-
IntersectionConsumer consumer,
83-
Optional<NodeLabel> bLabel,
84-
Optional<NodeLabel> cLabel
89+
IntersectionConsumer consumer
8590
) {
8691
long b = AdjacencyCursorUtils.next(neighborsOfa);
87-
boolean cTraversal = bLabel.isPresent() && !bLabel.get().equals(cLabel.get());
88-
8992
while (b != NOT_FOUND && (b < a)) {
90-
if (bLabel.isEmpty() || hasLabel.apply(b, bLabel.get())) {
93+
// No filters
94+
if ((aLabel.isEmpty() && bLabel.isEmpty() && cLabel.isEmpty())
95+
// One filter
96+
|| (bLabel.isEmpty() && cLabel.isEmpty() && (hasLabel.apply(a, aLabel.get()) || hasLabel.apply(b, aLabel.get())))
97+
// Two filters
98+
|| (cLabel.isEmpty() && aLabel.isPresent() && bLabel.isPresent()
99+
&& (hasLabel.apply(a, aLabel.get()) || hasLabel.apply(a, bLabel.get()) || hasLabel.apply(b, aLabel.get()) && hasLabel.apply(b, bLabel.get())))
100+
// Three filters
101+
|| (aLabel.isPresent() && bLabel.isPresent() && cLabel.isPresent()
102+
&& ((hasLabel.apply(a, aLabel.get()) && hasLabel.apply(b, bLabel.get()))
103+
|| (hasLabel.apply(a, aLabel.get()) && hasLabel.apply(b, cLabel.get()))
104+
|| (hasLabel.apply(a, bLabel.get()) && hasLabel.apply(b, aLabel.get()))
105+
|| (hasLabel.apply(a, bLabel.get()) && hasLabel.apply(b, cLabel.get()))
106+
|| (hasLabel.apply(a, cLabel.get()) && hasLabel.apply(b, aLabel.get()))
107+
|| (hasLabel.apply(a, cLabel.get()) && hasLabel.apply(b, bLabel.get()))))) {
91108
var degreeOfb = degree(b);
92109
if (degreeFilter.test(degreeOfb)) {
93110
helpingCursorOfb = cursorForNode(
@@ -103,34 +120,10 @@ private void triangles(
103120
b,
104121
helpingCursorOfa,
105122
helpingCursorOfb,
106-
consumer,
107-
cLabel
123+
consumer
108124
); //find all triangles involving the edge (a-b)
109125
}
110126
}
111-
if (cTraversal) {
112-
if (cLabel.isEmpty() || hasLabel.apply(b, cLabel.get())) {
113-
var degreeOfb = degree(b);
114-
if (degreeFilter.test(degreeOfb)) {
115-
helpingCursorOfb = cursorForNode(
116-
helpingCursorOfb,
117-
b,
118-
degreeOfb
119-
);
120-
121-
helpingCursorOfa = cursorForNode(helpingCursorOfa, a, degreeOfa);
122-
123-
triangles(
124-
a,
125-
b,
126-
helpingCursorOfa,
127-
helpingCursorOfb,
128-
consumer,
129-
bLabel
130-
); //find all triangles involving the edge (a-b)
131-
}
132-
}
133-
}
134127

135128
b = AdjacencyCursorUtils.next(neighborsOfa);
136129
}
@@ -142,13 +135,29 @@ private void triangles(
142135
long b,
143136
CURSOR neighborsOfa,
144137
CURSOR neighborsOfb,
145-
IntersectionConsumer consumer,
146-
Optional<NodeLabel> cLabel
138+
IntersectionConsumer consumer
147139
) {
148140
long c = AdjacencyCursorUtils.next(neighborsOfb);
149141
long currentOfa = AdjacencyCursorUtils.next(neighborsOfa);
150142
while (c != NOT_FOUND && currentOfa != NOT_FOUND && (c < b)) {
151-
if (cLabel.isEmpty() || hasLabel.apply(c, cLabel.get())) {
143+
// No filters
144+
if ((aLabel.isEmpty() && bLabel.isEmpty() && cLabel.isEmpty())
145+
// One filter
146+
|| (bLabel.isEmpty() && cLabel.isEmpty() && (hasLabel.apply(a, aLabel.get()) || hasLabel.apply(b, aLabel.get()) || hasLabel.apply(c, aLabel.get())))
147+
// Two filters
148+
|| (cLabel.isEmpty() && aLabel.isPresent() && bLabel.isPresent()
149+
&& ((hasLabel.apply(a, bLabel.get()) && (hasLabel.apply(b, aLabel.get()) || hasLabel.apply(c, aLabel.get())))
150+
|| (hasLabel.apply(b, bLabel.get()) && (hasLabel.apply(a, aLabel.get()) || hasLabel.apply(c, aLabel.get())))
151+
|| (hasLabel.apply(c, bLabel.get()) && (hasLabel.apply(a, aLabel.get()) || hasLabel.apply(b, aLabel.get())))))
152+
// Three filters
153+
|| (aLabel.isPresent() && bLabel.isPresent() && cLabel.isPresent())
154+
&& ((hasLabel.apply(a, aLabel.get()) && hasLabel.apply(b, bLabel.get()) && hasLabel.apply(c, cLabel.get()))
155+
|| (hasLabel.apply(a, aLabel.get()) && hasLabel.apply(c, aLabel.get()) && hasLabel.apply(b, cLabel.get()))
156+
|| (hasLabel.apply(b, aLabel.get()) && hasLabel.apply(a, bLabel.get()) && hasLabel.apply(c, cLabel.get()))
157+
|| (hasLabel.apply(b, aLabel.get()) && hasLabel.apply(c, bLabel.get()) && hasLabel.apply(a, cLabel.get()))
158+
|| (hasLabel.apply(c, aLabel.get()) && hasLabel.apply(a, bLabel.get()) && hasLabel.apply(b, cLabel.get()))
159+
|| (hasLabel.apply(c, aLabel.get()) && hasLabel.apply(b, bLabel.get()) && hasLabel.apply(a, cLabel.get())))) {
160+
152161
var degreeOfc = degree(c);
153162
if (degreeFilter.test(degreeOfc)) {
154163
currentOfa = AdjacencyCursorUtils.advance(neighborsOfa, currentOfa, c);

0 commit comments

Comments
 (0)