Skip to content

Commit 8e8589d

Browse files
author
Thomas Draier
committed
Fixed reusage of input types
1 parent 6bd6feb commit 8e8589d

File tree

5 files changed

+185
-6
lines changed

5 files changed

+185
-6
lines changed

src/main/java/graphql/annotations/DefaultTypeFunction.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,12 @@ public boolean canBuildType(Class<?> aClass, AnnotatedType annotatedType) {
260260

261261
@Override
262262
public GraphQLType buildType(String typeName, Class<?> aClass, AnnotatedType annotatedType) {
263-
return annotationsProcessor.getOutputTypeOrRef(aClass);
263+
try {
264+
return annotationsProcessor.getOutputTypeOrRef(aClass);
265+
} catch (ClassCastException e) {
266+
// Also try to resolve to input object
267+
return annotationsProcessor.getInputObject(aClass);
268+
}
264269
}
265270
}
266271

src/main/java/graphql/annotations/GraphQLAnnotations.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public class GraphQLAnnotations implements GraphQLAnnotationsProcessor {
5151

5252
private static final List<Class> TYPES_FOR_CONNECTION = Arrays.asList(GraphQLObjectType.class, GraphQLInterfaceType.class, GraphQLUnionType.class, GraphQLTypeReference.class);
5353

54-
private Map<String, graphql.schema.GraphQLOutputType> typeRegistry = new HashMap<>();
54+
private Map<String, graphql.schema.GraphQLType> typeRegistry = new HashMap<>();
5555
private Map<Class<?>, Set<Class<?>>> extensionsTypeRegistry = new HashMap<>();
5656
private final Stack<String> processing = new Stack<>();
5757
private Relay relay = new Relay();
@@ -256,7 +256,7 @@ public GraphQLOutputType getOutputType(Class<?> object) throws GraphQLAnnotation
256256
// all type instances to be unique singletons
257257
String typeName = getTypeName(object);
258258

259-
GraphQLOutputType type = typeRegistry.get(typeName);
259+
GraphQLOutputType type = (GraphQLOutputType) typeRegistry.get(typeName);
260260
if (type != null) { // type already exists, do not build a new new one
261261
return type;
262262
}
@@ -627,8 +627,9 @@ protected GraphQLFieldDefinition getField(Method method) throws GraphQLAnnotatio
627627
Class<?> t = parameter.getType();
628628
graphql.schema.GraphQLType graphQLType = finalTypeFunction.buildType(t, parameter.getAnnotatedType());
629629
if (graphQLType instanceof GraphQLObjectType) {
630-
GraphQLInputObjectType inputObject = getInputObject((GraphQLObjectType) graphQLType, "input");
630+
GraphQLInputObjectType inputObject = getInputObject((GraphQLObjectType) graphQLType, "");
631631
graphQLType = inputObject;
632+
typeRegistry.put(inputObject.getName(), inputObject);
632633
}
633634
return getArgument(parameter, graphQLType);
634635
}).collect(Collectors.toList());
@@ -695,6 +696,18 @@ protected static GraphQLFieldDefinition field(Method method) throws Instantiatio
695696

696697
}
697698

