Skip to content

Commit

Permalink
Misc adjustments
Browse files Browse the repository at this point in the history
* Add CacheKey option
* Remove redundant nil check
* Optionally allow caching for other HTTP methods than GET and HEAD
  • Loading branch information
bep committed May 19, 2024
1 parent ef54744 commit 5292b6e
Showing 1 changed file with 34 additions and 17 deletions.
51 changes: 34 additions & 17 deletions httpcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,16 @@ type Cache interface {
}

// cacheKey returns the cache key for req.
func cacheKey(req *http.Request) string {
func (t *Transport) cacheKey(req *http.Request) string {
if t.CacheKey != nil {
return t.CacheKey(req)
}

cacheable := (req.Method != http.MethodHead || req.Method == "HEAD") && req.Header.Get("range") == ""
if !cacheable {
return ""
}

if req.Method == http.MethodGet {
return req.URL.String()
} else {
Expand All @@ -47,8 +56,8 @@ func cacheKey(req *http.Request) string {

// cachedResponse returns the cached http.Response for req if present, and nil
// otherwise.
func cachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) {
cachedVal, ok := c.Get(cacheKey(req))
func (t *Transport) cachedResponse(req *http.Request) (resp *http.Response, err error) {
cachedVal, ok := t.Cache.Get(t.cacheKey(req))
if !ok {
return
}
Expand Down Expand Up @@ -105,11 +114,15 @@ type Transport struct {
// If true, responses returned from the cache will be given an extra header, X-From-Cache
MarkCachedResponses bool

// CacheKey is an optional func that returns the key to use to store the response.
// An empty string signals that this request should not be cached.
CacheKey func(req *http.Request) string

// Around is an optional func.
// If set, the Transport will call Around at the start of RoundTrip
// and defer the returned func until the end of RoundTrip.
// Typically used to implement a lock that is held for the duration of the RoundTrip.
Around func(key string) func()
Around func(req *http.Request, key string) func()
}

// varyMatches will return false unless all of the cached values for the headers listed in Vary
Expand All @@ -133,14 +146,16 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool {
// to give the server a chance to respond with NotModified. If this happens, then the cached Response
// will be returned.
func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
cacheKey := cacheKey(req)
cacheKey := t.cacheKey(req)
if f := t.Around; f != nil {
defer f(cacheKey)()
defer f(req, cacheKey)()
}
cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == ""

cacheable := cacheKey != ""

var cachedResp *http.Response
if cacheable {
cachedResp, err = cachedResponse(t.Cache, req)
cachedResp, err = t.cachedResponse(req)
} else {
// Need to invalidate an existing value
t.Cache.Delete(cacheKey)
Expand Down Expand Up @@ -185,15 +200,16 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
}

resp, err = transport.RoundTrip(req)
if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified {

if err == nil && req.Method != http.MethodHead && resp.StatusCode == http.StatusNotModified {
// Replace the 304 response with the one from cache, but update with some new headers
endToEndHeaders := getEndToEndHeaders(resp.Header)
for _, header := range endToEndHeaders {
cachedResp.Header[header] = resp.Header[header]
}
resp = cachedResp
} else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) &&
req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) {
} else if (err != nil || resp.StatusCode >= 500) &&
req.Method != http.MethodHead && canStaleOnError(cachedResp.Header, req.Header) {
// In case of transport failure and stale-if-error activated, returns cached content
// when available
return cachedResp, nil
Expand Down Expand Up @@ -227,7 +243,12 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
}
}
switch req.Method {
case "GET":
case http.MethodHead:
respBytes, err := httputil.DumpResponse(resp, true)
if err == nil {
t.Cache.Set(cacheKey, respBytes)
}
default:
// Delay caching until EOF is reached.
resp.Body = &cachingReadCloser{
R: resp.Body,
Expand All @@ -240,11 +261,7 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
}
},
}
default:
respBytes, err := httputil.DumpResponse(resp, true)
if err == nil {
t.Cache.Set(cacheKey, respBytes)
}

}
} else {
t.Cache.Delete(cacheKey)
Expand Down

0 comments on commit 5292b6e

Please sign in to comment.