1919 */
2020package org .neo4j .gds .kmeans ;
2121
22+ import org .assertj .core .data .Offset ;
2223import org .junit .jupiter .api .AfterEach ;
2324import org .junit .jupiter .api .BeforeEach ;
2425import org .junit .jupiter .params .ParameterizedTest ;
3233import java .util .List ;
3334import java .util .Map ;
3435
36+ import static org .assertj .core .api .Assertions .assertThat ;
37+ import static org .assertj .core .api .InstanceOfAssertFactories .DOUBLE ;
38+ import static org .assertj .core .api .InstanceOfAssertFactories .LONG ;
39+
3540class KmeansStreamProcTest extends BaseProcTest {
3641 @ BeforeEach
3742 void setup () throws Exception {
@@ -48,7 +53,7 @@ void cleanCatalog() {
4853
4954 @ ParameterizedTest
5055 @ ValueSource (strings = {"gds.kmeans" ,"gds.beta.kmeans" })
51- void shouldWork (String procedureName ) {
56+ void shouldStream (String procedureName ) {
5257 String nodeCreateQuery =
5358 "CREATE" +
5459 " (a:Person {kmeans: [1.0, 1.0]} )" +
@@ -75,33 +80,47 @@ void shouldWork(String procedureName) {
7580 .addParameter ("concurrency" , 1 )
7681 .addParameter ("computeSilhouette" , true )
7782 .yields ("nodeId" , "communityId" , "distanceFromCentroid" , "silhouette" );
78- assertCypherResult (algoQuery , List .of (
79- Map .of (
80- "nodeId" , 0L ,
81- "communityId" , 0L ,
82- "distanceFromCentroid" , 0.5 ,
83- "silhouette" , 0.9929292857150108
84- ),
85- Map .of (
86- "nodeId" , 1L ,
87- "communityId" , 1L ,
88- "distanceFromCentroid" , Math .sqrt (2 ),
89- "silhouette" , 0.9799515133128792
90- ),
91- Map .of (
92- "nodeId" , 2L ,
93- "communityId" , 0L ,
94- "distanceFromCentroid" , 0.5 ,
95- "silhouette" , 0.9928938477702276
96- ),
97- Map .of (
98- "nodeId" , 3L ,
99- "communityId" , 1L ,
100- "distanceFromCentroid" , Math .sqrt (2 ),
101- "silhouette" , 0.9799505034216922
102- )
103-
104- ));
83+
84+ var expectedStreamResult = Map .of (
85+ 0L , new KmeansTestStreamResult (0L , 0.5 , 0.9929292857150108 ),
86+ 1L , new KmeansTestStreamResult (1L , Math .sqrt (2 ), 0.9799515133128792 ),
87+ 2L , new KmeansTestStreamResult (0L , 0.5 , 0.9928938477702276 ),
88+ 3L , new KmeansTestStreamResult (1L , Math .sqrt (2 ), 0.9799505034216922 )
89+
90+ );
91+
92+ var rowCount = runQueryWithRowConsumer (algoQuery , (resultRow ) -> {
93+
94+ var nodeId = resultRow .getNumber ("nodeId" );
95+ var expectedCommunity = expectedStreamResult .get (nodeId ).communityId ;
96+ var expectedDistance = expectedStreamResult .get (nodeId ).distanceFromCentroid ;
97+ var expectedsilhouette = expectedStreamResult .get (nodeId ).silhouette ;
98+
99+ assertThat (resultRow .getNumber ("communityId" )).asInstanceOf (LONG ).isEqualTo (expectedCommunity );
100+ assertThat (resultRow .getNumber ("distanceFromCentroid" )).asInstanceOf (DOUBLE ).isCloseTo (
101+ expectedDistance ,
102+ Offset .offset (1e-6 )
103+ );
104+ assertThat (resultRow .getNumber ("silhouette" )).asInstanceOf (DOUBLE ).isCloseTo (
105+ expectedsilhouette ,
106+ Offset .offset (1e-6 )
107+ );
108+
109+ });
110+ assertThat (rowCount ).isEqualTo (4l );
111+ }
112+
113+ class KmeansTestStreamResult {
114+
115+ public long communityId ;
116+ public double distanceFromCentroid ;
117+ public double silhouette ;
118+
119+ public KmeansTestStreamResult (long communityId , double distanceFromCentroid , double silhouette ) {
120+ this .communityId = communityId ;
121+ this .distanceFromCentroid = distanceFromCentroid ;
122+ this .silhouette = silhouette ;
123+ }
105124 }
106125
107126}
0 commit comments