diff --git a/codegen/config/binder.go b/codegen/config/binder.go index a4f84fed808..514ccc6742e 100644 --- a/codegen/config/binder.go +++ b/codegen/config/binder.go @@ -151,6 +151,7 @@ func (b *Binder) PointerTo(ref *TypeReference) *TypeReference { newRef := &TypeReference{ GO: types.NewPointer(ref.GO), GQL: ref.GQL, + CastType: ref.CastType, Definition: ref.Definition, Unmarshaler: ref.Unmarshaler, Marshaler: ref.Marshaler, @@ -167,6 +168,7 @@ type TypeReference struct { GQL *ast.Type GO types.Type Target types.Type + CastType types.Type // Before calling marshalling functions cast from/to this base type Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler @@ -178,6 +180,7 @@ func (ref *TypeReference) Elem() *TypeReference { GO: p.Elem(), Target: ref.Target, GQL: ref.GQL, + CastType: ref.CastType, Definition: ref.Definition, Unmarshaler: ref.Unmarshaler, Marshaler: ref.Marshaler, @@ -190,6 +193,7 @@ func (ref *TypeReference) Elem() *TypeReference { GO: ref.GO.(*types.Slice).Elem(), Target: ref.Target, GQL: ref.GQL.Elem, + CastType: ref.CastType, Definition: ref.Definition, Unmarshaler: ref.Unmarshaler, Marshaler: ref.Marshaler, @@ -345,16 +349,27 @@ func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret return nil, err } - fun, isFunc := obj.(*types.Func) - switch { - case isFunc: + if fun, isFunc := obj.(*types.Func); isFunc { ref.GO = fun.Type().(*types.Signature).Params().At(0).Type() ref.Marshaler = fun ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil) - case hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL"): + } else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") { ref.GO = obj.Type() ref.IsMarshaler = true - default: + } else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String { + // Special case for named types wrapping strings. Used by default enum implementations. + + ref.GO = obj.Type() + ref.CastType = underlying + + underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil) + if err != nil { + return nil, err + } + + ref.Marshaler = underlyingRef.Marshaler + ref.Unmarshaler = underlyingRef.Unmarshaler + } else { ref.GO = obj.Type() } @@ -431,3 +446,19 @@ func hasMethod(it types.Type, name string) bool { } return false } + +func basicUnderlying(it types.Type) *types.Basic { + if ptr, isPtr := it.(*types.Pointer); isPtr { + it = ptr.Elem() + } + namedType, ok := it.(*types.Named) + if !ok { + return nil + } + + if basic, ok := namedType.Underlying().(*types.Basic); ok { + return basic + } + + return nil +} diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index 510e1110ee8..02c55293320 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -277,6 +277,7 @@ type ComplexityRoot struct { EnumInInput func(childComplexity int, input *InputWithEnumValue) int ErrorBubble func(childComplexity int) int Errors func(childComplexity int) int + Fallback func(childComplexity int, arg FallbackToStringEncoding) int InputNullableSlice func(childComplexity int, arg []string) int InputSlice func(childComplexity int, arg []string) int InvalidIdentifier func(childComplexity int) int @@ -461,6 +462,7 @@ type QueryResolver interface { DefaultScalar(ctx context.Context, arg string) (string, error) Slices(ctx context.Context) (*Slices, error) ScalarSlice(ctx context.Context) ([]byte, error) + Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) OptionalUnion(ctx context.Context) (TestUnion, error) ValidType(ctx context.Context) (*ValidType, error) WrappedStruct(ctx context.Context) (*WrappedStruct, error) @@ -1201,6 +1203,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.Errors(childComplexity), true + case "Query.fallback": + if e.complexity.Query.Fallback == nil { + break + } + + args, err := ec.field_Query_fallback_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.Fallback(childComplexity, args["arg"].(FallbackToStringEncoding)), true + case "Query.inputNullableSlice": if e.complexity.Query.InputNullableSlice == nil { break @@ -2186,6 +2200,16 @@ type Slices { } scalar Bytes +`, BuiltIn: false}, + {Name: "typefallback.graphql", Input: `extend type Query { + fallback(arg: FallbackToStringEncoding!): FallbackToStringEncoding! +} + +enum FallbackToStringEncoding { + A + B + C +} `, BuiltIn: false}, {Name: "useptr.graphql", Input: `type A { id: ID! @@ -2688,6 +2712,21 @@ func (ec *executionContext) field_Query_enumInInput_args(ctx context.Context, ra return args, nil } +func (ec *executionContext) field_Query_fallback_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 FallbackToStringEncoding + if tmp, ok := rawArgs["arg"]; ok { + ctx := graphql.WithFieldInputContext(ctx, graphql.NewFieldInputWithField("arg")) + arg0, err = ec.unmarshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx, tmp) + if err != nil { + return nil, err + } + } + args["arg"] = arg0 + return args, nil +} + func (ec *executionContext) field_Query_inputNullableSlice_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -7271,6 +7310,44 @@ func (ec *executionContext) _Query_scalarSlice(ctx context.Context, field graphq return ec.marshalNBytes2ᚕbyte(ctx, field.Selections, res) } +func (ec *executionContext) _Query_fallback(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + rawArgs := field.ArgumentMap(ec.Variables) + args, err := ec.field_Query_fallback_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + fc.Args = args + resTmp := ec._fieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().Fallback(rctx, args["arg"].(FallbackToStringEncoding)) + }) + + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(FallbackToStringEncoding) + fc.Result = res + return ec.marshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx, field.Selections, res) +} + func (ec *executionContext) _Query_optionalUnion(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -12033,6 +12110,20 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr } return res }) + case "fallback": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_fallback(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&invalids, 1) + } + return res + }) case "optionalUnion": field := field out.Concurrently(i, func() (res graphql.Marshaler) { @@ -12859,6 +12950,22 @@ func (ec *executionContext) marshalNError2ᚖgithubᚗcomᚋ99designsᚋgqlgen return ec._Error(ctx, sel, v) } +func (ec *executionContext) unmarshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx context.Context, v interface{}) (FallbackToStringEncoding, error) { + tmp, err := graphql.UnmarshalString(v) + res := FallbackToStringEncoding(tmp) + return res, graphql.WrapErrorWithInputPath(ctx, err) +} + +func (ec *executionContext) marshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx context.Context, sel ast.SelectionSet, v FallbackToStringEncoding) graphql.Marshaler { + res := graphql.MarshalString(string(v)) + if res == graphql.Null { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + ec.Errorf(ctx, "must not be null") + } + } + return res +} + func (ec *executionContext) unmarshalNID2int(ctx context.Context, v interface{}) (int, error) { res, err := graphql.UnmarshalIntID(v) return res, graphql.WrapErrorWithInputPath(ctx, err) @@ -13329,13 +13436,19 @@ func (ec *executionContext) marshalNWrappedMap2githubᚗcomᚋ99designsᚋgqlgen } func (ec *executionContext) unmarshalNWrappedScalar2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedScalar(ctx context.Context, v interface{}) (WrappedScalar, error) { - var res WrappedScalar - err := res.UnmarshalGQL(v) + tmp, err := graphql.UnmarshalString(v) + res := WrappedScalar(tmp) return res, graphql.WrapErrorWithInputPath(ctx, err) } func (ec *executionContext) marshalNWrappedScalar2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedScalar(ctx context.Context, sel ast.SelectionSet, v WrappedScalar) graphql.Marshaler { - return v + res := graphql.MarshalString(string(v)) + if res == graphql.Null { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + ec.Errorf(ctx, "must not be null") + } + } + return res } func (ec *executionContext) marshalNWrappedSlice2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedSlice(ctx context.Context, sel ast.SelectionSet, v WrappedSlice) graphql.Marshaler { diff --git a/codegen/testserver/resolver.go b/codegen/testserver/resolver.go index c66b8b220f2..e8f1a3a8599 100644 --- a/codegen/testserver/resolver.go +++ b/codegen/testserver/resolver.go @@ -263,6 +263,10 @@ func (r *queryResolver) ScalarSlice(ctx context.Context) ([]byte, error) { panic("not implemented") } +func (r *queryResolver) Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { + panic("not implemented") +} + func (r *queryResolver) OptionalUnion(ctx context.Context) (TestUnion, error) { panic("not implemented") } diff --git a/codegen/testserver/stub.go b/codegen/testserver/stub.go index db82863779e..b8580350bbf 100644 --- a/codegen/testserver/stub.go +++ b/codegen/testserver/stub.go @@ -92,6 +92,7 @@ type Stub struct { DefaultScalar func(ctx context.Context, arg string) (string, error) Slices func(ctx context.Context) (*Slices, error) ScalarSlice func(ctx context.Context) ([]byte, error) + Fallback func(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) OptionalUnion func(ctx context.Context) (TestUnion, error) ValidType func(ctx context.Context) (*ValidType, error) WrappedStruct func(ctx context.Context) (*WrappedStruct, error) @@ -380,6 +381,9 @@ func (r *stubQuery) Slices(ctx context.Context) (*Slices, error) { func (r *stubQuery) ScalarSlice(ctx context.Context) ([]byte, error) { return r.QueryResolver.ScalarSlice(ctx) } +func (r *stubQuery) Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { + return r.QueryResolver.Fallback(ctx, arg) +} func (r *stubQuery) OptionalUnion(ctx context.Context) (TestUnion, error) { return r.QueryResolver.OptionalUnion(ctx) } diff --git a/codegen/testserver/typefallback.graphql b/codegen/testserver/typefallback.graphql new file mode 100644 index 00000000000..e1ff1a59d7c --- /dev/null +++ b/codegen/testserver/typefallback.graphql @@ -0,0 +1,9 @@ +extend type Query { + fallback(arg: FallbackToStringEncoding!): FallbackToStringEncoding! +} + +enum FallbackToStringEncoding { + A + B + C +} diff --git a/codegen/testserver/typefallback_test.go b/codegen/testserver/typefallback_test.go new file mode 100644 index 00000000000..8ebd091e9ef --- /dev/null +++ b/codegen/testserver/typefallback_test.go @@ -0,0 +1,28 @@ +package testserver + +import ( + "context" + "testing" + + "github.com/99designs/gqlgen/client" + "github.com/99designs/gqlgen/graphql/handler" + "github.com/stretchr/testify/require" +) + +func TestTypeFallback(t *testing.T) { + resolvers := &Stub{} + + c := client.New(handler.NewDefaultServer(NewExecutableSchema(Config{Resolvers: resolvers}))) + + resolvers.QueryResolver.Fallback = func(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { + return arg, nil + } + + t.Run("fallback to string passthrough", func(t *testing.T) { + var resp struct { + Fallback string + } + c.MustPost(`query { fallback(arg: A) }`, &resp) + require.Equal(t, "A", resp.Fallback) + }) +} diff --git a/codegen/testserver/wrapped_type.go b/codegen/testserver/wrapped_type.go index bd7ea31006c..d3aa63a79b4 100644 --- a/codegen/testserver/wrapped_type.go +++ b/codegen/testserver/wrapped_type.go @@ -1,28 +1,8 @@ package testserver -import ( - "fmt" - "io" - "strconv" - - "github.com/99designs/gqlgen/codegen/testserver/otherpkg" - "github.com/99designs/gqlgen/graphql" -) +import "github.com/99designs/gqlgen/codegen/testserver/otherpkg" type WrappedScalar otherpkg.Scalar type WrappedStruct otherpkg.Struct type WrappedMap otherpkg.Map type WrappedSlice otherpkg.Slice - -func (e *WrappedScalar) UnmarshalGQL(v interface{}) error { - s, err := graphql.UnmarshalString(v) - if err != nil { - return err - } - *e = WrappedScalar(s) - return nil -} - -func (e WrappedScalar) MarshalGQL(w io.Writer) { - fmt.Fprint(w, strconv.Quote(string(e))) -} diff --git a/codegen/type.gotpl b/codegen/type.gotpl index 60a87699406..cf9e9149fba 100644 --- a/codegen/type.gotpl +++ b/codegen/type.gotpl @@ -25,7 +25,12 @@ return res, nil {{- else }} {{- if $type.Unmarshaler }} - res, err := {{ $type.Unmarshaler | call }}(v) + {{- if $type.CastType }} + tmp, err := {{ $type.Unmarshaler | call }}(v) + res := {{ $type.GO | ref }}(tmp) + {{- else}} + res, err := {{ $type.Unmarshaler | call }}(v) + {{- end }} {{- if and $type.IsTargetNilable (not $type.IsNilable) }} return *res, graphql.WrapErrorWithInputPath(ctx, err) {{- else if and (not $type.IsTargetNilable) $type.IsNilable }} @@ -123,7 +128,7 @@ {{- $v = "*v" }} {{- end }} {{- if $type.GQL.NonNull }} - res := {{ $type.Marshaler | call }}({{ $v }}) + res := {{ $type.Marshaler | call }}({{- if $type.CastType }}{{ $type.CastType | ref }}({{ $v }}){{else}}{{ $v }}{{- end }}) if res == graphql.Null { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { ec.Errorf(ctx, "must not be null") @@ -131,7 +136,7 @@ } return res {{- else }} - return {{ $type.Marshaler | call }}({{ $v }}) + return {{ $type.Marshaler | call }}({{- if $type.CastType }}{{ $type.CastType | ref }}({{ $v }}){{else}}{{ $v }}{{- end }}) {{- end }} {{- else }} return ec._{{$type.Definition.Name}}(ctx, sel, {{ if not $type.IsNilable}}&{{end}} v)