2424import org .neo4j .gds .InspectableTestProgressTracker ;
2525import org .neo4j .gds .Orientation ;
2626import org .neo4j .gds .RelationshipType ;
27- import org .neo4j .gds .api .Graph ;
2827import org .neo4j .gds .api .GraphStore ;
2928import org .neo4j .gds .api .schema .ElementSchemaEntry ;
3029import org .neo4j .gds .assertj .Extractors ;
4140import org .neo4j .gds .extension .Inject ;
4241import org .neo4j .gds .ml .pipeline .linkPipeline .LinkPredictionSplitConfigImpl ;
4342
44- import java .util .ArrayList ;
4543import java .util .List ;
4644import java .util .Map ;
4745import java .util .Optional ;
4846import java .util .Set ;
4947import java .util .stream .Collectors ;
50- import java .util .stream .IntStream ;
5148import java .util .stream .Stream ;
5249
5350import static org .assertj .core .api .Assertions .assertThat ;
@@ -62,40 +59,15 @@ class LinkPredictionRelationshipSamplerTest {
6259 @ GdlGraph (orientation = Orientation .UNDIRECTED , idOffset = 1337 )
6360 private static final String GRAPH =
6461 "CREATE " +
65- "(x1:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
66- "(x2:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
67- "(x3:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
68- "(x4:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
69- "(x5:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
70- "(x6:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
71- "(x7:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
72- "(x8:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
73- "(x9:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
74- "(y1:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
75- "(y2:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
76- "(y3:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
77- "(y4:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
78- "(y5:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
79- "(y6:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
80- "(y7:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
81- "(y8:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
82- "(y9:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
83- "(z1:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
84- "(z3:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
85- "(z5:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
86- "(z6:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
87- "(z7:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
88- "(z8:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
89- "(z9:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
9062 "(a:N {scalar: 0, array: [-1.0, -2.0, 1.0, 1.0, 3.0]}), " +
91- "(z4:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
9263 "(b:N {scalar: 4, array: [2.0, 1.0, -2.0, 2.0, 1.0]}), " +
9364 "(c:N {scalar: 0, array: [-3.0, 4.0, 3.0, 3.0, 2.0]}), " +
9465 "(d:N {scalar: 3, array: [1.0, 3.0, 1.0, -1.0, -1.0]}), " +
9566 "(e:N {scalar: 1, array: [-2.0, 1.0, 2.0, 1.0, -1.0]}), " +
9667 "(f:N {scalar: 0, array: [-1.0, -3.0, 1.0, 2.0, 2.0]}), " +
9768 "(g:N {scalar: 1, array: [3.0, 1.0, -3.0, 3.0, 1.0]}), " +
98- "(z2:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
69+ // leaving some id gap between nodes
70+ "(:Ignore {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " .repeat (20 ) +
9971 "(h:N {scalar: 3, array: [-1.0, 3.0, 2.0, 1.0, -3.0]}), " +
10072 "(i:N {scalar: 3, array: [4.0, 1.0, 1.0, 2.0, 1.0]}), " +
10173 "(j:N {scalar: 4, array: [1.0, -4.0, 2.0, -2.0, 2.0]}), " +
@@ -104,7 +76,7 @@ class LinkPredictionRelationshipSamplerTest {
10476 "(m:N {scalar: 0, array: [4.0, 4.0, 1.0, 1.0, 1.0]}), " +
10577 "(n:N {scalar: 3, array: [1.0, -2.0, 3.0, 2.0, 3.0]}), " +
10678 "(o:N {scalar: 2, array: [-3.0, 3.0, -1.0, -1.0, 1.0]}), " +
107- "" +
79+
10880 "(a)-[:REL {weight: 2.0}]->(b), " +
10981 "(a)-[:REL {weight: 1.0}]->(c), " +
11082 "(b)-[:REL {weight: 3.0}]->(c), " +
@@ -421,44 +393,23 @@ void splitWithSpecifiedNegativeRelationships() {
421393 //8 * 0.5 = 4 positive, 1 negative
422394 assertThat (trainGraphSize ).isEqualTo (5 );
423395 assertThat (featureInputGraphSize ).isEqualTo (8 );
424- Graph outGraph = graphStore .getGraph (trainConfig .nodeLabelIdentifiers (graphStore ), List .of (splitConfig .testRelationshipType (), splitConfig .trainRelationshipType ()), Optional .of ("label" ));
425- var positiveEdgesList = new ArrayList <String >();
426- var negativeEdgesList = new ArrayList <String >();
427- var idsToNames = IntStream
428- .range ('a' , 'o' + 1 )
429- .mapToObj (i -> (char ) i )
430- .collect (Collectors .toMap (c -> idFunction .of (String .valueOf (c )), String ::valueOf ));
396+ var outGraph = graphStore .getGraph (trainConfig .nodeLabelIdentifiers (graphStore ), List .of (splitConfig .testRelationshipType (), splitConfig .trainRelationshipType ()), Optional .of ("label" ));
397+
398+ var negativeRelSpace = graphStore .getGraph (RelationshipType .of ("NEGATIVE" ));
399+ var positiveRelSpace = graphStore .getGraph (RelationshipType .of ("REL" ));
400+
431401 outGraph .forEachNode (nodeId -> {
432- outGraph .forEachRelationship (nodeId , -2 , (s ,t , w ) -> {
433- var relationshipString = "(" + idsToNames .get (outGraph .toOriginalNodeId (s )) + "," + idsToNames .get (outGraph .toOriginalNodeId (t )) + ")" ;
402+ outGraph .forEachRelationship (nodeId , Double .NaN , (s ,t , w ) -> {
434403 if (w == 1.0 ) {
435- positiveEdgesList . add ( relationshipString );
404+ assertThat ( positiveRelSpace . exists ( outGraph . toRootNodeId ( s ), outGraph . toRootNodeId ( t ))). isTrue ( );
436405 }
437406 if (w == 0.0 ) {
438- negativeEdgesList . add ( relationshipString );
407+ assertThat ( negativeRelSpace . exists ( outGraph . toRootNodeId ( s ), outGraph . toRootNodeId ( t ))). isTrue ( );
439408 }
440409 return true ;
441410 });
442411 return true ;
443412 }
444413 );
445- assertThat (String .join ("\n " , positiveEdgesList ))
446- .isEqualTo (
447- "(a,b)\n " +
448- "(a,c)\n " +
449- "(c,d)\n " +
450- "(e,g)\n " +
451- "(f,g)\n " +
452- "(h,i)\n " +
453- "(j,k)\n " +
454- "(j,l)\n " +
455- "(k,l)\n " +
456- "(m,n)\n " +
457- "(m,o)\n " +
458- "(n,o)" );
459- assertThat (String .join ("\n " , negativeEdgesList ))
460- .isEqualTo ("(a,k)\n " +
461- "(b,k)\n " +
462- "(c,k)" );
463414 }
464415}
0 commit comments