Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ResponseWriters to unwrap writers when flushing/hijacking #2595

Merged
merged 3 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions middleware/body_dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package middleware
import (
"bufio"
"bytes"
"errors"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -98,9 +99,16 @@ func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) {
}

func (w *bodyDumpResponseWriter) Flush() {
w.ResponseWriter.(http.Flusher).Flush()
err := responseControllerFlush(w.ResponseWriter)
if err != nil && errors.Is(err, http.ErrNotSupported) {
panic(errors.New("response writer flushing is not supported"))
}
}

func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
return responseControllerHijack(w.ResponseWriter)
}

func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}
50 changes: 50 additions & 0 deletions middleware/body_dump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,53 @@ func TestBodyDumpFails(t *testing.T) {
}
})
}

func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) {
bdrw := bodyDumpResponseWriter{
ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush
}

assert.PanicsWithError(t, "response writer flushing is not supported", func() {
bdrw.Flush()
})
}

func TestBodyDumpResponseWriter_CanFlush(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu,
}

bdrw.Flush()
assert.Equal(t, 1, trwu.unwrapCalled)
}

func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) {
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := bodyDumpResponseWriter{
ResponseWriter: trwu,
}

result := bdrw.Unwrap()
assert.Equal(t, trwu, result)
}

func TestBodyDumpResponseWriter_CanHijack(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "can hijack")
}

func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) {
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "feature not supported")
}
10 changes: 6 additions & 4 deletions middleware/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,15 @@ func (w *gzipResponseWriter) Flush() {
}

w.Writer.(*gzip.Writer).Flush()
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
_ = responseControllerFlush(w.ResponseWriter)
}

func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
return responseControllerHijack(w.ResponseWriter)
}

func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
Expand Down
30 changes: 30 additions & 0 deletions middleware/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,36 @@ func TestGzipWithStatic(t *testing.T) {
}
}

func TestGzipResponseWriter_CanUnwrap(t *testing.T) {
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := gzipResponseWriter{
ResponseWriter: trwu,
}

result := bdrw.Unwrap()
assert.Equal(t, trwu, result)
}

func TestGzipResponseWriter_CanHijack(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := gzipResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "can hijack")
}

func TestGzipResponseWriter_CanNotHijack(t *testing.T) {
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := gzipResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}

_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "feature not supported")
}

func BenchmarkGzip(b *testing.B) {
e := echo.New()

Expand Down
46 changes: 46 additions & 0 deletions middleware/middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package middleware

import (
"bufio"
"errors"
"github.com/stretchr/testify/assert"
"net"
"net/http"
"net/http/httptest"
"regexp"
Expand Down Expand Up @@ -90,3 +93,46 @@ func TestRewriteURL(t *testing.T) {
})
}
}

type testResponseWriterNoFlushHijack struct {
}

func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) {
}

func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) {
return 0, nil
}

func (w *testResponseWriterNoFlushHijack) Header() http.Header {
return nil
}

type testResponseWriterUnwrapper struct {
unwrapCalled int
rw http.ResponseWriter
}

func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) {
}

func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) {
return 0, nil
}

func (w *testResponseWriterUnwrapper) Header() http.Header {
return nil
}

func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter {
w.unwrapCalled++
return w.rw
}

type testResponseWriterUnwrapperHijack struct {
testResponseWriterUnwrapper
}

func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errors.New("can hijack")
}
41 changes: 41 additions & 0 deletions middleware/responsecontroller_1.19.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//go:build !go1.20

package middleware

import (
"bufio"
"fmt"
"net"
"net/http"
)

// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerFlush(rw http.ResponseWriter) error {
for {
switch t := rw.(type) {
case interface{ FlushError() error }:
return t.FlushError()
case http.Flusher:
t.Flush()
return nil
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return fmt.Errorf("%w", http.ErrNotSupported)
}
}
}

// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
for {
switch t := rw.(type) {
case http.Hijacker:
return t.Hijack()
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return nil, nil, fmt.Errorf("%w", http.ErrNotSupported)
}
}
}
17 changes: 17 additions & 0 deletions middleware/responsecontroller_1.20.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//go:build go1.20

package middleware

import (
"bufio"
"net"
"net/http"
)

func responseControllerFlush(rw http.ResponseWriter) error {
return http.NewResponseController(rw).Flush()
}

func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(rw).Hijack()
}
8 changes: 6 additions & 2 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package echo

import (
"bufio"
"errors"
"net"
"net/http"
)
Expand Down Expand Up @@ -84,14 +85,17 @@ func (r *Response) Write(b []byte) (n int, err error) {
// buffered data to the client.
// See [http.Flusher](https://golang.org/pkg/net/http/#Flusher)
func (r *Response) Flush() {
r.Writer.(http.Flusher).Flush()
err := responseControllerFlush(r.Writer)
if err != nil && errors.Is(err, http.ErrNotSupported) {
panic(errors.New("response writer flushing is not supported"))
}
}

// Hijack implements the http.Hijacker interface to allow an HTTP handler to
// take over the connection.
// See [http.Hijacker](https://golang.org/pkg/net/http/#Hijacker)
func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.Writer.(http.Hijacker).Hijack()
return responseControllerHijack(r.Writer)
}

// Unwrap returns the original http.ResponseWriter.
Expand Down
25 changes: 25 additions & 0 deletions response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,31 @@ func TestResponse_Flush(t *testing.T) {
assert.True(t, rec.Flushed)
}

type testResponseWriter struct {
}

func (w *testResponseWriter) WriteHeader(statusCode int) {
}

func (w *testResponseWriter) Write([]byte) (int, error) {
return 0, nil
}

func (w *testResponseWriter) Header() http.Header {
return nil
}

func TestResponse_FlushPanics(t *testing.T) {
e := New()
rw := new(testResponseWriter)
res := &Response{echo: e, Writer: rw}

// we test that we behave as before unwrapping flushers - flushing writer that does not support it causes panic
assert.PanicsWithError(t, "response writer flushing is not supported", func() {
res.Flush()
})
}

func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
Expand Down
41 changes: 41 additions & 0 deletions responsecontroller_1.19.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//go:build !go1.20

package echo

import (
"bufio"
"fmt"
"net"
"net/http"
)

// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerFlush(rw http.ResponseWriter) error {
for {
switch t := rw.(type) {
case interface{ FlushError() error }:
return t.FlushError()
case http.Flusher:
t.Flush()
return nil
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return fmt.Errorf("%w", http.ErrNotSupported)
}
}
}

// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore
func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
for {
switch t := rw.(type) {
case http.Hijacker:
return t.Hijack()
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return nil, nil, fmt.Errorf("%w", http.ErrNotSupported)
}
}
}
17 changes: 17 additions & 0 deletions responsecontroller_1.20.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//go:build go1.20

package echo

import (
"bufio"
"net"
"net/http"
)

func responseControllerFlush(rw http.ResponseWriter) error {
return http.NewResponseController(rw).Flush()
}

func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(rw).Hijack()
}
Loading