From 55df21feb695816be8d7076bb7d3f3e653108c5d Mon Sep 17 00:00:00 2001 From: Muir Manders Date: Wed, 19 Aug 2020 21:27:44 -0700 Subject: [PATCH] Use httpsnoop to wrap ResponseWriter. (#193) Wrapping http.ResponseWriter is fraught with danger. Our compress handler made sure to implement all the optional ResponseWriter interfaces, but that made it implement them even if the underlying writer did not. For example, if the underlying ResponseWriter was _not_ an http.Hijacker, the compress writer nonetheless appeared to implement http.Hijacker, but would panic if you called Hijack(). On the other hand, the logging handler checked for certain combinations of optional interfaces and only implemented them as appropriate. However, it didn't check for all optional interfaces or all combinations, so most optional interfaces would still get lost. Fix both problems by using httpsnoop to do the wrapping. It uses code generation to ensure correctness, and it handles std lib changes like the http.Pusher addition in Go 1.8. Fixes #169. --- compress.go | 65 ++++++++++++++++++------------------------- compress_test.go | 38 ++++++++++++++++++++----- go.mod | 2 ++ go.sum | 2 ++ handlers.go | 35 +++-------------------- handlers_go18.go | 29 ------------------- handlers_go18_test.go | 13 +++++++-- handlers_pre18.go | 7 ----- logging.go | 39 ++++++++++---------------- logging_test.go | 13 +++++---- 10 files changed, 98 insertions(+), 145 deletions(-) create mode 100644 go.sum delete mode 100644 handlers_go18.go delete mode 100644 handlers_pre18.go diff --git a/compress.go b/compress.go index a70c044..64e825a 100644 --- a/compress.go +++ b/compress.go @@ -10,35 +10,30 @@ import ( "io" "net/http" "strings" + + "github.com/felixge/httpsnoop" ) const acceptEncoding string = "Accept-Encoding" type compressResponseWriter struct { - io.Writer - http.ResponseWriter - http.Hijacker - http.Flusher - http.CloseNotifier -} - -func (w *compressResponseWriter) WriteHeader(c int) { - w.ResponseWriter.Header().Del("Content-Length") - w.ResponseWriter.WriteHeader(c) + compressor io.Writer + w http.ResponseWriter } -func (w *compressResponseWriter) Header() http.Header { - return w.ResponseWriter.Header() +func (cw *compressResponseWriter) WriteHeader(c int) { + cw.w.Header().Del("Content-Length") + cw.w.WriteHeader(c) } -func (w *compressResponseWriter) Write(b []byte) (int, error) { - h := w.ResponseWriter.Header() +func (cw *compressResponseWriter) Write(b []byte) (int, error) { + h := cw.w.Header() if h.Get("Content-Type") == "" { h.Set("Content-Type", http.DetectContentType(b)) } h.Del("Content-Length") - return w.Writer.Write(b) + return cw.compressor.Write(b) } type flusher interface { @@ -47,12 +42,12 @@ type flusher interface { func (w *compressResponseWriter) Flush() { // Flush compressed data if compressor supports it. - if f, ok := w.Writer.(flusher); ok { + if f, ok := w.compressor.(flusher); ok { f.Flush() } // Flush HTTP response. - if w.Flusher != nil { - w.Flusher.Flush() + if f, ok := w.w.(http.Flusher); ok { + f.Flush() } } @@ -119,28 +114,22 @@ func CompressHandlerLevel(h http.Handler, level int) http.Handler { w.Header().Set("Content-Encoding", encoding) r.Header.Del(acceptEncoding) - hijacker, ok := w.(http.Hijacker) - if !ok { /* w is not Hijacker... oh well... */ - hijacker = nil + cw := &compressResponseWriter{ + w: w, + compressor: encWriter, } - flusher, ok := w.(http.Flusher) - if !ok { - flusher = nil - } - - closeNotifier, ok := w.(http.CloseNotifier) - if !ok { - closeNotifier = nil - } - - w = &compressResponseWriter{ - Writer: encWriter, - ResponseWriter: w, - Hijacker: hijacker, - Flusher: flusher, - CloseNotifier: closeNotifier, - } + w = httpsnoop.Wrap(w, httpsnoop.Hooks{ + Write: func(httpsnoop.WriteFunc) httpsnoop.WriteFunc { + return cw.Write + }, + WriteHeader: func(httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc { + return cw.WriteHeader + }, + Flush: func(httpsnoop.FlushFunc) httpsnoop.FlushFunc { + return cw.Flush + }, + }) h.ServeHTTP(w, r) }) diff --git a/compress_test.go b/compress_test.go index dbce929..adc2b8b 100644 --- a/compress_test.go +++ b/compress_test.go @@ -29,7 +29,6 @@ func compressedRequest(w *httptest.ResponseRecorder, compression string) { acceptEncoding: []string{compression}, }, }) - } func TestCompressHandlerNoCompression(t *testing.T) { @@ -165,6 +164,7 @@ type fullyFeaturedResponseWriter struct{} func (fullyFeaturedResponseWriter) Header() http.Header { return http.Header{} } + func (fullyFeaturedResponseWriter) Write([]byte) (int, error) { return 0, nil } @@ -193,9 +193,6 @@ func TestCompressHandlerPreserveInterfaces(t *testing.T) { ) var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { comp := r.Header.Get(acceptEncoding) - if _, ok := rw.(*compressResponseWriter); !ok { - t.Fatalf("ResponseWriter wasn't wrapped by compressResponseWriter, got %T type", rw) - } if _, ok := rw.(http.Flusher); !ok { t.Errorf("ResponseWriter lost http.Flusher interface for %q", comp) } @@ -207,9 +204,7 @@ func TestCompressHandlerPreserveInterfaces(t *testing.T) { } }) h = CompressHandler(h) - var ( - rw fullyFeaturedResponseWriter - ) + var rw fullyFeaturedResponseWriter r, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("Failed to create test request: %v", err) @@ -220,3 +215,32 @@ func TestCompressHandlerPreserveInterfaces(t *testing.T) { r.Header.Set(acceptEncoding, "deflate") h.ServeHTTP(rw, r) } + +type paltryResponseWriter struct{} + +func (paltryResponseWriter) Header() http.Header { + return http.Header{} +} + +func (paltryResponseWriter) Write([]byte) (int, error) { + return 0, nil +} +func (paltryResponseWriter) WriteHeader(int) {} + +func TestCompressHandlerDoesntInventInterfaces(t *testing.T) { + var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if _, ok := rw.(http.Hijacker); ok { + t.Error("ResponseWriter shouldn't implement http.Hijacker") + } + }) + + h = CompressHandler(h) + + var rw paltryResponseWriter + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create test request: %v", err) + } + r.Header.Set(acceptEncoding, "gzip") + h.ServeHTTP(rw, r) +} diff --git a/go.mod b/go.mod index 41136b4..58e6a85 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/gorilla/handlers go 1.14 + +require github.com/felixge/httpsnoop v1.0.1 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8c26458 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= +github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= diff --git a/handlers.go b/handlers.go index d03f2bf..0509482 100644 --- a/handlers.go +++ b/handlers.go @@ -51,10 +51,6 @@ type responseLogger struct { size int } -func (l *responseLogger) Header() http.Header { - return l.w.Header() -} - func (l *responseLogger) Write(b []byte) (int, error) { size, err := l.w.Write(b) l.size += size @@ -74,39 +70,16 @@ func (l *responseLogger) Size() int { return l.size } -func (l *responseLogger) Flush() { - f, ok := l.w.(http.Flusher) - if ok { - f.Flush() - } -} - -type hijackLogger struct { - responseLogger -} - -func (l *hijackLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) { - h := l.responseLogger.w.(http.Hijacker) - conn, rw, err := h.Hijack() - if err == nil && l.responseLogger.status == 0 { +func (l *responseLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) { + conn, rw, err := l.w.(http.Hijacker).Hijack() + if err == nil && l.status == 0 { // The status will be StatusSwitchingProtocols if there was no error and // WriteHeader has not been called yet - l.responseLogger.status = http.StatusSwitchingProtocols + l.status = http.StatusSwitchingProtocols } return conn, rw, err } -type closeNotifyWriter struct { - loggingResponseWriter - http.CloseNotifier -} - -type hijackCloseNotifier struct { - loggingResponseWriter - http.Hijacker - http.CloseNotifier -} - // isContentType validates the Content-Type header matches the supplied // contentType. That is, its type and subtype match. func isContentType(h http.Header, contentType string) bool { diff --git a/handlers_go18.go b/handlers_go18.go deleted file mode 100644 index 40f6914..0000000 --- a/handlers_go18.go +++ /dev/null @@ -1,29 +0,0 @@ -// +build go1.8 - -package handlers - -import ( - "fmt" - "net/http" -) - -type loggingResponseWriter interface { - commonLoggingResponseWriter - http.Pusher -} - -func (l *responseLogger) Push(target string, opts *http.PushOptions) error { - p, ok := l.w.(http.Pusher) - if !ok { - return fmt.Errorf("responseLogger does not implement http.Pusher") - } - return p.Push(target, opts) -} - -func (c *compressResponseWriter) Push(target string, opts *http.PushOptions) error { - p, ok := c.ResponseWriter.(http.Pusher) - if !ok { - return fmt.Errorf("compressResponseWriter does not implement http.Pusher") - } - return p.Push(target, opts) -} diff --git a/handlers_go18_test.go b/handlers_go18_test.go index c8cfa72..d8e6321 100644 --- a/handlers_go18_test.go +++ b/handlers_go18_test.go @@ -9,6 +9,15 @@ import ( "testing" ) +// *httptest.ResponseRecorder doesn't implement Pusher, so wrap it. +type pushRecorder struct { + *httptest.ResponseRecorder +} + +func (pr pushRecorder) Push(target string, opts *http.PushOptions) error { + return nil +} + func TestLoggingHandlerWithPush(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if _, ok := w.(http.Pusher); !ok { @@ -18,7 +27,7 @@ func TestLoggingHandlerWithPush(t *testing.T) { }) logger := LoggingHandler(ioutil.Discard, handler) - logger.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/")) + logger.ServeHTTP(pushRecorder{httptest.NewRecorder()}, newRequest("GET", "/")) } func TestCombinedLoggingHandlerWithPush(t *testing.T) { @@ -30,5 +39,5 @@ func TestCombinedLoggingHandlerWithPush(t *testing.T) { }) logger := CombinedLoggingHandler(ioutil.Discard, handler) - logger.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/")) + logger.ServeHTTP(pushRecorder{httptest.NewRecorder()}, newRequest("GET", "/")) } diff --git a/handlers_pre18.go b/handlers_pre18.go deleted file mode 100644 index 197836a..0000000 --- a/handlers_pre18.go +++ /dev/null @@ -1,7 +0,0 @@ -// +build !go1.8 - -package handlers - -type loggingResponseWriter interface { - commonLoggingResponseWriter -} diff --git a/logging.go b/logging.go index 88c25e7..228465e 100644 --- a/logging.go +++ b/logging.go @@ -12,6 +12,8 @@ import ( "strconv" "time" "unicode/utf8" + + "github.com/felixge/httpsnoop" ) // Logging @@ -39,10 +41,10 @@ type loggingHandler struct { func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { t := time.Now() - logger := makeLogger(w) + logger, w := makeLogger(w) url := *req.URL - h.handler.ServeHTTP(logger, req) + h.handler.ServeHTTP(w, req) if req.MultipartForm != nil { req.MultipartForm.RemoveAll() } @@ -58,27 +60,16 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { h.formatter(h.writer, params) } -func makeLogger(w http.ResponseWriter) loggingResponseWriter { - var logger loggingResponseWriter = &responseLogger{w: w, status: http.StatusOK} - if _, ok := w.(http.Hijacker); ok { - logger = &hijackLogger{responseLogger{w: w, status: http.StatusOK}} - } - h, ok1 := logger.(http.Hijacker) - c, ok2 := w.(http.CloseNotifier) - if ok1 && ok2 { - return hijackCloseNotifier{logger, h, c} - } - if ok2 { - return &closeNotifyWriter{logger, c} - } - return logger -} - -type commonLoggingResponseWriter interface { - http.ResponseWriter - http.Flusher - Status() int - Size() int +func makeLogger(w http.ResponseWriter) (*responseLogger, http.ResponseWriter) { + logger := &responseLogger{w: w, status: http.StatusOK} + return logger, httpsnoop.Wrap(w, httpsnoop.Hooks{ + Write: func(httpsnoop.WriteFunc) httpsnoop.WriteFunc { + return logger.Write + }, + WriteHeader: func(httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc { + return logger.WriteHeader + }, + }) } const lowerhex = "0123456789abcdef" @@ -145,7 +136,6 @@ func appendQuoted(buf []byte, s string) []byte { } } return buf - } // buildCommonLogLine builds a log entry for req in Apache Common Log Format. @@ -160,7 +150,6 @@ func buildCommonLogLine(req *http.Request, url url.URL, ts time.Time, status int } host, _, err := net.SplitHostPort(req.RemoteAddr) - if err != nil { host = req.RemoteAddr } diff --git a/logging_test.go b/logging_test.go index 00f027e..13b8369 100644 --- a/logging_test.go +++ b/logging_test.go @@ -23,24 +23,24 @@ import ( func TestMakeLogger(t *testing.T) { rec := httptest.NewRecorder() - logger := makeLogger(rec) + logger, w := makeLogger(rec) // initial status if logger.Status() != http.StatusOK { t.Fatalf("wrong status, got %d want %d", logger.Status(), http.StatusOK) } // WriteHeader - logger.WriteHeader(http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) if logger.Status() != http.StatusInternalServerError { t.Fatalf("wrong status, got %d want %d", logger.Status(), http.StatusInternalServerError) } // Write - logger.Write([]byte(ok)) + w.Write([]byte(ok)) if logger.Size() != len(ok) { t.Fatalf("wrong size, got %d want %d", logger.Size(), len(ok)) } // Header - logger.Header().Set("key", "value") - if val := logger.Header().Get("key"); val != "value" { + w.Header().Set("key", "value") + if val := w.Header().Get("key"); val != "value" { t.Fatalf("wrong header, got %s want %s", val, "value") } } @@ -202,6 +202,7 @@ func TestLogFormatterWriteLog_Scenario4(t *testing.T) { expected := "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"GET /test?abc=hello%20world&a=b%3F HTTP/1.1\" 200 100\n" LoggingScenario4(t, formatter, expected) } + func TestLogFormatterCombinedLog_Scenario5(t *testing.T) { formatter := writeCombinedLog expected := "::1 - kamil [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 200 100 \"http://example.com\" " + @@ -289,6 +290,7 @@ func LoggingScenario3(t *testing.T, formatter LogFormatter, expected string) { t.Fatalf("wrong log, got %q want %q", log, expected) } } + func LoggingScenario4(t *testing.T, formatter LogFormatter, expected string) { loc, err := time.LoadLocation("Europe/Warsaw") if err != nil { @@ -357,7 +359,6 @@ func constructTypicalRequestOk() *http.Request { // CONNECT request over http/2.0 func constructConnectRequest() *http.Request { - req := &http.Request{ Method: "CONNECT", Host: "www.example.com:443",