Skip to content

Commit c8ceddf

Browse files
Make KmeansStreamProcTest.java order indifferent
1 parent 56a2b82 commit c8ceddf

File tree

1 file changed

+47
-28
lines changed

1 file changed

+47
-28
lines changed

proc/community/src/test/java/org/neo4j/gds/kmeans/KmeansStreamProcTest.java

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
package org.neo4j.gds.kmeans;
2121

22+
import org.assertj.core.data.Offset;
2223
import org.junit.jupiter.api.AfterEach;
2324
import org.junit.jupiter.api.BeforeEach;
2425
import org.junit.jupiter.params.ParameterizedTest;
@@ -32,6 +33,10 @@
3233
import java.util.List;
3334
import java.util.Map;
3435

36+
import static org.assertj.core.api.Assertions.assertThat;
37+
import static org.assertj.core.api.InstanceOfAssertFactories.DOUBLE;
38+
import static org.assertj.core.api.InstanceOfAssertFactories.LONG;
39+
3540
class KmeansStreamProcTest extends BaseProcTest {
3641
@BeforeEach
3742
void setup() throws Exception {
@@ -48,7 +53,7 @@ void cleanCatalog() {
4853

4954
@ParameterizedTest
5055
@ValueSource(strings = {"gds.kmeans","gds.beta.kmeans"})
51-
void shouldWork(String procedureName) {
56+
void shouldStream(String procedureName) {
5257
String nodeCreateQuery =
5358
"CREATE" +
5459
" (a:Person {kmeans: [1.0, 1.0]} )" +
@@ -75,33 +80,47 @@ void shouldWork(String procedureName) {
7580
.addParameter("concurrency", 1)
7681
.addParameter("computeSilhouette", true)
7782
.yields("nodeId", "communityId", "distanceFromCentroid", "silhouette");
78-
assertCypherResult(algoQuery, List.of(
79-
Map.of(
80-
"nodeId", 0L,
81-
"communityId", 0L,
82-
"distanceFromCentroid", 0.5,
83-
"silhouette", 0.9929292857150108
84-
),
85-
Map.of(
86-
"nodeId", 1L,
87-
"communityId", 1L,
88-
"distanceFromCentroid", Math.sqrt(2),
89-
"silhouette", 0.9799515133128792
90-
),
91-
Map.of(
92-
"nodeId", 2L,
93-
"communityId", 0L,
94-
"distanceFromCentroid", 0.5,
95-
"silhouette", 0.9928938477702276
96-
),
97-
Map.of(
98-
"nodeId", 3L,
99-
"communityId", 1L,
100-
"distanceFromCentroid", Math.sqrt(2),
101-
"silhouette", 0.9799505034216922
102-
)
103-
104-
));
83+
84+
var expectedStreamResult = Map.of(
85+
0L, new KmeansTestStreamResult(0L, 0.5, 0.9929292857150108),
86+
1L, new KmeansTestStreamResult(1L, Math.sqrt(2), 0.9799515133128792),
87+
2L, new KmeansTestStreamResult(0L, 0.5, 0.9928938477702276),
88+
3L, new KmeansTestStreamResult(1L, Math.sqrt(2), 0.9799505034216922)
89+
90+
);
91+
92+
var rowCount = runQueryWithRowConsumer(algoQuery, (resultRow) -> {
93+
94+
var nodeId = resultRow.getNumber("nodeId");
95+
var expectedCommunity = expectedStreamResult.get(nodeId).communityId;
96+
var expectedDistance = expectedStreamResult.get(nodeId).distanceFromCentroid;
97+
var expectedsilhouette = expectedStreamResult.get(nodeId).silhouette;
98+
99+
assertThat(resultRow.getNumber("communityId")).asInstanceOf(LONG).isEqualTo(expectedCommunity);
100+
assertThat(resultRow.getNumber("distanceFromCentroid")).asInstanceOf(DOUBLE).isCloseTo(
101+
expectedDistance,
102+
Offset.offset(1e-6)
103+
);
104+
assertThat(resultRow.getNumber("silhouette")).asInstanceOf(DOUBLE).isCloseTo(
105+
expectedsilhouette,
106+
Offset.offset(1e-6)
107+
);
108+
109+
});
110+
assertThat(rowCount).isEqualTo(4l);
111+
}
112+
113+
class KmeansTestStreamResult {
114+
115+
public long communityId;
116+
public double distanceFromCentroid;
117+
public double silhouette;
118+
119+
public KmeansTestStreamResult(long communityId, double distanceFromCentroid, double silhouette) {
120+
this.communityId = communityId;
121+
this.distanceFromCentroid = distanceFromCentroid;
122+
this.silhouette = silhouette;
123+
}
105124
}
106125

107126
}

0 commit comments

Comments
 (0)