Skip to content

Commit

Permalink
Misc adjustments
Browse files Browse the repository at this point in the history
* Add EnableETagPair option to add a pair of eTags, even if the server does not provide one
* Add CacheKey option
* Remove redundant nil check
* Optionally allow caching for other HTTP methods than GET and HEAD
  • Loading branch information
bep committed May 20, 2024
1 parent ef54744 commit 2e6edbb
Showing 1 changed file with 92 additions and 22 deletions.
114 changes: 92 additions & 22 deletions httpcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ package httpcache
import (
"bufio"
"bytes"
"crypto/md5"
"encoding/hex"
"errors"
"hash"
"io"
"net/http"
"net/http/httputil"
Expand All @@ -23,6 +26,15 @@ const (
transparent
// XFromCache is the header added to responses that are returned from the cache
XFromCache = "X-From-Cache"

// xEtags is the prefix for the header with the custom etag pair set in the cached response.
xEtags = "X-Etags-"

// XETag1 is the key for the first eTag value.
XETag1 = xEtags + "1"

// XETag2 is the key for the second eTag value.
XETag2 = xEtags + "2"
)

// A Cache interface is used by the Transport to store and retrieve responses.
Expand All @@ -37,7 +49,16 @@ type Cache interface {
}

// cacheKey returns the cache key for req.
func cacheKey(req *http.Request) string {
func (t *Transport) cacheKey(req *http.Request) string {
if t.CacheKey != nil {
return t.CacheKey(req)
}

cacheable := (req.Method != http.MethodHead || req.Method == "HEAD") && req.Header.Get("range") == ""
if !cacheable {
return ""
}

if req.Method == http.MethodGet {
return req.URL.String()
} else {
Expand All @@ -47,8 +68,8 @@ func cacheKey(req *http.Request) string {

// cachedResponse returns the cached http.Response for req if present, and nil
// otherwise.
func cachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) {
cachedVal, ok := c.Get(cacheKey(req))
func (t *Transport) cachedResponse(req *http.Request) (resp *http.Response, err error) {
cachedVal, ok := t.Cache.Get(t.cacheKey(req))
if !ok {
return
}
Expand Down Expand Up @@ -105,11 +126,21 @@ type Transport struct {
// If true, responses returned from the cache will be given an extra header, X-From-Cache
MarkCachedResponses bool

// if EnableETagPair is true, the Transport will store the pair of eTags in the response header.
// These are stored in the X-Etags-1 and X-Etags-2 headers.
// If these are different, the response has been modified.
// If the server does not return an eTag, the MD5 hash of the response body is used.
EnableETagPair bool

// CacheKey is an optional func that returns the key to use to store the response.
// An empty string signals that this request should not be cached.
CacheKey func(req *http.Request) string

// Around is an optional func.
// If set, the Transport will call Around at the start of RoundTrip
// and defer the returned func until the end of RoundTrip.
// Typically used to implement a lock that is held for the duration of the RoundTrip.
Around func(key string) func()
Around func(req *http.Request, key string) func()
}

// varyMatches will return false unless all of the cached values for the headers listed in Vary
Expand All @@ -133,14 +164,18 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool {
// to give the server a chance to respond with NotModified. If this happens, then the cached Response
// will be returned.
func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
cacheKey := cacheKey(req)
cacheKey := t.cacheKey(req)
if f := t.Around; f != nil {
defer f(cacheKey)()
defer f(req, cacheKey)()
}
cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == ""

var cachedXEtag string

cacheable := cacheKey != ""

var cachedResp *http.Response
if cacheable {
cachedResp, err = cachedResponse(t.Cache, req)
cachedResp, err = t.cachedResponse(req)
} else {
// Need to invalidate an existing value
t.Cache.Delete(cacheKey)
Expand All @@ -155,6 +190,9 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
if t.MarkCachedResponses {
cachedResp.Header.Set(XFromCache, "1")
}
if t.EnableETagPair {
cachedXEtag, _ = getXETags(cachedResp.Header)
}

if varyMatches(cachedResp, req) {
// Can only use cached value if the new request doesn't Vary significantly
Expand Down Expand Up @@ -185,15 +223,16 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
}

resp, err = transport.RoundTrip(req)
if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified {

if err == nil && req.Method != http.MethodHead && resp.StatusCode == http.StatusNotModified {
// Replace the 304 response with the one from cache, but update with some new headers
endToEndHeaders := getEndToEndHeaders(resp.Header)
for _, header := range endToEndHeaders {
cachedResp.Header[header] = resp.Header[header]
}
resp = cachedResp
} else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) &&
req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) {
} else if (err != nil || resp.StatusCode >= 500) &&
req.Method != http.MethodHead && canStaleOnError(cachedResp.Header, req.Header) {
// In case of transport failure and stale-if-error activated, returns cached content
// when available
return cachedResp, nil
Expand Down Expand Up @@ -227,24 +266,51 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
}
}
switch req.Method {
case "GET":
// Delay caching until EOF is reached.
resp.Body = &cachingReadCloser{
R: resp.Body,
case http.MethodHead:
respBytes, err := httputil.DumpResponse(resp, true)
if err == nil {
t.Cache.Set(cacheKey, respBytes)
}
default:
var etagHash hash.Hash
r := resp.Body
if t.EnableETagPair {
if etag := resp.Header.Get("etag"); etag != "" {
resp.Header.Set(XETag1, etag)
resp.Header.Set(XETag2, cachedXEtag)
} else {
etagHash = md5.New()
r = struct {
io.Reader
io.Closer
}{
io.TeeReader(r, etagHash),
resp.Body,
}
}
}

r = &cachingReadCloser{
R: r,
OnEOF: func(r io.Reader) {
if etagHash != nil {
md5Str := hex.EncodeToString(etagHash.Sum(nil))
resp.Header.Set(XETag1, md5Str)
resp.Header.Set(XETag2, cachedXEtag)

}
resp := *resp
resp.Body = io.NopCloser(r)
respBytes, err := httputil.DumpResponse(&resp, true)
if err == nil {
t.Cache.Set(cacheKey, respBytes)
}
},
buf: &bytes.Buffer{},
}
default:
respBytes, err := httputil.DumpResponse(resp, true)
if err == nil {
t.Cache.Set(cacheKey, respBytes)
}
// Delay caching until EOF is reached.
resp.Body = r

}
} else {
t.Cache.Delete(cacheKey)
Expand Down Expand Up @@ -278,6 +344,10 @@ type timer interface {

var clock timer = &realClock{}

func getXETags(h http.Header) (string, string) {
return h.Get(XETag1), h.Get(XETag2)
}

// getFreshness will return one of fresh/stale/transparent based on the cache-control
// values of the request and the response
//
Expand Down Expand Up @@ -522,7 +592,7 @@ type cachingReadCloser struct {
// OnEOF is called with a copy of the content of R when EOF is reached.
OnEOF func(io.Reader)

buf bytes.Buffer // buf stores a copy of the content of R.
buf *bytes.Buffer // buf stores a copy of the content of R.
}

// Read reads the next len(p) bytes from R or until R is drained. The
Expand All @@ -533,7 +603,7 @@ func (r *cachingReadCloser) Read(p []byte) (n int, err error) {
n, err = r.R.Read(p)
r.buf.Write(p[:n])
if err == io.EOF {
r.OnEOF(bytes.NewReader(r.buf.Bytes()))
r.OnEOF(r.buf)
}
return n, err
}
Expand Down

0 comments on commit 2e6edbb

Please sign in to comment.