diff --git a/handler/graphql.go b/handler/graphql.go index c266d4c946c..8c3882ce9c2 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -30,7 +30,7 @@ type params struct { Variables map[string]interface{} `json:"variables"` } -type websocketInitFunc func(ctx context.Context, initPayload InitPayload) bool +type websocketInitFunc func(ctx context.Context, initPayload InitPayload) error type Config struct { cacheSize int @@ -255,7 +255,7 @@ func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) { // WebsocketInitFunc is called when the server receives connection init message from the client. // This can be used to check initial payload to see whether to accept the websocket connection. -func WebsocketInitFunc(websocketInitFunc func(ctx context.Context, initPayload InitPayload) bool) Option { +func WebsocketInitFunc(websocketInitFunc func(ctx context.Context, initPayload InitPayload) error) Option { return func(cfg *Config) { cfg.websocketInitFunc = websocketInitFunc } diff --git a/handler/websocket.go b/handler/websocket.go index a9a74893d20..07a1a8c2dd8 100644 --- a/handler/websocket.go +++ b/handler/websocket.go @@ -95,8 +95,8 @@ func (c *wsConnection) init() bool { } if c.cfg.websocketInitFunc != nil { - if ok := c.cfg.websocketInitFunc(c.ctx, c.initPayload); !ok { - c.sendConnectionError("invalid init payload") + if err := c.cfg.websocketInitFunc(c.ctx, c.initPayload); err != nil { + c.sendConnectionError(err.Error()) c.close(websocket.CloseNormalClosure, "terminated") return false } diff --git a/handler/websocket_test.go b/handler/websocket_test.go index b7d5e89905a..dc3e656e5fe 100644 --- a/handler/websocket_test.go +++ b/handler/websocket_test.go @@ -3,6 +3,7 @@ package handler import ( "context" "encoding/json" + "errors" "net/http/httptest" "strings" "testing" @@ -176,8 +177,8 @@ func TestWebsocketInitFunc(t *testing.T) { }) t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { - h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) bool { - return true + h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) error { + return nil })) srv := httptest.NewServer(h) defer srv.Close() @@ -191,8 +192,8 @@ func TestWebsocketInitFunc(t *testing.T) { }) t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { - h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) bool { - return false + h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) error { + return errors.New("invalid init payload") })) srv := httptest.NewServer(h) defer srv.Close()