Skip to content

Commit

Permalink
fix(x): fix Websocket code (#367)
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna authored Jan 30, 2025
1 parent 1f73408 commit 80b6430
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 135 deletions.
138 changes: 67 additions & 71 deletions x/examples/ws2endpoint/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package main

import (
"context"
"errors"
"flag"
"io"
"log"
Expand All @@ -29,7 +28,9 @@ import (

"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/Jigsaw-Code/outline-sdk/x/configurl"
"golang.org/x/net/websocket"
"github.com/Jigsaw-Code/outline-sdk/x/websocket"
"github.com/lmittmann/tint"
"golang.org/x/term"
)

type natConn struct {
Expand All @@ -43,44 +44,11 @@ func (c *natConn) Write(b []byte) (int, error) {
return c.Conn.Write(b)
}

func websocketToConn(targetConn io.Writer, clientConn *websocket.Conn) {
var buf []byte
for {
err := websocket.Message.Receive(clientConn, &buf)
if err != nil {
if !errors.Is(err, io.EOF) {
slog.Warn("failed to read from client", "error", err)
}
break
}
_, err = targetConn.Write(buf)
if err != nil {
slog.Warn("failed to write to target", "error", err)
break
}
}
}

func connToWebsocket(clientConn *websocket.Conn, targetConn io.Reader) {
// TODO: use a buffer pool
buf := make([]byte, 64*1024)
for {
n, err := targetConn.Read(buf)
if err != nil {
if !errors.Is(err, io.EOF) {
slog.Warn("failed to read from target", "error", err)
}
break
}
err = websocket.Message.Send(clientConn, buf[:n])
if err != nil {
slog.Warn("failed to write to client", "error", err)
break
}
}
}

func main() {
var logLevel slog.LevelVar
slog.SetDefault(slog.New(tint.NewHandler(
os.Stderr,
&tint.Options{NoColor: !term.IsTerminal(int(os.Stderr.Fd())), Level: &logLevel})))
listenFlag := flag.String("listen", "localhost:8080", "Local proxy address to listen on")
transportFlag := flag.String("transport", "", "Transport config")
backendFlag := flag.String("backend", "", "Address of the endpoint to forward traffic to")
Expand All @@ -97,7 +65,7 @@ func main() {
log.Fatalf("Could not listen on address %v: %v", *listenFlag, err)
}
defer listener.Close()
log.Printf("Proxy listening on %v\n", listener.Addr().String())
slog.Info("Proxy listening", "address", listener.Addr().String())

providers := configurl.NewDefaultProviders()
mux := http.NewServeMux()
Expand All @@ -108,24 +76,35 @@ func main() {
}
endpoint := transport.StreamDialerEndpoint{Dialer: dialer, Address: *backendFlag}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("Got stream request: %v\n", r)
handler := func(wsConn *websocket.Conn) {
defer wsConn.Close()
targetConn, err := endpoint.ConnectStream(r.Context())
slog.Info("Got stream request", "request", r)
clientConn, err := websocket.Upgrade(w, r, http.Header{})
if err != nil {
slog.Error("failed to accept Websocket connection", "error", err)
http.Error(w, "Failed to accept Websocket connection", http.StatusBadGateway)
return
}
defer clientConn.Close()

targetConn, err := endpoint.ConnectStream(r.Context())
if err != nil {
slog.Error("Failed to connect to the origin", "error", err)
w.WriteHeader(http.StatusBadGateway)
return
}
defer targetConn.Close()

go func() {
defer targetConn.CloseWrite()
_, err := io.Copy(targetConn, clientConn)
if err != nil {
log.Printf("Failed to upgrade: %v\n", err)
w.WriteHeader(http.StatusBadGateway)
return
slog.Error("Failed to relay client traffic to target", "error", err)
}
defer targetConn.Close()
// Relay from client to target.
go func() {
defer targetConn.CloseWrite()
websocketToConn(targetConn, wsConn)
}()
connToWebsocket(wsConn, targetConn)
}()
_, err = io.Copy(clientConn, targetConn)
if err != nil {
slog.Error("Failed to relay target traffic to client", "error", err)
}
websocket.Server{Handler: handler}.ServeHTTP(w, r)
clientConn.CloseWrite()
})
mux.Handle(*tcpPathFlag, http.StripPrefix(*tcpPathFlag, handler))
}
Expand All @@ -136,24 +115,41 @@ func main() {
}
endpoint := transport.PacketDialerEndpoint{Dialer: dialer, Address: *backendFlag}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("Got packet request: %v\n", r)
handler := func(wsConn *websocket.Conn) {
defer wsConn.Close()
targetConn, err := endpoint.ConnectPacket(r.Context())
if err != nil {
log.Printf("Failed to upgrade: %v\n", err)
w.WriteHeader(http.StatusBadGateway)
return
}
// Expire connetion after 5 minutes of idle time, as per
// https://datatracker.ietf.org/doc/html/rfc4787#section-4.3
targetConn = &natConn{targetConn, 5 * time.Minute}
slog.Info("Got packet request", "request", r)
clientConn, err := websocket.Upgrade(w, r, http.Header{})
if err != nil {
slog.Error("failed to accept Websocket connection", "error", err)
http.Error(w, "Failed to accept Websocket connection", http.StatusBadGateway)
return
}
defer clientConn.Close()

targetConn, err := endpoint.ConnectPacket(r.Context())
if err != nil {
slog.Error("Failed to connect to the origin", "error", err)
w.WriteHeader(http.StatusBadGateway)
return
}
// Expire connection after 5 minutes of idle time, as per
// https://datatracker.ietf.org/doc/html/rfc4787#section-4.3
targetConn = &natConn{targetConn, 5 * time.Minute}
defer targetConn.Close()

done := false
go func() {
defer targetConn.Close()
// Relay from client to target.
go websocketToConn(targetConn, wsConn)
connToWebsocket(wsConn, targetConn)
_, err := io.Copy(targetConn, clientConn)
if err != nil && !done {
slog.Error("Failed to relay client traffic to target", "error", err)
}
done = true
}()
_, err = io.Copy(clientConn, targetConn)
if err != nil && !done {
slog.Error("Failed to relay target traffic to client", "error", err)
}
websocket.Server{Handler: handler}.ServeHTTP(w, r)
done = true
clientConn.Close()
})
mux.Handle(*udpPathFlag, http.StripPrefix(*udpPathFlag, handler))
}
Expand Down
1 change: 0 additions & 1 deletion x/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ require (
// Use github.com/Psiphon-Labs/psiphon-tunnel-core@staging-client as per
// https://github.com/Psiphon-Labs/psiphon-tunnel-core/?tab=readme-ov-file#using-psiphon-with-go-modules
github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647
github.com/coder/websocket v1.8.12
github.com/gorilla/websocket v1.5.3
github.com/lmittmann/tint v1.0.5
github.com/quic-go/quic-go v0.48.1
Expand Down
2 changes: 0 additions & 2 deletions x/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ github.com/cheekybits/genny v0.0.0-20170328200008-9127e812e1e9/go.mod h1:+tQajlR
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/cognusion/go-cache-lru v0.0.0-20170419142635-f73e2280ecea h1:9C2rdYRp8Vzwhm3sbFX0yYfB+70zKFRjn7cnPCucHSw=
github.com/cognusion/go-cache-lru v0.0.0-20170419142635-f73e2280ecea/go.mod h1:MdyNkAe06D7xmJsf+MsLvbZKYNXuOHLKJrvw+x4LlcQ=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
21 changes: 13 additions & 8 deletions x/websocket/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,19 @@ func newEndpoint[ConnType net.Conn](urlStr string, sd transport.StreamDialer, ws
if err != nil {
return zero, err
}
gConn := &gorillaConn{wsConn: wsConn}
wsConn.SetCloseHandler(func(code int, text string) error {
gConn.readErr = io.EOF
return nil
})
return wsToConn(gConn), nil
return wsToConn(newGorillaConn(wsConn)), nil
}, nil
}

