From 80b6430a1fc83e6f8e7d90c7dea4a08d875c43dc Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Thu, 30 Jan 2025 17:26:46 -0500 Subject: [PATCH] fix(x): fix Websocket code (#367) --- x/examples/ws2endpoint/main.go | 138 ++++++++++++++++----------------- x/go.mod | 1 - x/go.sum | 2 - x/websocket/endpoint.go | 21 +++-- x/websocket/endpoint_test.go | 71 +++++------------ 5 files changed, 98 insertions(+), 135 deletions(-) diff --git a/x/examples/ws2endpoint/main.go b/x/examples/ws2endpoint/main.go index 4a415cc2..75a95632 100644 --- a/x/examples/ws2endpoint/main.go +++ b/x/examples/ws2endpoint/main.go @@ -16,7 +16,6 @@ package main import ( "context" - "errors" "flag" "io" "log" @@ -29,7 +28,9 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/x/configurl" - "golang.org/x/net/websocket" + "github.com/Jigsaw-Code/outline-sdk/x/websocket" + "github.com/lmittmann/tint" + "golang.org/x/term" ) type natConn struct { @@ -43,44 +44,11 @@ func (c *natConn) Write(b []byte) (int, error) { return c.Conn.Write(b) } -func websocketToConn(targetConn io.Writer, clientConn *websocket.Conn) { - var buf []byte - for { - err := websocket.Message.Receive(clientConn, &buf) - if err != nil { - if !errors.Is(err, io.EOF) { - slog.Warn("failed to read from client", "error", err) - } - break - } - _, err = targetConn.Write(buf) - if err != nil { - slog.Warn("failed to write to target", "error", err) - break - } - } -} - -func connToWebsocket(clientConn *websocket.Conn, targetConn io.Reader) { - // TODO: use a buffer pool - buf := make([]byte, 64*1024) - for { - n, err := targetConn.Read(buf) - if err != nil { - if !errors.Is(err, io.EOF) { - slog.Warn("failed to read from target", "error", err) - } - break - } - err = websocket.Message.Send(clientConn, buf[:n]) - if err != nil { - slog.Warn("failed to write to client", "error", err) - break - } - } -} - func main() { + var logLevel slog.LevelVar + slog.SetDefault(slog.New(tint.NewHandler( + os.Stderr, + &tint.Options{NoColor: !term.IsTerminal(int(os.Stderr.Fd())), Level: &logLevel}))) listenFlag := flag.String("listen", "localhost:8080", "Local proxy address to listen on") transportFlag := flag.String("transport", "", "Transport config") backendFlag := flag.String("backend", "", "Address of the endpoint to forward traffic to") @@ -97,7 +65,7 @@ func main() { log.Fatalf("Could not listen on address %v: %v", *listenFlag, err) } defer listener.Close() - log.Printf("Proxy listening on %v\n", listener.Addr().String()) + slog.Info("Proxy listening", "address", listener.Addr().String()) providers := configurl.NewDefaultProviders() mux := http.NewServeMux() @@ -108,24 +76,35 @@ func main() { } endpoint := transport.StreamDialerEndpoint{Dialer: dialer, Address: *backendFlag} handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Printf("Got stream request: %v\n", r) - handler := func(wsConn *websocket.Conn) { - defer wsConn.Close() - targetConn, err := endpoint.ConnectStream(r.Context()) + slog.Info("Got stream request", "request", r) + clientConn, err := websocket.Upgrade(w, r, http.Header{}) + if err != nil { + slog.Error("failed to accept Websocket connection", "error", err) + http.Error(w, "Failed to accept Websocket connection", http.StatusBadGateway) + return + } + defer clientConn.Close() + + targetConn, err := endpoint.ConnectStream(r.Context()) + if err != nil { + slog.Error("Failed to connect to the origin", "error", err) + w.WriteHeader(http.StatusBadGateway) + return + } + defer targetConn.Close() + + go func() { + defer targetConn.CloseWrite() + _, err := io.Copy(targetConn, clientConn) if err != nil { - log.Printf("Failed to upgrade: %v\n", err) - w.WriteHeader(http.StatusBadGateway) - return + slog.Error("Failed to relay client traffic to target", "error", err) } - defer targetConn.Close() - // Relay from client to target. - go func() { - defer targetConn.CloseWrite() - websocketToConn(targetConn, wsConn) - }() - connToWebsocket(wsConn, targetConn) + }() + _, err = io.Copy(clientConn, targetConn) + if err != nil { + slog.Error("Failed to relay target traffic to client", "error", err) } - websocket.Server{Handler: handler}.ServeHTTP(w, r) + clientConn.CloseWrite() }) mux.Handle(*tcpPathFlag, http.StripPrefix(*tcpPathFlag, handler)) } @@ -136,24 +115,41 @@ func main() { } endpoint := transport.PacketDialerEndpoint{Dialer: dialer, Address: *backendFlag} handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Printf("Got packet request: %v\n", r) - handler := func(wsConn *websocket.Conn) { - defer wsConn.Close() - targetConn, err := endpoint.ConnectPacket(r.Context()) - if err != nil { - log.Printf("Failed to upgrade: %v\n", err) - w.WriteHeader(http.StatusBadGateway) - return - } - // Expire connetion after 5 minutes of idle time, as per - // https://datatracker.ietf.org/doc/html/rfc4787#section-4.3 - targetConn = &natConn{targetConn, 5 * time.Minute} + slog.Info("Got packet request", "request", r) + clientConn, err := websocket.Upgrade(w, r, http.Header{}) + if err != nil { + slog.Error("failed to accept Websocket connection", "error", err) + http.Error(w, "Failed to accept Websocket connection", http.StatusBadGateway) + return + } + defer clientConn.Close() + + targetConn, err := endpoint.ConnectPacket(r.Context()) + if err != nil { + slog.Error("Failed to connect to the origin", "error", err) + w.WriteHeader(http.StatusBadGateway) + return + } + // Expire connection after 5 minutes of idle time, as per + // https://datatracker.ietf.org/doc/html/rfc4787#section-4.3 + targetConn = &natConn{targetConn, 5 * time.Minute} + defer targetConn.Close() + + done := false + go func() { defer targetConn.Close() - // Relay from client to target. - go websocketToConn(targetConn, wsConn) - connToWebsocket(wsConn, targetConn) + _, err := io.Copy(targetConn, clientConn) + if err != nil && !done { + slog.Error("Failed to relay client traffic to target", "error", err) + } + done = true + }() + _, err = io.Copy(clientConn, targetConn) + if err != nil && !done { + slog.Error("Failed to relay target traffic to client", "error", err) } - websocket.Server{Handler: handler}.ServeHTTP(w, r) + done = true + clientConn.Close() }) mux.Handle(*udpPathFlag, http.StripPrefix(*udpPathFlag, handler)) } diff --git a/x/go.mod b/x/go.mod index ea9d32a9..522edfa7 100644 --- a/x/go.mod +++ b/x/go.mod @@ -7,7 +7,6 @@ require ( // Use github.com/Psiphon-Labs/psiphon-tunnel-core@staging-client as per // https://github.com/Psiphon-Labs/psiphon-tunnel-core/?tab=readme-ov-file#using-psiphon-with-go-modules github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647 - github.com/coder/websocket v1.8.12 github.com/gorilla/websocket v1.5.3 github.com/lmittmann/tint v1.0.5 github.com/quic-go/quic-go v0.48.1 diff --git a/x/go.sum b/x/go.sum index 3d19fb51..ecee0264 100644 --- a/x/go.sum +++ b/x/go.sum @@ -37,8 +37,6 @@ github.com/cheekybits/genny v0.0.0-20170328200008-9127e812e1e9/go.mod h1:+tQajlR github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= -github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/cognusion/go-cache-lru v0.0.0-20170419142635-f73e2280ecea h1:9C2rdYRp8Vzwhm3sbFX0yYfB+70zKFRjn7cnPCucHSw= github.com/cognusion/go-cache-lru v0.0.0-20170419142635-f73e2280ecea/go.mod h1:MdyNkAe06D7xmJsf+MsLvbZKYNXuOHLKJrvw+x4LlcQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/x/websocket/endpoint.go b/x/websocket/endpoint.go index bd2fdd67..dca2c82a 100644 --- a/x/websocket/endpoint.go +++ b/x/websocket/endpoint.go @@ -97,15 +97,19 @@ func newEndpoint[ConnType net.Conn](urlStr string, sd transport.StreamDialer, ws if err != nil { return zero, err } - gConn := &gorillaConn{wsConn: wsConn} - wsConn.SetCloseHandler(func(code int, text string) error { - gConn.readErr = io.EOF - return nil - }) - return wsToConn(gConn), nil + return wsToConn(newGorillaConn(wsConn)), nil }, nil } +func newGorillaConn(wsConn *websocket.Conn) *gorillaConn { + gConn := &gorillaConn{wsConn: wsConn} + wsConn.SetCloseHandler(func(code int, text string) error { + gConn.readErr = io.EOF + return nil + }) + return gConn +} + type gorillaConn struct { wsConn *websocket.Conn writeErr error @@ -196,6 +200,8 @@ func (c *gorillaConn) CloseWrite() error { } func (c *gorillaConn) Close() error { + c.CloseRead() + c.CloseWrite() return c.wsConn.Close() } @@ -208,6 +214,5 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) if err != nil { return nil, err } - - return &gorillaConn{wsConn: wsConn}, nil + return newGorillaConn(wsConn), nil } diff --git a/x/websocket/endpoint_test.go b/x/websocket/endpoint_test.go index 342dfe53..cea2aff9 100644 --- a/x/websocket/endpoint_test.go +++ b/x/websocket/endpoint_test.go @@ -23,8 +23,6 @@ import ( "testing" "github.com/Jigsaw-Code/outline-sdk/transport" - // TODO(fortuna): Implement the test with gorilla instead. - "github.com/coder/websocket" "github.com/stretchr/testify/require" ) @@ -36,61 +34,28 @@ func Test_NewStreamEndpoint(t *testing.T) { // TODO(fortuna): support h2 and h3 on the server. require.Equal(t, "", r.TLS.NegotiatedProtocol) require.Equal(t, "HTTP/1.1", r.Proto) - t.Log("Got stream request", "request", r) defer t.Log("Request done") - clientConn, err := websocket.Accept(w, r, nil) + + clientConn, err := Upgrade(w, r, http.Header{}) if err != nil { t.Log("Failed to accept Websocket connection", "error", err) http.Error(w, "Failed to accept Websocket connection", http.StatusBadGateway) return } - clientConn.SetReadLimit(-1) - defer clientConn.CloseNow() + defer clientConn.Close() // Handle client -> target. readClientDone := make(chan struct{}) go func() { defer close(readClientDone) - defer clientConn.CloseRead(r.Context()) - for { - msgType, msg, err := clientConn.Read(r.Context()) - if err != nil { - if !errors.Is(err, io.EOF) { - t.Log("Failed to read from client", "error", err) - clientConn.Close(websocket.StatusInternalError, "failed to read from client") - } - break - } - require.Equal(t, websocket.MessageBinary, msgType) - if _, err := toTargetWriter.Write(msg); err != nil { - t.Log("Failed to write to target", "error", err) - clientConn.Close(websocket.StatusInternalError, "failed to write message to target") - break - } - } + defer toTargetWriter.Close() + _, err := io.Copy(toTargetWriter, clientConn) + require.NoError(t, err) }() // Handle target -> client - func() { - // About 2 MTUs - buf := make([]byte, 3000) - for { - n, err := fromTargetReader.Read(buf) - if err != nil { - if !errors.Is(err, io.EOF) { - t.Log("Failed to read from target", "error", err) - clientConn.Close(websocket.StatusInternalError, "failed to read message from target") - } - break - } - read := buf[:n] - if err := clientConn.Write(r.Context(), websocket.MessageBinary, read); err != nil { - t.Log("Failed to write to client", "error", err) - clientConn.Close(websocket.StatusInternalError, "failed to write message to client") - break - } - } - }() + _, err = io.Copy(clientConn, fromTargetReader) + require.NoError(t, err) <-readClientDone }) mux.Handle("/tcp", http.StripPrefix("/tcp", handler)) @@ -157,19 +122,18 @@ func Test_NewPacketEndpoint(t *testing.T) { // TODO(fortuna): support h2 and h3 on the server. require.Equal(t, "", r.TLS.NegotiatedProtocol) require.Equal(t, "HTTP/1.1", r.Proto) - clientConn, err := websocket.Accept(w, r, nil) + clientConn, err := Upgrade(w, r, http.Header{}) require.NoError(t, err) - defer clientConn.CloseNow() + defer clientConn.Close() - msgType, msg, err := clientConn.Read(r.Context()) + buf := make([]byte, 8) + n, err := clientConn.Read(buf) require.NoError(t, err) - require.Equal(t, websocket.MessageBinary, msgType) - require.Equal(t, []byte("Request"), msg) + require.Equal(t, []byte("Request"), buf[:n]) - err = clientConn.Write(r.Context(), websocket.MessageBinary, []byte("Response")) + n, err = clientConn.Write([]byte("Response")) require.NoError(t, err) - - clientConn.Close(websocket.StatusNormalClosure, "") + require.Equal(t, 8, n) }) mux.Handle("/udp", http.StripPrefix("/udp", handler)) ts := httptest.NewUnstartedServer(mux) @@ -190,7 +154,8 @@ func Test_NewPacketEndpoint(t *testing.T) { require.NoError(t, err) require.Equal(t, 7, n) - resp, err := io.ReadAll(conn) + buf := make([]byte, 9) + n, err = conn.Read(buf) require.NoError(t, err) - require.Equal(t, []byte("Response"), resp) + require.Equal(t, []byte("Response"), buf[:n]) }