Skip to content

Commit aeabbba

Browse files
committed
Add Steiner write procedure
1 parent 6e74f73 commit aeabbba

File tree

9 files changed

+433
-6
lines changed

9 files changed

+433
-6
lines changed

algo/src/main/java/org/neo4j/gds/spanningtree/SpanningGraph.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ public SpanningGraph(Graph graph, SpanningTree spanningTree) {
3737

3838
@Override
3939
public int degree(long nodeId) {
40-
if (spanningTree.parent.get(nodeId) == -1) {
41-
return Math.toIntExact(Arrays.stream(spanningTree.parent.toArray()).filter(i -> i == -1).count());
40+
if (spanningTree.parent.get(nodeId) < 0) {
41+
return Math.toIntExact(Arrays.stream(spanningTree.parent.toArray()).filter(i -> i < 0).count());
4242
} else {
4343
return 1;
4444
}
@@ -56,7 +56,7 @@ public void forEachRelationship(long nodeId, RelationshipConsumer consumer) {
5656
@Override
5757
public void forEachRelationship(long nodeId, double fallbackValue, RelationshipWithPropertyConsumer consumer) {
5858
long parent = spanningTree.parent.get(nodeId);
59-
if (parent != -1) {
59+
if (parent >=0) {
6060
consumer.accept(parent, nodeId, spanningTree.costToParent(nodeId));
6161
}
6262
}

algo/src/main/java/org/neo4j/gds/spanningtree/SpanningTree.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public class SpanningTree {
3838
final HugeLongArray parent;
3939
final double totalWeight;
4040

41-
SpanningTree(
41+
public SpanningTree(
4242
long head,
4343
long nodeCount,
4444
long effectiveNodeCount,
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.steiner;
21+
22+
import org.neo4j.gds.annotation.Configuration;
23+
import org.neo4j.gds.annotation.ValueClass;
24+
import org.neo4j.gds.config.WritePropertyConfig;
25+
import org.neo4j.gds.config.WriteRelationshipConfig;
26+
import org.neo4j.gds.core.CypherMapWrapper;
27+
28+
@ValueClass
29+
@Configuration
30+
public interface SteinerTreeWriteConfig extends SteinerTreeBaseConfig, WriteRelationshipConfig, WritePropertyConfig {
31+
32+
static SteinerTreeWriteConfig of(CypherMapWrapper userInput) {
33+
return new SteinerTreeWriteConfigImpl(userInput);
34+
}
35+
36+
}

doc/modules/ROOT/pages/operations-reference/algorithm-references.adoc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,10 @@
360360
.2+<.^| xref:algorithms/alpha/modularity.adoc[Modularity Metric]
361361
| `gds.alpha.modularity.stats`
362362
| `gds.alpha.modularity.stream`
363-
.3+<.^| Directer Steiner Tree
363+
.4+<.^| Directer Steiner Tree
364364
| `gds.alpha.steinerTree.mutate`
365365
| `gds.alpha.steinerTree.stats`
366366
| `gds.alpha.steinerTree.stream`
367+
| `gds.alpha.steinerTree.write`
367368

368369
|===

open-packaging/src/test/java/org/neo4j/gds/OpenGdsProcedureSmokeTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ class OpenGdsProcedureSmokeTest extends BaseProcTest {
190190
"gds.alpha.steinerTree.mutate",
191191
"gds.alpha.steinerTree.stats",
192192
"gds.alpha.steinerTree.stream",
193+
"gds.alpha.steinerTree.write",
193194

194195
"gds.alpha.triangles",
195196
"gds.alpha.ml.splitRelationships.mutate",
@@ -526,7 +527,7 @@ void countShouldMatch() {
526527
);
527528

528529
// If you find yourself updating this count, please also update the count in SmokeTest.kt
529-
int expectedCount = 374;
530+
int expectedCount = 375;
530531
assertEquals(
531532
expectedCount,
532533
registeredProcedures.size(),
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.paths.steiner;
21+
22+
import org.neo4j.gds.BaseProc;
23+
import org.neo4j.gds.core.write.RelationshipExporter;
24+
import org.neo4j.gds.core.write.RelationshipExporterBuilder;
25+
import org.neo4j.gds.executor.ExecutionContext;
26+
import org.neo4j.gds.executor.ImmutableExecutionContext;
27+
import org.neo4j.gds.executor.ProcedureExecutor;
28+
import org.neo4j.procedure.Context;
29+
import org.neo4j.procedure.Description;
30+
import org.neo4j.procedure.Name;
31+
import org.neo4j.procedure.Procedure;
32+
33+
import java.util.Map;
34+
import java.util.stream.Stream;
35+
36+
import static org.neo4j.procedure.Mode.WRITE;
37+
38+
public class SteinerTreeWriteProc extends BaseProc {
39+
40+
@Context
41+
public RelationshipExporterBuilder<? extends RelationshipExporter> relationshipExporterBuilder;
42+
43+
@Procedure(value = "gds.alpha.steinerTree.write", mode = WRITE)
44+
@Description(SteinerTreeStatsProc.DESCRIPTION)
45+
public Stream<WriteResult> write(
46+
@Name(value = "graphName") String graphName,
47+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
48+
) {
49+
return new ProcedureExecutor<>(
50+
new SteinerTreeWriteSpec(),
51+
executionContext()
52+
).compute(graphName, configuration, true, true);
53+
}
54+
55+
@Override
56+
public ExecutionContext executionContext() {
57+
return ImmutableExecutionContext
58+
.builder()
59+
.databaseService(databaseService)
60+
.log(log)
61+
.procedureTransaction(procedureTransaction)
62+
.transaction(transaction)
63+
.callContext(callContext)
64+
.userLogRegistryFactory(userLogRegistryFactory)
65+
.taskRegistryFactory(taskRegistryFactory)
66+
.username(username())
67+
.relationshipExporterBuilder(relationshipExporterBuilder)
68+
.build();
69+
}
70+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.paths.steiner;
21+
22+
import org.neo4j.gds.core.utils.ProgressTimer;
23+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
24+
import org.neo4j.gds.core.write.RelationshipExporter;
25+
import org.neo4j.gds.core.write.RelationshipExporterBuilder;
26+
import org.neo4j.gds.executor.AlgorithmSpec;
27+
import org.neo4j.gds.executor.ComputationResultConsumer;
28+
import org.neo4j.gds.executor.GdsCallable;
29+
import org.neo4j.gds.executor.NewConfigFunction;
30+
import org.neo4j.gds.spanningtree.SpanningGraph;
31+
import org.neo4j.gds.spanningtree.SpanningTree;
32+
import org.neo4j.gds.steiner.ShortestPathsSteinerAlgorithm;
33+
import org.neo4j.gds.steiner.SteinerTreeAlgorithmFactory;
34+
import org.neo4j.gds.steiner.SteinerTreeResult;
35+
import org.neo4j.gds.steiner.SteinerTreeWriteConfig;
36+
37+
import java.util.stream.Stream;
38+
39+
import static org.neo4j.gds.executor.ExecutionMode.WRITE_RELATIONSHIP;
40+
41+
@GdsCallable(name = "gds.alpha.SteinerTree.write", description = SteinerTreeStatsProc.DESCRIPTION, executionMode = WRITE_RELATIONSHIP)
42+
public class SteinerTreeWriteSpec implements AlgorithmSpec<ShortestPathsSteinerAlgorithm, SteinerTreeResult, SteinerTreeWriteConfig, Stream<WriteResult>, SteinerTreeAlgorithmFactory<SteinerTreeWriteConfig>> {
43+
44+
@Override
45+
public String name() {
46+
return "SteinerTreeWrite";
47+
}
48+
49+
@Override
50+
public SteinerTreeAlgorithmFactory<SteinerTreeWriteConfig> algorithmFactory() {
51+
return new SteinerTreeAlgorithmFactory<>();
52+
}
53+
54+
@Override
55+
public NewConfigFunction<SteinerTreeWriteConfig> newConfigFunction() {
56+
return (__, config) -> SteinerTreeWriteConfig.of(config);
57+
}
58+
59+
public ComputationResultConsumer<ShortestPathsSteinerAlgorithm, SteinerTreeResult, SteinerTreeWriteConfig, Stream<WriteResult>> computationResultConsumer() {
60+
61+
return (computationResult, executionContext) -> {
62+
var config = computationResult.config();
63+
var terminationFlag = computationResult.algorithm().getTerminationFlag();
64+
var sourceNode = config.sourceNode();
65+
var graph = computationResult.graph();
66+
var steinerTreeResult = computationResult.result();
67+
68+
var builder = new WriteResult.Builder();
69+
70+
builder
71+
.withEffectiveNodeCount(steinerTreeResult.effectiveNodeCount())
72+
.withEffectiveTargetNodeCount(steinerTreeResult.effectiveTargetNodesCount())
73+
.withTotalWeight(steinerTreeResult.totalCost());
74+
75+
try (ProgressTimer ignored = ProgressTimer.start(builder::withWriteMillis)) {
76+
var spanningTree = new SpanningTree(
77+
graph.toMappedNodeId(sourceNode),
78+
graph.nodeCount(),
79+
steinerTreeResult.effectiveNodeCount(),
80+
steinerTreeResult.parentArray(),
81+
steinerTreeResult.relationshipToParentCost(),
82+
steinerTreeResult.totalCost()
83+
);
84+
var spanningGraph = new SpanningGraph(graph, spanningTree);
85+
86+
RelationshipExporterBuilder<? extends RelationshipExporter> relationshipExporterBuilder = executionContext.relationshipExporterBuilder();
87+
relationshipExporterBuilder
88+
.withGraph(spanningGraph)
89+
.withIdMappingOperator(spanningGraph::toOriginalNodeId)
90+
.withTerminationFlag(terminationFlag)
91+
.withProgressTracker(ProgressTracker.NULL_TRACKER)
92+
.build()
93+
.write(
94+
config.writeRelationshipType(),
95+
config.writeProperty()
96+
);
97+
98+
}
99+
builder
100+
.withComputeMillis(computationResult.computeMillis())
101+
.withPreProcessingMillis(computationResult.preProcessingMillis())
102+
.withRelationshipsWritten(steinerTreeResult.effectiveNodeCount() - 1)
103+
.withConfig(config);
104+
105+
return Stream.of(builder.build());
106+
};
107+
}
108+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.paths.steiner;
21+
22+
import org.neo4j.gds.result.AbstractResultBuilder;
23+
24+
import java.util.Map;
25+
26+
public final class WriteResult extends StatsResult {
27+
28+
29+
public final long writeMillis;
30+
public final long relationshipsWritten;
31+
32+
public WriteResult(
33+
long preProcessingMillis,
34+
long computeMillis,
35+
long writeMillis,
36+
long effectiveNodeCount,
37+
long effectiveTargetNodesCount,
38+
double totalCost,
39+
long relationshipsWritten,
40+
Map<String, Object> configuration
41+
) {
42+
super(preProcessingMillis, computeMillis, effectiveNodeCount, effectiveTargetNodesCount, totalCost, configuration);
43+
this.writeMillis = writeMillis;
44+
this.relationshipsWritten = relationshipsWritten;
45+
}
46+
47+
public static class Builder extends AbstractResultBuilder<WriteResult> {
48+
49+
long effectiveNodeCount;
50+
long effectiveTargetNodesCount;
51+
double totalWeight;
52+
53+
Builder withEffectiveNodeCount(long effectiveNodeCount) {
54+
this.effectiveNodeCount = effectiveNodeCount;
55+
return this;
56+
}
57+
58+
Builder withEffectiveTargetNodeCount(long effectiveTargetNodesCount) {
59+
this.effectiveTargetNodesCount = effectiveTargetNodesCount;
60+
return this;
61+
}
62+
63+
Builder withTotalWeight(double totalWeight) {
64+
this.totalWeight = totalWeight;
65+
return this;
66+
}
67+
68+
@Override
69+
public WriteResult build() {
70+
return new WriteResult(
71+
preProcessingMillis,
72+
computeMillis,
73+
writeMillis,
74+
effectiveNodeCount,
75+
effectiveTargetNodesCount,
76+
totalWeight,
77+
relationshipsWritten,
78+
config.toMap()
79+
);
80+
}
81+
}
82+
}

0 commit comments

Comments
 (0)