Skip to content

Commit

Permalink
convert APQ to middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Oct 1, 2019
1 parent da98618 commit 311887d
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 144 deletions.
70 changes: 70 additions & 0 deletions graphql/handler/apq.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package handler

import (
"context"
"crypto/sha256"
"encoding/hex"

"github.com/99designs/gqlgen/graphql"
"github.com/mitchellh/mapstructure"
)

const (
errPersistedQueryNotSupported = "PersistedQueryNotSupported"
errPersistedQueryNotFound = "PersistedQueryNotFound"
)

// AutomaticPersistedQuery saves client upload by optimistically sending only the hashes of queries, if the server
// does not yet know what the query is for the hash it will respond telling the client to send the query along with the
// hash in the next request.
// see https://github.com/apollographql/apollo-link-persisted-queries
func AutomaticPersistedQuery(cache Cache) Middleware {
return func(next Handler) Handler {
return func(ctx context.Context, writer Writer) {
rc := graphql.GetRequestContext(ctx)

if rc.Extensions["persistedQuery"] == nil {
next(ctx, writer)
return
}

var extension struct {
Sha256 string `json:"sha256Hash"`
Version int64 `json:"version"`
}

if err := mapstructure.Decode(rc.Extensions["persistedQuery"], &extension); err != nil {
writer.Error("Invalid APQ extension data")
return
}

if extension.Version != 1 {
writer.Error("Unsupported APQ version")
return
}

if rc.RawQuery == "" {
// client sent optimistic query hash without query string, get it from the cache
query, ok := cache.Get(extension.Sha256)
if !ok {
writer.Error(errPersistedQueryNotFound)
return
}
rc.RawQuery = query.(string)
} else {
// client sent optimistic query hash with query string, verify and store it
if computeQueryHash(rc.RawQuery) != extension.Sha256 {
writer.Error("Provided APQ hash does not match query")
return
}
cache.Add(extension.Sha256, rc.RawQuery)
}
next(ctx, writer)
}
}
}

func computeQueryHash(query string) string {
b := sha256.Sum256([]byte(query))
return hex.EncodeToString(b[:])
}
128 changes: 128 additions & 0 deletions graphql/handler/apq_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package handler

import (
"testing"

"github.com/99designs/gqlgen/graphql"
"github.com/stretchr/testify/require"
)

func TestAPQ(t *testing.T) {
const query = "{ me { name } }"
const hash = "b8d9506e34c83b0e53c2aa463624fcea354713bc38f95276e6f0bd893ffb5b88"

t.Run("with query and no hash", func(t *testing.T) {
rc := testMiddleware(AutomaticPersistedQuery(MapCache{}), graphql.RequestContext{
RawQuery: "original query",
})

require.True(t, rc.InvokedNext)
require.Equal(t, "original query", rc.ResultContext.RawQuery)
})

t.Run("with hash miss and no query", func(t *testing.T) {
rc := testMiddleware(AutomaticPersistedQuery(MapCache{}), graphql.RequestContext{
RawQuery: "",
Extensions: map[string]interface{}{
"persistedQuery": map[string]interface{}{
"sha256": hash,
"version": 1,
},
},
})

require.False(t, rc.InvokedNext)
require.Equal(t, "PersistedQueryNotFound", rc.Response.Errors[0].Message)
})

t.Run("with hash miss and query", func(t *testing.T) {
cache := MapCache{}
rc := testMiddleware(AutomaticPersistedQuery(cache), graphql.RequestContext{
RawQuery: query,
Extensions: map[string]interface{}{
"persistedQuery": map[string]interface{}{
"sha256": hash,
"version": 1,
},
},
})

require.True(t, rc.InvokedNext, rc.Response.Errors)
require.Equal(t, "{ me { name } }", rc.ResultContext.RawQuery)
require.Equal(t, "{ me { name } }", cache[hash])
})

t.Run("with hash miss and query", func(t *testing.T) {
cache := MapCache{}
rc := testMiddleware(AutomaticPersistedQuery(cache), graphql.RequestContext{
RawQuery: query,
Extensions: map[string]interface{}{
"persistedQuery": map[string]interface{}{
"sha256": hash,
"version": 1,
},
},
})

require.True(t, rc.InvokedNext, rc.Response.Errors)
require.Equal(t, "{ me { name } }", rc.ResultContext.RawQuery)
require.Equal(t, "{ me { name } }", cache[hash])
})

t.Run("with hash hit and no query", func(t *testing.T) {
cache := MapCache{
hash: query,
}
rc := testMiddleware(AutomaticPersistedQuery(cache), graphql.RequestContext{
RawQuery: "",
Extensions: map[string]interface{}{
"persistedQuery": map[string]interface{}{
"sha256": hash,
"version": 1,
},
},
})

require.True(t, rc.InvokedNext, rc.Response.Errors)
require.Equal(t, "{ me { name } }", rc.ResultContext.RawQuery)
})

t.Run("with malformed extension payload", func(t *testing.T) {
rc := testMiddleware(AutomaticPersistedQuery(MapCache{}), graphql.RequestContext{
Extensions: map[string]interface{}{
"persistedQuery": "asdf",
},
})

require.False(t, rc.InvokedNext)
require.Equal(t, "Invalid APQ extension data", rc.Response.Errors[0].Message)
})

t.Run("with invalid extension version", func(t *testing.T) {
rc := testMiddleware(AutomaticPersistedQuery(MapCache{}), graphql.RequestContext{
Extensions: map[string]interface{}{
"persistedQuery": map[string]interface{}{
"version": 2,
},
},
})

require.False(t, rc.InvokedNext)
require.Equal(t, "Unsupported APQ version", rc.Response.Errors[0].Message)
})

t.Run("with hash mismatch", func(t *testing.T) {
rc := testMiddleware(AutomaticPersistedQuery(MapCache{}), graphql.RequestContext{
RawQuery: query,
Extensions: map[string]interface{}{
"persistedQuery": map[string]interface{}{
"sha256": "badhash",
"version": 1,
},
},
})

require.False(t, rc.InvokedNext)
require.Equal(t, "Provided APQ hash does not match query", rc.Response.Errors[0].Message)
})
}
24 changes: 24 additions & 0 deletions graphql/handler/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package handler

