Skip to content

Commit ecfb59b

Browse files
refactor manage nodeQueue in SelectionStrategy
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neotechnology.com>
1 parent e9ecce9 commit ecfb59b

File tree

3 files changed

+52
-39
lines changed

3 files changed

+52
-39
lines changed

algo/src/main/java/org/neo4j/gds/betweenness/BetweennessCentrality.java

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,11 @@
3333
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3434

3535
import java.util.concurrent.ExecutorService;
36-
import java.util.concurrent.atomic.AtomicLong;
3736
import java.util.function.Consumer;
3837

3938
public class BetweennessCentrality extends Algorithm<HugeAtomicDoubleArray> {
4039

4140
private final Graph graph;
42-
private final AtomicLong nodeQueue = new AtomicLong();
4341
private final long nodeCount;
4442
private final double divisor;
4543
private final ForwardTraverser.Factory traverserFactory;
@@ -75,7 +73,6 @@ public BetweennessCentrality(
7573
@Override
7674
public HugeAtomicDoubleArray compute() {
7775
progressTracker.beginSubTask();
78-
nodeQueue.set(0);
7976
ParallelUtil.run(ParallelUtil.tasks(concurrency, BCTask::new), executorService);
8077
progressTracker.endSubTask();
8178
return centrality;
@@ -113,15 +110,11 @@ public void run() {
113110
);
114111

115112
for (;;) {
116-
// take start node from the queue
117-
long startNodeId = nodeQueue.getAndIncrement();
118-
if (startNodeId >= nodeCount || !terminationFlag.running()) {
113+
long startNodeId = selectionStrategy.next();
114+
if (startNodeId == -1 || !terminationFlag.running()) {
119115
return;
120116
}
121-
// check whether the node is part of the subset
122-
if (!selectionStrategy.select(startNodeId)) {
123-
continue;
124-
}
117+
125118
// reset
126119
getProgressTracker().logProgress();
127120

algo/src/main/java/org/neo4j/gds/betweenness/SelectionStrategy.java

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,24 +39,36 @@
3939
public interface SelectionStrategy {
4040

4141
SelectionStrategy ALL = new SelectionStrategy() {
42+
private final AtomicLong nodeQueue = new AtomicLong();
43+
private long graphSize;
44+
4245
@Override
43-
public void init(Graph graph, ExecutorService executorService, int concurrency) { }
46+
public void init(Graph graph, ExecutorService executorService, int concurrency) {
47+
this.graphSize = graph.nodeCount();
48+
nodeQueue.set(0);
49+
}
4450

4551
@Override
46-
public boolean select(long nodeId) {
47-
return true;
52+
public long next() {
53+
long next = nodeQueue.getAndIncrement();
54+
if (next >= graphSize) {
55+
return -1;
56+
}
57+
return next;
4858
}
4959
};
5060

5161
void init(Graph graph, ExecutorService executorService, int concurrency);
5262

53-
boolean select(long nodeId);
63+
long next();
5464

5565
class RandomDegree implements SelectionStrategy {
5666

5767
private final long samplingSize;
5868
private final Optional<Long> maybeRandomSeed;
69+
private final AtomicLong nodeQueue = new AtomicLong();
5970

71+
private long graphSize;
6072
private BitSet bitSet;
6173

6274
public RandomDegree(long samplingSize) {
@@ -72,17 +84,25 @@ public RandomDegree(long samplingSize, Optional<Long> maybeRandomSeed) {
7284
public void init(Graph graph, ExecutorService executorService, int concurrency) {
7385
assert samplingSize <= graph.nodeCount();
7486
this.bitSet = new BitSet(graph.nodeCount());
87+
this.graphSize = graph.nodeCount();
88+
nodeQueue.set(0);
7589
var partitions = PartitionUtils.numberAlignedPartitioning(concurrency, graph.nodeCount(), Long.SIZE);
7690
var maxDegree = maxDegree(graph, partitions, executorService, concurrency);
7791
selectNodes(graph, partitions, maxDegree, executorService, concurrency);
7892
}
7993

8094
@Override
81-
public boolean select(long nodeId) {
82-
return bitSet.get(nodeId);
95+
public long next() {
96+
long next;
97+
while ((next = nodeQueue.getAndIncrement()) < graphSize) {
98+
if (bitSet.get(next)) {
99+
return next;
100+
}
101+
}
102+
return -1;
83103
}
84104

85-
private int maxDegree(
105+
private static int maxDegree(
86106
Graph graph,
87107
Collection<Partition> partitions,
88108
ExecutorService executorService,

algo/src/test/java/org/neo4j/gds/betweenness/SelectionStrategyTest.java

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import java.util.Optional;
3434

3535
import static org.junit.jupiter.api.Assertions.assertEquals;
36-
import static org.junit.jupiter.api.Assertions.assertFalse;
3736
import static org.junit.jupiter.api.Assertions.assertTrue;
3837
import static org.neo4j.gds.TestSupport.fromGdl;
3938

@@ -61,15 +60,15 @@ class SelectionStrategyTest {
6160
void selectAll() {
6261
SelectionStrategy selectionStrategy = SelectionStrategy.ALL;
6362
selectionStrategy.init(graph, Pools.DEFAULT, 1);
64-
assertEquals(graph.nodeCount(), samplingSize(graph.nodeCount(), selectionStrategy));
63+
assertEquals(graph.nodeCount(), samplingSize(selectionStrategy));
6564
}
6665

6766
@ParameterizedTest
6867
@ValueSource(longs = {0, 1, 2, 10, 11})
6968
void selectSamplingSize(long samplingSize) {
7069
SelectionStrategy selectionStrategy = new SelectionStrategy.RandomDegree(samplingSize);
7170
selectionStrategy.init(graph, Pools.DEFAULT, 1);
72-
assertEquals(samplingSize, samplingSize(graph.nodeCount(), selectionStrategy));
71+
assertEquals(samplingSize, samplingSize(selectionStrategy));
7372
}
7473

7574
@ParameterizedTest
@@ -83,17 +82,18 @@ void selectSamplingSizeMultiThreaded(long samplingSize) {
8382
.generate();
8483
SelectionStrategy selectionStrategy = new SelectionStrategy.RandomDegree(samplingSize, Optional.of(42L));
8584
selectionStrategy.init(graph, Pools.DEFAULT, 4);
86-
assertEquals(samplingSize, samplingSize(graph.nodeCount(), selectionStrategy));
85+
assertEquals(samplingSize, samplingSize(selectionStrategy));
8786
}
8887

8988
@Test
9089
void selectSamplingSizeWithSeed() {
9190
SelectionStrategy selectionStrategy = new SelectionStrategy.RandomDegree(3, Optional.of(42L));
9291
selectionStrategy.init(graph, Pools.DEFAULT, 1);
93-
assertEquals(3, samplingSize(graph.nodeCount(), selectionStrategy));
94-
assertTrue(selectionStrategy.select(graph.toMappedNodeId("a")));
95-
assertTrue(selectionStrategy.select(graph.toMappedNodeId("b")));
96-
assertTrue(selectionStrategy.select(graph.toMappedNodeId("f")));
92+
assertEquals(3, samplingSize(selectionStrategy));
93+
selectionStrategy.init(graph, Pools.DEFAULT, 1);
94+
assertEquals(graph.toMappedNodeId("a"), selectionStrategy.next());
95+
assertEquals(graph.toMappedNodeId("b"), selectionStrategy.next());
96+
assertEquals(graph.toMappedNodeId("f"), selectionStrategy.next());
9797
}
9898

9999
@Test
@@ -102,8 +102,9 @@ void neverIncludeZeroDegNodesIfBetterChoicesExist() {
102102

103103
SelectionStrategy selectionStrategy = new SelectionStrategy.RandomDegree(1);
104104
selectionStrategy.init(graph, Pools.DEFAULT, 1);
105-
assertEquals(1, samplingSize(graph.nodeCount(), selectionStrategy));
106-
assertTrue(selectionStrategy.select(graph.toMappedNodeId("a")));
105+
assertEquals(1, samplingSize(selectionStrategy));
106+
selectionStrategy.init(graph, Pools.DEFAULT, 1);
107+
assertEquals(graph.toMappedNodeId("a"), selectionStrategy.next());
107108
}
108109

109110
@Test
@@ -112,27 +113,26 @@ void not100PercentLikelyUnlessMaxDegNode() {
112113

113114
SelectionStrategy selectionStrategy = new SelectionStrategy.RandomDegree(1, Optional.of(42L));
114115
selectionStrategy.init(graph, Pools.DEFAULT, 1);
115-
assertEquals(1, samplingSize(graph.nodeCount(), selectionStrategy));
116-
assertFalse(selectionStrategy.select(graph.toMappedNodeId("a")));
117-
assertTrue(selectionStrategy.select(graph.toMappedNodeId("b")));
116+
assertEquals(1, samplingSize(selectionStrategy));
117+
selectionStrategy.init(graph, Pools.DEFAULT, 1);
118+
assertEquals(graph.toMappedNodeId("b"), selectionStrategy.next());
118119
}
119120

120121
@Test
121122
void selectHighDegreeNode() {
122123
SelectionStrategy selectionStrategy = new SelectionStrategy.RandomDegree(1);
123124
selectionStrategy.init(graph, Pools.DEFAULT, 1);
124-
assertEquals(1, samplingSize(graph.nodeCount(), selectionStrategy));
125-
var isA = selectionStrategy.select(graph.toMappedNodeId("a"));
126-
var isB = selectionStrategy.select(graph.toMappedNodeId("b"));
127-
assertTrue(isA || isB);
125+
assertEquals(1, samplingSize(selectionStrategy));
126+
selectionStrategy.init(graph, Pools.DEFAULT, 1);
127+
var isA = selectionStrategy.next();
128+
var isB = selectionStrategy.next();
129+
assertTrue(isA >= 0 || isB >= 0);
128130
}
129131

130-
private static long samplingSize(long nodeCount, SelectionStrategy selectionStrategy) {
132+
private static long samplingSize(SelectionStrategy selectionStrategy) {
131133
long samplingSize = 0;
132-
for (int nodeId = 0; nodeId < nodeCount; nodeId++) {
133-
if (selectionStrategy.select(nodeId)) {
134-
samplingSize++;
135-
}
134+
while (selectionStrategy.next() >= 0) {
135+
samplingSize++;
136136
}
137137
return samplingSize;
138138
}

0 commit comments

Comments
 (0)