Skip to content

Commit

Permalink
Support stale cache entries
Browse files Browse the repository at this point in the history
To support eTag/change detection for entries marked stale outside of this library.

Also switch the values in XETag1 and XETag2 so they are ordered by old/new.
  • Loading branch information
bep committed Jun 4, 2024
1 parent d79f2a3 commit d62773c
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 26 deletions.
67 changes: 44 additions & 23 deletions httpcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ const (
XETag1 = xEtags + "1"

// XETag2 is the key for the second eTag value.
// Note that in the cache, XETag1 and XETag2 will always be the same.
// In the Response returned from Response, XETag1 will be the cached value (old) and
// XETag2 will be the eTag value from the server (new).
XETag2 = xEtags + "2"
)

// A Cache interface is used by the Transport to store and retrieve responses.
type Cache interface {
// Get returns the []byte representation of a cached response and a bool
// set to true if the value isn't empty
// set to set to false if the key is not found or the value is stale.
Get(key string) (responseBytes []byte, ok bool)
// Set stores the []byte representation of a response against a key
Set(key string, responseBytes []byte)
Expand All @@ -65,16 +68,19 @@ func (t *Transport) cacheKey(req *http.Request) string {
}
}

// cachedResponse returns the cached http.Response for req if present, and nil
// otherwise.
func (t *Transport) cachedResponse(req *http.Request) (resp *http.Response, err error) {
// cachedResponse returns the cached http.Response for req if present and
// a bool set to false if the value is stale.
func (t *Transport) cachedResponse(req *http.Request) (*http.Response, bool, error) {
cachedVal, ok := t.Cache.Get(t.cacheKey(req))
if !ok {
return
if !ok && len(cachedVal) == 0 {
return nil, false, nil
}

b := bytes.NewBuffer(cachedVal)
return http.ReadResponse(bufio.NewReader(b), req)
resp, err := http.ReadResponse(bufio.NewReader(b), req)
if err != nil {
return nil, false, err
}
return resp, ok, nil
}

// Transport is an implementation of http.RoundTripper that will return values from a cache
Expand Down Expand Up @@ -145,10 +151,13 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error

cacheable := cacheKey != ""

var cachedResp *http.Response
var (
cachedResp *http.Response
hasCachedResp bool
)
if cacheable {
cachedResp, err = t.cachedResponse(req)
if err == nil && cachedResp != nil && t.AlwaysUseCachedResponse != nil && t.AlwaysUseCachedResponse(req, cacheKey) {
cachedResp, hasCachedResp, err = t.cachedResponse(req)
if err == nil && hasCachedResp && t.AlwaysUseCachedResponse != nil && t.AlwaysUseCachedResponse(req, cacheKey) {
return cachedResp, nil
}
} else {
Expand All @@ -161,13 +170,16 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
transport = http.DefaultTransport
}

if cacheable && cachedResp != nil && err == nil {
if t.MarkCachedResponses {
cachedResp.Header.Set(XFromCache, "1")
}
if cachedResp != nil {
if t.EnableETagPair {
cachedXEtag, _ = getXETags(cachedResp.Header)
}
}

if cacheable && hasCachedResp && err == nil {
if t.MarkCachedResponses {
cachedResp.Header.Set(XFromCache, "1")
}

if varyMatches(cachedResp, req) {
// Can only use cached value if the new request doesn't Vary significantly
Expand Down Expand Up @@ -247,16 +259,19 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
t.Cache.Set(cacheKey, respBytes)
}
default:
var etagHash hash.Hash
var (
etagHash hash.Hash
etag1 = cachedXEtag
etag2 string
)

r := resp.Body
if t.EnableETagPair {
if etag := resp.Header.Get("etag"); etag != "" {
resp.Header.Set(XETag1, etag)
etag2 := cachedXEtag
etag1 = etag
if etag2 == "" {
etag2 = etag
}
resp.Header.Set(XETag2, etag2)
} else {
etagHash = md5.New()
r = struct {
Expand All @@ -274,17 +289,23 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
OnEOF: func(r io.Reader) {
if etagHash != nil {
md5Str := hex.EncodeToString(etagHash.Sum(nil))
etag2 = md5Str
resp.Header.Set(XETag1, md5Str)
etag2 := cachedXEtag
if etag2 == "" {
etag2 = md5Str
resp.Header.Set(XETag2, md5Str)
if etag1 == "" {
etag1 = md5Str
}
resp.Header.Set(XETag2, etag2)
} else {
resp.Header.Set(XETag1, etag1)
resp.Header.Set(XETag2, etag1)
}

resp := *resp
resp.Body = io.NopCloser(r)
respBytes, err := httputil.DumpResponse(&resp, true)
if err == nil {
// Signal any change back to the caller.
resp.Header.Set(XETag1, etag1)
t.Cache.Set(cacheKey, respBytes)
}
},
Expand Down
50 changes: 47 additions & 3 deletions httpcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,9 @@ func TestEnableETagPair(t *testing.T) {
{
_, resp := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world2"})
c.Assert(resp.StatusCode, qt.Equals, http.StatusOK)
c.Assert(resp.Header.Get(XETag1), qt.Equals, "61b7d44bc024f189195b549bf094fbe8")
c.Assert(resp.Header.Get(XETag2), qt.Equals, "48b21a691481958c34cc165011bdb9bc")
c.Assert(resp.Header.Get(XETag1), qt.Equals, "48b21a691481958c34cc165011bdb9bc")
c.Assert(resp.Header.Get(XETag2), qt.Equals, "61b7d44bc024f189195b549bf094fbe8")

}
}

Expand Down Expand Up @@ -277,7 +278,6 @@ func TestShouldCache(t *testing.T) {
s.transport.AlwaysUseCachedResponse = func(req *http.Request, key string) bool {
return true
}

s.transport.ShouldCache = func(req *http.Request, resp *http.Response, key string) bool {
return req.Header.Get("Hello") == "world2"
}
Expand All @@ -295,6 +295,28 @@ func TestShouldCache(t *testing.T) {
}
}

func TestStaleCachedResponse(t *testing.T) {
resetTest()
s.transport.Cache = &staleCache{}
s.transport.AlwaysUseCachedResponse = func(req *http.Request, key string) bool {
return true
}
s.transport.EnableETagPair = true
c := qt.New(t)
{
_, resp := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world1"})
c.Assert(resp.StatusCode, qt.Equals, http.StatusOK)
c.Assert(resp.Header.Get(XETag1), qt.Equals, "48b21a691481958c34cc165011bdb9bc")
c.Assert(resp.Header.Get(XETag2), qt.Equals, "48b21a691481958c34cc165011bdb9bc")
}
{
_, resp := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world2"})
c.Assert(resp.StatusCode, qt.Equals, http.StatusOK)
c.Assert(resp.Header.Get(XETag1), qt.Equals, "48b21a691481958c34cc165011bdb9bc")
c.Assert(resp.Header.Get(XETag2), qt.Equals, "61b7d44bc024f189195b549bf094fbe8")
}
}

func TestAround(t *testing.T) {
resetTest()
c := qt.New(t)
Expand Down Expand Up @@ -1420,3 +1442,25 @@ func (c *memoryCache) Delete(key string) {
delete(c.items, key)
c.mu.Unlock()
}

var _ Cache = &staleCache{}

type staleCache struct {
val []byte
}

func (c *staleCache) Get(key string) ([]byte, bool) {
return c.val, false
}

func (c *staleCache) Set(key string, resp []byte) {
c.val = resp
}

func (c *staleCache) Delete(key string) {
c.val = nil
}

func (c *staleCache) Size() int {
return 1
}

0 comments on commit d62773c

Please sign in to comment.