Skip to content

Commit 6856721

Browse files
Merge pull request #8634 from IoannisPanagiotas/clear-ord-proc-test-26
Clear ord proc test 26
2 parents 532bddf + 4af8a55 commit 6856721

File tree

6 files changed

+129
-90
lines changed

6 files changed

+129
-90
lines changed

proc/centrality/src/test/java/org/neo4j/gds/betweenness/BetweennessCentralityStreamProcTest.java

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
package org.neo4j.gds.betweenness;
2121

22+
import org.assertj.core.data.Offset;
2223
import org.junit.jupiter.api.BeforeEach;
2324
import org.junit.jupiter.api.Test;
2425
import org.neo4j.gds.BaseProcTest;
@@ -30,10 +31,11 @@
3031
import org.neo4j.gds.extension.Neo4jGraph;
3132
import org.neo4j.graphdb.QueryExecutionException;
3233

33-
import java.util.List;
3434
import java.util.Map;
3535

36+
import static org.assertj.core.api.Assertions.assertThat;
3637
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
38+
import static org.assertj.core.api.InstanceOfAssertFactories.DOUBLE;
3739

3840
class BetweennessCentralityStreamProcTest extends BaseProcTest {
3941

@@ -76,15 +78,26 @@ void testStream() {
7678
.streamMode()
7779
.yields();
7880

79-
assertCypherResult(
80-
query,
81-
List.of(
82-
Map.of("nodeId", idFunction.of("a"), "score", 0.0),
83-
Map.of("nodeId", idFunction.of("b"), "score", 3.0),
84-
Map.of("nodeId", idFunction.of("c"), "score", 4.0),
85-
Map.of("nodeId", idFunction.of("d"), "score", 3.0),
86-
Map.of("nodeId", idFunction.of("e"), "score", 0.0)
87-
));
81+
var expectedResultMap = Map.of(
82+
idFunction.of("a"), 0.0,
83+
idFunction.of("b"), 3.0,
84+
idFunction.of("c"), 4.0,
85+
idFunction.of("d"), 3.0,
86+
idFunction.of("e"), 0.0
87+
);
88+
89+
var rowCount = runQueryWithRowConsumer(query, (resultRow) -> {
90+
91+
var nodeId = resultRow.getNumber("nodeId");
92+
var expectedScore = expectedResultMap.get(nodeId);
93+
94+
assertThat(resultRow.getNumber("score")).asInstanceOf(DOUBLE).isCloseTo(
95+
expectedScore,
96+
Offset.offset(1e-6)
97+
);
98+
99+
});
100+
assertThat(rowCount).isEqualTo(5l);
88101
}
89102

90103
// FIXME: This should not be tested here

proc/centrality/src/test/java/org/neo4j/gds/closeness/ClosenessCentralityStreamProcTest.java

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
*/
2020
package org.neo4j.gds.closeness;
2121

22-
import org.hamcrest.Matchers;
22+
import org.assertj.core.data.Offset;
2323
import org.junit.jupiter.api.BeforeEach;
2424
import org.junit.jupiter.api.Test;
2525
import org.neo4j.gds.BaseProcTest;
@@ -30,9 +30,11 @@
3030
import org.neo4j.gds.extension.Inject;
3131
import org.neo4j.gds.extension.Neo4jGraph;
3232

33-
import java.util.List;
3433
import java.util.Map;
3534

35+
import static java.util.Map.entry;
36+
import static org.assertj.core.api.Assertions.assertThat;
37+
3638
class ClosenessCentralityStreamProcTest extends BaseProcTest {
3739

3840
@Neo4jGraph
@@ -86,7 +88,7 @@ class ClosenessCentralityStreamProcTest extends BaseProcTest {
8688
@Inject
8789
private IdFunction idFunction;
8890

89-
private List<Map<String, Object>> expectedCentralityResult;
91+
private Map<Long, Double> expectedCentralityResult;
9092

9193
@BeforeEach
9294
void setupGraph() throws Exception {
@@ -95,19 +97,20 @@ void setupGraph() throws Exception {
9597
GraphProjectProc.class
9698
);
9799

98-
expectedCentralityResult = List.of(
99-
Map.of("nodeId", idFunction.of("n0"), "score", Matchers.closeTo(1.0, 0.01)),
100-
Map.of("nodeId", idFunction.of("n1"), "score", Matchers.closeTo(0.588, 0.01)),
101-
Map.of("nodeId", idFunction.of("n2"), "score", Matchers.closeTo(0.588, 0.01)),
102-
Map.of("nodeId", idFunction.of("n3"), "score", Matchers.closeTo(0.588, 0.01)),
103-
Map.of("nodeId", idFunction.of("n4"), "score", Matchers.closeTo(0.588, 0.01)),
104-
Map.of("nodeId", idFunction.of("n5"), "score", Matchers.closeTo(0.588, 0.01)),
105-
Map.of("nodeId", idFunction.of("n6"), "score", Matchers.closeTo(0.588, 0.01)),
106-
Map.of("nodeId", idFunction.of("n7"), "score", Matchers.closeTo(0.588, 0.01)),
107-
Map.of("nodeId", idFunction.of("n8"), "score", Matchers.closeTo(0.588, 0.01)),
108-
Map.of("nodeId", idFunction.of("n9"), "score", Matchers.closeTo(0.588, 0.01)),
109-
Map.of("nodeId", idFunction.of("n10"), "score", Matchers.closeTo(0.588, 0.01))
100+
expectedCentralityResult = Map.ofEntries(
101+
entry(idFunction.of("n0"), 1.0),
102+
entry(idFunction.of("n1"), 0.588),
103+
entry(idFunction.of("n2"), 0.588),
104+
entry(idFunction.of("n3"), 0.588),
105+
entry(idFunction.of("n4"), 0.588),
106+
entry(idFunction.of("n5"), 0.588),
107+
entry(idFunction.of("n6"), 0.588),
108+
entry(idFunction.of("n7"), 0.588),
109+
entry(idFunction.of("n8"), 0.588),
110+
entry(idFunction.of("n9"), 0.588),
111+
entry(idFunction.of("n10"), 0.588)
110112
);
113+
111114
loadCompleteGraph(DEFAULT_GRAPH_NAME, Orientation.UNDIRECTED);
112115
}
113116

@@ -118,7 +121,13 @@ void shouldStream() {
118121
.streamMode()
119122
.yields("nodeId", "score");
120123

121-
assertCypherResult(query, expectedCentralityResult);
124+
var rowCount = runQueryWithRowConsumer(query, row -> {
125+
var nodeId = row.getNumber("nodeId").longValue();
126+
var property = row.getNumber("score").doubleValue();
127+
assertThat(expectedCentralityResult.get(nodeId)).isCloseTo(property, Offset.offset(0.01));
128+
}
129+
);
130+
assertThat(rowCount).isEqualTo(11);
122131
}
123132

124133
}

proc/community/src/test/java/org/neo4j/gds/kcore/KCoreDecompositionWriteProcTest.java

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import org.neo4j.gds.extension.Inject;
3030
import org.neo4j.gds.extension.Neo4jGraph;
3131

32-
import java.util.List;
3332
import java.util.Map;
3433

3534
import static org.assertj.core.api.Assertions.assertThat;
@@ -83,8 +82,19 @@ void setup() throws Exception {
8382
@Test
8483
void shouldWrite(){
8584

86-
String query="CALL gds.kcore.write('graph', { writeProperty: 'coreValue'})";
8785

86+
String query="CALL gds.kcore.write('graph', { writeProperty: 'coreValue'})";
87+
var expectedResultMap = Map.of(
88+
idFunction.of("z"), 0L,
89+
idFunction.of("a"), 1L,
90+
idFunction.of("b"), 1L,
91+
idFunction.of("c"), 2L,
92+
idFunction.of("d"), 2L,
93+
idFunction.of("e"), 2L,
94+
idFunction.of("f"), 2L,
95+
idFunction.of("g"), 2L,
96+
idFunction.of("h"), 2L
97+
);
8898
var rowCount = runQueryWithRowConsumer(query, row -> {
8999

90100
assertThat(row.getNumber("preProcessingMillis"))
@@ -127,20 +137,15 @@ void shouldWrite(){
127137
.as("`write` mode should always return one row")
128138
.isEqualTo(1);
129139

130-
assertCypherResult(
140+
var verificationRowCount = runQueryWithRowConsumer(
131141
"MATCH (n:node) RETURN id(n) AS nodeId, n.coreValue AS coreValue",
132-
List.of(
133-
Map.of("nodeId", idFunction.of("z"), "coreValue", 0L),
134-
Map.of("nodeId", idFunction.of("a"), "coreValue", 1L),
135-
Map.of("nodeId", idFunction.of("b"), "coreValue", 1L),
136-
Map.of("nodeId", idFunction.of("c"), "coreValue", 2L),
137-
Map.of("nodeId", idFunction.of("d"), "coreValue", 2L),
138-
Map.of("nodeId", idFunction.of("e"), "coreValue", 2L),
139-
Map.of("nodeId", idFunction.of("f"), "coreValue", 2L),
140-
Map.of("nodeId", idFunction.of("g"), "coreValue", 2L),
141-
Map.of("nodeId", idFunction.of("h"), "coreValue", 2L)
142-
)
142+
(resultRow) -> {
143+
var nodeId = resultRow.getNumber("nodeId");
144+
var expected = expectedResultMap.get(nodeId);
145+
assertThat(resultRow.getNumber("coreValue")).asInstanceOf(LONG).isEqualTo(expected);
146+
}
143147
);
148+
assertThat(verificationRowCount).isEqualTo(9L);
144149
}
145150

146151
@Test

proc/community/src/test/java/org/neo4j/gds/kmeans/KmeansStreamProcTest.java

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
package org.neo4j.gds.kmeans;
2121

22+
import org.assertj.core.data.Offset;
2223
import org.junit.jupiter.api.AfterEach;
2324
import org.junit.jupiter.api.BeforeEach;
2425
import org.junit.jupiter.params.ParameterizedTest;
@@ -32,6 +33,10 @@
3233
import java.util.List;
3334
import java.util.Map;
3435

36+
import static org.assertj.core.api.Assertions.assertThat;
37+
import static org.assertj.core.api.InstanceOfAssertFactories.DOUBLE;
38+
import static org.assertj.core.api.InstanceOfAssertFactories.LONG;
39+
3540
class KmeansStreamProcTest extends BaseProcTest {
3641
@BeforeEach
3742
void setup() throws Exception {
@@ -48,7 +53,7 @@ void cleanCatalog() {
4853

4954
@ParameterizedTest
5055
@ValueSource(strings = {"gds.kmeans","gds.beta.kmeans"})
51-
void shouldWork(String procedureName) {
56+
void shouldStream(String procedureName) {
5257
String nodeCreateQuery =
5358
"CREATE" +
5459
" (a:Person {kmeans: [1.0, 1.0]} )" +
@@ -75,33 +80,47 @@ void shouldWork(String procedureName) {
7580
.addParameter("concurrency", 1)
7681
.addParameter("computeSilhouette", true)
7782
.yields("nodeId", "communityId", "distanceFromCentroid", "silhouette");
78-
assertCypherResult(algoQuery, List.of(
79-
Map.of(
80-
"nodeId", 0L,
81-
"communityId", 0L,
82-
"distanceFromCentroid", 0.5,
83-
"silhouette", 0.9929292857150108
84-
),
85-
Map.of(
86-
"nodeId", 1L,
87-
"communityId", 1L,
88-
"distanceFromCentroid", Math.sqrt(2),
89-
"silhouette", 0.9799515133128792
90-
),
91-
Map.of(
92-
"nodeId", 2L,
93-
"communityId", 0L,
94-
"distanceFromCentroid", 0.5,
95-
"silhouette", 0.9928938477702276
96-
),
97-
Map.of(
98-
"nodeId", 3L,
99-
"communityId", 1L,
100-
"distanceFromCentroid", Math.sqrt(2),
101-
"silhouette", 0.9799505034216922
102-
)
103-
104-
));
83+
84+
var expectedStreamResult = Map.of(
85+
0L, new KmeansTestStreamResult(0L, 0.5, 0.9929292857150108),
86+
1L, new KmeansTestStreamResult(1L, Math.sqrt(2), 0.9799515133128792),
87+
2L, new KmeansTestStreamResult(0L, 0.5, 0.9928938477702276),
88+
3L, new KmeansTestStreamResult(1L, Math.sqrt(2), 0.9799505034216922)
89+
90+
);
91+
92+
var rowCount = runQueryWithRowConsumer(algoQuery, (resultRow) -> {
93+
94+
var nodeId = resultRow.getNumber("nodeId");
95+
var expectedCommunity = expectedStreamResult.get(nodeId).communityId;
96+
var expectedDistance = expectedStreamResult.get(nodeId).distanceFromCentroid;
97+
var expectedsilhouette = expectedStreamResult.get(nodeId).silhouette;
98+
99+
assertThat(resultRow.getNumber("communityId")).asInstanceOf(LONG).isEqualTo(expectedCommunity);
100+
assertThat(resultRow.getNumber("distanceFromCentroid")).asInstanceOf(DOUBLE).isCloseTo(
101+
expectedDistance,
102+
Offset.offset(1e-6)
103+
);
104+
assertThat(resultRow.getNumber("silhouette")).asInstanceOf(DOUBLE).isCloseTo(
105+
expectedsilhouette,
106+
Offset.offset(1e-6)
107+
);
108+
109+
});
110+
assertThat(rowCount).isEqualTo(4l);
111+
}
112+
113+
class KmeansTestStreamResult {
114+
115+
public long communityId;
116+
public double distanceFromCentroid;
117+
public double silhouette;
118+
119+
public KmeansTestStreamResult(long communityId, double distanceFromCentroid, double silhouette) {
120+
this.communityId = communityId;
121+
this.distanceFromCentroid = distanceFromCentroid;
122+
this.silhouette = silhouette;
123+
}
105124
}
106125

107126
}

proc/community/src/test/java/org/neo4j/gds/labelpropagation/LabelPropagationWriteProcTest.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -680,9 +680,13 @@ void shouldRunLabelPropagationReverse() {
680680
"MATCH (n) RETURN n.%1$s AS community, count(*) AS communitySize",
681681
"community"
682682
);
683-
assertCypherResult(validateQuery, Arrays.asList(
684-
Map.of("community", idFunction.of("a"), "communitySize", 6L),
685-
Map.of("community", idFunction.of("b"), "communitySize", 6L)
686-
));
683+
684+
685+
var validationRowCount = runQueryWithRowConsumer(validateQuery, (resultRow) -> {
686+
assertThat(resultRow.get("community")).asInstanceOf(LONG).isIn(idFunction.of("a"), idFunction.of("b"));
687+
});
688+
689+
assertThat(validationRowCount).isEqualTo(2l);
690+
687691
}
688692
}

proc/embeddings/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageStreamProcTest.java

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
*/
2020
package org.neo4j.gds.embeddings.graphsage;
2121

22-
import org.hamcrest.Matchers;
2322
import org.junit.jupiter.api.BeforeEach;
2423
import org.junit.jupiter.api.Test;
2524
import org.junit.jupiter.params.ParameterizedTest;
@@ -50,6 +49,7 @@
5049

5150
import static org.assertj.core.api.Assertions.assertThat;
5251
import static org.assertj.core.api.Assertions.assertThatThrownBy;
52+
import static org.assertj.core.api.InstanceOfAssertFactories.LIST;
5353
import static org.neo4j.gds.ElementProjection.PROJECT_ALL;
5454

5555
@Neo4jModelCatalogExtension
@@ -177,24 +177,13 @@ void weightedGraphSage() {
177177
.addParameter("concurrency", 1)
178178
.addParameter("modelName", modelName)
179179
.yields();
180+
long[] nodeIds = idFunction.of("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o");
181+
var rowCount = runQueryWithRowConsumer(streamQuery, (resultRow) -> {
182+
assertThat(nodeIds).contains((long) resultRow.getNumber("nodeId"));
183+
assertThat(resultRow.get("embedding")).asInstanceOf(LIST).hasSize(1);
184+
});
180185

181-
assertCypherResult(streamQuery, List.of(
182-
Map.of("nodeId", idFunction.of("a"), "embedding", Matchers.iterableWithSize(1)),
183-
Map.of("nodeId", idFunction.of("b"), "embedding", Matchers.iterableWithSize(1)),
184-
Map.of("nodeId", idFunction.of("c"), "embedding", Matchers.iterableWithSize(1)),
185-
Map.of("nodeId", idFunction.of("d"), "embedding", Matchers.iterableWithSize(1)),
186-
Map.of("nodeId", idFunction.of("e"), "embedding", Matchers.iterableWithSize(1)),
187-
Map.of("nodeId", idFunction.of("f"), "embedding", Matchers.iterableWithSize(1)),
188-
Map.of("nodeId", idFunction.of("g"), "embedding", Matchers.iterableWithSize(1)),
189-
Map.of("nodeId", idFunction.of("h"), "embedding", Matchers.iterableWithSize(1)),
190-
Map.of("nodeId", idFunction.of("i"), "embedding", Matchers.iterableWithSize(1)),
191-
Map.of("nodeId", idFunction.of("j"), "embedding", Matchers.iterableWithSize(1)),
192-
Map.of("nodeId", idFunction.of("k"), "embedding", Matchers.iterableWithSize(1)),
193-
Map.of("nodeId", idFunction.of("l"), "embedding", Matchers.iterableWithSize(1)),
194-
Map.of("nodeId", idFunction.of("m"), "embedding", Matchers.iterableWithSize(1)),
195-
Map.of("nodeId", idFunction.of("n"), "embedding", Matchers.iterableWithSize(1)),
196-
Map.of("nodeId", idFunction.of("o"), "embedding", Matchers.iterableWithSize(1))
197-
));
186+
assertThat(rowCount).isEqualTo(15);
198187
}
199188

200189
@ParameterizedTest(name = "Graph Properties: {2} - Algo Properties: {1}")

0 commit comments

Comments
 (0)