diff --git a/gzhttp/compress.go b/gzhttp/compress.go index a7950b39ad..265e71c062 100644 --- a/gzhttp/compress.go +++ b/gzhttp/compress.go @@ -302,7 +302,15 @@ func (w *GzipResponseWriter) startPlain() error { } // WriteHeader just saves the response code until close or GZIP effective writes. +// In the specific case of 1xx status codes, WriteHeader is directly calling the wrapped ResponseWriter. func (w *GzipResponseWriter) WriteHeader(code int) { + // Handle informational headers + // This is gated to not forward 1xx responses on builds prior to go1.20. + if shouldWrite1xxResponses() && code >= 100 && code <= 199 { + w.ResponseWriter.WriteHeader(code) + return + } + if w.code == 0 { w.code = code } diff --git a/gzhttp/compress_go119.go b/gzhttp/compress_go119.go new file mode 100644 index 0000000000..97fc25acbc --- /dev/null +++ b/gzhttp/compress_go119.go @@ -0,0 +1,9 @@ +//go:build !go1.20 +// +build !go1.20 + +package gzhttp + +// shouldWrite1xxResponses indicates whether the current build supports writes of 1xx status codes. +func shouldWrite1xxResponses() bool { + return false +} diff --git a/gzhttp/compress_go120.go b/gzhttp/compress_go120.go new file mode 100644 index 0000000000..2b65f67c79 --- /dev/null +++ b/gzhttp/compress_go120.go @@ -0,0 +1,9 @@ +//go:build go1.20 +// +build go1.20 + +package gzhttp + +// shouldWrite1xxResponses indicates whether the current build supports writes of 1xx status codes. +func shouldWrite1xxResponses() bool { + return true +} diff --git a/gzhttp/compress_test.go b/gzhttp/compress_test.go index 8d595a9c0d..fc19723892 100644 --- a/gzhttp/compress_test.go +++ b/gzhttp/compress_test.go @@ -2,12 +2,15 @@ package gzhttp import ( "bytes" + "context" "fmt" "io" "math/rand" "net" "net/http" "net/http/httptest" + "net/http/httptrace" + "net/textproto" "net/url" "os" "strconv" @@ -1796,3 +1799,87 @@ func TestGzipHandlerNilContentType(t *testing.T) { assertEqual(t, "", res.Header().Get("Content-Type")) } + +// This test is an adapted version of net/http/httputil.Test1xxResponses test. +func Test1xxResponses(t *testing.T) { + // do not test 1xx responses on builds prior to go1.20. + if !shouldWrite1xxResponses() { + return + } + + wrapper, _ := NewWrapper() + handler := wrapper(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Add("Link", "; rel=preload; as=style") + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusProcessing) + + w.Write(testBody) + }, + )) + + frontend := httptest.NewServer(handler) + defer frontend.Close() + frontendClient := frontend.Client() + + checkLinkHeaders := func(t *testing.T, expected, got []string) { + t.Helper() + + if len(expected) != len(got) { + t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) + } + + for i := range expected { + if i >= len(got) { + t.Errorf("Expected %q link header; got nothing", expected[i]) + + continue + } + + if expected[i] != got[i] { + t.Errorf("Expected %q link header; got %q", expected[i], got[i]) + } + } + } + + var respCounter uint8 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusEarlyHints: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script"}, header["Link"]) + case http.StatusProcessing: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, header["Link"]) + default: + t.Error("Unexpected 1xx response") + } + + respCounter++ + + return nil + }, + } + req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil) + req.Header.Set("Accept-Encoding", "gzip") + + res, err := frontendClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + + defer res.Body.Close() + + if respCounter != 2 { + t.Errorf("Expected 2 1xx responses; got %d", respCounter) + } + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, res.Header["Link"]) + + assertEqual(t, "gzip", res.Header.Get("Content-Encoding")) + + body, _ := io.ReadAll(res.Body) + assertEqual(t, gzipStrLevel(testBody, gzip.DefaultCompression), body) +}