// Cache is a shared store for APQ and query AST caching
type Cache interface {
// Get looks up a key's value from the cache.
Get(key string) (value interface{}, ok bool)

// Add adds a value to the cache.
Add(key, value string)
}

// MapCache is the simplest implementation of a cache, because it can not evict it should only be used in tests
type MapCache map[string]interface{}

// Get looks up a key's value from the cache.
func (m MapCache) Get(key string) (value interface{}, ok bool) {
v, ok := m[key]
return v, ok
}

// Add adds a value to the cache.
func (m MapCache) Add(key, value string) {
m[key] = value
}
4 changes: 2 additions & 2 deletions graphql/handler/complexity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func TestComplexityLimit(t *testing.T) {
}))

require.True(t, rc.InvokedNext)
require.Equal(t, 10, rc.ComplexityLimit)
require.Equal(t, 10, rc.ResultContext.ComplexityLimit)
}

func TestComplexityLimitFunc(t *testing.T) {
Expand All @@ -22,5 +22,5 @@ func TestComplexityLimitFunc(t *testing.T) {
}))

require.True(t, rc.InvokedNext)
require.Equal(t, 22, rc.ComplexityLimit)
require.Equal(t, 22, rc.ResultContext.ComplexityLimit)
}
4 changes: 2 additions & 2 deletions graphql/handler/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestErrorPresenter(t *testing.T) {

require.True(t, rc.InvokedNext)
// cant test for function equality in go, so testing the return type instead
require.Equal(t, "boom", rc.ErrorPresenter(nil, nil).Message)
require.Equal(t, "boom", rc.ResultContext.ErrorPresenter(nil, nil).Message)
}

func TestRecoverFunc(t *testing.T) {
Expand All @@ -28,5 +28,5 @@ func TestRecoverFunc(t *testing.T) {

require.True(t, rc.InvokedNext)
// cant test for function equality in go, so testing the return type instead
assert.Equal(t, "boom", rc.Recover(nil, nil).Error())
assert.Equal(t, "boom", rc.ResultContext.Recover(nil, nil).Error())
}
2 changes: 1 addition & 1 deletion graphql/handler/introspection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ func TestIntrospection(t *testing.T) {

require.True(t, rc.InvokedNext)
// cant test for function equality in go, so testing the return type instead
assert.False(t, rc.DisableIntrospection)
assert.False(t, rc.ResultContext.DisableIntrospection)
}
12 changes: 12 additions & 0 deletions graphql/handler/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ type (
ResponseStream func() *graphql.Response
)

func (w Writer) Errorf(format string, args ...interface{}) {
w(&graphql.Response{
Errors: gqlerror.List{{Message: fmt.Sprintf(format, args...)}},
})
}

func (w Writer) Error(msg string) {
w(&graphql.Response{
Errors: gqlerror.List{{Message: msg}},
})
}

func (s *Server) AddTransport(transport Transport) {
s.transports = append(s.transports, transport)
}
Expand Down
24 changes: 12 additions & 12 deletions graphql/handler/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@ import (
)

type middlewareContext struct {
*graphql.RequestContext
InvokedNext bool
InvokedNext bool
ResultContext graphql.RequestContext
Response graphql.Response
}

func testMiddleware(m Middleware, initialContexts ...graphql.RequestContext) middlewareContext {
rc := &graphql.RequestContext{}
var c middlewareContext
initial := &graphql.RequestContext{}
if len(initialContexts) > 0 {
rc = &initialContexts[0]
initial = &initialContexts[0]
}

m(func(ctx context.Context, writer Writer) {
rc = graphql.GetRequestContext(ctx)
})(graphql.WithRequestContext(context.Background(), rc), noopWriter)
c.ResultContext = *graphql.GetRequestContext(ctx)
c.InvokedNext = true
})(graphql.WithRequestContext(context.Background(), initial), func(response *graphql.Response) {
c.Response = *response
})

return middlewareContext{
InvokedNext: rc != nil,
RequestContext: rc,
}
return c
}

func noopWriter(response *graphql.Response) {}
Loading

0 comments on commit 311887d

Please sign in to comment.