Skip to content

Commit 14886b8

Browse files
committed
Validate triangle cout node label filter config
1 parent d22cf28 commit 14886b8

File tree

2 files changed

+132
-64
lines changed

2 files changed

+132
-64
lines changed

proc/community/src/integrationTest/java/org/neo4j/gds/triangle/TriangleCountMutateProcTest.java

Lines changed: 97 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import java.util.Map;
3434

3535
import static org.assertj.core.api.Assertions.assertThat;
36+
import static org.assertj.core.api.Assertions.assertThatRuntimeException;
3637
import static org.assertj.core.api.InstanceOfAssertFactories.LONG;
3738
import static org.neo4j.gds.TestSupport.assertGraphEquals;
3839
import static org.neo4j.gds.TestSupport.fromGdl;
@@ -41,21 +42,21 @@ class TriangleCountMutateProcTest extends BaseProcTest {
4142

4243
@Neo4jGraph
4344
public static final String DB_CYPHER = "CREATE " +
44-
"(a:A)-[:T]->(b:A), " +
45-
"(b)-[:T]->(c:A), " +
46-
"(c)-[:T]->(a)";
45+
"(a:A)-[:T]->(b:A), " +
46+
"(b)-[:T]->(c:A), " +
47+
"(c)-[:T]->(a)";
4748

4849
String expectedMutatedGraph =
4950
" (a: A { mutatedTriangleCount: 1 })" +
50-
", (b: A { mutatedTriangleCount: 1 })" +
51-
", (c: A { mutatedTriangleCount: 1 })" +
52-
// Graph is UNDIRECTED, e.g. each rel twice
53-
", (a)-[:T]->(b)" +
54-
", (b)-[:T]->(a)" +
55-
", (b)-[:T]->(c)" +
56-
", (c)-[:T]->(b)" +
57-
", (a)-[:T]->(c)" +
58-
", (c)-[:T]->(a)";
51+
", (b: A { mutatedTriangleCount: 1 })" +
52+
", (c: A { mutatedTriangleCount: 1 })" +
53+
// Graph is UNDIRECTED, e.g. each rel twice
54+
", (a)-[:T]->(b)" +
55+
", (b)-[:T]->(a)" +
56+
", (b)-[:T]->(c)" +
57+
", (c)-[:T]->(b)" +
58+
", (a)-[:T]->(c)" +
59+
", (c)-[:T]->(a)";
5960

6061
@BeforeEach
6162
void setup() throws Exception {
@@ -77,41 +78,43 @@ void shouldMutateYield() {
7778
.yields();
7879

7980

80-
var rowCount = runQueryWithRowConsumer(query, row -> {
81-
assertThat(row.getNumber("globalTriangleCount"))
82-
.asInstanceOf(LONG)
83-
.isEqualTo(1L);
81+
var rowCount = runQueryWithRowConsumer(
82+
query, row -> {
83+
assertThat(row.getNumber("globalTriangleCount"))
84+
.asInstanceOf(LONG)
85+
.isEqualTo(1L);
8486

85-
assertThat(row.getNumber("nodeCount"))
86-
.asInstanceOf(LONG)
87-
.isEqualTo(3L);
87+
assertThat(row.getNumber("nodeCount"))
88+
.asInstanceOf(LONG)
89+
.isEqualTo(3L);
8890

89-
assertThat(row.getNumber("preProcessingMillis"))
90-
.asInstanceOf(LONG)
91-
.isGreaterThan(-1L);
91+
assertThat(row.getNumber("preProcessingMillis"))
92+
.asInstanceOf(LONG)
93+
.isGreaterThan(-1L);
9294

93-
assertThat(row.getNumber("computeMillis"))
94-
.asInstanceOf(LONG)
95-
.isGreaterThan(-1L);
95+
assertThat(row.getNumber("computeMillis"))
96+
.asInstanceOf(LONG)
97+
.isGreaterThan(-1L);
9698

97-
assertThat(row.getNumber("postProcessingMillis"))
98-
.asInstanceOf(LONG)
99-
.isGreaterThan(-1L);
99+
assertThat(row.getNumber("postProcessingMillis"))
100+
.asInstanceOf(LONG)
101+
.isGreaterThan(-1L);
100102

101-
assertThat(row.getNumber("mutateMillis"))
102-
.asInstanceOf(LONG)
103-
.isGreaterThan(-1L);
103+
assertThat(row.getNumber("mutateMillis"))
104+
.asInstanceOf(LONG)
105+
.isGreaterThan(-1L);
104106

105-
assertThat(row.get("configuration"))
106-
.isInstanceOf(Map.class);
107+
assertThat(row.get("configuration"))
108+
.isInstanceOf(Map.class);
107109

108-
assertThat(row.getNumber("nodePropertiesWritten"))
109-
.asInstanceOf(LONG)
110-
.isEqualTo(3L);
111-
});
110+
assertThat(row.getNumber("nodePropertiesWritten"))
111+
.asInstanceOf(LONG)
112+
.isEqualTo(3L);
113+
}
114+
);
112115
assertThat(rowCount).isEqualTo(1L);
113116

114-
117+
115118
Graph mutatedGraph = GraphStoreCatalog
116119
.get(getUsername(), DatabaseId.of(db.databaseName()), "graph")
117120
.graphStore()
@@ -125,8 +128,8 @@ void shouldMutateWithMaxDegree() {
125128
// Add a single node and connect it to the triangle
126129
// to be able to apply the maxDegree filter.
127130
runQuery("MATCH (n) " +
128-
"WITH n LIMIT 1 " +
129-
"CREATE (d)-[:REL]->(n)");
131+
"WITH n LIMIT 1 " +
132+
"CREATE (d)-[:REL]->(n)");
130133

131134
var createQuery = GdsCypher.call("testGraph")
132135
.graphProject()
@@ -142,19 +145,21 @@ void shouldMutateWithMaxDegree() {
142145
.addParameter("maxDegree", 2)
143146
.yields("globalTriangleCount", "nodeCount", "nodePropertiesWritten");
144147

145-
var rowCount = runQueryWithRowConsumer(query, row -> {
146-
assertThat(row.getNumber("globalTriangleCount"))
147-
.asInstanceOf(LONG)
148-
.isEqualTo(0L);
148+
var rowCount = runQueryWithRowConsumer(
149+
query, row -> {
150+
assertThat(row.getNumber("globalTriangleCount"))
151+
.asInstanceOf(LONG)
152+
.isEqualTo(0L);
149153

150-
assertThat(row.getNumber("nodeCount"))
151-
.asInstanceOf(LONG)
152-
.isEqualTo(4L);
154+
assertThat(row.getNumber("nodeCount"))
155+
.asInstanceOf(LONG)
156+
.isEqualTo(4L);
153157

154-
assertThat(row.getNumber("nodePropertiesWritten"))
155-
.asInstanceOf(LONG)
156-
.isEqualTo(4L);
157-
});
158+
assertThat(row.getNumber("nodePropertiesWritten"))
159+
.asInstanceOf(LONG)
160+
.isEqualTo(4L);
161+
}
162+
);
158163

159164
assertThat(rowCount).isEqualTo(1L);
160165

@@ -166,19 +171,47 @@ void shouldMutateWithMaxDegree() {
166171
assertGraphEquals(
167172
fromGdl(
168173
" (a { mutatedTriangleCount: -1 })" +
169-
", (b { mutatedTriangleCount: 0 })" +
170-
", (c { mutatedTriangleCount: 0 })" +
171-
", (d { mutatedTriangleCount: 0 })" +
172-
// Graph is UNDIRECTED, e.g. each rel twice
173-
", (a)-->(b)" +
174-
", (b)-->(a)" +
175-
", (b)-->(c)" +
176-
", (c)-->(b)" +
177-
", (a)-->(c)" +
178-
", (c)-->(a)" +
179-
", (d)-->(a)" +
180-
", (a)-->(d)"
181-
), mutatedGraph);
174+
", (b { mutatedTriangleCount: 0 })" +
175+
", (c { mutatedTriangleCount: 0 })" +
176+
", (d { mutatedTriangleCount: 0 })" +
177+
// Graph is UNDIRECTED, e.g. each rel twice
178+
", (a)-->(b)" +
179+
", (b)-->(a)" +
180+
", (b)-->(c)" +
181+
", (c)-->(b)" +
182+
", (a)-->(c)" +
183+
", (c)-->(a)" +
184+
", (d)-->(a)" +
185+
", (a)-->(d)"
186+
), mutatedGraph
187+
);
188+
}
189+
190+
@Test
191+
void shouldAcceptValidNodeLabelFilter() {
192+
String query = GdsCypher
193+
.call(DEFAULT_GRAPH_NAME)
194+
.algo("triangleCount")
195+
.mutateMode()
196+
.addParameter("mutateProperty", "mutatedTriangleCount")
197+
.addParameter("aLabel", "A")
198+
.yields();
199+
}
200+
201+
@Test
202+
void shouldThrowOnInvalidNodeLabelFilter() {
203+
String query = GdsCypher
204+
.call(DEFAULT_GRAPH_NAME)
205+
.algo("triangleCount")
206+
.mutateMode()
207+
.addParameter("mutateProperty", "mutatedTriangleCount")
208+
.addParameter("aLabel", "X")
209+
.yields();
210+
211+
assertThatRuntimeException()
212+
.isThrownBy(() -> runQuery(query))
213+
.withMessage(
214+
"Failed to invoke procedure `gds.triangleCount.mutate`: Caused by: java.lang.IllegalArgumentException: TriangleCount requires the provided 'aLabel' node label 'X' to be present in the graph.");
182215
}
183216
}
184217

procedures/facade-api/configs/community-configs/src/main/java/org/neo4j/gds/triangle/TriangleCountBaseConfig.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,41 @@ default void validateTargetRelIsUndirected(
6868
}
6969
}
7070

