From 5292b6e4940978a022107ae6b0a1b723dc9a1300 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Sat, 18 May 2024 11:58:56 +0200 Subject: [PATCH] Misc adjustments * Add CacheKey option * Remove redundant nil check * Optionally allow caching for other HTTP methods than GET and HEAD --- httpcache.go | 51 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/httpcache.go b/httpcache.go index b44c99e..8956215 100644 --- a/httpcache.go +++ b/httpcache.go @@ -37,7 +37,16 @@ type Cache interface { } // cacheKey returns the cache key for req. -func cacheKey(req *http.Request) string { +func (t *Transport) cacheKey(req *http.Request) string { + if t.CacheKey != nil { + return t.CacheKey(req) + } + + cacheable := (req.Method != http.MethodHead || req.Method == "HEAD") && req.Header.Get("range") == "" + if !cacheable { + return "" + } + if req.Method == http.MethodGet { return req.URL.String() } else { @@ -47,8 +56,8 @@ func cacheKey(req *http.Request) string { // cachedResponse returns the cached http.Response for req if present, and nil // otherwise. -func cachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) { - cachedVal, ok := c.Get(cacheKey(req)) +func (t *Transport) cachedResponse(req *http.Request) (resp *http.Response, err error) { + cachedVal, ok := t.Cache.Get(t.cacheKey(req)) if !ok { return } @@ -105,11 +114,15 @@ type Transport struct { // If true, responses returned from the cache will be given an extra header, X-From-Cache MarkCachedResponses bool + // CacheKey is an optional func that returns the key to use to store the response. + // An empty string signals that this request should not be cached. + CacheKey func(req *http.Request) string + // Around is an optional func. // If set, the Transport will call Around at the start of RoundTrip // and defer the returned func until the end of RoundTrip. // Typically used to implement a lock that is held for the duration of the RoundTrip. - Around func(key string) func() + Around func(req *http.Request, key string) func() } // varyMatches will return false unless all of the cached values for the headers listed in Vary @@ -133,14 +146,16 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool { // to give the server a chance to respond with NotModified. If this happens, then the cached Response // will be returned. func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { - cacheKey := cacheKey(req) + cacheKey := t.cacheKey(req) if f := t.Around; f != nil { - defer f(cacheKey)() + defer f(req, cacheKey)() } - cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" + + cacheable := cacheKey != "" + var cachedResp *http.Response if cacheable { - cachedResp, err = cachedResponse(t.Cache, req) + cachedResp, err = t.cachedResponse(req) } else { // Need to invalidate an existing value t.Cache.Delete(cacheKey) @@ -185,15 +200,16 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error } resp, err = transport.RoundTrip(req) - if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified { + + if err == nil && req.Method != http.MethodHead && resp.StatusCode == http.StatusNotModified { // Replace the 304 response with the one from cache, but update with some new headers endToEndHeaders := getEndToEndHeaders(resp.Header) for _, header := range endToEndHeaders { cachedResp.Header[header] = resp.Header[header] } resp = cachedResp - } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && - req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) { + } else if (err != nil || resp.StatusCode >= 500) && + req.Method != http.MethodHead && canStaleOnError(cachedResp.Header, req.Header) { // In case of transport failure and stale-if-error activated, returns cached content // when available return cachedResp, nil @@ -227,7 +243,12 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error } } switch req.Method { - case "GET": + case http.MethodHead: + respBytes, err := httputil.DumpResponse(resp, true) + if err == nil { + t.Cache.Set(cacheKey, respBytes) + } + default: // Delay caching until EOF is reached. resp.Body = &cachingReadCloser{ R: resp.Body, @@ -240,11 +261,7 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error } }, } - default: - respBytes, err := httputil.DumpResponse(resp, true) - if err == nil { - t.Cache.Set(cacheKey, respBytes) - } + } } else { t.Cache.Delete(cacheKey)