From 2b135ff9d5a0c406d48fbe211f5b1c9744626867 Mon Sep 17 00:00:00 2001 From: Hellyson Rodrigo Parteka Date: Tue, 5 Jan 2021 09:13:38 -0300 Subject: [PATCH] feat(websocket): change websocket lib to nhooyr.io/websocket (#815) Fixes #713, #543, and #664. --- Gopkg.lock | 41 +++++++++++++++++++++++++-------- Gopkg.toml | 4 ++-- go/grpcweb/websocket_wrapper.go | 40 +++++++++++++++----------------- go/grpcweb/wrapper.go | 31 +++++++++++-------------- 4 files changed, 66 insertions(+), 50 deletions(-) diff --git a/Gopkg.lock b/Gopkg.lock index 82ada0bd..ca67a2e9 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -17,6 +17,14 @@ revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" version = "v1.1.1" +[[projects]] + branch = "master" + digest = "1:c28a369c8fec7124b081d49b9097ac3349e0bf73555f05a8b384632b5278c80c" + name = "github.com/desertbit/timer" + packages = ["."] + pruneopts = "" + revision = "c41aec40b27f0eeb2b94300fffcd624c69b02990" + [[projects]] digest = "1:529d738b7976c3848cae5cf3a8036440166835e389c1f617af701eeb12a0518d" name = "github.com/golang/protobuf" @@ -40,14 +48,6 @@ revision = "b5d812f8a3706043e23a9cd5babf2e5423744d30" version = "v1.3.1" -[[projects]] - digest = "1:09aa5dd1332b93c96bde671bafb053249dc813febf7d5ca84e8f382ba255d67d" - name = "github.com/gorilla/websocket" - packages = ["."] - pruneopts = "" - revision = "66b9c49e59c6c48f0ffce28c2d8b8a5678502c6d" - version = "v1.4.0" - [[projects]] branch = "master" digest = "1:ef152d412a8e4fa5d997f3db288beafc591b5c1e619b47584af420d970ebf1ef" @@ -71,6 +71,14 @@ revision = "c225b8c3b01faf2899099b768856a9e916e5087b" version = "v1.2.0" +[[projects]] + digest = "1:13c741134c7da17734277a83d1312be4572751b64aa87189322b8f6a50cb9547" + name = "github.com/klauspost/compress" + packages = ["flate"] + pruneopts = "" + revision = "156c8d0eb96e404e0870865fc7b102b1cccde45a" + version = "v1.11.4" + [[projects]] digest = "1:0f51cee70b0d254dbc93c22666ea2abf211af81c1701a96d04e2284b408621db" name = "github.com/konsorten/go-windows-terminal-sequences" @@ -285,14 +293,28 @@ revision = "25c4f928eaa6d96443009bd842389fb4fa48664e" version = "v1.20.1" +[[projects]] + digest = "1:fe18a04a5fae08be6a3c0090d92f24c110f50eb0229c1da5303e270dbb7c5155" + name = "nhooyr.io/websocket" + packages = [ + ".", + "internal/bpool", + "internal/errd", + "internal/wsjs", + "internal/xsync", + ] + pruneopts = "" + revision = "02861b474d9c29660eff53a3c424d589aaf46d1e" + version = "v1.8.6" + [solve-meta] analyzer-name = "dep" analyzer-version = 1 input-imports = [ + "github.com/desertbit/timer", "github.com/golang/protobuf/proto", "github.com/golang/protobuf/protoc-gen-go", "github.com/golang/protobuf/ptypes/empty", - "github.com/gorilla/websocket", "github.com/grpc-ecosystem/go-grpc-middleware", "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus", "github.com/grpc-ecosystem/go-grpc-prometheus", @@ -314,6 +336,7 @@ "google.golang.org/grpc/credentials", "google.golang.org/grpc/grpclog", "google.golang.org/grpc/metadata", + "nhooyr.io/websocket", ] solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index 577808d4..0b787cb2 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -15,8 +15,8 @@ required = [ version = "1.1.0" [[constraint]] - name = "github.com/gorilla/websocket" - version = "1.2.0" + name = "nhooyr.io/websocket" + version = "1.8.6" [[constraint]] branch = "master" diff --git a/go/grpcweb/websocket_wrapper.go b/go/grpcweb/websocket_wrapper.go index a920d00e..51631e93 100644 --- a/go/grpcweb/websocket_wrapper.go +++ b/go/grpcweb/websocket_wrapper.go @@ -13,8 +13,8 @@ import ( "time" "github.com/desertbit/timer" - "github.com/gorilla/websocket" "golang.org/x/net/http2" + "nhooyr.io/websocket" ) type webSocketResponseWriter struct { @@ -24,40 +24,34 @@ type webSocketResponseWriter struct { flushedHeaders http.Header timeOutInterval time.Duration timer *timer.Timer + context context.Context } -func newWebSocketResponseWriter(wsConn *websocket.Conn) *webSocketResponseWriter { +func newWebSocketResponseWriter(ctx context.Context, wsConn *websocket.Conn) *webSocketResponseWriter { return &webSocketResponseWriter{ writtenHeaders: false, headers: make(http.Header), flushedHeaders: make(http.Header), wsConn: wsConn, + context: ctx, } } func (w *webSocketResponseWriter) enablePing(timeOutInterval time.Duration) { w.timeOutInterval = timeOutInterval w.timer = timer.NewTimer(w.timeOutInterval) - dispose := make(chan bool) - w.wsConn.SetCloseHandler(func(code int, text string) error { - close(dispose) - return nil - }) - go w.ping(dispose) + go w.ping() } -func (w *webSocketResponseWriter) ping(dispose chan bool) { - if dispose == nil { - return - } +func (w *webSocketResponseWriter) ping() { defer w.timer.Stop() for { select { - case <-dispose: + case <-w.context.Done(): return case <-w.timer.C: w.timer.Reset(w.timeOutInterval) - w.wsConn.WriteMessage(websocket.PingMessage, []byte{}) + w.wsConn.Ping(w.context) } } } @@ -73,7 +67,7 @@ func (w *webSocketResponseWriter) Write(b []byte) (int, error) { if w.timeOutInterval > time.Second && w.timer != nil { w.timer.Reset(w.timeOutInterval) } - return len(b), w.wsConn.WriteMessage(websocket.BinaryMessage, b) + return len(b), w.wsConn.Write(w.context, websocket.MessageBinary, b) } func (w *webSocketResponseWriter) writeHeaderFrame(headers http.Header) { @@ -81,8 +75,8 @@ func (w *webSocketResponseWriter) writeHeaderFrame(headers http.Header) { headers.Write(headerBuffer) headerGrpcDataHeader := []byte{1 << 7, 0, 0, 0, 0} // MSB=1 indicates this is a header data frame. binary.BigEndian.PutUint32(headerGrpcDataHeader[1:5], uint32(headerBuffer.Len())) - w.wsConn.WriteMessage(websocket.BinaryMessage, headerGrpcDataHeader) - w.wsConn.WriteMessage(websocket.BinaryMessage, headerBuffer.Bytes()) + w.wsConn.Write(w.context, websocket.MessageBinary, headerGrpcDataHeader) + w.wsConn.Write(w.context, websocket.MessageBinary, headerBuffer.Bytes()) } func (w *webSocketResponseWriter) copyFlushedHeaders() { @@ -127,12 +121,13 @@ type webSocketWrappedReader struct { respWriter *webSocketResponseWriter remainingBuffer []byte remainingError error + context context.Context cancel context.CancelFunc } func (w *webSocketWrappedReader) Close() error { w.respWriter.FlushTrailers() - return w.wsConn.Close() + return w.wsConn.Close(websocket.StatusNormalClosure, "request body closed") } // First byte of a binary WebSocket frame is used for control flow: @@ -167,15 +162,15 @@ func (w *webSocketWrappedReader) Read(p []byte) (int, error) { } // Read a whole frame from the WebSocket connection - messageType, framePayload, err := w.wsConn.ReadMessage() - if err == io.EOF || messageType == -1 { + messageType, framePayload, err := w.wsConn.Read(w.context) + if err == io.EOF || messageType == 0 { // The client has closed the connection. Indicate to the response writer that it should close w.cancel() return 0, io.EOF } // Only Binary frames are valid - if messageType != websocket.BinaryMessage { + if messageType != websocket.MessageBinary { return 0, errors.New("websocket frame was not a binary frame") } @@ -211,12 +206,13 @@ func (w *webSocketWrappedReader) Read(p []byte) (int, error) { return len(p), nil } -func newWebsocketWrappedReader(wsConn *websocket.Conn, respWriter *webSocketResponseWriter, cancel context.CancelFunc) *webSocketWrappedReader { +func newWebsocketWrappedReader(ctx context.Context, wsConn *websocket.Conn, respWriter *webSocketResponseWriter, cancel context.CancelFunc) *webSocketWrappedReader { return &webSocketWrappedReader{ wsConn: wsConn, respWriter: respWriter, remainingBuffer: nil, remainingError: nil, + context: ctx, cancel: cancel, } } diff --git a/go/grpcweb/wrapper.go b/go/grpcweb/wrapper.go index 847f7c0d..eb5cd2dc 100644 --- a/go/grpcweb/wrapper.go +++ b/go/grpcweb/wrapper.go @@ -11,10 +11,10 @@ import ( "strings" "time" - "github.com/gorilla/websocket" "github.com/rs/cors" "google.golang.org/grpc" "google.golang.org/grpc/grpclog" + "nhooyr.io/websocket" ) var ( @@ -147,18 +147,15 @@ func (w *WrappedGrpcServer) HandleGrpcWebRequest(resp http.ResponseWriter, req * intResp.finishRequest(req) } -var websocketUpgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - Subprotocols: []string{"grpc-websockets"}, -} - // HandleGrpcWebsocketRequest takes a HTTP request that is assumed to be a gRPC-Websocket request and wraps it with a // compatibility layer to transform it to a standard gRPC request for the wrapped gRPC server and transforms the // response to comply with the gRPC-Web protocol. func (w *WrappedGrpcServer) HandleGrpcWebsocketRequest(resp http.ResponseWriter, req *http.Request) { - wsConn, err := websocketUpgrader.Upgrade(resp, req, nil) + + wsConn, err := websocket.Accept(resp, req, &websocket.AcceptOptions{ + InsecureSkipVerify: true, // managed by ServeHTTP + Subprotocols: []string{"grpc-websockets"}, + }) if err != nil { grpclog.Errorf("Unable to upgrade websocket request: %v", err) return @@ -170,13 +167,16 @@ func (w *WrappedGrpcServer) HandleGrpcWebsocketRequest(resp http.ResponseWriter, } } - messageType, readBytes, err := wsConn.ReadMessage() + ctx, cancelFunc := context.WithCancel(req.Context()) + defer cancelFunc() + + messageType, readBytes, err := wsConn.Read(ctx) if err != nil { - grpclog.Errorf("Unable to read first websocket message: %v", err) + grpclog.Errorf("Unable to read first websocket message: %v %v %v", messageType, readBytes, err) return } - if messageType != websocket.BinaryMessage { + if messageType != websocket.MessageBinary { grpclog.Errorf("First websocket message is non-binary") return } @@ -187,14 +187,11 @@ func (w *WrappedGrpcServer) HandleGrpcWebsocketRequest(resp http.ResponseWriter, return } - ctx, cancelFunc := context.WithCancel(req.Context()) - defer cancelFunc() - - respWriter := newWebSocketResponseWriter(wsConn) + respWriter := newWebSocketResponseWriter(ctx, wsConn) if w.opts.websocketPingInterval >= time.Second { respWriter.enablePing(w.opts.websocketPingInterval) } - wrappedReader := newWebsocketWrappedReader(wsConn, respWriter, cancelFunc) + wrappedReader := newWebsocketWrappedReader(ctx, wsConn, respWriter, cancelFunc) for name, values := range wsHeaders { headers[name] = values