Skip to content

Commit 3dfb050

Browse files
s1ckknutwalker
authored andcommitted
Use custom hashing strategy for map shards
1 parent 4dfe318 commit 3dfb050

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

core/src/main/java/org/neo4j/gds/core/utils/paged/ShardedByteArrayLongMap.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
*/
2020
package org.neo4j.gds.core.utils.paged;
2121

22+
import org.eclipse.collections.api.block.HashingStrategy;
2223
import org.eclipse.collections.api.map.primitive.MutableObjectLongMap;
2324
import org.eclipse.collections.api.map.primitive.ObjectLongMap;
2425
import org.eclipse.collections.impl.collection.mutable.AbstractMultiReaderMutableCollection;
25-
import org.eclipse.collections.impl.factory.primitive.ObjectLongMaps;
26+
import org.eclipse.collections.impl.map.mutable.primitive.ObjectLongHashMapWithHashingStrategy;
2627
import org.neo4j.gds.api.IdMap;
2728
import org.neo4j.gds.collections.ha.HugeObjectArray;
2829
import org.neo4j.gds.core.concurrency.Concurrency;
@@ -125,12 +126,24 @@ private static <S extends MapShard> ShardedByteArrayLongMap build(
125126

126127
abstract static class MapShard {
127128

129+
private static class ArrayHashingStrategy implements HashingStrategy<byte[]> {
130+
@Override
131+
public int computeHashCode(byte[] object) {
132+
return Arrays.hashCode(object);
133+
}
134+
135+
@Override
136+
public boolean equals(byte[] object1, byte[] object2) {
137+
return Arrays.equals(object1, object2);
138+
}
139+
}
140+
128141
private final ReentrantLock lock;
129142
private final AbstractMultiReaderMutableCollection.LockWrapper lockWrapper;
130143
final MutableObjectLongMap<byte[]> mapping;
131144

132145
MapShard() {
133-
this.mapping = ObjectLongMaps.mutable.empty();
146+
this.mapping = new ObjectLongHashMapWithHashingStrategy<>(new ArrayHashingStrategy());
134147
this.lock = new ReentrantLock();
135148
this.lockWrapper = new AbstractMultiReaderMutableCollection.LockWrapper(lock);
136149
}

core/src/test/java/org/neo4j/gds/core/utils/paged/ShardedByteArrayLongMapTest.java

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.neo4j.gds.core.utils.partition.PartitionUtils;
3535

3636
import java.nio.charset.StandardCharsets;
37+
import java.util.Arrays;
3738
import java.util.Optional;
3839

3940
import static org.assertj.core.api.Assertions.assertThat;
@@ -43,7 +44,9 @@ class ShardedByteArrayLongMapTest {
4344
@Provide
4445
Arbitrary<byte[][]> nodes() {
4546
var idGen = Arbitraries.bytes().array(byte[].class).ofSize(10);
46-
return Arbitraries.create(idGen::sample).array(byte[][].class);
47+
return Arbitraries
48+
.create(idGen::sample)
49+
.array(byte[][].class);
4750
}
4851

4952
@Test
@@ -61,8 +64,7 @@ void addSingleNode() {
6164
void addNodes(@ForAll("nodes") @Size(100) byte[][] nodes) {
6265
var builder = ShardedByteArrayLongMap.builder(new Concurrency(1));
6366
for (byte[] node : nodes) {
64-
long mapped = builder.addNode(node);
65-
assertThat(mapped).isGreaterThanOrEqualTo(0);
67+
builder.addNode(node);
6668
}
6769
var map = builder.build();
6870

@@ -72,6 +74,34 @@ void addNodes(@ForAll("nodes") @Size(100) byte[][] nodes) {
7274
}
7375
}
7476

77+
@Property
78+
void addNodesDifferentObject(@ForAll("nodes") @Size(100) byte[][] nodes) {
79+
var builder = ShardedByteArrayLongMap.builder(new Concurrency(1));
80+
for (byte[] node : nodes) {
81+
builder.addNode(node);
82+
}
83+
var map = builder.build();
84+
85+
assertThat(map.size()).isEqualTo(nodes.length);
86+
for (byte[] node : nodes) {
87+
// Ensure that hashCode and equals work correctly for byte arrays
88+
// with same elements, but different objects.
89+
var nodeCopy = Arrays.copyOf(node, node.length);
90+
assertThat(map.toOriginalNodeId(map.toMappedNodeId(nodeCopy))).isEqualTo(node);
91+
}
92+
}
93+
94+
@Test
95+
void addExistingNode() {
96+
byte[] node1 = "foobar".getBytes(StandardCharsets.UTF_8);
97+
byte[] node2 = "foobar".getBytes(StandardCharsets.UTF_8);
98+
var builder = ShardedByteArrayLongMap.builder(new Concurrency(1));
99+
long mappedNode1 = builder.addNode(node1);
100+
assertThat(mappedNode1).isGreaterThanOrEqualTo(0);
101+
long mappedNode2 = builder.addNode(node2);
102+
assertThat(mappedNode2).isEqualTo(-(mappedNode1 + 1));
103+
}
104+
75105
@ParameterizedTest
76106
@ValueSource(ints = {0, 1024, 4096, 5000, 9999})
77107
void size(int expectedSize) {

0 commit comments

Comments
 (0)