diff --git a/src/definitions/input.ts b/src/definitions/input.ts index 0fd416c..8457439 100644 --- a/src/definitions/input.ts +++ b/src/definitions/input.ts @@ -28,7 +28,7 @@ export function buildInputObjectDefinition( return ""; } - const typeWillBeConsolidated = inputTypeHasMatchingOutputType(node, schema); + const typeWillBeConsolidated = inputTypeHasMatchingOutputType(schema, node); if (typeWillBeConsolidated) { return ""; } diff --git a/src/definitions/interface.ts b/src/definitions/interface.ts index bf302c0..3f54f8e 100644 --- a/src/definitions/interface.ts +++ b/src/definitions/interface.ts @@ -36,7 +36,12 @@ export function buildInterfaceDefinition( config, definitionNode: fieldNode, }); - const fieldDefinition = buildFieldDefinition(fieldNode, node, config); + const fieldDefinition = buildFieldDefinition( + fieldNode, + node, + schema, + config, + ); const fieldText = indent( `${fieldDefinition}: ${typeToUse.typeName}${ typeToUse.isNullable ? "?" : "" diff --git a/src/definitions/object.ts b/src/definitions/object.ts index c16d3ef..63ed3ff 100644 --- a/src/definitions/object.ts +++ b/src/definitions/object.ts @@ -59,9 +59,10 @@ ${getDataClassMembers({ node, schema, config, completableFuture: true })} } const potentialMatchingInputType = schema.getType(`${name}Input`)?.astNode; - const typeWillBeConsolidated = - potentialMatchingInputType?.kind === Kind.INPUT_OBJECT_TYPE_DEFINITION && - inputTypeHasMatchingOutputType(potentialMatchingInputType, schema); + const typeWillBeConsolidated = inputTypeHasMatchingOutputType( + schema, + potentialMatchingInputType, + ); const outputRestrictionAnnotation = typeWillBeConsolidated ? "" : "@GraphQLValidObjectLocations(locations = [GraphQLValidObjectLocations.Locations.OBJECT])\n"; @@ -85,7 +86,7 @@ function getDataClassMembers({ return node.fields ?.map((fieldNode) => { - const typeToUse = buildTypeMetadata(fieldNode.type, schema, config); + const typeMetadata = buildTypeMetadata(fieldNode.type, schema, config); const shouldOverrideField = !completableFuture && node.interfaces?.some((i) => { @@ -98,11 +99,12 @@ function getDataClassMembers({ const fieldDefinition = buildFieldDefinition( fieldNode, node, + schema, config, completableFuture, ); - const completableFutureDefinition = `java.util.concurrent.CompletableFuture<${typeToUse.typeName}${typeToUse.isNullable ? "?" : ""}>`; - const defaultDefinition = `${typeToUse.typeName}${isExternalField(fieldNode) ? (typeToUse.isNullable ? "?" : "") : typeToUse.defaultValue}`; + const completableFutureDefinition = `java.util.concurrent.CompletableFuture<${typeMetadata.typeName}${typeMetadata.isNullable ? "?" : ""}>`; + const defaultDefinition = `${typeMetadata.typeName}${isExternalField(fieldNode) ? (typeMetadata.isNullable ? "?" : "") : typeMetadata.defaultValue}`; const field = indent( `${shouldOverrideField ? "override " : ""}${fieldDefinition}: ${completableFuture ? completableFutureDefinition : defaultDefinition}`, 2, @@ -110,7 +112,7 @@ function getDataClassMembers({ const annotations = buildAnnotations({ config, definitionNode: fieldNode, - resolvedType: typeToUse, + typeMetadata, }); return `${annotations}${field}`; }) diff --git a/src/helpers/build-annotations.ts b/src/helpers/build-annotations.ts index b74d737..2c99b5e 100644 --- a/src/helpers/build-annotations.ts +++ b/src/helpers/build-annotations.ts @@ -32,25 +32,25 @@ export type DefinitionNode = export function buildAnnotations({ config, definitionNode, - resolvedType, + typeMetadata, }: { config: CodegenConfigWithDefaults; definitionNode: DefinitionNode; - resolvedType?: TypeMetadata; + typeMetadata?: TypeMetadata; }) { const description = definitionNode?.description?.value ?? ""; const descriptionAnnotation = buildDescriptionAnnotation( description, definitionNode, config, - resolvedType, + typeMetadata, ); const directiveAnnotations = buildDirectiveAnnotations( definitionNode, config, ); - const unionAnnotation = resolvedType?.unionAnnotation - ? `@${resolvedType.unionAnnotation}\n` + const unionAnnotation = typeMetadata?.unionAnnotation + ? `@${typeMetadata.unionAnnotation}\n` : ""; const annotations = [ diff --git a/src/helpers/build-description-annotation.ts b/src/helpers/build-description-annotation.ts index 24129b9..9b59dd8 100644 --- a/src/helpers/build-description-annotation.ts +++ b/src/helpers/build-description-annotation.ts @@ -8,13 +8,13 @@ export function buildDescriptionAnnotation( description: string, definitionNode: DefinitionNode, config: CodegenConfigWithDefaults, - resolvedType?: TypeMetadata, + typeMetadata?: TypeMetadata, ) { const trimmedDescription = trimDescription(description); const isDeprecatedDescription = trimmedDescription.startsWith( deprecatedDescriptionPrefix, ); - if (isDeprecatedDescription && resolvedType?.unionAnnotation) { + if (isDeprecatedDescription && typeMetadata?.unionAnnotation) { return `@GraphQLDescription("${trimmedDescription}")\n`; } else if (isDeprecatedDescription) { const descriptionValue = description.replace( @@ -36,7 +36,7 @@ export function buildDescriptionAnnotation( : ""; const trimmedDeprecatedReason = trimDescription(deprecatedReason); - if (deprecatedDirective && resolvedType?.unionAnnotation) { + if (deprecatedDirective && typeMetadata?.unionAnnotation) { return `@GraphQLDescription("${trimmedDeprecatedReason}")\n`; } else if (deprecatedDirective) { const graphqlDescription = trimmedDescription diff --git a/src/helpers/build-field-definition.ts b/src/helpers/build-field-definition.ts index ca7b22c..8d4ddcb 100644 --- a/src/helpers/build-field-definition.ts +++ b/src/helpers/build-field-definition.ts @@ -11,10 +11,10 @@ See the License for the specific language governing permissions and limitations under the License. */ -import { buildListType } from "./build-type-metadata"; -import { getFieldTypeName } from "./dependent-type-utils"; +import { buildTypeMetadata } from "./build-type-metadata"; import { FieldDefinitionNode, + GraphQLSchema, InterfaceTypeDefinitionNode, Kind, ObjectTypeDefinitionNode, @@ -26,6 +26,7 @@ import { CodegenConfigWithDefaults } from "./build-config-with-defaults"; export function buildFieldDefinition( fieldNode: FieldDefinitionNode, definitionNode: ObjectTypeDefinitionNode | InterfaceTypeDefinitionNode, + schema: GraphQLSchema, config: CodegenConfigWithDefaults, completableFuture?: boolean, ) { @@ -36,10 +37,10 @@ export function buildFieldDefinition( ? "fun" : "suspend fun" : "val"; - const existingFieldArguments = fieldNode.arguments?.map( - (arg) => - `${arg.name.value}: ${buildListType(arg.type, getFieldTypeName(arg.type))}${arg.type.kind === Kind.NON_NULL_TYPE ? "" : "?"}`, - ); + const existingFieldArguments = fieldNode.arguments?.map((arg) => { + const typeMetadata = buildTypeMetadata(arg.type, schema, config); + return `${arg.name.value}: ${typeMetadata.typeName}${arg.type.kind === Kind.NON_NULL_TYPE ? "" : "?"}`; + }); const additionalFieldArguments = config.extraResolverArguments ?.map(({ typeNames, argumentType, argumentName }) => { const shouldIncludeArg = diff --git a/src/helpers/build-type-metadata.ts b/src/helpers/build-type-metadata.ts index a704738..ab72c37 100644 --- a/src/helpers/build-type-metadata.ts +++ b/src/helpers/build-type-metadata.ts @@ -22,6 +22,10 @@ import { import { getBaseTypeNode } from "@graphql-codegen/visitor-plugin-common"; import { wrapTypeWithModifiers } from "@graphql-codegen/java-common"; import { CodegenConfigWithDefaults } from "./build-config-with-defaults"; +import { + getTypeNameWithoutInput, + inputTypeHasMatchingOutputType, +} from "./input-type-has-matching-output-type"; export interface TypeMetadata { typeName: string; @@ -73,14 +77,21 @@ export function buildTypeMetadata( ), }; } else { + const typeWillBeConsolidated = inputTypeHasMatchingOutputType( + schema, + schemaType.astNode, + ); + const typeName = typeWillBeConsolidated + ? getTypeNameWithoutInput(schemaType.name) + : schemaType.name; return { ...commonMetadata, - typeName: buildListType(typeNode, schemaType.name), + typeName: buildListType(typeNode, typeName), }; } } -export function buildListType(typeNode: TypeNode, typeName: string) { +function buildListType(typeNode: TypeNode, typeName: string) { const isNullable = typeNode.kind !== Kind.NON_NULL_TYPE; const listTypeNullableWithNullableMember = typeNode.kind == Kind.LIST_TYPE && diff --git a/src/helpers/dependent-type-utils.ts b/src/helpers/dependent-type-utils.ts index f9da022..4a3099a 100644 --- a/src/helpers/dependent-type-utils.ts +++ b/src/helpers/dependent-type-utils.ts @@ -19,6 +19,7 @@ import { TypeNode, } from "graphql"; import { CodegenConfigWithDefaults } from "./build-config-with-defaults"; +import { getBaseTypeNode } from "@graphql-codegen/visitor-plugin-common"; export function getDependentFieldTypeNames( node: TypeDefinitionNode, @@ -35,20 +36,8 @@ export function getDependentFieldTypeNames( : []; } -export function getFieldTypeName(fieldType: TypeNode) { - switch (fieldType.kind) { - case Kind.NAMED_TYPE: - return fieldType.name.value; - case Kind.LIST_TYPE: - return getFieldTypeName(fieldType.type); - case Kind.NON_NULL_TYPE: - switch (fieldType.type.kind) { - case Kind.NAMED_TYPE: - return fieldType.type.name.value; - case Kind.LIST_TYPE: - return getFieldTypeName(fieldType.type.type); - } - } +function getFieldTypeName(fieldType: TypeNode) { + return getBaseTypeNode(fieldType).name.value; } export function getDependentInterfaceNames(node: TypeDefinitionNode) { diff --git a/src/helpers/input-type-has-matching-output-type.ts b/src/helpers/input-type-has-matching-output-type.ts index 590657e..3ede0cd 100644 --- a/src/helpers/input-type-has-matching-output-type.ts +++ b/src/helpers/input-type-has-matching-output-type.ts @@ -1,19 +1,21 @@ import { Kind, TypeNode } from "graphql/index"; -import { GraphQLSchema, InputObjectTypeDefinitionNode } from "graphql"; +import { GraphQLSchema, TypeDefinitionNode } from "graphql"; export function inputTypeHasMatchingOutputType( - inputTypeNode: InputObjectTypeDefinitionNode, schema: GraphQLSchema, + typeNode?: TypeDefinitionNode | null, ) { - const typeNameWithoutInput = getTypeNameWithoutInput( - inputTypeNode.name.value, - ); + if (typeNode?.kind !== Kind.INPUT_OBJECT_TYPE_DEFINITION) { + return false; + } + + const typeNameWithoutInput = getTypeNameWithoutInput(typeNode.name.value); const matchingType = schema.getType(typeNameWithoutInput)?.astNode; const matchingTypeFields = matchingType?.kind === Kind.OBJECT_TYPE_DEFINITION ? matchingType.fields : []; - const inputFields = inputTypeNode.fields; + const inputFields = typeNode.fields; const fieldsMatch = matchingTypeFields?.every((field) => { const matchingInputField = inputFields?.find( (inputField) => inputField.name.value === field.name.value, @@ -24,7 +26,7 @@ export function inputTypeHasMatchingOutputType( return Boolean(matchingTypeFields?.length && fieldsMatch); } -function getTypeNameWithoutInput(name: string) { +export function getTypeNameWithoutInput(name: string) { return name.endsWith("Input") ? name.replace("Input", "") : name; } diff --git a/test/unit/should_consolidate_input_and_output_types/expected.kt b/test/unit/should_consolidate_input_and_output_types/expected.kt index ecb33e7..1bc73c0 100644 --- a/test/unit/should_consolidate_input_and_output_types/expected.kt +++ b/test/unit/should_consolidate_input_and_output_types/expected.kt @@ -57,3 +57,23 @@ data class MyTypeWhereFieldsDoNotMatchInput( val field: String? = null, val field2: Int? = null ) + +@GraphQLValidObjectLocations(locations = [GraphQLValidObjectLocations.Locations.OBJECT]) +data class MyTypeToConsolidateParent( + val field: MyTypeToConsolidate? = null +) + +@GraphQLValidObjectLocations(locations = [GraphQLValidObjectLocations.Locations.INPUT_OBJECT]) +data class MyTypeToConsolidateInputParent( + val field: MyTypeToConsolidate? = null +) + +@GraphQLIgnore +interface MyTypeToConsolidateParent2 { + suspend fun field(input: MyTypeToConsolidate): String? = null +} + +@GraphQLIgnore +interface MyTypeToConsolidateParent2CompletableFuture { + fun field(input: MyTypeToConsolidate): java.util.concurrent.CompletableFuture +} diff --git a/test/unit/should_consolidate_input_and_output_types/schema.graphql b/test/unit/should_consolidate_input_and_output_types/schema.graphql index 2cacc6a..820160e 100644 --- a/test/unit/should_consolidate_input_and_output_types/schema.graphql +++ b/test/unit/should_consolidate_input_and_output_types/schema.graphql @@ -84,3 +84,17 @@ input MyTypeWhereFieldsDoNotMatchInput { field: String field2: Int } + +# case where parent types reference consolidated types + +type MyTypeToConsolidateParent { + field: MyTypeToConsolidate +} + +input MyTypeToConsolidateInputParent { + field: MyTypeToConsolidateInput +} + +type MyTypeToConsolidateParent2 { + field(input: MyTypeToConsolidateInput!): String +}