From d62773ca3b8d72fe53e88e00d3831ccaca8b1ffc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Tue, 4 Jun 2024 11:28:24 +0200 Subject: [PATCH] Support stale cache entries To support eTag/change detection for entries marked stale outside of this library. Also switch the values in XETag1 and XETag2 so they are ordered by old/new. --- httpcache.go | 67 +++++++++++++++++++++++++++++++---------------- httpcache_test.go | 50 ++++++++++++++++++++++++++++++++--- 2 files changed, 91 insertions(+), 26 deletions(-) diff --git a/httpcache.go b/httpcache.go index 9faa754..cad2f23 100644 --- a/httpcache.go +++ b/httpcache.go @@ -33,13 +33,16 @@ const ( XETag1 = xEtags + "1" // XETag2 is the key for the second eTag value. + // Note that in the cache, XETag1 and XETag2 will always be the same. + // In the Response returned from Response, XETag1 will be the cached value (old) and + // XETag2 will be the eTag value from the server (new). XETag2 = xEtags + "2" ) // A Cache interface is used by the Transport to store and retrieve responses. type Cache interface { // Get returns the []byte representation of a cached response and a bool - // set to true if the value isn't empty + // set to set to false if the key is not found or the value is stale. Get(key string) (responseBytes []byte, ok bool) // Set stores the []byte representation of a response against a key Set(key string, responseBytes []byte) @@ -65,16 +68,19 @@ func (t *Transport) cacheKey(req *http.Request) string { } } -// cachedResponse returns the cached http.Response for req if present, and nil -// otherwise. -func (t *Transport) cachedResponse(req *http.Request) (resp *http.Response, err error) { +// cachedResponse returns the cached http.Response for req if present and +// a bool set to false if the value is stale. +func (t *Transport) cachedResponse(req *http.Request) (*http.Response, bool, error) { cachedVal, ok := t.Cache.Get(t.cacheKey(req)) - if !ok { - return + if !ok && len(cachedVal) == 0 { + return nil, false, nil } - b := bytes.NewBuffer(cachedVal) - return http.ReadResponse(bufio.NewReader(b), req) + resp, err := http.ReadResponse(bufio.NewReader(b), req) + if err != nil { + return nil, false, err + } + return resp, ok, nil } // Transport is an implementation of http.RoundTripper that will return values from a cache @@ -145,10 +151,13 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error cacheable := cacheKey != "" - var cachedResp *http.Response + var ( + cachedResp *http.Response + hasCachedResp bool + ) if cacheable { - cachedResp, err = t.cachedResponse(req) - if err == nil && cachedResp != nil && t.AlwaysUseCachedResponse != nil && t.AlwaysUseCachedResponse(req, cacheKey) { + cachedResp, hasCachedResp, err = t.cachedResponse(req) + if err == nil && hasCachedResp && t.AlwaysUseCachedResponse != nil && t.AlwaysUseCachedResponse(req, cacheKey) { return cachedResp, nil } } else { @@ -161,13 +170,16 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error transport = http.DefaultTransport } - if cacheable && cachedResp != nil && err == nil { - if t.MarkCachedResponses { - cachedResp.Header.Set(XFromCache, "1") - } + if cachedResp != nil { if t.EnableETagPair { cachedXEtag, _ = getXETags(cachedResp.Header) } + } + + if cacheable && hasCachedResp && err == nil { + if t.MarkCachedResponses { + cachedResp.Header.Set(XFromCache, "1") + } if varyMatches(cachedResp, req) { // Can only use cached value if the new request doesn't Vary significantly @@ -247,16 +259,19 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error t.Cache.Set(cacheKey, respBytes) } default: - var etagHash hash.Hash + var ( + etagHash hash.Hash + etag1 = cachedXEtag + etag2 string + ) + r := resp.Body if t.EnableETagPair { if etag := resp.Header.Get("etag"); etag != "" { - resp.Header.Set(XETag1, etag) - etag2 := cachedXEtag + etag1 = etag if etag2 == "" { etag2 = etag } - resp.Header.Set(XETag2, etag2) } else { etagHash = md5.New() r = struct { @@ -274,17 +289,23 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error OnEOF: func(r io.Reader) { if etagHash != nil { md5Str := hex.EncodeToString(etagHash.Sum(nil)) + etag2 = md5Str resp.Header.Set(XETag1, md5Str) - etag2 := cachedXEtag - if etag2 == "" { - etag2 = md5Str + resp.Header.Set(XETag2, md5Str) + if etag1 == "" { + etag1 = md5Str } - resp.Header.Set(XETag2, etag2) + } else { + resp.Header.Set(XETag1, etag1) + resp.Header.Set(XETag2, etag1) } + resp := *resp resp.Body = io.NopCloser(r) respBytes, err := httputil.DumpResponse(&resp, true) if err == nil { + // Signal any change back to the caller. + resp.Header.Set(XETag1, etag1) t.Cache.Set(cacheKey, respBytes) } }, diff --git a/httpcache_test.go b/httpcache_test.go index 118da9e..106705b 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -245,8 +245,9 @@ func TestEnableETagPair(t *testing.T) { { _, resp := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world2"}) c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) - c.Assert(resp.Header.Get(XETag1), qt.Equals, "61b7d44bc024f189195b549bf094fbe8") - c.Assert(resp.Header.Get(XETag2), qt.Equals, "48b21a691481958c34cc165011bdb9bc") + c.Assert(resp.Header.Get(XETag1), qt.Equals, "48b21a691481958c34cc165011bdb9bc") + c.Assert(resp.Header.Get(XETag2), qt.Equals, "61b7d44bc024f189195b549bf094fbe8") + } } @@ -277,7 +278,6 @@ func TestShouldCache(t *testing.T) { s.transport.AlwaysUseCachedResponse = func(req *http.Request, key string) bool { return true } - s.transport.ShouldCache = func(req *http.Request, resp *http.Response, key string) bool { return req.Header.Get("Hello") == "world2" } @@ -295,6 +295,28 @@ func TestShouldCache(t *testing.T) { } } +func TestStaleCachedResponse(t *testing.T) { + resetTest() + s.transport.Cache = &staleCache{} + s.transport.AlwaysUseCachedResponse = func(req *http.Request, key string) bool { + return true + } + s.transport.EnableETagPair = true + c := qt.New(t) + { + _, resp := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world1"}) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XETag1), qt.Equals, "48b21a691481958c34cc165011bdb9bc") + c.Assert(resp.Header.Get(XETag2), qt.Equals, "48b21a691481958c34cc165011bdb9bc") + } + { + _, resp := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world2"}) + c.Assert(resp.StatusCode, qt.Equals, http.StatusOK) + c.Assert(resp.Header.Get(XETag1), qt.Equals, "48b21a691481958c34cc165011bdb9bc") + c.Assert(resp.Header.Get(XETag2), qt.Equals, "61b7d44bc024f189195b549bf094fbe8") + } +} + func TestAround(t *testing.T) { resetTest() c := qt.New(t) @@ -1420,3 +1442,25 @@ func (c *memoryCache) Delete(key string) { delete(c.items, key) c.mu.Unlock() } + +var _ Cache = &staleCache{} + +type staleCache struct { + val []byte +} + +func (c *staleCache) Get(key string) ([]byte, bool) { + return c.val, false +} + +func (c *staleCache) Set(key string, resp []byte) { + c.val = resp +} + +func (c *staleCache) Delete(key string) { + c.val = nil +} + +func (c *staleCache) Size() int { + return 1 +}