diff --git a/codegen/config/binder.go b/codegen/config/binder.go index a4f84fed808..514ccc6742e 100644 --- a/codegen/config/binder.go +++ b/codegen/config/binder.go @@ -151,6 +151,7 @@ func (b *Binder) PointerTo(ref *TypeReference) *TypeReference { newRef := &TypeReference{ GO: types.NewPointer(ref.GO), GQL: ref.GQL, + CastType: ref.CastType, Definition: ref.Definition, Unmarshaler: ref.Unmarshaler, Marshaler: ref.Marshaler, @@ -167,6 +168,7 @@ type TypeReference struct { GQL *ast.Type GO types.Type Target types.Type + CastType types.Type // Before calling marshalling functions cast from/to this base type Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler @@ -178,6 +180,7 @@ func (ref *TypeReference) Elem() *TypeReference { GO: p.Elem(), Target: ref.Target, GQL: ref.GQL, + CastType: ref.CastType, Definition: ref.Definition, Unmarshaler: ref.Unmarshaler, Marshaler: ref.Marshaler, @@ -190,6 +193,7 @@ func (ref *TypeReference) Elem() *TypeReference { GO: ref.GO.(*types.Slice).Elem(), Target: ref.Target, GQL: ref.GQL.Elem, + CastType: ref.CastType, Definition: ref.Definition, Unmarshaler: ref.Unmarshaler, Marshaler: ref.Marshaler, @@ -345,16 +349,27 @@ func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret return nil, err } - fun, isFunc := obj.(*types.Func) - switch { - case isFunc: + if fun, isFunc := obj.(*types.Func); isFunc { ref.GO = fun.Type().(*types.Signature).Params().At(0).Type() ref.Marshaler = fun ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil) - case hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL"): + } else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") { ref.GO = obj.Type() ref.IsMarshaler = true - default: + } else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String { + // Special case for named types wrapping strings. Used by default enum implementations. + + ref.GO = obj.Type() + ref.CastType = underlying + + underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil) + if err != nil { + return nil, err + } + + ref.Marshaler = underlyingRef.Marshaler + ref.Unmarshaler = underlyingRef.Unmarshaler + } else { ref.GO = obj.Type() } @@ -431,3 +446,19 @@ func hasMethod(it types.Type, name string) bool { } return false } + +func basicUnderlying(it types.Type) *types.Basic { + if ptr, isPtr := it.(*types.Pointer); isPtr { + it = ptr.Elem() + } + namedType, ok := it.(*types.Named) + if !ok { + return nil + } + + if basic, ok := namedType.Underlying().(*types.Basic); ok { + return basic + } + + return nil +} diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index 510e1110ee8..bc826d997b8 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -53,6 +53,7 @@ type ResolverRoot interface { User() UserResolver WrappedMap() WrappedMapResolver WrappedSlice() WrappedSliceResolver + WrappedStruct() WrappedStructResolver } type DirectiveRoot struct { @@ -277,6 +278,7 @@ type ComplexityRoot struct { EnumInInput func(childComplexity int, input *InputWithEnumValue) int ErrorBubble func(childComplexity int) int Errors func(childComplexity int) int + Fallback func(childComplexity int, arg FallbackToStringEncoding) int InputNullableSlice func(childComplexity int, arg []string) int InputSlice func(childComplexity int, arg []string) int InvalidIdentifier func(childComplexity int) int @@ -357,6 +359,7 @@ type ComplexityRoot struct { } WrappedStruct struct { + Desc func(childComplexity int) int Name func(childComplexity int) int } @@ -461,6 +464,7 @@ type QueryResolver interface { DefaultScalar(ctx context.Context, arg string) (string, error) Slices(ctx context.Context) (*Slices, error) ScalarSlice(ctx context.Context) ([]byte, error) + Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) OptionalUnion(ctx context.Context) (TestUnion, error) ValidType(ctx context.Context) (*ValidType, error) WrappedStruct(ctx context.Context) (*WrappedStruct, error) @@ -486,6 +490,10 @@ type WrappedMapResolver interface { type WrappedSliceResolver interface { Get(ctx context.Context, obj WrappedSlice, idx int) (string, error) } +type WrappedStructResolver interface { + Name(ctx context.Context, obj *WrappedStruct) (WrappedScalar, error) + Desc(ctx context.Context, obj *WrappedStruct) (*WrappedScalar, error) +} type executableSchema struct { resolvers ResolverRoot @@ -1201,6 +1209,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.Errors(childComplexity), true + case "Query.fallback": + if e.complexity.Query.Fallback == nil { + break + } + + args, err := ec.field_Query_fallback_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.Fallback(childComplexity, args["arg"].(FallbackToStringEncoding)), true + case "Query.inputNullableSlice": if e.complexity.Query.InputNullableSlice == nil { break @@ -1668,6 +1688,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.WrappedSlice.Get(childComplexity, args["idx"].(int)), true + case "WrappedStruct.desc": + if e.complexity.WrappedStruct.Desc == nil { + break + } + + return e.complexity.WrappedStruct.Desc(childComplexity), true + case "WrappedStruct.name": if e.complexity.WrappedStruct.Name == nil { break @@ -2186,6 +2213,16 @@ type Slices { } scalar Bytes +`, BuiltIn: false}, + {Name: "typefallback.graphql", Input: `extend type Query { + fallback(arg: FallbackToStringEncoding!): FallbackToStringEncoding! +} + +enum FallbackToStringEncoding { + A + B + C +} `, BuiltIn: false}, {Name: "useptr.graphql", Input: `type A { id: ID! @@ -2298,7 +2335,7 @@ extend type Query { wrappedSlice: WrappedSlice! } -type WrappedStruct { name: String! } +type WrappedStruct { name: WrappedScalar!, desc: WrappedScalar } scalar WrappedScalar type WrappedMap { get(key: String!): String! } type WrappedSlice { get(idx: Int!): String! } @@ -2688,6 +2725,21 @@ func (ec *executionContext) field_Query_enumInInput_args(ctx context.Context, ra return args, nil } +func (ec *executionContext) field_Query_fallback_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 FallbackToStringEncoding + if tmp, ok := rawArgs["arg"]; ok { + ctx := graphql.WithFieldInputContext(ctx, graphql.NewFieldInputWithField("arg")) + arg0, err = ec.unmarshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx, tmp) + if err != nil { + return nil, err + } + } + args["arg"] = arg0 + return args, nil +} + func (ec *executionContext) field_Query_inputNullableSlice_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -7271,6 +7323,44 @@ func (ec *executionContext) _Query_scalarSlice(ctx context.Context, field graphq return ec.marshalNBytes2ᚕbyte(ctx, field.Selections, res) } +func (ec *executionContext) _Query_fallback(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + rawArgs := field.ArgumentMap(ec.Variables) + args, err := ec.field_Query_fallback_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + fc.Args = args + resTmp := ec._fieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().Fallback(rctx, args["arg"].(FallbackToStringEncoding)) + }) + + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(FallbackToStringEncoding) + fc.Result = res + return ec.marshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx, field.Selections, res) +} + func (ec *executionContext) _Query_optionalUnion(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -8394,13 +8484,13 @@ func (ec *executionContext) _WrappedStruct_name(ctx context.Context, field graph Object: "WrappedStruct", Field: field, Args: nil, - IsMethod: false, + IsMethod: true, } ctx = graphql.WithFieldContext(ctx, fc) resTmp := ec._fieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return obj.Name, nil + return ec.resolvers.WrappedStruct().Name(rctx, obj) }) if resTmp == nil { @@ -8409,9 +8499,37 @@ func (ec *executionContext) _WrappedStruct_name(ctx context.Context, field graph } return graphql.Null } - res := resTmp.(string) + res := resTmp.(WrappedScalar) fc.Result = res - return ec.marshalNString2string(ctx, field.Selections, res) + return ec.marshalNWrappedScalar2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedScalar(ctx, field.Selections, res) +} + +func (ec *executionContext) _WrappedStruct_desc(ctx context.Context, field graphql.CollectedField, obj *WrappedStruct) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "WrappedStruct", + Field: field, + Args: nil, + IsMethod: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp := ec._fieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.WrappedStruct().Desc(rctx, obj) + }) + + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*WrappedScalar) + fc.Result = res + return ec.marshalOWrappedScalar2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedScalar(ctx, field.Selections, res) } func (ec *executionContext) _XXIt_id(ctx context.Context, field graphql.CollectedField, obj *XXIt) (ret graphql.Marshaler) { @@ -12033,6 +12151,20 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr } return res }) + case "fallback": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_fallback(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&invalids, 1) + } + return res + }) case "optionalUnion": field := field out.Concurrently(i, func() (res graphql.Marshaler) { @@ -12396,10 +12528,30 @@ func (ec *executionContext) _WrappedStruct(ctx context.Context, sel ast.Selectio case "__typename": out.Values[i] = graphql.MarshalString("WrappedStruct") case "name": - out.Values[i] = ec._WrappedStruct_name(ctx, field, obj) - if out.Values[i] == graphql.Null { - invalids++ - } + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._WrappedStruct_name(ctx, field, obj) + if res == graphql.Null { + atomic.AddUint32(&invalids, 1) + } + return res + }) + case "desc": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._WrappedStruct_desc(ctx, field, obj) + return res + }) default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -12859,6 +13011,22 @@ func (ec *executionContext) marshalNError2ᚖgithubᚗcomᚋ99designsᚋgqlgen return ec._Error(ctx, sel, v) } +func (ec *executionContext) unmarshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx context.Context, v interface{}) (FallbackToStringEncoding, error) { + tmp, err := graphql.UnmarshalString(v) + res := FallbackToStringEncoding(tmp) + return res, graphql.WrapErrorWithInputPath(ctx, err) +} + +func (ec *executionContext) marshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx context.Context, sel ast.SelectionSet, v FallbackToStringEncoding) graphql.Marshaler { + res := graphql.MarshalString(string(v)) + if res == graphql.Null { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + ec.Errorf(ctx, "must not be null") + } + } + return res +} + func (ec *executionContext) unmarshalNID2int(ctx context.Context, v interface{}) (int, error) { res, err := graphql.UnmarshalIntID(v) return res, graphql.WrapErrorWithInputPath(ctx, err) @@ -13329,13 +13497,19 @@ func (ec *executionContext) marshalNWrappedMap2githubᚗcomᚋ99designsᚋgqlgen } func (ec *executionContext) unmarshalNWrappedScalar2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedScalar(ctx context.Context, v interface{}) (WrappedScalar, error) { - var res WrappedScalar - err := res.UnmarshalGQL(v) + tmp, err := graphql.UnmarshalString(v) + res := WrappedScalar(tmp) return res, graphql.WrapErrorWithInputPath(ctx, err) } func (ec *executionContext) marshalNWrappedScalar2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedScalar(ctx context.Context, sel ast.SelectionSet, v WrappedScalar) graphql.Marshaler { - return v + res := graphql.MarshalString(string(v)) + if res == graphql.Null { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + ec.Errorf(ctx, "must not be null") + } + } + return res } func (ec *executionContext) marshalNWrappedSlice2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedSlice(ctx context.Context, sel ast.SelectionSet, v WrappedSlice) graphql.Marshaler { @@ -14283,6 +14457,22 @@ func (ec *executionContext) marshalOValidType2ᚖgithubᚗcomᚋ99designsᚋgqlg return ec._ValidType(ctx, sel, v) } +func (ec *executionContext) unmarshalOWrappedScalar2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedScalar(ctx context.Context, v interface{}) (*WrappedScalar, error) { + if v == nil { + return nil, nil + } + tmp, err := graphql.UnmarshalString(v) + res := WrappedScalar(tmp) + return &res, graphql.WrapErrorWithInputPath(ctx, err) +} + +func (ec *executionContext) marshalOWrappedScalar2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedScalar(ctx context.Context, sel ast.SelectionSet, v *WrappedScalar) graphql.Marshaler { + if v == nil { + return graphql.Null + } + return graphql.MarshalString(string(*v)) +} + func (ec *executionContext) marshalO__EnumValue2ᚕgithubᚗcomᚋ99designsᚋgqlgenᚋgraphqlᚋintrospectionᚐEnumValueᚄ(ctx context.Context, sel ast.SelectionSet, v []introspection.EnumValue) graphql.Marshaler { if v == nil { return graphql.Null diff --git a/codegen/testserver/resolver.go b/codegen/testserver/resolver.go index c66b8b220f2..d5d196934fa 100644 --- a/codegen/testserver/resolver.go +++ b/codegen/testserver/resolver.go @@ -263,6 +263,10 @@ func (r *queryResolver) ScalarSlice(ctx context.Context) ([]byte, error) { panic("not implemented") } +func (r *queryResolver) Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { + panic("not implemented") +} + func (r *queryResolver) OptionalUnion(ctx context.Context) (TestUnion, error) { panic("not implemented") } @@ -327,6 +331,14 @@ func (r *wrappedSliceResolver) Get(ctx context.Context, obj WrappedSlice, idx in panic("not implemented") } +func (r *wrappedStructResolver) Name(ctx context.Context, obj *WrappedStruct) (WrappedScalar, error) { + panic("not implemented") +} + +func (r *wrappedStructResolver) Desc(ctx context.Context, obj *WrappedStruct) (*WrappedScalar, error) { + panic("not implemented") +} + // BackedByInterface returns BackedByInterfaceResolver implementation. func (r *Resolver) BackedByInterface() BackedByInterfaceResolver { return &backedByInterfaceResolver{r} @@ -373,6 +385,9 @@ func (r *Resolver) WrappedMap() WrappedMapResolver { return &wrappedMapResolver{ // WrappedSlice returns WrappedSliceResolver implementation. func (r *Resolver) WrappedSlice() WrappedSliceResolver { return &wrappedSliceResolver{r} } +// WrappedStruct returns WrappedStructResolver implementation. +func (r *Resolver) WrappedStruct() WrappedStructResolver { return &wrappedStructResolver{r} } + type backedByInterfaceResolver struct{ *Resolver } type errorsResolver struct{ *Resolver } type forcedResolverResolver struct{ *Resolver } @@ -387,3 +402,4 @@ type subscriptionResolver struct{ *Resolver } type userResolver struct{ *Resolver } type wrappedMapResolver struct{ *Resolver } type wrappedSliceResolver struct{ *Resolver } +type wrappedStructResolver struct{ *Resolver } diff --git a/codegen/testserver/stub.go b/codegen/testserver/stub.go index db82863779e..88303f44165 100644 --- a/codegen/testserver/stub.go +++ b/codegen/testserver/stub.go @@ -92,6 +92,7 @@ type Stub struct { DefaultScalar func(ctx context.Context, arg string) (string, error) Slices func(ctx context.Context) (*Slices, error) ScalarSlice func(ctx context.Context) ([]byte, error) + Fallback func(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) OptionalUnion func(ctx context.Context) (TestUnion, error) ValidType func(ctx context.Context) (*ValidType, error) WrappedStruct func(ctx context.Context) (*WrappedStruct, error) @@ -117,6 +118,10 @@ type Stub struct { WrappedSliceResolver struct { Get func(ctx context.Context, obj WrappedSlice, idx int) (string, error) } + WrappedStructResolver struct { + Name func(ctx context.Context, obj *WrappedStruct) (WrappedScalar, error) + Desc func(ctx context.Context, obj *WrappedStruct) (*WrappedScalar, error) + } } func (r *Stub) BackedByInterface() BackedByInterfaceResolver { @@ -161,6 +166,9 @@ func (r *Stub) WrappedMap() WrappedMapResolver { func (r *Stub) WrappedSlice() WrappedSliceResolver { return &stubWrappedSlice{r} } +func (r *Stub) WrappedStruct() WrappedStructResolver { + return &stubWrappedStruct{r} +} type stubBackedByInterface struct{ *Stub } @@ -380,6 +388,9 @@ func (r *stubQuery) Slices(ctx context.Context) (*Slices, error) { func (r *stubQuery) ScalarSlice(ctx context.Context) ([]byte, error) { return r.QueryResolver.ScalarSlice(ctx) } +func (r *stubQuery) Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { + return r.QueryResolver.Fallback(ctx, arg) +} func (r *stubQuery) OptionalUnion(ctx context.Context) (TestUnion, error) { return r.QueryResolver.OptionalUnion(ctx) } @@ -440,3 +451,12 @@ type stubWrappedSlice struct{ *Stub } func (r *stubWrappedSlice) Get(ctx context.Context, obj WrappedSlice, idx int) (string, error) { return r.WrappedSliceResolver.Get(ctx, obj, idx) } + +type stubWrappedStruct struct{ *Stub } + +func (r *stubWrappedStruct) Name(ctx context.Context, obj *WrappedStruct) (WrappedScalar, error) { + return r.WrappedStructResolver.Name(ctx, obj) +} +func (r *stubWrappedStruct) Desc(ctx context.Context, obj *WrappedStruct) (*WrappedScalar, error) { + return r.WrappedStructResolver.Desc(ctx, obj) +} diff --git a/codegen/testserver/typefallback.graphql b/codegen/testserver/typefallback.graphql new file mode 100644 index 00000000000..e1ff1a59d7c --- /dev/null +++ b/codegen/testserver/typefallback.graphql @@ -0,0 +1,9 @@ +extend type Query { + fallback(arg: FallbackToStringEncoding!): FallbackToStringEncoding! +} + +enum FallbackToStringEncoding { + A + B + C +} diff --git a/codegen/testserver/typefallback_test.go b/codegen/testserver/typefallback_test.go new file mode 100644 index 00000000000..8ebd091e9ef --- /dev/null +++ b/codegen/testserver/typefallback_test.go @@ -0,0 +1,28 @@ +package testserver + +import ( + "context" + "testing" + + "github.com/99designs/gqlgen/client" + "github.com/99designs/gqlgen/graphql/handler" + "github.com/stretchr/testify/require" +) + +func TestTypeFallback(t *testing.T) { + resolvers := &Stub{} + + c := client.New(handler.NewDefaultServer(NewExecutableSchema(Config{Resolvers: resolvers}))) + + resolvers.QueryResolver.Fallback = func(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { + return arg, nil + } + + t.Run("fallback to string passthrough", func(t *testing.T) { + var resp struct { + Fallback string + } + c.MustPost(`query { fallback(arg: A) }`, &resp) + require.Equal(t, "A", resp.Fallback) + }) +} diff --git a/codegen/testserver/wrapped_type.go b/codegen/testserver/wrapped_type.go index bd7ea31006c..d3aa63a79b4 100644 --- a/codegen/testserver/wrapped_type.go +++ b/codegen/testserver/wrapped_type.go @@ -1,28 +1,8 @@ package testserver -import ( - "fmt" - "io" - "strconv" - - "github.com/99designs/gqlgen/codegen/testserver/otherpkg" - "github.com/99designs/gqlgen/graphql" -) +import "github.com/99designs/gqlgen/codegen/testserver/otherpkg" type WrappedScalar otherpkg.Scalar type WrappedStruct otherpkg.Struct type WrappedMap otherpkg.Map type WrappedSlice otherpkg.Slice - -func (e *WrappedScalar) UnmarshalGQL(v interface{}) error { - s, err := graphql.UnmarshalString(v) - if err != nil { - return err - } - *e = WrappedScalar(s) - return nil -} - -func (e WrappedScalar) MarshalGQL(w io.Writer) { - fmt.Fprint(w, strconv.Quote(string(e))) -} diff --git a/codegen/testserver/wrapped_type.graphql b/codegen/testserver/wrapped_type.graphql index 4f7df84a755..116147432cb 100644 --- a/codegen/testserver/wrapped_type.graphql +++ b/codegen/testserver/wrapped_type.graphql @@ -7,7 +7,7 @@ extend type Query { wrappedSlice: WrappedSlice! } -type WrappedStruct { name: String! } +type WrappedStruct { name: WrappedScalar!, desc: WrappedScalar } scalar WrappedScalar type WrappedMap { get(key: String!): String! } type WrappedSlice { get(idx: Int!): String! } diff --git a/codegen/type.gotpl b/codegen/type.gotpl index 60a87699406..2bd0c1943a1 100644 --- a/codegen/type.gotpl +++ b/codegen/type.gotpl @@ -25,7 +25,16 @@ return res, nil {{- else }} {{- if $type.Unmarshaler }} - res, err := {{ $type.Unmarshaler | call }}(v) + {{- if $type.CastType }} + tmp, err := {{ $type.Unmarshaler | call }}(v) + {{- if $type.IsNilable }} + res := {{ $type.Elem.GO | ref }}(tmp) + {{- else}} + res := {{ $type.GO | ref }}(tmp) + {{- end }} + {{- else}} + res, err := {{ $type.Unmarshaler | call }}(v) + {{- end }} {{- if and $type.IsTargetNilable (not $type.IsNilable) }} return *res, graphql.WrapErrorWithInputPath(ctx, err) {{- else if and (not $type.IsTargetNilable) $type.IsNilable }} @@ -36,7 +45,7 @@ {{- else if eq ($type.GO | ref) "map[string]interface{}" }} return v.(map[string]interface{}), nil {{- else if $type.IsMarshaler }} - {{- if $type.IsNilable }} + {{- if $type.IsNilable }} var res = new({{ $type.Elem.GO | ref }}) {{- else}} var res {{ $type.GO | ref }} @@ -123,7 +132,7 @@ {{- $v = "*v" }} {{- end }} {{- if $type.GQL.NonNull }} - res := {{ $type.Marshaler | call }}({{ $v }}) + res := {{ $type.Marshaler | call }}({{- if $type.CastType }}{{ $type.CastType | ref }}({{ $v }}){{else}}{{ $v }}{{- end }}) if res == graphql.Null { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { ec.Errorf(ctx, "must not be null") @@ -131,7 +140,7 @@ } return res {{- else }} - return {{ $type.Marshaler | call }}({{ $v }}) + return {{ $type.Marshaler | call }}({{- if $type.CastType }}{{ $type.CastType | ref }}({{ $v }}){{else}}{{ $v }}{{- end }}) {{- end }} {{- else }} return ec._{{$type.Definition.Name}}(ctx, sel, {{ if not $type.IsNilable}}&{{end}} v)