From ff2042caafc52ca8e48c30def149b6acf8bfe09b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Fri, 17 May 2024 17:48:12 +0200 Subject: [PATCH] Add AroundRoundTrip option, trim the API --- httpcache.go | 61 ++++++++++++++++++++++++----------------------- httpcache_test.go | 16 ++++++------- 2 files changed, 39 insertions(+), 38 deletions(-) diff --git a/httpcache.go b/httpcache.go index be4e7e8..efabe03 100644 --- a/httpcache.go +++ b/httpcache.go @@ -45,9 +45,9 @@ func cacheKey(req *http.Request) string { } } -// CachedResponse returns the cached http.Response for req if present, and nil +// cachedResponse returns the cached http.Response for req if present, and nil // otherwise. -func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) { +func cachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) { cachedVal, ok := c.Get(cacheKey(req)) if !ok { return @@ -57,14 +57,14 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) return http.ReadResponse(bufio.NewReader(b), req) } -// MemoryCache is an implemtation of Cache that stores responses in an in-memory map. -type MemoryCache struct { +// memoryCache is an implemtation of Cache that stores responses in an in-memory map. +type memoryCache struct { mu sync.RWMutex items map[string][]byte } // Get returns the []byte representation of the response and true if present, false if not -func (c *MemoryCache) Get(key string) (resp []byte, ok bool) { +func (c *memoryCache) Get(key string) (resp []byte, ok bool) { c.mu.RLock() resp, ok = c.items[key] c.mu.RUnlock() @@ -72,22 +72,22 @@ func (c *MemoryCache) Get(key string) (resp []byte, ok bool) { } // Set saves response resp to the cache with key -func (c *MemoryCache) Set(key string, resp []byte) { +func (c *memoryCache) Set(key string, resp []byte) { c.mu.Lock() c.items[key] = resp c.mu.Unlock() } // Delete removes key from the cache -func (c *MemoryCache) Delete(key string) { +func (c *memoryCache) Delete(key string) { c.mu.Lock() delete(c.items, key) c.mu.Unlock() } -// NewMemoryCache returns a new Cache that will store items in an in-memory map -func NewMemoryCache() *MemoryCache { - c := &MemoryCache{items: map[string][]byte{}} +// newMemoryCache returns a new Cache that will store items in an in-memory map +func newMemoryCache() *memoryCache { + c := &memoryCache{items: map[string][]byte{}} return c } @@ -98,20 +98,18 @@ type Transport struct { // The RoundTripper interface actually used to make requests // If nil, http.DefaultTransport is used Transport http.RoundTripper - Cache Cache + + // The Cache interface used to store and retrieve responses. + Cache Cache + // If true, responses returned from the cache will be given an extra header, X-From-Cache MarkCachedResponses bool -} -// NewTransport returns a new Transport with the -// provided Cache implementation and MarkCachedResponses set to true -func NewTransport(c Cache) *Transport { - return &Transport{Cache: c, MarkCachedResponses: true} -} - -// Client returns an *http.Client that caches responses. -func (t *Transport) Client() *http.Client { - return &http.Client{Transport: t} + // AroundRoundTrip is an optional func. + // If set, the Transport will call AroundRoundTrip at the start of RoundTrip + // and defer the returned func until the end of RoundTrip. + // Typically used to implement a lock that is held for the duration of the RoundTrip. + AroundRoundTrip func(key string) func() } // varyMatches will return false unless all of the cached values for the headers listed in Vary @@ -136,10 +134,13 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool { // will be returned. func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { cacheKey := cacheKey(req) + if f := t.AroundRoundTrip; f != nil { + defer f(cacheKey)() + } cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" var cachedResp *http.Response if cacheable { - cachedResp, err = CachedResponse(t.Cache, req) + cachedResp, err = cachedResponse(t.Cache, req) } else { // Need to invalidate an existing value t.Cache.Delete(cacheKey) @@ -254,8 +255,8 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error // ErrNoDateHeader indicates that the HTTP headers contained no Date header. var ErrNoDateHeader = errors.New("no Date header") -// Date parses and returns the value of the Date header. -func Date(respHeaders http.Header) (date time.Time, err error) { +// date parses and returns the value of the date header. +func date(respHeaders http.Header) (date time.Time, err error) { dateHeader := respHeaders.Get("date") if dateHeader == "" { err = ErrNoDateHeader @@ -299,7 +300,7 @@ func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { return fresh } - date, err := Date(respHeaders) + date, err := date(respHeaders) if err != nil { return stale } @@ -398,7 +399,7 @@ func canStaleOnError(respHeaders, reqHeaders http.Header) bool { } if lifetime >= 0 { - date, err := Date(respHeaders) + date, err := date(respHeaders) if err != nil { return false } @@ -541,9 +542,9 @@ func (r *cachingReadCloser) Close() error { return r.R.Close() } -// NewMemoryCacheTransport returns a new Transport using the in-memory cache implementation -func NewMemoryCacheTransport() *Transport { - c := NewMemoryCache() - t := NewTransport(c) +// newMemoryCacheTransport returns a new Transport using the in-memory cache implementation +func newMemoryCacheTransport() *Transport { + c := newMemoryCache() + t := &Transport{Cache: c, MarkCachedResponses: true} return t } diff --git a/httpcache_test.go b/httpcache_test.go index 61ed4f8..79167d6 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -37,7 +37,7 @@ func TestMain(m *testing.M) { } func setup() { - tp := NewMemoryCacheTransport() + tp := newMemoryCacheTransport() client := http.Client{Transport: tp} s.transport = tp s.client = client @@ -165,7 +165,7 @@ func teardown() { } func resetTest() { - s.transport.Cache = NewMemoryCache() + s.transport.Cache = newMemoryCache() clock = &realClock{} } @@ -1206,7 +1206,7 @@ func TestStaleIfErrorRequest(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport() + tp := newMemoryCacheTransport() tp.Transport = &tmock // First time, response is cached on success @@ -1251,7 +1251,7 @@ func TestStaleIfErrorRequestLifetime(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport() + tp := newMemoryCacheTransport() tp.Transport = &tmock // First time, response is cached on success @@ -1314,7 +1314,7 @@ func TestStaleIfErrorResponse(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport() + tp := newMemoryCacheTransport() tp.Transport = &tmock // First time, response is cached on success @@ -1358,7 +1358,7 @@ func TestStaleIfErrorResponseLifetime(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport() + tp := newMemoryCacheTransport() tp.Transport = &tmock // First time, response is cached on success @@ -1412,7 +1412,7 @@ func TestStaleIfErrorKeepsStatus(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport() + tp := newMemoryCacheTransport() tp.Transport = &tmock // First time, response is cached on success @@ -1456,7 +1456,7 @@ func TestClientTimeout(t *testing.T) { } resetTest() client := &http.Client{ - Transport: NewMemoryCacheTransport(), + Transport: newMemoryCacheTransport(), Timeout: time.Second, } started := time.Now()