Skip to content

Commit 1abe53e

Browse files
committed
Move all dimensions computation to the factory
1 parent d86ec02 commit 1abe53e

File tree

2 files changed

+58
-32
lines changed

2 files changed

+58
-32
lines changed

core/src/main/java/org/neo4j/gds/core/loading/CypherFactory.java

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
import org.neo4j.gds.NodeLabel;
2424
import org.neo4j.gds.NodeProjection;
2525
import org.neo4j.gds.NodeProjections;
26+
import org.neo4j.gds.PropertyMapping;
2627
import org.neo4j.gds.RelationshipProjection;
2728
import org.neo4j.gds.RelationshipProjections;
2829
import org.neo4j.gds.RelationshipType;
2930
import org.neo4j.gds.api.CSRGraphStoreFactory;
31+
import org.neo4j.gds.api.DefaultValue;
3032
import org.neo4j.gds.api.GraphLoaderContext;
3133
import org.neo4j.gds.config.GraphProjectFromCypherConfig;
3234
import org.neo4j.gds.core.GraphDimensions;
@@ -39,30 +41,33 @@
3941
import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory;
4042
import org.neo4j.gds.transaction.TransactionContext;
4143

42-
import javax.annotation.Nullable;
44+
import java.util.Collection;
45+
import java.util.Optional;
46+
import java.util.stream.Collectors;
47+
import java.util.stream.LongStream;
4348

4449
import static org.neo4j.gds.core.loading.CypherQueryEstimator.EstimationResult;
4550
import static org.neo4j.internal.kernel.api.security.AccessMode.Static.READ;
4651

