diff --git a/compress.go b/compress.go index 64e825a..1e95f1c 100644 --- a/compress.go +++ b/compress.go @@ -36,6 +36,10 @@ func (cw *compressResponseWriter) Write(b []byte) (int, error) { return cw.compressor.Write(b) } +func (cw *compressResponseWriter) ReadFrom(r io.Reader) (int64, error) { + return io.Copy(cw.compressor, r) +} + type flusher interface { Flush() error } @@ -129,6 +133,9 @@ func CompressHandlerLevel(h http.Handler, level int) http.Handler { Flush: func(httpsnoop.FlushFunc) httpsnoop.FlushFunc { return cw.Flush }, + ReadFrom: func(rff httpsnoop.ReadFromFunc) httpsnoop.ReadFromFunc { + return cw.ReadFrom + }, }) h.ServeHTTP(w, r) diff --git a/compress_test.go b/compress_test.go index adc2b8b..b9457cd 100644 --- a/compress_test.go +++ b/compress_test.go @@ -6,10 +6,16 @@ package handlers import ( "bufio" + "bytes" + "compress/gzip" "io" + "io/ioutil" "net" "net/http" "net/http/httptest" + "net/url" + "os" + "path/filepath" "strconv" "testing" ) @@ -158,6 +164,54 @@ func TestCompressHandlerGzipDeflate(t *testing.T) { } } +// Make sure we can compress and serve an *os.File properly. We need +// to use a real http server to trigger the net/http sendfile special +// case. +func TestCompressFile(t *testing.T) { + dir, err := ioutil.TempDir("", "gorilla_compress") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + err = ioutil.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hello"), 0644) + if err != nil { + t.Fatal(err) + } + + s := httptest.NewServer(CompressHandler(http.FileServer(http.Dir(dir)))) + defer s.Close() + + url := &url.URL{Scheme: "http", Host: s.Listener.Addr().String(), Path: "/hello.txt"} + req, err := http.NewRequest("GET", url.String(), nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set(acceptEncoding, "gzip") + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + + if res.StatusCode != http.StatusOK { + t.Fatalf("expected OK, got %q", res.Status) + } + + var got bytes.Buffer + gr, err := gzip.NewReader(res.Body) + if err != nil { + t.Fatal(err) + } + _, err = io.Copy(&got, gr) + if err != nil { + t.Fatal(err) + } + + if got.String() != "hello" { + t.Errorf("expected hello, got %q", got.String()) + } +} + type fullyFeaturedResponseWriter struct{} // Header/Write/WriteHeader implement the http.ResponseWriter interface.