Skip to content

Commit 9e6bb72

Browse files
authored
Merge pull request #6825 from DarthMax/2.3_pregel_inverse_gen
2.3 - Generate inverse indexes for generated Bidirectional Pregel Procedures
2 parents 1cf0ebb + 1fe1d39 commit 9e6bb72

File tree

9 files changed

+459
-26
lines changed

9 files changed

+459
-26
lines changed

pregel-proc-generator/src/main/java/org/neo4j/gds/beta/pregel/PregelValidation.java

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,23 @@ final class PregelValidation {
5050
private final Elements elementUtils;
5151

5252
// Represents the PregelComputation interface
53-
private final TypeMirror pregelComputation;
53+
private final TypeMirror basePregelComputation;
54+
55+
private final TypeMirror bidirectionalPregelComputation;
56+
5457
// Represents the PregelProcedureConfig interface
5558
private final TypeMirror pregelProcedureConfig;
5659

5760
PregelValidation(Messager messager, Elements elementUtils, Types typeUtils) {
5861
this.messager = messager;
5962
this.typeUtils = typeUtils;
6063
this.elementUtils = elementUtils;
61-
this.pregelComputation = MoreTypes.asDeclared(
64+
this.basePregelComputation = MoreTypes.asDeclared(
6265
typeUtils.erasure(elementUtils.getTypeElement(BasePregelComputation.class.getName()).asType())
6366
);
67+
this.bidirectionalPregelComputation = MoreTypes.asDeclared(
68+
typeUtils.erasure(elementUtils.getTypeElement(BidirectionalPregelComputation.class.getName()).asType())
69+
);
6470
this.pregelProcedureConfig = MoreTypes.asDeclared(elementUtils
6571
.getTypeElement(PregelProcedureConfig.class.getName())
6672
.asType());
@@ -69,7 +75,7 @@ final class PregelValidation {
6975
Optional<Spec> validate(Element pregelElement) {
7076
if (
7177
!isClass(pregelElement) ||
72-
!isPregelComputation(pregelElement) ||
78+
!isBasePregelComputation(pregelElement) ||
7379
!isPregelProcedureConfig(pregelElement) ||
7480
!hasEmptyConstructor(pregelElement) ||
7581
!configHasFactoryMethod(pregelElement)
@@ -92,7 +98,8 @@ Optional<Spec> validate(Element pregelElement) {
9298
configTypeName,
9399
procedure.name(),
94100
procedure.modes(),
95-
maybeDescription
101+
maybeDescription,
102+
requiresInverseIndex(pregelElement)
96103
));
97104
}
98105

@@ -108,17 +115,17 @@ private boolean isClass(Element pregelElement) {
108115
return isClass;
109116
}
110117

111-
private Optional<DeclaredType> pregelComputation(Element pregelElement) {
118+
private Optional<DeclaredType> pregelComputation(Element pregelElement, TypeMirror computationInterface) {
112119
// TODO: this check needs to bubble up the inheritance tree
113120
return MoreElements.asType(pregelElement).getInterfaces().stream()
114121
.map(MoreTypes::asDeclared)
115-
.filter(declaredType -> typeUtils.isSubtype(declaredType, pregelComputation))
122+
.filter(declaredType -> typeUtils.isSubtype(declaredType, computationInterface))
116123
.findFirst();
117124
}
118125

119-
private boolean isPregelComputation(Element pregelElement) {
126+
private boolean isBasePregelComputation(Element pregelElement) {
120127
var pregelTypeElement = MoreElements.asType(pregelElement);
121-
var maybeInterface = pregelComputation(pregelElement);
128+
var maybeInterface = pregelComputation(pregelElement, basePregelComputation);
122129
boolean isPregelComputation = maybeInterface.isPresent();
123130

124131
if (!isPregelComputation) {
@@ -131,6 +138,11 @@ private boolean isPregelComputation(Element pregelElement) {
131138
return isPregelComputation;
132139
}
133140

141+
private boolean requiresInverseIndex(Element pregelElement) {
142+
var maybeInterface = pregelComputation(pregelElement, bidirectionalPregelComputation);
143+
return maybeInterface.isPresent();
144+
}
145+
134146
private boolean isPregelProcedureConfig(Element pregelElement) {
135147
var config = config(pregelElement);
136148

@@ -198,7 +210,7 @@ private boolean configHasFactoryMethod(Element pregelElement) {
198210
}
199211

200212
private TypeMirror config(Element pregelElement) {
201-
return pregelComputation(pregelElement)
213+
return pregelComputation(pregelElement, basePregelComputation)
202214
.map(declaredType -> declaredType.getTypeArguments().get(0))
203215
.orElseThrow(() -> new IllegalStateException("Could not find a pregel computation"));
204216
}
@@ -218,6 +230,8 @@ interface Spec {
218230
GDSMode[] procedureModes();
219231

220232
Optional<String> description();
233+
234+
boolean requiresInverseIndex();
221235
}
222236

223237
}

pregel-proc-generator/src/main/java/org/neo4j/gds/beta/pregel/ProcedureGenerator.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
import org.neo4j.gds.core.utils.progress.tasks.Task;
3737
import org.neo4j.gds.executor.ExecutionMode;
3838
import org.neo4j.gds.executor.GdsCallable;
39+
import org.neo4j.gds.executor.validation.ValidationConfiguration;
40+
import org.neo4j.gds.pregel.proc.PregelBaseProc;
3941
import org.neo4j.gds.results.MemoryEstimateResult;
4042
import org.neo4j.procedure.Description;
4143
import org.neo4j.procedure.Mode;
@@ -98,6 +100,11 @@ TypeSpec typeSpec() {
98100
typeSpecBuilder.addMethod(procResultMethod());
99101
typeSpecBuilder.addMethod(newConfigMethod());
100102
typeSpecBuilder.addMethod(algorithmFactoryMethod(algorithmClassName));
103+
104+
if (pregelSpec.requiresInverseIndex()) {
105+
typeSpecBuilder.addMethod(validationConfigMethod());
106+
}
107+
101108
return typeSpecBuilder.build();
102109
}
103110

@@ -281,4 +288,16 @@ private MethodSpec algorithmFactoryMethod(ClassName algorithmClassName) {
281288
.addStatement("return $L", anonymousFactoryType)
282289
.build();
283290
}
291+
292+
private MethodSpec validationConfigMethod() {
293+
return MethodSpec.methodBuilder("validationConfig")
294+
.addAnnotation(Override.class)
295+
.addModifiers(Modifier.PUBLIC)
296+
.returns(ParameterizedTypeName.get(
297+
ClassName.get(ValidationConfiguration.class),
298+
pregelSpec.configTypeName()
299+
))
300+
.addStatement("return $T.ensureIndexValidation(log, taskRegistryFactory)", PregelBaseProc.class)
301+
.build();
302+
}
284303
}

pregel-proc-generator/src/test/java/org/neo4j/gds/beta/pregel/PregelProcessorTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,22 @@ void positiveTest(String className) {
6262
);
6363
}
6464

65+
@ParameterizedTest
66+
@ValueSource(strings = {
67+
"BidirectionalComputation"
68+
})
69+
void positiveBiTest(String className) {
70+
assertAbout(javaSource())
71+
.that(forResource(String.format("positive/%s.java", className)))
72+
.processedWith(new PregelProcessor())
73+
.compilesWithoutError()
74+
.and()
75+
.generatesSources(
76+
loadExpectedFile(formatWithLocale("expected/%sStreamProc.java", className)),
77+
loadExpectedFile(formatWithLocale("expected/%sAlgorithm.java", className))
78+
);
79+
}
80+
6581
@Test
6682
void baseClassMustBeAClass() {
6783
runNegativeTest(
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.beta.pregel.cc;
21+
22+
import javax.annotation.processing.Generated;
23+
import org.neo4j.gds.Algorithm;
24+
import org.neo4j.gds.api.Graph;
25+
import org.neo4j.gds.beta.pregel.Pregel;
26+
import org.neo4j.gds.beta.pregel.PregelProcedureConfig;
27+
import org.neo4j.gds.beta.pregel.PregelResult;
28+
import org.neo4j.gds.core.concurrency.Pools;
29+
import org.neo4j.gds.core.utils.TerminationFlag;
30+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
31+
32+
@Generated("org.neo4j.gds.beta.pregel.PregelProcessor")
33+
public final class BidirectionalComputationAlgorithm extends Algorithm<PregelResult> {
34+
private final Pregel<PregelProcedureConfig> pregelJob;
35+
36+
BidirectionalComputationAlgorithm(Graph graph, PregelProcedureConfig configuration,
37+
ProgressTracker progressTracker) {
38+
super(progressTracker);
39+
this.pregelJob = Pregel.create(graph, configuration, new BidirectionalComputation(), Pools.DEFAULT, progressTracker);
40+
}
41+
42+
@Override
43+
public void setTerminationFlag(TerminationFlag terminationFlag) {
44+
super.setTerminationFlag(terminationFlag);
45+
pregelJob.setTerminationFlag(terminationFlag);
46+
}
47+
48+
@Override
49+
public PregelResult compute() {
50+
return pregelJob.run();
51+
}
52+
53+
@Override
54+
public void release() {
55+
pregelJob.release();
56+
}
57+
}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.beta.pregel.cc;
21+
22+
import java.util.Map;
23+
import java.util.stream.Stream;
24+
import javax.annotation.processing.Generated;
25+
import org.neo4j.gds.BaseProc;
26+
import org.neo4j.gds.GraphAlgorithmFactory;
27+
import org.neo4j.gds.api.Graph;
28+
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
29+
import org.neo4j.gds.beta.pregel.Pregel;
30+
import org.neo4j.gds.beta.pregel.PregelProcedureConfig;
31+
import org.neo4j.gds.core.CypherMapWrapper;
32+
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
33+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
34+
import org.neo4j.gds.core.utils.progress.tasks.Task;
35+
import org.neo4j.gds.executor.ExecutionMode;
36+
import org.neo4j.gds.executor.GdsCallable;
37+
import org.neo4j.gds.executor.validation.ValidationConfiguration;
38+
import org.neo4j.gds.pregel.proc.PregelBaseProc;
39+
import org.neo4j.gds.pregel.proc.PregelStreamProc;
40+
import org.neo4j.gds.pregel.proc.PregelStreamResult;
41+
import org.neo4j.gds.results.MemoryEstimateResult;
42+
import org.neo4j.procedure.Description;
43+
import org.neo4j.procedure.Mode;
44+
import org.neo4j.procedure.Name;
45+
import org.neo4j.procedure.Procedure;
46+
47+
@GdsCallable(
48+
name = "gds.pregel.bidirectionalTest.stream",
49+
executionMode = ExecutionMode.STREAM,
50+
description = "Bidirectional Test computation description"
51+
)
52+
@Generated("org.neo4j.gds.beta.pregel.PregelProcessor")
53+
public final class BidirectionalComputationStreamProc extends PregelStreamProc<BidirectionalComputationAlgorithm, PregelProcedureConfig> {
54+
@Procedure(
55+
name = "gds.pregel.bidirectionalTest.stream",
56+
mode = Mode.READ
57+
)
58+
@Description("Bidirectional Test computation description")
59+
public Stream<PregelStreamResult> stream(@Name("graphName") String graphName,
60+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) {
61+
return stream(compute(graphName, configuration));
62+
}
63+
64+
@Procedure(
65+
name = "gds.pregel.bidirectionalTest.stream.estimate",
66+
mode = Mode.READ
67+
)
68+
@Description(BaseProc.ESTIMATE_DESCRIPTION)
69+
public Stream<MemoryEstimateResult> estimate(
70+
@Name("graphNameOrConfiguration") Object graphNameOrConfiguration,
71+
@Name("algoConfiguration") Map<String, Object> algoConfiguration) {
72+
return computeEstimate(graphNameOrConfiguration, algoConfiguration);
73+
}
74+
75+
@Override
76+
protected PregelStreamResult streamResult(long originalNodeId, long internalNodeId,
77+
NodePropertyValues nodePropertyValues) {
78+
throw new UnsupportedOperationException();
79+
}
80+
81+
@Override
82+
protected PregelProcedureConfig newConfig(String username, CypherMapWrapper config) {
83+
return PregelProcedureConfig.of(config);
84+
}
85+
86+
@Override
87+
public GraphAlgorithmFactory<BidirectionalComputationAlgorithm, PregelProcedureConfig> algorithmFactory(
88+
) {
89+
return new GraphAlgorithmFactory<BidirectionalComputationAlgorithm, PregelProcedureConfig>() {
90+
@Override
91+
public BidirectionalComputationAlgorithm build(Graph graph,
92+
PregelProcedureConfig configuration, ProgressTracker progressTracker) {
93+
return new BidirectionalComputationAlgorithm(graph, configuration, progressTracker);
94+
}
95+
96+
@Override
97+
public String taskName() {
98+
return BidirectionalComputationAlgorithm.class.getSimpleName();
99+
}
100+
101+
@Override
102+
public Task progressTask(Graph graph, PregelProcedureConfig configuration) {
103+
return Pregel.progressTask(graph, configuration);
104+
}
105+
106+
@Override
107+
public MemoryEstimation memoryEstimation(PregelProcedureConfig configuration) {
108+
var computation = new BidirectionalComputation();
109+
return Pregel.memoryEstimation(computation.schema(configuration), computation.reducer().isEmpty(), configuration.isAsynchronous());
110+
}
111+
};
112+
}
113+
114+
@Override
115+
public ValidationConfiguration<PregelProcedureConfig> validationConfig() {
116+
return PregelBaseProc.ensureIndexValidation(log, taskRegistryFactory);
117+
}
118+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.beta.pregel.cc;
21+
22+
import org.neo4j.gds.beta.pregel.BidirectionalPregelComputation;
23+
import org.neo4j.gds.beta.pregel.Messages;
24+
import org.neo4j.gds.beta.pregel.PregelComputation;
25+
import org.neo4j.gds.beta.pregel.PregelProcedureConfig;
26+
import org.neo4j.gds.beta.pregel.context.ComputeContext;
27+
import org.neo4j.gds.beta.pregel.PregelSchema;
28+
import org.neo4j.gds.beta.pregel.annotation.GDSMode;
29+
import org.neo4j.gds.beta.pregel.annotation.PregelProcedure;
30+
import org.neo4j.gds.beta.pregel.context.ComputeContext.BidirectionalComputeContext;
31+
32+
@PregelProcedure(
33+
name = "gds.pregel.bidirectionalTest",
34+
description = "Bidirectional Test computation description",
35+
modes = {GDSMode.STREAM}
36+
)
37+
public class BidirectionalComputation implements BidirectionalPregelComputation<PregelProcedureConfig> {
38+
39+
@Override
40+
public PregelSchema schema(PregelProcedureConfig config) {
41+
return null;
42+
}
43+
44+
@Override
45+
public void compute(BidirectionalComputeContext<PregelProcedureConfig> context, Messages messages) {
46+
47+
}
48+
}

proc/pregel/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ dependencies {
2020
implementation project(':executor')
2121
implementation project(':proc-common')
2222
implementation project(':string-formatting')
23-
23+
implementation project(':graph-schema-api')
2424
api project(':pregel')
2525

2626
testAnnotationProcessor project(':annotations')

0 commit comments

Comments
 (0)