Skip to content

Commit

Permalink
Extract builder object
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Jan 9, 2019
1 parent 87b37b0 commit 6b82903
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 155 deletions.
20 changes: 5 additions & 15 deletions codegen/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ import (
type ResolverBuild struct {
*unified.Schema

PackageName string
ResolverType string
ResolverFound bool
PackageName string
ResolverType string
}

func GenerateResolver(schema *unified.Schema) error {
Expand All @@ -24,11 +23,6 @@ func GenerateResolver(schema *unified.Schema) error {
}
filename := schema.Config.Resolver.Filename

if resolverBuild.ResolverFound {
log.Printf("Skipped resolver: %s.%s already exists\n", schema.Config.Resolver.ImportPath(), schema.Config.Resolver.Type)
return nil
}

if _, err := os.Stat(filename); os.IsNotExist(errors.Cause(err)) {
if err := templates.RenderToFile("resolver.gotpl", filename, resolverBuild); err != nil {
return err
Expand All @@ -41,13 +35,9 @@ func GenerateResolver(schema *unified.Schema) error {
}

func buildResolver(s *unified.Schema) (*ResolverBuild, error) {
def, _ := s.FindGoType(s.Config.Resolver.ImportPath(), s.Config.Resolver.Type)
resolverFound := def != nil

return &ResolverBuild{
Schema: s,
PackageName: s.Config.Resolver.Package,
ResolverType: s.Config.Resolver.Type,
ResolverFound: resolverFound,
Schema: s,
PackageName: s.Config.Resolver.Package,
ResolverType: s.Config.Resolver.Type,
}, nil
}
115 changes: 81 additions & 34 deletions codegen/unified/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,25 @@ import (
"github.com/99designs/gqlgen/codegen/config"
"github.com/pkg/errors"
"github.com/vektah/gqlparser/ast"
"golang.org/x/tools/go/loader"
)

func NewSchema(cfg *config.Config) (*Schema, error) {
g := Schema{
type builder struct {
Config *config.Config
Schema *ast.Schema
SchemaStr map[string]string
Program *loader.Program
Directives map[string]*Directive
NamedTypes NamedTypes
}

func buildSchema(cfg *config.Config) (*Schema, error) {
b := builder{
Config: cfg,
}

var err error
g.Schema, g.SchemaStr, err = cfg.LoadSchema()
b.Schema, b.SchemaStr, err = cfg.LoadSchema()
if err != nil {
return nil, err
}
Expand All @@ -27,89 +37,96 @@ func NewSchema(cfg *config.Config) (*Schema, error) {
return nil, err
}

progLoader := g.Config.NewLoaderWithoutErrors()
g.Program, err = progLoader.Load()
progLoader := b.Config.NewLoaderWithoutErrors()
b.Program, err = progLoader.Load()
if err != nil {
return nil, errors.Wrap(err, "loading failed")
}

g.NamedTypes = NamedTypes{}
b.NamedTypes = NamedTypes{}

for _, schemaType := range g.Schema.Types {
g.NamedTypes[schemaType.Name], err = g.buildTypeDef(schemaType)
for _, schemaType := range b.Schema.Types {
b.NamedTypes[schemaType.Name], err = b.buildTypeDef(schemaType)
if err != nil {
return nil, errors.Wrap(err, "unable to build type definition")
}
}

g.Directives, err = g.buildDirectives()
b.Directives, err = b.buildDirectives()
if err != nil {
return nil, err
}

for _, schemaType := range g.Schema.Types {
s := Schema{
Config: cfg,
Directives: b.Directives,
Schema: b.Schema,
SchemaStr: b.SchemaStr,
}

for _, schemaType := range b.Schema.Types {
switch schemaType.Kind {
case ast.Object:
obj, err := g.buildObject(schemaType)
obj, err := b.buildObject(schemaType)
if err != nil {
return nil, errors.Wrap(err, "unable to build object definition")
}

g.Objects = append(g.Objects, obj)
s.Objects = append(s.Objects, obj)
case ast.InputObject:
input, err := g.buildObject(schemaType)
input, err := b.buildObject(schemaType)
if err != nil {
return nil, errors.Wrap(err, "unable to build input definition")
}

g.Inputs = append(g.Inputs, input)
s.Inputs = append(s.Inputs, input)

case ast.Union, ast.Interface:
g.Interfaces = append(g.Interfaces, g.buildInterface(schemaType))
s.Interfaces = append(s.Interfaces, b.buildInterface(schemaType))

case ast.Enum:
if enum := g.buildEnum(schemaType); enum != nil {
g.Enums = append(g.Enums, *enum)
if enum := b.buildEnum(schemaType); enum != nil {
s.Enums = append(s.Enums, *enum)
}
}
}

if err := g.injectIntrospectionRoots(); err != nil {
if err := b.injectIntrospectionRoots(&s); err != nil {
return nil, err
}

sort.Slice(g.Objects, func(i, j int) bool {
return g.Objects[i].Definition.GQLDefinition.Name < g.Objects[j].Definition.GQLDefinition.Name
sort.Slice(s.Objects, func(i, j int) bool {
return s.Objects[i].Definition.GQLDefinition.Name < s.Objects[j].Definition.GQLDefinition.Name
})

sort.Slice(g.Inputs, func(i, j int) bool {
return g.Inputs[i].Definition.GQLDefinition.Name < g.Inputs[j].Definition.GQLDefinition.Name
sort.Slice(s.Inputs, func(i, j int) bool {
return s.Inputs[i].Definition.GQLDefinition.Name < s.Inputs[j].Definition.GQLDefinition.Name
})

sort.Slice(g.Interfaces, func(i, j int) bool {
return g.Interfaces[i].Definition.GQLDefinition.Name < g.Interfaces[j].Definition.GQLDefinition.Name
sort.Slice(s.Interfaces, func(i, j int) bool {
return s.Interfaces[i].Definition.GQLDefinition.Name < s.Interfaces[j].Definition.GQLDefinition.Name
})

sort.Slice(g.Enums, func(i, j int) bool {
return g.Enums[i].Definition.GQLDefinition.Name < g.Enums[j].Definition.GQLDefinition.Name
sort.Slice(s.Enums, func(i, j int) bool {
return s.Enums[i].Definition.GQLDefinition.Name < s.Enums[j].Definition.GQLDefinition.Name
})

return &g, nil
return &s, nil
}

func (g *Schema) injectIntrospectionRoots() error {
obj := g.Objects.ByName(g.Schema.Query.Name)
func (b *builder) injectIntrospectionRoots(s *Schema) error {
obj := s.Objects.ByName(b.Schema.Query.Name)
if obj == nil {
return fmt.Errorf("root query type must be defined")
}

typeType, err := g.FindGoType("github.com/99designs/gqlgen/graphql/introspection", "Type")
typeType, err := b.FindGoType("github.com/99designs/gqlgen/graphql/introspection", "Type")
if err != nil {
return errors.Wrap(err, "unable to find root Type introspection type")
}

obj.Fields = append(obj.Fields, &Field{
TypeReference: &TypeReference{g.NamedTypes["__Type"], types.NewPointer(typeType.Type()), ast.NamedType("__Schema", nil)},
TypeReference: &TypeReference{b.NamedTypes["__Type"], types.NewPointer(typeType.Type()), ast.NamedType("__Schema", nil)},
GQLName: "__type",
GoFieldType: GoFieldMethod,
GoReceiverName: "ec",
Expand All @@ -118,7 +135,7 @@ func (g *Schema) injectIntrospectionRoots() error {
{
GQLName: "name",
TypeReference: &TypeReference{
g.NamedTypes["String"],
b.NamedTypes["String"],
types.Typ[types.String],
ast.NamedType("String", nil),
},
Expand All @@ -128,13 +145,13 @@ func (g *Schema) injectIntrospectionRoots() error {
Object: obj,
})

schemaType, err := g.FindGoType("github.com/99designs/gqlgen/graphql/introspection", "Schema")
schemaType, err := b.FindGoType("github.com/99designs/gqlgen/graphql/introspection", "Schema")
if err != nil {
return errors.Wrap(err, "unable to find root Schema introspection type")
}

obj.Fields = append(obj.Fields, &Field{
TypeReference: &TypeReference{g.NamedTypes["__Schema"], types.NewPointer(schemaType.Type()), ast.NamedType("__Schema", nil)},
TypeReference: &TypeReference{b.NamedTypes["__Schema"], types.NewPointer(schemaType.Type()), ast.NamedType("__Schema", nil)},
GQLName: "__schema",
GoFieldType: GoFieldMethod,
GoReceiverName: "ec",
Expand All @@ -144,3 +161,33 @@ func (g *Schema) injectIntrospectionRoots() error {

return nil
}

func (b *builder) FindGoType(pkgName string, typeName string) (types.Object, error) {
if pkgName == "" {
return nil, nil
}
fullName := typeName
if pkgName != "" {
fullName = pkgName + "." + typeName
}

pkgName, err := resolvePkg(pkgName)
if err != nil {
return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error())
}

pkg := b.Program.Imported[pkgName]
if pkg == nil {
return nil, errors.Errorf("required package was not loaded: %s", fullName)
}

for astNode, def := range pkg.Defs {
if astNode.Name != typeName || def.Parent() == nil || def.Parent() != pkg.Pkg.Scope() {
continue
}

return def, nil
}

return nil, errors.Errorf("unable to find type %s\n", fullName)
}
32 changes: 16 additions & 16 deletions codegen/unified/build_bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ func (b BindErrors) Error() string {
return strings.Join(errs, "\n\n")
}

func (g *Schema) bindObject(object *Object) BindErrors {
func (b *builder) bindObject(object *Object) BindErrors {
var errs BindErrors
for _, field := range object.Fields {
if field.IsResolver {
continue
}

// first try binding to a method
methodErr := g.bindMethod(object.Definition.GoType, field)
methodErr := b.bindMethod(object.Definition.GoType, field)
if methodErr == nil {
continue
}

// otherwise try binding to a var
varErr := g.bindVar(object.Definition.GoType, field)
varErr := b.bindVar(object.Definition.GoType, field)

// if both failed, add a resolver
if varErr != nil {
Expand All @@ -71,13 +71,13 @@ func (g *Schema) bindObject(object *Object) BindErrors {
return errs
}

func (g *Schema) bindMethod(t types.Type, field *Field) error {
func (b *builder) bindMethod(t types.Type, field *Field) error {
namedType, err := findGoNamedType(t)
if err != nil {
return err
}

method := g.findMethod(namedType, field.GoFieldName)
method := b.findMethod(namedType, field.GoFieldName)
if method == nil {
return fmt.Errorf("no method named %s", field.GoFieldName)
}
Expand All @@ -100,7 +100,7 @@ func (g *Schema) bindMethod(t types.Type, field *Field) error {
params = types.NewTuple(vars...)
}

if err := g.bindArgs(field, params); err != nil {
if err := b.bindArgs(field, params); err != nil {
return err
}

Expand All @@ -117,13 +117,13 @@ func (g *Schema) bindMethod(t types.Type, field *Field) error {
return nil
}

func (g *Schema) bindVar(t types.Type, field *Field) error {
func (b *builder) bindVar(t types.Type, field *Field) error {
underlying, ok := t.Underlying().(*types.Struct)
if !ok {
return fmt.Errorf("not a struct")
}

structField, err := g.findField(underlying, field.GoFieldName)
structField, err := b.findField(underlying, field.GoFieldName)
if err != nil {
return err
}
Expand All @@ -140,7 +140,7 @@ func (g *Schema) bindVar(t types.Type, field *Field) error {
return nil
}

func (g *Schema) bindArgs(field *Field, params *types.Tuple) error {
func (b *builder) bindArgs(field *Field, params *types.Tuple) error {
var newArgs []*FieldArgument

nextArg:
Expand Down Expand Up @@ -328,7 +328,7 @@ func normalizeVendor(pkg string) string {
return modifiers + parts[len(parts)-1]
}

func (g *Schema) findMethod(typ *types.Named, name string) *types.Func {
func (b *builder) findMethod(typ *types.Named, name string) *types.Func {
for i := 0; i < typ.NumMethods(); i++ {
method := typ.Method(i)
if !method.Exported() {
Expand All @@ -348,7 +348,7 @@ func (g *Schema) findMethod(typ *types.Named, name string) *types.Func {
}

if named, ok := field.Type().(*types.Named); ok {
if f := g.findMethod(named, name); f != nil {
if f := b.findMethod(named, name); f != nil {
return f
}
}
Expand All @@ -363,18 +363,18 @@ func (g *Schema) findMethod(typ *types.Named, name string) *types.Func {
// 1. If struct tag is passed then struct tag has highest priority
// 2. Actual Field name
// 3. Field in an embedded struct
func (g *Schema) findField(typ *types.Struct, name string) (*types.Var, error) {
if g.Config.StructTag != "" {
func (b *builder) findField(typ *types.Struct, name string) (*types.Var, error) {
if b.Config.StructTag != "" {
var foundField *types.Var
for i := 0; i < typ.NumFields(); i++ {
field := typ.Field(i)
if !field.Exported() {
continue
}
tags := reflect.StructTag(typ.Tag(i))
if val, ok := tags.Lookup(g.Config.StructTag); ok && equalFieldName(val, name) {
if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
if foundField != nil {
return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", g.Config.StructTag, val)
return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
}

foundField = field
Expand Down Expand Up @@ -411,7 +411,7 @@ func (g *Schema) findField(typ *types.Struct, name string) (*types.Var, error) {
// Type.Underlying() returns itself for all types except types.Named, where it returns a struct type.
// It should be safe to always call.
if named, ok := fieldType.Underlying().(*types.Struct); ok {
f, err := g.findField(named, name)
f, err := b.findField(named, name)
if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions codegen/unified/build_bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ type Embed struct {
}

for _, tt := range tests {
schema := Schema{Config: &config.Config{StructTag: tt.Tag}}
field, err := schema.findField(tt.Struct, tt.Field)
b := builder{Config: &config.Config{StructTag: tt.Tag}}
field, err := b.findField(tt.Struct, tt.Field)
if tt.ShouldError {
require.Nil(t, field, tt.Name)
require.Error(t, err, tt.Name)
Expand Down
Loading

0 comments on commit 6b82903

Please sign in to comment.