diff --git a/middleware/cache/README.md b/middleware/cache/README.md index 9f76418f1a..10972e2f01 100644 --- a/middleware/cache/README.md +++ b/middleware/cache/README.md @@ -118,6 +118,13 @@ type Config struct { // // Default: false StoreResponseHeaders bool + + // Max number of bytes of response bodies simultaneously stored in cache. When limit is reached, + // entries with the nearest expiration are deleted to make room for new. + // 0 means no limit + // + // Default: 0 + MaxBytes uint } ``` @@ -133,8 +140,9 @@ var ConfigDefault = Config{ KeyGenerator: func(c *fiber.Ctx) string { return utils.CopyString(c.Path()) }, - ExpirationGenerator : nil, + ExpirationGenerator: nil, StoreResponseHeaders: false, - Storage: nil, + Storage: nil, + MaxBytes: 0, } ``` diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index 367baef694..e761ff1731 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -59,6 +59,10 @@ func New(config ...Config) fiber.Handler { ) // Create manager to simplify storage operations ( see manager.go ) manager := newManager(cfg.Storage) + // Create indexed heap for tracking expirations ( see heap.go ) + heap := &indexedHeap{} + // count stored bytes (sizes of response bodies) + var storedBytes uint = 0 // Update timestamp in the configured interval go func() { @@ -68,6 +72,15 @@ func New(config ...Config) fiber.Handler { } }() + // Delete key from both manager and storage + deleteKey := func(dkey string) { + manager.delete(dkey) + // External storage saves body data with different key + if cfg.Storage != nil { + manager.delete(dkey + "_body") + } + } + // Return new handler return func(c *fiber.Ctx) error { // Only cache GET and HEAD methods @@ -89,12 +102,12 @@ func New(config ...Config) fiber.Handler { // Get timestamp ts := atomic.LoadUint64(×tamp) + // Check if entry is expired if e.exp != 0 && ts >= e.exp { - // Check if entry is expired - manager.delete(key) - // External storage saves body data with different key - if cfg.Storage != nil { - manager.delete(key + "_body") + deleteKey(key) + if cfg.MaxBytes > 0 { + _, size := heap.remove(e.heapidx) + storedBytes -= size } } else if e.exp != 0 { // Separate body value to avoid msgp serialization @@ -146,6 +159,22 @@ func New(config ...Config) fiber.Handler { return nil } + // Don't try to cache if body won't fit into cache + bodySize := uint(len(c.Response().Body())) + if cfg.MaxBytes > 0 && bodySize > cfg.MaxBytes { + c.Set(cfg.CacheHeader, cacheUnreachable) + return nil + } + + // Remove oldest to make room for new + if cfg.MaxBytes > 0 { + for storedBytes+bodySize > cfg.MaxBytes { + key, size := heap.removeFirst() + deleteKey(key) + storedBytes -= size + } + } + // Cache response e.body = utils.CopyBytes(c.Response().Body()) e.status = c.Response().StatusCode() @@ -175,6 +204,12 @@ func New(config ...Config) fiber.Handler { } e.exp = ts + uint64(expiration.Seconds()) + // Store entry in heap + if cfg.MaxBytes > 0 { + e.heapidx = heap.put(key, e.exp, bodySize) + storedBytes += bodySize + } + // For external Storage we store raw body separated if cfg.Storage != nil { manager.setRaw(key+"_body", e.body, expiration) diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index f29fd3650c..4c7e155a36 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -6,6 +6,7 @@ import ( "bytes" "fmt" "io/ioutil" + "math" "net/http" "net/http/httptest" "strconv" @@ -493,6 +494,88 @@ func Test_CustomCacheHeader(t *testing.T) { utils.AssertEqual(t, cacheMiss, resp.Header.Get("Cache-Status")) } +// Because time points are updated once every X milliseconds, entries in tests can often have +// equal expiration times and thus be in an random order. This closure hands out increasing +// time intervals to maintain strong ascending order of expiration +func stableAscendingExpiration() func(c1 *fiber.Ctx, c2 *Config) time.Duration { + i := 0 + return func(c1 *fiber.Ctx, c2 *Config) time.Duration { + i += 1 + return time.Hour * time.Duration(i) + } +} + +func Test_Cache_MaxBytesOrder(t *testing.T) { + t.Parallel() + + app := fiber.New() + app.Use(New(Config{ + MaxBytes: 2, + ExpirationGenerator: stableAscendingExpiration(), + })) + + app.Get("/*", func(c *fiber.Ctx) error { + return c.SendString("1") + }) + + cases := [][]string{ + // Insert a, b into cache of size 2 bytes (responses are 1 byte) + {"/a", cacheMiss}, + {"/b", cacheMiss}, + {"/a", cacheHit}, + {"/b", cacheHit}, + // Add c -> a evicted + {"/c", cacheMiss}, + {"/b", cacheHit}, + // Add a again -> b evicted + {"/a", cacheMiss}, + {"/c", cacheHit}, + // Add b -> c evicted + {"/b", cacheMiss}, + {"/c", cacheMiss}, + } + + for idx, tcase := range cases { + rsp, err := app.Test(httptest.NewRequest("GET", tcase[0], nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx)) + } +} + +func Test_Cache_MaxBytesSizes(t *testing.T) { + t.Parallel() + + app := fiber.New() + + app.Use(New(Config{ + MaxBytes: 7, + ExpirationGenerator: stableAscendingExpiration(), + })) + + app.Get("/*", func(c *fiber.Ctx) error { + path := c.Context().URI().LastPathSegment() + size, _ := strconv.Atoi(string(path)) + return c.Send(make([]byte, size)) + }) + + cases := [][]string{ + {"/1", cacheMiss}, + {"/2", cacheMiss}, + {"/3", cacheMiss}, + {"/4", cacheMiss}, // 1+2+3+4 > 7 => 1,2 are evicted now + {"/3", cacheHit}, + {"/1", cacheMiss}, + {"/2", cacheMiss}, + {"/8", cacheUnreachable}, // too big to cache -> unreachable + } + + for idx, tcase := range cases { + rsp, err := app.Test(httptest.NewRequest("GET", tcase[0], nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx)) + } +} + // go test -v -run=^$ -bench=Benchmark_Cache -benchmem -count=4 func Benchmark_Cache(b *testing.B) { app := fiber.New() @@ -578,3 +661,36 @@ func Benchmark_Cache_AdditionalHeaders(b *testing.B) { utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) utils.AssertEqual(b, []byte("foobar"), fctx.Response.Header.Peek("X-Foobar")) } + +func Benchmark_Cache_MaxSize(b *testing.B) { + // The benchmark is run with three different MaxSize parameters + // 1) 0: Tracking is disabled = no overhead + // 2) MaxInt32: Enough to store all entries = no removals + // 3) 100: Small size = constant insertions and removals + cases := []uint{0, math.MaxUint32, 100} + names := []string{"Disabled", "Unlim", "LowBounded"} + for i, size := range cases { + b.Run(names[i], func(b *testing.B) { + app := fiber.New() + app.Use(New(Config{MaxBytes: size})) + + app.Get("/*", func(c *fiber.Ctx) error { + return c.Status(fiber.StatusTeapot).SendString("1") + }) + + h := app.Handler() + fctx := &fasthttp.RequestCtx{} + fctx.Request.Header.SetMethod("GET") + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + fctx.Request.SetRequestURI(fmt.Sprintf("/%v", n)) + h(fctx) + } + + utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) + }) + } +} diff --git a/middleware/cache/config.go b/middleware/cache/config.go index 636eb0141e..625d1c478b 100644 --- a/middleware/cache/config.go +++ b/middleware/cache/config.go @@ -59,6 +59,13 @@ type Config struct { // // Default: false StoreResponseHeaders bool + + // Max number of bytes of response bodies simultaneously stored in cache. When limit is reached, + // entries with the nearest expiration are deleted to make room for new. + // 0 means no limit + // + // Default: 0 + MaxBytes uint } // ConfigDefault is the default config @@ -73,6 +80,7 @@ var ConfigDefault = Config{ ExpirationGenerator: nil, StoreResponseHeaders: false, Storage: nil, + MaxBytes: 0, } // Helper function to set default values diff --git a/middleware/cache/heap.go b/middleware/cache/heap.go new file mode 100644 index 0000000000..70271b84b8 --- /dev/null +++ b/middleware/cache/heap.go @@ -0,0 +1,92 @@ +package cache + +import ( + "container/heap" +) + +type heapEntry struct { + key string + exp uint64 + bytes uint + idx int +} + +// indexedHeap is a regular min-heap that allows finding +// elements in constant time. It does so by handing out special indices +// and tracking entry movement. +// +// indexdedHeap is used for quickly finding entries with the lowest +// expiration timestamp and deleting arbitrary entries. +type indexedHeap struct { + // Slice the heap is built on + entries []heapEntry + // Mapping "index" to position in heap slice + indices []int + // Max index handed out + maxidx int +} + +func (h indexedHeap) Len() int { + return len(h.entries) +} + +func (h indexedHeap) Less(i, j int) bool { + return h.entries[i].exp < h.entries[j].exp +} + +func (h indexedHeap) Swap(i, j int) { + h.entries[i], h.entries[j] = h.entries[j], h.entries[i] + h.indices[h.entries[i].idx] = i + h.indices[h.entries[j].idx] = j +} + +func (h *indexedHeap) Push(x interface{}) { + h.pushInternal(x.(heapEntry)) +} + +func (h *indexedHeap) Pop() interface{} { + n := len(h.entries) + h.entries = h.entries[0 : n-1] + return h.entries[0:n][n-1] +} + +func (h *indexedHeap) pushInternal(entry heapEntry) { + h.indices[entry.idx] = len(h.entries) + h.entries = append(h.entries, entry) +} + +// Returns index to track entry +func (h *indexedHeap) put(key string, exp uint64, bytes uint) int { + idx := 0 + if len(h.entries) < h.maxidx { + // Steal index from previously removed entry + // capacity > size is guaranteed + n := len(h.entries) + idx = h.entries[:n+1][n].idx + } else { + idx = h.maxidx + h.maxidx += 1 + h.indices = append(h.indices, idx) + } + // Push manually to avoid allocation + h.pushInternal(heapEntry{ + key: key, exp: exp, idx: idx, bytes: bytes, + }) + heap.Fix(h, h.Len()-1) + return idx +} + +func (h *indexedHeap) removeInternal(realIdx int) (string, uint) { + x := heap.Remove(h, realIdx).(heapEntry) + return x.key, x.bytes +} + +// Remove entry by index +func (h *indexedHeap) remove(idx int) (string, uint) { + return h.removeInternal(h.indices[idx]) +} + +// Remove entry with lowest expiration time +func (h *indexedHeap) removeFirst() (string, uint) { + return h.removeInternal(0) +} diff --git a/middleware/cache/manager.go b/middleware/cache/manager.go index 5ba9260c09..6b9256fd23 100644 --- a/middleware/cache/manager.go +++ b/middleware/cache/manager.go @@ -19,6 +19,8 @@ type item struct { status int exp uint64 headers map[string][]byte + // used for finding the item in an indexed heap + heapidx int } //msgp:ignore manager diff --git a/middleware/cache/manager_msgp.go b/middleware/cache/manager_msgp.go index 9d373dae05..6d8fd8536d 100644 --- a/middleware/cache/manager_msgp.go +++ b/middleware/cache/manager_msgp.go @@ -1,6 +1,8 @@ package cache -// Code generated by github.com/tinylib/msgp DO NOT EDIT. +// NOTE: THIS FILE WAS PRODUCED BY THE +// MSGP CODE GENERATION TOOL (github.com/tinylib/msgp) +// DO NOT EDIT import ( "github.com/gofiber/fiber/v2/internal/msgp" @@ -10,84 +12,78 @@ import ( func (z *item) DecodeMsg(dc *msgp.Reader) (err error) { var field []byte _ = field - var zb0001 uint32 - zb0001, err = dc.ReadMapHeader() + var zbai uint32 + zbai, err = dc.ReadMapHeader() if err != nil { - err = msgp.WrapError(err) return } - for zb0001 > 0 { - zb0001-- + for zbai > 0 { + zbai-- field, err = dc.ReadMapKeyPtr() if err != nil { - err = msgp.WrapError(err) return } switch msgp.UnsafeString(field) { case "body": z.body, err = dc.ReadBytes(z.body) if err != nil { - err = msgp.WrapError(err, "body") return } case "ctype": z.ctype, err = dc.ReadBytes(z.ctype) if err != nil { - err = msgp.WrapError(err, "ctype") return } case "cencoding": z.cencoding, err = dc.ReadBytes(z.cencoding) if err != nil { - err = msgp.WrapError(err, "cencoding") return } case "status": z.status, err = dc.ReadInt() if err != nil { - err = msgp.WrapError(err, "status") return } case "exp": z.exp, err = dc.ReadUint64() if err != nil { - err = msgp.WrapError(err, "exp") return } case "headers": - var zb0002 uint32 - zb0002, err = dc.ReadMapHeader() + var zcmr uint32 + zcmr, err = dc.ReadMapHeader() if err != nil { - err = msgp.WrapError(err, "headers") return } - if z.headers == nil { - z.headers = make(map[string][]byte, zb0002) + if z.headers == nil && zcmr > 0 { + z.headers = make(map[string][]byte, zcmr) } else if len(z.headers) > 0 { - for key := range z.headers { + for key, _ := range z.headers { delete(z.headers, key) } } - for zb0002 > 0 { - zb0002-- - var za0001 string - var za0002 []byte - za0001, err = dc.ReadString() + for zcmr > 0 { + zcmr-- + var zxvk string + var zbzg []byte + zxvk, err = dc.ReadString() if err != nil { - err = msgp.WrapError(err, "headers") return } - za0002, err = dc.ReadBytes(za0002) + zbzg, err = dc.ReadBytes(zbzg) if err != nil { - err = msgp.WrapError(err, "headers", za0001) return } - z.headers[za0001] = za0002 + z.headers[zxvk] = zbzg + } + case "heapidx": + z.heapidx, err = dc.ReadInt() + if err != nil { + return } default: err = dc.Skip() if err != nil { - err = msgp.WrapError(err) return } } @@ -97,88 +93,89 @@ func (z *item) DecodeMsg(dc *msgp.Reader) (err error) { // EncodeMsg implements msgp.Encodable func (z *item) EncodeMsg(en *msgp.Writer) (err error) { - // map header, size 6 + // map header, size 7 // write "body" - err = en.Append(0x86, 0xa4, 0x62, 0x6f, 0x64, 0x79) + err = en.Append(0x87, 0xa4, 0x62, 0x6f, 0x64, 0x79) if err != nil { - return + return err } err = en.WriteBytes(z.body) if err != nil { - err = msgp.WrapError(err, "body") return } // write "ctype" err = en.Append(0xa5, 0x63, 0x74, 0x79, 0x70, 0x65) if err != nil { - return + return err } err = en.WriteBytes(z.ctype) if err != nil { - err = msgp.WrapError(err, "ctype") return } // write "cencoding" err = en.Append(0xa9, 0x63, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67) if err != nil { - return + return err } err = en.WriteBytes(z.cencoding) if err != nil { - err = msgp.WrapError(err, "cencoding") return } // write "status" err = en.Append(0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73) if err != nil { - return + return err } err = en.WriteInt(z.status) if err != nil { - err = msgp.WrapError(err, "status") return } // write "exp" err = en.Append(0xa3, 0x65, 0x78, 0x70) if err != nil { - return + return err } err = en.WriteUint64(z.exp) if err != nil { - err = msgp.WrapError(err, "exp") return } // write "headers" - err = en.Append(0xaa, 0x65, 0x32, 0x65, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73) + err = en.Append(0xa7, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73) if err != nil { - return + return err } err = en.WriteMapHeader(uint32(len(z.headers))) if err != nil { - err = msgp.WrapError(err, "headers") return } - for za0001, za0002 := range z.headers { - err = en.WriteString(za0001) + for zxvk, zbzg := range z.headers { + err = en.WriteString(zxvk) if err != nil { - err = msgp.WrapError(err, "headers") return } - err = en.WriteBytes(za0002) + err = en.WriteBytes(zbzg) if err != nil { - err = msgp.WrapError(err, "headers", za0001) return } } + // write "heapidx" + err = en.Append(0xa7, 0x68, 0x65, 0x61, 0x70, 0x69, 0x64, 0x78) + if err != nil { + return err + } + err = en.WriteInt(z.heapidx) + if err != nil { + return + } return } // MarshalMsg implements msgp.Marshaler func (z *item) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) - // map header, size 6 + // map header, size 7 // string "body" - o = append(o, 0x86, 0xa4, 0x62, 0x6f, 0x64, 0x79) + o = append(o, 0x87, 0xa4, 0x62, 0x6f, 0x64, 0x79) o = msgp.AppendBytes(o, z.body) // string "ctype" o = append(o, 0xa5, 0x63, 0x74, 0x79, 0x70, 0x65) @@ -193,12 +190,15 @@ func (z *item) MarshalMsg(b []byte) (o []byte, err error) { o = append(o, 0xa3, 0x65, 0x78, 0x70) o = msgp.AppendUint64(o, z.exp) // string "headers" - o = append(o, 0xaa, 0x65, 0x32, 0x65, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73) + o = append(o, 0xa7, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73) o = msgp.AppendMapHeader(o, uint32(len(z.headers))) - for za0001, za0002 := range z.headers { - o = msgp.AppendString(o, za0001) - o = msgp.AppendBytes(o, za0002) + for zxvk, zbzg := range z.headers { + o = msgp.AppendString(o, zxvk) + o = msgp.AppendBytes(o, zbzg) } + // string "heapidx" + o = append(o, 0xa7, 0x68, 0x65, 0x61, 0x70, 0x69, 0x64, 0x78) + o = msgp.AppendInt(o, z.heapidx) return } @@ -206,84 +206,78 @@ func (z *item) MarshalMsg(b []byte) (o []byte, err error) { func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) { var field []byte _ = field - var zb0001 uint32 - zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + var zajw uint32 + zajw, bts, err = msgp.ReadMapHeaderBytes(bts) if err != nil { - err = msgp.WrapError(err) return } - for zb0001 > 0 { - zb0001-- + for zajw > 0 { + zajw-- field, bts, err = msgp.ReadMapKeyZC(bts) if err != nil { - err = msgp.WrapError(err) return } switch msgp.UnsafeString(field) { case "body": z.body, bts, err = msgp.ReadBytesBytes(bts, z.body) if err != nil { - err = msgp.WrapError(err, "body") return } case "ctype": z.ctype, bts, err = msgp.ReadBytesBytes(bts, z.ctype) if err != nil { - err = msgp.WrapError(err, "ctype") return } case "cencoding": z.cencoding, bts, err = msgp.ReadBytesBytes(bts, z.cencoding) if err != nil { - err = msgp.WrapError(err, "cencoding") return } case "status": z.status, bts, err = msgp.ReadIntBytes(bts) if err != nil { - err = msgp.WrapError(err, "status") return } case "exp": z.exp, bts, err = msgp.ReadUint64Bytes(bts) if err != nil { - err = msgp.WrapError(err, "exp") return } case "headers": - var zb0002 uint32 - zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) + var zwht uint32 + zwht, bts, err = msgp.ReadMapHeaderBytes(bts) if err != nil { - err = msgp.WrapError(err, "headers") return } - if z.headers == nil { - z.headers = make(map[string][]byte, zb0002) + if z.headers == nil && zwht > 0 { + z.headers = make(map[string][]byte, zwht) } else if len(z.headers) > 0 { - for key := range z.headers { + for key, _ := range z.headers { delete(z.headers, key) } } - for zb0002 > 0 { - var za0001 string - var za0002 []byte - zb0002-- - za0001, bts, err = msgp.ReadStringBytes(bts) + for zwht > 0 { + var zxvk string + var zbzg []byte + zwht-- + zxvk, bts, err = msgp.ReadStringBytes(bts) if err != nil { - err = msgp.WrapError(err, "headers") return } - za0002, bts, err = msgp.ReadBytesBytes(bts, za0002) + zbzg, bts, err = msgp.ReadBytesBytes(bts, zbzg) if err != nil { - err = msgp.WrapError(err, "headers", za0001) return } - z.headers[za0001] = za0002 + z.headers[zxvk] = zbzg + } + case "heapidx": + z.heapidx, bts, err = msgp.ReadIntBytes(bts) + if err != nil { + return } default: bts, err = msgp.Skip(bts) if err != nil { - err = msgp.WrapError(err) return } } @@ -294,12 +288,13 @@ func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) { // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message func (z *item) Msgsize() (s int) { - s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 7 + msgp.IntSize + 4 + msgp.Uint64Size + 11 + msgp.MapHeaderSize + s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 7 + msgp.IntSize + 4 + msgp.Uint64Size + 8 + msgp.MapHeaderSize if z.headers != nil { - for za0001, za0002 := range z.headers { - _ = za0002 - s += msgp.StringPrefixSize + len(za0001) + msgp.BytesPrefixSize + len(za0002) + for zxvk, zbzg := range z.headers { + _ = zbzg + s += msgp.StringPrefixSize + len(zxvk) + msgp.BytesPrefixSize + len(zbzg) } } + s += 8 + msgp.IntSize return }