Skip to content

Commit 9f912cd

Browse files
authored
Merge pull request #6628 from Mats-SX/optionals-with-converters
Improve support for Optionals with converters as configuration parameters
2 parents 6a39685 + b6921a3 commit 9f912cd

File tree

11 files changed

+118
-54
lines changed

11 files changed

+118
-54
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashGNNConfig.java

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import java.util.Map;
3030
import java.util.Optional;
3131

32-
@Configuration
3332
interface HashGNNConfig extends AlgoBaseConfig, FeaturePropertiesConfig, RandomSeedConfig {
3433

3534
@Configuration.IntegerRange(min = 1)
@@ -51,16 +50,12 @@ default boolean heterogeneous() {
5150
}
5251

5352
@Configuration.ToMapValue("org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#toMapGenerateFeaturesConfig")
54-
@Configuration.ConvertWith(method = "org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#parseGenerateFeaturesConfig", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
55-
default Optional<GenerateFeaturesConfig> generateFeatures() {
56-
return Optional.empty();
57-
}
53+
@Configuration.ConvertWith(method = "parseGenerateFeaturesConfig", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
54+
Optional<GenerateFeaturesConfig> generateFeatures();
5855

5956
@Configuration.ToMapValue("org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#toMapBinarizationConfig")
60-
@Configuration.ConvertWith(method = "org.neo4j.gds.embeddings.hashgnn.HashGNNConfig#parseBinarizationConfig", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
61-
default Optional<BinarizeFeaturesConfig> binarizeFeatures() {
62-
return Optional.empty();
63-
}
57+
@Configuration.ConvertWith(method = "parseBinarizationConfig", inverse = Configuration.ConvertWith.INVERSE_IS_TO_MAP)
58+
Optional<BinarizeFeaturesConfig> binarizeFeatures();
6459

6560
@Value.Check
6661
default void validate() {
@@ -73,28 +68,22 @@ default void validate() {
7368
}
7469
}
7570

76-
static Optional<BinarizeFeaturesConfig> parseBinarizationConfig(Object o) {
77-
if (o instanceof Optional) {
78-
return (Optional<BinarizeFeaturesConfig>) o;
79-
}
80-
var cypherMapWrapper = CypherMapWrapper.create((Map<String, Object>) o);
71+
static BinarizeFeaturesConfig parseBinarizationConfig(Map<String, Object> parameter) {
72+
var cypherMapWrapper = CypherMapWrapper.create(parameter);
8173
var binarizeFeaturesConfig = new BinarizeFeaturesConfigImpl(cypherMapWrapper);
8274
cypherMapWrapper.requireOnlyKeysFrom(binarizeFeaturesConfig.configKeys());
83-
return Optional.of(binarizeFeaturesConfig);
75+
return binarizeFeaturesConfig;
8476
}
8577

8678
static Map<String, Object> toMapBinarizationConfig(BinarizeFeaturesConfig config) {
8779
return config.toMap();
8880
}
8981

90-
static Optional<GenerateFeaturesConfig> parseGenerateFeaturesConfig(Object o) {
91-
if (o instanceof Optional) {
92-
return (Optional<GenerateFeaturesConfig>) o;
93-
}
94-
var cypherMapWrapper = CypherMapWrapper.create((Map<String, Object>) o);
82+
static GenerateFeaturesConfig parseGenerateFeaturesConfig(Map<String, Object> parameter) {
83+
var cypherMapWrapper = CypherMapWrapper.create(parameter);
9584
var generateFeaturesConfig = new GenerateFeaturesConfigImpl(cypherMapWrapper);
9685
cypherMapWrapper.requireOnlyKeysFrom(generateFeaturesConfig.configKeys());
97-
return Optional.of(generateFeaturesConfig);
86+
return generateFeaturesConfig;
9887
}
9988

10089
static Map<String, Object> toMapGenerateFeaturesConfig(GenerateFeaturesConfig config) {

algo/src/test/java/org/neo4j/gds/embeddings/hashgnn/DensifyTaskTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void shouldDensify() {
3636
var nodeCount = 3;
3737

3838
var partition = new Partition(0, nodeCount);
39-
var config = HashGNNConfigImpl
39+
var config = HashGNNStreamConfigImpl
4040
.builder()
4141
.featureProperties(List.of("f1", "f2"))
4242
.embeddingDensity(4)

algo/src/test/java/org/neo4j/gds/embeddings/hashgnn/GenerateFeaturesTaskTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ void shouldGenerateCorrectNumberOfFeatures() {
5151

5252
var partition = new Partition(0, graph.nodeCount());
5353
var totalFeatureCount = new MutableLong(0);
54-
var config = HashGNNConfigImpl
54+
var config = HashGNNStreamConfigImpl
5555
.builder()
5656
.generateFeatures(Map.of("dimension", embeddingDimension, "densityLevel", densityLevel))
5757
.iterations(1337)

algo/src/test/java/org/neo4j/gds/embeddings/hashgnn/HashGNNConfigTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
class HashGNNConfigTest {
3232
@Test
3333
void binarizationConfigCorrectType() {
34-
var config = HashGNNConfigImpl
34+
var config = HashGNNStreamConfigImpl
3535
.builder()
3636
.featureProperties(List.of("x"))
3737
.binarizeFeatures(Map.of("dimension", 100))
@@ -44,7 +44,7 @@ void binarizationConfigCorrectType() {
4444
@Test
4545
void shouldNotAllowGeneratedAndFeatureProperties() {
4646
assertThatThrownBy(() -> {
47-
HashGNNConfigImpl
47+
HashGNNStreamConfigImpl
4848
.builder()
4949
.featureProperties(List.of("x"))
5050
.generateFeatures(Map.of("dimension", 100, "densityLevel", 2))
@@ -57,7 +57,7 @@ void shouldNotAllowGeneratedAndFeatureProperties() {
5757
@Test
5858
void requiresFeaturePropertiesIfNoGeneratedFeatures() {
5959
assertThatThrownBy(() -> {
60-
HashGNNConfigImpl
60+
HashGNNStreamConfigImpl
6161
.builder()
6262
.embeddingDensity(4)
6363
.iterations(100)
@@ -68,7 +68,7 @@ void requiresFeaturePropertiesIfNoGeneratedFeatures() {
6868
@Test
6969
void requiresDensityLevelAtMostDensity() {
7070
assertThatThrownBy(() -> {
71-
HashGNNConfigImpl
71+
HashGNNStreamConfigImpl
7272
.builder()
7373
.embeddingDensity(4)
7474
.generateFeatures(Map.of("dimension", 4, "densityLevel", 5))
@@ -80,7 +80,7 @@ void requiresDensityLevelAtMostDensity() {
8080
@Test
8181
void failsOnInvalidBinarizationKeys() {
8282
assertThatThrownBy(() -> {
83-
new HashGNNConfigImpl(CypherMapWrapper.create(
83+
new HashGNNStreamConfigImpl(CypherMapWrapper.create(
8484
Map.of(
8585
"mutateProperty", "foo",
8686
"featureProperties", List.of("x"),
@@ -97,7 +97,7 @@ void failsOnInvalidBinarizationKeys() {
9797
@Test
9898
void failsOnInvalidGenerateFeaturesKeys() {
9999
assertThatThrownBy(() -> {
100-
new HashGNNConfigImpl(CypherMapWrapper.create(
100+
new HashGNNStreamConfigImpl(CypherMapWrapper.create(
101101
Map.of(
102102
"generateFeatures", Map.of("dimension", 100, "densityElfen", 2),
103103
"embeddingDensity", 4,

algo/src/test/java/org/neo4j/gds/embeddings/hashgnn/HashGNNTest.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class HashGNNTest {
9797
@Test
9898
void binaryLowNeighborInfluence() {
9999
int embeddingDensity = 4;
100-
var config = HashGNNConfigImpl
100+
var config = HashGNNStreamConfigImpl
101101
.builder()
102102
.featureProperties(List.of("f1", "f2"))
103103
.embeddingDensity(embeddingDensity)
@@ -114,7 +114,7 @@ void binaryLowNeighborInfluence() {
114114

115115
@Test
116116
void binaryHighEmbeddingDensityHighNeighborInfluence() {
117-
var config = HashGNNConfigImpl
117+
var config = HashGNNStreamConfigImpl
118118
.builder()
119119
.featureProperties(List.of("f1", "f2"))
120120
.embeddingDensity(200)
@@ -143,7 +143,7 @@ static Stream<Arguments> determinismParams() {
143143
@ParameterizedTest
144144
@MethodSource("determinismParams")
145145
void shouldBeDeterministic(int concurrency, boolean binarize, boolean dimReduce) {
146-
var configBuilder = HashGNNConfigImpl
146+
var configBuilder = HashGNNStreamConfigImpl
147147
.builder()
148148
.featureProperties(List.of("f1", "f2"))
149149
.embeddingDensity(2)
@@ -177,7 +177,7 @@ void shouldRunOnDoublesAndBeDeterministicEqualNeighborInfluence() {
177177
// not all random seeds will give b a unique feature
178178
// this intends to test that if b has a unique feature before the first iteration, then it also has it after the first iteration
179179
// however we simulate what is before the first iteration by running with neighborInfluence 0
180-
var configBuilder = HashGNNConfigImpl
180+
var configBuilder = HashGNNStreamConfigImpl
181181
.builder()
182182
.featureProperties(List.of("f1", "f2"))
183183
.embeddingDensity(embeddingDensity)
@@ -216,7 +216,7 @@ void shouldRunOnDoublesAndBeDeterministicHighNeighborInfluence() {
216216
int embeddingDensity = 100;
217217
int binarizationDimension = 8;
218218

219-
var configBuilder = HashGNNConfigImpl
219+
var configBuilder = HashGNNStreamConfigImpl
220220
.builder()
221221
.featureProperties(List.of("f1", "f2"))
222222
.embeddingDensity(embeddingDensity)
@@ -251,7 +251,7 @@ void shouldRunOnDoublesAndBeDeterministicHighNeighborInfluence() {
251251
void outputDimensionIsApplied() {
252252
int embeddingDensity = 200;
253253
double avgDegree = binaryGraph.relationshipCount() / (double) binaryGraph.nodeCount();
254-
var config = HashGNNConfigImpl
254+
var config = HashGNNStreamConfigImpl
255255
.builder()
256256
.featureProperties(List.of("f1", "f2"))
257257
.embeddingDensity(embeddingDensity)
@@ -295,7 +295,7 @@ void shouldEstimateMemory(
295295
int concurrency,
296296
long expectedMemory
297297
) {
298-
var config = HashGNNConfigImpl
298+
var config = HashGNNStreamConfigImpl
299299
.builder()
300300
.featureProperties(List.of("f1", "f2"))
301301
.embeddingDensity(embeddingDensity)
@@ -318,7 +318,7 @@ void shouldLogProgress(boolean dense) {
318318

319319
int embeddingDensity = 200;
320320
double avgDegree = g.relationshipCount() / (double) g.nodeCount();
321-
var configBuilder = HashGNNConfigImpl
321+
var configBuilder = HashGNNStreamConfigImpl
322322
.builder()
323323
.featureProperties(List.of("f1", "f2"))
324324
.embeddingDensity(embeddingDensity)
@@ -414,7 +414,7 @@ void shouldBeDeterministicGivenSameOriginalIds() {
414414
var firstGraph = GraphFactory.create(firstIdMap, firstRelationships);
415415
var secondGraph = GraphFactory.create(secondIdMap, secondRelationships);
416416

417-
var config = HashGNNConfigImpl
417+
var config = HashGNNStreamConfigImpl
418418
.builder()
419419
.embeddingDensity(8)
420420
.generateFeatures(Map.of("dimension", embeddingDimension, "densityLevel", 2))

algo/src/test/java/org/neo4j/gds/embeddings/hashgnn/HashTaskTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void shouldHash() {
3636
int NUMBER_OF_RELATIONSHIPS = 3;
3737
int EMBEDDING_DIMENSION = 10;
3838

39-
var config = HashGNNConfigImpl
39+
var config = HashGNNStreamConfigImpl
4040
.builder()
4141
.featureProperties(List.of("f1", "f2"))
4242
.iterations(ITERATIONS)

config-api/src/main/java/org/neo4j/gds/config/BaseConfig.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ default Map<String, Object> toMap() {
6060
return new HashMap<>();
6161
}
6262

63-
static Optional<String> trim(Optional<String> input) {
64-
return input.map(String::trim).filter(s -> !s.isEmpty());
63+
static String trim(String input) {
64+
return input.trim();
6565
}
6666
}

config-generator/src/main/java/org/neo4j/gds/proc/GenerateConfiguration.java

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
import javax.lang.model.element.Modifier;
4545
import javax.lang.model.element.PackageElement;
4646
import javax.lang.model.element.TypeElement;
47-
import javax.lang.model.element.VariableElement;
4847
import javax.lang.model.type.DeclaredType;
4948
import javax.lang.model.type.TypeMirror;
5049
import javax.lang.model.util.Elements;
@@ -309,7 +308,7 @@ private void addConfigFieldToConstructor(
309308
definition.configKey()
310309
);
311310
definition.defaultProvider().ifPresent(d -> code.add(", $L", d));
312-
definition.expectedType().ifPresent(t -> code.add(", $L", t));
311+
definition.expectedTypeCodeBlock().ifPresent(t -> code.add(", $L", t));
313312
CodeBlock codeBlock = code.add(")").build();
314313
for (UnaryOperator<CodeBlock> converter : definition.converters()) {
315314
codeBlock = converter.apply(codeBlock);
@@ -454,7 +453,7 @@ private Optional<MemberDefinition> memberDefinition(NameAllocator names, ConfigP
454453
TypeMirror targetType = method.getReturnType();
455454
ConvertWith convertWith = method.getAnnotation(ConvertWith.class);
456455
if (convertWith == null) {
457-
return memberDefinition(names, member, targetType);
456+
return memberDefinition(names, member, targetType, Optional.empty());
458457
}
459458

460459
String converter = convertWith.method().trim();
@@ -555,9 +554,22 @@ private Optional<MemberDefinition> memberDefinition(
555554

556555
if (validCandidates.size() == 1) {
557556
ExecutableElement candidate = validCandidates.get(0);
558-
VariableElement parameter = candidate.getParameters().get(0);
559557
TypeMirror currentType = currentClass.asType();
560-
return memberDefinition(names, member, parameter.asType())
558+
TypeMirror converterInputType = candidate.getParameters().get(0).asType();
559+
if (isTypeOf(Optional.class, targetType)) {
560+
return memberDefinition(names, member, targetType, Optional.of(converterInputType))
561+
.map(definition -> ImmutableMemberDefinition.builder()
562+
.from(definition)
563+
.addConverter(codeBlock -> CodeBlock.of(
564+
"$L.map($T::$N)",
565+
codeBlock,
566+
currentType,
567+
candidate.getSimpleName().toString()
568+
))
569+
.build()
570+
);
571+
}
572+
return memberDefinition(names, member, converterInputType, Optional.empty())
561573
.map(d -> ImmutableMemberDefinition.builder()
562574
.from(d)
563575
.addConverter(c -> CodeBlock.of(
@@ -612,6 +624,10 @@ private void validateCandidateModifiers(
612624
if (!(candidate.getParameters().size() == 1)) {
613625
invalidCandidates.add(InvalidCandidate.of(candidate, "May only accept one parameter"));
614626
}
627+
if (isTypeOf(Optional.class, targetType)) {
628+
// for a parameter declared as Optional<T>, we want to find a converter that returns T
629+
targetType = ((DeclaredType) targetType).getTypeArguments().get(0);
630+
}
615631
if (!typeUtils.isAssignable(candidate.getReturnType(), targetType)) {
616632
invalidCandidates.add(InvalidCandidate.of(
617633
candidate,
@@ -624,7 +640,8 @@ private void validateCandidateModifiers(
624640
private Optional<MemberDefinition> memberDefinition(
625641
NameAllocator names,
626642
ConfigParser.Member member,
627-
TypeMirror targetType
643+
TypeMirror targetType,
644+
Optional<TypeMirror> converterInputType
628645
) {
629646
ImmutableMemberDefinition.Builder builder = ImmutableMemberDefinition
630647
.builder()
@@ -680,18 +697,30 @@ private Optional<MemberDefinition> memberDefinition(
680697
}
681698
}
682699

683-
if (maybeInnerType.isPresent()) {
700+
if (converterInputType.isPresent()) {
701+
TypeMirror expectedType = typeUtils.erasure(converterInputType.get());
702+
var expectedWrappedInOptional = typeUtils.getDeclaredType(
703+
elementUtils.getTypeElement("java.util.Optional"),
704+
expectedType
705+
);
684706
builder
685707
.methodPrefix("get")
686708
.methodName("Optional")
687-
.expectedType(CodeBlock.of("$T.class", maybeInnerType.get()));
709+
.expectedTypeCodeBlock(CodeBlock.of("$T.class", expectedType))
710+
.expectedType(expectedType)
711+
.expectedTypeWrappedInOptional(expectedWrappedInOptional);
712+
} else if (maybeInnerType.isPresent()) {
713+
builder
714+
.methodPrefix("get")
715+
.methodName("Optional")
716+
.expectedTypeCodeBlock(CodeBlock.of("$T.class", maybeInnerType.get()));
688717
} else {
689718
return Optional.empty();
690719
}
691720
} else {
692721
builder
693722
.methodName("Checked")
694-
.expectedType(CodeBlock.of("$T.class", ClassName.get(asTypeElement(targetType))));
723+
.expectedTypeCodeBlock(CodeBlock.of("$T.class", ClassName.get(asTypeElement(targetType))));
695724
}
696725
break;
697726
default:
@@ -758,7 +787,11 @@ interface MemberDefinition {
758787

759788
Optional<CodeBlock> defaultProvider();
760789

761-
Optional<CodeBlock> expectedType();
790+
Optional<CodeBlock> expectedTypeCodeBlock();
791+
792+
Optional<TypeMirror> expectedType();
793+
794+
Optional<TypeMirror> expectedTypeWrappedInOptional();
762795

763796
List<UnaryOperator<CodeBlock>> converters();
764797
}

config-generator/src/main/java/org/neo4j/gds/proc/GenerateConfigurationBuilder.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,11 @@ private static List<MethodSpec> defineConfigMapEntrySetters(
212212
.flatMap(implMember -> {
213213
var setterMethods = Stream.<MethodSpec>builder();
214214
String configKeyName = implMember.member().methodName();
215+
TypeMirror parameterType = implMember.parameterType();
215216

216217
MethodSpec.Builder setMethodBuilder = MethodSpec.methodBuilder(configKeyName)
217218
.addModifiers(Modifier.PUBLIC)
218-
.addParameter(unpackedType(implMember.parameterType()), configKeyName)
219+
.addParameter(unpackedType(implMember.expectedType().orElse(parameterType)), configKeyName)
219220
.returns(builderClassName)
220221
.addCode(CodeBlock.builder()
221222
.addStatement(
@@ -230,12 +231,15 @@ private static List<MethodSpec> defineConfigMapEntrySetters(
230231

231232
setterMethods.add(setMethodBuilder.build());
232233

233-
if (isTypeOf(Optional.class, implMember.parameterType())) {
234+
if (isTypeOf(Optional.class, parameterType)) {
234235
String lambdaVarName = "actual" + configKeyName;
235236

236237
var optionalSetterBuilder = MethodSpec.methodBuilder(configKeyName)
237238
.addModifiers(Modifier.PUBLIC)
238-
.addParameter(TypeName.get(implMember.parameterType()), configKeyName)
239+
.addParameter(
240+
TypeName.get(implMember.expectedTypeWrappedInOptional().orElse(parameterType)),
241+
configKeyName
242+
)
239243
.returns(builderClassName)
240244
.addStatement(
241245
"$1N.ifPresent($2N -> this.$3N.put(\"$4L\", $2N))",

0 commit comments

Comments
 (0)