Skip to content

Commit

Permalink
Bind to embedded interface method
Browse files Browse the repository at this point in the history
  • Loading branch information
matiasanaya committed Nov 10, 2019
1 parent a745dc7 commit 70e860c
Show file tree
Hide file tree
Showing 7 changed files with 336 additions and 97 deletions.
249 changes: 152 additions & 97 deletions codegen/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,107 +184,164 @@ func (b *builder) bindField(obj *Object, f *Field) (errret error) {
}
}

// findBindTarget attempts to match the name to a struct field or method
func (b *builder) findBindTarget(in types.Type, name string) (types.Object, error) {
switch t := in.(type) {
case *types.Named:
if _, ok := t.Underlying().(*types.Interface); ok {
return nil, errors.New("can't bind to an interface at root")
}
case *types.Interface:
return nil, errors.New("can't bind to an interface at root")
}

return b.findBindTargetRecur(in, name)
}

// findBindTargetRecur attempts to match the name to a field or method on a Type
// with the following priorites:
// 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found
// 2. Any method or field with a matching name. Errors if more than one match is found
// 3. Same logic again for embedded fields
func (b *builder) findBindTarget(named *types.Named, name string) (types.Object, error) {
strukt, isStruct := named.Underlying().(*types.Struct)
if isStruct {
// NOTE: a struct tag will override both methods and fields
// Bind to struct tag
found, err := b.findBindStructTagTarget(strukt, name)
if found != nil || err != nil {
return found, err
}
func (b *builder) findBindTargetRecur(t types.Type, name string) (types.Object, error) {
// NOTE: a struct tag will override both methods and fields
// Bind to struct tag
found, err := b.findBindStructTagTarget(t, name)
if found != nil || err != nil {
return found, err
}

// Search for a method to bind to
var foundMethod types.Object
for i := 0; i < named.NumMethods(); i++ {
method := named.Method(i)
if !method.Exported() || !strings.EqualFold(method.Name(), name) {
continue
}

if foundMethod != nil {
return nil, errors.Errorf("found more than one matching method to bind for %s", name)
}

foundMethod = method
foundMethod, err := b.findBindMethodTarget(t, name)
if err != nil {
return nil, err
}

// Search for a field to bind to
if isStruct {
foundField, err := b.findBindFieldTarget(strukt, name)
if err != nil {
return nil, err
}
foundField, err := b.findBindFieldTarget(t, name)
if err != nil {
return nil, err
}

switch {
case foundField == nil && foundMethod == nil:
// Search embeds
return b.findBindEmbedsTarget(strukt, name)
case foundField == nil && foundMethod != nil:
// Bind to method
return foundMethod, nil
case foundField != nil && foundMethod == nil:
// Bind to field
return foundField, nil
case foundField != nil && foundMethod != nil:
// Error
return nil, errors.Errorf("found more than one way to bind for %s", name)
}
switch {
case foundField == nil && foundMethod != nil:
// Bind to method
return foundMethod, nil
case foundField != nil && foundMethod == nil:
// Bind to field
return foundField, nil
case foundField != nil && foundMethod != nil:
// Error
return nil, errors.Errorf("found more than one way to bind for %s", name)
}

// Bind to method or don't bind at all
return foundMethod, nil
// Search embeds
return b.findBindEmbedsTarget(t, name)
}

func (b *builder) findBindStructTagTarget(strukt *types.Struct, name string) (types.Object, error) {
func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) {
if b.Config.StructTag == "" {
return nil, nil
}

var found types.Object
for i := 0; i < strukt.NumFields(); i++ {
field := strukt.Field(i)
if !field.Exported() || field.Embedded() {
continue
}
tags := reflect.StructTag(strukt.Tag(i))
if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
if found != nil {
return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
switch t := in.(type) {
case *types.Named:
return b.findBindStructTagTarget(t.Underlying(), name)
case *types.Struct:
var found types.Object
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
if !field.Exported() || field.Embedded() {
continue
}
tags := reflect.StructTag(t.Tag(i))
if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
if found != nil {
return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
}

found = field
found = field
}
}

return found, nil
}

return found, nil
return nil, nil
}

func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) {
switch t := in.(type) {
case *types.Named:
if _, ok := t.Underlying().(*types.Interface); ok {
return b.findBindMethodTarget(t.Underlying(), name)
}

return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
case *types.Interface:
// FIX-ME: Should use ExplicitMethod here? What's the difference?
return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
}

return nil, nil
}

func (b *builder) findBindFieldTarget(strukt *types.Struct, name string) (types.Object, error) {
func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) {
var found types.Object
for i := 0; i < strukt.NumFields(); i++ {
field := strukt.Field(i)
if !field.Exported() || !equalFieldName(field.Name(), name) {
for i := 0; i < methodCount; i++ {
method := methodFunc(i)
if !method.Exported() || !strings.EqualFold(method.Name(), name) {
continue
}

if found != nil {
return nil, errors.Errorf("found more than one matching field to bind for %s", name)
return nil, errors.Errorf("found more than one matching method to bind for %s", name)
}

found = field
found = method
}

return found, nil
}

func (b *builder) findBindEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) {
func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) {
switch t := in.(type) {
case *types.Named:
return b.findBindFieldTarget(t.Underlying(), name)
case *types.Struct:
var found types.Object
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
if !field.Exported() || !equalFieldName(field.Name(), name) {
continue
}

if found != nil {
return nil, errors.Errorf("found more than one matching field to bind for %s", name)
}

found = field
}

return found, nil
}

return nil, nil
}

func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) {
switch t := in.(type) {
case *types.Named:
return b.findBindEmbedsTarget(t.Underlying(), name)
case *types.Struct:
return b.findBindStructEmbedsTarget(t, name)
case *types.Interface:
return b.findBindInterfaceEmbedsTarget(t, name)
}

return nil, nil
}

func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) {
var found types.Object
for i := 0; i < strukt.NumFields(); i++ {
field := strukt.Field(i)
Expand All @@ -297,41 +354,39 @@ func (b *builder) findBindEmbedsTarget(strukt *types.Struct, name string) (types
fieldType = ptr.Elem()
}

switch fieldType := fieldType.(type) {
case *types.Named:
f, err := b.findBindTarget(fieldType, name)
if err != nil {
return nil, err
}
if f != nil && found != nil {
return nil, errors.Errorf("found more than one way to bind for %s", name)
}
if f != nil {
found = f
}
case *types.Struct:
f, err := b.findBindStructTagTarget(fieldType, name)
if err != nil {
return nil, err
}
if f != nil && found != nil {
return nil, errors.Errorf("found more than one way to bind for %s", name)
}
if f != nil {
found = f
continue
}
f, err := b.findBindTargetRecur(fieldType, name)
if err != nil {
return nil, err
}

f, err = b.findBindFieldTarget(fieldType, name)
if err != nil {
return nil, err
}
if f != nil && found != nil {
return nil, errors.Errorf("found more than one way to bind for %s", name)
}
if f != nil {
found = f
}
if f != nil && found != nil {
return nil, errors.Errorf("found more than one way to bind for %s", name)
}

if f != nil {
found = f
}
}

return found, nil
}

func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name string) (types.Object, error) {
var found types.Object
for i := 0; i < iface.NumEmbeddeds(); i++ {
embeddedType := iface.EmbeddedType(i)

f, err := b.findBindTargetRecur(embeddedType, name)
if err != nil {
return nil, err
}

if f != nil && found != nil {
return nil, errors.Errorf("found more than one way to bind for %s", name)
}

if f != nil {
found = f
}
}

Expand Down
13 changes: 13 additions & 0 deletions codegen/testserver/embedded.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,16 @@ type unexportedEmbeddedPointer struct{}
func (*unexportedEmbeddedPointer) UnexportedEmbeddedPointerExportedMethod() string {
return "UnexportedEmbeddedPointerExportedMethodResponse"
}

// EmbeddedCase3 model
type EmbeddedCase3 struct {
unexportedEmbeddedInterface
}

type unexportedEmbeddedInterface interface {
nestedInterface
}

type nestedInterface interface {
UnexportedEmbeddedInterfaceExportedMethod() string
}
5 changes: 5 additions & 0 deletions codegen/testserver/embedded.graphql
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
extend type Query {
embeddedCase1: EmbeddedCase1
embeddedCase2: EmbeddedCase2
embeddedCase3: EmbeddedCase3
}

type EmbeddedCase1 @goModel(model:"testserver.EmbeddedCase1") {
Expand All @@ -10,3 +11,7 @@ type EmbeddedCase1 @goModel(model:"testserver.EmbeddedCase1") {
type EmbeddedCase2 @goModel(model:"testserver.EmbeddedCase2") {
unexportedEmbeddedPointerExportedMethod: String!
}

type EmbeddedCase3 @goModel(model:"testserver.EmbeddedCase3") {
unexportedEmbeddedInterfaceExportedMethod: String!
}
20 changes: 20 additions & 0 deletions codegen/testserver/embedded_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ import (
"github.com/stretchr/testify/require"
)

type fakeUnexportedEmbeddedInterface struct{}

func (*fakeUnexportedEmbeddedInterface) UnexportedEmbeddedInterfaceExportedMethod() string {
return "UnexportedEmbeddedInterfaceExportedMethod"
}

func TestEmbedded(t *testing.T) {
resolver := &Stub{}
resolver.QueryResolver.EmbeddedCase1 = func(ctx context.Context) (*EmbeddedCase1, error) {
Expand All @@ -17,6 +23,9 @@ func TestEmbedded(t *testing.T) {
resolver.QueryResolver.EmbeddedCase2 = func(ctx context.Context) (*EmbeddedCase2, error) {
return &EmbeddedCase2{&unexportedEmbeddedPointer{}}, nil
}
resolver.QueryResolver.EmbeddedCase3 = func(ctx context.Context) (*EmbeddedCase3, error) {
return &EmbeddedCase3{&fakeUnexportedEmbeddedInterface{}}, nil
}

c := client.New(handler.GraphQL(
NewExecutableSchema(Config{Resolvers: resolver}),
Expand All @@ -43,4 +52,15 @@ func TestEmbedded(t *testing.T) {
require.NoError(t, err)
require.Equal(t, resp.EmbeddedCase2.UnexportedEmbeddedPointerExportedMethod, "UnexportedEmbeddedPointerExportedMethodResponse")
})

t.Run("embedded case 3", func(t *testing.T) {
var resp struct {
EmbeddedCase3 struct {
UnexportedEmbeddedInterfaceExportedMethod string
}
}
err := c.Post(`query { embeddedCase3 { unexportedEmbeddedInterfaceExportedMethod } }`, &resp)
require.NoError(t, err)
require.Equal(t, resp.EmbeddedCase3.UnexportedEmbeddedInterfaceExportedMethod, "UnexportedEmbeddedInterfaceExportedMethod")
})
}
Loading

0 comments on commit 70e860c

Please sign in to comment.