Skip to content

Commit

Permalink
Merge pull request 99designs#20 from vektah/codegen-cleanup
Browse files Browse the repository at this point in the history
Codegen cleanup
  • Loading branch information
vektah authored Feb 22, 2018
2 parents d11d7b8 + b7754da commit 2105587
Show file tree
Hide file tree
Showing 21 changed files with 3,253 additions and 2,846 deletions.
49 changes: 12 additions & 37 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (f *Field) ResolverDeclaration() string {
res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.GQLType, f.GQLName)

if !f.Object.Root {
res += fmt.Sprintf(", it *%s", f.Object.FullName())
res += fmt.Sprintf(", obj *%s", f.Object.FullName())
}
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.GQLName, arg.Signature())
Expand All @@ -86,7 +86,7 @@ func (f *Field) CallArgs() string {
args = append(args, "ec.ctx")

if !f.Object.Root {
args = append(args, "it")
args = append(args, "obj")
}
}

Expand All @@ -98,59 +98,44 @@ func (f *Field) CallArgs() string {
}

// should be in the template, but its recursive and has a bunch of args
func (f *Field) WriteJson(res string) string {
return f.doWriteJson(res, "res", f.Type.Modifiers, false, 1)
func (f *Field) WriteJson() string {
return f.doWriteJson("res", f.Type.Modifiers, false, 1)
}

func (f *Field) doWriteJson(res string, val string, remainingMods []string, isPtr bool, depth int) string {
func (f *Field) doWriteJson(val string, remainingMods []string, isPtr bool, depth int) string {
switch {
case len(remainingMods) > 0 && remainingMods[0] == modPtr:
return tpl(`
if {{.val}} == nil {
{{.res}} = graphql.Null
} else {
{{.next}}
}`, map[string]interface{}{
"res": res,
"val": val,
"next": f.doWriteJson(res, val, remainingMods[1:], true, depth+1),
})
return fmt.Sprintf("if %s == nil { return graphql.Null }\n%s", val, f.doWriteJson(val, remainingMods[1:], true, depth+1))

case len(remainingMods) > 0 && remainingMods[0] == modList:
if isPtr {
val = "*" + val
}
var tmp = "tmp" + strconv.Itoa(depth)
var arr = "arr" + strconv.Itoa(depth)
var index = "idx" + strconv.Itoa(depth)

return tpl(`
{{.arr}} := graphql.Array{}
return tpl(`{{.arr}} := graphql.Array{}
for {{.index}} := range {{.val}} {
var {{.tmp}} graphql.Marshaler
{{.next}}
{{.arr}} = append({{.arr}}, {{.tmp}})
{{.arr}} = append({{.arr}}, func() graphql.Marshaler { {{ .next }} }())
}
{{.res}} = {{.arr}}`, map[string]interface{}{
"res": res,
return {{.arr}}`, map[string]interface{}{
"val": val,
"tmp": tmp,
"arr": arr,
"index": index,
"next": f.doWriteJson(tmp, val+"["+index+"]", remainingMods[1:], false, depth+1),
"next": f.doWriteJson(val+"["+index+"]", remainingMods[1:], false, depth+1),
})

case f.IsScalar:
if isPtr {
val = "*" + val
}
return f.Marshal(res, val)
return f.Marshal(val)

default:
if !isPtr {
val = "&" + val
}
return fmt.Sprintf("%s = ec._%s(field.Selections, %s)", res, lcFirst(f.GQLType), val)
return fmt.Sprintf("return ec._%s(field.Selections, %s)", f.GQLType, val)
}
}

Expand All @@ -169,16 +154,6 @@ func tpl(tpl string, vars map[string]interface{}) string {
return b.String()
}

func lcFirst(s string) string {
if s == "" {
return ""
}

r := []rune(s)
r[0] = unicode.ToLower(r[0])
return string(r)
}

func ucFirst(s string) string {
if s == "" {
return ""
Expand Down
7 changes: 4 additions & 3 deletions codegen/templates/data.go

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions codegen/templates/field.gotpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{{ $field := . }}
{{ $object := $field.Object }}

{{- if $object.Stream }}
func (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(field graphql.CollectedField) func() graphql.Marshaler {
{{- template "args.gotpl" $field.Args }}
results, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }})
if err != nil {
ec.Error(err)
return nil
}
return func() graphql.Marshaler {
res, ok := <-results
if !ok {
return nil
}
var out graphql.OrderedMap
out.Add(field.Alias, func() graphql.Marshaler { {{ $field.WriteJson }} }())
return &out
}
}
{{ else }}
func (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(field graphql.CollectedField, {{if not $object.Root}}obj *{{$object.FullName}}{{end}}) graphql.Marshaler {
{{- template "args.gotpl" $field.Args }}

{{- if $field.IsConcurrent }}
return graphql.Defer(func() graphql.Marshaler {
{{- end }}

{{- if $field.GoVarName }}
res := obj.{{$field.GoVarName}}
{{- else if $field.GoMethodName }}
{{- if $field.NoErr }}
res := {{$field.GoMethodName}}({{ $field.CallArgs }})
{{- else }}
res, err := {{$field.GoMethodName}}({{ $field.CallArgs }})
if err != nil {
ec.Error(err)
return graphql.Null
}
{{- end }}
{{- else }}
res, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }})
if err != nil {
ec.Error(err)
return graphql.Null
}
{{- end }}
{{ $field.WriteJson }}
{{- if $field.IsConcurrent }}
})
{{- end }}
}
{{ end }}
62 changes: 32 additions & 30 deletions codegen/templates/file.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ func (e *executableSchema) Query(ctx context.Context, doc *query.Document, varia
{{- if .QueryRoot }}
ec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx}

