Skip to content

Commit

Permalink
caddyhttp: ensure ResponseWriterWrapper and ResponseRecorder use Read…
Browse files Browse the repository at this point in the history
…From if the underlying response writer implements it. (#5022)

Doing so allows for splice/sendfile optimizations when available.
Fixes #4731

Co-authored-by: flga <flga@users.noreply.github.com>
Co-authored-by: Matthew Holt <mholt@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 7, 2022
1 parent 1c9c8f6 commit dd9813c
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 2 deletions.
37 changes: 35 additions & 2 deletions modules/caddyhttp/responsewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ func (rww *ResponseWriterWrapper) Push(target string, opts *http.PushOptions) er
return ErrNotImplemented
}

// ReadFrom implements io.ReaderFrom. It simply calls the underlying
// ResponseWriter's ReadFrom method if there is one, otherwise it defaults
// to io.Copy.
func (rww *ResponseWriterWrapper) ReadFrom(r io.Reader) (n int64, err error) {
if rf, ok := rww.ResponseWriter.(io.ReaderFrom); ok {
return rf.ReadFrom(r)
}
return io.Copy(rww.ResponseWriter, r)
}

// HTTPInterfaces mix all the interfaces that middleware ResponseWriters need to support.
type HTTPInterfaces interface {
http.ResponseWriter
Expand Down Expand Up @@ -188,9 +198,26 @@ func (rr *responseRecorder) Write(data []byte) (int, error) {
} else {
n, err = rr.buf.Write(data)
}
if err == nil {
rr.size += n

rr.size += n
return n, err
}

func (rr *responseRecorder) ReadFrom(r io.Reader) (int64, error) {
rr.WriteHeader(http.StatusOK)
var n int64
var err error
if rr.stream {
if rf, ok := rr.ResponseWriter.(io.ReaderFrom); ok {
n, err = rf.ReadFrom(r)
} else {
n, err = io.Copy(rr.ResponseWriter, r)
}
} else {
n, err = rr.buf.ReadFrom(r)
}

rr.size += int(n)
return n, err
}

Expand Down Expand Up @@ -251,4 +278,10 @@ type ShouldBufferFunc func(status int, header http.Header) bool
var (
_ HTTPInterfaces = (*ResponseWriterWrapper)(nil)
_ ResponseRecorder = (*responseRecorder)(nil)

// Implementing ReaderFrom can be such a significant
// optimization that it should probably be required!
// see PR #5022 (25%-50% speedup)
_ io.ReaderFrom = (*ResponseWriterWrapper)(nil)
_ io.ReaderFrom = (*responseRecorder)(nil)
)
165 changes: 165 additions & 0 deletions modules/caddyhttp/responsewriter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package caddyhttp

import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"testing"
)

type responseWriterSpy interface {
http.ResponseWriter
Written() string
CalledReadFrom() bool
}

var (
_ responseWriterSpy = (*baseRespWriter)(nil)
_ responseWriterSpy = (*readFromRespWriter)(nil)
)

// a barebones http.ResponseWriter mock
type baseRespWriter []byte

func (brw *baseRespWriter) Write(d []byte) (int, error) {
*brw = append(*brw, d...)
return len(d), nil
}
func (brw *baseRespWriter) Header() http.Header { return nil }
func (brw *baseRespWriter) WriteHeader(statusCode int) {}
func (brw *baseRespWriter) Written() string { return string(*brw) }
func (brw *baseRespWriter) CalledReadFrom() bool { return false }

// an http.ResponseWriter mock that supports ReadFrom
type readFromRespWriter struct {
baseRespWriter
called bool
}

func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) {
rf.called = true
return io.Copy(&rf.baseRespWriter, r)
}

func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called }

func TestResponseWriterWrapperReadFrom(t *testing.T) {
tests := map[string]struct {
responseWriter responseWriterSpy
wantReadFrom bool
}{
"no ReadFrom": {
responseWriter: &baseRespWriter{},
wantReadFrom: false,
},
"has ReadFrom": {
responseWriter: &readFromRespWriter{},
wantReadFrom: true,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
// what we expect middlewares to do:
type myWrapper struct {
*ResponseWriterWrapper
}

wrapped := myWrapper{
ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: tt.responseWriter},
}

const srcData = "boo!"
// hides everything but Read, since strings.Reader implements WriteTo it would
// take precedence over our ReadFrom.
src := struct{ io.Reader }{strings.NewReader(srcData)}

fmt.Println(name)
if _, err := io.Copy(wrapped, src); err != nil {
t.Errorf("Copy() err = %v", err)
}

if got := tt.responseWriter.Written(); got != srcData {
t.Errorf("data = %q, want %q", got, srcData)
}

if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom {
if tt.wantReadFrom {
t.Errorf("ReadFrom() should have been called")
} else {
t.Errorf("ReadFrom() should not have been called")
}
}
})
}
}

func TestResponseRecorderReadFrom(t *testing.T) {
tests := map[string]struct {
responseWriter responseWriterSpy
shouldBuffer bool
wantReadFrom bool
}{
"buffered plain": {
responseWriter: &baseRespWriter{},
shouldBuffer: true,
wantReadFrom: false,
},
"streamed plain": {
responseWriter: &baseRespWriter{},
shouldBuffer: false,
wantReadFrom: false,
},
"buffered ReadFrom": {
responseWriter: &readFromRespWriter{},
shouldBuffer: true,
wantReadFrom: false,
},
"streamed ReadFrom": {
responseWriter: &readFromRespWriter{},
shouldBuffer: false,
wantReadFrom: true,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
var buf bytes.Buffer

rr := NewResponseRecorder(tt.responseWriter, &buf, func(status int, header http.Header) bool {
return tt.shouldBuffer
})

const srcData = "boo!"
// hides everything but Read, since strings.Reader implements WriteTo it would
// take precedence over our ReadFrom.
src := struct{ io.Reader }{strings.NewReader(srcData)}

if _, err := io.Copy(rr, src); err != nil {
t.Errorf("Copy() err = %v", err)
}

wantStreamed := srcData
wantBuffered := ""
if tt.shouldBuffer {
wantStreamed = ""
wantBuffered = srcData
}

if got := tt.responseWriter.Written(); got != wantStreamed {
t.Errorf("streamed data = %q, want %q", got, wantStreamed)
}
if got := buf.String(); got != wantBuffered {
t.Errorf("buffered data = %q, want %q", got, wantBuffered)
}

if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom {
if tt.wantReadFrom {
t.Errorf("ReadFrom() should have been called")
} else {
t.Errorf("ReadFrom() should not have been called")
}
}
})
}
}

0 comments on commit dd9813c

Please sign in to comment.