Skip to content

Commit

Permalink
Implement bc shim for old handler package
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Nov 1, 2019
1 parent 631142c commit 473a0d2
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 29 deletions.
30 changes: 30 additions & 0 deletions graphql/handler/lru/lru.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package lru

import (
"github.com/99designs/gqlgen/graphql"
lru "github.com/hashicorp/golang-lru"
)

type LRU struct {
lru *lru.Cache
}

var _ graphql.Cache = &LRU{}

func New(size int) *LRU {
cache, err := lru.New(size)
if err != nil {
// An error is only returned for non-positive cache size
// and we already checked for that.
panic("unexpected error creating cache: " + err.Error())
}
return &LRU{cache}
}

func (l LRU) Get(key string) (value interface{}, ok bool) {
return l.lru.Get(key)
}

func (l LRU) Add(key string, value interface{}) {
l.lru.Add(key, value)
}
41 changes: 34 additions & 7 deletions graphql/handler/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package handler

import (
"context"
"encoding/json"
"fmt"
"net/http"
Expand Down Expand Up @@ -64,6 +65,21 @@ func (s *Server) Use(extension graphql.HandlerExtension) {
}
}

// AroundFields is a convenience method for creating an extension that only implements field middleware
func (s *Server) AroundFields(f graphql.FieldMiddleware) {
s.Use(FieldFunc(f))
}

// AroundOperations is a convenience method for creating an extension that only implements operation middleware
func (s *Server) AroundOperations(f graphql.OperationMiddleware) {
s.Use(OperationFunc(f))
}

// AroundResponses is a convenience method for creating an extension that only implements response middleware
func (s *Server) AroundResponses(f graphql.ResponseMiddleware) {
s.Use(ResponseFunc(f))
}

func (s *Server) getTransport(r *http.Request) graphql.Transport {
for _, t := range s.transports {
if t.Supports(r) {
Expand All @@ -85,13 +101,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
transport.Do(w, r, s.exec)
}

func getStatus(resp *graphql.Response) graphql.Status {
if len(resp.Errors) > 0 {
return graphql.StatusResolverError
}
return graphql.StatusOk
}

func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
w.WriteHeader(code)
b, err := json.Marshal(&graphql.Response{Errors: errors})
Expand All @@ -104,3 +113,21 @@ func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
}

type OperationFunc func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler

func (r OperationFunc) InterceptOperation(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
return r(ctx, next)
}

type ResponseFunc func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response

func (r ResponseFunc) InterceptResponse(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
return r(ctx, next)
}

type FieldFunc func(ctx context.Context, next graphql.Resolver) (res interface{}, err error)

func (f FieldFunc) InterceptField(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
return f(ctx, next)
}
44 changes: 24 additions & 20 deletions graphql/handler/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,30 @@ func TestServer(t *testing.T) {

t.Run("invokes operation middleware in order", func(t *testing.T) {
var calls []string
srv.Use(opFunc(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
calls = append(calls, "first")
return next(ctx)
}))
srv.Use(opFunc(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
})
srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
calls = append(calls, "second")
return next(ctx)
}))
})

resp := get(srv, "/foo?query={name}")
assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
assert.Equal(t, []string{"first", "second"}, calls)
})

t.Run("invokes response middleware in order", func(t *testing.T) {
var calls []string
srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
calls = append(calls, "first")
return next(ctx)
})
srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
calls = append(calls, "second")
return next(ctx)
})

resp := get(srv, "/foo?query={name}")
assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
Expand All @@ -62,14 +78,14 @@ func TestServer(t *testing.T) {

t.Run("invokes field middleware in order", func(t *testing.T) {
var calls []string
srv.Use(fieldFunc(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
calls = append(calls, "first")
return next(ctx)
}))
srv.Use(fieldFunc(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
})
srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
calls = append(calls, "second")
return next(ctx)
}))
})

resp := get(srv, "/foo?query={name}")
assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
Expand Down Expand Up @@ -108,18 +124,6 @@ func TestServer(t *testing.T) {

}

type opFunc func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler

func (r opFunc) InterceptOperation(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
return r(ctx, next)
}

type fieldFunc func(ctx context.Context, next graphql.Resolver) (res interface{}, err error)

func (f fieldFunc) InterceptField(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
return f(ctx, next)
}

func get(handler http.Handler, target string) *httptest.ResponseRecorder {
r := httptest.NewRequest("GET", target, nil)
w := httptest.NewRecorder()
Expand Down
4 changes: 2 additions & 2 deletions graphql/handler/transport/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const (
type (
Websocket struct {
Upgrader websocket.Upgrader
InitFunc websocketInitFunc
InitFunc WebsocketInitFunc
KeepAlivePingInterval time.Duration
}
wsConnection struct {
Expand All @@ -50,7 +50,7 @@ type (
ID string `json:"id,omitempty"`
Type string `json:"type"`
}
websocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)
WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)
)

var _ graphql.Transport = Websocket{}
Expand Down
Loading

0 comments on commit 473a0d2

Please sign in to comment.