data := ec._{{.QueryRoot.GQLType|lcFirst}}(op.Selections)
ec.wg.Wait()
data := ec._{{.QueryRoot.GQLType}}(op.Selections)
var buf bytes.Buffer
data.MarshalGQL(&buf)

return &graphql.Response{
Data: data,
Data: buf.Bytes(),
Errors: ec.Errors,
}
{{- else }}
Expand All @@ -52,48 +53,46 @@ func (e *executableSchema) Mutation(ctx context.Context, doc *query.Document, va
{{- if .MutationRoot }}
ec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx}

data := ec._{{.MutationRoot.GQLType|lcFirst}}(op.Selections)
ec.wg.Wait()
data := ec._{{.MutationRoot.GQLType}}(op.Selections)
var buf bytes.Buffer
data.MarshalGQL(&buf)

return &graphql.Response{
Data: data,
Data: buf.Bytes(),
Errors: ec.Errors,
}
{{- else }}
return &graphql.Response{Errors: []*errors.QueryError{ {Message: "mutations are not supported"} }}
{{- end }}
}

func (e *executableSchema) Subscription(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation) <-chan *graphql.Response {
func (e *executableSchema) Subscription(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation) func() *graphql.Response {
{{- if .SubscriptionRoot }}
events := make(chan *graphql.Response, 10)

ec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx}

eventData := ec._{{.SubscriptionRoot.GQLType|lcFirst}}(op.Selections)
next := ec._{{.SubscriptionRoot.GQLType}}(op.Selections)
if ec.Errors != nil {
events<-&graphql.Response{
Data: graphql.Null,
Errors: ec.Errors,
return graphql.OneShot(&graphql.Response{Data: []byte("null"), Errors: ec.Errors})
}

var buf bytes.Buffer
return func() *graphql.Response {
buf.Reset()
data := next()
if data == nil {
return nil
}
data.MarshalGQL(&buf)

errs := ec.Errors
ec.Errors = nil
return &graphql.Response{
Data: buf.Bytes(),
Errors: errs,
}
close(events)
} else {
go func() {
for data := range eventData {
ec.wg.Wait()
events <- &graphql.Response{
Data: data,
Errors: ec.Errors,
}
time.Sleep(20 * time.Millisecond)
}
}()
}
return events
{{- else }}
events := make(chan *graphql.Response, 1)
events<-&graphql.Response{Errors: []*errors.QueryError{ {Message: "subscriptions are not supported"} }}
return events
return graphql.OneShot(&graphql.Response{Errors: []*errors.QueryError{ {Message: "subscriptions are not supported"} }})
{{- end }}
}

Expand All @@ -103,11 +102,14 @@ type executionContext struct {
variables map[string]interface{}
doc *query.Document
ctx context.Context
wg sync.WaitGroup
}

{{- range $object := .Objects }}
{{ template "object.gotpl" $object }}

{{- range $field := $object.Fields }}
{{ template "field.gotpl" $field }}
{{ end }}
{{- end}}

{{- range $interface := .Interfaces }}
Expand Down
10 changes: 5 additions & 5 deletions codegen/templates/interface.gotpl
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
{{- $interface := . }}

func (ec *executionContext) _{{$interface.GQLType|lcFirst}}(sel []query.Selection, it *{{$interface.FullName}}) graphql.Marshaler {
switch it := (*it).(type) {
func (ec *executionContext) _{{$interface.GQLType}}(sel []query.Selection, obj *{{$interface.FullName}}) graphql.Marshaler {
switch obj := (*obj).(type) {
case nil:
return graphql.Null
{{- range $implementor := $interface.Implementors }}
case {{$implementor.FullName}}:
return ec._{{$implementor.GQLType|lcFirst}}(sel, &it)
return ec._{{$implementor.GQLType}}(sel, &obj)

case *{{$implementor.FullName}}:
return ec._{{$implementor.GQLType|lcFirst}}(sel, it)
return ec._{{$implementor.GQLType}}(sel, obj)

{{- end }}
default:
panic(fmt.Errorf("unexpected type %T", it))
panic(fmt.Errorf("unexpected type %T", obj))
}
}
80 changes: 6 additions & 74 deletions codegen/templates/object.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,104 +4,36 @@ var {{ $object.GQLType|lcFirst}}Implementors = {{$object.Implementors}}

// nolint: gocyclo, errcheck, gas, goconst
{{- if .Stream }}
func (ec *executionContext) _{{$object.GQLType|lcFirst}}(sel []query.Selection{{if not $object.Root}}, it *{{$object.FullName}}{{end}}) <-chan graphql.Marshaler {
func (ec *executionContext) _{{$object.GQLType}}(sel []query.Selection) func() graphql.Marshaler {
fields := graphql.CollectFields(ec.doc, sel, {{$object.GQLType|lcFirst}}Implementors, ec.variables)

if len(fields) != 1 {
ec.Errorf("must subscribe to exactly one stream")
return nil
}

var field = fields[0]
channel := make(chan graphql.Marshaler, 1)
switch field.Name {
switch fields[0].Name {
{{- range $field := $object.Fields }}
case "{{$field.GQLName}}":
{{- template "args.gotpl" $field.Args }}

{{- if $field.GoVarName }}
results := it.{{$field.GoVarName}}
{{- else if $field.GoMethodName }}
{{- if $field.NoErr }}
results := {{$field.GoMethodName}}({{ $field.CallArgs }})
{{- else }}
results, err := {{$field.GoMethodName}}({{ $field.CallArgs }})
if err != nil {
ec.Error(err)
return nil
}
{{- end }}
{{- else }}
results, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }})
if err != nil {
ec.Error(err)
return nil
}
{{- end }}

go func() {
for res := range results {
var out graphql.OrderedMap
var messageRes graphql.Marshaler
{{ $field.WriteJson "messageRes" }}
out.Add(field.Alias, messageRes)
channel <- &out
}
}()

return ec._{{$object.GQLType}}_{{$field.GQLName}}(fields[0])
{{- end }}
default:
panic("unknown field " + strconv.Quote(field.Name))
panic("unknown field " + strconv.Quote(fields[0].Name))
}

return channel
}
{{- else }}
func (ec *executionContext) _{{$object.GQLType|lcFirst}}(sel []query.Selection{{if not $object.Root}}, it *{{$object.FullName}} {{end}}) graphql.Marshaler {
func (ec *executionContext) _{{$object.GQLType}}(sel []query.Selection{{if not $object.Root}}, obj *{{$object.FullName}} {{end}}) graphql.Marshaler {
fields := graphql.CollectFields(ec.doc, sel, {{$object.GQLType|lcFirst}}Implementors, ec.variables)
out := graphql.NewOrderedMap(len(fields))
for i, field := range fields {
out.Keys[i] = field.Alias
out.Values[i] = graphql.Null

switch field.Name {
case "__typename":
out.Values[i] = graphql.MarshalString({{$object.GQLType|quote}})
{{- range $field := $object.Fields }}
case "{{$field.GQLName}}":
{{- template "args.gotpl" $field.Args }}

{{- if $field.IsConcurrent }}
ec.wg.Add(1)
go func(i int, field graphql.CollectedField) {
defer ec.wg.Done()
{{- end }}

{{- if $field.GoVarName }}
res := it.{{$field.GoVarName}}
{{- else if $field.GoMethodName }}
{{- if $field.NoErr }}
res := {{$field.GoMethodName}}({{ $field.CallArgs }})
{{- else }}
res, err := {{$field.GoMethodName}}({{ $field.CallArgs }})
if err != nil {
ec.Error(err)
{{ if $field.IsConcurrent }}return{{ else }}continue{{end}}
}
{{- end }}
{{- else }}
res, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }})
if err != nil {
ec.Error(err)
{{ if $field.IsConcurrent }}return{{ else }}continue{{end}}
}
{{- end }}

{{ $field.WriteJson "out.Values[i]" }}

{{- if $field.IsConcurrent }}
}(i, field)
{{- end }}
out.Values[i] = ec._{{$object.GQLType}}_{{$field.GQLName}}(field{{if not $object.Root}}, obj{{end}})
{{- end }}
default:
panic("unknown field " + strconv.Quote(field.Name))
Expand Down
Loading

0 comments on commit 2105587

Please sign in to comment.