2222import org .junit .jupiter .api .Test ;
2323import org .junit .jupiter .params .ParameterizedTest ;
2424import org .junit .jupiter .params .provider .ValueSource ;
25- import org .neo4j .gds .TestProgressTracker ;
25+ import org .neo4j .gds .applications .algorithms .machinery .ProgressTrackerCreator ;
26+ import org .neo4j .gds .applications .algorithms .machinery .RequestScopedDependencies ;
27+ import org .neo4j .gds .applications .algorithms .similarity .SimilarityAlgorithms ;
2628import org .neo4j .gds .core .concurrency .Concurrency ;
2729import org .neo4j .gds .core .utils .progress .EmptyTaskRegistryFactory ;
2830import org .neo4j .gds .core .utils .progress .tasks .ProgressTracker ;
31+ import org .neo4j .gds .core .utils .warnings .EmptyUserLogRegistryFactory ;
2932import org .neo4j .gds .extension .GdlExtension ;
3033import org .neo4j .gds .extension .GdlGraph ;
3134import org .neo4j .gds .extension .Inject ;
3235import org .neo4j .gds .extension .TestGraph ;
3336import org .neo4j .gds .logging .GdsTestLog ;
3437import org .neo4j .gds .similarity .filtering .NodeFilterSpecFactory ;
38+ import org .neo4j .gds .termination .TerminationFlag ;
3539
3640import java .util .List ;
3741import java .util .stream .Collectors ;
@@ -70,29 +74,23 @@ class FilteredNodeSimilarityTest {
7074
7175 @ Test
7276 void should () {
77+ var similarityAlgorithms = new SimilarityAlgorithms (null , TerminationFlag .RUNNING_TRUE );
78+
7379 var sourceNodeFilter = Stream .of ("a" , "b" , "c" ).map (graph ::toOriginalNodeId ).collect (Collectors .toList ());
7480
7581 var config = FilteredNodeSimilarityStreamConfigImpl .builder ()
7682 .sourceNodeFilter (NodeFilterSpecFactory .create (sourceNodeFilter ))
7783 .build ();
7884
79- var nodeSimilarity = new FilteredNodeSimilarityFactory <>().build (
80- graph ,
81- config ,
82- ProgressTracker .NULL_TRACKER
83- );
84-
8585 // no results for nodes that are not specified in the node filter -- nice
86- var noOfResultsWithSourceNodeOutsideOfFilter = nodeSimilarity
87- .compute ()
86+ var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , config , ProgressTracker .NULL_TRACKER )
8887 .streamResult ()
8988 .filter (res -> !sourceNodeFilter .contains (graph .toOriginalNodeId (res .node1 )))
9089 .count ();
9190 assertThat (noOfResultsWithSourceNodeOutsideOfFilter ).isEqualTo (0L );
9291
9392 // nodes outside of the node filter are not present as target nodes either -- not nice
94- var noOfResultsWithTargetNodeOutSideOfFilter = nodeSimilarity
95- .compute ()
93+ var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , config , ProgressTracker .NULL_TRACKER )
9694 .streamResult ()
9795 .filter (res -> !sourceNodeFilter .contains (graph .toOriginalNodeId (res .node2 )))
9896 .count ();
@@ -101,30 +99,24 @@ void should() {
10199
102100 @ Test
103101 void shouldSurviveIoannisObjections () {
102+ var similarityAlgorithms = new SimilarityAlgorithms (null , TerminationFlag .RUNNING_TRUE );
103+
104104 var sourceNodeFilter = List .of (graph .toOriginalNodeId ("d" ));
105105
106106 var config = FilteredNodeSimilarityStreamConfigImpl .builder ()
107107 .sourceNodeFilter (NodeFilterSpecFactory .create (sourceNodeFilter ))
108108 .concurrency (1 )
109109 .build ();
110110
111- var nodeSimilarity = new FilteredNodeSimilarityFactory <>().build (
112- graph ,
113- config ,
114- ProgressTracker .NULL_TRACKER
115- );
116-
117111 // no results for nodes that are not specified in the node filter -- nice
118- var noOfResultsWithSourceNodeOutsideOfFilter = nodeSimilarity
119- .compute ()
112+ var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , config , ProgressTracker .NULL_TRACKER )
120113 .streamResult ()
121114 .filter (res -> !sourceNodeFilter .contains (graph .toOriginalNodeId (res .node1 )))
122115 .count ();
123116 assertThat (noOfResultsWithSourceNodeOutsideOfFilter ).isEqualTo (0L );
124117
125118 // nodes outside of the node filter are not present as target nodes either -- not nice
126- var noOfResultsWithTargetNodeOutSideOfFilter = nodeSimilarity
127- .compute ()
119+ var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , config , ProgressTracker .NULL_TRACKER )
128120 .streamResult ()
129121 .filter (res -> !sourceNodeFilter .contains (graph .toOriginalNodeId (res .node2 )))
130122 .count ();
@@ -134,6 +126,8 @@ void shouldSurviveIoannisObjections() {
134126 @ ParameterizedTest
135127 @ ValueSource (booleans = {true , false })
136128 void shouldSurviveIoannisFurtherObjections (boolean enableWcc ) {
129+ var similarityAlgorithms = new SimilarityAlgorithms (null , TerminationFlag .RUNNING_TRUE );
130+
137131 var sourceNodeFilter = List .of (graph .toOriginalNodeId ("d" ));
138132
139133 var config = FilteredNodeSimilarityStreamConfigImpl .builder ()
@@ -144,23 +138,15 @@ void shouldSurviveIoannisFurtherObjections(boolean enableWcc) {
144138 .topN (10 )
145139 .build ();
146140
147- var nodeSimilarity = new FilteredNodeSimilarityFactory <>().build (
148- graph ,
149- config ,
150- ProgressTracker .NULL_TRACKER
151- );
152-
153141 // no results for nodes that are not specified in the node filter -- nice
154- var noOfResultsWithSourceNodeOutsideOfFilter = nodeSimilarity
155- .compute ()
142+ var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , config , ProgressTracker .NULL_TRACKER )
156143 .streamResult ()
157144 .filter (res -> !sourceNodeFilter .contains (graph .toOriginalNodeId (res .node1 )))
158145 .count ();
159146 assertThat (noOfResultsWithSourceNodeOutsideOfFilter ).isEqualTo (0L );
160147
161148 // nodes outside of the node filter are not present as target nodes either -- not nice
162- var noOfResultsWithTargetNodeOutSideOfFilter = nodeSimilarity
163- .compute ()
149+ var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , config , ProgressTracker .NULL_TRACKER )
164150 .streamResult ()
165151 .filter (res -> !sourceNodeFilter .contains (graph .toOriginalNodeId (res .node2 )))
166152 .count ();
@@ -170,6 +156,15 @@ void shouldSurviveIoannisFurtherObjections(boolean enableWcc) {
170156 @ ParameterizedTest
171157 @ ValueSource (ints = {1 , 2 })
172158 void shouldLogProgressAccurately (int concurrencyValue ) {
159+ var log = new GdsTestLog ();
160+ var requestScopedDependencies = RequestScopedDependencies .builder ()
161+ .with (EmptyTaskRegistryFactory .INSTANCE )
162+ .with (TerminationFlag .RUNNING_TRUE )
163+ .with (EmptyUserLogRegistryFactory .INSTANCE )
164+ .build ();
165+ var progressTrackerCreator = new ProgressTrackerCreator (log , requestScopedDependencies );
166+ var similarityAlgorithms = new SimilarityAlgorithms (progressTrackerCreator , requestScopedDependencies .getTerminationFlag ());
167+
173168 var sourceNodeFilter = List .of (graph .toOriginalNodeId ("c" ), graph .toOriginalNodeId ("d" ));
174169 var concurrency = new Concurrency (concurrencyValue );
175170 var config = FilteredNodeSimilarityStreamConfigImpl .builder ()
@@ -178,44 +173,28 @@ void shouldLogProgressAccurately(int concurrencyValue) {
178173 .topK (1 )
179174 .topN (10 )
180175 .build ();
181- var progressTask = new FilteredNodeSimilarityFactory <>().progressTask (graph , config );
182- var log = new GdsTestLog ();
183- var progressTracker = new TestProgressTracker (
184- progressTask ,
185- log ,
186- concurrency ,
187- EmptyTaskRegistryFactory .INSTANCE
188- );
189-
190-
191- new FilteredNodeSimilarityFactory <>().build (
192- graph ,
193- config ,
194- progressTracker
195- ).compute ();
196-
176+ similarityAlgorithms .filteredNodeSimilarity (graph , config );
197177
198178 assertThat (log .getMessages (INFO ))
199179 .extracting (removingThreadId ())
200180 .containsExactly (
201- "FilteredNodeSimilarity :: Start" ,
202- "FilteredNodeSimilarity :: prepare :: Start" ,
203- "FilteredNodeSimilarity :: prepare 33%" ,
204- "FilteredNodeSimilarity :: prepare 55%" ,
205- "FilteredNodeSimilarity :: prepare 66%" ,
206- "FilteredNodeSimilarity :: prepare 100%" ,
207- "FilteredNodeSimilarity :: prepare :: Finished" ,
208- "FilteredNodeSimilarity :: compare node pairs :: Start" ,
209- "FilteredNodeSimilarity :: compare node pairs 12%" ,
210- "FilteredNodeSimilarity :: compare node pairs 25%" ,
211- "FilteredNodeSimilarity :: compare node pairs 37%" ,
212- "FilteredNodeSimilarity :: compare node pairs 50%" ,
213- "FilteredNodeSimilarity :: compare node pairs 62%" ,
214- "FilteredNodeSimilarity :: compare node pairs 75%" ,
215- "FilteredNodeSimilarity :: compare node pairs 100%" ,
216- "FilteredNodeSimilarity :: compare node pairs :: Finished" ,
217- "FilteredNodeSimilarity :: Finished"
181+ "Filtered Node Similarity :: Start" ,
182+ "Filtered Node Similarity :: prepare :: Start" ,
183+ "Filtered Node Similarity :: prepare 33%" ,
184+ "Filtered Node Similarity :: prepare 55%" ,
185+ "Filtered Node Similarity :: prepare 66%" ,
186+ "Filtered Node Similarity :: prepare 100%" ,
187+ "Filtered Node Similarity :: prepare :: Finished" ,
188+ "Filtered Node Similarity :: compare node pairs :: Start" ,
189+ "Filtered Node Similarity :: compare node pairs 12%" ,
190+ "Filtered Node Similarity :: compare node pairs 25%" ,
191+ "Filtered Node Similarity :: compare node pairs 37%" ,
192+ "Filtered Node Similarity :: compare node pairs 50%" ,
193+ "Filtered Node Similarity :: compare node pairs 62%" ,
194+ "Filtered Node Similarity :: compare node pairs 75%" ,
195+ "Filtered Node Similarity :: compare node pairs 100%" ,
196+ "Filtered Node Similarity :: compare node pairs :: Finished" ,
197+ "Filtered Node Similarity :: Finished"
218198 );
219199 }
220-
221200}
0 commit comments