Skip to content

Commit 7e7bafa

Browse files
authored
Merge pull request #6600 from DarthMax/cypher-agg-undirected
CypherAggregation undirected
2 parents 91e0142 + ee24d2b commit 7e7bafa

File tree

4 files changed

+66
-7
lines changed

4 files changed

+66
-7
lines changed

cypher-aggregation/src/main/java/org/neo4j/gds/projection/CypherAggregation.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
import java.util.function.Supplier;
8383

8484
import static org.neo4j.gds.Orientation.NATURAL;
85+
import static org.neo4j.gds.Orientation.UNDIRECTED;
8586
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
8687

8788
public final class CypherAggregation {
@@ -192,7 +193,7 @@ public void update(
192193
}
193194
}
194195

195-
var relImporter = this.relImporters.computeIfAbsent(relationshipType, type -> newRelImporter());
196+
var relImporter = this.relImporters.computeIfAbsent(relationshipType, this::newRelImporter);
196197

197198
var intermediateSourceId = loadNode(sourceNode, sourceNodeLabels, sourceNodePropertyValues);
198199

@@ -386,12 +387,17 @@ private RelationshipType typeConfig(
386387
));
387388
}
388389

389-
private RelationshipsBuilder newRelImporter() {
390+
private RelationshipsBuilder newRelImporter(RelationshipType relType) {
390391
assert this.idMapBuilder != null;
391392

393+
var undirectedTypes = config.undirectedRelationshipTypes();
394+
var orientation = undirectedTypes.contains(relType.name) || undirectedTypes.contains("*")
395+
? UNDIRECTED
396+
: NATURAL;
397+
392398
var relationshipsBuilderBuilder = GraphFactory.initRelationshipsBuilder()
393399
.nodes(this.idMapBuilder)
394-
.orientation(NATURAL)
400+
.orientation(orientation)
395401
.aggregation(Aggregation.NONE)
396402
.concurrency(config.readConcurrency());
397403

@@ -692,6 +698,12 @@ default Aggregation aggregation() {
692698
return Aggregation.NONE;
693699
}
694700

701+
@org.immutables.value.Value.Default
702+
@NotNull
703+
default List<String> undirectedRelationshipTypes() {
704+
return List.of();
705+
}
706+
695707
@Configuration.Ignore
696708
@Override
697709
default GraphStoreFactory.Supplier graphStoreFactory() {

cypher-aggregation/src/test/java/org/neo4j/gds/projection/CypherAggregationTest.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
import org.junit.jupiter.api.Nested;
2525
import org.junit.jupiter.api.Test;
2626
import org.junit.jupiter.params.ParameterizedTest;
27+
import org.junit.jupiter.params.provider.Arguments;
2728
import org.junit.jupiter.params.provider.CsvSource;
29+
import org.junit.jupiter.params.provider.MethodSource;
2830
import org.junit.jupiter.params.provider.ValueSource;
2931
import org.neo4j.gds.BaseProcTest;
3032
import org.neo4j.gds.BaseTest;
@@ -49,6 +51,7 @@
4951
import java.util.Optional;
5052
import java.util.stream.Collectors;
5153
import java.util.stream.LongStream;
54+
import java.util.stream.Stream;
5255
import java.util.stream.StreamSupport;
5356

5457
import static org.assertj.core.api.Assertions.assertThat;
@@ -578,6 +581,48 @@ void testRelationshipPropertiesAggregation() {
578581
});
579582
}
580583

584+
@ParameterizedTest
585+
@MethodSource("undirectedTypes")
586+
void testRespectUndirectedTypes(List<String> undirectedConfig, List<String> expectedUndirectedTypes) {
587+
runQuery(
588+
"UNWIND [" +
589+
" {s: 0, t: 1, type: 'UNDIRECTED'}, " +
590+
" {s: 1, t: 2, type: 'DIRECTED'} " +
591+
"] as d" +
592+
" RETURN gds.alpha.graph.project('g', d.s, d.t, null, {relationshipType: d.type}, {undirectedRelationshipTypes: $undirected})",
593+
Map.of("undirected", undirectedConfig)
594+
);
595+
596+
assertThat(GraphStoreCatalog.exists("", db.databaseName(), "g")).isTrue();
597+
var graphStore = GraphStoreCatalog.get("", db.databaseName(), "g").graphStore();
598+
599+
600+
graphStore
601+
.schema()
602+
.relationshipSchema()
603+
.entries()
604+
.forEach(entry -> {
605+
var expectedDirection = expectedUndirectedTypes.contains(entry.identifier.name) ? org.neo4j.gds.api.schema.Direction.UNDIRECTED : org.neo4j.gds.api.schema.Direction.DIRECTED;
606+
assertThat(entry.direction()).isEqualTo(expectedDirection);
607+
});
608+
}
609+
static Stream<Arguments> undirectedTypes() {
610+
return Stream.of(
611+
Arguments.of(
612+
List.of("UNDIRECTED"),
613+
List.of("UNDIRECTED")
614+
),
615+
Arguments.of(
616+
List.of("UNDIRECTED", "DIRECTED"),
617+
List.of("UNDIRECTED", "DIRECTED")
618+
),
619+
Arguments.of(
620+
List.of("*"),
621+
List.of("UNDIRECTED", "DIRECTED")
622+
)
623+
);
624+
}
625+
581626
@Test
582627
void testPipelinePseudoAnonymous() {
583628
assertCypherResult(

doc/modules/ROOT/pages/management-ops/projections/graph-project-cypher-aggregation.adoc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@ RETURN gds.alpha.graph.project(
6868
.Configuration
6969
[opts="header",cols="1,1,1,4"]
7070
|===
71-
| Name | Type | Default | Description
72-
| readConcurrency | Integer | 4 | The number of concurrent threads used for creating the graph.
71+
| Name | Type | Default | Description
72+
| readConcurrency | Integer | 4 | The number of concurrent threads used for creating the graph.
73+
| undirectedRelationshipTypes | List of String | [] | Declare a number of relationship types as undirected. Relationships with the specified types will be imported as undirected. `*` can be used to declare all relationship types as undirected.
7374
|===
7475

7576

proc/catalog/src/test/java/org/neo4j/gds/catalog/GraphListProcTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,9 +414,10 @@ void listCypherAggregation() {
414414
"configuration", new Condition<>(config -> {
415415
assertThat(config)
416416
.asInstanceOf(stringObjectMapAssertFactory())
417-
.hasSize(2)
417+
.hasSize(3)
418418
.hasEntrySatisfying("creationTime", creationTimeAssertConsumer())
419-
.hasEntrySatisfying("jobId", jobId -> assertThat(jobId).isNotNull());
419+
.hasEntrySatisfying("jobId", jobId -> assertThat(jobId).isNotNull())
420+
.hasEntrySatisfying("undirectedRelationshipTypes", t -> assertThat(t).isEqualTo(List.of()));
420421

421422
return true;
422423
}, "Assert Cypher Aggregation `configuration` map"),

0 commit comments

Comments
 (0)