2121import java .util .ArrayList ;
2222import java .util .Arrays ;
2323import java .util .Collection ;
24+ import java .util .Iterator ;
2425import java .util .LinkedHashSet ;
2526import java .util .List ;
2627import java .util .Map ;
3435import java .util .stream .Collectors ;
3536
3637import graphql .execution .DataFetcherResult ;
38+ import graphql .language .ObjectTypeDefinition ;
39+ import graphql .language .Type ;
40+ import graphql .language .TypeDefinition ;
41+ import graphql .language .TypeName ;
3742import graphql .schema .DataFetcher ;
3843import graphql .schema .DataFetchingEnvironment ;
3944import graphql .schema .FieldCoordinates ;
4045import graphql .schema .GraphQLCodeRegistry ;
4146import graphql .schema .idl .RuntimeWiring ;
47+ import graphql .schema .idl .TypeDefinitionRegistry ;
4248import org .dataloader .DataLoader ;
4349import org .reactivestreams .Publisher ;
4450import reactor .core .publisher .Flux ;
7278import org .springframework .stereotype .Controller ;
7379import org .springframework .util .Assert ;
7480import org .springframework .util .ClassUtils ;
81+ import org .springframework .util .LinkedMultiValueMap ;
82+ import org .springframework .util .MultiValueMap ;
7583import org .springframework .util .StringUtils ;
7684import org .springframework .validation .DataBinder ;
7785
@@ -118,6 +126,8 @@ public class AnnotatedControllerConfigurer
118126
119127 private final List <HandlerMethodArgumentResolver > customArgumentResolvers = new ArrayList <>(8 );
120128
129+ private final InterfaceMappingHelper interfaceMappingHelper = new InterfaceMappingHelper ();
130+
121131 @ Nullable
122132 private ValidationHelper validationHelper ;
123133
@@ -133,6 +143,11 @@ public void addCustomArgumentResolver(HandlerMethodArgumentResolver resolver) {
133143 this .customArgumentResolvers .add (resolver );
134144 }
135145
146+ @ Override
147+ public void setTypeDefinitionRegistry (TypeDefinitionRegistry registry ) {
148+ this .interfaceMappingHelper .setTypeDefinitionRegistry (registry );
149+ }
150+
136151 /**
137152 * Configure an initializer that configures the {@link DataBinder} before the binding process.
138153 * @param consumer the data binder initializer
@@ -228,19 +243,16 @@ private void addSortMethodArgumentResolver(HandlerMethodArgumentResolverComposit
228243
229244 @ Override
230245 public void configure (RuntimeWiring .Builder runtimeWiringBuilder ) {
231- detectHandlerMethods ().forEach ((info ) -> {
232- DataFetcher <?> dataFetcher ;
233- if (!info .isBatchMapping ()) {
234- dataFetcher = new SchemaMappingDataFetcher (
235- info , getArgumentResolvers (), this .validationHelper , getExceptionResolver (), getExecutor ());
236- }
237- else {
238- dataFetcher = registerBatchLoader (info );
239- }
240- FieldCoordinates coordinates = info .getCoordinates ();
241- runtimeWiringBuilder .type (coordinates .getTypeName (), (typeBuilder ) ->
242- typeBuilder .dataFetcher (coordinates .getFieldName (), dataFetcher ));
243- });
246+
247+ Set <DataFetcherMappingInfo > allInfos = detectHandlerMethods ();
248+ Set <DataFetcherMappingInfo > subTypeInfos = this .interfaceMappingHelper .removeInterfaceMappings (allInfos );
249+
250+ allInfos .forEach ((info ) -> registerDataFetcher (info , runtimeWiringBuilder ));
251+
252+ RuntimeWiring wiring = runtimeWiringBuilder .build ();
253+ subTypeInfos = this .interfaceMappingHelper .filterExistingMappings (subTypeInfos , wiring .getDataFetchers ());
254+
255+ subTypeInfos .forEach ((info ) -> registerDataFetcher (info , runtimeWiringBuilder ));
244256 }
245257
246258 @ Override
@@ -313,6 +325,20 @@ protected HandlerMethod getHandlerMethod(DataFetcherMappingInfo mappingInfo) {
313325 return mappingInfo .getHandlerMethod ();
314326 }
315327
328+ private void registerDataFetcher (DataFetcherMappingInfo info , RuntimeWiring .Builder runtimeWiringBuilder ) {
329+ DataFetcher <?> dataFetcher ;
330+ if (!info .isBatchMapping ()) {
331+ dataFetcher = new SchemaMappingDataFetcher (
332+ info , getArgumentResolvers (), this .validationHelper , getExceptionResolver (), getExecutor ());
333+ }
334+ else {
335+ dataFetcher = registerBatchLoader (info );
336+ }
337+ FieldCoordinates coordinates = info .getCoordinates ();
338+ runtimeWiringBuilder .type (coordinates .getTypeName (), (typeBuilder ) ->
339+ typeBuilder .dataFetcher (coordinates .getFieldName (), dataFetcher ));
340+ }
341+
316342 private DataFetcher <Object > registerBatchLoader (DataFetcherMappingInfo info ) {
317343 if (!info .isBatchMapping ()) {
318344 throw new IllegalArgumentException ("Not a @BatchMapping method: " + info );
@@ -506,6 +532,9 @@ public String toString() {
506532 }
507533
508534
535+ /**
536+ * {@link DataFetcher} that uses a DataLoader.
537+ */
509538 static class BatchMappingDataFetcher implements DataFetcher <Object >, SelfDescribingDataFetcher <Object > {
510539
511540 private final DataFetcherMappingInfo mappingInfo ;
@@ -538,4 +567,51 @@ public Object get(DataFetchingEnvironment env) {
538567 }
539568 }
540569
570+
571+ /**
572+ * Helper to expand schema interface mappings into object type mappings.
573+ */
574+ private static final class InterfaceMappingHelper {
575+
576+ private final MultiValueMap <String , String > interfaceMappings = new LinkedMultiValueMap <>();
577+
578+ void setTypeDefinitionRegistry (TypeDefinitionRegistry registry ) {
579+ for (TypeDefinition <?> definition : registry .types ().values ()) {
580+ if (definition instanceof ObjectTypeDefinition objectDefinition ) {
581+ for (Type <?> type : objectDefinition .getImplements ()) {
582+ this .interfaceMappings .add (((TypeName ) type ).getName (), objectDefinition .getName ());
583+ }
584+ }
585+ }
586+ }
587+
588+ Set <DataFetcherMappingInfo > removeInterfaceMappings (Set <DataFetcherMappingInfo > infos ) {
589+ Set <DataFetcherMappingInfo > subTypeMappings = new LinkedHashSet <>();
590+ Iterator <DataFetcherMappingInfo > it = infos .iterator ();
591+ while (it .hasNext ()) {
592+ DataFetcherMappingInfo info = it .next ();
593+ List <String > names = this .interfaceMappings .get (info .getTypeName ());
594+ if (names != null ) {
595+ for (String name : names ) {
596+ subTypeMappings .add (new DataFetcherMappingInfo (name , info ));
597+ }
598+ it .remove ();
599+ }
600+ }
601+ return subTypeMappings ;
602+ }
603+
604+ @ SuppressWarnings ("rawtypes" )
605+ Set <DataFetcherMappingInfo > filterExistingMappings (
606+ Set <DataFetcherMappingInfo > infos , Map <String , Map <String , DataFetcher >> dataFetchers ) {
607+
608+ return infos .stream ()
609+ .filter ((info ) -> {
610+ Map <String , DataFetcher > registrations = dataFetchers .get (info .getTypeName ());
611+ return (registrations == null || !registrations .containsKey (info .getFieldName ()));
612+ })
613+ .collect (Collectors .toSet ());
614+ }
615+ }
616+
541617}
0 commit comments