Skip to content

Commit

Permalink
simpler way to resolve type refs in schema
Browse files Browse the repository at this point in the history
  • Loading branch information
neelance committed Oct 23, 2016
1 parent 7cbf85f commit 042e306
Showing 1 changed file with 118 additions and 62 deletions.
180 changes: 118 additions & 62 deletions internal/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,19 @@ type NonNull struct {
OfType Type
}

func (Scalar) isType() {}
func (Object) isType() {}
func (Interface) isType() {}
func (Union) isType() {}
func (Enum) isType() {}
func (InputObject) isType() {}
func (List) isType() {}
func (NonNull) isType() {}
type typeRef struct {
name string
}

func (*Scalar) isType() {}
func (*Object) isType() {}
func (*Interface) isType() {}
func (*Union) isType() {}
func (*Enum) isType() {}
func (*InputObject) isType() {}
func (*List) isType() {}
func (*NonNull) isType() {}
func (*typeRef) isType() {}

type Field struct {
Name string
Expand All @@ -89,36 +94,24 @@ type InputValue struct {
Default interface{}
}

type typeRef struct {
name string
target *Type
}

type context struct {
typeRefs []*typeRef
}

func Parse(schemaString string) (s *Schema, err *errors.GraphQLError) {
sc := &scanner.Scanner{
Mode: scanner.ScanIdents | scanner.ScanInts | scanner.ScanFloats | scanner.ScanStrings,
}
sc.Init(strings.NewReader(schemaString))

c := &context{}
l := lexer.New(sc)
err = l.CatchSyntaxError(func() {
s = parseSchema(l, c)
s = parseSchema(l)
})
if err != nil {
return nil, err
}

for _, ref := range c.typeRefs {
t, ok := s.Types[ref.name]
if !ok {
return nil, errors.Errorf("type %q not found", ref.name)
for _, t := range s.Types {
if err := resolveType(s, t); err != nil {
return nil, err
}
*ref.target = t
}

for _, obj := range s.objects {
Expand Down Expand Up @@ -155,7 +148,78 @@ func Parse(schemaString string) (s *Schema, err *errors.GraphQLError) {
return s, nil
}

func parseSchema(l *lexer.Lexer, c *context) *Schema {
func resolveType(s *Schema, t Type) *errors.GraphQLError {
var err *errors.GraphQLError
switch t := t.(type) {
case *Scalar:
// nothing
case *Object:
for _, f := range t.Fields {
if err := resolveField(s, f); err != nil {
return err
}
}
case *Interface:
for _, f := range t.Fields {
if err := resolveField(s, f); err != nil {
return err
}
}
case *Union:
// nothing
case *Enum:
// nothing
case *InputObject:
for _, f := range t.InputFields {
f.Type, err = resolveTypeRef(s, f.Type)
if err != nil {
return err
}
}
case *List:
t.OfType, err = resolveTypeRef(s, t.OfType)
if err != nil {
return err
}
case *NonNull:
t.OfType, err = resolveTypeRef(s, t.OfType)
if err != nil {
return err
}
default:
panic("unreachable")
}
return nil
}

func resolveField(s *Schema, f *Field) *errors.GraphQLError {
var err *errors.GraphQLError
f.Type, err = resolveTypeRef(s, f.Type)
if err != nil {
return err
}
for _, arg := range f.Args {
arg.Type, err = resolveTypeRef(s, arg.Type)
if err != nil {
return err
}
}
return nil
}

func resolveTypeRef(s *Schema, t Type) (Type, *errors.GraphQLError) {
if ref, ok := t.(*typeRef); ok {
refT, ok := s.Types[ref.name]
if !ok {
return nil, errors.Errorf("type %q not found", ref.name)
}
return refT, nil
}
resolveType(s, t)
return t, nil
}

func parseSchema(l *lexer.Lexer) *Schema {
s := &Schema{
EntryPoints: make(map[string]string),
Types: map[string]Type{
Expand All @@ -179,21 +243,21 @@ func parseSchema(l *lexer.Lexer, c *context) *Schema {
}
l.ConsumeToken('}')
case "type":
obj := parseObjectDecl(l, c)
obj := parseObjectDecl(l)
s.Types[obj.Name] = obj
s.objects = append(s.objects, obj)
case "interface":
intf := parseInterfaceDecl(l, c)
intf := parseInterfaceDecl(l)
s.Types[intf.Name] = intf
case "union":
union := parseUnionDecl(l, c)
union := parseUnionDecl(l)
s.Types[union.Name] = union
s.unions = append(s.unions, union)
case "enum":
enum := parseEnumDecl(l, c)
enum := parseEnumDecl(l)
s.Types[enum.Name] = enum
case "input":
input := parseInputDecl(l, c)
input := parseInputDecl(l)
s.Types[input.Name] = input
default:
l.SyntaxError(fmt.Sprintf(`unexpected %q, expecting "schema", "type", "enum", "interface", "union" or "input"`, x))
Expand All @@ -203,7 +267,7 @@ func parseSchema(l *lexer.Lexer, c *context) *Schema {
return s
}

func parseObjectDecl(l *lexer.Lexer, c *context) *Object {
func parseObjectDecl(l *lexer.Lexer) *Object {
o := &Object{}
o.Name = l.ConsumeIdent()
if l.Peek() == scanner.Ident {
Expand All @@ -216,21 +280,21 @@ func parseObjectDecl(l *lexer.Lexer, c *context) *Object {
}
}
l.ConsumeToken('{')
o.Fields, o.FieldOrder = parseFields(l, c)
o.Fields, o.FieldOrder = parseFields(l)
l.ConsumeToken('}')
return o
}

func parseInterfaceDecl(l *lexer.Lexer, c *context) *Interface {
func parseInterfaceDecl(l *lexer.Lexer) *Interface {
i := &Interface{}
i.Name = l.ConsumeIdent()
l.ConsumeToken('{')
i.Fields, i.FieldOrder = parseFields(l, c)
i.Fields, i.FieldOrder = parseFields(l)
l.ConsumeToken('}')
return i
}

func parseUnionDecl(l *lexer.Lexer, c *context) *Union {
func parseUnionDecl(l *lexer.Lexer) *Union {
union := &Union{}
union.Name = l.ConsumeIdent()
l.ConsumeToken('=')
Expand All @@ -242,22 +306,22 @@ func parseUnionDecl(l *lexer.Lexer, c *context) *Union {
return union
}

func parseInputDecl(l *lexer.Lexer, c *context) *InputObject {
func parseInputDecl(l *lexer.Lexer) *InputObject {
i := &InputObject{
InputFields: make(map[string]*InputValue),
}
i.Name = l.ConsumeIdent()
l.ConsumeToken('{')
for l.Peek() != '}' {
v := parseInputValue(l, c)
v := parseInputValue(l)
i.InputFields[v.Name] = v
i.InputFieldOrder = append(i.InputFieldOrder, v.Name)
}
l.ConsumeToken('}')
return i
}

func parseEnumDecl(l *lexer.Lexer, c *context) *Enum {
func parseEnumDecl(l *lexer.Lexer) *Enum {
enum := &Enum{}
enum.Name = l.ConsumeIdent()
l.ConsumeToken('{')
Expand All @@ -268,7 +332,7 @@ func parseEnumDecl(l *lexer.Lexer, c *context) *Enum {
return enum
}

func parseFields(l *lexer.Lexer, c *context) (map[string]*Field, []string) {
func parseFields(l *lexer.Lexer) (map[string]*Field, []string) {
fields := make(map[string]*Field)
var fieldOrder []string
for l.Peek() != '}' {
Expand All @@ -278,58 +342,50 @@ func parseFields(l *lexer.Lexer, c *context) (map[string]*Field, []string) {
f.Args = make(map[string]*InputValue)
l.ConsumeToken('(')
for l.Peek() != ')' {
v := parseInputValue(l, c)
v := parseInputValue(l)
f.Args[v.Name] = v
f.ArgOrder = append(f.ArgOrder, v.Name)
}
l.ConsumeToken(')')
}
l.ConsumeToken(':')
parseType(&f.Type, l, c)
f.Type = parseType(l)
fields[f.Name] = f
fieldOrder = append(fieldOrder, f.Name)
}
return fields, fieldOrder
}

func parseInputValue(l *lexer.Lexer, c *context) *InputValue {
func parseInputValue(l *lexer.Lexer) *InputValue {
p := &InputValue{}
p.Name = l.ConsumeIdent()
l.ConsumeToken(':')
parseType(&p.Type, l, c)
p.Type = parseType(l)
if l.Peek() == '=' {
l.ConsumeToken('=')
p.Default = parseValue(l)
}
return p
}

func parseType(target *Type, l *lexer.Lexer, c *context) {
parseNonNil := func() {
if l.Peek() == '!' {
l.ConsumeToken('!')
nn := &NonNull{}
*target = nn
target = &nn.OfType
}
func parseType(l *lexer.Lexer) Type {
t := parseNullType(l)
if l.Peek() == '!' {
l.ConsumeToken('!')
return &NonNull{OfType: t}
}
return t
}

func parseNullType(l *lexer.Lexer) Type {
if l.Peek() == '[' {
l.ConsumeToken('[')
t := &List{}
parseType(&t.OfType, l, c)
ofType := parseType(l)
l.ConsumeToken(']')
parseNonNil()
*target = t
return
return &List{OfType: ofType}
}

name := l.ConsumeIdent()
parseNonNil()
c.typeRefs = append(c.typeRefs, &typeRef{
name: name,
target: target,
})
return &typeRef{name: l.ConsumeIdent()}
}

func parseValue(l *lexer.Lexer) interface{} {
Expand Down

0 comments on commit 042e306

Please sign in to comment.