Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gzhttp: Always delete HeaderNoCompression #683

Merged
merged 2 commits into from
Oct 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 72 additions & 6 deletions gzhttp/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ func (w *GzipResponseWriter) Write(b []byte) (int, error) {
}
w.buf = append(w.buf, b[:toAdd]...)
remain := b[toAdd:]
hdr := w.Header()

// Only continue if they didn't already choose an encoding or a known unhandled content length or type.
if len(w.Header()[HeaderNoCompression]) == 0 && w.Header().Get(contentEncoding) == "" && w.Header().Get(contentRange) == "" {
if len(hdr[HeaderNoCompression]) == 0 && hdr.Get(contentEncoding) == "" && hdr.Get(contentRange) == "" {
// Check more expensive parts now.
cl, _ := atoi(w.Header().Get(contentLength))
ct := w.Header().Get(contentType)
cl, _ := atoi(hdr.Get(contentLength))
ct := hdr.Get(contentType)
if cl == 0 || cl >= w.minSize && (ct == "" || w.contentTypeFilter(ct)) {
// If the current buffer is less than minSize and a Content-Length isn't set, then wait until we have more data.
if len(w.buf) < w.minSize && cl == 0 {
Expand All @@ -121,8 +122,8 @@ func (w *GzipResponseWriter) Write(b []byte) (int, error) {

// Handles the intended case of setting a nil Content-Type (as for http/server or http/fs)
// Set the header only if the key does not exist
if _, ok := w.Header()[contentType]; w.setContentType && !ok {
w.Header().Set(contentType, ct)
if _, ok := hdr[contentType]; w.setContentType && !ok {
hdr.Set(contentType, ct)
}

// If the Content-Type is acceptable to GZIP, initialize the GZIP writer.
Expand Down Expand Up @@ -388,7 +389,8 @@ func NewWrapper(opts ...option) (func(http.Handler) http.HandlerFunc, error) {
h.ServeHTTP(gw, r)
}
} else {
h.ServeHTTP(w, r)
h.ServeHTTP(newNoCompressResponseWriter(w), r)
w.Header().Del(HeaderNoCompression)
}
}
}, nil
Expand Down Expand Up @@ -743,3 +745,67 @@ func atoi(s string) (int, bool) {
i64, err := strconv.ParseInt(s, 10, 0)
return int(i64), err == nil
}

// newNoCompressResponseWriter will return a response writer that
// cleans up compression artifacts.
// Depending on whether http.Hijacker is supported the returned will as well.
func newNoCompressResponseWriter(w http.ResponseWriter) http.ResponseWriter {
n := &noCompressResponseWriter{hw: w}
if hj, ok := w.(http.Hijacker); ok {
x := struct {
http.ResponseWriter
http.Hijacker
http.Flusher
}{
ResponseWriter: n,
Hijacker: hj,
Flusher: n,
}
return x
}

return n
}

// noCompressResponseWriter filters out HeaderNoCompression.
type noCompressResponseWriter struct {
hw http.ResponseWriter
hdrCleaned bool
}

func (n *noCompressResponseWriter) CloseNotify() <-chan bool {
if cn, ok := n.hw.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
return nil
}

func (n *noCompressResponseWriter) Flush() {
if !n.hdrCleaned {
n.hw.Header().Del(HeaderNoCompression)
n.hdrCleaned = true
}
if f, ok := n.hw.(http.Flusher); ok {
f.Flush()
}
}

func (n *noCompressResponseWriter) Header() http.Header {
return n.hw.Header()
}

func (n *noCompressResponseWriter) Write(bytes []byte) (int, error) {
if !n.hdrCleaned {
n.hw.Header().Del(HeaderNoCompression)
n.hdrCleaned = true
}
return n.hw.Write(bytes)
}

func (n *noCompressResponseWriter) WriteHeader(statusCode int) {
if !n.hdrCleaned {
n.hw.Header().Del(HeaderNoCompression)
n.hdrCleaned = true
}
n.hw.WriteHeader(statusCode)
}
66 changes: 65 additions & 1 deletion gzhttp/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,9 +748,9 @@ func TestContentTypes(t *testing.T) {
})
t.Run("disable-"+tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", tt.contentType)
w.Header().Set(HeaderNoCompression, "plz")
w.WriteHeader(http.StatusOK)
w.Write(testBody)
})

Expand All @@ -765,6 +765,70 @@ func TestContentTypes(t *testing.T) {

assertEqual(t, 200, res.StatusCode)
assertNotEqual(t, "gzip", res.Header.Get("Content-Encoding"))
_, ok := res.Header[HeaderNoCompression]
assertEqual(t, false, ok)
})
t.Run("head-req"+tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", tt.contentType)
w.Header().Set(HeaderNoCompression, "plz")
w.WriteHeader(http.StatusOK)
})

wrapper, err := NewWrapper(ContentTypes(tt.acceptedContentTypes))
assertNil(t, err)

req, _ := http.NewRequest("HEAD", "/whatever", nil)
req.Header.Set("Accept-Encoding", "gzip")
resp := httptest.NewRecorder()
wrapper(handler).ServeHTTP(resp, req)
res := resp.Result()

assertEqual(t, 200, res.StatusCode)
assertNotEqual(t, "gzip", res.Header.Get("Content-Encoding"))
_, ok := res.Header[HeaderNoCompression]
assertEqual(t, false, ok)
})
t.Run("head-req-no-ok"+tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", tt.contentType)
w.Header().Set(HeaderNoCompression, "plz")
})

wrapper, err := NewWrapper(ContentTypes(tt.acceptedContentTypes))
assertNil(t, err)

req, _ := http.NewRequest("HEAD", "/whatever", nil)
req.Header.Set("Accept-Encoding", "gzip")
resp := httptest.NewRecorder()
wrapper(handler).ServeHTTP(resp, req)
res := resp.Result()

assertEqual(t, 200, res.StatusCode)
assertNotEqual(t, "gzip", res.Header.Get("Content-Encoding"))
_, ok := res.Header[HeaderNoCompression]
assertEqual(t, false, ok)
})
t.Run("req-no-ok-write"+tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", tt.contentType)
w.Header().Set(HeaderNoCompression, "plz")
w.Write(testBody)
})

wrapper, err := NewWrapper(ContentTypes(tt.acceptedContentTypes))
assertNil(t, err)

req, _ := http.NewRequest("GET", "/whatever", nil)
req.Header.Set("Accept-Encoding", "")
resp := httptest.NewRecorder()
wrapper(handler).ServeHTTP(resp, req)
res := resp.Result()

assertEqual(t, 200, res.StatusCode)
assertNotEqual(t, "gzip", res.Header.Get("Content-Encoding"))
_, ok := res.Header[HeaderNoCompression]
assertEqual(t, false, ok)
})
}
}
Expand Down