Skip to content

Commit e1ad623

Browse files
committed
Cleanup test
Avoid string comparisons which are fixed for a specific seed
1 parent 8aaf93c commit e1ad623

File tree

1 file changed

+11
-60
lines changed

1 file changed

+11
-60
lines changed

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSamplerTest.java

Lines changed: 11 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.neo4j.gds.InspectableTestProgressTracker;
2525
import org.neo4j.gds.Orientation;
2626
import org.neo4j.gds.RelationshipType;
27-
import org.neo4j.gds.api.Graph;
2827
import org.neo4j.gds.api.GraphStore;
2928
import org.neo4j.gds.api.schema.ElementSchemaEntry;
3029
import org.neo4j.gds.assertj.Extractors;
@@ -41,13 +40,11 @@
4140
import org.neo4j.gds.extension.Inject;
4241
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfigImpl;
4342

44-
import java.util.ArrayList;
4543
import java.util.List;
4644
import java.util.Map;
4745
import java.util.Optional;
4846
import java.util.Set;
4947
import java.util.stream.Collectors;
50-
import java.util.stream.IntStream;
5148
import java.util.stream.Stream;
5249

5350
import static org.assertj.core.api.Assertions.assertThat;
@@ -62,40 +59,15 @@ class LinkPredictionRelationshipSamplerTest {
6259
@GdlGraph(orientation = Orientation.UNDIRECTED, idOffset = 1337)
6360
private static final String GRAPH =
6461
"CREATE " +
65-
"(x1:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
66-
"(x2:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
67-
"(x3:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
68-
"(x4:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
69-
"(x5:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
70-
"(x6:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
71-
"(x7:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
72-
"(x8:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
73-
"(x9:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
74-
"(y1:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
75-
"(y2:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
76-
"(y3:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
77-
"(y4:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
78-
"(y5:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
79-
"(y6:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
80-
"(y7:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
81-
"(y8:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
82-
"(y9:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
83-
"(z1:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
84-
"(z3:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
85-
"(z5:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
86-
"(z6:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
87-
"(z7:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
88-
"(z8:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
89-
"(z9:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
9062
"(a:N {scalar: 0, array: [-1.0, -2.0, 1.0, 1.0, 3.0]}), " +
91-
"(z4:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
9263
"(b:N {scalar: 4, array: [2.0, 1.0, -2.0, 2.0, 1.0]}), " +
9364
"(c:N {scalar: 0, array: [-3.0, 4.0, 3.0, 3.0, 2.0]}), " +
9465
"(d:N {scalar: 3, array: [1.0, 3.0, 1.0, -1.0, -1.0]}), " +
9566
"(e:N {scalar: 1, array: [-2.0, 1.0, 2.0, 1.0, -1.0]}), " +
9667
"(f:N {scalar: 0, array: [-1.0, -3.0, 1.0, 2.0, 2.0]}), " +
9768
"(g:N {scalar: 1, array: [3.0, 1.0, -3.0, 3.0, 1.0]}), " +
98-
"(z2:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
69+
// leaving some id gap between nodes
70+
"(:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), ".repeat(20) +
9971
"(h:N {scalar: 3, array: [-1.0, 3.0, 2.0, 1.0, -3.0]}), " +
10072
"(i:N {scalar: 3, array: [4.0, 1.0, 1.0, 2.0, 1.0]}), " +
10173
"(j:N {scalar: 4, array: [1.0, -4.0, 2.0, -2.0, 2.0]}), " +
@@ -104,7 +76,7 @@ class LinkPredictionRelationshipSamplerTest {
10476
"(m:N {scalar: 0, array: [4.0, 4.0, 1.0, 1.0, 1.0]}), " +
10577
"(n:N {scalar: 3, array: [1.0, -2.0, 3.0, 2.0, 3.0]}), " +
10678
"(o:N {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
107-
"" +
79+
10880
"(a)-[:REL {weight: 2.0}]->(b), " +
10981
"(a)-[:REL {weight: 1.0}]->(c), " +
11082
"(b)-[:REL {weight: 3.0}]->(c), " +
@@ -421,44 +393,23 @@ void splitWithSpecifiedNegativeRelationships() {
421393
//8 * 0.5 = 4 positive, 1 negative
422394
assertThat(trainGraphSize).isEqualTo(5);
423395
assertThat(featureInputGraphSize).isEqualTo(8);
424-
Graph outGraph = graphStore.getGraph(trainConfig.nodeLabelIdentifiers(graphStore), List.of(splitConfig.testRelationshipType(), splitConfig.trainRelationshipType()), Optional.of("label"));
425-
var positiveEdgesList = new ArrayList<String>();
426-
var negativeEdgesList = new ArrayList<String>();
427-
var idsToNames = IntStream
428-
.range('a', 'o' + 1)
429-
.mapToObj(i -> (char) i)
430-
.collect(Collectors.toMap(c -> idFunction.of(String.valueOf(c)), String::valueOf));
396+
var outGraph = graphStore.getGraph(trainConfig.nodeLabelIdentifiers(graphStore), List.of(splitConfig.testRelationshipType(), splitConfig.trainRelationshipType()), Optional.of("label"));
397+
398+
var negativeRelSpace = graphStore.getGraph(RelationshipType.of("NEGATIVE"));
399+
var positiveRelSpace = graphStore.getGraph(RelationshipType.of("REL"));
400+
431401
outGraph.forEachNode(nodeId -> {
432-
outGraph.forEachRelationship(nodeId, -2, (s,t, w) -> {
433-
var relationshipString = "(" + idsToNames.get(outGraph.toOriginalNodeId(s)) + "," + idsToNames.get(outGraph.toOriginalNodeId(t)) + ")";
402+
outGraph.forEachRelationship(nodeId, Double.NaN, (s,t, w) -> {
434403
if (w == 1.0) {
435-
positiveEdgesList.add(relationshipString);
404+
assertThat(positiveRelSpace.exists(outGraph.toRootNodeId(s), outGraph.toRootNodeId(t))).isTrue();
436405
}
437406
if (w == 0.0) {
438-
negativeEdgesList.add(relationshipString);
407+
assertThat(negativeRelSpace.exists(outGraph.toRootNodeId(s), outGraph.toRootNodeId(t))).isTrue();
439408
}
440409
return true;
441410
});
442411
return true;
443412
}
444413
);
445-
assertThat(String.join("\n", positiveEdgesList))
446-
.isEqualTo(
447-
"(a,b)\n" +
448-
"(a,c)\n" +
449-
"(c,d)\n" +
450-
"(e,g)\n" +
451-
"(f,g)\n" +
452-
"(h,i)\n" +
453-
"(j,k)\n" +
454-
"(j,l)\n" +
455-
"(k,l)\n" +
456-
"(m,n)\n" +
457-
"(m,o)\n" +
458-
"(n,o)");
459-
assertThat(String.join("\n", negativeEdgesList))
460-
.isEqualTo("(a,k)\n" +
461-
"(b,k)\n" +
462-
"(c,k)");
463414
}
464415
}

0 commit comments

Comments
 (0)