func newGorillaConn(wsConn *websocket.Conn) *gorillaConn {
gConn := &gorillaConn{wsConn: wsConn}
wsConn.SetCloseHandler(func(code int, text string) error {
gConn.readErr = io.EOF
return nil
})
return gConn
}

type gorillaConn struct {
wsConn *websocket.Conn
writeErr error
Expand Down Expand Up @@ -196,6 +200,8 @@ func (c *gorillaConn) CloseWrite() error {
}

func (c *gorillaConn) Close() error {
c.CloseRead()
c.CloseWrite()
return c.wsConn.Close()
}

Expand All @@ -208,6 +214,5 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header)
if err != nil {
return nil, err
}

return &gorillaConn{wsConn: wsConn}, nil
return newGorillaConn(wsConn), nil
}
71 changes: 18 additions & 53 deletions x/websocket/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import (
"testing"

"github.com/Jigsaw-Code/outline-sdk/transport"
// TODO(fortuna): Implement the test with gorilla instead.
"github.com/coder/websocket"
"github.com/stretchr/testify/require"
)

Expand All @@ -36,61 +34,28 @@ func Test_NewStreamEndpoint(t *testing.T) {
// TODO(fortuna): support h2 and h3 on the server.
require.Equal(t, "", r.TLS.NegotiatedProtocol)
require.Equal(t, "HTTP/1.1", r.Proto)

t.Log("Got stream request", "request", r)
defer t.Log("Request done")
clientConn, err := websocket.Accept(w, r, nil)

clientConn, err := Upgrade(w, r, http.Header{})
if err != nil {
t.Log("Failed to accept Websocket connection", "error", err)
http.Error(w, "Failed to accept Websocket connection", http.StatusBadGateway)
return
}
clientConn.SetReadLimit(-1)
defer clientConn.CloseNow()
defer clientConn.Close()

// Handle client -> target.
readClientDone := make(chan struct{})
go func() {
defer close(readClientDone)
defer clientConn.CloseRead(r.Context())
for {
msgType, msg, err := clientConn.Read(r.Context())
if err != nil {
if !errors.Is(err, io.EOF) {
t.Log("Failed to read from client", "error", err)
clientConn.Close(websocket.StatusInternalError, "failed to read from client")
}
break
}
require.Equal(t, websocket.MessageBinary, msgType)
if _, err := toTargetWriter.Write(msg); err != nil {
t.Log("Failed to write to target", "error", err)
clientConn.Close(websocket.StatusInternalError, "failed to write message to target")
break
}
}
defer toTargetWriter.Close()
_, err := io.Copy(toTargetWriter, clientConn)
require.NoError(t, err)
}()
// Handle target -> client
func() {
// About 2 MTUs
buf := make([]byte, 3000)
for {
n, err := fromTargetReader.Read(buf)
if err != nil {
if !errors.Is(err, io.EOF) {
t.Log("Failed to read from target", "error", err)
clientConn.Close(websocket.StatusInternalError, "failed to read message from target")
}
break
}
read := buf[:n]
if err := clientConn.Write(r.Context(), websocket.MessageBinary, read); err != nil {
t.Log("Failed to write to client", "error", err)
clientConn.Close(websocket.StatusInternalError, "failed to write message to client")
break
}
}
}()
_, err = io.Copy(clientConn, fromTargetReader)
require.NoError(t, err)
<-readClientDone
})
mux.Handle("/tcp", http.StripPrefix("/tcp", handler))
Expand Down Expand Up @@ -157,19 +122,18 @@ func Test_NewPacketEndpoint(t *testing.T) {
// TODO(fortuna): support h2 and h3 on the server.
require.Equal(t, "", r.TLS.NegotiatedProtocol)
require.Equal(t, "HTTP/1.1", r.Proto)
clientConn, err := websocket.Accept(w, r, nil)
clientConn, err := Upgrade(w, r, http.Header{})
require.NoError(t, err)
defer clientConn.CloseNow()
defer clientConn.Close()

msgType, msg, err := clientConn.Read(r.Context())
buf := make([]byte, 8)
n, err := clientConn.Read(buf)
require.NoError(t, err)
require.Equal(t, websocket.MessageBinary, msgType)
require.Equal(t, []byte("Request"), msg)
require.Equal(t, []byte("Request"), buf[:n])

err = clientConn.Write(r.Context(), websocket.MessageBinary, []byte("Response"))
n, err = clientConn.Write([]byte("Response"))
require.NoError(t, err)

clientConn.Close(websocket.StatusNormalClosure, "")
require.Equal(t, 8, n)
})
mux.Handle("/udp", http.StripPrefix("/udp", handler))
ts := httptest.NewUnstartedServer(mux)
Expand All @@ -190,7 +154,8 @@ func Test_NewPacketEndpoint(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 7, n)

resp, err := io.ReadAll(conn)
buf := make([]byte, 9)
n, err = conn.Read(buf)
require.NoError(t, err)
require.Equal(t, []byte("Response"), resp)
require.Equal(t, []byte("Response"), buf[:n])
}

0 comments on commit 80b6430

Please sign in to comment.