From 7b5f6f3e74bca9a1843c04ebc286650f22790a9a Mon Sep 17 00:00:00 2001 From: Justin Lei <97976793+leizor@users.noreply.github.com> Date: Thu, 15 Sep 2022 10:11:48 -0700 Subject: [PATCH] Implement `http.Flusher` interface on Log middleware (#257) This adds `Flush()` to the `middleware.badResponseLoggingWriter`, making it implement `http.Flusher` if the wrapped `http.ResponseWriter` does. --- middleware/logging.go | 2 +- middleware/response.go | 52 +++++++++++++++++++++++++++++-------- middleware/response_test.go | 42 ++++++++++++++++++++++++++++-- 3 files changed, 82 insertions(+), 14 deletions(-) diff --git a/middleware/logging.go b/middleware/logging.go index 015cc3b5..780c72be 100644 --- a/middleware/logging.go +++ b/middleware/logging.go @@ -54,7 +54,7 @@ func (l Log) Wrap(next http.Handler) http.Handler { wrapped := newBadResponseLoggingWriter(w, &buf) next.ServeHTTP(wrapped, r) - statusCode, writeErr := wrapped.statusCode, wrapped.writeError + statusCode, writeErr := wrapped.getStatusCode(), wrapped.getWriteError() if writeErr != nil { if errors.Is(writeErr, context.Canceled) { diff --git a/middleware/response.go b/middleware/response.go index 0192c182..61fc4a72 100644 --- a/middleware/response.go +++ b/middleware/response.go @@ -12,9 +12,15 @@ const ( maxResponseBodyInLogs = 4096 // At most 4k bytes from response bodies in our logs. ) -// badResponseLoggingWriter writes the body of "bad" responses (i.e. 5xx +type badResponseLoggingWriter interface { + http.ResponseWriter + getStatusCode() int + getWriteError() error +} + +// nonFlushingBadResponseLoggingWriter writes the body of "bad" responses (i.e. 5xx // responses) to a buffer. -type badResponseLoggingWriter struct { +type nonFlushingBadResponseLoggingWriter struct { rw http.ResponseWriter buffer io.Writer logBody bool @@ -23,27 +29,39 @@ type badResponseLoggingWriter struct { writeError error // The error returned when downstream Write() fails. } -// newBadResponseLoggingWriter makes a new badResponseLoggingWriter. -func newBadResponseLoggingWriter(rw http.ResponseWriter, buffer io.Writer) *badResponseLoggingWriter { - return &badResponseLoggingWriter{ +// flushingBadResponseLoggingWriter is a badResponseLoggingWriter that +// implements http.Flusher. +type flushingBadResponseLoggingWriter struct { + nonFlushingBadResponseLoggingWriter + f http.Flusher +} + +func newBadResponseLoggingWriter(rw http.ResponseWriter, buffer io.Writer) badResponseLoggingWriter { + b := nonFlushingBadResponseLoggingWriter{ rw: rw, buffer: buffer, logBody: false, bodyBytesLeft: maxResponseBodyInLogs, statusCode: http.StatusOK, } + + if f, ok := rw.(http.Flusher); ok { + return &flushingBadResponseLoggingWriter{b, f} + } + + return &b } // Header returns the header map that will be sent by WriteHeader. // Implements ResponseWriter. -func (b *badResponseLoggingWriter) Header() http.Header { +func (b *nonFlushingBadResponseLoggingWriter) Header() http.Header { return b.rw.Header() } // Write writes HTTP response data. -func (b *badResponseLoggingWriter) Write(data []byte) (int, error) { +func (b *nonFlushingBadResponseLoggingWriter) Write(data []byte) (int, error) { if b.statusCode == 0 { - // WriteHeader has (probably) not been called, so we need to call it with StatusOK to fuflil the interface contract. + // WriteHeader has (probably) not been called, so we need to call it with StatusOK to fulfill the interface contract. // https://godoc.org/net/http#ResponseWriter b.WriteHeader(http.StatusOK) } @@ -58,7 +76,7 @@ func (b *badResponseLoggingWriter) Write(data []byte) (int, error) { } // WriteHeader writes the HTTP response header. -func (b *badResponseLoggingWriter) WriteHeader(statusCode int) { +func (b *nonFlushingBadResponseLoggingWriter) WriteHeader(statusCode int) { b.statusCode = statusCode if statusCode >= 500 { b.logBody = true @@ -67,7 +85,7 @@ func (b *badResponseLoggingWriter) WriteHeader(statusCode int) { } // Hijack hijacks the first response writer that is a Hijacker. -func (b *badResponseLoggingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { +func (b *nonFlushingBadResponseLoggingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { hj, ok := b.rw.(http.Hijacker) if ok { return hj.Hijack() @@ -75,7 +93,15 @@ func (b *badResponseLoggingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) return nil, nil, fmt.Errorf("badResponseLoggingWriter: can't cast underlying response writer to Hijacker") } -func (b *badResponseLoggingWriter) captureResponseBody(data []byte) { +func (b *nonFlushingBadResponseLoggingWriter) getStatusCode() int { + return b.statusCode +} + +func (b *nonFlushingBadResponseLoggingWriter) getWriteError() error { + return b.writeError +} + +func (b *nonFlushingBadResponseLoggingWriter) captureResponseBody(data []byte) { if len(data) > b.bodyBytesLeft { b.buffer.Write(data[:b.bodyBytesLeft]) io.WriteString(b.buffer, "...") @@ -86,3 +112,7 @@ func (b *badResponseLoggingWriter) captureResponseBody(data []byte) { b.bodyBytesLeft -= len(data) } } + +func (b *flushingBadResponseLoggingWriter) Flush() { + b.f.Flush() +} diff --git a/middleware/response_test.go b/middleware/response_test.go index 13aba657..c5f018a4 100644 --- a/middleware/response_test.go +++ b/middleware/response_test.go @@ -31,8 +31,8 @@ func TestBadResponseLoggingWriter(t *testing.T) { default: http.Error(wrapped, tc.data, tc.statusCode) } - if wrapped.statusCode != tc.statusCode { - t.Errorf("Wrong status code: have %d want %d", wrapped.statusCode, tc.statusCode) + if wrapped.getStatusCode() != tc.statusCode { + t.Errorf("Wrong status code: have %d want %d", wrapped.getStatusCode(), tc.statusCode) } data := string(buf.Bytes()) if data != tc.expected { @@ -40,3 +40,41 @@ func TestBadResponseLoggingWriter(t *testing.T) { } } } + +// nonFlushingResponseWriter implements http.ResponseWriter but does not implement http.Flusher +type nonFlushingResponseWriter struct{} + +func (rw *nonFlushingResponseWriter) Header() http.Header { + return nil +} + +func (rw *nonFlushingResponseWriter) Write(_ []byte) (int, error) { + return -1, nil +} + +func (rw *nonFlushingResponseWriter) WriteHeader(_ int) { +} + +func TestBadResponseLoggingWriter_WithAndWithoutFlusher(t *testing.T) { + var buf bytes.Buffer + + nf := newBadResponseLoggingWriter(&nonFlushingResponseWriter{}, &buf) + + _, ok := nf.(http.Flusher) + if ok { + t.Errorf("Should not be able to cast nf as an http.Flusher") + } + + rec := httptest.NewRecorder() + f := newBadResponseLoggingWriter(rec, &buf) + + ff, ok := f.(http.Flusher) + if !ok { + t.Errorf("Should be able to cast f as an http.Flusher") + } + + ff.Flush() + if !rec.Flushed { + t.Errorf("Flush should have worked but did not") + } +}