Skip to content

Commit

Permalink
Update websocketInitFunc to return error instead of boolean
Browse files Browse the repository at this point in the history
  • Loading branch information
Eddy committed Jun 18, 2019
1 parent a6508b6 commit c397be0
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
4 changes: 2 additions & 2 deletions handler/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type params struct {
Variables map[string]interface{} `json:"variables"`
}

type websocketInitFunc func(ctx context.Context, initPayload InitPayload) bool
type websocketInitFunc func(ctx context.Context, initPayload InitPayload) error

type Config struct {
cacheSize int
Expand Down Expand Up @@ -255,7 +255,7 @@ func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) {

// WebsocketInitFunc is called when the server receives connection init message from the client.
// This can be used to check initial payload to see whether to accept the websocket connection.
func WebsocketInitFunc(websocketInitFunc func(ctx context.Context, initPayload InitPayload) bool) Option {
func WebsocketInitFunc(websocketInitFunc func(ctx context.Context, initPayload InitPayload) error) Option {
return func(cfg *Config) {
cfg.websocketInitFunc = websocketInitFunc
}
Expand Down
4 changes: 2 additions & 2 deletions handler/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ func (c *wsConnection) init() bool {
}

if c.cfg.websocketInitFunc != nil {
if ok := c.cfg.websocketInitFunc(c.ctx, c.initPayload); !ok {
c.sendConnectionError("invalid init payload")
if err := c.cfg.websocketInitFunc(c.ctx, c.initPayload); err != nil {
c.sendConnectionError(err.Error())
c.close(websocket.CloseNormalClosure, "terminated")
return false
}
Expand Down
9 changes: 5 additions & 4 deletions handler/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package handler
import (
"context"
"encoding/json"
"errors"
"net/http/httptest"
"strings"
"testing"
Expand Down Expand Up @@ -176,8 +177,8 @@ func TestWebsocketInitFunc(t *testing.T) {
})

t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) bool {
return true
h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) error {
return nil
}))
srv := httptest.NewServer(h)
defer srv.Close()
Expand All @@ -191,8 +192,8 @@ func TestWebsocketInitFunc(t *testing.T) {
})

t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) bool {
return false
h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) error {
return errors.New("invalid init payload")
}))
srv := httptest.NewServer(h)
defer srv.Close()
Expand Down

0 comments on commit c397be0

Please sign in to comment.