Skip to content

Commit

Permalink
Implement http.Flusher interface on Log middleware (#257)
Browse files Browse the repository at this point in the history
This adds `Flush()` to the `middleware.badResponseLoggingWriter`, making it implement `http.Flusher` if the wrapped `http.ResponseWriter` does.
  • Loading branch information
leizor authored Sep 15, 2022
1 parent e98fcdf commit 7b5f6f3
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 14 deletions.
2 changes: 1 addition & 1 deletion middleware/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (l Log) Wrap(next http.Handler) http.Handler {
wrapped := newBadResponseLoggingWriter(w, &buf)
next.ServeHTTP(wrapped, r)

statusCode, writeErr := wrapped.statusCode, wrapped.writeError
statusCode, writeErr := wrapped.getStatusCode(), wrapped.getWriteError()

if writeErr != nil {
if errors.Is(writeErr, context.Canceled) {
Expand Down
52 changes: 41 additions & 11 deletions middleware/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@ const (
maxResponseBodyInLogs = 4096 // At most 4k bytes from response bodies in our logs.
)

// badResponseLoggingWriter writes the body of "bad" responses (i.e. 5xx
type badResponseLoggingWriter interface {
http.ResponseWriter
getStatusCode() int
getWriteError() error
}

// nonFlushingBadResponseLoggingWriter writes the body of "bad" responses (i.e. 5xx
// responses) to a buffer.
type badResponseLoggingWriter struct {
type nonFlushingBadResponseLoggingWriter struct {
rw http.ResponseWriter
buffer io.Writer
logBody bool
Expand All @@ -23,27 +29,39 @@ type badResponseLoggingWriter struct {
writeError error // The error returned when downstream Write() fails.
}

// newBadResponseLoggingWriter makes a new badResponseLoggingWriter.
func newBadResponseLoggingWriter(rw http.ResponseWriter, buffer io.Writer) *badResponseLoggingWriter {
return &badResponseLoggingWriter{
// flushingBadResponseLoggingWriter is a badResponseLoggingWriter that
// implements http.Flusher.
type flushingBadResponseLoggingWriter struct {
nonFlushingBadResponseLoggingWriter
f http.Flusher
}

func newBadResponseLoggingWriter(rw http.ResponseWriter, buffer io.Writer) badResponseLoggingWriter {
b := nonFlushingBadResponseLoggingWriter{
rw: rw,
buffer: buffer,
logBody: false,
bodyBytesLeft: maxResponseBodyInLogs,
statusCode: http.StatusOK,
}

if f, ok := rw.(http.Flusher); ok {
return &flushingBadResponseLoggingWriter{b, f}
}

return &b
}

// Header returns the header map that will be sent by WriteHeader.
// Implements ResponseWriter.
func (b *badResponseLoggingWriter) Header() http.Header {
func (b *nonFlushingBadResponseLoggingWriter) Header() http.Header {
return b.rw.Header()
}

// Write writes HTTP response data.
func (b *badResponseLoggingWriter) Write(data []byte) (int, error) {
func (b *nonFlushingBadResponseLoggingWriter) Write(data []byte) (int, error) {
if b.statusCode == 0 {
// WriteHeader has (probably) not been called, so we need to call it with StatusOK to fuflil the interface contract.
// WriteHeader has (probably) not been called, so we need to call it with StatusOK to fulfill the interface contract.
// https://godoc.org/net/http#ResponseWriter
b.WriteHeader(http.StatusOK)
}
Expand All @@ -58,7 +76,7 @@ func (b *badResponseLoggingWriter) Write(data []byte) (int, error) {
}

// WriteHeader writes the HTTP response header.
func (b *badResponseLoggingWriter) WriteHeader(statusCode int) {
func (b *nonFlushingBadResponseLoggingWriter) WriteHeader(statusCode int) {
b.statusCode = statusCode
if statusCode >= 500 {
b.logBody = true
Expand All @@ -67,15 +85,23 @@ func (b *badResponseLoggingWriter) WriteHeader(statusCode int) {
}

// Hijack hijacks the first response writer that is a Hijacker.
func (b *badResponseLoggingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
func (b *nonFlushingBadResponseLoggingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj, ok := b.rw.(http.Hijacker)
if ok {
return hj.Hijack()
}
return nil, nil, fmt.Errorf("badResponseLoggingWriter: can't cast underlying response writer to Hijacker")
}

func (b *badResponseLoggingWriter) captureResponseBody(data []byte) {
func (b *nonFlushingBadResponseLoggingWriter) getStatusCode() int {
return b.statusCode
}

func (b *nonFlushingBadResponseLoggingWriter) getWriteError() error {
return b.writeError
}

func (b *nonFlushingBadResponseLoggingWriter) captureResponseBody(data []byte) {
if len(data) > b.bodyBytesLeft {
b.buffer.Write(data[:b.bodyBytesLeft])
io.WriteString(b.buffer, "...")
Expand All @@ -86,3 +112,7 @@ func (b *badResponseLoggingWriter) captureResponseBody(data []byte) {
b.bodyBytesLeft -= len(data)
}
}

func (b *flushingBadResponseLoggingWriter) Flush() {
b.f.Flush()
}
42 changes: 40 additions & 2 deletions middleware/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,50 @@ func TestBadResponseLoggingWriter(t *testing.T) {
default:
http.Error(wrapped, tc.data, tc.statusCode)
}
if wrapped.statusCode != tc.statusCode {
t.Errorf("Wrong status code: have %d want %d", wrapped.statusCode, tc.statusCode)
if wrapped.getStatusCode() != tc.statusCode {
t.Errorf("Wrong status code: have %d want %d", wrapped.getStatusCode(), tc.statusCode)
}
data := string(buf.Bytes())
if data != tc.expected {
t.Errorf("Wrong data: have %q want %q", data, tc.expected)
}
}
}

// nonFlushingResponseWriter implements http.ResponseWriter but does not implement http.Flusher
type nonFlushingResponseWriter struct{}

func (rw *nonFlushingResponseWriter) Header() http.Header {
return nil
}

func (rw *nonFlushingResponseWriter) Write(_ []byte) (int, error) {
return -1, nil
}

func (rw *nonFlushingResponseWriter) WriteHeader(_ int) {
}

func TestBadResponseLoggingWriter_WithAndWithoutFlusher(t *testing.T) {
var buf bytes.Buffer

nf := newBadResponseLoggingWriter(&nonFlushingResponseWriter{}, &buf)

_, ok := nf.(http.Flusher)
if ok {
t.Errorf("Should not be able to cast nf as an http.Flusher")
}

rec := httptest.NewRecorder()
f := newBadResponseLoggingWriter(rec, &buf)

ff, ok := f.(http.Flusher)
if !ok {
t.Errorf("Should be able to cast f as an http.Flusher")
}

ff.Flush()
if !rec.Flushed {
t.Errorf("Flush should have worked but did not")
}
}

0 comments on commit 7b5f6f3

Please sign in to comment.