diff --git a/graphql/executor/executor.go b/graphql/executor/executor.go index 0ace5ad38b5..d4f7ba7f786 100644 --- a/graphql/executor/executor.go +++ b/graphql/executor/executor.go @@ -13,13 +13,9 @@ import ( // Executor executes graphql queries against a schema. type Executor struct { - es graphql.ExecutableSchema - extensions []graphql.HandlerExtension - operationMiddleware graphql.OperationMiddleware - responseMiddleware graphql.ResponseMiddleware - fieldMiddleware graphql.FieldMiddleware - operationParameterMutators []graphql.OperationParameterMutator - operationContextMutators []graphql.OperationContextMutator + es graphql.ExecutableSchema + extensions []graphql.HandlerExtension + ext extensions errorPresenter graphql.ErrorPresenterFunc recoverFunc graphql.RecoverFunc @@ -36,8 +32,8 @@ func New(es graphql.ExecutableSchema) *Executor { errorPresenter: graphql.DefaultErrorPresenter, recoverFunc: graphql.DefaultRecover, queryCache: graphql.NoCache{}, + ext: processExtensions(nil), } - e.setExtensions() return e } @@ -63,7 +59,7 @@ func (e *Executor) CreateOperationContext(ctx context.Context, params *graphql.R rc := &graphql.OperationContext{ DisableIntrospection: true, Recover: e.recoverFunc, - ResolverMiddleware: e.fieldMiddleware, + ResolverMiddleware: e.ext.fieldMiddleware, Stats: graphql.Stats{ Read: params.ReadTime, OperationStart: graphql.GetStartTime(ctx), @@ -71,7 +67,7 @@ func (e *Executor) CreateOperationContext(ctx context.Context, params *graphql.R } ctx = graphql.WithOperationContext(ctx, rc) - for _, p := range e.operationParameterMutators { + for _, p := range e.ext.operationParameterMutators { if err := p.MutateOperationParameters(ctx, params); err != nil { return rc, gqlerror.List{err} } @@ -99,7 +95,7 @@ func (e *Executor) CreateOperationContext(ctx context.Context, params *graphql.R } rc.Stats.Validation.End = graphql.Now() - for _, p := range e.operationContextMutators { + for _, p := range e.ext.operationContextMutators { if err := p.MutateOperationContext(ctx, rc); err != nil { return rc, gqlerror.List{err} } @@ -112,7 +108,7 @@ func (e *Executor) DispatchOperation(ctx context.Context, rc *graphql.OperationC ctx = graphql.WithOperationContext(ctx, rc) var innerCtx context.Context - res := e.operationMiddleware(ctx, func(ctx context.Context) graphql.ResponseHandler { + res := e.ext.operationMiddleware(ctx, func(ctx context.Context) graphql.ResponseHandler { innerCtx = ctx tmpResponseContext := graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc) @@ -123,7 +119,7 @@ func (e *Executor) DispatchOperation(ctx context.Context, rc *graphql.OperationC return func(ctx context.Context) *graphql.Response { ctx = graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc) - resp := e.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response { + resp := e.ext.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response { resp := responses(ctx) if resp == nil { return nil @@ -149,7 +145,7 @@ func (e *Executor) DispatchError(ctx context.Context, list gqlerror.List) *graph graphql.AddError(ctx, gErr) } - resp := e.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response { + resp := e.ext.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response { resp := &graphql.Response{ Errors: list, } diff --git a/graphql/executor/executor_test.go b/graphql/executor/executor_test.go index 4fb7dce8235..fc4207effbb 100644 --- a/graphql/executor/executor_test.go +++ b/graphql/executor/executor_test.go @@ -69,6 +69,25 @@ func TestExecutor(t *testing.T) { assert.Equal(t, []string{"first", "second"}, calls) }) + t.Run("invokes operation mutators", func(t *testing.T) { + var calls []string + exec.Use(&testParamMutator{ + Mutate: func(ctx context.Context, req *graphql.RawParams) *gqlerror.Error { + calls = append(calls, "param") + return nil + }, + }) + exec.Use(&testCtxMutator{ + Mutate: func(ctx context.Context, rc *graphql.OperationContext) *gqlerror.Error { + calls = append(calls, "context") + return nil + }, + }) + resp := query(exec, "", "{name}") + assert.Equal(t, `{"name":"test"}`, string(resp.Data)) + assert.Equal(t, []string{"param", "context"}, calls) + }) + t.Run("get query parse error in AroundResponses", func(t *testing.T) { var errors1 gqlerror.List var errors2 gqlerror.List @@ -116,6 +135,38 @@ func TestExecutor(t *testing.T) { }) } +type testParamMutator struct { + Mutate func(context.Context, *graphql.RawParams) *gqlerror.Error +} + +func (m *testParamMutator) ExtensionName() string { + return "Operation: Mutate Parameters" +} + +func (m *testParamMutator) Validate(s graphql.ExecutableSchema) error { + return nil +} + +func (m *testParamMutator) MutateOperationParameters(ctx context.Context, r *graphql.RawParams) *gqlerror.Error { + return m.Mutate(ctx, r) +} + +type testCtxMutator struct { + Mutate func(context.Context, *graphql.OperationContext) *gqlerror.Error +} + +func (m *testCtxMutator) ExtensionName() string { + return "Operation: Mutate the Context" +} + +func (m *testCtxMutator) Validate(s graphql.ExecutableSchema) error { + return nil +} + +func (m *testCtxMutator) MutateOperationContext(ctx context.Context, rc *graphql.OperationContext) *gqlerror.Error { + return m.Mutate(ctx, rc) +} + func TestErrorServer(t *testing.T) { exec := testexecutor.NewError() diff --git a/graphql/executor/extensions.go b/graphql/executor/extensions.go index 4593e7e6238..d737f49a789 100644 --- a/graphql/executor/extensions.go +++ b/graphql/executor/extensions.go @@ -20,7 +20,7 @@ func (e *Executor) Use(extension graphql.HandlerExtension) { graphql.FieldInterceptor, graphql.ResponseInterceptor: e.extensions = append(e.extensions, extension) - e.setExtensions() + e.ext = processExtensions(e.extensions) default: panic(fmt.Errorf("cannot Use %T as a gqlgen handler extension because it does not implement any extension hooks", extension)) @@ -42,7 +42,15 @@ func (e *Executor) AroundResponses(f graphql.ResponseMiddleware) { e.Use(aroundRespFunc(f)) } -func (e *Executor) setExtensions() { +type extensions struct { + operationMiddleware graphql.OperationMiddleware + responseMiddleware graphql.ResponseMiddleware + fieldMiddleware graphql.FieldMiddleware + operationParameterMutators []graphql.OperationParameterMutator + operationContextMutators []graphql.OperationContextMutator +} + +func processExtensions(extensions []graphql.HandlerExtension) (e extensions) { e.operationMiddleware = func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler { return next(ctx) } @@ -54,8 +62,8 @@ func (e *Executor) setExtensions() { } // this loop goes backwards so the first extension is the outer most middleware and runs first. - for i := len(e.extensions) - 1; i >= 0; i-- { - p := e.extensions[i] + for i := len(extensions) - 1; i >= 0; i-- { + p := extensions[i] if p, ok := p.(graphql.OperationInterceptor); ok { previous := e.operationMiddleware e.operationMiddleware = func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler { @@ -84,7 +92,7 @@ func (e *Executor) setExtensions() { } } - for _, p := range e.extensions { + for _, p := range extensions { if p, ok := p.(graphql.OperationParameterMutator); ok { e.operationParameterMutators = append(e.operationParameterMutators, p) } @@ -93,6 +101,8 @@ func (e *Executor) setExtensions() { e.operationContextMutators = append(e.operationContextMutators, p) } } + + return } type aroundOpFunc func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler