Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ensure websocket conns respect max duration #156

Merged
merged 1 commit into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions httpbin/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,7 @@ func (h *HTTPBin) WebSocketEcho(w http.ResponseWriter, r *http.Request) {
}

ws := websocket.New(w, r, websocket.Limits{
MaxDuration: h.MaxDuration,
MaxFragmentSize: int(maxFragmentSize),
MaxMessageSize: int(maxMessageSize),
})
Expand Down
8 changes: 8 additions & 0 deletions httpbin/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"io"
"net/http"
"strings"
"time"
"unicode/utf8"
)

Expand Down Expand Up @@ -80,6 +81,7 @@ var EchoHandler Handler = func(ctx context.Context, msg *Message) (*Message, err

// Limits define the limits imposed on a websocket connection.
type Limits struct {
MaxDuration time.Duration
MaxFragmentSize int
MaxMessageSize int
}
Expand All @@ -88,6 +90,7 @@ type Limits struct {
type WebSocket struct {
w http.ResponseWriter
r *http.Request
maxDuration time.Duration
maxFragmentSize int
maxMessageSize int
handshook bool
Expand All @@ -98,6 +101,7 @@ func New(w http.ResponseWriter, r *http.Request, limits Limits) *WebSocket {
return &WebSocket{
w: w,
r: r,
maxDuration: limits.MaxDuration,
maxFragmentSize: limits.MaxFragmentSize,
maxMessageSize: limits.MaxMessageSize,
}
Expand Down Expand Up @@ -152,6 +156,10 @@ func (s *WebSocket) Serve(handler Handler) {
}
defer conn.Close()

// best effort attempt to ensure that our websocket conenctions do not
// exceed the maximum request duration
conn.SetDeadline(time.Now().Add(s.maxDuration))

// errors intentionally ignored here. it's serverLoop's responsibility to
// properly close the websocket connection with a useful error message, and
// any unexpected error returned from serverLoop is not actionable.
Expand Down
1 change: 1 addition & 0 deletions httpbin/websocket/websocket_autobahn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func TestWebSocketServer(t *testing.T) {

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws := websocket.New(w, r, websocket.Limits{
MaxDuration: 30 * time.Second,
MaxFragmentSize: 1024 * 1024 * 16,
MaxMessageSize: 1024 * 1024 * 16,
})
Expand Down
152 changes: 152 additions & 0 deletions httpbin/websocket/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ package websocket_test
import (
"bufio"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/mccutchen/go-httpbin/v2/httpbin/websocket"
"github.com/mccutchen/go-httpbin/v2/internal/testing/assert"
Expand Down Expand Up @@ -220,6 +225,153 @@ func TestHandshakeOrder(t *testing.T) {
})
}

func TestConnectionLimits(t *testing.T) {
t.Run("maximum request duration is enforced", func(t *testing.T) {
t.Parallel()

maxDuration := 500 * time.Millisecond

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws := websocket.New(w, r, websocket.Limits{
MaxDuration: maxDuration,
// TODO: test these limits as well
MaxFragmentSize: 128,
MaxMessageSize: 256,
})
if err := ws.Handshake(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
ws.Serve(websocket.EchoHandler)
}))
defer srv.Close()

conn, err := net.Dial("tcp", srv.Listener.Addr().String())
assert.NilError(t, err)
defer conn.Close()

reqParts := []string{
"GET /websocket/echo HTTP/1.1",
"Host: test",
"Connection: upgrade",
"Upgrade: websocket",
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version: 13",
}
reqBytes := []byte(strings.Join(reqParts, "\r\n") + "\r\n\r\n")
t.Logf("raw request:\n%q", reqBytes)

// first, we write the request line and headers, which should cause the
// server to respond with a 101 Switching Protocols response.
{
n, err := conn.Write(reqBytes)
assert.NilError(t, err)
assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written")

resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
assert.NilError(t, err)
assert.StatusCode(t, resp, http.StatusSwitchingProtocols)
}

// next, we try to read from the connection, expecting the connection
// to be closed after roughly maxDuration seconds
{
start := time.Now()
_, err := conn.Read(make([]byte, 1))
elapsed := time.Since(start)

assert.Error(t, err, io.EOF)
assert.RoughDuration(t, elapsed, maxDuration, 25*time.Millisecond)
}
})

t.Run("client closing connection", func(t *testing.T) {
t.Parallel()

// the client will close the connection well before the server closes
// the connection. make sure the server properly handles the client
// closure.
var (
clientTimeout = 100 * time.Millisecond
serverTimeout = time.Hour // should never be reached
elapsedClientTime time.Duration
elapsedServerTime time.Duration
wg sync.WaitGroup
)

wg.Add(1)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer wg.Done()
start := time.Now()
ws := websocket.New(w, r, websocket.Limits{
MaxDuration: serverTimeout,
MaxFragmentSize: 128,
MaxMessageSize: 256,
})
if err := ws.Handshake(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
ws.Serve(websocket.EchoHandler)
elapsedServerTime = time.Since(start)
}))
defer srv.Close()

conn, err := net.Dial("tcp", srv.Listener.Addr().String())
assert.NilError(t, err)
defer conn.Close()

// should cause the client end of the connection to close well before
// the max request time configured above
conn.SetDeadline(time.Now().Add(clientTimeout))

reqParts := []string{
"GET /websocket/echo HTTP/1.1",
"Host: test",
"Connection: upgrade",
"Upgrade: websocket",
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version: 13",
}
reqBytes := []byte(strings.Join(reqParts, "\r\n") + "\r\n\r\n")
t.Logf("raw request:\n%q", reqBytes)

// first, we write the request line and headers, which should cause the
// server to respond with a 101 Switching Protocols response.
{
n, err := conn.Write(reqBytes)
assert.NilError(t, err)
assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written")

resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
assert.NilError(t, err)
assert.StatusCode(t, resp, http.StatusSwitchingProtocols)
}

// next, we try to read from the connection, expecting the connection
// to be closed after roughly clientTimeout seconds.
//
// the server should detect the closed connection and abort the
// handler, also after roughly clientTimeout seconds.
{
start := time.Now()
_, err := conn.Read(make([]byte, 1))
elapsedClientTime = time.Since(start)

// close client connection, which should interrupt the server's
// blocking read call on the connection
conn.Close()

assert.Equal(t, os.IsTimeout(err), true, "expected timeout error")
assert.RoughDuration(t, elapsedClientTime, clientTimeout, 10*time.Millisecond)

// wait for the server to finish
wg.Wait()
assert.RoughDuration(t, elapsedServerTime, clientTimeout, 10*time.Millisecond)
}
})
}

// brokenHijackResponseWriter implements just enough to satisfy the
// http.ResponseWriter and http.Hijacker interfaces and get through the
// handshake before failing to actually hijack the connection.
Expand Down