Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve response writer safety #30

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,6 @@ f := fox.New(fox.DefaultOptions())
f.MustHandle(http.MethodGet, "/articles/{id}", fox.WrapH(httpRateLimiter.RateLimit(articles)))
```

### Custom http.ResponseWriter Implementations
When using custom `http.ResponseWriter` implementations, it's important to ensure that these implementations expose the
required http interfaces. For HTTP/1.x requests, Fox expects the `http.ResponseWriter` to implement the `http.Flusher`,
`http.Hijacker`, and `io.ReaderFrom` interfaces. For HTTP/2 requests, the `http.ResponseWriter` should implement the
`http.Flusher` and `http.Pusher` interfaces. Fox will invoke these methods **without any prior assertion**.


## Middleware
Middlewares can be registered globally using the `fox.WithMiddleware` option. The example below demonstrates how
to create and apply automatically a simple logging middleware to all routes (including 404, 405, etc...).
Expand Down
6 changes: 1 addition & 5 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,7 @@ type context struct {
func (c *context) Reset(fox *Router, w http.ResponseWriter, r *http.Request) {
c.rec.reset(w)
c.req = r
if r.ProtoMajor == 2 {
c.w = h2Writer{&c.rec}
} else {
c.w = h1Writer{&c.rec}
}
c.w = &c.rec
c.fox = fox
c.path = ""
c.cachedQuery = nil
Expand Down
36 changes: 36 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,42 @@ import (
"testing"
)

func TestContext_Writer_ReadFrom(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
w := httptest.NewRecorder()

c := NewTestContextOnly(New(), w, req)

n, err := c.Writer().ReadFrom(bytes.NewBuffer([]byte("foo bar")))
require.NoError(t, err)
assert.Equal(t, int(n), c.Writer().Size())
assert.True(t, c.Writer().Written())
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, int(n), w.Body.Len())
}

func TestContext_SetWriter(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
w := httptest.NewRecorder()

c := NewTestContextOnly(New(), w, req)

newRec := new(recorder)
c.SetWriter(newRec)
assert.Equal(t, newRec, c.Writer())
}

func TestContext_SetRequest(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
w := httptest.NewRecorder()

c := NewTestContextOnly(New(), w, req)

newReq := new(http.Request)
c.SetRequest(newReq)
assert.Equal(t, newReq, c.Request())
}

func TestContext_QueryParams(t *testing.T) {
t.Parallel()
wantValues := url.Values{
Expand Down
5 changes: 0 additions & 5 deletions fox.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,6 @@ func defaultOptionsHandler(c Context) {

// ServeHTTP is the main entry point to serve a request. It handles all incoming HTTP requests and dispatches them
// to the appropriate handler function based on the request's method and path.
//
// It expects the http.ResponseWriter provided to implement the http.Flusher, http.Hijacker, and io.ReaderFrom
// interfaces for HTTP/1.x requests and the http.Flusher and http.Pusher interfaces for HTTP/2 requests.
// If a custom response writer is used, it is critical to ensure that these methods are properly exposed as Fox
// will invoke them without any prior assertion.
func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {

var (
Expand Down
2 changes: 1 addition & 1 deletion helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func newTextContextOnly(fox *Router, w http.ResponseWriter, r *http.Request) *co
c.fox = fox
c.req = r
c.rec.reset(w)
c.w = flushWriter{&c.rec}
c.w = &c.rec
return c
}

Expand Down
11 changes: 5 additions & 6 deletions helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package fox

import (
"io"
"net/http"
"net/http/httptest"
"testing"
Expand All @@ -19,7 +18,7 @@ func TestNewTestContext(t *testing.T) {
w := httptest.NewRecorder()
_, c := NewTestContext(w, req)

flusher, ok := c.Writer().(http.Flusher)
flusher, ok := c.Writer().(interface{ Unwrap() http.ResponseWriter }).Unwrap().(http.Flusher)
require.True(t, ok)

n, err := c.Writer().Write([]byte("foo"))
Expand All @@ -33,9 +32,9 @@ func TestNewTestContext(t *testing.T) {

assert.Equal(t, 6, c.Writer().Size())

_, ok = c.Writer().(http.Hijacker)
assert.False(t, ok)
_, _, err = c.Writer().Hijack()
assert.ErrorIs(t, err, http.ErrNotSupported)

_, ok = c.Writer().(io.ReaderFrom)
assert.False(t, ok)
err = c.Writer().Push("foo", nil)
assert.ErrorIs(t, err, http.ErrNotSupported)
}
148 changes: 75 additions & 73 deletions response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,44 +20,39 @@
"path"
"runtime"
"strings"
"sync"
)

var (
_ http.Flusher = (*h1Writer)(nil)
_ http.Hijacker = (*h1Writer)(nil)
_ io.ReaderFrom = (*h1Writer)(nil)
)

var (
_ http.Pusher = (*h2Writer)(nil)
_ http.Flusher = (*h2Writer)(nil)
)
var _ ResponseWriter = (*recorder)(nil)

var (
_ ResponseWriter = (*flushWriter)(nil)
_ http.Flusher = (*flushWriter)(nil)
)

var (
_ ResponseWriter = (*pushWriter)(nil)
_ http.Pusher = (*pushWriter)(nil)
)
var copyBufPool = sync.Pool{
New: func() any {
b := make([]byte, 32*1024)
return &b
},
}

// ResponseWriter extends http.ResponseWriter and provides methods to retrieve the recorded status code,
// written state, and response size. ResponseWriter object implements additional http.Flusher, http.Hijacker,
// io.ReaderFrom interfaces for HTTP/1.x requests and http.Flusher, http.Pusher interfaces for HTTP/2 requests.
// written state, and response size.
type ResponseWriter interface {
http.ResponseWriter
io.StringWriter
io.ReaderFrom
// Status recorded after Write and WriteHeader.
Status() int
// Written returns true if the response has been written.
Written() bool
// Size returns the size of the written response.
Size() int
// WriteString writes the provided string to the underlying connection
// as part of an HTTP reply. The method returns the number of bytes written
// and an error, if any.
WriteString(s string) (int, error)
// FlushError flushes buffered data to the client. If flush is not supported, FlushError returns an error
// matching http.ErrNotSupported. See http.Flusher for more details.
FlushError() error
// Hijack lets the caller take over the connection. If hijacking the connection is not supported, Hijack returns
// an error matching http.ErrNotSupported. See http.Hijacker for more details.
Hijack() (net.Conn, *bufio.ReadWriter, error)
// Push initiates an HTTP/2 server push. Push returns http.ErrNotSupported if the client has disabled push or if push
// is not supported on the underlying connection. See http.Pusher for more details.
Push(target string, opts *http.PushOptions) error
}

const notWritten = -1
Expand All @@ -74,14 +69,17 @@
r.status = http.StatusOK
}

// Status recorded after Write or WriteHeader.
func (r *recorder) Status() int {
return r.status
}

// Written returns true if the response has been written.
func (r *recorder) Written() bool {
return r.size != notWritten
}

// Size returns the size of the written response.
func (r *recorder) Size() int {
if r.size < 0 {
return 0
Expand All @@ -93,6 +91,8 @@
return r.ResponseWriter
}

// WriteHeader sends an HTTP response header with the provided
// status code. See http.ResponseWriter for more details.
func (r *recorder) WriteHeader(code int) {
if r.Written() {
caller := relevantCaller()
Expand All @@ -105,6 +105,8 @@
r.ResponseWriter.WriteHeader(code)
}

// Write writes the data to the connection as part of an HTTP reply.
// See http.ResponseWriter for more details.
func (r *recorder) Write(buf []byte) (n int, err error) {
if !r.Written() {
r.size = 0
Expand All @@ -115,6 +117,9 @@
return
}

// WriteString writes the provided string to the underlying connection
// as part of an HTTP reply. The method returns the number of bytes written
// and an error, if any.
func (r *recorder) WriteString(s string) (n int, err error) {
if !r.Written() {
r.size = 0
Expand All @@ -126,74 +131,67 @@
return
}

type flushWriter struct {
*recorder
}

func (w flushWriter) Flush() {
if !w.recorder.Written() {
w.recorder.size = 0
// ReadFrom reads data from src until EOF or error. The return value n is the number of bytes read.
// Any error except EOF encountered during the read is also returned.
func (r *recorder) ReadFrom(src io.Reader) (n int64, err error) {
if !r.Written() {
r.size = 0
}
w.recorder.ResponseWriter.(http.Flusher).Flush()
}

type h1Writer struct {
*recorder
}

func (w h1Writer) ReadFrom(src io.Reader) (n int64, err error) {
if !w.recorder.Written() {
w.recorder.size = 0
if rf, ok := r.ResponseWriter.(io.ReaderFrom); ok {
n, err = rf.ReadFrom(src)
r.size += int(n)
return

Check warning on line 144 in response_writer.go

View check run for this annotation

Codecov / codecov/patch

response_writer.go#L142-L144

Added lines #L142 - L144 were not covered by tests
}

rf := w.recorder.ResponseWriter.(io.ReaderFrom)
n, err = rf.ReadFrom(src)
w.recorder.size += int(n)
// Fallback in compatibility mode.
bufp := copyBufPool.Get().(*[]byte)
buf := *bufp
n, err = io.CopyBuffer(onlyWrite{r}, src, buf)
copyBufPool.Put(bufp)
return
}

func (w h1Writer) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if !w.recorder.Written() {
w.recorder.size = 0
// FlushError flushes buffered data to the client. If flush is not supported, FlushError returns an error
// matching http.ErrNotSupported. See http.Flusher for more details.
func (r *recorder) FlushError() error {
switch flusher := r.ResponseWriter.(type) {
case interface{ FlushError() error }:
return flusher.FlushError()
case http.Flusher:
flusher.Flush()
return nil
default:
return errNotSupported()

Check warning on line 165 in response_writer.go

View check run for this annotation

Codecov / codecov/patch

response_writer.go#L157-L165

Added lines #L157 - L165 were not covered by tests
}
return w.recorder.ResponseWriter.(http.Hijacker).Hijack()
}

func (w h1Writer) Flush() {
if !w.recorder.Written() {
w.recorder.size = 0
// Push initiates an HTTP/2 server push. Push returns http.ErrNotSupported if the client has disabled push or if push
// is not supported on the underlying connection. See http.Pusher for more details.
func (r *recorder) Push(target string, opts *http.PushOptions) error {
if pusher, ok := r.ResponseWriter.(http.Pusher); ok {
return pusher.Push(target, opts)

Check warning on line 173 in response_writer.go

View check run for this annotation

Codecov / codecov/patch

response_writer.go#L173

Added line #L173 was not covered by tests
}
w.recorder.ResponseWriter.(http.Flusher).Flush()
}

type h2Writer struct {
*recorder
return http.ErrNotSupported
}

func (w h2Writer) Push(target string, opts *http.PushOptions) error {
return w.recorder.ResponseWriter.(http.Pusher).Push(target, opts)
}

func (w h2Writer) Flush() {
if !w.recorder.Written() {
w.recorder.size = 0
// Hijack lets the caller take over the connection. If hijacking the connection is not supported, Hijack returns
// an error matching http.ErrNotSupported. See http.Hijacker for more details.
func (r *recorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacker, ok := r.ResponseWriter.(http.Hijacker); ok {
return hijacker.Hijack()

Check warning on line 182 in response_writer.go

View check run for this annotation

Codecov / codecov/patch

response_writer.go#L182

Added line #L182 was not covered by tests
}
w.recorder.ResponseWriter.(http.Flusher).Flush()
}

type pushWriter struct {
*recorder
return nil, nil, errNotSupported()
}

func (w pushWriter) Push(target string, opts *http.PushOptions) error {
return w.recorder.ResponseWriter.(http.Pusher).Push(target, opts)
}

// noUnwrap hide the Unwrap method of the ResponseWriter.
type noUnwrap struct {
ResponseWriter
}

type onlyWrite struct {
io.Writer
}

type noopWriter struct {
h http.Header
}
Expand Down Expand Up @@ -226,3 +224,7 @@
}
return frame
}

func errNotSupported() error {
return fmt.Errorf("%w", http.ErrNotSupported)
}
Loading