Skip to content

Commit b13c445

Browse files
committed
Make optionals with converters better supported
The generated @configuration code will now inspect the converter in/out types and do Optional wrapping/unwrapping automatically. This also includes the builder methods.
1 parent ac9dadd commit b13c445

File tree

6 files changed

+89
-31
lines changed

6 files changed

+89
-31
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNNConfig.java

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,12 @@ default boolean heterogeneous() {
5151
}
5252

5353
@Configuration.ToMapValue("org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#toMapGenerateFeaturesConfig")
54-
@Configuration.ConvertWith(method = "org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#parseGenerateFeaturesConfig", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
55-
default Optional<GenerateFeaturesConfig> generateFeatures() {
56-
return Optional.empty();
57-
}
54+
@Configuration.ConvertWith(method = "parseGenerateFeaturesConfig", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
55+
Optional<GenerateFeaturesConfig> generateFeatures();
5856

5957
@Configuration.ToMapValue("org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#toMapBinarizationConfig")
60-
@Configuration.ConvertWith(method = "org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#parseBinarizationConfig", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
61-
default Optional<BinarizeFeaturesConfig> binarizeFeatures() {
62-
return Optional.empty();
63-
}
58+
@Configuration.ConvertWith(method = "parseBinarizationConfig", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
59+
Optional<BinarizeFeaturesConfig> binarizeFeatures();
6460

6561
@Value.Check
6662
default void validate() {
@@ -73,28 +69,22 @@ default void validate() {
7369
}
7470
}
7571

76-
static Optional<BinarizeFeaturesConfig> parseBinarizationConfig(Object o) {
77-
if (o instanceof Optional) {
78-
return (Optional<BinarizeFeaturesConfig>) o;
79-
}
80-
var cypherMapWrapper = CypherMapWrapper.create((Map<String, Object>) o);
72+
static BinarizeFeaturesConfig parseBinarizationConfig(Map<String, Object> parameter) {
73+
var cypherMapWrapper = CypherMapWrapper.create(parameter);
8174
var binarizeFeaturesConfig = new BinarizeFeaturesConfigImpl(cypherMapWrapper);
8275
cypherMapWrapper.requireOnlyKeysFrom(binarizeFeaturesConfig.configKeys());
83-
return Optional.of(binarizeFeaturesConfig);
76+
return binarizeFeaturesConfig;
8477
}
8578

8679
static Map<String, Object> toMapBinarizationConfig(BinarizeFeaturesConfig config) {
8780
return config.toMap();
8881
}
8982

90-
static Optional<GenerateFeaturesConfig> parseGenerateFeaturesConfig(Object o) {
91-
if (o instanceof Optional) {
92-
return (Optional<GenerateFeaturesConfig>) o;
93-
}
94-
var cypherMapWrapper = CypherMapWrapper.create((Map<String, Object>) o);
83+
static GenerateFeaturesConfig parseGenerateFeaturesConfig(Map<String, Object> parameter) {
84+
var cypherMapWrapper = CypherMapWrapper.create(parameter);
9585
var generateFeaturesConfig = new GenerateFeaturesConfigImpl(cypherMapWrapper);
9686
cypherMapWrapper.requireOnlyKeysFrom(generateFeaturesConfig.configKeys());
97-
return Optional.of(generateFeaturesConfig);
87+
return generateFeaturesConfig;
9888
}
9989

10090
static Map<String, Object> toMapGenerateFeaturesConfig(GenerateFeaturesConfig config) {

config-api/src/main/java/org/neo4j/gds/config/BaseConfig.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ default Map<String, Object> toMap() {
6060
return new HashMap<>();
6161
}
6262

63-
static Optional<String> trim(Optional<String> input) {
64-
return input.map(String::trim).filter(s -> !s.isEmpty());
63+
static String trim(String input) {
64+
return input.trim();
6565
}
6666
}

config-generator/src/main/java/org/neo4j/gds/proc/GenerateConfiguration.java

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
import javax.lang.model.element.Modifier;
4545
import javax.lang.model.element.PackageElement;
4646
import javax.lang.model.element.TypeElement;
47-
import javax.lang.model.element.VariableElement;
4847
import javax.lang.model.type.DeclaredType;
4948
import javax.lang.model.type.TypeMirror;
5049
import javax.lang.model.util.Elements;
@@ -454,7 +453,7 @@ private Optional<MemberDefinition> memberDefinition(NameAllocator names, ConfigP
454453
TypeMirror targetType = method.getReturnType();
455454
ConvertWith convertWith = method.getAnnotation(ConvertWith.class);
456455
if (convertWith == null) {
457-
return memberDefinition(names, member, targetType);
456+
return memberDefinition(names, member, targetType, Optional.empty());
458457
}
459458

460459
String converter = convertWith.method().trim();
@@ -555,9 +554,22 @@ private Optional<MemberDefinition> memberDefinition(
555554

556555
if (validCandidates.size() == 1) {
557556
ExecutableElement candidate = validCandidates.get(0);
558-
VariableElement parameter = candidate.getParameters().get(0);
559557
TypeMirror currentType = currentClass.asType();
560-
return memberDefinition(names, member, parameter.asType())
558+
TypeMirror converterInputType = candidate.getParameters().get(0).asType();
559+
if (isTypeOf(Optional.class, targetType)) {
560+
return memberDefinition(names, member, targetType, Optional.of(converterInputType))
561+
.map(d -> ImmutableMemberDefinition.builder()
562+
.from(d)
563+
.addConverter(c -> CodeBlock.of(
564+
"$L.map($T::$N)",
565+
c,
566+
currentType,
567+
candidate.getSimpleName().toString()
568+
))
569+
.build()
570+
);
571+
}
572+
return memberDefinition(names, member, converterInputType, Optional.empty())
561573
.map(d -> ImmutableMemberDefinition.builder()
562574
.from(d)
563575
.addConverter(c -> CodeBlock.of(
@@ -612,6 +624,9 @@ private void validateCandidateModifiers(
612624
if (!(candidate.getParameters().size() == 1)) {
613625
invalidCandidates.add(InvalidCandidate.of(candidate, "May only accept one parameter"));
614626
}
627+
if (isTypeOf(Optional.class, targetType)) {
628+
targetType = ((DeclaredType) targetType).getTypeArguments().get(0);
629+
}
615630
if (!typeUtils.isAssignable(candidate.getReturnType(), targetType)) {
616631
invalidCandidates.add(InvalidCandidate.of(
617632
candidate,
@@ -624,7 +639,8 @@ private void validateCandidateModifiers(
624639
private Optional<MemberDefinition> memberDefinition(
625640
NameAllocator names,
626641
ConfigParser.Member member,
627-
TypeMirror targetType
642+
TypeMirror targetType,
643+
Optional<TypeMirror> converterInputType
628644
) {
629645
ImmutableMemberDefinition.Builder builder = ImmutableMemberDefinition
630646
.builder()
@@ -680,7 +696,16 @@ private Optional<MemberDefinition> memberDefinition(
680696
}
681697
}
682698

683-
if (maybeInnerType.isPresent()) {
699+
if (converterInputType.isPresent()) {
700+
TypeMirror expectedType = typeUtils.erasure(converterInputType.get());
701+
var x = typeUtils.getDeclaredType(elementUtils.getTypeElement("java.util.Optional"), expectedType);
702+
builder
703+
.methodPrefix("get")
704+
.methodName("Optional")
705+
.expectedType(CodeBlock.of("$T.class", expectedType))
706+
.expectedTypeRaw(expectedType)
707+
.expectedTypeRawWrappedInOptional(x);
708+
} else if (maybeInnerType.isPresent()) {
684709
builder
685710
.methodPrefix("get")
686711
.methodName("Optional")
@@ -760,6 +785,10 @@ interface MemberDefinition {
760785

761786
Optional<CodeBlock> expectedType();
762787

788+
Optional<TypeMirror> expectedTypeRaw();
789+
790+
Optional<TypeMirror> expectedTypeRawWrappedInOptional();
791+
763792
List<UnaryOperator<CodeBlock>> converters();
764793
}
765794

config-generator/src/main/java/org/neo4j/gds/proc/GenerateConfigurationBuilder.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,11 @@ private static List<MethodSpec> defineConfigMapEntrySetters(
212212
.flatMap(implMember -> {
213213
var setterMethods = Stream.<MethodSpec>builder();
214214
String configKeyName = implMember.member().methodName();
215+
TypeMirror parameterType = implMember.parameterType();
215216

216217
MethodSpec.Builder setMethodBuilder = MethodSpec.methodBuilder(configKeyName)
217218
.addModifiers(Modifier.PUBLIC)
218-
.addParameter(unpackedType(implMember.parameterType()), configKeyName)
219+
.addParameter(unpackedType(implMember.expectedTypeRaw().orElse(parameterType)), configKeyName)
219220
.returns(builderClassName)
220221
.addCode(CodeBlock.builder()
221222
.addStatement(
@@ -230,12 +231,12 @@ private static List<MethodSpec> defineConfigMapEntrySetters(
230231

231232
setterMethods.add(setMethodBuilder.build());
232233

233-
if (isTypeOf(Optional.class, implMember.parameterType())) {
234+
if (isTypeOf(Optional.class, parameterType)) {
234235
String lambdaVarName = "actual" + configKeyName;
235236

236237
var optionalSetterBuilder = MethodSpec.methodBuilder(configKeyName)
237238
.addModifiers(Modifier.PUBLIC)
238-
.addParameter(TypeName.get(implMember.parameterType()), configKeyName)
239+
.addParameter(TypeName.get(implMember.expectedTypeRawWrappedInOptional().orElse(parameterType)), configKeyName)
239240
.returns(builderClassName)
240241
.addStatement(
241242
"$1N.ifPresent($2N -> this.$3N.put(\"$4L\", $2N))",

config-generator/src/test/resources/expected/Conversions.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.ArrayList;
2323
import java.util.HashMap;
2424
import java.util.Map;
25+
import java.util.Optional;
2526
import java.util.stream.Collectors;
2627
import javax.annotation.processing.Generated;
2728
import org.jetbrains.annotations.NotNull;
@@ -38,6 +39,8 @@ public final class ConversionsConfig implements Conversions.MyConversion {
3839

3940
private String referenceTypeAsResult;
4041

42+
private Optional<Conversions.Foo> optional;
43+
4144
public ConversionsConfig(@NotNull CypherMapAccess config) {
4245
ArrayList<IllegalArgumentException> errors = new ArrayList<>();
4346
try {
@@ -63,6 +66,11 @@ public ConversionsConfig(@NotNull CypherMapAccess config) {
6366
} catch (IllegalArgumentException e) {
6467
errors.add(e);
6568
}
69+
try {
70+
this.optional = CypherMapAccess.failOnNull("optional", config.getOptional("optional", Map.class).map(Conversions.MyConversion::toFoo));
71+
} catch (IllegalArgumentException e) {
72+
errors.add(e);
73+
}
6674
if (!errors.isEmpty()) {
6775
if (errors.size() == 1) {
6876
throw errors.get(0);
@@ -101,6 +109,11 @@ public String referenceTypeAsResult() {
101109
return this.referenceTypeAsResult;
102110
}
103111

112+
@Override
113+
public Optional<Conversions.Foo> optional() {
114+
return this.optional;
115+
}
116+
104117
public static ConversionsConfig.Builder builder() {
105118
return new ConversionsConfig.Builder();
106119
}
@@ -118,6 +131,7 @@ public static ConversionsConfig.Builder from(Conversions.MyConversion baseConfig
118131
builder.inheritedMethod(String.valueOf(baseConfig.inheritedMethod()));
119132
builder.qualifiedMethod(String.valueOf(baseConfig.qualifiedMethod()));
120133
builder.referenceTypeAsResult(positive.Conversions.MyConversion.remove42(baseConfig.referenceTypeAsResult()));
134+
builder.optional(baseConfig.optional().map(v -> positive.Conversions.MyConversion.fromFoo(v)));
121135
return builder;
122136
}
123137

@@ -141,6 +155,16 @@ public ConversionsConfig.Builder referenceTypeAsResult(String referenceTypeAsRes
141155
return this;
142156
}
143157

158+
public ConversionsConfig.Builder optional(Map optional) {
159+
this.config.put("optional", optional));
160+
return this;
161+
}
162+
163+
public ConversionsConfig.Builder optional(Optional<Map> optional) {
164+
optional.ifPresent(actualoptional -> this.config.put("optional", actualoptional));
165+
return this;
166+
}
167+
144168
public Conversions.MyConversion build() {
145169
CypherMapWrapper config = CypherMapWrapper.create(this.config);
146170
return new ConversionsConfig(config);

config-generator/src/test/resources/positive/Conversions.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import org.neo4j.gds.annotation.Configuration;
2323

24+
import java.util.Map;
25+
import java.util.Optional;
2426

2527
public interface Conversions {
2628

@@ -51,6 +53,9 @@ interface MyConversion extends BaseConversion {
5153
@Configuration.ConvertWith(method = "add42", inverse = "positive.Conversions.MyConversion#remove42")
5254
String referenceTypeAsResult();
5355

56+
@Configuration.ConvertWith(method = "toFoo", inverse = "positive.Conversions.MyConversion#fromFoo")
57+
Optional<Foo> optional();
58+
5459
static int toInt(String input) {
5560
return Integer.parseInt(input);
5661
}
@@ -62,5 +67,14 @@ static String add42(String input) {
6267
static String remove42(String input) {
6368
return input.substring(0, input.length() - 2);
6469
}
70+
71+
static Foo toFoo(Map<String, Object> parameter) {
72+
return new Foo() {};
73+
}
74+
static Map<String, Object> fromFoo(Foo f) {
75+
return Map.of();
76+
}
6577
}
78+
79+
interface Foo {}
6680
}

0 commit comments

Comments
 (0)