Skip to content

Commit

Permalink
feat(middleware): add Deflate middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanji-dev committed Feb 4, 2022
1 parent 4a1ccdf commit 6e6ae1e
Show file tree
Hide file tree
Showing 2 changed files with 272 additions and 22 deletions.
97 changes: 75 additions & 22 deletions middleware/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package middleware
import (
"bufio"
"compress/gzip"
"compress/zlib"
"io"
"io/ioutil"
"net"
Expand All @@ -14,33 +15,50 @@ import (
)

type (
// GzipConfig defines the config for Gzip middleware.
GzipConfig struct {
compressConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper

// Gzip compression level.
// Compression level.
// Optional. Default value -1.
Level int `yaml:"level"`
}

gzipResponseWriter struct {
// GzipConfig defines the config for Gzip middleware.
GzipConfig compressConfig
// DeflateConfig defines the config for Deflate middleware.
DeflateConfig compressConfig

compressResponseWriter struct {
io.Writer
http.ResponseWriter
wroteBody bool
}

resetWriteCloser interface {
Reset(w io.Writer)
io.WriteCloser
}

flusher interface {
Flush() error
}
)

const (
gzipScheme = "gzip"
gzipScheme = "gzip"
deflateScheme = "deflate"
)

var (
// DefaultGzipConfig is the default Gzip middleware config.
DefaultGzipConfig = GzipConfig{
defaultConfig = compressConfig{
Skipper: DefaultSkipper,
Level: -1,
}
// DefaultGzipConfig is the default Gzip middleware config.
DefaultGzipConfig = GzipConfig(defaultConfig)
// DefaultDeflateConfig is the default Deflate middleware config.
DefaultDeflateConfig = DeflateConfig(defaultConfig)
)

// Gzip returns a middleware which compresses HTTP response using gzip compression
Expand All @@ -49,18 +67,41 @@ func Gzip() echo.MiddlewareFunc {
return GzipWithConfig(DefaultGzipConfig)
}

// Deflate returns a middleware which compresses HTTP response using deflate(zlib) compression
func Deflate() echo.MiddlewareFunc {
return DeflateWithConfig(DefaultDeflateConfig)
}

// GzipWithConfig return Gzip middleware with config.
// See: `Gzip()`.
func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
return compressWithConfig(compressConfig(config), gzipScheme)
}

// DeflateWithConfig return Deflate middleware with config.
// See: `Deflate()`.
func DeflateWithConfig(config DeflateConfig) echo.MiddlewareFunc {
return compressWithConfig(compressConfig(config), deflateScheme)
}

func compressWithConfig(config compressConfig, encoding string) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
config.Skipper = defaultConfig.Skipper
}
if config.Level == 0 {
config.Level = DefaultGzipConfig.Level
config.Level = defaultConfig.Level
}

pool := gzipCompressPool(config)
var pool sync.Pool
switch encoding {
case gzipScheme:
pool = gzipCompressPool(config)
case deflateScheme:
pool = deflateCompressPool(config)
default:
panic("echo: either gzip or deflate is currently supported")
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
Expand All @@ -70,19 +111,19 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {

res := c.Response()
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), encoding) {
res.Header().Set(echo.HeaderContentEncoding, encoding) // Issue #806
i := pool.Get()
w, ok := i.(*gzip.Writer)
w, ok := i.(resetWriteCloser)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}
rw := res.Writer
w.Reset(rw)
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
grw := &compressResponseWriter{Writer: w, ResponseWriter: rw}
defer func() {
if !grw.wroteBody {
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
if res.Header().Get(echo.HeaderContentEncoding) == encoding {
res.Header().Del(echo.HeaderContentEncoding)
}
// We have to reset response to it's pristine state when
Expand All @@ -101,38 +142,38 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
}
}

func (w *gzipResponseWriter) WriteHeader(code int) {
func (w *compressResponseWriter) WriteHeader(code int) {
w.Header().Del(echo.HeaderContentLength) // Issue #444
w.ResponseWriter.WriteHeader(code)
}

func (w *gzipResponseWriter) Write(b []byte) (int, error) {
func (w *compressResponseWriter) Write(b []byte) (int, error) {
if w.Header().Get(echo.HeaderContentType) == "" {
w.Header().Set(echo.HeaderContentType, http.DetectContentType(b))
}
w.wroteBody = true
return w.Writer.Write(b)
}

func (w *gzipResponseWriter) Flush() {
w.Writer.(*gzip.Writer).Flush()
func (w *compressResponseWriter) Flush() {
w.Writer.(flusher).Flush()
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
func (w *compressResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
}

func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
func (w *compressResponseWriter) Push(target string, opts *http.PushOptions) error {
if p, ok := w.ResponseWriter.(http.Pusher); ok {
return p.Push(target, opts)
}
return http.ErrNotSupported
}

func gzipCompressPool(config GzipConfig) sync.Pool {
func gzipCompressPool(config compressConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)
Expand All @@ -143,3 +184,15 @@ func gzipCompressPool(config GzipConfig) sync.Pool {
},
}
}

func deflateCompressPool(config compressConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
w, err := zlib.NewWriterLevel(ioutil.Discard, config.Level)
if err != nil {
return err
}
return w
},
}
}
Loading

0 comments on commit 6e6ae1e

Please sign in to comment.