Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into enforce-request-con…
Browse files Browse the repository at this point in the history
…tent-type
  • Loading branch information
vektah committed May 8, 2019
2 parents f7d0b9c + f8ef6d2 commit d4b3de3
Show file tree
Hide file tree
Showing 39 changed files with 4,674 additions and 227 deletions.
6 changes: 3 additions & 3 deletions codegen/complexity.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package codegen

func (o *Object) UniqueFields() map[string]*Field {
m := map[string]*Field{}
func (o *Object) UniqueFields() map[string][]*Field {
m := map[string][]*Field{}

for _, f := range o.Fields {
m[f.GoFieldName] = f
m[f.GoFieldName] = append(m[f.GoFieldName], f)
}

return m
Expand Down
6 changes: 4 additions & 2 deletions codegen/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,10 @@ func (c *Config) InjectBuiltins(s *ast.Schema) {

// These are additional types that are injected if defined in the schema as scalars.
extraBuiltins := TypeMap{
"Time": {Model: StringList{"github.com/99designs/gqlgen/graphql.Time"}},
"Map": {Model: StringList{"github.com/99designs/gqlgen/graphql.Map"}},
"Time": {Model: StringList{"github.com/99designs/gqlgen/graphql.Time"}},
"Map": {Model: StringList{"github.com/99designs/gqlgen/graphql.Map"}},
"Upload": {Model: StringList{"github.com/99designs/gqlgen/graphql.Upload"}},
"Any": {Model: StringList{"github.com/99designs/gqlgen/graphql.Any"}},
}

for typeName, entry := range extraBuiltins {
Expand Down
32 changes: 19 additions & 13 deletions codegen/generated!.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ type ComplexityRoot struct {
{{ range $object := .Objects }}
{{ if not $object.IsReserved -}}
{{ $object.Name|go }} struct {
{{ range $field := $object.UniqueFields -}}
{{ range $_, $fields := $object.UniqueFields }}
{{- $field := index $fields 0 -}}
{{ if not $field.IsReserved -}}
{{ $field.GoFieldName }} {{ $field.ComplexitySignature }}
{{ end }}
Expand Down Expand Up @@ -84,20 +85,25 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
switch typeName + "." + field {
{{ range $object := .Objects }}
{{ if not $object.IsReserved }}
{{ range $field := $object.UniqueFields }}
{{ if not $field.IsReserved }}
case "{{$object.Name}}.{{$field.GoFieldName}}":
if e.complexity.{{$object.Name|go}}.{{$field.GoFieldName}} == nil {
break
}
{{ if $field.Args }}
args, err := ec.{{ $field.ArgsFunc }}(context.TODO(),rawArgs)
if err != nil {
return 0, false
{{ range $_, $fields := $object.UniqueFields }}
{{- $len := len $fields }}
{{- range $i, $field := $fields }}
{{- $last := eq (add $i 1) $len }}
{{- if not $field.IsReserved }}
{{- if eq $i 0 }}case {{ end }}"{{$object.Name}}.{{$field.Name}}"{{ if not $last }},{{ else }}:
if e.complexity.{{$object.Name|go}}.{{$field.GoFieldName}} == nil {
break
}
{{ if $field.Args }}
args, err := ec.{{ $field.ArgsFunc }}(context.TODO(),rawArgs)
if err != nil {
return 0, false
}
{{ end }}
return e.complexity.{{$object.Name|go}}.{{$field.GoFieldName}}(childComplexity{{if $field.Args}}, {{$field.ComplexityArgs}} {{ end }}), true
{{ end }}
return e.complexity.{{$object.Name|go}}.{{$field.GoFieldName}}(childComplexity{{if $field.Args}}, {{$field.ComplexityArgs}} {{end}}), true
{{ end }}
{{- end }}
{{- end }}
{{ end }}
{{ end }}
{{ end }}
Expand Down
71 changes: 71 additions & 0 deletions codegen/testserver/complexity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,76 @@ func TestComplexityCollisions(t *testing.T) {
require.Equal(t, 2, resp.Overlapping.OldFoo)
require.Equal(t, 3, resp.Overlapping.NewFoo)
require.Equal(t, 3, resp.Overlapping.New_foo)
}

func TestComplexityFuncs(t *testing.T) {
resolvers := &Stub{}
cfg := Config{Resolvers: resolvers}
cfg.Complexity.OverlappingFields.Foo = func(childComplexity int) int { return 1000 }
cfg.Complexity.OverlappingFields.NewFoo = func(childComplexity int) int { return 5 }

srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(cfg), handler.ComplexityLimit(10)))
c := client.New(srv.URL)

resolvers.QueryResolver.Overlapping = func(ctx context.Context) (fields *OverlappingFields, e error) {
return &OverlappingFields{
Foo: 2,
NewFoo: 3,
}, nil
}

t.Run("with high complexity limit will not run", func(t *testing.T) {
ran := false
resolvers.OverlappingFieldsResolver.OldFoo = func(ctx context.Context, obj *OverlappingFields) (i int, e error) {
ran = true
return obj.Foo, nil
}

var resp struct {
Overlapping interface{}
}
err := c.Post(`query { overlapping { oneFoo, twoFoo, oldFoo, newFoo, new_foo } }`, &resp)

require.EqualError(t, err, `http 422: {"errors":[{"message":"operation has complexity 2012, which exceeds the limit of 10"}],"data":null}`)
require.False(t, ran)
})

t.Run("with low complexity will run", func(t *testing.T) {
ran := false
resolvers.QueryResolver.Overlapping = func(ctx context.Context) (fields *OverlappingFields, e error) {
ran = true
return &OverlappingFields{
Foo: 2,
NewFoo: 3,
}, nil
}

var resp struct {
Overlapping interface{}
}
c.MustPost(`query { overlapping { newFoo } }`, &resp)

require.True(t, ran)
})

t.Run("with multiple low complexity will not run", func(t *testing.T) {
ran := false
resolvers.QueryResolver.Overlapping = func(ctx context.Context) (fields *OverlappingFields, e error) {
ran = true
return &OverlappingFields{
Foo: 2,
NewFoo: 3,
}, nil
}

var resp interface{}
err := c.Post(`query {
a: overlapping { newFoo },
b: overlapping { newFoo },
c: overlapping { newFoo },
}`, &resp)

require.EqualError(t, err, `http 422: {"errors":[{"message":"operation has complexity 18, which exceeds the limit of 10"}],"data":null}`)
require.False(t, ran)
})
}
Loading

0 comments on commit d4b3de3

Please sign in to comment.