Skip to content

Commit 8ac4544

Browse files
s1ckknutwalker
authored andcommitted
Add ShardedLongLongMap
1 parent 95f40a2 commit 8ac4544

File tree

2 files changed

+352
-0
lines changed

2 files changed

+352
-0
lines changed
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.core.utils.paged;
21+
22+
import org.eclipse.collections.api.map.primitive.MutableObjectLongMap;
23+
import org.eclipse.collections.api.map.primitive.ObjectLongMap;
24+
import org.eclipse.collections.impl.collection.mutable.AbstractMultiReaderMutableCollection;
25+
import org.eclipse.collections.impl.factory.primitive.ObjectLongMaps;
26+
import org.neo4j.gds.api.IdMap;
27+
import org.neo4j.gds.collections.ha.HugeObjectArray;
28+
import org.neo4j.gds.core.concurrency.Concurrency;
29+
import org.neo4j.gds.mem.BitUtil;
30+
31+
import java.util.Arrays;
32+
import java.util.concurrent.atomic.AtomicLong;
33+
import java.util.concurrent.locks.ReentrantLock;
34+
import java.util.stream.IntStream;
35+
36+
public final class ShardedByteArrayLongMap {
37+
38+
private final HugeObjectArray<byte[]> internalNodeMapping;
39+
private final ObjectLongMap<byte[]>[] originalNodeMappingShards;
40+
private final int shardShift;
41+
42+
public static Builder builder(Concurrency concurrency) {
43+
return new Builder(concurrency);
44+
}
45+
46+
private ShardedByteArrayLongMap(
47+
HugeObjectArray<byte[]> internalNodeMapping,
48+
ObjectLongMap<byte[]>[] originalNodeMappingShards,
49+
int shardShift
50+
) {
51+
this.internalNodeMapping = internalNodeMapping;
52+
this.originalNodeMappingShards = originalNodeMappingShards;
53+
this.shardShift = shardShift;
54+
}
55+
56+
public long toMappedNodeId(byte[] nodeId) {
57+
var shard = findShard(nodeId, this.originalNodeMappingShards, this.shardShift);
58+
return shard.getIfAbsent(nodeId, IdMap.NOT_FOUND);
59+
}
60+
61+
public boolean contains(byte[] originalId) {
62+
var shard = findShard(originalId, this.originalNodeMappingShards, this.shardShift);
63+
return shard.containsKey(originalId);
64+
}
65+
66+
public byte[] toOriginalNodeId(long nodeId) {
67+
return internalNodeMapping.get(nodeId);
68+
}
69+
70+
public long size() {
71+
return internalNodeMapping.size();
72+
}
73+
74+
private static <T> T findShard(byte[] key, T[] shards, int shift) {
75+
int idx = shardIdx(key, shift);
76+
return shards[idx];
77+
}
78+
79+
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
80+
private static final long FNV1_64_INIT = 0xcbf29ce484222325L;
81+
private static final long FNV1_64_PRIME = 1099511628211L;
82+
83+
// We use FNV-1a 64-bit hash function to hash the byte array key
84+
// and achieve a somewhat uniform distribution of keys across shards.
85+
private static int shardIdx(byte[] key, int shift) {
86+
long hash = FNV1_64_INIT;
87+
88+
for (int i = 0; i < key.length; i++) {
89+
hash ^= (key[i] & 0xff);
90+
hash *= FNV1_64_PRIME;
91+
}
92+
93+
return (int) (hash >>> shift);
94+
}
95+
96+
private static int numberOfShards(Concurrency concurrency) {
97+
return BitUtil.nextHighestPowerOfTwo(concurrency.value() * 4);
98+
}
99+
100+
@SuppressWarnings("unchecked")
101+
private static <S extends MapShard> ShardedByteArrayLongMap build(
102+
long nodeCount,
103+
S[] shards,
104+
int shardShift
105+
) {
106+
var internalNodeMapping = HugeObjectArray.newArray(byte[].class, nodeCount);
107+
var mapShards = new ObjectLongMap[shards.length];
108+
109+
// ignoring concurrency limitation 🤷
110+
Arrays.parallelSetAll(mapShards, idx -> {
111+
var shard = shards[idx];
112+
var mapping = shard.intoMapping();
113+
mapping.forEachKeyValue((originalId, mappedId) -> {
114+
internalNodeMapping.set(mappedId, originalId);
115+
});
116+
return mapping;
117+
});
118+
119+
return new ShardedByteArrayLongMap(
120+
internalNodeMapping,
121+
mapShards,
122+
shardShift
123+
);
124+
}
125+
126+
abstract static class MapShard {
127+
128+
private final ReentrantLock lock;
129+
private final AbstractMultiReaderMutableCollection.LockWrapper lockWrapper;
130+
final MutableObjectLongMap<byte[]> mapping;
131+
132+
MapShard() {
133+
this.mapping = ObjectLongMaps.mutable.empty();
134+
this.lock = new ReentrantLock();
135+
this.lockWrapper = new AbstractMultiReaderMutableCollection.LockWrapper(lock);
136+
}
137+
138+
final AbstractMultiReaderMutableCollection.LockWrapper acquireLock() {
139+
this.lock.lock();
140+
return this.lockWrapper;
141+
}
142+
143+
void assertIsUnderLock() {
144+
assert this.lock.isHeldByCurrentThread() : "addNode must only be called while holding the lock";
145+
}
146+
147+
MutableObjectLongMap<byte[]> intoMapping() {
148+
return mapping;
149+
}
150+
}
151+
152+
public static final class Builder {
153+
154+
private final AtomicLong nodeCount;
155+
private final Shard[] shards;
156+
private final int shardShift;
157+
private final int shardMask;
158+
159+
Builder(Concurrency concurrency) {
160+
this.nodeCount = new AtomicLong();
161+
int numberOfShards = numberOfShards(concurrency);
162+
this.shardShift = Long.SIZE - Integer.numberOfTrailingZeros(numberOfShards);
163+
this.shardMask = numberOfShards - 1;
164+
this.shards = IntStream.range(0, numberOfShards)
165+
.mapToObj(__ -> new Shard(this.nodeCount))
166+
.toArray(Shard[]::new);
167+
}
168+
169+
/**
170+
* Add a node to the mapping.
171+
*
172+
* @return {@code mappedId >= 0} if the node was added,
173+
* or {@code -(mappedId) - 1} if the node was already mapped.
174+
*/
175+
public long addNode(byte[] nodeId) {
176+
var shard = findShard(nodeId, this.shards, this.shardShift);
177+
try (var ignoredLock = shard.acquireLock()) {
178+
return shard.addNode(nodeId);
179+
}
180+
}
181+
182+
public ShardedByteArrayLongMap build() {
183+
return ShardedByteArrayLongMap.build(
184+
this.nodeCount.get(),
185+
this.shards,
186+
this.shardShift
187+
);
188+
}
189+
190+
private static final class Shard extends MapShard {
191+
private final AtomicLong nextId;
192+
193+
private Shard(AtomicLong nextId) {
194+
super();
195+
this.nextId = nextId;
196+
}
197+
198+
long addNode(byte[] nodeId) {
199+
this.assertIsUnderLock();
200+
long mappedId = mapping.getIfAbsent(nodeId, IdMap.NOT_FOUND);
201+
if (mappedId != IdMap.NOT_FOUND) {
202+
return -mappedId - 1;
203+
}
204+
mappedId = nextId.getAndIncrement();
205+
mapping.put(nodeId, mappedId);
206+
return mappedId;
207+
}
208+
}
209+
}
210+
}
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.core.utils.paged;
21+
22+
import net.jqwik.api.Arbitraries;
23+
import net.jqwik.api.Arbitrary;
24+
import net.jqwik.api.ForAll;
25+
import net.jqwik.api.Property;
26+
import net.jqwik.api.Provide;
27+
import net.jqwik.api.constraints.Size;
28+
import org.junit.jupiter.api.Test;
29+
import org.junit.jupiter.params.ParameterizedTest;
30+
import org.junit.jupiter.params.provider.ValueSource;
31+
import org.neo4j.gds.core.concurrency.Concurrency;
32+
import org.neo4j.gds.core.concurrency.DefaultPool;
33+
import org.neo4j.gds.core.concurrency.ParallelUtil;
34+
import org.neo4j.gds.core.utils.partition.PartitionUtils;
35+
36+
import java.nio.charset.StandardCharsets;
37+
import java.util.Optional;
38+
39+
import static org.assertj.core.api.Assertions.assertThat;
40+
41+
class ShardedByteArrayLongMapTest {
42+
43+
@Provide
44+
Arbitrary<byte[][]> nodes() {
45+
var idGen = Arbitraries.bytes().array(byte[].class).ofSize(10);
46+
return Arbitraries.create(idGen::sample).array(byte[][].class);
47+
}
48+
49+
@Test
50+
void addSingleNode() {
51+
byte[] node = "foobar".getBytes(StandardCharsets.UTF_8);
52+
var builder = ShardedByteArrayLongMap.builder(new Concurrency(1));
53+
long mapped = builder.addNode(node);
54+
assertThat(mapped).isGreaterThanOrEqualTo(0);
55+
var map = builder.build();
56+
assertThat(map.toMappedNodeId(node)).isEqualTo(mapped);
57+
assertThat(map.toOriginalNodeId(mapped)).isEqualTo(node);
58+
}
59+
60+
@Property
61+
void addNodes(@ForAll("nodes") @Size(100) byte[][] nodes) {
62+
var builder = ShardedByteArrayLongMap.builder(new Concurrency(1));
63+
for (byte[] node : nodes) {
64+
long mapped = builder.addNode(node);
65+
assertThat(mapped).isGreaterThanOrEqualTo(0);
66+
}
67+
var map = builder.build();
68+
69+
assertThat(map.size()).isEqualTo(nodes.length);
70+
for (byte[] node : nodes) {
71+
assertThat(map.toOriginalNodeId(map.toMappedNodeId(node))).isEqualTo(node);
72+
}
73+
}
74+
75+
@ParameterizedTest
76+
@ValueSource(ints = {0, 1024, 4096, 5000, 9999})
77+
void size(int expectedSize) {
78+
var builder = ShardedByteArrayLongMap.builder(new Concurrency(1));
79+
for (int i = 0; i < expectedSize; i++) {
80+
builder.addNode(String.valueOf(i).getBytes(StandardCharsets.UTF_8));
81+
}
82+
assertThat(builder.build().size()).isEqualTo(expectedSize);
83+
}
84+
85+
@Test
86+
void toMappedNodeId() {
87+
byte[] node = "foobar".getBytes(StandardCharsets.UTF_8);
88+
var builder = ShardedByteArrayLongMap.builder(new Concurrency(1));
89+
long mapped = builder.addNode(node);
90+
var map = builder.build();
91+
assertThat(map.toMappedNodeId(node)).isEqualTo(mapped);
92+
}
93+
94+
@Test
95+
void toOriginalNodeId() {
96+
byte[] node = "foobar".getBytes(StandardCharsets.UTF_8);
97+
var builder = ShardedByteArrayLongMap.builder(new Concurrency(1));
98+
long mapped = builder.addNode(node);
99+
var map = builder.build();
100+
assertThat(map.toOriginalNodeId(mapped)).isEqualTo(node);
101+
}
102+
103+
@Test
104+
void contains() {
105+
byte[] node = "foobar".getBytes(StandardCharsets.UTF_8);
106+
var builder = ShardedByteArrayLongMap.builder(new Concurrency(1));
107+
builder.addNode(node);
108+
var map = builder.build();
109+
assertThat(map.contains(node)).isTrue();
110+
assertThat(map.contains("barfoo".getBytes(StandardCharsets.UTF_8))).isFalse();
111+
}
112+
113+
@Property(tries = 1)
114+
void testAddingMultipleNodesInParallel(@ForAll("nodes") @Size(10000) byte[][] originalIds) {
115+
var concurrency = new Concurrency(4);
116+
var builder = ShardedByteArrayLongMap.builder(concurrency);
117+
118+
var tasks = PartitionUtils.rangePartition(
119+
concurrency,
120+
originalIds.length,
121+
partition -> (Runnable) () -> {
122+
byte[][] batch = new byte[(int) partition.nodeCount()][];
123+
System.arraycopy(originalIds, (int) partition.startNode(), batch, 0, batch.length);
124+
for (byte[] node : batch) {
125+
builder.addNode(node);
126+
}
127+
},
128+
Optional.of(100)
129+
);
130+
131+
ParallelUtil.run(tasks, DefaultPool.INSTANCE);
132+
133+
var map = builder.build();
134+
135+
assertThat(map.size()).isEqualTo(originalIds.length);
136+
long[] mappedIds = new long[originalIds.length];
137+
for (int i = 0; i < map.size(); i++) {
138+
mappedIds[i] = map.toMappedNodeId(originalIds[i]);
139+
}
140+
assertThat(mappedIds).doesNotHaveDuplicates();
141+
}
142+
}

0 commit comments

Comments
 (0)