From 947c657032c494407c4d50d56d86ead7a155b8a8 Mon Sep 17 00:00:00 2001 From: richieyu Date: Thu, 1 Dec 2022 16:55:12 +0800 Subject: [PATCH] echo back requested Sec-Websocket-Protocol header for better browser compatibility --- golang/websockify.go | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/golang/websockify.go b/golang/websockify.go index 5804768..7dc50c0 100644 --- a/golang/websockify.go +++ b/golang/websockify.go @@ -32,14 +32,6 @@ func init() { web = flag.String("web", path, "web root folder") } -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return true - }, -} - func forwardTcp(wsConn *websocket.Conn, conn net.Conn) { var tcpBuffer [1024]byte defer func() { @@ -56,12 +48,12 @@ func forwardTcp(wsConn *websocket.Conn, conn net.Conn) { } n, err := conn.Read(tcpBuffer[0:]) if err != nil { - log.Printf("%s: reading from TCP failed: %s", time.Now().Format(time.Stamp), err) + log.Printf("%s: TCP.Read() failed: %s", time.Now().Format(time.Stamp), err) + return + } + if err := wsConn.WriteMessage(websocket.BinaryMessage, tcpBuffer[0:n]); err != nil { + log.Printf("%s: websocket.WriteMessage() failed: %s", time.Now().Format(time.Stamp), err) return - } else { - if err := wsConn.WriteMessage(websocket.BinaryMessage, tcpBuffer[0:n]); err != nil { - log.Printf("%s: writing to WS failed: %s", time.Now().Format(time.Stamp), err) - } } } } @@ -69,7 +61,7 @@ func forwardTcp(wsConn *websocket.Conn, conn net.Conn) { func forwardWeb(wsConn *websocket.Conn, conn net.Conn) { defer func() { if err := recover(); err != nil { - log.Printf("%s: reading from WS failed: %s", time.Now().Format(time.Stamp), err) + log.Printf("%s: websocket forwarding side panic: %s", time.Now().Format(time.Stamp), err) } if conn != nil { conn.Close() @@ -84,15 +76,26 @@ func forwardWeb(wsConn *websocket.Conn, conn net.Conn) { } _, buffer, err := wsConn.ReadMessage() - if err == nil { - if _, err := conn.Write(buffer); err != nil { - log.Printf("%s: writing to TCP failed: %s", time.Now().Format(time.Stamp), err) - } + if err != nil { + log.Printf("%s: websocket.ReadMessage() failed: %s", time.Now().Format(time.Stamp), err) + return + } + if _, err := conn.Write(buffer); err != nil { + log.Printf("%s: tcp.Write() failed: %s", time.Now().Format(time.Stamp), err) + return } } } func serveWs(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true + }, + Subprotocols: websocket.Subprotocols(r), + } ws, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("%s: failed to upgrade to WS: %s", time.Now().Format(time.Stamp), err)