From cecd84c698b8ce171e0cd9604215405de248e765 Mon Sep 17 00:00:00 2001 From: Evan Shaw Date: Mon, 27 Aug 2018 10:45:54 +1200 Subject: [PATCH] Add complexity package tests Also some small behavior fixes to complexity calculations. --- complexity/complexity.go | 31 ++--- complexity/complexity_test.go | 228 ++++++++++++++++++++++++++++++++++ 2 files changed, 240 insertions(+), 19 deletions(-) create mode 100644 complexity/complexity_test.go diff --git a/complexity/complexity.go b/complexity/complexity.go index 0868acb2b3b..6e7c671a0cb 100644 --- a/complexity/complexity.go +++ b/complexity/complexity.go @@ -10,16 +10,7 @@ func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars ma es: es, vars: vars, } - typeName := "" - switch op.Operation { - case ast.Query: - typeName = es.Schema().Query.Name - case ast.Mutation: - typeName = es.Schema().Mutation.Name - case ast.Subscription: - typeName = es.Schema().Subscription.Name - } - return walker.selectionSetComplexity(typeName, op.SelectionSet) + return walker.selectionSetComplexity(op.SelectionSet) } type complexityWalker struct { @@ -27,7 +18,7 @@ type complexityWalker struct { vars map[string]interface{} } -func (cw complexityWalker) selectionSetComplexity(typeName string, selectionSet ast.SelectionSet) int { +func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int { var complexity int for _, selection := range selectionSet { switch s := selection.(type) { @@ -35,11 +26,11 @@ func (cw complexityWalker) selectionSetComplexity(typeName string, selectionSet var childComplexity int switch s.ObjectDefinition.Kind { case ast.Object, ast.Interface, ast.Union: - childComplexity = cw.selectionSetComplexity(s.ObjectDefinition.Name, s.SelectionSet) + childComplexity = cw.selectionSetComplexity(s.SelectionSet) } args := s.ArgumentMap(cw.vars) - if customComplexity, ok := cw.es.Complexity(typeName, s.Name, childComplexity, args); ok { + if customComplexity, ok := cw.es.Complexity(s.ObjectDefinition.Name, s.Name, childComplexity, args); ok && customComplexity >= childComplexity { complexity = safeAdd(complexity, customComplexity) } else { // default complexity calculation @@ -47,15 +38,17 @@ func (cw complexityWalker) selectionSetComplexity(typeName string, selectionSet } case *ast.FragmentSpread: - complexity = safeAdd(complexity, cw.selectionSetComplexity(typeName, s.Definition.SelectionSet)) + complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet)) case *ast.InlineFragment: - complexity = safeAdd(complexity, cw.selectionSetComplexity(typeName, s.SelectionSet)) + complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet)) } } return complexity } +const maxInt = int(^uint(0) >> 1) + // safeAdd is a saturating add of a and b that ignores negative operands. // If a + b would overflow through normal Go addition, // it returns the maximum integer value instead. @@ -67,19 +60,19 @@ func (cw complexityWalker) selectionSetComplexity(typeName string, selectionSet // return negative values. func safeAdd(a, b int) int { // Ignore negative operands. - if a <= 0 { + if a < 0 { if b < 0 { - return 0 + return 1 } return b - } else if b <= 0 { + } else if b < 0 { return a } c := a + b if c < a { // Set c to maximum integer instead of overflowing. - c = int(^uint(0) >> 1) + c = maxInt } return c } diff --git a/complexity/complexity_test.go b/complexity/complexity_test.go new file mode 100644 index 00000000000..f5a512d758d --- /dev/null +++ b/complexity/complexity_test.go @@ -0,0 +1,228 @@ +package complexity + +import ( + "context" + "math" + "testing" + + "github.com/99designs/gqlgen/graphql" + "github.com/stretchr/testify/require" + "github.com/vektah/gqlparser" + "github.com/vektah/gqlparser/ast" +) + +var schema = gqlparser.MustLoadSchema( + &ast.Source{ + Name: "test.graphql", + Input: ` + interface NameInterface { + name: String + } + + type Item implements NameInterface { + scalar: String + name: String + list(size: Int = 10): [Item] + } + + type Named { + name: String + } + + union NameUnion = Item | Named + + type Query { + scalar: String + object: Item + interface: NameInterface + union: NameUnion + customObject: Item + list(size: Int = 10): [Item] + } + `, + }, +) + +func requireComplexity(t *testing.T, source string, vars map[string]interface{}, complexity int) { + t.Helper() + query := gqlparser.MustLoadQuery(schema, source) + es := &executableSchemaStub{} + actualComplexity := Calculate(es, query.Operations[0], vars) + require.Equal(t, complexity, actualComplexity) +} + +func TestCalculate(t *testing.T) { + t.Run("uses default cost", func(t *testing.T) { + const query = ` + { + scalar + } + ` + requireComplexity(t, query, nil, 1) + }) + + t.Run("adds together fields", func(t *testing.T) { + const query = ` + { + scalar1: scalar + scalar2: scalar + } + ` + requireComplexity(t, query, nil, 2) + }) + + t.Run("a level of nesting adds complexity", func(t *testing.T) { + const query = ` + { + object { + scalar + } + } + ` + requireComplexity(t, query, nil, 2) + }) + + t.Run("adds together children", func(t *testing.T) { + const query = ` + { + scalar + object { + scalar + } + } + ` + requireComplexity(t, query, nil, 3) + }) + + t.Run("adds inline fragments", func(t *testing.T) { + const query = ` + { + ... { + scalar + } + } + ` + requireComplexity(t, query, nil, 1) + }) + + t.Run("adds fragments", func(t *testing.T) { + const query = ` + { + ... Fragment + } + + fragment Fragment on Query { + scalar + } + ` + requireComplexity(t, query, nil, 1) + }) + + t.Run("uses custom complexity", func(t *testing.T) { + const query = ` + { + list { + scalar + } + } + ` + requireComplexity(t, query, nil, 10) + }) + + t.Run("ignores negative custom complexity values", func(t *testing.T) { + const query = ` + { + list(size: -100) { + scalar + } + } + ` + requireComplexity(t, query, nil, 2) + }) + + t.Run("custom complexity must be >= child complexity", func(t *testing.T) { + const query = ` + { + customObject { + list(size: 100) { + scalar + } + } + } + ` + requireComplexity(t, query, nil, 101) + }) + + t.Run("interfaces have different costs than concrete types", func(t *testing.T) { + const query = ` + { + interface { + name + } + } + ` + requireComplexity(t, query, nil, 6) + }) + + t.Run("guards against integer overflow", func(t *testing.T) { + if maxInt == math.MaxInt32 { + // this test is written assuming 64-bit ints + t.Skip() + } + const query = ` + { + list1: list(size: 2147483647) { + list(size: 2147483647) { + list(size: 2) { + scalar + } + } + } + # total cost so far: 2*0x7fffffff*0x7fffffff + # = 0x7ffffffe00000002 + # Adding the same again should cause overflow + list2: list(size: 2147483647) { + list(size: 2147483647) { + list(size: 2) { + scalar + } + } + } + } + ` + requireComplexity(t, query, nil, math.MaxInt64) + }) +} + +type executableSchemaStub struct { +} + +var _ graphql.ExecutableSchema = &executableSchemaStub{} + +func (e *executableSchemaStub) Schema() *ast.Schema { + return schema +} + +func (e *executableSchemaStub) Complexity(typeName, field string, childComplexity int, args map[string]interface{}) (int, bool) { + switch typeName + "." + field { + case "Query.list", "Item.list": + return int(args["size"].(int64)) * childComplexity, true + case "Query.customObject": + return 1, true + case "NameInterface.name": + return 5, true + } + return 0, false +} + +func (e *executableSchemaStub) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { + panic("Query should never be called by complexity calculations") +} + +func (e *executableSchemaStub) Mutation(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { + panic("Mutation should never be called by complexity calculations") +} + +func (e *executableSchemaStub) Subscription(ctx context.Context, op *ast.OperationDefinition) func() *graphql.Response { + panic("Subscription should never be called by complexity calculations") +}