Skip to content

Commit 8aaf93c

Browse files
breakanalysisFlorentinDadamnsch
committed
Fix sampling bugs related to idmaps
Co-Authored-By: Florentin Dörre <florentin.dorre@neotechnology.com> Co-Authored-By: Adam Schill Collberg <adam.schill.collberg@protonmail.com>
1 parent 33da470 commit 8aaf93c

File tree

11 files changed

+163
-26
lines changed

11 files changed

+163
-26
lines changed

ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/NegativeSampler.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;
2828

2929
import java.util.Collection;
30+
import java.util.List;
3031
import java.util.Optional;
3132

3233
public interface NegativeSampler {
@@ -36,6 +37,7 @@ public interface NegativeSampler {
3637
static NegativeSampler of(
3738
GraphStore graphStore,
3839
Graph graph,
40+
Collection<NodeLabel> sourceAndTargetNodeLabels,
3941
Optional<String> negativeRelationshipType,
4042
double negativeSamplingRatio,
4143
long testPositiveCount,
@@ -47,7 +49,11 @@ static NegativeSampler of(
4749
Optional<Long> randomSeed
4850
) {
4951
if (negativeRelationshipType.isPresent()) {
50-
Graph negativeExampleGraph = graphStore.getGraph(RelationshipType.of(negativeRelationshipType.orElseThrow()));
52+
Graph negativeExampleGraph = graphStore.getGraph(
53+
sourceAndTargetNodeLabels,
54+
List.of(RelationshipType.of(negativeRelationshipType.orElseThrow())),
55+
Optional.empty()
56+
);
5157
double testTrainFraction = testPositiveCount / (double) (testPositiveCount + trainPositiveCount);
5258

5359
return new UserInputNegativeSampler(

ml/ml-algo/src/main/java/org/neo4j/gds/ml/negativeSampling/UserInputNegativeSampler.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,16 @@ public void produceNegativeSamples(
6464

6565
negativeExampleGraph.forEachNode(nodeId -> {
6666
negativeExampleGraph.forEachRelationship(nodeId, (s, t) -> {
67+
// as we are adding the relationships to the GraphStore we need to operate over the rootNodeIds
68+
long rootS = negativeExampleGraph.toRootNodeId(s);
69+
long rootT = negativeExampleGraph.toRootNodeId(t);
6770
if (s < t) {
6871
if (sample(testRelationshipsToAdd.doubleValue()/(testRelationshipsToAdd.doubleValue() + trainRelationshipsToAdd.doubleValue()))) {
6972
testRelationshipsToAdd.decrement();
70-
testSetBuilder.add(s, t, NEGATIVE);
73+
testSetBuilder.addFromInternal(rootS, rootT, NEGATIVE);
7174
} else {
7275
trainRelationshipsToAdd.decrement();
73-
trainSetBuilder.add(s, t, NEGATIVE);
76+
trainSetBuilder.addFromInternal(rootS, rootT, NEGATIVE);
7477
}
7578
}
7679
return true;

ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitter.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,22 @@ public class DirectedEdgeSplitter extends EdgeSplitter {
3636

3737
public DirectedEdgeSplitter(
3838
Optional<Long> maybeSeed,
39+
IdMap rootNodes,
3940
IdMap sourceLabels,
4041
IdMap targetLabels,
4142
RelationshipType selectedRelationshipType,
4243
RelationshipType remainingRelationshipType,
4344
int concurrency
4445
) {
45-
super(maybeSeed, sourceLabels, targetLabels, selectedRelationshipType, remainingRelationshipType, concurrency);
46+
super(
47+
maybeSeed,
48+
rootNodes,
49+
sourceLabels,
50+
targetLabels,
51+
selectedRelationshipType,
52+
remainingRelationshipType,
53+
concurrency
54+
);
4655
}
4756

4857
@Override

ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/EdgeSplitter.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,20 @@ public abstract class EdgeSplitter {
4747

4848
protected final IdMap sourceNodes;
4949
protected final IdMap targetNodes;
50+
protected final IdMap rootNodes;
5051

5152
protected int concurrency;
5253

5354
EdgeSplitter(
5455
Optional<Long> maybeSeed,
56+
IdMap rootNodes,
5557
IdMap sourceNodes,
5658
IdMap targetNodes,
5759
RelationshipType selectedRelationshipType,
5860
RelationshipType remainingRelationshipType,
5961
int concurrency
6062
) {
63+
this.rootNodes = rootNodes;
6164
this.selectedRelationshipType = selectedRelationshipType;
6265
this.remainingRelationshipType = remainingRelationshipType;
6366
this.rng = new Random();
@@ -78,7 +81,7 @@ public SplitResult splitPositiveExamples(
7881
LongLongPredicate isValidNodePair = (s, t) -> isValidSourceNode.apply(s) && isValidTargetNode.apply(t);
7982

8083
RelationshipsBuilder selectedRelsBuilder = newRelationshipsBuilder(
81-
graph,
84+
rootNodes,
8285
selectedRelationshipType,
8386
Direction.DIRECTED,
8487
Optional.of(EdgeSplitter.RELATIONSHIP_PROPERTY)
@@ -89,7 +92,7 @@ public SplitResult splitPositiveExamples(
8992
RelationshipsBuilder remainingRelsBuilder;
9093
RelationshipWithPropertyConsumer remainingRelsConsumer;
9194

92-
remainingRelsBuilder = newRelationshipsBuilder(graph, remainingRelationshipType, remainingRelDirection, remainingRelPropertyKey);
95+
remainingRelsBuilder = newRelationshipsBuilder(rootNodes, remainingRelationshipType, remainingRelDirection, remainingRelPropertyKey);
9396
remainingRelsConsumer = (s, t, w) -> {
9497
remainingRelsBuilder.addFromInternal(graph.toRootNodeId(s), graph.toRootNodeId(t), w);
9598
return true;
@@ -153,15 +156,15 @@ protected long samplesPerNode(long maxSamples, double remainingSamples, long rem
153156
}
154157

155158
private static RelationshipsBuilder newRelationshipsBuilder(
156-
Graph graph,
159+
IdMap rootNodes,
157160
RelationshipType relationshipType,
158161
Direction direction,
159162
Optional<String> propertyKey
160163
) {
161164
return GraphFactory.initRelationshipsBuilder()
162165
.relationshipType(relationshipType)
163166
.aggregation(Aggregation.SINGLE)
164-
.nodes(graph)
167+
.nodes(rootNodes)
165168
.orientation(direction.toOrientation())
166169
.addAllPropertyConfigs(propertyKey
167170
.map(key -> List.of(GraphFactory.PropertyConfig.of(key, Aggregation.SINGLE, DefaultValue.forDouble())))

ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/SplitRelationships.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,19 @@ public final class SplitRelationships extends Algorithm<EdgeSplitter.SplitResult
4040

4141
private final SplitRelationshipsBaseConfig config;
4242

43+
private final IdMap rootNodes;
44+
4345
private final IdMap sourceNodes;
4446

4547
private final IdMap targetNodes;
4648

47-
private SplitRelationships(Graph graph, Graph masterGraph, IdMap sourceNodes, IdMap targetNodes, SplitRelationshipsBaseConfig config) {
49+
private SplitRelationships(Graph graph, Graph masterGraph,
50+
IdMap rootNodes,
51+
IdMap sourceNodes, IdMap targetNodes, SplitRelationshipsBaseConfig config) {
4852
super(ProgressTracker.NULL_TRACKER);
4953
this.graph = graph;
5054
this.masterGraph = masterGraph;
55+
this.rootNodes = rootNodes;
5156
this.config = config;
5257
this.sourceNodes = sourceNodes;
5358
this.targetNodes = targetNodes;
@@ -66,7 +71,7 @@ public static SplitRelationships of(GraphStore graphStore, SplitRelationshipsBas
6671
IdMap sourceNodes = graphStore.getGraph(sourceLabels);
6772
IdMap targetNodes = graphStore.getGraph(targetLabels);
6873

69-
return new SplitRelationships(graph, masterGraph, sourceNodes, targetNodes, config);
74+
return new SplitRelationships(graph, masterGraph, graphStore.nodes(), sourceNodes, targetNodes, config);
7075
}
7176

7277
public static MemoryEstimation estimate(SplitRelationshipsBaseConfig configuration) {
@@ -98,6 +103,7 @@ public EdgeSplitter.SplitResult compute() {
98103
var splitter = isUndirected
99104
? new UndirectedEdgeSplitter(
100105
config.randomSeed(),
106+
rootNodes,
101107
sourceNodes,
102108
targetNodes,
103109
config.holdoutRelationshipType(),
@@ -106,6 +112,7 @@ public EdgeSplitter.SplitResult compute() {
106112
)
107113
: new DirectedEdgeSplitter(
108114
config.randomSeed(),
115+
rootNodes,
109116
sourceNodes,
110117
targetNodes,
111118
config.holdoutRelationshipType(),

ml/ml-algo/src/main/java/org/neo4j/gds/ml/splitting/UndirectedEdgeSplitter.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,16 @@ public class UndirectedEdgeSplitter extends EdgeSplitter {
4545

4646
public UndirectedEdgeSplitter(
4747
Optional<Long> maybeSeed,
48+
IdMap rootNodes,
4849
IdMap sourceNodes,
4950
IdMap targetNodes,
5051
RelationshipType selectedRelationshipType,
5152
RelationshipType remainingRelationshipType,
5253
int concurrency
5354
) {
54-
super(maybeSeed, sourceNodes, targetNodes, selectedRelationshipType, remainingRelationshipType, concurrency);
55+
super(maybeSeed,
56+
rootNodes,
57+
sourceNodes, targetNodes, selectedRelationshipType, remainingRelationshipType, concurrency);
5558
}
5659

5760
@Override

ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/DirectedEdgeSplitterTest.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ void splitSkewedGraph() {
131131
Optional.of(-1L),
132132
skewedGraphStore.nodes(),
133133
skewedGraphStore.nodes(),
134+
skewedGraphStore.nodes(),
134135
RelationshipType.of("SELECTED"),
135136
RelationshipType.of("REMAINING"),
136137
4
@@ -150,6 +151,7 @@ void splitMultiGraph() {
150151
Optional.of(-1L),
151152
multiGraphStore.nodes(),
152153
multiGraphStore.nodes(),
154+
multiGraphStore.nodes(),
153155
RelationshipType.of("SELECTED"),
154156
RelationshipType.of("REMAINING"),
155157
4
@@ -170,6 +172,7 @@ void split() {
170172
Optional.of(-1L),
171173
graphStore.nodes(),
172174
graphStore.nodes(),
175+
graphStore.nodes(),
173176
RelationshipType.of("SELECTED"),
174177
RelationshipType.of("REMAINING"),
175178
4
@@ -207,6 +210,7 @@ void negativeEdgesShouldNotOverlapMasterGraph() {
207210
.generate();
208211

209212
var splitter = new DirectedEdgeSplitter(Optional.of(42L),
213+
huuuuugeDenseGraph,
210214
huuuuugeDenseGraph,
211215
huuuuugeDenseGraph,
212216
RelationshipType.of("SELECTED"),
@@ -241,6 +245,7 @@ void negativeEdgeSampling() {
241245
Optional.of(42L),
242246
graphStore.nodes(),
243247
graphStore.nodes(),
248+
graphStore.nodes(),
244249
RelationshipType.of("SELECTED"),
245250
RelationshipType.of("REMAINING"),
246251
4
@@ -261,6 +266,7 @@ void splitWithFilteringWithDifferentSourceTargetLabels() {
261266
Collection<NodeLabel> targetNodeLabels = List.of(NodeLabel.of("C"), NodeLabel.of("D"));
262267
var splitter = new DirectedEdgeSplitter(
263268
Optional.of(1337L),
269+
multiLabelGraphStore.nodes(),
264270
multiLabelGraphStore.getGraph(sourceNodeLabels),
265271
multiLabelGraphStore.getGraph(targetNodeLabels),
266272
RelationshipType.of("SELECTED"),
@@ -280,7 +286,7 @@ void splitWithFilteringWithDifferentSourceTargetLabels() {
280286

281287
var selectedRelationships = result.selectedRels().build();
282288
assertThat(selectedRelationships.topology()).satisfies(topology -> {
283-
assertRelSamplingProperties(selectedRelationships, multiLabelGraph);
289+
assertRelSamplingProperties(selectedRelationships, multiLabelGraphStore);
284290
assertThat(topology.elementCount()).isEqualTo(1);
285291
assertFalse(topology.isMultiGraph());
286292
});
@@ -295,6 +301,7 @@ void samplesWithinBounds() {
295301
Optional.of(42L),
296302
graphStore.nodes(),
297303
graphStore.nodes(),
304+
graphStore.nodes(),
298305
RelationshipType.of("SELECTED"),
299306
RelationshipType.of("REMAINING"),
300307
4
@@ -310,6 +317,7 @@ void shouldPreserveRelationshipWeights() {
310317
Optional.of(42L),
311318
graphStore.nodes(),
312319
graphStore.nodes(),
320+
graphStore.nodes(),
313321
RelationshipType.of("SELECTED"),
314322
RelationshipType.of("REMAINING"),
315323
4

ml/ml-algo/src/test/java/org/neo4j/gds/ml/splitting/UndirectedEdgeSplitterTest.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ void split() {
9797
Optional.of(1337L),
9898
graphStore.nodes(),
9999
graphStore.nodes(),
100+
graphStore.nodes(),
100101
RelationshipType.of("SELECTED"),
101102
RelationshipType.of("REMAINING"),
102103
4
@@ -127,6 +128,7 @@ void splitMultiGraph() {
127128
Optional.of(-1L),
128129
multiGraphStore.nodes(),
129130
multiGraphStore.nodes(),
131+
multiGraphStore.nodes(),
130132
RelationshipType.of("SELECTED"),
131133
RelationshipType.of("REMAINING"),
132134
4
@@ -157,6 +159,7 @@ void negativeEdgesShouldNotOverlapMasterGraph() {
157159
Optional.of(42L),
158160
huuuuugeDenseGraph,
159161
huuuuugeDenseGraph,
162+
huuuuugeDenseGraph,
160163
RelationshipType.of("SELECTED"),
161164
RelationshipType.of("REMAINING"),
162165
4
@@ -198,29 +201,31 @@ void shouldProduceDeterministicResult() {
198201

199202
var splitResult1 = new UndirectedEdgeSplitter(
200203
Optional.of(12L),
201-
graphStore.nodes(),
202-
graphStore.nodes(),
204+
graph.idMap(),
205+
graph.idMap(),
206+
graph.idMap(),
203207
RelationshipType.of("SELECTED"),
204208
RelationshipType.of("REMAINING"),
205209
4
206210
).splitPositiveExamples(graph, 0.5, Optional.empty());
207211
var splitResult2 = new UndirectedEdgeSplitter(
208212
Optional.of(12L),
209-
graphStore.nodes(),
210-
graphStore.nodes(),
213+
graph.idMap(),
214+
graph.idMap(),
215+
graph.idMap(),
211216
RelationshipType.of("SELECTED"),
212217
RelationshipType.of("REMAINING"),
213218
4
214219
).splitPositiveExamples(graph, 0.5, Optional.empty());
215220
var remainingAreEqual = relationshipsAreEqual(
216-
graph,
221+
graph.idMap(),
217222
splitResult1.remainingRels().build(),
218223
splitResult2.remainingRels().build()
219224
);
220225
assertTrue(remainingAreEqual);
221226

222227
var holdoutAreEqual = relationshipsAreEqual(
223-
graph,
228+
graph.idMap(),
224229
splitResult1.selectedRels().build(),
225230
splitResult2.selectedRels().build()
226231
);
@@ -244,6 +249,7 @@ void shouldProduceNonDeterministicResult() {
244249
Optional.of(42L),
245250
graphStore.nodes(),
246251
graphStore.nodes(),
252+
graphStore.nodes(),
247253
RelationshipType.of("SELECTED"),
248254
RelationshipType.of("REMAINING"),
249255
4
@@ -252,6 +258,7 @@ void shouldProduceNonDeterministicResult() {
252258
Optional.of(117L),
253259
graphStore.nodes(),
254260
graphStore.nodes(),
261+
graphStore.nodes(),
255262
RelationshipType.of("SELECTED"),
256263
RelationshipType.of("REMAINING"),
257264
4
@@ -277,6 +284,7 @@ void negativeEdgeSampling() {
277284
Optional.of(42L),
278285
graphStore.nodes(),
279286
graphStore.nodes(),
287+
graphStore.nodes(),
280288
RelationshipType.of("SELECTED"),
281289
RelationshipType.of("REMAINING"),
282290
4
@@ -299,6 +307,7 @@ void splitWithFilteringWithSameSourceTargetLabels() {
299307
Optional.of(1337L),
300308
graphStore.getGraph(NodeLabel.of("A")),
301309
graphStore.getGraph(NodeLabel.of("A")),
310+
graphStore.getGraph(NodeLabel.of("A")),
302311
RelationshipType.of("SELECTED"),
303312
RelationshipType.of("REMAINING"),
304313
4
@@ -332,6 +341,7 @@ void splitWithFilteringWithDifferentSourceTargetLabels() {
332341
Collection<NodeLabel> targetNodeLabels = List.of(NodeLabel.of("C"), NodeLabel.of("D"));
333342
var splitter = new UndirectedEdgeSplitter(
334343
Optional.of(1337L),
344+
multiLabelGraphStore.nodes(),
335345
multiLabelGraphStore.getGraph(sourceNodeLabels),
336346
multiLabelGraphStore.getGraph(targetNodeLabels),
337347
RelationshipType.of("SELECTED"),
@@ -367,6 +377,7 @@ void samplesWithinBounds() {
367377
Optional.of(42L),
368378
graphStore.nodes(),
369379
graphStore.nodes(),
380+
graphStore.nodes(),
370381
RelationshipType.of("SELECTED"),
371382
RelationshipType.of("REMAINING"),
372383
4
@@ -382,6 +393,7 @@ void shouldPreserveRelationshipWeights() {
382393
Optional.of(42L),
383394
graphStore.nodes(),
384395
graphStore.nodes(),
396+
graphStore.nodes(),
385397
RelationshipType.of("SELECTED"),
386398
RelationshipType.of("REMAINING"),
387399
4
@@ -411,6 +423,7 @@ void zeroNegativeSamples() {
411423
Optional.of(1337L),
412424
graphStore.nodes(),
413425
graphStore.nodes(),
426+
graphStore.nodes(),
414427
RelationshipType.of("SELECTED"),
415428
RelationshipType.of("REMAINING"),
416429
4

0 commit comments

Comments
 (0)