71+
@Configuration.GraphStoreValidationCheck
72+
default void validateLabelFilter(
73+
GraphStore ignored,
74+
Collection<NodeLabel> nodeLabels,
75+
Collection<RelationshipType> algoIgnored
76+
) {
77+
if (aLabel().isPresent()) {
78+
var _aLabel = NodeLabel.of(aLabel().get());
79+
if (!nodeLabels.contains(_aLabel)) {
80+
throw new IllegalArgumentException(formatWithLocale(
81+
"TriangleCount requires the provided 'aLabel' node label '%s' to be present in the graph.",
82+
_aLabel.name()
83+
));
84+
}
85+
}
86+
if (bLabel().isPresent()) {
87+
var _bLabel = NodeLabel.of(bLabel().get());
88+
if (!nodeLabels.contains(_bLabel)) {
89+
throw new IllegalArgumentException(formatWithLocale(
90+
"TriangleCount requires the provided 'bLabel' node label '%s' to be present in the graph.",
91+
_bLabel.name()
92+
));
93+
}
94+
}
95+
if (cLabel().isPresent()) {
96+
var _cLabel = NodeLabel.of(cLabel().get());
97+
if (!nodeLabels.contains(_cLabel)) {
98+
throw new IllegalArgumentException(formatWithLocale(
99+
"TriangleCount requires the provided 'cLabel' node label '%s' to be present in the graph.",
100+
_cLabel.name()
101+
));
102+
}
103+
}
104+
}
105+
71106
static TriangleCountBaseConfig of(CypherMapWrapper userInput) {
72107
return new TriangleCountBaseConfigImpl(userInput);
73108
}

0 commit comments

Comments
 (0)