Skip to content

Commit 161e129

Browse files
Do for steiner tree
1 parent e7eda51 commit 161e129

File tree

4 files changed

+223
-3
lines changed

4 files changed

+223
-3
lines changed

procedures/pushback-procedures-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/write/SpanningTreeWriteResultTransformerBuilder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,6 @@ public ResultTransformer<TimedAlgorithmResult<SpanningTree>, Stream<SpanningTree
5252
graphResources.resultStore(),
5353
config.jobId(),
5454
config.toMap()
55-
); }
55+
);
56+
}
5657
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.procedures.algorithms.pathfinding.write;
21+
22+
import org.neo4j.gds.api.Graph;
23+
import org.neo4j.gds.api.GraphStore;
24+
import org.neo4j.gds.api.ResultStore;
25+
import org.neo4j.gds.applications.algorithms.metadata.RelationshipsWritten;
26+
import org.neo4j.gds.core.utils.ProgressTimer;
27+
import org.neo4j.gds.core.utils.progress.JobId;
28+
import org.neo4j.gds.pathfinding.SteinerTreeWriteStep;
29+
import org.neo4j.gds.procedures.algorithms.pathfinding.SteinerWriteResult;
30+
import org.neo4j.gds.result.TimedAlgorithmResult;
31+
import org.neo4j.gds.results.ResultTransformer;
32+
import org.neo4j.gds.steiner.SteinerTreeResult;
33+
34+
import java.util.Map;
35+
import java.util.concurrent.atomic.AtomicLong;
36+
import java.util.stream.Stream;
37+
38+
public class SteinerTreeWriteResultTransformer implements ResultTransformer<TimedAlgorithmResult<SteinerTreeResult>, Stream<SteinerWriteResult>> {
39+
40+
private final SteinerTreeWriteStep writeStep;
41+
private final Graph graph;
42+
private final GraphStore graphStore;
43+
@Deprecated(forRemoval = true)
44+
private final ResultStore resultStore;
45+
private final JobId jobId;
46+
private final Map<String, Object> configuration;
47+
48+
public SteinerTreeWriteResultTransformer(
49+
SteinerTreeWriteStep writeStep,
50+
Graph graph,
51+
GraphStore graphStore,
52+
ResultStore resultStore,
53+
JobId jobId,
54+
Map<String, Object> configuration
55+
) {
56+
this.writeStep = writeStep;
57+
this.graph = graph;
58+
this.graphStore = graphStore;
59+
this.resultStore = resultStore;
60+
this.jobId = jobId;
61+
this.configuration = configuration;
62+
}
63+
64+
@Override
65+
public Stream<SteinerWriteResult> apply(TimedAlgorithmResult<SteinerTreeResult> algorithmResult) {
66+
67+
RelationshipsWritten relationshipsWritten;
68+
var writeMillis = new AtomicLong();
69+
var result = algorithmResult.result();
70+
try (var ignored = ProgressTimer.start(writeMillis::set)) {
71+
relationshipsWritten = writeStep.execute(
72+
graph,
73+
graphStore,
74+
resultStore,
75+
result,
76+
jobId
77+
);
78+
}
79+
80+
return Stream.of(
81+
new SteinerWriteResult(
82+
0,
83+
algorithmResult.computeMillis(),
84+
writeMillis.get(),
85+
result.effectiveNodeCount(),
86+
result.effectiveTargetNodesCount(),
87+
result.totalCost(),
88+
relationshipsWritten.value(),
89+
configuration
90+
)
91+
);
92+
}
93+
94+
}

procedures/pushback-procedures-facade/src/main/java/org/neo4j/gds/procedures/algorithms/pathfinding/write/SteinerTreeWriteResultTransformerBuilder.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,25 @@
3131
import java.util.stream.Stream;
3232

