From 00236142b39cf9312abbd660e6be9a9e0041161d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Wed, 22 May 2024 19:12:45 +0200 Subject: [PATCH] Add ShouldCache option --- httpcache.go | 5 ++++- httpcache_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/httpcache.go b/httpcache.go index d662112..9faa754 100644 --- a/httpcache.go +++ b/httpcache.go @@ -105,6 +105,9 @@ type Transport struct { // a successful response from the cache will be returned without connecting to the server. AlwaysUseCachedResponse func(req *http.Request, key string) bool + // ShouldCache is an optional func that when it returns false, the response will not be cached. + ShouldCache func(req *http.Request, resp *http.Response, key string) bool + // Around is an optional func. // If set, the Transport will call Around at the start of RoundTrip // and defer the returned func until the end of RoundTrip. @@ -228,7 +231,7 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error } } - if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { + if cacheable && (t.ShouldCache == nil || t.ShouldCache(req, resp, cacheKey)) && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") { varyKey = http.CanonicalHeaderKey(varyKey) fakeHeader := "X-Varied-" + varyKey diff --git a/httpcache_test.go b/httpcache_test.go index 26df046..60a7e37 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -269,6 +269,31 @@ func TestAlwaysUseCachedResponse(t *testing.T) { } } +// ShouldCache +func TestShouldCache(t *testing.T) { + resetTest() + c := qt.New(t) + s.transport.AlwaysUseCachedResponse = func(req *http.Request, key string) bool { + return true + } + + s.transport.ShouldCache = func(req *http.Request, resp *http.Response, key string) bool { + return req.Header.Get("Hello") == "world2" + } + { + s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world1"}) + c.Assert(s, qt.Equals, "world1") + } + { + s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world2"}) + c.Assert(s, qt.Equals, "world2") + } + { + s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world3"}) + c.Assert(s, qt.Equals, "world2") + } +} + func TestAround(t *testing.T) { resetTest() c := qt.New(t)