Skip to content

Commit

Permalink
Add ShouldCache option
Browse files Browse the repository at this point in the history
  • Loading branch information
bep committed May 22, 2024
1 parent 7a0d97d commit d79f2a3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
5 changes: 4 additions & 1 deletion httpcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ type Transport struct {
// a successful response from the cache will be returned without connecting to the server.
AlwaysUseCachedResponse func(req *http.Request, key string) bool

// ShouldCache is an optional func that when it returns false, the response will not be cached.
ShouldCache func(req *http.Request, resp *http.Response, key string) bool

// Around is an optional func.
// If set, the Transport will call Around at the start of RoundTrip
// and defer the returned func until the end of RoundTrip.
Expand Down Expand Up @@ -228,7 +231,7 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
}
}

if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) {
if cacheable && (t.ShouldCache == nil || t.ShouldCache(req, resp, cacheKey)) && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) {
for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") {
varyKey = http.CanonicalHeaderKey(varyKey)
fakeHeader := "X-Varied-" + varyKey
Expand Down
26 changes: 26 additions & 0 deletions httpcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ func cacheSize() int {
func resetTest() {
s.transport.Cache = newMemoryCache()
s.transport.CacheKey = nil
s.transport.AlwaysUseCachedResponse = nil
s.transport.ShouldCache = nil
s.transport.EnableETagPair = false
s.transport.MarkCachedResponses = false
clock = &realClock{}
Expand Down Expand Up @@ -269,6 +271,30 @@ func TestAlwaysUseCachedResponse(t *testing.T) {
}
}

func TestShouldCache(t *testing.T) {
resetTest()
c := qt.New(t)
s.transport.AlwaysUseCachedResponse = func(req *http.Request, key string) bool {
return true
}

s.transport.ShouldCache = func(req *http.Request, resp *http.Response, key string) bool {
return req.Header.Get("Hello") == "world2"
}
{
s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world1"})
c.Assert(s, qt.Equals, "world1")
}
{
s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world2"})
c.Assert(s, qt.Equals, "world2")
}
{
s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world3"})
c.Assert(s, qt.Equals, "world2")
}
}

func TestAround(t *testing.T) {
resetTest()
c := qt.New(t)
Expand Down

0 comments on commit d79f2a3

Please sign in to comment.