3333import java .util .Optional ;
3434
3535import static org .junit .jupiter .api .Assertions .assertEquals ;
36- import static org .junit .jupiter .api .Assertions .assertFalse ;
3736import static org .junit .jupiter .api .Assertions .assertTrue ;
3837import static org .neo4j .gds .TestSupport .fromGdl ;
3938
@@ -61,15 +60,15 @@ class SelectionStrategyTest {
6160 void selectAll () {
6261 SelectionStrategy selectionStrategy = SelectionStrategy .ALL ;
6362 selectionStrategy .init (graph , Pools .DEFAULT , 1 );
64- assertEquals (graph .nodeCount (), samplingSize (graph . nodeCount (), selectionStrategy ));
63+ assertEquals (graph .nodeCount (), samplingSize (selectionStrategy ));
6564 }
6665
6766 @ ParameterizedTest
6867 @ ValueSource (longs = {0 , 1 , 2 , 10 , 11 })
6968 void selectSamplingSize (long samplingSize ) {
7069 SelectionStrategy selectionStrategy = new SelectionStrategy .RandomDegree (samplingSize );
7170 selectionStrategy .init (graph , Pools .DEFAULT , 1 );
72- assertEquals (samplingSize , samplingSize (graph . nodeCount (), selectionStrategy ));
71+ assertEquals (samplingSize , samplingSize (selectionStrategy ));
7372 }
7473
7574 @ ParameterizedTest
@@ -83,17 +82,18 @@ void selectSamplingSizeMultiThreaded(long samplingSize) {
8382 .generate ();
8483 SelectionStrategy selectionStrategy = new SelectionStrategy .RandomDegree (samplingSize , Optional .of (42L ));
8584 selectionStrategy .init (graph , Pools .DEFAULT , 4 );
86- assertEquals (samplingSize , samplingSize (graph . nodeCount (), selectionStrategy ));
85+ assertEquals (samplingSize , samplingSize (selectionStrategy ));
8786 }
8887
8988 @ Test
9089 void selectSamplingSizeWithSeed () {
9190 SelectionStrategy selectionStrategy = new SelectionStrategy .RandomDegree (3 , Optional .of (42L ));
9291 selectionStrategy .init (graph , Pools .DEFAULT , 1 );
93- assertEquals (3 , samplingSize (graph .nodeCount (), selectionStrategy ));
94- assertTrue (selectionStrategy .select (graph .toMappedNodeId ("a" )));
95- assertTrue (selectionStrategy .select (graph .toMappedNodeId ("b" )));
96- assertTrue (selectionStrategy .select (graph .toMappedNodeId ("f" )));
92+ assertEquals (3 , samplingSize (selectionStrategy ));
93+ selectionStrategy .init (graph , Pools .DEFAULT , 1 );
94+ assertEquals (graph .toMappedNodeId ("a" ), selectionStrategy .next ());
95+ assertEquals (graph .toMappedNodeId ("b" ), selectionStrategy .next ());
96+ assertEquals (graph .toMappedNodeId ("f" ), selectionStrategy .next ());
9797 }
9898
9999 @ Test
@@ -102,8 +102,9 @@ void neverIncludeZeroDegNodesIfBetterChoicesExist() {
102102
103103 SelectionStrategy selectionStrategy = new SelectionStrategy .RandomDegree (1 );
104104 selectionStrategy .init (graph , Pools .DEFAULT , 1 );
105- assertEquals (1 , samplingSize (graph .nodeCount (), selectionStrategy ));
106- assertTrue (selectionStrategy .select (graph .toMappedNodeId ("a" )));
105+ assertEquals (1 , samplingSize (selectionStrategy ));
106+ selectionStrategy .init (graph , Pools .DEFAULT , 1 );
107+ assertEquals (graph .toMappedNodeId ("a" ), selectionStrategy .next ());
107108 }
108109
109110 @ Test
@@ -112,27 +113,26 @@ void not100PercentLikelyUnlessMaxDegNode() {
112113
113114 SelectionStrategy selectionStrategy = new SelectionStrategy .RandomDegree (1 , Optional .of (42L ));
114115 selectionStrategy .init (graph , Pools .DEFAULT , 1 );
115- assertEquals (1 , samplingSize (graph . nodeCount (), selectionStrategy ));
116- assertFalse ( selectionStrategy .select (graph . toMappedNodeId ( "a" )) );
117- assertTrue ( selectionStrategy . select ( graph .toMappedNodeId ("b" )));
116+ assertEquals (1 , samplingSize (selectionStrategy ));
117+ selectionStrategy .init (graph , Pools . DEFAULT , 1 );
118+ assertEquals ( graph .toMappedNodeId ("b" ), selectionStrategy . next ( ));
118119 }
119120
120121 @ Test
121122 void selectHighDegreeNode () {
122123 SelectionStrategy selectionStrategy = new SelectionStrategy .RandomDegree (1 );
123124 selectionStrategy .init (graph , Pools .DEFAULT , 1 );
124- assertEquals (1 , samplingSize (graph .nodeCount (), selectionStrategy ));
125- var isA = selectionStrategy .select (graph .toMappedNodeId ("a" ));
126- var isB = selectionStrategy .select (graph .toMappedNodeId ("b" ));
127- assertTrue (isA || isB );
125+ assertEquals (1 , samplingSize (selectionStrategy ));
126+ selectionStrategy .init (graph , Pools .DEFAULT , 1 );
127+ var isA = selectionStrategy .next ();
128+ var isB = selectionStrategy .next ();
129+ assertTrue (isA >= 0 || isB >= 0 );
128130 }
129131
130- private static long samplingSize (long nodeCount , SelectionStrategy selectionStrategy ) {
132+ private static long samplingSize (SelectionStrategy selectionStrategy ) {
131133 long samplingSize = 0 ;
132- for (int nodeId = 0 ; nodeId < nodeCount ; nodeId ++) {
133- if (selectionStrategy .select (nodeId )) {
134- samplingSize ++;
135- }
134+ while (selectionStrategy .next () >= 0 ) {
135+ samplingSize ++;
136136 }
137137 return samplingSize ;
138138 }
0 commit comments