Skip to content

Commit

Permalink
Add complexity package tests
Browse files Browse the repository at this point in the history
Also some small behavior fixes to complexity calculations.
  • Loading branch information
edsrzf committed Aug 27, 2018
1 parent 556b93a commit cecd84c
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 19 deletions.
31 changes: 12 additions & 19 deletions complexity/complexity.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,52 +10,45 @@ func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars ma
es: es,
vars: vars,
}
typeName := ""
switch op.Operation {
case ast.Query:
typeName = es.Schema().Query.Name
case ast.Mutation:
typeName = es.Schema().Mutation.Name
case ast.Subscription:
typeName = es.Schema().Subscription.Name
}
return walker.selectionSetComplexity(typeName, op.SelectionSet)
return walker.selectionSetComplexity(op.SelectionSet)
}

type complexityWalker struct {
es graphql.ExecutableSchema
vars map[string]interface{}
}

func (cw complexityWalker) selectionSetComplexity(typeName string, selectionSet ast.SelectionSet) int {
func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int {
var complexity int
for _, selection := range selectionSet {
switch s := selection.(type) {
case *ast.Field:
var childComplexity int
switch s.ObjectDefinition.Kind {
case ast.Object, ast.Interface, ast.Union:
childComplexity = cw.selectionSetComplexity(s.ObjectDefinition.Name, s.SelectionSet)
childComplexity = cw.selectionSetComplexity(s.SelectionSet)
}

args := s.ArgumentMap(cw.vars)
if customComplexity, ok := cw.es.Complexity(typeName, s.Name, childComplexity, args); ok {
if customComplexity, ok := cw.es.Complexity(s.ObjectDefinition.Name, s.Name, childComplexity, args); ok && customComplexity >= childComplexity {
complexity = safeAdd(complexity, customComplexity)
} else {
// default complexity calculation
complexity = safeAdd(complexity, safeAdd(1, childComplexity))
}

case *ast.FragmentSpread:
complexity = safeAdd(complexity, cw.selectionSetComplexity(typeName, s.Definition.SelectionSet))
complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet))

case *ast.InlineFragment:
complexity = safeAdd(complexity, cw.selectionSetComplexity(typeName, s.SelectionSet))
complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet))
}
}
return complexity
}

const maxInt = int(^uint(0) >> 1)

// safeAdd is a saturating add of a and b that ignores negative operands.
// If a + b would overflow through normal Go addition,
// it returns the maximum integer value instead.
Expand All @@ -67,19 +60,19 @@ func (cw complexityWalker) selectionSetComplexity(typeName string, selectionSet
// return negative values.
func safeAdd(a, b int) int {
// Ignore negative operands.
if a <= 0 {
if a < 0 {
if b < 0 {
return 0
return 1
}
return b
} else if b <= 0 {
} else if b < 0 {
return a
}

c := a + b
if c < a {
// Set c to maximum integer instead of overflowing.
c = int(^uint(0) >> 1)
c = maxInt
}
return c
}
228 changes: 228 additions & 0 deletions complexity/complexity_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
package complexity

import (
"context"
"math"
"testing"

"github.com/99designs/gqlgen/graphql"
"github.com/stretchr/testify/require"
"github.com/vektah/gqlparser"
"github.com/vektah/gqlparser/ast"
)

var schema = gqlparser.MustLoadSchema(
&ast.Source{
Name: "test.graphql",
Input: `
interface NameInterface {
name: String
}
type Item implements NameInterface {
scalar: String
name: String
list(size: Int = 10): [Item]
}
type Named {
name: String
}
union NameUnion = Item | Named
type Query {
scalar: String
object: Item
interface: NameInterface
union: NameUnion
customObject: Item
list(size: Int = 10): [Item]
}
`,
},
)

func requireComplexity(t *testing.T, source string, vars map[string]interface{}, complexity int) {
t.Helper()
query := gqlparser.MustLoadQuery(schema, source)
es := &executableSchemaStub{}
actualComplexity := Calculate(es, query.Operations[0], vars)
require.Equal(t, complexity, actualComplexity)
}

func TestCalculate(t *testing.T) {
t.Run("uses default cost", func(t *testing.T) {
const query = `
{
scalar
}
`
requireComplexity(t, query, nil, 1)
})

t.Run("adds together fields", func(t *testing.T) {
const query = `
{
scalar1: scalar
scalar2: scalar
}
`
requireComplexity(t, query, nil, 2)
})

t.Run("a level of nesting adds complexity", func(t *testing.T) {
const query = `
{
object {
scalar
}
}
`
requireComplexity(t, query, nil, 2)
})

t.Run("adds together children", func(t *testing.T) {
const query = `
{
scalar
object {
scalar
}
}
`
requireComplexity(t, query, nil, 3)
})

t.Run("adds inline fragments", func(t *testing.T) {
const query = `
{
... {
scalar
}
}
`
requireComplexity(t, query, nil, 1)
})

t.Run("adds fragments", func(t *testing.T) {
const query = `
{
... Fragment
}
fragment Fragment on Query {
scalar
}
`
requireComplexity(t, query, nil, 1)
})

t.Run("uses custom complexity", func(t *testing.T) {
const query = `
{
list {
scalar
}
}
`
requireComplexity(t, query, nil, 10)
})

t.Run("ignores negative custom complexity values", func(t *testing.T) {
const query = `
{
list(size: -100) {
scalar
}
}
`
requireComplexity(t, query, nil, 2)
})

t.Run("custom complexity must be >= child complexity", func(t *testing.T) {
const query = `
{
customObject {
list(size: 100) {
scalar
}
}
}
`
requireComplexity(t, query, nil, 101)
})

t.Run("interfaces have different costs than concrete types", func(t *testing.T) {
const query = `
{
interface {
name
}
}
`
requireComplexity(t, query, nil, 6)
})

t.Run("guards against integer overflow", func(t *testing.T) {
if maxInt == math.MaxInt32 {
// this test is written assuming 64-bit ints
t.Skip()
}
const query = `
{
list1: list(size: 2147483647) {
list(size: 2147483647) {
list(size: 2) {
scalar
}
}
}
# total cost so far: 2*0x7fffffff*0x7fffffff
# = 0x7ffffffe00000002
# Adding the same again should cause overflow
list2: list(size: 2147483647) {
list(size: 2147483647) {
list(size: 2) {
scalar
}
}
}
}
`
requireComplexity(t, query, nil, math.MaxInt64)
})
}

type executableSchemaStub struct {
}

var _ graphql.ExecutableSchema = &executableSchemaStub{}

func (e *executableSchemaStub) Schema() *ast.Schema {
return schema
}

func (e *executableSchemaStub) Complexity(typeName, field string, childComplexity int, args map[string]interface{}) (int, bool) {
switch typeName + "." + field {
case "Query.list", "Item.list":
return int(args["size"].(int64)) * childComplexity, true
case "Query.customObject":
return 1, true
case "NameInterface.name":
return 5, true
}
return 0, false
}

func (e *executableSchemaStub) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {
panic("Query should never be called by complexity calculations")
}

func (e *executableSchemaStub) Mutation(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {
panic("Mutation should never be called by complexity calculations")
}

func (e *executableSchemaStub) Subscription(ctx context.Context, op *ast.OperationDefinition) func() *graphql.Response {
panic("Subscription should never be called by complexity calculations")
}

0 comments on commit cecd84c

Please sign in to comment.