From 32ded857b7161eeead9ffefa4f0bb5f6cc2d8560 Mon Sep 17 00:00:00 2001 From: Ameer Ghani Date: Mon, 4 Mar 2024 13:10:03 -0500 Subject: [PATCH] Fix race when initializing ringBufferRateLimiter This closes https://github.com/mholt/caddy-ratelimit/issues/36. Fixes a race condition between ringBufferRateLimiter creation and its insertion into a map. Do this by locking the entire map when we get or insert a ringBufferRateLimiter. I have replaced use of sync.Map with a normal `map[string]*ringBufferRateLimiter` and a `sync.Mutex`. They are passed around with a `rateLimitersMap` struct. I've factored out logic into methods of rateLimitersMap, which enables some careful use of defer `rlm.mu.Unlock()`` to avoid leaving a lock held open on `panic()`. We didn't see a need for a sync.Map. The docs suggest against using it for type safety, and none of the suggested use cases apply. https://pkg.go.dev/sync#Map. Let me know if I'm misunderstanding the use case (very possible!). I've removed the sync.Pool, for now. Since ringBufferRateLimiter creation and insertion is fully synchronized, I didn't see a need for it. Note that some of the defensive refactoring is not strictly required--I have a change that preserves the existing data structures, but I think the suggested changeset is an overall improvement in maintainability. https://github.com/divviup/caddy-ratelimit/pull/1/commits/65ad951ea012a5410dff297efa9da6f769e20dc0. Some discussion of the performance impact and profiles is here https://github.com/divviup/caddy-ratelimit/pull/1. TL;DR, no meaningful impact to CPU, memory, or latency. This implementation could be optimized by replacing the normal mutex with a RWMutex, but it would be a marginal improvement (if any) in exchange for much more complicated locking semantics. --- distributed.go | 27 +----------- distributed_test.go | 17 +++++--- handler.go | 60 +------------------------ ratelimit.go | 104 +++++++++++++++++++++++++++++++++++++++----- ringbuffer.go | 13 +++--- ringbuffer_test.go | 3 +- 6 files changed, 113 insertions(+), 111 deletions(-) diff --git a/distributed.go b/distributed.go index bbe286a..6258880 100644 --- a/distributed.go +++ b/distributed.go @@ -90,9 +90,9 @@ func (h Handler) syncDistributedWrite(ctx context.Context) error { // iterate all rate limit zones rateLimits.Range(func(zoneName, value interface{}) bool { zoneNameStr := zoneName.(string) - zoneLimiters := value.(*sync.Map) + zoneLimiters := value.(*rateLimitersMap) - state.Zones[zoneNameStr] = rlStateForZone(zoneLimiters, state.Timestamp) + state.Zones[zoneNameStr] = zoneLimiters.rlStateForZone(state.Timestamp) return true }) @@ -100,29 +100,6 @@ func (h Handler) syncDistributedWrite(ctx context.Context) error { return writeRateLimitState(ctx, state, h.Distributed.instanceID, h.storage) } -func rlStateForZone(zoneLimiters *sync.Map, timestamp time.Time) map[string]rlStateValue { - state := make(map[string]rlStateValue) - - // iterate all limiters within zone - zoneLimiters.Range(func(key, value interface{}) bool { - if value == nil { - return true - } - rl := value.(*ringBufferRateLimiter) - - count, oldestEvent := rl.Count(timestamp) - - state[key.(string)] = rlStateValue{ - Count: count, - OldestEvent: oldestEvent, - } - - return true - }) - - return state -} - func writeRateLimitState(ctx context.Context, state rlState, instanceID string, storage certmagic.Storage) error { buf := gobBufPool.Get().(*bytes.Buffer) buf.Reset() diff --git a/distributed_test.go b/distributed_test.go index 330229c..9e39697 100644 --- a/distributed_test.go +++ b/distributed_test.go @@ -19,7 +19,6 @@ import ( "fmt" "os" "strings" - "sync" "testing" "time" @@ -96,8 +95,7 @@ func TestDistributed(t *testing.T) { if err != nil { t.Fatal("failed to parse duration") } - var simulatedPeer ringBufferRateLimiter - simulatedPeer.initialize(maxEvents, parsedDuration) + simulatedPeer := newRingBufferRateLimiter(maxEvents, parsedDuration) for i := 0; i < testCase.peerRequests; i++ { if when := simulatedPeer.When(); when != 0 { @@ -105,13 +103,13 @@ func TestDistributed(t *testing.T) { } } - zoneLimiters := new(sync.Map) - zoneLimiters.Store("static", &simulatedPeer) + zoneLimiters := newRateLimiterMap() + zoneLimiters.limiters["static"] = simulatedPeer rlState := rlState{ Timestamp: testCase.peerStateTimeStamp, Zones: map[string]map[string]rlStateValue{ - zone: rlStateForZone(zoneLimiters, now()), + zone: zoneLimiters.rlStateForZone(now()), }, } @@ -134,6 +132,13 @@ func TestDistributed(t *testing.T) { "module": "file_system", "root": "%s" }, + "logging": { + "logs": { + "default": { + "level": "DEBUG" + } + } + }, "apps": { "http": { "servers": { diff --git a/handler.go b/handler.go index a7ea0c9..7ad283e 100644 --- a/handler.go +++ b/handler.go @@ -23,7 +23,6 @@ import ( "net/http" "sort" "strconv" - "sync" "time" "github.com/caddyserver/caddy/v2" @@ -166,21 +165,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhtt // make key for the individual rate limiter in this zone key := repl.ReplaceAll(rl.Key, "") - - // the API for sync.Pool is unfortunate: there is no LoadOrNew() method - // which allocates/constructs a value only if needed, so we always need - // to pre-allocate the value even if we never use it; we should be able - // to relieve some memory pressure by putting unused values back into a - // pool... - limiter := ringBufPool.Get().(*ringBufferRateLimiter) - if val, loaded := rl.limiters.LoadOrStore(key, limiter); loaded { - ringBufPool.Put(limiter) // didn't use; save for next time - limiter = val.(*ringBufferRateLimiter) - } else { - // as another side-effect of sync.Map's bad API, avoid all the - // work of initializing the ring buffer unless we have to - limiter.initialize(rl.MaxEvents, time.Duration(rl.Window)) - } + limiter := rl.limitersMap.getOrInsert(key, rl.MaxEvents, time.Duration(rl.Window)) if h.Distributed == nil { // internal rate limiter only @@ -248,42 +233,8 @@ func (h Handler) sweepRateLimiters(ctx context.Context) { for { select { case <-cleanerTicker.C: - // iterate all rate limit zones rateLimits.Range(func(key, value interface{}) bool { - rlMap := value.(*sync.Map) - - // iterate all static and dynamic rate limiters within zone - rlMap.Range(func(key, value interface{}) bool { - if value == nil { - return true - } - rl := value.(*ringBufferRateLimiter) - - rl.mu.Lock() - // no point in keeping a ring buffer of size 0 around - if len(rl.ring) == 0 { - rl.mu.Unlock() - rlMap.Delete(key) - return true - } - // get newest event in ring (should come right before oldest) - cursorNewest := rl.cursor - 1 - if cursorNewest < 0 { - cursorNewest = len(rl.ring) - 1 - } - newest := rl.ring[cursorNewest] - window := rl.window - rl.mu.Unlock() - - // if newest event in memory is outside the window, - // the entire ring has expired and can be forgotten - if newest.Add(window).Before(now()) { - rlMap.Delete(key) - } - - return true - }) - + value.(*rateLimitersMap).sweep() return true }) @@ -296,13 +247,6 @@ func (h Handler) sweepRateLimiters(ctx context.Context) { // rateLimits persists RL zones through config changes. var rateLimits = caddy.NewUsagePool() -// ringBufPool reduces allocations from unneeded rate limiters. -var ringBufPool = sync.Pool{ - New: func() interface{} { - return new(ringBufferRateLimiter) - }, -} - // Interface guards var ( _ caddy.Provisioner = (*Handler)(nil) diff --git a/ratelimit.go b/ratelimit.go index 6d25100..ca9a24d 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -46,7 +46,7 @@ type RateLimit struct { zoneName string - limiters *sync.Map + limitersMap *rateLimitersMap } func (rl *RateLimit) provision(ctx caddy.Context, name string) error { @@ -69,18 +69,11 @@ func (rl *RateLimit) provision(ctx caddy.Context, name string) error { } // ensure rate limiter state endures across config changes - rl.limiters = new(sync.Map) - if val, loaded := rateLimits.LoadOrStore(name, rl.limiters); loaded { - rl.limiters = val.(*sync.Map) + rl.limitersMap = newRateLimiterMap() + if val, loaded := rateLimits.LoadOrStore(name, rl.limitersMap); loaded { + rl.limitersMap = val.(*rateLimitersMap) } - - // update existing rate limiters with new settings - rl.limiters.Range(func(key, value interface{}) bool { - limiter := value.(*ringBufferRateLimiter) - limiter.SetMaxEvents(rl.MaxEvents) - limiter.SetWindow(time.Duration(rl.Window)) - return true - }) + rl.limitersMap.updateAll(rl.MaxEvents, time.Duration(rl.Window)) return nil } @@ -88,3 +81,90 @@ func (rl *RateLimit) provision(ctx caddy.Context, name string) error { func (rl *RateLimit) permissiveness() float64 { return float64(rl.MaxEvents) / float64(rl.Window) } + +type rateLimitersMap struct { + limiters map[string]*ringBufferRateLimiter + limitersMu sync.Mutex +} + +func newRateLimiterMap() *rateLimitersMap { + var rlm rateLimitersMap + rlm.limiters = make(map[string]*ringBufferRateLimiter) + return &rlm +} + +// getOrInsert returns an existing rate limiter from the map, or inserts a new +// one with the desired settings and returns it. +func (rlm *rateLimitersMap) getOrInsert(key string, maxEvents int, window time.Duration) *ringBufferRateLimiter { + rlm.limitersMu.Lock() + defer rlm.limitersMu.Unlock() + + rateLimiter, ok := rlm.limiters[key] + if !ok { + newRateLimiter := newRingBufferRateLimiter(maxEvents, window) + rlm.limiters[key] = newRateLimiter + return newRateLimiter + } + return rateLimiter +} + +// updateAll updates existing rate limiters with new settings. +func (rlm *rateLimitersMap) updateAll(maxEvents int, window time.Duration) { + rlm.limitersMu.Lock() + defer rlm.limitersMu.Unlock() + + for _, limiter := range rlm.limiters { + limiter.SetMaxEvents(maxEvents) + limiter.SetWindow(time.Duration(window)) + } +} + +// sweep cleans up expired rate limit states. +func (rlm *rateLimitersMap) sweep() { + rlm.limitersMu.Lock() + defer rlm.limitersMu.Unlock() + + for key, rl := range rlm.limiters { + func(rl *ringBufferRateLimiter) { + rl.mu.Lock() + defer rl.mu.Unlock() + + // no point in keeping a ring buffer of size 0 around + if len(rl.ring) == 0 { + delete(rlm.limiters, key) + return + } + + // get newest event in ring (should come right before oldest) + cursorNewest := rl.cursor - 1 + if cursorNewest < 0 { + cursorNewest = len(rl.ring) - 1 + } + newest := rl.ring[cursorNewest] + window := rl.window + + // if newest event in memory is outside the window, + // the entire ring has expired and can be forgotten + if newest.Add(window).Before(now()) { + delete(rlm.limiters, key) + } + }(rl) + } +} + +// rlStateForZone returns the state of all rate limiters in the map. +func (rlm *rateLimitersMap) rlStateForZone(timestamp time.Time) map[string]rlStateValue { + state := make(map[string]rlStateValue) + + rlm.limitersMu.Lock() + defer rlm.limitersMu.Unlock() + for key, rl := range rlm.limiters { + count, oldestEvent := rl.Count(timestamp) + state[key] = rlStateValue{ + Count: count, + OldestEvent: oldestEvent, + } + } + + return state +} diff --git a/ringbuffer.go b/ringbuffer.go index 6319c97..d764643 100644 --- a/ringbuffer.go +++ b/ringbuffer.go @@ -30,16 +30,12 @@ type ringBufferRateLimiter struct { cursor int // always points to the oldest timestamp } -// initialize sets up the rate limiter if it isn't already, allowing maxEvents +// newRingBufferRateLimiter sets up a new rate limiter, allowing maxEvents // in a sliding window of size window. If maxEvents is 0, no events are // allowed. If window is 0, all events are allowed. It panics if maxEvents or -// window are less than zero. This method is idempotent. -func (r *ringBufferRateLimiter) initialize(maxEvents int, window time.Duration) { - r.mu.Lock() - defer r.mu.Unlock() - if r.window != 0 || r.ring != nil { - return - } +// window are less than zero. +func newRingBufferRateLimiter(maxEvents int, window time.Duration) *ringBufferRateLimiter { + r := new(ringBufferRateLimiter) if maxEvents < 0 { panic("maxEvents cannot be less than zero") } @@ -48,6 +44,7 @@ func (r *ringBufferRateLimiter) initialize(maxEvents int, window time.Duration) } r.window = window r.ring = make([]time.Time, maxEvents) // TODO: we can probably pool these + return r } // When returns the duration before the next allowable event; it does not block. diff --git a/ringbuffer_test.go b/ringbuffer_test.go index c87feac..c9ee530 100644 --- a/ringbuffer_test.go +++ b/ringbuffer_test.go @@ -23,9 +23,8 @@ func TestCount(t *testing.T) { initTime() var zeroTime time.Time - var rb ringBufferRateLimiter bufSize := 10 - rb.initialize(bufSize, time.Duration(bufSize)*time.Second) + rb := newRingBufferRateLimiter(bufSize, time.Duration(bufSize)*time.Second) startTime := now() count, oldest := rb.Count(now())