From 02503210b401df5b0735638212aea3ae61d5a839 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Tue, 21 May 2024 17:51:00 +0200 Subject: [PATCH] Add AlwaysUseCachedResponse option func --- httpcache.go | 7 +++++++ httpcache_test.go | 21 +++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/httpcache.go b/httpcache.go index 1e560cc..d662112 100644 --- a/httpcache.go +++ b/httpcache.go @@ -101,6 +101,10 @@ type Transport struct { // An empty string signals that this request should not be cached. CacheKey func(req *http.Request) string + // AlwaysUseCachedResponse is an optional func that when it returns true + // a successful response from the cache will be returned without connecting to the server. + AlwaysUseCachedResponse func(req *http.Request, key string) bool + // Around is an optional func. // If set, the Transport will call Around at the start of RoundTrip // and defer the returned func until the end of RoundTrip. @@ -141,6 +145,9 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error var cachedResp *http.Response if cacheable { cachedResp, err = t.cachedResponse(req) + if err == nil && cachedResp != nil && t.AlwaysUseCachedResponse != nil && t.AlwaysUseCachedResponse(req, cacheKey) { + return cachedResp, nil + } } else { // Need to invalidate an existing value t.Cache.Delete(cacheKey) diff --git a/httpcache_test.go b/httpcache_test.go index 491dccc..26df046 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -248,6 +248,27 @@ func TestEnableETagPair(t *testing.T) { } } +func TestAlwaysUseCachedResponse(t *testing.T) { + resetTest() + c := qt.New(t) + s.transport.AlwaysUseCachedResponse = func(req *http.Request, key string) bool { + return req.Header.Get("Hello") == "world2" + } + + { + s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world1"}) + c.Assert(s, qt.Equals, "world1") + } + { + s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world2"}) + c.Assert(s, qt.Equals, "world1") + } + { + s, _ := doMethod(t, "GET", "/helloheaderasbody", map[string]string{"Hello": "world3"}) + c.Assert(s, qt.Equals, "world3") + } +} + func TestAround(t *testing.T) { resetTest() c := qt.New(t)