Skip to content

Commit

Permalink
graphql/executor: reinit all extension values on every Use() call
Browse files Browse the repository at this point in the history
  • Loading branch information
technoweenie committed Mar 4, 2020
1 parent f3909a8 commit 9ae6bc0
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 19 deletions.
24 changes: 10 additions & 14 deletions graphql/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,9 @@ import (

// Executor executes graphql queries against a schema.
type Executor struct {
es graphql.ExecutableSchema
extensions []graphql.HandlerExtension
operationMiddleware graphql.OperationMiddleware
responseMiddleware graphql.ResponseMiddleware
fieldMiddleware graphql.FieldMiddleware
operationParameterMutators []graphql.OperationParameterMutator
operationContextMutators []graphql.OperationContextMutator
es graphql.ExecutableSchema
extensions []graphql.HandlerExtension
ext extensions

errorPresenter graphql.ErrorPresenterFunc
recoverFunc graphql.RecoverFunc
Expand All @@ -36,8 +32,8 @@ func New(es graphql.ExecutableSchema) *Executor {
errorPresenter: graphql.DefaultErrorPresenter,
recoverFunc: graphql.DefaultRecover,
queryCache: graphql.NoCache{},
ext: processExtensions(nil),
}
e.setExtensions()
return e
}

Expand All @@ -63,15 +59,15 @@ func (e *Executor) CreateOperationContext(ctx context.Context, params *graphql.R
rc := &graphql.OperationContext{
DisableIntrospection: true,
Recover: e.recoverFunc,
ResolverMiddleware: e.fieldMiddleware,
ResolverMiddleware: e.ext.fieldMiddleware,
Stats: graphql.Stats{
Read: params.ReadTime,
OperationStart: graphql.GetStartTime(ctx),
},
}
ctx = graphql.WithOperationContext(ctx, rc)

for _, p := range e.operationParameterMutators {
for _, p := range e.ext.operationParameterMutators {
if err := p.MutateOperationParameters(ctx, params); err != nil {
return rc, gqlerror.List{err}
}
Expand Down Expand Up @@ -99,7 +95,7 @@ func (e *Executor) CreateOperationContext(ctx context.Context, params *graphql.R
}
rc.Stats.Validation.End = graphql.Now()

for _, p := range e.operationContextMutators {
for _, p := range e.ext.operationContextMutators {
if err := p.MutateOperationContext(ctx, rc); err != nil {
return rc, gqlerror.List{err}
}
Expand All @@ -112,7 +108,7 @@ func (e *Executor) DispatchOperation(ctx context.Context, rc *graphql.OperationC
ctx = graphql.WithOperationContext(ctx, rc)

var innerCtx context.Context
res := e.operationMiddleware(ctx, func(ctx context.Context) graphql.ResponseHandler {
res := e.ext.operationMiddleware(ctx, func(ctx context.Context) graphql.ResponseHandler {
innerCtx = ctx

tmpResponseContext := graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc)
Expand All @@ -123,7 +119,7 @@ func (e *Executor) DispatchOperation(ctx context.Context, rc *graphql.OperationC

return func(ctx context.Context) *graphql.Response {
ctx = graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc)
resp := e.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response {
resp := e.ext.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response {
resp := responses(ctx)
if resp == nil {
return nil
Expand All @@ -149,7 +145,7 @@ func (e *Executor) DispatchError(ctx context.Context, list gqlerror.List) *graph
graphql.AddError(ctx, gErr)
}

resp := e.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response {
resp := e.ext.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response {
resp := &graphql.Response{
Errors: list,
}
Expand Down
51 changes: 51 additions & 0 deletions graphql/executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,25 @@ func TestExecutor(t *testing.T) {
assert.Equal(t, []string{"first", "second"}, calls)
})

t.Run("invokes operation mutators", func(t *testing.T) {
var calls []string
exec.Use(&testParamMutator{
Mutate: func(ctx context.Context, req *graphql.RawParams) *gqlerror.Error {
calls = append(calls, "param")
return nil
},
})
exec.Use(&testCtxMutator{
Mutate: func(ctx context.Context, rc *graphql.OperationContext) *gqlerror.Error {
calls = append(calls, "context")
return nil
},
})
resp := query(exec, "", "{name}")
assert.Equal(t, `{"name":"test"}`, string(resp.Data))
assert.Equal(t, []string{"param", "context"}, calls)
})

t.Run("get query parse error in AroundResponses", func(t *testing.T) {
var errors1 gqlerror.List
var errors2 gqlerror.List
Expand Down Expand Up @@ -116,6 +135,38 @@ func TestExecutor(t *testing.T) {
})
}

type testParamMutator struct {
Mutate func(context.Context, *graphql.RawParams) *gqlerror.Error
}

func (m *testParamMutator) ExtensionName() string {
return "Operation: Mutate Parameters"
}

func (m *testParamMutator) Validate(s graphql.ExecutableSchema) error {
return nil
}

func (m *testParamMutator) MutateOperationParameters(ctx context.Context, r *graphql.RawParams) *gqlerror.Error {
return m.Mutate(ctx, r)
}

type testCtxMutator struct {
Mutate func(context.Context, *graphql.OperationContext) *gqlerror.Error
}

func (m *testCtxMutator) ExtensionName() string {
return "Operation: Mutate the Context"
}

func (m *testCtxMutator) Validate(s graphql.ExecutableSchema) error {
return nil
}

func (m *testCtxMutator) MutateOperationContext(ctx context.Context, rc *graphql.OperationContext) *gqlerror.Error {
return m.Mutate(ctx, rc)
}

func TestErrorServer(t *testing.T) {
exec := testexecutor.NewError()

Expand Down
20 changes: 15 additions & 5 deletions graphql/executor/extensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (e *Executor) Use(extension graphql.HandlerExtension) {
graphql.FieldInterceptor,
graphql.ResponseInterceptor:
e.extensions = append(e.extensions, extension)
e.setExtensions()
e.ext = processExtensions(e.extensions)

default:
panic(fmt.Errorf("cannot Use %T as a gqlgen handler extension because it does not implement any extension hooks", extension))
Expand All @@ -42,7 +42,15 @@ func (e *Executor) AroundResponses(f graphql.ResponseMiddleware) {
e.Use(aroundRespFunc(f))
}

func (e *Executor) setExtensions() {
type extensions struct {
operationMiddleware graphql.OperationMiddleware
responseMiddleware graphql.ResponseMiddleware
fieldMiddleware graphql.FieldMiddleware
operationParameterMutators []graphql.OperationParameterMutator
operationContextMutators []graphql.OperationContextMutator
}

func processExtensions(extensions []graphql.HandlerExtension) (e extensions) {
e.operationMiddleware = func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
return next(ctx)
}
Expand All @@ -54,8 +62,8 @@ func (e *Executor) setExtensions() {
}

// this loop goes backwards so the first extension is the outer most middleware and runs first.
for i := len(e.extensions) - 1; i >= 0; i-- {
p := e.extensions[i]
for i := len(extensions) - 1; i >= 0; i-- {
p := extensions[i]
if p, ok := p.(graphql.OperationInterceptor); ok {
previous := e.operationMiddleware
e.operationMiddleware = func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
Expand Down Expand Up @@ -84,7 +92,7 @@ func (e *Executor) setExtensions() {
}
}

for _, p := range e.extensions {
for _, p := range extensions {
if p, ok := p.(graphql.OperationParameterMutator); ok {
e.operationParameterMutators = append(e.operationParameterMutators, p)
}
Expand All @@ -93,6 +101,8 @@ func (e *Executor) setExtensions() {
e.operationContextMutators = append(e.operationContextMutators, p)
}
}

return
}

type aroundOpFunc func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler
Expand Down

0 comments on commit 9ae6bc0

Please sign in to comment.