From d79f2a37eca5ec6084176e32e59e10dcfbf4f213 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Wed, 22 May 2024 19:12:45 +0200 Subject: [PATCH] Add ShouldCache option --- httpcache.go | 5 ++++- httpcache_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/httpcache.go b/httpcache.go index d662112..9faa754 100644 --- a/httpcache.go +++ b/httpcache.go @@ -105,6 +105,9 @@ type Transport struct { // a successful response from the cache will be returned without connecting to the server. AlwaysUseCachedResponse func(req *http.Request, key string) bool + // ShouldCache is an optional func that when it returns false, the response will not be cached. + ShouldCache func(req *http.Request, resp *http.Response, key string) bool + // Around is an optional func. // If set, the Transport will call Around at the start of RoundTrip // and defer the returned func until the end of RoundTrip. @@ -228,7 +231,7 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error } } - if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { + if cacheable && (t.ShouldCache == nil || t.ShouldCache(req, resp, cacheKey)) && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") { varyKey = http.CanonicalHeaderKey(varyKey) fakeHeader := "X-Varied-" + varyKey diff --git a/httpcache_test.go b/httpcache_test.go index 26df046..118da9e 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -178,6 +178,8 @@ func cacheSize() int { func resetTest() { s.transport.Cache = newMemoryCache() s.transport.CacheKey = nil + s.transport.AlwaysUseCachedResponse = nil + s.transport.ShouldCache = nil s.transport.EnableETagPair = false s.transport.MarkCachedResponses = false clock = &realClock{} @@ -269,6 +271,30 @@ func TestAlwaysUseCachedResponse(t *testing.T) { } } +func TestShouldCache(t *testing.T) { + resetTest() + c := qt.New(t) + s.transport.AlwaysUseCachedResponse = func(req *http.Request, key string) bool { + return true + } + + s.transport.ShouldCache = func(req *http.Request, resp *http.Response, key string) bool { + return req.Header.Get("Hello") == "world2" + } + { + s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world1"}) + c.Assert(s, qt.Equals, "world1") + } + { + s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world2"}) + c.Assert(s, qt.Equals, "world2") + } + { + s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world3"}) + c.Assert(s, qt.Equals, "world2") + } +} + func TestAround(t *testing.T) { resetTest() c := qt.New(t)