diff --git a/graphql/handler/apollotracing/tracer_test.go b/graphql/handler/apollotracing/tracer_test.go index ab9dbdc8e18..f0622d471a1 100644 --- a/graphql/handler/apollotracing/tracer_test.go +++ b/graphql/handler/apollotracing/tracer_test.go @@ -1,76 +1,26 @@ package apollotracing_test import ( - "context" "encoding/json" "net/http" "net/http/httptest" "strings" "testing" - "time" - "github.com/99designs/gqlgen/graphql" - "github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/handler/apollotracing" + "github.com/99designs/gqlgen/graphql/handler/testserver" "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/vektah/gqlparser" - "github.com/vektah/gqlparser/ast" ) -// todo: extract out common code for testing handler plugins without requiring a codegenned server. func TestApolloTracing(t *testing.T) { - now := time.Unix(0, 0) - - graphql.Now = func() time.Time { - defer func() { - now = now.Add(100 * time.Nanosecond) - }() - return now - } - - schema := gqlparser.MustLoadSchema(&ast.Source{Input: ` - schema { query: Query } - type Query { - me: User! - user(id: Int): User! - } - type User { name: String! } - `}) - - es := &graphql.ExecutableSchemaMock{ - QueryFunc: func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { - // Field execution happens inside the generated code, we want just enough to test against right now. - ctx = graphql.WithResolverContext(ctx, &graphql.ResolverContext{ - Object: "Query", - Field: graphql.CollectedField{ - Field: &ast.Field{ - Name: "me", - Alias: "me", - Definition: schema.Types["Query"].Fields.ForName("me"), - }, - }, - }) - res, err := graphql.GetRequestContext(ctx).ResolverMiddleware(ctx, func(ctx context.Context) (interface{}, error) { - return &graphql.Response{Data: []byte(`{"name":"test"}`)}, nil - }) - require.NoError(t, err) - return res.(*graphql.Response) - }, - MutationFunc: func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { - return graphql.ErrorResponse(ctx, "mutations are not supported") - }, - SchemaFunc: func() *ast.Schema { - return schema - }, - } - h := handler.New(es) + h := testserver.New() h.AddTransport(transport.POST{}) h.Use(apollotracing.New()) - resp := doRequest(h, "POST", "/graphql", `{"query":"{ me { name } }"}`) - assert.Equal(t, http.StatusOK, resp.Code) + resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`) + assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) var respData struct { Extensions struct { Tracing apollotracing.TracingExtension `json:"tracing"` @@ -94,10 +44,10 @@ func TestApolloTracing(t *testing.T) { require.EqualValues(t, 500, tracing.Execution.Resolvers[0].StartOffset) require.EqualValues(t, 100, tracing.Execution.Resolvers[0].Duration) - require.EqualValues(t, []interface{}{"me"}, tracing.Execution.Resolvers[0].Path) + require.EqualValues(t, []interface{}{"name"}, tracing.Execution.Resolvers[0].Path) require.EqualValues(t, "Query", tracing.Execution.Resolvers[0].ParentType) - require.EqualValues(t, "me", tracing.Execution.Resolvers[0].FieldName) - require.EqualValues(t, "User!", tracing.Execution.Resolvers[0].ReturnType) + require.EqualValues(t, "name", tracing.Execution.Resolvers[0].FieldName) + require.EqualValues(t, "String!", tracing.Execution.Resolvers[0].ReturnType) } diff --git a/graphql/handler/server_test.go b/graphql/handler/server_test.go index 43e69fb9059..fe8ab220871 100644 --- a/graphql/handler/server_test.go +++ b/graphql/handler/server_test.go @@ -1,4 +1,4 @@ -package handler +package handler_test import ( "context" @@ -8,41 +8,13 @@ import ( "testing" "github.com/99designs/gqlgen/graphql" + "github.com/99designs/gqlgen/graphql/handler/testserver" "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/vektah/gqlparser/ast" ) func TestServer(t *testing.T) { - es := &graphql.ExecutableSchemaMock{ - QueryFunc: func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { - // Field execution happens inside the generated code, we want just enough to test against right now. - res, err := graphql.GetRequestContext(ctx).ResolverMiddleware(ctx, func(ctx context.Context) (interface{}, error) { - return &graphql.Response{Data: []byte(`"query resp"`)}, nil - }) - require.NoError(t, err) - - return res.(*graphql.Response) - }, - MutationFunc: func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { - return &graphql.Response{Data: []byte(`"mutation resp"`)} - }, - SubscriptionFunc: func(ctx context.Context, op *ast.OperationDefinition) func() *graphql.Response { - called := 0 - return func() *graphql.Response { - called++ - if called > 2 { - return nil - } - return &graphql.Response{Data: []byte(`"subscription resp"`)} - } - }, - SchemaFunc: func() *ast.Schema { - return &ast.Schema{} - }, - } - srv := New(es) + srv := testserver.New() srv.AddTransport(&transport.GET{}) t.Run("returns an error if no transport matches", func(t *testing.T) { @@ -52,20 +24,20 @@ func TestServer(t *testing.T) { }) t.Run("calls query on executable schema", func(t *testing.T) { - resp := get(srv, "/foo?query={a}") + resp := get(srv, "/foo?query={name}") assert.Equal(t, http.StatusOK, resp.Code) - assert.Equal(t, `{"data":"query resp"}`, resp.Body.String()) + assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) }) t.Run("mutations are forbidden", func(t *testing.T) { - resp := get(srv, "/foo?query=mutation{a}") - assert.Equal(t, http.StatusOK, resp.Code) + resp := get(srv, "/foo?query=mutation{name}") + assert.Equal(t, http.StatusNotAcceptable, resp.Code) assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String()) }) t.Run("subscriptions are forbidden", func(t *testing.T) { - resp := get(srv, "/foo?query=subscription{a}") - assert.Equal(t, http.StatusOK, resp.Code) + resp := get(srv, "/foo?query=subscription{name}") + assert.Equal(t, http.StatusNotAcceptable, resp.Code) assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String()) }) @@ -80,8 +52,8 @@ func TestServer(t *testing.T) { next(ctx, writer) })) - resp := get(srv, "/foo?query={a}") - assert.Equal(t, http.StatusOK, resp.Code) + resp := get(srv, "/foo?query={name}") + assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) assert.Equal(t, []string{"first", "second"}, calls) }) @@ -98,8 +70,8 @@ func TestServer(t *testing.T) { return next(ctx) })) - resp := get(srv, "/foo?query={a}") - assert.Equal(t, http.StatusOK, resp.Code) + resp := get(srv, "/foo?query={name}") + assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) assert.Equal(t, []string{"first", "second"}, calls) }) } diff --git a/graphql/handler/testserver/testserver.go b/graphql/handler/testserver/testserver.go new file mode 100644 index 00000000000..284b8c1f653 --- /dev/null +++ b/graphql/handler/testserver/testserver.go @@ -0,0 +1,99 @@ +package testserver + +import ( + "context" + "fmt" + "time" + + "github.com/99designs/gqlgen/graphql" + "github.com/99designs/gqlgen/graphql/handler" + "github.com/vektah/gqlparser" + "github.com/vektah/gqlparser/ast" +) + +// New provides a server for use in tests that isn't relying on generated code. It isnt a perfect reproduction of +// a generated server, but it aims to be good enough to test the handler package without relying on codegen. +func New() *TestServer { + next := make(chan struct{}) + now := time.Unix(0, 0) + + graphql.Now = func() time.Time { + defer func() { + now = now.Add(100 * time.Nanosecond) + }() + return now + } + + schema := gqlparser.MustLoadSchema(&ast.Source{Input: ` + schema { query: Query } + type Query { + name: String! + find(id: Int!): String! + } + type Mutation { + name: String! + } + type Subscription { + name: String! + } + `}) + + es := &graphql.ExecutableSchemaMock{ + QueryFunc: func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { + // Field execution happens inside the generated code, lets simulate some of it. + ctx = graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Object: "Query", + Field: graphql.CollectedField{ + Field: &ast.Field{ + Name: "name", + Alias: "name", + Definition: schema.Types["Query"].Fields.ForName("name"), + }, + }, + }) + res, err := graphql.GetRequestContext(ctx).ResolverMiddleware(ctx, func(ctx context.Context) (interface{}, error) { + return &graphql.Response{Data: []byte(`{"name":"test"}`)}, nil + }) + if err != nil { + panic(err) + } + return res.(*graphql.Response) + }, + MutationFunc: func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { + return graphql.ErrorResponse(ctx, "mutations are not supported") + }, + SubscriptionFunc: func(ctx context.Context, op *ast.OperationDefinition) func() *graphql.Response { + return func() *graphql.Response { + select { + case <-ctx.Done(): + return nil + case <-next: + return &graphql.Response{ + Data: []byte(`{"name":"test"}`), + } + } + } + }, + SchemaFunc: func() *ast.Schema { + return schema + }, + } + return &TestServer{ + Server: handler.New(es), + next: next, + } +} + +type TestServer struct { + *handler.Server + next chan struct{} +} + +func (s *TestServer) SendNextSubscriptionMessage() { + select { + case s.next <- struct{}{}: + case <-time.After(1 * time.Second): + fmt.Println("WARNING: no active subscription") + } + +} diff --git a/graphql/handler/transport/http_get.go b/graphql/handler/transport/http_get.go index f7d9ac48410..2a8efe44f07 100644 --- a/graphql/handler/transport/http_get.go +++ b/graphql/handler/transport/http_get.go @@ -6,9 +6,8 @@ import ( "net/http" "strings" - "github.com/vektah/gqlparser/ast" - "github.com/99designs/gqlgen/graphql" + "github.com/vektah/gqlparser/ast" ) // GET implements the GET side of the default HTTP transport @@ -47,6 +46,7 @@ func (H GET) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecut if variables := r.URL.Query().Get("variables"); variables != "" { if err := jsonDecode(strings.NewReader(variables), &raw.Variables); err != nil { + w.WriteHeader(http.StatusBadRequest) writer.Errorf("variables could not be decoded") return } @@ -54,6 +54,7 @@ func (H GET) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecut if extensions := r.URL.Query().Get("extensions"); extensions != "" { if err := jsonDecode(strings.NewReader(extensions), &raw.Extensions); err != nil { + w.WriteHeader(http.StatusBadRequest) writer.Errorf("extensions could not be decoded") return } @@ -63,9 +64,11 @@ func (H GET) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecut if err != nil { w.WriteHeader(http.StatusUnprocessableEntity) writer.GraphqlErr(err...) + return } op := rc.Doc.Operations.ForName(rc.OperationName) if op.Operation != ast.Query { + w.WriteHeader(http.StatusNotAcceptable) writer.Errorf("GET requests only allow query operations") return } diff --git a/graphql/handler/transport/http_get_test.go b/graphql/handler/transport/http_get_test.go new file mode 100644 index 00000000000..e0c09625230 --- /dev/null +++ b/graphql/handler/transport/http_get_test.go @@ -0,0 +1,45 @@ +package transport_test + +import ( + "net/http" + "testing" + + "github.com/99designs/gqlgen/graphql/handler/testserver" + "github.com/99designs/gqlgen/graphql/handler/transport" + "github.com/stretchr/testify/assert" +) + +func TestGET(t *testing.T) { + h := testserver.New() + h.AddTransport(transport.GET{}) + + t.Run("success", func(t *testing.T) { + resp := doRequest(h, "GET", "/graphql?query={name}", ``) + assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) + assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) + }) + + t.Run("decode failure", func(t *testing.T) { + resp := doRequest(h, "GET", "/graphql?query={name}&variables=notjson", "") + assert.Equal(t, http.StatusBadRequest, resp.Code, resp.Body.String()) + assert.Equal(t, `{"errors":[{"message":"variables could not be decoded"}],"data":null}`, resp.Body.String()) + }) + + t.Run("invalid variable", func(t *testing.T) { + resp := doRequest(h, "GET", `/graphql?query=query($id:Int!){find(id:$id)}&variables={"id":false}`, "") + assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String()) + assert.Equal(t, `{"errors":[{"message":"cannot use bool as Int","path":["variable","id"]}],"data":null}`, resp.Body.String()) + }) + + t.Run("parse failure", func(t *testing.T) { + resp := doRequest(h, "GET", "/graphql?query=!", "") + assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String()) + assert.Equal(t, `{"errors":[{"message":"Unexpected !","locations":[{"line":1,"column":1}]}],"data":null}`, resp.Body.String()) + }) + + t.Run("no mutations", func(t *testing.T) { + resp := doRequest(h, "GET", "/graphql?query=mutation{name}", "") + assert.Equal(t, http.StatusNotAcceptable, resp.Code, resp.Body.String()) + assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String()) + }) +} diff --git a/graphql/handler/transport/http_post_test.go b/graphql/handler/transport/http_post_test.go index ad075e29423..86411767ca7 100644 --- a/graphql/handler/transport/http_post_test.go +++ b/graphql/handler/transport/http_post_test.go @@ -1,45 +1,23 @@ package transport_test import ( - "context" "fmt" "net/http" "net/http/httptest" "strings" "testing" - "github.com/99designs/gqlgen/graphql" - "github.com/99designs/gqlgen/graphql/handler" + "github.com/99designs/gqlgen/graphql/handler/testserver" "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/stretchr/testify/assert" - "github.com/vektah/gqlparser" - "github.com/vektah/gqlparser/ast" ) func TestPOST(t *testing.T) { - es := &graphql.ExecutableSchemaMock{ - QueryFunc: func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { - return &graphql.Response{Data: []byte(`{"name":"test"}`)} - }, - MutationFunc: func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { - return graphql.ErrorResponse(ctx, "mutations are not supported") - }, - SchemaFunc: func() *ast.Schema { - return gqlparser.MustLoadSchema(&ast.Source{Input: ` - schema { query: Query } - type Query { - me: User! - user(id: Int): User! - } - type User { name: String! } - `}) - }, - } - h := handler.New(es) + h := testserver.New() h.AddTransport(transport.POST{}) t.Run("success", func(t *testing.T) { - resp := doRequest(h, "POST", "/graphql", `{"query":"{ me { name } }"}`) + resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`) assert.Equal(t, http.StatusOK, resp.Code) assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) }) @@ -84,21 +62,21 @@ func TestPOST(t *testing.T) { }) t.Run("validation failure", func(t *testing.T) { - resp := doRequest(h, "POST", "/graphql", `{"query": "{ me { title }}"}`) + resp := doRequest(h, "POST", "/graphql", `{"query": "{ title }"}`) assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String()) assert.Equal(t, resp.Header().Get("Content-Type"), "application/json") - assert.Equal(t, `{"errors":[{"message":"Cannot query field \"title\" on type \"User\".","locations":[{"line":1,"column":8}]}],"data":null}`, resp.Body.String()) + assert.Equal(t, `{"errors":[{"message":"Cannot query field \"title\" on type \"Query\".","locations":[{"line":1,"column":3}]}],"data":null}`, resp.Body.String()) }) t.Run("invalid variable", func(t *testing.T) { - resp := doRequest(h, "POST", "/graphql", `{"query": "query($id:Int!){user(id:$id){name}}","variables":{"id":false}}`) + resp := doRequest(h, "POST", "/graphql", `{"query": "query($id:Int!){find(id:$id)}","variables":{"id":false}}`) assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String()) assert.Equal(t, resp.Header().Get("Content-Type"), "application/json") assert.Equal(t, `{"errors":[{"message":"cannot use bool as Int","path":["variable","id"]}],"data":null}`, resp.Body.String()) }) t.Run("execution failure", func(t *testing.T) { - resp := doRequest(h, "POST", "/graphql", `{"query": "mutation { me { name } }"}`) + resp := doRequest(h, "POST", "/graphql", `{"query": "mutation { name }"}`) assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) assert.Equal(t, resp.Header().Get("Content-Type"), "application/json") assert.Equal(t, `{"errors":[{"message":"mutations are not supported"}],"data":null}`, resp.Body.String()) @@ -123,7 +101,7 @@ func TestPOST(t *testing.T) { for _, contentType := range validContentTypes { t.Run(fmt.Sprintf("allow for content type %s", contentType), func(t *testing.T) { - resp := doReq(h, "POST", "/graphql", `{"query":"{ me { name } }"}`, contentType) + resp := doReq(h, "POST", "/graphql", `{"query":"{ name }"}`, contentType) assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) }) @@ -140,7 +118,7 @@ func TestPOST(t *testing.T) { for _, tc := range invalidContentTypes { t.Run(fmt.Sprintf("reject for content type %s", tc), func(t *testing.T) { - resp := doReq(h, "POST", "/graphql", `{"query":"{ me { name } }"}`, tc) + resp := doReq(h, "POST", "/graphql", `{"query":"{ name }"}`, tc) assert.Equal(t, http.StatusBadRequest, resp.Code, resp.Body.String()) assert.Equal(t, fmt.Sprintf(`{"errors":[{"message":"%s"}],"data":null}`, "transport not supported"), resp.Body.String()) }) diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index 87bc4c1608c..1e40bbc61c5 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -29,13 +29,13 @@ const ( ) type ( - WebsocketTransport struct { + Websocket struct { Upgrader websocket.Upgrader InitFunc websocketInitFunc KeepAlivePingInterval time.Duration } wsConnection struct { - WebsocketTransport + Websocket ctx context.Context conn *websocket.Conn active map[string]context.CancelFunc @@ -53,13 +53,13 @@ type ( websocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) ) -var _ graphql.Transport = WebsocketTransport{} +var _ graphql.Transport = Websocket{} -func (t WebsocketTransport) Supports(r *http.Request) bool { +func (t Websocket) Supports(r *http.Request) bool { return r.Header.Get("Upgrade") != "" } -func (t WebsocketTransport) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { +func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { ws, err := t.Upgrader.Upgrade(w, r, http.Header{ "Sec-Websocket-Protocol": []string{"graphql-ws"}, }) @@ -70,11 +70,11 @@ func (t WebsocketTransport) Do(w http.ResponseWriter, r *http.Request, exec grap } conn := wsConnection{ - active: map[string]context.CancelFunc{}, - conn: ws, - ctx: r.Context(), - exec: exec, - WebsocketTransport: t, + active: map[string]context.CancelFunc{}, + conn: ws, + ctx: r.Context(), + exec: exec, + Websocket: t, } if !conn.init() { diff --git a/graphql/handler/transport/websocket_test.go b/graphql/handler/transport/websocket_test.go index 1592d196336..e7a5436e486 100644 --- a/graphql/handler/transport/websocket_test.go +++ b/graphql/handler/transport/websocket_test.go @@ -12,6 +12,7 @@ import ( "github.com/99designs/gqlgen/client" "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/handler" + "github.com/99designs/gqlgen/graphql/handler/testserver" "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" @@ -21,9 +22,8 @@ import ( ) func TestWebsocket(t *testing.T) { - next := make(chan struct{}) - handler := newServer(next) - handler.AddTransport(transport.WebsocketTransport{}) + handler := testserver.New() + handler.AddTransport(transport.Websocket{}) srv := httptest.NewServer(handler) defer srv.Close() @@ -114,33 +114,33 @@ func TestWebsocket(t *testing.T) { require.NoError(t, c.WriteJSON(&operationMessage{ Type: startMsg, ID: "test_1", - Payload: json.RawMessage(`{"query": "subscription { user { title } }"}`), + Payload: json.RawMessage(`{"query": "subscription { name }"}`), })) - next <- struct{}{} + handler.SendNextSubscriptionMessage() msg := readOp(c) - assert.Equal(t, dataMsg, msg.Type) - assert.Equal(t, "test_1", msg.ID) - assert.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) + require.Equal(t, dataMsg, msg.Type, string(msg.Payload)) + require.Equal(t, "test_1", msg.ID, string(msg.Payload)) + require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) - next <- struct{}{} + handler.SendNextSubscriptionMessage() msg = readOp(c) - assert.Equal(t, dataMsg, msg.Type) - assert.Equal(t, "test_1", msg.ID) - assert.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) + require.Equal(t, dataMsg, msg.Type, string(msg.Payload)) + require.Equal(t, "test_1", msg.ID, string(msg.Payload)) + require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) require.NoError(t, c.WriteJSON(&operationMessage{Type: stopMsg, ID: "test_1"})) msg = readOp(c) - assert.Equal(t, completeMsg, msg.Type) - assert.Equal(t, "test_1", msg.ID) + require.Equal(t, completeMsg, msg.Type) + require.Equal(t, "test_1", msg.ID) }) } func TestWebsocketWithKeepAlive(t *testing.T) { - next := make(chan struct{}) - h := newServer(next) - h.AddTransport(transport.WebsocketTransport{ + + h := testserver.New() + h.AddTransport(transport.Websocket{ KeepAlivePingInterval: 10 * time.Millisecond, }) @@ -157,7 +157,7 @@ func TestWebsocketWithKeepAlive(t *testing.T) { require.NoError(t, c.WriteJSON(&operationMessage{ Type: startMsg, ID: "test_1", - Payload: json.RawMessage(`{"query": "subscription { user { title } }"}`), + Payload: json.RawMessage(`{"query": "subscription { name }"}`), })) // keepalive @@ -165,7 +165,7 @@ func TestWebsocketWithKeepAlive(t *testing.T) { assert.Equal(t, connectionKeepAliveMsg, msg.Type) // server message - next <- struct{}{} + h.SendNextSubscriptionMessage() msg = readOp(c) assert.Equal(t, dataMsg, msg.Type) @@ -175,11 +175,9 @@ func TestWebsocketWithKeepAlive(t *testing.T) { } func TestWebsocketInitFunc(t *testing.T) { - next := make(chan struct{}) - t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) { - h := newServer(next) - h.AddTransport(transport.WebsocketTransport{}) + h := testserver.New() + h.AddTransport(transport.Websocket{}) srv := httptest.NewServer(h) defer srv.Close() @@ -193,8 +191,8 @@ func TestWebsocketInitFunc(t *testing.T) { }) t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { - h := newServer(next) - h.AddTransport(transport.WebsocketTransport{ + h := testserver.New() + h.AddTransport(transport.Websocket{ InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { return context.WithValue(ctx, "newkey", "newvalue"), nil }, @@ -212,8 +210,8 @@ func TestWebsocketInitFunc(t *testing.T) { }) t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { - h := newServer(next) - h.AddTransport(transport.WebsocketTransport{ + h := testserver.New() + h.AddTransport(transport.Websocket{ InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { return ctx, errors.New("invalid init payload") }, @@ -248,7 +246,7 @@ func TestWebsocketInitFunc(t *testing.T) { } h := handler.New(es) - h.AddTransport(transport.WebsocketTransport{ + h.AddTransport(transport.Websocket{ InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { return context.WithValue(ctx, "newkey", "newvalue"), nil }, @@ -267,40 +265,6 @@ func TestWebsocketInitFunc(t *testing.T) { }) } -func newServer(next chan struct{}) *handler.Server { - es := &graphql.ExecutableSchemaMock{ - QueryFunc: func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { - return graphql.ErrorResponse(ctx, "queries are not supported") - }, - MutationFunc: func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { - return graphql.ErrorResponse(ctx, "mutations are not supported") - }, - SubscriptionFunc: func(ctx context.Context, op *ast.OperationDefinition) func() *graphql.Response { - return func() *graphql.Response { - select { - case <-ctx.Done(): - return nil - case <-next: - return &graphql.Response{ - Data: []byte(`{"name":"test"}`), - } - } - } - }, - SchemaFunc: func() *ast.Schema { - return gqlparser.MustLoadSchema(&ast.Source{Input: ` - schema { query: Query } - type Query { - me: User! - user(id: Int): User! - } - type User { name: String! } - `}) - }, - } - return handler.New(es) -} - func wsConnect(url string) *websocket.Conn { c, resp, err := websocket.DefaultDialer.Dial(strings.Replace(url, "http://", "ws://", -1), nil) if err != nil {