Skip to content

Commit

Permalink
compress: update compression middleware
Browse files Browse the repository at this point in the history
This updates the compression middleware to use the 1.20
ResponseController scheme, zstd compression, and adds tests.

Signed-off-by: Hank Donnay <hdonnay@redhat.com>
  • Loading branch information
hdonnay committed Oct 20, 2023
1 parent a90ecc4 commit c90a55f
Show file tree
Hide file tree
Showing 2 changed files with 296 additions and 106 deletions.
210 changes: 104 additions & 106 deletions middleware/compress/handler.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
// Package compress implements an RFC9110 compliant handler for the
// "Accept-Encoding" header.
//
// This package supports "identity", "gzip", "deflate", and "zstd".
package compress

import (
"fmt"
"errors"
"io"
"mime"
"net/http"
Expand All @@ -12,24 +16,36 @@ import (

"github.com/klauspost/compress/flate"
"github.com/klauspost/compress/gzip"
"github.com/klauspost/compress/snappy"
"github.com/klauspost/compress/zstd"
)

// Handler wraps the provided http.Handler and provides transparent body
// compression based on a Request's "Accept-Encoding" header.
//
// Each handler instance pools its own compressors.
func Handler(next http.Handler) http.Handler {
h := handler{
next: next,
}
h.snappy.New = func() interface{} {
return snappy.NewBufferedWriter(nil)
h.zstd.New = func() interface{} {
w, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedFastest))
if err != nil {
panic(err)
}
return w
}
h.gzip.New = func() interface{} {
w, _ := gzip.NewWriterLevel(nil, gzip.BestSpeed)
w, err := gzip.NewWriterLevel(nil, gzip.BestSpeed)
if err != nil {
panic(err)
}
return w
}
h.flate.New = func() interface{} {
w, _ := flate.NewWriter(nil, flate.BestSpeed)
w, err := flate.NewWriter(nil, flate.BestSpeed)
if err != nil {
panic(err)
}
return w
}

Expand All @@ -38,28 +54,17 @@ func Handler(next http.Handler) http.Handler {

var _ http.Handler = (*handler)(nil)

// handler performs transparent HTTP body compression.
// Handler performs transparent HTTP body compression.
type handler struct {
snappy, gzip, flate sync.Pool
next http.Handler
}

// Header is an interface that has the http.ResponseWriter's Header-related
// methods.
type header interface {
Header() http.Header
WriteHeader(int)
zstd, gzip, flate sync.Pool
next http.Handler
}

// ParseAccept parses an "Accept-Encoding" header.
//
// Reports a sorted list of encodings and a map of disallowed encodings.
// Reports nil if no selections were present.
func parseAccept(h string) ([]accept, map[string]struct{}) {
if h == "" {
return nil, nil
}

segs := strings.Split(h, ",")
ret := make([]accept, 0, len(segs))
nok := make(map[string]struct{})
Expand Down Expand Up @@ -98,45 +103,41 @@ type accept struct {

// ServeHTTP implements http.Handler.
func (c *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ae, nok := parseAccept(r.Header.Get("accept-encoding"))
if ae == nil {
// If there was no header, play it cool.
v, ok := r.Header["Accept-Encoding"]
if !ok { // No header, use "identity".
c.next.ServeHTTP(w, r)
return
}
var (
flusher http.Flusher
pusher http.Pusher
cw io.WriteCloser
)
flusher, _ = w.(http.Flusher)
pusher, _ = w.(http.Pusher)

// Find the first accept-encoding we support.
// See https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.1 for
// all the sematics.
ae, nok := parseAccept(v[0])
var zw zwriter
// Find the first accept-encoding we support. See
// https://www.rfc-editor.org/rfc/rfc9110.html#section-12.5.3 for all the
// semantics.
//
// NB The "identity" encoding shouldn't show up in the Content-Encoding
// response header.
for _, a := range ae {
switch a.Type {
case "gzip":
case "gzip", "x-gzip":
w.Header().Set("content-encoding", "gzip")
gz := c.gzip.Get().(*gzip.Writer)
gz.Reset(w)
defer c.gzip.Put(gz)
cw = gz
zw = gz
case "deflate":
w.Header().Set("content-encoding", "deflate")
z := c.flate.Get().(*flate.Writer)
z.Reset(w)
defer c.flate.Put(z)
cw = z
case "snappy": // Nonstandard
w.Header().Set("content-encoding", "snappy")
s := c.snappy.Get().(*snappy.Writer)
zw = z
case "zstd":
w.Header().Set("content-encoding", "zstd")
s := c.zstd.Get().(*zstd.Encoder)
s.Reset(w)
defer c.snappy.Put(s)
cw = s
defer c.zstd.Put(s)
zw = s
case "identity":
w.Header().Set("content-encoding", "identity")
w.Header().Set("accept-encoding", acceptable)
case "*":
// If we hit a star, it's technically OK to return any encoding not
// already specified. So, attempt to use gzip and then identity and
Expand All @@ -152,80 +153,77 @@ func (c *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
gz := c.gzip.Get().(*gzip.Writer)
gz.Reset(w)
defer c.gzip.Put(gz)
cw = gz
zw = gz
case !idnok:
w.Header().Set("content-encoding", "identity")
// "Identity" isn't not OK, so fallthrough.
default:
w.WriteHeader(http.StatusNotAcceptable)
w.Header().Set("accept-encoding", acceptable)
w.WriteHeader(http.StatusUnsupportedMediaType)
return
}
default:
continue
}
break
}
// Now "zw" should be populated if it can be.
if zw == nil {
w.Header().Set("accept-encoding", acceptable)
// If it's not, we need to make sure identity or "any" aren't
// disallowed.
_, idnok := nok["identity"]
_, anynok := nok["*"]
if idnok || anynok {
w.WriteHeader(http.StatusUnsupportedMediaType)
return
}
// Couldn't pick something, fall back to identity.
c.next.ServeHTTP(w, r)
return
}
// Do some setup so we can see the error, albeit as a trailer.
if cw != nil {
const errHeader = `Clair-Error`
w.Header().Add("trailer", errHeader)
defer func() {
if err := cw.Close(); err != nil {
w.Header().Add(errHeader, err.Error())
}
}()
const errHeader = `Clair-Error`
w.Header().Add("trailer", errHeader)
defer func() {
if err := zw.Close(); err != nil {
w.Header().Add(errHeader, err.Error())
}
}()
next := compressWriter{
ResponseWriter: w,
zwriter: zw,
}
c.next.ServeHTTP(&next, r)
}

// Nw is the http.ResponseWriter for our next http.Handler.
var nw http.ResponseWriter
// This is a giant truth table to make anonymous types that satisfy as many
// optional interfaces as possible.
//
// We care about 3 interfaces, so there are 2^3 == 8 combinations.
switch {
case flusher == nil && pusher == nil && cw == nil:
nw = w
case flusher == nil && pusher == nil && cw != nil:
nw = struct {
header
io.Writer
}{w, cw}
case flusher == nil && pusher != nil && cw == nil:
nw = struct {
http.ResponseWriter
http.Pusher
}{w, pusher}
case flusher == nil && pusher != nil && cw != nil:
nw = struct {
header
io.Writer
http.Pusher
}{w, cw, pusher}
case flusher != nil && pusher == nil && cw == nil:
nw = struct {
http.ResponseWriter
http.Flusher
}{w, flusher}
case flusher != nil && pusher == nil && cw != nil:
nw = struct {
header
io.Writer
http.Flusher
}{w, cw, flusher}
case flusher != nil && pusher != nil && cw == nil:
nw = struct {
http.ResponseWriter
http.Flusher
http.Pusher
}{w, flusher, pusher}
case flusher != nil && pusher != nil && cw != nil:
nw = struct {
header
io.Writer
http.Flusher
http.Pusher
}{w, cw, flusher, pusher}
default:
panic(fmt.Sprintf("unexpected type combination: %T/%T/%T", flusher, pusher, cw))
// Acceptable is a preformatted list of acceptable encodings.
const acceptable = `zstd, gzip, deflate`

// CompressWriter is compressing http.ResponseWriter that understands the go1.20
// ResponseController scheme.
type compressWriter struct {
http.ResponseWriter
zwriter
}

type zwriter interface {
io.WriteCloser
Flush() error
}

var _ http.ResponseWriter = (*compressWriter)(nil)

func (c *compressWriter) Unwrap() http.ResponseWriter {
return c.ResponseWriter
}
func (c *compressWriter) Write(b []byte) (int, error) {
return c.zwriter.Write(b)
}
func (c *compressWriter) FlushError() error {
zFlush := c.zwriter.Flush()
httpFlush := http.NewResponseController(c.ResponseWriter).Flush()
if errors.Is(httpFlush, http.ErrNotSupported) {
httpFlush = nil
}
c.next.ServeHTTP(nw, r)
return errors.Join(zFlush, httpFlush)
}
Loading

0 comments on commit c90a55f

Please sign in to comment.