Skip to content

Commit aced8ed

Browse files
committed
Fix concurrent insertion into MultiLabelInformation
1 parent f3e8b29 commit aced8ed

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

core/src/main/java/org/neo4j/gds/core/loading/MultiLabelInformation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ static Builder of(
167167
Collection<NodeLabel> starNodeLabelMappings
168168
) {
169169
var nodeLabelBitSetMap = availableNodeLabels.stream().collect(
170-
Collectors.toMap(
170+
Collectors.toConcurrentMap(
171171
nodeLabel -> nodeLabel,
172172
ignored -> HugeAtomicGrowingBitSet.create(expectedCapacity)
173173
)

core/src/test/java/org/neo4j/gds/core/loading/MultiLabelInformationTest.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@
2525
import org.neo4j.gds.NodeLabel;
2626
import org.neo4j.gds.api.BatchNodeIterable;
2727
import org.neo4j.gds.api.IdMap;
28+
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
2829

2930
import java.util.List;
3031
import java.util.Random;
3132
import java.util.concurrent.atomic.LongAdder;
3233
import java.util.function.LongConsumer;
3334
import java.util.function.LongUnaryOperator;
35+
import java.util.stream.LongStream;
3436

37+
import static java.util.stream.Collectors.toList;
3538
import static org.assertj.core.api.Assertions.assertThat;
3639
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
3740
import static org.assertj.core.api.Assertions.assertThatNoException;
@@ -367,5 +370,28 @@ void shouldAddNodeIdsToLabel() {
367370
assertThat(labelInformation.nodeLabelsForNodeId(3L)).contains(NodeLabel.of("B"));
368371
}
369372

373+
374+
@Test
375+
void shouldAcceptConcurrentInserts() {
376+
var builder = MultiLabelInformation.Builder.of(110, List.of(), List.of());
377+
378+
// Create 10 tasks that try to insert overlapping node labels
379+
List<Runnable> tasks = LongStream
380+
.range(0, 10)
381+
.mapToObj(i ->
382+
(Runnable) () -> LongStream.range(i * 10, i * 10 + 20).forEach(l -> builder.addNodeIdToLabel(NodeLabel.of("" + l), l))
383+
).collect(toList());
384+
385+
RunWithConcurrency.builder()
386+
.tasks(tasks)
387+
.concurrency(4)
388+
.build()
389+
.run();
390+
391+
var map = builder.build(110, i -> i);
392+
393+
LongStream.range(0, 110).forEach(i -> assertThat(map.hasLabel(i, NodeLabel.of("" + i))).isTrue());
394+
}
395+
370396
}
371397
}

0 commit comments

Comments
 (0)