699+
public GraphQLInputObjectType getInputObject(Class<?> object) {
700+
String typeName = getTypeName(object);
701+
if (typeRegistry.containsKey(typeName)) {
702+
return (GraphQLInputObjectType) typeRegistry.get(typeName);
703+
} else {
704+
graphql.schema.GraphQLType graphQLType = getObject(object);
705+
GraphQLInputObjectType inputObject = getInputObject((GraphQLObjectType) graphQLType, "");
706+
typeRegistry.put(inputObject.getName(), inputObject);
707+
return inputObject;
708+
}
709+
}
710+
698711
@Override
699712
public GraphQLInputObjectType getInputObject(GraphQLObjectType graphQLType, String newNamePrefix) {
700713
GraphQLObjectType object = graphQLType;
@@ -780,7 +793,7 @@ public static void register(TypeFunction typeFunction) {
780793
getInstance().registerType(typeFunction);
781794
}
782795

783-
public Map<String, graphql.schema.GraphQLOutputType> getTypeRegistry() {
796+
public Map<String, graphql.schema.GraphQLType> getTypeRegistry() {
784797
return typeRegistry;
785798
}
786799

src/main/java/graphql/annotations/GraphQLAnnotationsProcessor.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,17 @@ public interface GraphQLAnnotationsProcessor {
105105
*/
106106
GraphQLObjectType.Builder getObjectBuilder(Class<?> object) throws GraphQLAnnotationsException;
107107

108+
/**
109+
* This will examine the object class and return a {@link GraphQLInputType} representation
110+
*
111+
* @param object the object class to examine
112+
*
113+
* @return a {@link GraphQLInputType} that represents that object class
114+
*
115+
* @throws GraphQLAnnotationsException if the object class cannot be examined
116+
*/
117+
GraphQLInputObjectType getInputObject(Class<?> object) throws GraphQLAnnotationsException;
118+
108119
/**
109120
* This will turn a {@link GraphQLObjectType} into a corresponding {@link GraphQLInputObjectType}
110121
*

src/main/java/graphql/annotations/MethodDataFetcher.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import graphql.schema.DataFetcher;
1818
import graphql.schema.DataFetchingEnvironment;
19+
import graphql.schema.GraphQLInputObjectType;
1920
import graphql.schema.GraphQLObjectType;
2021

2122
import java.lang.reflect.Constructor;
@@ -78,7 +79,7 @@ private Object[] invocationArgs(DataFetchingEnvironment environment) {
7879
continue;
7980
}
8081
graphql.schema.GraphQLType graphQLType = typeFunction.buildType(paramType, p.getAnnotatedType());
81-
if (graphQLType instanceof GraphQLObjectType) {
82+
if (graphQLType instanceof GraphQLInputObjectType) {
8283
Constructor<?> constructor = constructor(paramType, HashMap.class);
8384
result.add(constructNewInstance(constructor, envArgs.next()));
8485

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/**
2+
* Copyright 2016 Yurii Rashkovskii
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
*/
15+
package graphql.annotations;
16+
17+
import graphql.ExecutionResult;
18+
import graphql.GraphQL;
19+
import graphql.TypeResolutionEnvironment;
20+
import graphql.schema.*;
21+
import org.testng.annotations.Test;
22+
23+
import java.util.Collections;
24+
import java.util.HashMap;
25+
import java.util.List;
26+
import java.util.Map;
27+
28+
import static graphql.schema.GraphQLSchema.newSchema;
29+
import static org.testng.Assert.assertEquals;
30+
import static org.testng.Assert.assertTrue;
31+
32+
@SuppressWarnings("unchecked")
33+
public class GraphQLInputTest {
34+
35+
public static class Resolver implements TypeResolver {
36+
37+
@Override
38+
public GraphQLObjectType getType(TypeResolutionEnvironment env) {
39+
try {
40+
return GraphQLAnnotations.object(TestObject.class);
41+
} catch (GraphQLAnnotationsException e) {
42+
return null;
43+
}
44+
}
45+
}
46+
47+
static class InputObject {
48+
public InputObject(HashMap map) {
49+
key = (String) map.get("key");
50+
}
51+
52+
@GraphQLField
53+
private String key;
54+
}
55+
56+
static class RecursiveInputObject {
57+
public RecursiveInputObject(HashMap map) {
58+
key = (String) map.get("key");
59+
if (map.containsKey("rec")) {
60+
rec = new RecursiveInputObject((HashMap) map.get("rec"));
61+
}
62+
}
63+
64+
@GraphQLField
65+
private String key;
66+
67+
@GraphQLField
68+
private RecursiveInputObject rec;
69+
}
70+
71+
@GraphQLTypeResolver(Resolver.class)
72+
interface TestIface {
73+
@GraphQLField
74+
String value(InputObject input);
75+
}
76+
77+
static class TestObject implements TestIface {
78+
79+
@Override
80+
public String value(InputObject input) {
81+
return input.key + "a";
82+
}
83+
}
84+
85+
static class TestObjectRec {
86+
@GraphQLField
87+
public String value(RecursiveInputObject input) {
88+
return (input.rec != null ? ("rec"+input.rec.key) : input.key) + "a";
89+
}
90+
}
91+
92+
static class Query {
93+
@GraphQLField
94+
public TestIface object() {
95+
return new TestObject();
96+
};
97+
}
98+
99+
static class QueryRecursion {
100+
@GraphQLField
101+
public TestObjectRec object() {
102+
return new TestObjectRec();
103+
};
104+
}
105+
106+
static class QueryIface {
107+
@GraphQLField
108+
public TestObject iface() {
109+
return new TestObject();
110+
}
111+
}
112+
113+
114+
@Test
115+
public void query() {
116+
GraphQLSchema schema = newSchema().query(GraphQLAnnotations.object(Query.class)).build();
117+
118+
GraphQL graphQL = GraphQL.newGraphQL(schema).build();
119+
ExecutionResult result = graphQL.execute("{ object { value(input:{key:\"test\"}) } }", new Query());
120+
assertTrue(result.getErrors().isEmpty());
121+
assertEquals(((Map<String, Map<String, String>>) result.getData()).get("object").get("value"), "testa");
122+
}
123+
124+
@Test
125+
public void queryWithRecursion() {
126+
GraphQLSchema schema = newSchema().query(GraphQLAnnotations.object(QueryRecursion.class)).build();
127+
128+
GraphQL graphQL = GraphQL.newGraphQL(schema).build();
129+
ExecutionResult result = graphQL.execute("{ object { value(input:{key:\"test\"}) } }", new QueryRecursion());
130+
assertTrue(result.getErrors().isEmpty());
131+
assertEquals(((Map<String, Map<String, String>>) result.getData()).get("object").get("value"), "testa");
132+
133+
result = graphQL.execute("{ object { value(input:{rec:{key:\"test\"}}) } }", new QueryRecursion());
134+
assertTrue(result.getErrors().isEmpty());
135+
assertEquals(((Map<String, Map<String, String>>) result.getData()).get("object").get("value"), "rectesta");
136+
}
137+
138+
@Test
139+
public void queryWithInterface() {
140+
GraphQLSchema schema = newSchema().query(GraphQLAnnotations.object(QueryIface.class)).build(Collections.singleton(GraphQLAnnotations.object(TestObject.class)));
141+
142+
GraphQL graphQL = GraphQL.newGraphQL(schema).build();
143+
ExecutionResult result = graphQL.execute("{ iface { value(input:{key:\"test\"}) } }", new QueryIface());
144+
assertTrue(result.getErrors().isEmpty());
145+
assertEquals(((Map<String, Map<String, String>>) result.getData()).get("iface").get("value"), "testa");
146+
}
147+
148+
149+
}

0 commit comments

Comments
 (0)