Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
neelance committed May 24, 2017
1 parent 0933d24 commit 4aff297
Showing 1 changed file with 67 additions and 58 deletions.
125 changes: 67 additions & 58 deletions internal/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ func (c *context) addErrMultiLoc(locs []errors.Location, rule string, format str
})
}

type opContext struct {
*context
ops []*query.Operation
}

func Validate(s *schema.Schema, doc *query.Document) []*errors.QueryError {
c := context{
c := &context{
schema: s,
doc: doc,
opErrs: make(map[*query.Operation][]*errors.QueryError),
Expand All @@ -42,33 +47,35 @@ func Validate(s *schema.Schema, doc *query.Document) []*errors.QueryError {
opNames := make(nameSet)
fragUsedBy := make(map[*query.FragmentDecl][]*query.Operation)
for _, op := range doc.Operations {
opc := &opContext{c, []*query.Operation{op}}

if op.Name.Name == "" && len(doc.Operations) != 1 {
c.addErr(op.Loc, "LoneAnonymousOperation", "This anonymous operation must be the only defined operation.")
}
if op.Name.Name != "" {
c.validateName(opNames, op.Name, "UniqueOperationNames", "operation")
validateName(c, opNames, op.Name, "UniqueOperationNames", "operation")
}

c.validateDirectives(string(op.Type), op.Directives, nil)
validateDirectives(opc, string(op.Type), op.Directives)

varNames := make(nameSet)
for _, v := range op.Vars {
c.validateName(varNames, v.Name, "UniqueVariableNames", "variable")
validateName(c, varNames, v.Name, "UniqueVariableNames", "variable")

t := c.resolveType(v.Type)
t := resolveType(c, v.Type)
if !canBeInput(t) {
c.addErr(v.TypeLoc, "VariablesAreInputTypes", "Variable %q cannot be non-input type %q.", "$"+v.Name.Name, t)
}

if v.Default != nil {
c.validateLiteral(v.Default, nil)
validateLiteral(opc, v.Default)

if t != nil {
if nn, ok := t.(*common.NonNull); ok {
c.addErr(v.Default.Location(), "DefaultValuesOfCorrectType", "Variable %q of type %q is required and will not use the default value. Perhaps you meant to use type %q.", "$"+v.Name.Name, t, nn.OfType)
}

if ok, reason := c.validateValueType(v.Default, t, nil); !ok {
if ok, reason := validateValueType(opc, v.Default, t); !ok {
c.addErr(v.Default.Location(), "DefaultValuesOfCorrectType", "Variable %q of type %q has invalid default value %s.\n%s", "$"+v.Name.Name, t, v.Default, reason)
}
}
Expand All @@ -87,10 +94,10 @@ func Validate(s *schema.Schema, doc *query.Document) []*errors.QueryError {
panic("unreachable")
}

c.validateSelectionSet(op.SelSet, entryPoint, []*query.Operation{op})
validateSelectionSet(opc, op.SelSet, entryPoint)

fragUsed := make(map[*query.FragmentDecl]struct{})
c.markUsedFragments(op.SelSet, fragUsed)
markUsedFragments(c, op.SelSet, fragUsed)
for frag := range fragUsed {
fragUsedBy[frag] = append(fragUsedBy[frag], op)
}
Expand All @@ -99,20 +106,22 @@ func Validate(s *schema.Schema, doc *query.Document) []*errors.QueryError {
fragNames := make(nameSet)
fragVisited := make(map[*query.FragmentDecl]struct{})
for _, frag := range doc.Fragments {
c.validateName(fragNames, frag.Name, "UniqueFragmentNames", "fragment")
c.validateDirectives("FRAGMENT_DEFINITION", frag.Directives, nil)
opc := &opContext{c, fragUsedBy[frag]}

validateName(c, fragNames, frag.Name, "UniqueFragmentNames", "fragment")
validateDirectives(opc, "FRAGMENT_DEFINITION", frag.Directives)

t := c.resolveType(&frag.On)
t := resolveType(c, &frag.On)
// continue even if t is nil
if t != nil && !canBeFragment(t) {
c.addErr(frag.On.Loc, "FragmentsOnCompositeTypes", "Fragment %q cannot condition on non composite type %q.", frag.Name.Name, t)
continue
}

c.validateSelectionSet(frag.SelSet, t, fragUsedBy[frag])
validateSelectionSet(opc, frag.SelSet, t)

if _, ok := fragVisited[frag]; !ok {
c.detectFragmentCycle(frag.SelSet, fragVisited, nil, map[string]int{frag.Name.Name: 0})
detectFragmentCycle(c, frag.SelSet, fragVisited, nil, map[string]int{frag.Name.Name: 0})
}
}

Expand All @@ -129,16 +138,16 @@ func Validate(s *schema.Schema, doc *query.Document) []*errors.QueryError {
return c.errs
}

func (c *context) validateSelectionSet(selSet *query.SelectionSet, t common.Type, ops []*query.Operation) {
func validateSelectionSet(c *opContext, selSet *query.SelectionSet, t common.Type) {
for _, sel := range selSet.Selections {
c.validateSelection(sel, t, ops)
validateSelection(c, sel, t)
}
}

func (c *context) validateSelection(sel query.Selection, t common.Type, ops []*query.Operation) {
func validateSelection(c *opContext, sel query.Selection, t common.Type) {
switch sel := sel.(type) {
case *query.Field:
c.validateDirectives("FIELD", sel.Directives, ops)
validateDirectives(c, "FIELD", sel.Directives)

fieldName := sel.Name.Name
var f *schema.Field
Expand Down Expand Up @@ -172,9 +181,9 @@ func (c *context) validateSelection(sel query.Selection, t common.Type, ops []*q
}
}

c.validateArgumentLiterals(sel.Arguments, ops)
validateArgumentLiterals(c, sel.Arguments)
if f != nil {
c.validateArgumentTypes(sel.Arguments, f.Args, sel.Alias.Loc, ops,
validateArgumentTypes(c, sel.Arguments, f.Args, sel.Alias.Loc,
func() string { return fmt.Sprintf("field %q of type %q", fieldName, t) },
func() string { return fmt.Sprintf("Field %q", fieldName) },
)
Expand All @@ -192,13 +201,13 @@ func (c *context) validateSelection(sel query.Selection, t common.Type, ops []*q
}
}
if sel.SelSet != nil {
c.validateSelectionSet(sel.SelSet, unwrapType(ft), ops)
validateSelectionSet(c, sel.SelSet, unwrapType(ft))
}

case *query.InlineFragment:
c.validateDirectives("INLINE_FRAGMENT", sel.Directives, ops)
validateDirectives(c, "INLINE_FRAGMENT", sel.Directives)
if sel.On.Name != "" {
fragTyp := c.resolveType(&sel.On)
fragTyp := resolveType(c.context, &sel.On)
if fragTyp != nil && !compatible(t, fragTyp) {
c.addErr(sel.Loc, "PossibleFragmentSpreads", "Fragment cannot be spread here as objects of type %q can never be of type %q.", t, fragTyp)
}
Expand All @@ -209,10 +218,10 @@ func (c *context) validateSelection(sel query.Selection, t common.Type, ops []*q
c.addErr(sel.On.Loc, "FragmentsOnCompositeTypes", "Fragment cannot condition on non composite type %q.", t)
return
}
c.validateSelectionSet(sel.SelSet, unwrapType(t), ops)
validateSelectionSet(c, sel.SelSet, unwrapType(t))

case *query.FragmentSpread:
c.validateDirectives("FRAGMENT_SPREAD", sel.Directives, ops)
validateDirectives(c, "FRAGMENT_SPREAD", sel.Directives)
frag := c.doc.Fragments.Get(sel.Name.Name)
if frag == nil {
c.addErr(sel.Name.Loc, "KnownFragmentNames", "Unknown fragment %q.", sel.Name.Name)
Expand Down Expand Up @@ -252,16 +261,16 @@ func possibleTypes(t common.Type) []*schema.Object {
}
}

func (c *context) markUsedFragments(selSet *query.SelectionSet, fragUsed map[*query.FragmentDecl]struct{}) {
func markUsedFragments(c *context, selSet *query.SelectionSet, fragUsed map[*query.FragmentDecl]struct{}) {
for _, sel := range selSet.Selections {
switch sel := sel.(type) {
case *query.Field:
if sel.SelSet != nil {
c.markUsedFragments(sel.SelSet, fragUsed)
markUsedFragments(c, sel.SelSet, fragUsed)
}

case *query.InlineFragment:
c.markUsedFragments(sel.SelSet, fragUsed)
markUsedFragments(c, sel.SelSet, fragUsed)

case *query.FragmentSpread:
frag := c.doc.Fragments.Get(sel.Name.Name)
Expand All @@ -273,29 +282,29 @@ func (c *context) markUsedFragments(selSet *query.SelectionSet, fragUsed map[*qu
return
}
fragUsed[frag] = struct{}{}
c.markUsedFragments(frag.SelSet, fragUsed)
markUsedFragments(c, frag.SelSet, fragUsed)

default:
panic("unreachable")
}
}
}

func (c *context) detectFragmentCycle(selSet *query.SelectionSet, fragVisited map[*query.FragmentDecl]struct{}, spreadPath []*query.FragmentSpread, spreadPathIndex map[string]int) {
func detectFragmentCycle(c *context, selSet *query.SelectionSet, fragVisited map[*query.FragmentDecl]struct{}, spreadPath []*query.FragmentSpread, spreadPathIndex map[string]int) {
for _, sel := range selSet.Selections {
c.detectFragmentCycleSel(sel, fragVisited, spreadPath, spreadPathIndex)
detectFragmentCycleSel(c, sel, fragVisited, spreadPath, spreadPathIndex)
}
}

func (c *context) detectFragmentCycleSel(sel query.Selection, fragVisited map[*query.FragmentDecl]struct{}, spreadPath []*query.FragmentSpread, spreadPathIndex map[string]int) {
func detectFragmentCycleSel(c *context, sel query.Selection, fragVisited map[*query.FragmentDecl]struct{}, spreadPath []*query.FragmentSpread, spreadPathIndex map[string]int) {
switch sel := sel.(type) {
case *query.Field:
if sel.SelSet != nil {
c.detectFragmentCycle(sel.SelSet, fragVisited, spreadPath, spreadPathIndex)
detectFragmentCycle(c, sel.SelSet, fragVisited, spreadPath, spreadPathIndex)
}

case *query.InlineFragment:
c.detectFragmentCycle(sel.SelSet, fragVisited, spreadPath, spreadPathIndex)
detectFragmentCycle(c, sel.SelSet, fragVisited, spreadPath, spreadPathIndex)

case *query.FragmentSpread:
frag := c.doc.Fragments.Get(sel.Name.Name)
Expand Down Expand Up @@ -329,7 +338,7 @@ func (c *context) detectFragmentCycleSel(sel query.Selection, fragVisited map[*q
fragVisited[frag] = struct{}{}

spreadPathIndex[frag.Name.Name] = len(spreadPath)
c.detectFragmentCycle(frag.SelSet, fragVisited, spreadPath, spreadPathIndex)
detectFragmentCycle(c, frag.SelSet, fragVisited, spreadPath, spreadPathIndex)
delete(spreadPathIndex, frag.Name.Name)

default:
Expand Down Expand Up @@ -359,23 +368,23 @@ func unwrapType(t common.Type) common.Type {
}
}

func (c *context) resolveType(t common.Type) common.Type {
func resolveType(c *context, t common.Type) common.Type {
t2, err := common.ResolveType(t, c.schema.Resolve)
if err != nil {
c.errs = append(c.errs, err)
}
return t2
}

func (c *context) validateDirectives(loc string, directives common.DirectiveList, ops []*query.Operation) {
func validateDirectives(c *opContext, loc string, directives common.DirectiveList) {
directiveNames := make(nameSet)
for _, d := range directives {
dirName := d.Name.Name
c.validateNameCustomMsg(directiveNames, d.Name, "UniqueDirectivesPerLocation", func() string {
validateNameCustomMsg(c.context, directiveNames, d.Name, "UniqueDirectivesPerLocation", func() string {
return fmt.Sprintf("The directive %q can only be used once at this location.", dirName)
})

c.validateArgumentLiterals(d.Args, ops)
validateArgumentLiterals(c, d.Args)

dd, ok := c.schema.Directives[dirName]
if !ok {
Expand All @@ -394,7 +403,7 @@ func (c *context) validateDirectives(loc string, directives common.DirectiveList
c.addErr(d.Name.Loc, "KnownDirectives", "Directive %q may not be used on %s.", dirName, loc)
}

c.validateArgumentTypes(d.Args, dd.Args, d.Name.Loc, ops,
validateArgumentTypes(c, d.Args, dd.Args, d.Name.Loc,
func() string { return fmt.Sprintf("directive %q", "@"+dirName) },
func() string { return fmt.Sprintf("Directive %q", "@"+dirName) },
)
Expand All @@ -404,13 +413,13 @@ func (c *context) validateDirectives(loc string, directives common.DirectiveList

type nameSet map[string]errors.Location

func (c *context) validateName(set nameSet, name common.Ident, rule string, kind string) {
c.validateNameCustomMsg(set, name, rule, func() string {
func validateName(c *context, set nameSet, name common.Ident, rule string, kind string) {
validateNameCustomMsg(c, set, name, rule, func() string {
return fmt.Sprintf("There can be only one %s named %q.", kind, name.Name)
})
}

func (c *context) validateNameCustomMsg(set nameSet, name common.Ident, rule string, msg func() string) {
func validateNameCustomMsg(c *context, set nameSet, name common.Ident, rule string, msg func() string) {
if loc, ok := set[name.Name]; ok {
c.addErrMultiLoc([]errors.Location{loc, name.Loc}, rule, msg())
return
Expand All @@ -419,15 +428,15 @@ func (c *context) validateNameCustomMsg(set nameSet, name common.Ident, rule str
return
}

func (c *context) validateArgumentTypes(args common.ArgumentList, argDecls common.InputValueList, loc errors.Location, ops []*query.Operation, owner1, owner2 func() string) {
func validateArgumentTypes(c *opContext, args common.ArgumentList, argDecls common.InputValueList, loc errors.Location, owner1, owner2 func() string) {
for _, selArg := range args {
arg := argDecls.Get(selArg.Name.Name)
if arg == nil {
c.addErr(selArg.Name.Loc, "KnownArgumentNames", "Unknown argument %q on %s.", selArg.Name.Name, owner1())
continue
}
value := selArg.Value
if ok, reason := c.validateValueType(value, arg.Type, ops); !ok {
if ok, reason := validateValueType(c, value, arg.Type); !ok {
c.addErr(value.Location(), "ArgumentsOfCorrectType", "Argument %q has invalid value %s.\n%s", arg.Name.Name, value, reason)
}
}
Expand All @@ -440,28 +449,28 @@ func (c *context) validateArgumentTypes(args common.ArgumentList, argDecls commo
}
}

func (c *context) validateArgumentLiterals(args common.ArgumentList, ops []*query.Operation) {
func validateArgumentLiterals(c *opContext, args common.ArgumentList) {
argNames := make(nameSet)
for _, arg := range args {
c.validateName(argNames, arg.Name, "UniqueArgumentNames", "argument")
c.validateLiteral(arg.Value, ops)
validateName(c.context, argNames, arg.Name, "UniqueArgumentNames", "argument")
validateLiteral(c, arg.Value)
}
}

func (c *context) validateLiteral(l common.Literal, ops []*query.Operation) {
func validateLiteral(c *opContext, l common.Literal) {
switch l := l.(type) {
case *common.ObjectLit:
fieldNames := make(nameSet)
for _, f := range l.Fields {
c.validateName(fieldNames, f.Name, "UniqueInputFieldNames", "input field")
c.validateLiteral(f.Value, ops)
validateName(c.context, fieldNames, f.Name, "UniqueInputFieldNames", "input field")
validateLiteral(c, f.Value)
}
case *common.ListLit:
for _, entry := range l.Entries {
c.validateLiteral(entry, ops)
validateLiteral(c, entry)
}
case *common.Variable:
for _, op := range ops {
for _, op := range c.ops {
if op.Vars.Get(l.Name) == nil {
byOp := ""
if op.Name.Name != "" {
Expand All @@ -477,9 +486,9 @@ func (c *context) validateLiteral(l common.Literal, ops []*query.Operation) {
}
}

func (c *context) validateValueType(v common.Literal, t common.Type, ops []*query.Operation) (bool, string) {
func validateValueType(c *opContext, v common.Literal, t common.Type) (bool, string) {
if v, ok := v.(*common.Variable); ok {
for _, op := range ops {
for _, op := range c.ops {
if v2 := op.Vars.Get(v.Name); v != nil {
t2, err := common.ResolveType(v2.Type, c.schema.Resolve)
if _, ok := t2.(*common.NonNull); !ok && v2.Default != nil {
Expand Down Expand Up @@ -514,10 +523,10 @@ func (c *context) validateValueType(v common.Literal, t common.Type, ops []*quer
case *common.List:
list, ok := v.(*common.ListLit)
if !ok {
return c.validateValueType(v, t.OfType, ops) // single value instead of list
return validateValueType(c, v, t.OfType) // single value instead of list
}
for i, entry := range list.Entries {
if ok, reason := c.validateValueType(entry, t.OfType, ops); !ok {
if ok, reason := validateValueType(c, entry, t.OfType); !ok {
return false, fmt.Sprintf("In element #%d: %s", i, reason)
}
}
Expand All @@ -534,7 +543,7 @@ func (c *context) validateValueType(v common.Literal, t common.Type, ops []*quer
if iv == nil {
return false, fmt.Sprintf("In field %q: Unknown field.", name)
}
if ok, reason := c.validateValueType(f.Value, iv.Type, ops); !ok {
if ok, reason := validateValueType(c, f.Value, iv.Type); !ok {
return false, fmt.Sprintf("In field %q: %s", name, reason)
}
}
Expand Down

0 comments on commit 4aff297

Please sign in to comment.