3333
class SteinerTreeWriteResultTransformerBuilder implements ResultTransformerBuilder<TimedAlgorithmResult<SteinerTreeResult>, Stream<SteinerWriteResult>> {
34-
SteinerTreeWriteResultTransformerBuilder(SteinerTreeWriteStep writeStep, SteinerTreeWriteConfig config) {}
34+
private final SteinerTreeWriteStep writeStep;
35+
private final SteinerTreeWriteConfig config;
36+
37+
SteinerTreeWriteResultTransformerBuilder(SteinerTreeWriteStep writeStep, SteinerTreeWriteConfig config) {
38+
this.writeStep = writeStep;
39+
this.config = config;
40+
}
41+
3542

3643
@Override
3744
public ResultTransformer<TimedAlgorithmResult<SteinerTreeResult>, Stream<SteinerWriteResult>> build(GraphResources graphResources) {
38-
return ar -> Stream.empty();
45+
return new SteinerTreeWriteResultTransformer(
46+
writeStep,
47+
graphResources.graph(),
48+
graphResources.graphStore(),
49+
graphResources.resultStore(),
50+
config.jobId(),
51+
config.toMap()
52+
);
3953
}
54+
4055
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.procedures.algorithms.pathfinding.write;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.api.Graph;
24+
import org.neo4j.gds.api.GraphStore;
25+
import org.neo4j.gds.api.ResultStore;
26+
import org.neo4j.gds.applications.algorithms.metadata.RelationshipsWritten;
27+
import org.neo4j.gds.core.utils.progress.JobId;
28+
import org.neo4j.gds.pathfinding.SteinerTreeWriteStep;
29+
import org.neo4j.gds.result.TimedAlgorithmResult;
30+
import org.neo4j.gds.steiner.SteinerTreeResult;
31+
32+
import java.util.Map;
33+
34+
import static org.assertj.core.api.Assertions.assertThat;
35+
import static org.mockito.ArgumentMatchers.any;
36+
import static org.mockito.Mockito.mock;
37+
import static org.mockito.Mockito.times;
38+
import static org.mockito.Mockito.verify;
39+
import static org.mockito.Mockito.verifyNoMoreInteractions;
40+
import static org.mockito.Mockito.when;
41+
42+
class SteinerTreeWriteResultTransformerTest {
43+
44+
@Test
45+
void shouldTransformToWriteResult() {
46+
var config = Map.<String, Object>of("foo", "bar");
47+
var graph = mock(Graph.class);
48+
var graphStore = mock(GraphStore.class);
49+
var resultStore = mock(ResultStore.class);
50+
var jobId = new JobId();
51+
var writeStep = mock(SteinerTreeWriteStep.class);
52+
53+
var algoResult = mock(SteinerTreeResult.class);
54+
when(algoResult.effectiveNodeCount()).thenReturn(1L);
55+
when(algoResult.effectiveTargetNodesCount()).thenReturn(2L);
56+
when(algoResult.totalCost()).thenReturn(3d);
57+
58+
var relationshipsWritten = new RelationshipsWritten(5L);
59+
when(writeStep.execute(any(), any(), any(), any(), any())).thenReturn(relationshipsWritten);
60+
61+
var timedResult = new TimedAlgorithmResult<>(algoResult, 123L);
62+
63+
var transformer = new SteinerTreeWriteResultTransformer(writeStep, graph, graphStore, resultStore, jobId, config);
64+
65+
var resultStream = transformer.apply(timedResult);
66+
var result = resultStream.findFirst().orElseThrow();
67+
68+
assertThat(result.preProcessingMillis()).isZero();
69+
assertThat(result.computeMillis()).isEqualTo(123L);
70+
assertThat(result.writeMillis()).isNotNegative();
71+
assertThat(result.configuration()).isEqualTo(config);
72+
assertThat(result.effectiveNodeCount()).isEqualTo(1L);
73+
assertThat(result.effectiveTargetNodesCount()).isEqualTo(2L);
74+
assertThat(result.totalWeight()).isEqualTo(3d);
75+
76+
assertThat(result.relationshipsWritten()).isEqualTo(5L);
77+
78+
verify(writeStep, times(1)).execute(graph, graphStore, resultStore, algoResult, jobId);
79+
verifyNoMoreInteractions(writeStep);
80+
}
81+
82+
@Test
83+
void shouldTransformEmptyResultToWriteResult() {
84+
var config = Map.<String, Object>of("boo", "foo");
85+
var graph = mock(Graph.class);
86+
var graphStore = mock(GraphStore.class);
87+
var resultStore = mock(ResultStore.class);
88+
var jobId = new JobId();
89+
var writeStep = mock(SteinerTreeWriteStep.class);
90+
when(writeStep.execute(any(), any(), any(), any(), any())).thenReturn(new RelationshipsWritten(0L));
91+
92+
var algoResult = SteinerTreeResult.EMPTY;
93+
94+
var timedResult = new TimedAlgorithmResult<>(algoResult, 123L);
95+
96+
var transformer = new SteinerTreeWriteResultTransformer(writeStep, graph, graphStore, resultStore, jobId, config);
97+
98+
var resultStream = transformer.apply(timedResult);
99+
var result = resultStream.findFirst().orElseThrow();
100+
101+
assertThat(result.preProcessingMillis()).isZero();
102+
assertThat(result.computeMillis()).isEqualTo(123L);
103+
assertThat(result.writeMillis()).isNotNegative();
104+
assertThat(result.configuration()).isEqualTo(config);
105+
assertThat(result.effectiveNodeCount()).isEqualTo(0L);
106+
assertThat(result.effectiveTargetNodesCount()).isEqualTo(0L);
107+
assertThat(result.totalWeight()).isEqualTo(0d);
108+
assertThat(result.relationshipsWritten()).isEqualTo(0);
109+
}
110+
}

0 commit comments

Comments
 (0)