Skip to content

Commit 996001f

Browse files
Merge pull request #10837 from IoannisPanagiotas/msbfs-refactoring
Msbfs refactoring
2 parents 3c3c6ad + 4b95ef5 commit 996001f

16 files changed

+434
-163
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.msbfs;
21+
22+
import org.neo4j.gds.collections.ha.HugeLongArray;
23+
24+
public class EmptySourceNodesSpec implements SourceNodesSpec{
25+
26+
private final int localNodeCount;
27+
private final long nodeOffset;
28+
29+
public EmptySourceNodesSpec(long nodeOffset, int localNodeCount) {
30+
this.localNodeCount = localNodeCount;
31+
this.nodeOffset = nodeOffset;
32+
}
33+
34+
@Override
35+
public SourceNodes setUp(HugeLongArray visitSet, HugeLongArray seenSet, boolean allowStartNodeTraversal) {
36+
37+
for (int i = 0; i < localNodeCount; ++i) {
38+
long currentNode = nodeOffset + i;
39+
40+
if (!allowStartNodeTraversal) {
41+
seenSet.set(currentNode, 1L << i);
42+
}
43+
44+
visitSet.or(currentNode, 1L << i);
45+
}
46+
47+
return new SourceNodes(nodeOffset, localNodeCount);
48+
}
49+
50+
@Override
51+
public long[] nodes() {
52+
var array = new long[localNodeCount];
53+
for (int i=0; i<localNodeCount;++i){
54+
long currentNode = nodeOffset + i;
55+
array[i] = currentNode;
56+
}
57+
return array;
58+
}
59+
60+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.msbfs;
21+
22+
import org.neo4j.gds.collections.ha.HugeLongArray;
23+
24+
public class ListSourceNodesSpec implements SourceNodesSpec{
25+
26+
private final long[] sourceNodes;
27+
private final int from;
28+
private final int to;
29+
30+
public ListSourceNodesSpec(long[] sourceNodes, int from, int length) {
31+
this.sourceNodes = sourceNodes;
32+
this.from = from;
33+
this.to = Math.min(from + length, sourceNodes.length);
34+
}
35+
36+
@Override
37+
public SourceNodes setUp(HugeLongArray visitSet, HugeLongArray seenSet, boolean allowStartNodeTraversal) {
38+
39+
for (int i = from; i < to; ++i) {
40+
int adaptedI = i-from;
41+
long nodeId = sourceNodes[i];
42+
if (!allowStartNodeTraversal) {
43+
seenSet.set(nodeId, 1L << adaptedI);
44+
}
45+
visitSet.or(nodeId, 1L << adaptedI);
46+
}
47+
48+
return new SourceNodes(nodes());
49+
}
50+
51+
@Override
52+
public long[] nodes() {
53+
var array = new long[to-from];
54+
System.arraycopy(sourceNodes, from, array, 0, to - from);
55+
return array;
56+
}
57+
58+
}

algo/src/main/java/org/neo4j/gds/msbfs/MultiSourceBFSAccessMethods.java

Lines changed: 6 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,8 @@ public final class MultiSourceBFSAccessMethods {
7878
private final boolean allowStartNodeTraversal;
7979
private final long @Nullable [] sourceNodes;
8080

81-
// hypothesis: you supply actual source nodes, or you provide a count - if so that should be rationalised
82-
private final int sourceNodeCount;
83-
private final long nodeOffset;
84-
8581
private final TerminationFlag terminationFlag;
8682

87-
8883
public static MultiSourceBFSAccessMethods aggregatedNeighborProcessing(
8984
long nodeCount,
9085
RelationshipIterator relationships,
@@ -157,8 +152,6 @@ private static MultiSourceBFSAccessMethods createMultiSourceBFS(
157152
strategy,
158153
spec.allowStartNodeTraversal(),
159154
sourceNodes,
160-
0,
161-
0,
162155
terminationFlag
163156
);
164157

@@ -177,8 +170,6 @@ private MultiSourceBFSAccessMethods(
177170
ExecutionStrategy strategy,
178171
boolean allowStartNodeTraversal,
179172
long @Nullable [] sourceNodes,
180-
int sourceNodeCount,
181-
long nodeOffset,
182173
TerminationFlag terminationFlag
183174
) {
184175
this.visits = visits;
@@ -190,8 +181,6 @@ private MultiSourceBFSAccessMethods(
190181
this.strategy = strategy;
191182
this.allowStartNodeTraversal = allowStartNodeTraversal;
192183
this.sourceNodes = sourceNodes;
193-
this.sourceNodeCount = sourceNodeCount;
194-
this.nodeOffset = nodeOffset;
195184
this.terminationFlag = terminationFlag;
196185
}
197186

@@ -213,14 +202,9 @@ public void run(Concurrency concurrency, ExecutorService executor) {
213202
}
214203

215204
private long sourceLength() {
216-
if (sourceNodes != null) {
217-
return sourceNodes.length;
218-
}
219-
if (sourceNodeCount == 0) {
220-
return nodeCount;
221-
}
222-
return sourceNodeCount;
223-
}
205+
if (sourceNodes == null || sourceNodes.length==0) return nodeCount;
206+
return sourceNodes.length;
207+
}
224208

225209
private int numberOfThreads() {
226210
long sourceLength = sourceLength();
@@ -233,30 +217,8 @@ private int numberOfThreads() {
233217

234218
// lazily creates MS-BFS instances for OMEGA sized source chunks
235219
private Collection<MultiSourceBFSRunnable> allSourceBfss(int threads) {
236-
if (sourceNodes == null) {
237-
long sourceLength = nodeCount;
238-
return new ParallelMultiSources(threads, sourceLength) {
239-
@Override
240-
MultiSourceBFSRunnable next(final long from, final int length) {
241-
return new MultiSourceBFSRunnable(
242-
visits,
243-
visitsNext,
244-
seens,
245-
seensNext,
246-
sourceLength,
247-
relationships.concurrentCopy(),
248-
strategy,
249-
allowStartNodeTraversal,
250-
null,
251-
length,
252-
from
253-
);
254-
}
255-
};
256-
}
257-
long[] sourceNodes = this.sourceNodes;
258-
int sourceLength = sourceNodes.length;
259-
return new ParallelMultiSources(threads, sourceLength) {
220+
var producer = SourceNodeSpecFactory.createProducer(sourceNodes);
221+
return new ParallelMultiSources(threads, sourceLength()) {
260222
@Override
261223
MultiSourceBFSRunnable next(final long from, final int length) {
262224
return new MultiSourceBFSRunnable(
@@ -268,25 +230,10 @@ MultiSourceBFSRunnable next(final long from, final int length) {
268230
relationships.concurrentCopy(),
269231
strategy,
270232
allowStartNodeTraversal,
271-
Arrays.copyOfRange(sourceNodes, (int) from, (int) (from + length)),
272-
0,
273-
0
233+
producer.create(from,length)
274234
);
275235
}
276236
};
277237
}
278238

279-
@Override
280-
public String toString() {
281-
if (sourceNodes != null && sourceNodes.length > 0) {
282-
return "MSBFS{" + sourceNodes[0] +
283-
" .. " + (sourceNodes[sourceNodes.length - 1] + 1) +
284-
" (" + sourceNodes.length +
285-
")}";
286-
}
287-
return "MSBFS{" + nodeOffset +
288-
" .. " + (nodeOffset + sourceNodeCount) +
289-
" (" + sourceNodeCount +
290-
")}";
291-
}
292239
}

algo/src/main/java/org/neo4j/gds/msbfs/MultiSourceBFSRunnable.java

Lines changed: 7 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,14 @@ public final class MultiSourceBFSRunnable implements Runnable {
6565
private final RelationshipIterator relationships;
6666
private final ExecutionStrategy strategy;
6767
private final boolean allowStartNodeTraversal;
68-
private final long @Nullable [] sourceNodes;
69-
70-
// hypothesis: you supply actual source nodes, or you provide a count - if so that should be rationalised
71-
private final int sourceNodeCount;
72-
private final long nodeOffset;
68+
private final SourceNodesSpec sourceNodesSpec;
7369

7470
public static MultiSourceBFSRunnable createWithoutSeensNext(
7571
long nodeCount,
7672
RelationshipIterator relationships,
7773
ExecutionStrategy strategy,
7874
boolean allowStartNodeTraversal,
79-
int sourceNodeCount,
80-
long nodeOffset
75+
SourceNodesSpec sourceNodesSpec
8176
) {
8277
var visits = new LocalHugeLongArray(nodeCount);
8378
var visitsNext = new LocalHugeLongArray(nodeCount);
@@ -92,9 +87,7 @@ public static MultiSourceBFSRunnable createWithoutSeensNext(
9287
relationships,
9388
strategy,
9489
allowStartNodeTraversal,
95-
null,
96-
sourceNodeCount,
97-
nodeOffset
90+
sourceNodesSpec
9891
);
9992
}
10093

@@ -110,9 +103,7 @@ public static MultiSourceBFSRunnable createWithoutSeensNext(
110103
RelationshipIterator relationships,
111104
ExecutionStrategy strategy,
112105
boolean allowStartNodeTraversal,
113-
long @Nullable [] sourceNodes,
114-
int sourceNodeCount,
115-
long nodeOffset
106+
SourceNodesSpec sourceNodesSpec
116107
) {
117108

118109
this.visits = visits;
@@ -123,9 +114,7 @@ public static MultiSourceBFSRunnable createWithoutSeensNext(
123114
this.relationships = relationships;
124115
this.strategy = strategy;
125116
this.allowStartNodeTraversal = allowStartNodeTraversal;
126-
this.sourceNodes = sourceNodes;
127-
this.sourceNodeCount = sourceNodeCount;
128-
this.nodeOffset = nodeOffset;
117+
this.sourceNodesSpec = sourceNodesSpec;
129118
}
130119

131120
/**
@@ -138,62 +127,11 @@ public void run() {
138127
var seenSet = seens.get();
139128
var seenNextSet = seensNext != null ? seensNext.get() : null;
140129

141-
SourceNodes sourceNodes = this.sourceNodes == null
142-
? prepareOffsetSources(visitSet, seenSet, allowStartNodeTraversal)
143-
: prepareSpecifiedSources(visitSet, seenSet, this.sourceNodes, allowStartNodeTraversal);
144-
145-
strategy.run(relationships, nodeCount, sourceNodes, visitSet, visitNextSet, seenSet, seenNextSet);
146-
}
147-
148-
private SourceNodes prepareOffsetSources(
149-
HugeLongArray visitSet,
150-
HugeLongArray seenSet,
151-
boolean allowStartNodeTraversal
152-
) {
153-
var localNodeCount = this.sourceNodeCount;
154-
var nodeOffset = this.nodeOffset;
155-
156-
for (int i = 0; i < localNodeCount; ++i) {
157-
long currentNode = nodeOffset + i;
158130

159-
if (!allowStartNodeTraversal) {
160-
seenSet.set(currentNode, 1L << i);
161-
}
162-
163-
visitSet.or(currentNode, 1L << i);
164-
}
131+
var sourceNodes = sourceNodesSpec.setUp(visitSet,seenSet,allowStartNodeTraversal);
165132

166-
return new SourceNodes(nodeOffset, localNodeCount);
133+
strategy.run(relationships, nodeCount, sourceNodes, visitSet, visitNextSet, seenSet, seenNextSet);
167134
}
168135

169-
private static SourceNodes prepareSpecifiedSources(
170-
HugeLongArray visitSet,
171-
HugeLongArray seenSet,
172-
long[] sourceNodes,
173-
boolean allowStartNodeTraversal
174-
) {
175-
for (int i = 0; i < sourceNodes.length; ++i) {
176-
long nodeId = sourceNodes[i];
177-
if (!allowStartNodeTraversal) {
178-
seenSet.set(nodeId, 1L << i);
179-
}
180-
visitSet.or(nodeId, 1L << i);
181-
}
182136

183-
return new SourceNodes(sourceNodes);
184-
}
185-
186-
@Override
187-
public String toString() {
188-
if (sourceNodes != null && sourceNodes.length > 0) {
189-
return "MSBFS{" + sourceNodes[0] +
190-
" .. " + (sourceNodes[sourceNodes.length - 1] + 1) +
191-
" (" + sourceNodes.length +
192-
")}";
193-
}
194-
return "MSBFS{" + nodeOffset +
195-
" .. " + (nodeOffset + sourceNodeCount) +
196-
" (" + sourceNodeCount +
197-
")}";
198-
}
199137
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.msbfs;
21+
22+
import org.jetbrains.annotations.Nullable;
23+
24+
public final class SourceNodeSpecFactory {
25+
26+
private SourceNodeSpecFactory() {}
27+
28+
static SourceNodesSpecProducer createProducer(@Nullable long [] sourceNodes){
29+
30+
if (sourceNodes == null || sourceNodes.length == 0){
31+
return EmptySourceNodesSpec::new;
32+
}else{
33+
return (from,length) -> new ListSourceNodesSpec(sourceNodes,(int)from, length);
34+
}
35+
36+
}
37+
38+
}

0 commit comments

Comments
 (0)