Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use httpsnoop to wrap ResponseWriter. #193

Merged
merged 1 commit into from
Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 27 additions & 38 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,30 @@ import (
"io"
"net/http"
"strings"

"github.com/felixge/httpsnoop"
)

const acceptEncoding string = "Accept-Encoding"

type compressResponseWriter struct {
io.Writer
http.ResponseWriter
http.Hijacker
http.Flusher
http.CloseNotifier
}

func (w *compressResponseWriter) WriteHeader(c int) {
w.ResponseWriter.Header().Del("Content-Length")
w.ResponseWriter.WriteHeader(c)
compressor io.Writer
w http.ResponseWriter
}

func (w *compressResponseWriter) Header() http.Header {
return w.ResponseWriter.Header()
func (cw *compressResponseWriter) WriteHeader(c int) {
cw.w.Header().Del("Content-Length")
cw.w.WriteHeader(c)
}

func (w *compressResponseWriter) Write(b []byte) (int, error) {
h := w.ResponseWriter.Header()
func (cw *compressResponseWriter) Write(b []byte) (int, error) {
h := cw.w.Header()
if h.Get("Content-Type") == "" {
h.Set("Content-Type", http.DetectContentType(b))
}
h.Del("Content-Length")

return w.Writer.Write(b)
return cw.compressor.Write(b)
}

type flusher interface {
Expand All @@ -47,12 +42,12 @@ type flusher interface {

func (w *compressResponseWriter) Flush() {
// Flush compressed data if compressor supports it.
if f, ok := w.Writer.(flusher); ok {
if f, ok := w.compressor.(flusher); ok {
f.Flush()
}
// Flush HTTP response.
if w.Flusher != nil {
w.Flusher.Flush()
if f, ok := w.w.(http.Flusher); ok {
f.Flush()
}
}

Expand Down Expand Up @@ -119,28 +114,22 @@ func CompressHandlerLevel(h http.Handler, level int) http.Handler {
w.Header().Set("Content-Encoding", encoding)
r.Header.Del(acceptEncoding)

hijacker, ok := w.(http.Hijacker)
if !ok { /* w is not Hijacker... oh well... */
hijacker = nil
cw := &compressResponseWriter{
w: w,
compressor: encWriter,
}

flusher, ok := w.(http.Flusher)
if !ok {
flusher = nil
}

closeNotifier, ok := w.(http.CloseNotifier)
if !ok {
closeNotifier = nil
}

w = &compressResponseWriter{
Writer: encWriter,
ResponseWriter: w,
Hijacker: hijacker,
Flusher: flusher,
CloseNotifier: closeNotifier,
}
w = httpsnoop.Wrap(w, httpsnoop.Hooks{
Write: func(httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return cw.Write
},
WriteHeader: func(httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
return cw.WriteHeader
},
Flush: func(httpsnoop.FlushFunc) httpsnoop.FlushFunc {
return cw.Flush
},
})

h.ServeHTTP(w, r)
})
Expand Down
38 changes: 31 additions & 7 deletions compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ func compressedRequest(w *httptest.ResponseRecorder, compression string) {
acceptEncoding: []string{compression},
},
})

}

func TestCompressHandlerNoCompression(t *testing.T) {
Expand Down Expand Up @@ -165,6 +164,7 @@ type fullyFeaturedResponseWriter struct{}
func (fullyFeaturedResponseWriter) Header() http.Header {
return http.Header{}
}

func (fullyFeaturedResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
Expand Down Expand Up @@ -193,9 +193,6 @@ func TestCompressHandlerPreserveInterfaces(t *testing.T) {
)
var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
comp := r.Header.Get(acceptEncoding)
if _, ok := rw.(*compressResponseWriter); !ok {
t.Fatalf("ResponseWriter wasn't wrapped by compressResponseWriter, got %T type", rw)
}
if _, ok := rw.(http.Flusher); !ok {
t.Errorf("ResponseWriter lost http.Flusher interface for %q", comp)
}
Expand All @@ -207,9 +204,7 @@ func TestCompressHandlerPreserveInterfaces(t *testing.T) {
}
})
h = CompressHandler(h)
var (
rw fullyFeaturedResponseWriter
)
var rw fullyFeaturedResponseWriter
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create test request: %v", err)
Expand All @@ -220,3 +215,32 @@ func TestCompressHandlerPreserveInterfaces(t *testing.T) {
r.Header.Set(acceptEncoding, "deflate")
h.ServeHTTP(rw, r)
}

type paltryResponseWriter struct{}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha.


func (paltryResponseWriter) Header() http.Header {
return http.Header{}
}

func (paltryResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (paltryResponseWriter) WriteHeader(int) {}

func TestCompressHandlerDoesntInventInterfaces(t *testing.T) {
var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if _, ok := rw.(http.Hijacker); ok {
t.Error("ResponseWriter shouldn't implement http.Hijacker")
}
})

h = CompressHandler(h)

var rw paltryResponseWriter
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create test request: %v", err)
}
r.Header.Set(acceptEncoding, "gzip")
h.ServeHTTP(rw, r)
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/gorilla/handlers

go 1.14

require github.com/felixge/httpsnoop v1.0.1
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ=
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
35 changes: 4 additions & 31 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ type responseLogger struct {
size int
}

func (l *responseLogger) Header() http.Header {
return l.w.Header()
}

func (l *responseLogger) Write(b []byte) (int, error) {
size, err := l.w.Write(b)
l.size += size
Expand All @@ -74,39 +70,16 @@ func (l *responseLogger) Size() int {
return l.size
}

func (l *responseLogger) Flush() {
f, ok := l.w.(http.Flusher)
if ok {
f.Flush()
}
}

type hijackLogger struct {
responseLogger
}

func (l *hijackLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h := l.responseLogger.w.(http.Hijacker)
conn, rw, err := h.Hijack()
if err == nil && l.responseLogger.status == 0 {
func (l *responseLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
conn, rw, err := l.w.(http.Hijacker).Hijack()
if err == nil && l.status == 0 {
// The status will be StatusSwitchingProtocols if there was no error and
// WriteHeader has not been called yet
l.responseLogger.status = http.StatusSwitchingProtocols
l.status = http.StatusSwitchingProtocols
}
return conn, rw, err
}

type closeNotifyWriter struct {
loggingResponseWriter
http.CloseNotifier
}

type hijackCloseNotifier struct {
loggingResponseWriter
http.Hijacker
http.CloseNotifier
}

// isContentType validates the Content-Type header matches the supplied
// contentType. That is, its type and subtype match.
func isContentType(h http.Header, contentType string) bool {
Expand Down
29 changes: 0 additions & 29 deletions handlers_go18.go

This file was deleted.

13 changes: 11 additions & 2 deletions handlers_go18_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ import (
"testing"
)

// *httptest.ResponseRecorder doesn't implement Pusher, so wrap it.
type pushRecorder struct {
*httptest.ResponseRecorder
}

func (pr pushRecorder) Push(target string, opts *http.PushOptions) error {
return nil
}

func TestLoggingHandlerWithPush(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if _, ok := w.(http.Pusher); !ok {
Expand All @@ -18,7 +27,7 @@ func TestLoggingHandlerWithPush(t *testing.T) {
})

logger := LoggingHandler(ioutil.Discard, handler)
logger.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/"))
logger.ServeHTTP(pushRecorder{httptest.NewRecorder()}, newRequest("GET", "/"))
}

func TestCombinedLoggingHandlerWithPush(t *testing.T) {
Expand All @@ -30,5 +39,5 @@ func TestCombinedLoggingHandlerWithPush(t *testing.T) {
})

logger := CombinedLoggingHandler(ioutil.Discard, handler)
logger.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/"))
logger.ServeHTTP(pushRecorder{httptest.NewRecorder()}, newRequest("GET", "/"))
}
7 changes: 0 additions & 7 deletions handlers_pre18.go

This file was deleted.

39 changes: 14 additions & 25 deletions logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"strconv"
"time"
"unicode/utf8"

"github.com/felixge/httpsnoop"
)

// Logging
Expand Down Expand Up @@ -39,10 +41,10 @@ type loggingHandler struct {

func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
t := time.Now()
logger := makeLogger(w)
logger, w := makeLogger(w)
url := *req.URL

h.handler.ServeHTTP(logger, req)
h.handler.ServeHTTP(w, req)
if req.MultipartForm != nil {
req.MultipartForm.RemoveAll()
}
Expand All @@ -58,27 +60,16 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
h.formatter(h.writer, params)
}

func makeLogger(w http.ResponseWriter) loggingResponseWriter {
var logger loggingResponseWriter = &responseLogger{w: w, status: http.StatusOK}
if _, ok := w.(http.Hijacker); ok {
logger = &hijackLogger{responseLogger{w: w, status: http.StatusOK}}
}
h, ok1 := logger.(http.Hijacker)
c, ok2 := w.(http.CloseNotifier)
if ok1 && ok2 {
return hijackCloseNotifier{logger, h, c}
}
if ok2 {
return &closeNotifyWriter{logger, c}
}
return logger
}

type commonLoggingResponseWriter interface {
http.ResponseWriter
http.Flusher
Status() int
Size() int
func makeLogger(w http.ResponseWriter) (*responseLogger, http.ResponseWriter) {
logger := &responseLogger{w: w, status: http.StatusOK}
return logger, httpsnoop.Wrap(w, httpsnoop.Hooks{
Write: func(httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return logger.Write
},
WriteHeader: func(httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
return logger.WriteHeader
},
})
}

const lowerhex = "0123456789abcdef"
Expand Down Expand Up @@ -145,7 +136,6 @@ func appendQuoted(buf []byte, s string) []byte {
}
}
return buf

}

// buildCommonLogLine builds a log entry for req in Apache Common Log Format.
Expand All @@ -160,7 +150,6 @@ func buildCommonLogLine(req *http.Request, url url.URL, ts time.Time, status int
}

host, _, err := net.SplitHostPort(req.RemoteAddr)

if err != nil {
host = req.RemoteAddr
}
Expand Down
Loading