4752
public final class CypherFactory extends CSRGraphStoreFactory<GraphProjectFromCypherConfig> {
4853

4954
private final GraphProjectFromCypherConfig cypherConfig;
50-
private final EstimationResult nodeEstimation;
51-
private final EstimationResult relationshipEstimation;
55+
private final long numberOfNodeProperties;
56+
private final long numberOfRelationshipProperties;
5257
private final ProgressTracker progressTracker;
5358

5459
public static CypherFactory createWithBaseDimensions(GraphProjectFromCypherConfig graphProjectConfig, GraphLoaderContext loadingContext, GraphDimensions graphDimensions) {
55-
return create(graphProjectConfig, loadingContext, graphDimensions);
60+
return create(graphProjectConfig, loadingContext, Optional.of(graphDimensions));
5661
}
5762

5863
public static CypherFactory createWithDerivedDimensions(GraphProjectFromCypherConfig graphProjectConfig, GraphLoaderContext loadingContext) {
59-
return create(graphProjectConfig, loadingContext, null);
64+
return create(graphProjectConfig, loadingContext, Optional.empty());
6065
}
6166

6267
private static CypherFactory create(
6368
GraphProjectFromCypherConfig graphProjectConfig,
6469
GraphLoaderContext loadingContext,
65-
@Nullable GraphDimensions dimensions
70+
Optional<GraphDimensions> baseDimensions
6671
) {
6772

6873
EstimationResult nodeEstimation;
@@ -79,39 +84,48 @@ private static CypherFactory create(
7984

8085
var dimBuilder = ImmutableGraphDimensions.builder();
8186

82-
if (dimensions != null) {
83-
dimBuilder.from(dimensions);
84-
}
87+
baseDimensions.ifPresent(dimBuilder::from);
88+
89+
long highestPossibleNodeCount = Math.max(baseDimensions
90+
.map(GraphDimensions::highestPossibleNodeCount)
91+
.orElse(-1L), nodeEstimation.estimatedRows());
92+
long nodeCount = Math.max(
93+
baseDimensions.map(GraphDimensions::nodeCount).orElse(-1L),
94+
nodeEstimation.estimatedRows()
95+
);
96+
long relCountUpperBound = Math.max(
97+
baseDimensions.map(GraphDimensions::relCountUpperBound).orElse(-1L),
98+
relationEstimation.estimatedRows()
99+
);
85100

86-
GraphDimensions dim = ImmutableGraphDimensions.builder()
87-
.highestPossibleNodeCount(nodeEstimation.estimatedRows())
88-
.nodeCount(nodeEstimation.estimatedRows())
89-
.relCountUpperBound(relationEstimation.estimatedRows())
101+
GraphDimensions dim = dimBuilder
102+
.highestPossibleNodeCount(highestPossibleNodeCount)
103+
.nodeCount(nodeCount)
104+
.relCountUpperBound(relCountUpperBound)
90105
.build();
91106

92107
return new CypherFactory(
93108
graphProjectConfig,
94109
loadingContext,
95110
dim,
96-
nodeEstimation,
97-
relationEstimation
111+
nodeEstimation.propertyCount(),
112+
relationEstimation.propertyCount()
98113
);
99114
}
100115

101116
private CypherFactory(
102117
GraphProjectFromCypherConfig graphProjectConfig,
103118
GraphLoaderContext loadingContext,
104119
GraphDimensions graphDimensions,
105-
EstimationResult nodeEstimation,
106-
EstimationResult relationshipEstimation
107-
120+
long estimatedNumberOfNodeProperties,
121+
long estimatedNumberOfRelProperties
108122
) {
109123
// TODO: need to pass capabilities from outside?
110124
super(graphProjectConfig, ImmutableStaticCapabilities.of(true), loadingContext, graphDimensions);
111125

112126
this.cypherConfig = graphProjectConfig;
113-
this.nodeEstimation = nodeEstimation;
114-
this.relationshipEstimation = relationshipEstimation;
127+
this.numberOfNodeProperties = estimatedNumberOfNodeProperties;
128+
this.numberOfRelationshipProperties = estimatedNumberOfRelProperties;
115129
this.progressTracker = initProgressTracker();
116130
}
117131

@@ -140,12 +154,7 @@ public MemoryEstimation estimateMemoryUsageAfterLoading() {
140154

141155
@Override
142156
public GraphDimensions estimationDimensions() {
143-
return ImmutableGraphDimensions.builder()
144-
.from(dimensions)
145-
.highestPossibleNodeCount(Math.max(dimensions.highestPossibleNodeCount(), nodeEstimation.estimatedRows()))
146-
.nodeCount(Math.max(dimensions.nodeCount(), nodeEstimation.estimatedRows()))
147-
.relCountUpperBound(Math.max(dimensions.relCountUpperBound(), relationshipEstimation.estimatedRows()))
148-
.build();
157+
return dimensions;
149158
}
150159

151160
@Override
@@ -190,11 +199,10 @@ public CSRGraphStore build() {
190199
}
191200

192201
private ProgressTracker initProgressTracker() {
193-
var estimatedDimensions = estimationDimensions();
194202
var task = Tasks.task(
195203
"Loading",
196-
Tasks.leaf("Nodes", estimatedDimensions.highestPossibleNodeCount()),
197-
Tasks.leaf("Relationships", estimatedDimensions.relCountUpperBound())
204+
Tasks.leaf("Nodes", dimensions.highestPossibleNodeCount()),
205+
Tasks.leaf("Relationships", dimensions.relCountUpperBound())
198206
);
199207

200208
if (graphProjectConfig.logProgress()) {
@@ -226,7 +234,7 @@ private NodeProjections buildEstimateNodeProjections() {
226234
var nodeProjection = NodeProjection
227235
.builder()
228236
.label(ElementProjection.PROJECT_ALL)
229-
.addAllProperties(nodeEstimation.propertyMappings())
237+
.addAllProperties(propertyMappings(numberOfNodeProperties))
230238
.build();
231239

232240
return NodeProjections.single(
@@ -239,12 +247,19 @@ private RelationshipProjections buildEstimateRelationshipProjections() {
239247
var relationshipProjection = RelationshipProjection
240248
.builder()
241249
.type(ElementProjection.PROJECT_ALL)
242-
.addAllProperties(relationshipEstimation.propertyMappings())
250+
.addAllProperties(propertyMappings(numberOfRelationshipProperties))
243251
.build();
244252

245253
return RelationshipProjections.single(
246254
RelationshipType.ALL_RELATIONSHIPS,
247255
relationshipProjection
248256
);
249257
}
258+
259+
private static Collection<PropertyMapping> propertyMappings(long propertyCount) {
260+
return LongStream
261+
.range(0, propertyCount)
262+
.mapToObj(property -> PropertyMapping.of(Long.toString(property), DefaultValue.DEFAULT))
263+
.collect(Collectors.toList());
264+
}
250265
}

core/src/test/java/org/neo4j/gds/core/GraphLoaderTest.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747

4848
import java.util.List;
4949
import java.util.Map;
50+
import java.util.regex.Matcher;
51+
import java.util.regex.Pattern;
5052
import java.util.stream.Stream;
5153

5254
import static org.assertj.core.api.Assertions.assertThat;
@@ -212,6 +214,7 @@ public void shouldTrackProgressWithNativeLoadingUsingIndex() {
212214

213215
@Test
214216
void shouldLogProgressWithCypherLoading() {
217+
var progressRegex = Pattern.compile("(\\d+)%$");
215218
var log = Neo4jProxy.testLog();
216219
new CypherLoaderBuilder()
217220
.databaseService(db)
@@ -234,7 +237,15 @@ void shouldLogProgressWithCypherLoading() {
234237
"Loading :: Relationships 100%",
235238
"Loading :: Relationships :: Finished",
236239
"Loading :: Finished"
237-
);
240+
)
241+
.noneMatch(message -> {
242+
Matcher matcher = progressRegex.matcher(message);
243+
if (matcher.find()) {
244+
int progress = Integer.parseInt(matcher.group(1));
245+
return progress > 100;
246+
}
247+
return false;
248+
});
238249

239250
assertThat(log.getMessages(TestLog.DEBUG)).isEmpty();
240251
}

0 commit comments

Comments
 (0)