Skip to content

Commit

Permalink
Add a dedicated cache container for the zstd compression algorithm (#…
Browse files Browse the repository at this point in the history
…1828)

Add a dedicated cache container for the zstd compression algorithm to prevent discrepancies between the response content and the implied Content-Encoding in certain scenarios.Fix: #1827 (comment)
  • Loading branch information
newacorn authored Aug 20, 2024
1 parent 5cc0ea1 commit d29a2b9
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
5 changes: 5 additions & 0 deletions fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ func newCacheManager(fs *FS) cacheManager {
cache: make(map[string]*fsFile),
cacheBrotli: make(map[string]*fsFile),
cacheGzip: make(map[string]*fsFile),
cacheZstd: make(map[string]*fsFile),
}

go instance.handleCleanCache(fs.CleanStop)
Expand Down Expand Up @@ -850,6 +851,7 @@ type inMemoryCacheManager struct {
cache map[string]*fsFile
cacheBrotli map[string]*fsFile
cacheGzip map[string]*fsFile
cacheZstd map[string]*fsFile
cacheDuration time.Duration
cacheLock sync.Mutex
}
Expand All @@ -869,6 +871,8 @@ func (cm *inMemoryCacheManager) getFsCache(cacheKind CacheKind) map[string]*fsFi
fileCache = cm.cacheBrotli
case gzipCacheKind:
fileCache = cm.cacheGzip
case zstdCacheKind:
fileCache = cm.cacheZstd
}

return fileCache
Expand Down Expand Up @@ -959,6 +963,7 @@ func (cm *inMemoryCacheManager) cleanCache(pendingFiles []*fsFile) []*fsFile {
pendingFiles, filesToRelease = cleanCacheNolock(cm.cache, pendingFiles, filesToRelease, cm.cacheDuration)
pendingFiles, filesToRelease = cleanCacheNolock(cm.cacheBrotli, pendingFiles, filesToRelease, cm.cacheDuration)
pendingFiles, filesToRelease = cleanCacheNolock(cm.cacheGzip, pendingFiles, filesToRelease, cm.cacheDuration)
pendingFiles, filesToRelease = cleanCacheNolock(cm.cacheZstd, pendingFiles, filesToRelease, cm.cacheDuration)

cm.cacheLock.Unlock()

Expand Down
64 changes: 64 additions & 0 deletions fs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"path/filepath"
"runtime"
"sort"
"strings"
"testing"
"time"
)
Expand Down Expand Up @@ -939,3 +940,66 @@ func TestServeFileDirectoryRedirect(t *testing.T) {
t.Fatalf("Unexpected status code %d for file '/fs.go'. Expecting %d.", ctx.Response.StatusCode(), StatusOK)
}
}

func TestFileCacheForZstd(t *testing.T) {
f, err := os.CreateTemp(os.TempDir(), "test")
if err != nil {
t.Fatal(err)
}
data := bytes.Repeat([]byte("1"), 1000)
changedData := bytes.Repeat([]byte("2"), 1000)
_, err = f.Write(data)
if err != nil {
t.Fatal(err)
}
err = f.Sync()
if err != nil {
t.Fatal(err)
}
fs := FS{Root: os.TempDir(), Compress: true, CacheDuration: time.Second * 60}
h := fs.NewRequestHandler()
var ctx RequestCtx
var req Request
req.Header.Set("Accept-Encoding", "zstd")
req.SetRequestURI("http://foobar.com/" + strings.TrimPrefix(f.Name(), os.TempDir()))
ctx.Init(&req, nil, nil)
h(&ctx)
if !bytes.Equal(ctx.Response.Header.ContentEncoding(), []byte("zstd")) {
t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ctx.Response.Header.ContentEncoding(), "zstd")
}
ctx.Response.Reset()
_, err = f.Seek(0, io.SeekStart)
if err != nil {
t.Fatal(err)
}
_, err = f.Write(changedData)
if err != nil {
t.Fatal(err)
}
f.Close()
h(&ctx)
if !bytes.Equal(ctx.Response.Header.ContentEncoding(), []byte("zstd")) {
t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ctx.Response.Header.ContentEncoding(), "zstd")
}
d, err := acquireZstdReader(strings.NewReader(string(ctx.Response.Body())))
if err != nil {
t.Fatalf("invalid zstd reader")
}
plainText, err := io.ReadAll(d)
d.Close()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plainText, data) {
t.Fatalf("Unexpected response body %q. Expecting %q . Zstd cache doesn't work", plainText, data)
}
ctx.Request.Header.Del("Accept-Encoding")
ctx.Response.Reset()
h(&ctx)
if !bytes.Equal(ctx.Response.Header.ContentEncoding(), []byte("")) {
t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ctx.Response.Header.ContentEncoding(), "")
}
if !bytes.Equal(ctx.Response.Body(), changedData) {
t.Fatalf("Unexpected response body %q. Expecting %q", ctx.Response.Body(), data)
}
}

0 comments on commit d29a2b9

Please sign in to comment.