diff --git a/codegen/type.go b/codegen/type.go index 7af24b3c83c..689f4e4aa34 100644 --- a/codegen/type.go +++ b/codegen/type.go @@ -26,8 +26,8 @@ type Ref struct { type Type struct { *NamedType - Modifiers []string - CastType *Ref // the type to cast to when unmarshalling + Modifiers []string + AliasedType *Ref } const ( @@ -47,6 +47,9 @@ func (t Ref) PkgDot() string { } func (t Type) Signature() string { + if t.AliasedType != nil { + return strings.Join(t.Modifiers, "") + t.AliasedType.FullName() + } return strings.Join(t.Modifiers, "") + t.FullName() } @@ -125,11 +128,11 @@ func (t Type) unmarshal(result, raw string, remainingMods []string, depth int) s } realResult := result - if t.CastType != nil { + if t.AliasedType != nil { result = "castTmp" } - return tpl(`{{- if .t.CastType }} + return tpl(`{{- if .t.AliasedType }} var castTmp {{.t.FullName}} {{ end }} {{- if eq .t.GoType "map[string]interface{}" }} @@ -139,8 +142,8 @@ func (t Type) unmarshal(result, raw string, remainingMods []string, depth int) s {{- else -}} err = (&{{.result}}).UnmarshalGQL({{.raw}}) {{- end }} - {{- if .t.CastType }} - {{ .realResult }} = {{.t.CastType.FullName}}(castTmp) + {{- if .t.AliasedType }} + {{ .realResult }} = {{.t.AliasedType.FullName}}(castTmp) {{- end }}`, map[string]interface{}{ "realResult": realResult, "result": result, @@ -150,7 +153,7 @@ func (t Type) unmarshal(result, raw string, remainingMods []string, depth int) s } func (t Type) Marshal(val string) string { - if t.CastType != nil { + if t.AliasedType != nil { val = t.GoType + "(" + val + ")" } diff --git a/codegen/util.go b/codegen/util.go index 5ff41074324..5c5bd5cf0e2 100644 --- a/codegen/util.go +++ b/codegen/util.go @@ -276,7 +276,7 @@ func validateTypeBinding(imports *Imports, field *Field, goType types.Type) erro field.Type.Modifiers = modifiersFromGoType(goType) pkg, typ := pkgAndType(goType.String()) imp := imports.findByPath(pkg) - field.CastType = &Ref{GoType: typ, Import: imp} + field.AliasedType = &Ref{GoType: typ, Import: imp} return nil } diff --git a/example/scalars/generated.go b/example/scalars/generated.go index 43506e7e817..f29f6285934 100644 --- a/example/scalars/generated.go +++ b/example/scalars/generated.go @@ -454,7 +454,7 @@ func (ec *executionContext) _User_isBanned(ctx context.Context, field graphql.Co if resTmp == nil { return graphql.Null } - res := resTmp.(bool) + res := resTmp.(model.Banned) return graphql.MarshalBoolean(bool(res)) }