Skip to content

Commit

Permalink
Add AroundRoundTrip option, trim the API
Browse files Browse the repository at this point in the history
  • Loading branch information
bep committed May 17, 2024
1 parent 10eb476 commit ff2042c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 38 deletions.
61 changes: 31 additions & 30 deletions httpcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ func cacheKey(req *http.Request) string {
}
}

// CachedResponse returns the cached http.Response for req if present, and nil
// cachedResponse returns the cached http.Response for req if present, and nil
// otherwise.
func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) {
func cachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) {
cachedVal, ok := c.Get(cacheKey(req))
if !ok {
return
Expand All @@ -57,37 +57,37 @@ func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error)
return http.ReadResponse(bufio.NewReader(b), req)
}

// MemoryCache is an implemtation of Cache that stores responses in an in-memory map.
type MemoryCache struct {
// memoryCache is an implemtation of Cache that stores responses in an in-memory map.
type memoryCache struct {
mu sync.RWMutex
items map[string][]byte
}

// Get returns the []byte representation of the response and true if present, false if not
func (c *MemoryCache) Get(key string) (resp []byte, ok bool) {
func (c *memoryCache) Get(key string) (resp []byte, ok bool) {
c.mu.RLock()
resp, ok = c.items[key]
c.mu.RUnlock()
return resp, ok
}

// Set saves response resp to the cache with key
func (c *MemoryCache) Set(key string, resp []byte) {
func (c *memoryCache) Set(key string, resp []byte) {
c.mu.Lock()
c.items[key] = resp
c.mu.Unlock()
}

// Delete removes key from the cache
func (c *MemoryCache) Delete(key string) {
func (c *memoryCache) Delete(key string) {
c.mu.Lock()
delete(c.items, key)
c.mu.Unlock()
}

// NewMemoryCache returns a new Cache that will store items in an in-memory map
func NewMemoryCache() *MemoryCache {
c := &MemoryCache{items: map[string][]byte{}}
// newMemoryCache returns a new Cache that will store items in an in-memory map
func newMemoryCache() *memoryCache {
c := &memoryCache{items: map[string][]byte{}}
return c
}

Expand All @@ -98,20 +98,18 @@ type Transport struct {
// The RoundTripper interface actually used to make requests
// If nil, http.DefaultTransport is used
Transport http.RoundTripper
Cache Cache

// The Cache interface used to store and retrieve responses.
Cache Cache

// If true, responses returned from the cache will be given an extra header, X-From-Cache
MarkCachedResponses bool
}

// NewTransport returns a new Transport with the
// provided Cache implementation and MarkCachedResponses set to true
func NewTransport(c Cache) *Transport {
return &Transport{Cache: c, MarkCachedResponses: true}
}

// Client returns an *http.Client that caches responses.
func (t *Transport) Client() *http.Client {
return &http.Client{Transport: t}
// AroundRoundTrip is an optional func.
// If set, the Transport will call AroundRoundTrip at the start of RoundTrip
// and defer the returned func until the end of RoundTrip.
// Typically used to implement a lock that is held for the duration of the RoundTrip.
AroundRoundTrip func(key string) func()
}

// varyMatches will return false unless all of the cached values for the headers listed in Vary
Expand All @@ -136,10 +134,13 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool {
// will be returned.
func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
cacheKey := cacheKey(req)
if f := t.AroundRoundTrip; f != nil {
defer f(cacheKey)()
}
cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == ""
var cachedResp *http.Response
if cacheable {
cachedResp, err = CachedResponse(t.Cache, req)
cachedResp, err = cachedResponse(t.Cache, req)
} else {
// Need to invalidate an existing value
t.Cache.Delete(cacheKey)
Expand Down Expand Up @@ -254,8 +255,8 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
// ErrNoDateHeader indicates that the HTTP headers contained no Date header.
var ErrNoDateHeader = errors.New("no Date header")

// Date parses and returns the value of the Date header.
func Date(respHeaders http.Header) (date time.Time, err error) {
// date parses and returns the value of the date header.
func date(respHeaders http.Header) (date time.Time, err error) {
dateHeader := respHeaders.Get("date")
if dateHeader == "" {
err = ErrNoDateHeader
Expand Down Expand Up @@ -299,7 +300,7 @@ func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) {
return fresh
}

date, err := Date(respHeaders)
date, err := date(respHeaders)
if err != nil {
return stale
}
Expand Down Expand Up @@ -398,7 +399,7 @@ func canStaleOnError(respHeaders, reqHeaders http.Header) bool {
}

if lifetime >= 0 {
date, err := Date(respHeaders)
date, err := date(respHeaders)
if err != nil {
return false
}
Expand Down Expand Up @@ -541,9 +542,9 @@ func (r *cachingReadCloser) Close() error {
return r.R.Close()
}

// NewMemoryCacheTransport returns a new Transport using the in-memory cache implementation
func NewMemoryCacheTransport() *Transport {
c := NewMemoryCache()
t := NewTransport(c)
// newMemoryCacheTransport returns a new Transport using the in-memory cache implementation
func newMemoryCacheTransport() *Transport {
c := newMemoryCache()
t := &Transport{Cache: c, MarkCachedResponses: true}
return t
}
16 changes: 8 additions & 8 deletions httpcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestMain(m *testing.M) {
}

func setup() {
tp := NewMemoryCacheTransport()
tp := newMemoryCacheTransport()
client := http.Client{Transport: tp}
s.transport = tp
s.client = client
Expand Down Expand Up @@ -165,7 +165,7 @@ func teardown() {
}

func resetTest() {
s.transport.Cache = NewMemoryCache()
s.transport.Cache = newMemoryCache()
clock = &realClock{}
}

Expand Down Expand Up @@ -1206,7 +1206,7 @@ func TestStaleIfErrorRequest(t *testing.T) {
},
err: nil,
}
tp := NewMemoryCacheTransport()
tp := newMemoryCacheTransport()
tp.Transport = &tmock

// First time, response is cached on success
Expand Down Expand Up @@ -1251,7 +1251,7 @@ func TestStaleIfErrorRequestLifetime(t *testing.T) {
},
err: nil,
}
tp := NewMemoryCacheTransport()
tp := newMemoryCacheTransport()
tp.Transport = &tmock

// First time, response is cached on success
Expand Down Expand Up @@ -1314,7 +1314,7 @@ func TestStaleIfErrorResponse(t *testing.T) {
},
err: nil,
}
tp := NewMemoryCacheTransport()
tp := newMemoryCacheTransport()
tp.Transport = &tmock

// First time, response is cached on success
Expand Down Expand Up @@ -1358,7 +1358,7 @@ func TestStaleIfErrorResponseLifetime(t *testing.T) {
},
err: nil,
}
tp := NewMemoryCacheTransport()
tp := newMemoryCacheTransport()
tp.Transport = &tmock

// First time, response is cached on success
Expand Down Expand Up @@ -1412,7 +1412,7 @@ func TestStaleIfErrorKeepsStatus(t *testing.T) {
},
err: nil,
}
tp := NewMemoryCacheTransport()
tp := newMemoryCacheTransport()
tp.Transport = &tmock

// First time, response is cached on success
Expand Down Expand Up @@ -1456,7 +1456,7 @@ func TestClientTimeout(t *testing.T) {
}
resetTest()
client := &http.Client{
Transport: NewMemoryCacheTransport(),
Transport: newMemoryCacheTransport(),
Timeout: time.Second,
}
started := time.Now()
Expand Down

0 comments on commit ff2042c

Please sign in to comment.