Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race when initializing ringBufferRateLimiter #43

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 2 additions & 25 deletions distributed.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,39 +90,16 @@ func (h Handler) syncDistributedWrite(ctx context.Context) error {
// iterate all rate limit zones
rateLimits.Range(func(zoneName, value interface{}) bool {
zoneNameStr := zoneName.(string)
zoneLimiters := value.(*sync.Map)
zoneLimiters := value.(*rateLimitersMap)

state.Zones[zoneNameStr] = rlStateForZone(zoneLimiters, state.Timestamp)
state.Zones[zoneNameStr] = zoneLimiters.rlStateForZone(state.Timestamp)

return true
})

return writeRateLimitState(ctx, state, h.Distributed.instanceID, h.storage)
}

func rlStateForZone(zoneLimiters *sync.Map, timestamp time.Time) map[string]rlStateValue {
state := make(map[string]rlStateValue)

// iterate all limiters within zone
zoneLimiters.Range(func(key, value interface{}) bool {
if value == nil {
return true
}
rl := value.(*ringBufferRateLimiter)

count, oldestEvent := rl.Count(timestamp)

state[key.(string)] = rlStateValue{
Count: count,
OldestEvent: oldestEvent,
}

return true
})

return state
}

func writeRateLimitState(ctx context.Context, state rlState, instanceID string, storage certmagic.Storage) error {
buf := gobBufPool.Get().(*bytes.Buffer)
buf.Reset()
Expand Down
17 changes: 11 additions & 6 deletions distributed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"fmt"
"os"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -96,22 +95,21 @@ func TestDistributed(t *testing.T) {
if err != nil {
t.Fatal("failed to parse duration")
}
var simulatedPeer ringBufferRateLimiter
simulatedPeer.initialize(maxEvents, parsedDuration)
simulatedPeer := newRingBufferRateLimiter(maxEvents, parsedDuration)

for i := 0; i < testCase.peerRequests; i++ {
if when := simulatedPeer.When(); when != 0 {
t.Fatalf("event should be allowed")
}
}

zoneLimiters := new(sync.Map)
zoneLimiters.Store("static", &simulatedPeer)
zoneLimiters := newRateLimiterMap()
zoneLimiters.limiters["static"] = simulatedPeer

rlState := rlState{
Timestamp: testCase.peerStateTimeStamp,
Zones: map[string]map[string]rlStateValue{
zone: rlStateForZone(zoneLimiters, now()),
zone: zoneLimiters.rlStateForZone(now()),
},
}

Expand All @@ -134,6 +132,13 @@ func TestDistributed(t *testing.T) {
"module": "file_system",
"root": "%s"
},
"logging": {
"logs": {
"default": {
"level": "DEBUG"
}
}
},
"apps": {
"http": {
"servers": {
Expand Down
60 changes: 2 additions & 58 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"net/http"
"sort"
"strconv"
"sync"
"time"

"github.com/caddyserver/caddy/v2"
Expand Down Expand Up @@ -166,21 +165,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhtt

// make key for the individual rate limiter in this zone
key := repl.ReplaceAll(rl.Key, "")

// the API for sync.Pool is unfortunate: there is no LoadOrNew() method
// which allocates/constructs a value only if needed, so we always need
// to pre-allocate the value even if we never use it; we should be able
// to relieve some memory pressure by putting unused values back into a
// pool...
limiter := ringBufPool.Get().(*ringBufferRateLimiter)
if val, loaded := rl.limiters.LoadOrStore(key, limiter); loaded {
ringBufPool.Put(limiter) // didn't use; save for next time
limiter = val.(*ringBufferRateLimiter)
} else {
// as another side-effect of sync.Map's bad API, avoid all the
// work of initializing the ring buffer unless we have to
limiter.initialize(rl.MaxEvents, time.Duration(rl.Window))
}
limiter := rl.limitersMap.getOrInsert(key, rl.MaxEvents, time.Duration(rl.Window))

if h.Distributed == nil {
// internal rate limiter only
Expand Down Expand Up @@ -248,42 +233,8 @@ func (h Handler) sweepRateLimiters(ctx context.Context) {
for {
select {
case <-cleanerTicker.C:
// iterate all rate limit zones
rateLimits.Range(func(key, value interface{}) bool {
rlMap := value.(*sync.Map)

// iterate all static and dynamic rate limiters within zone
rlMap.Range(func(key, value interface{}) bool {
if value == nil {
return true
}
rl := value.(*ringBufferRateLimiter)

rl.mu.Lock()
// no point in keeping a ring buffer of size 0 around
if len(rl.ring) == 0 {
rl.mu.Unlock()
rlMap.Delete(key)
return true
}
// get newest event in ring (should come right before oldest)
cursorNewest := rl.cursor - 1
if cursorNewest < 0 {
cursorNewest = len(rl.ring) - 1
}
newest := rl.ring[cursorNewest]
window := rl.window
rl.mu.Unlock()

// if newest event in memory is outside the window,
// the entire ring has expired and can be forgotten
if newest.Add(window).Before(now()) {
rlMap.Delete(key)
}

return true
})

value.(*rateLimitersMap).sweep()
return true
})

Expand All @@ -296,13 +247,6 @@ func (h Handler) sweepRateLimiters(ctx context.Context) {
// rateLimits persists RL zones through config changes.
var rateLimits = caddy.NewUsagePool()

// ringBufPool reduces allocations from unneeded rate limiters.
var ringBufPool = sync.Pool{
New: func() interface{} {
return new(ringBufferRateLimiter)
},
}

// Interface guards
var (
_ caddy.Provisioner = (*Handler)(nil)
Expand Down
104 changes: 92 additions & 12 deletions ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ type RateLimit struct {

zoneName string

limiters *sync.Map
limitersMap *rateLimitersMap
}

func (rl *RateLimit) provision(ctx caddy.Context, name string) error {
Expand All @@ -69,22 +69,102 @@ func (rl *RateLimit) provision(ctx caddy.Context, name string) error {
}

// ensure rate limiter state endures across config changes
rl.limiters = new(sync.Map)
if val, loaded := rateLimits.LoadOrStore(name, rl.limiters); loaded {
rl.limiters = val.(*sync.Map)
rl.limitersMap = newRateLimiterMap()
if val, loaded := rateLimits.LoadOrStore(name, rl.limitersMap); loaded {
rl.limitersMap = val.(*rateLimitersMap)
}

// update existing rate limiters with new settings
rl.limiters.Range(func(key, value interface{}) bool {
limiter := value.(*ringBufferRateLimiter)
limiter.SetMaxEvents(rl.MaxEvents)
limiter.SetWindow(time.Duration(rl.Window))
return true
})
rl.limitersMap.updateAll(rl.MaxEvents, time.Duration(rl.Window))

return nil
}

func (rl *RateLimit) permissiveness() float64 {
return float64(rl.MaxEvents) / float64(rl.Window)
}

type rateLimitersMap struct {
limiters map[string]*ringBufferRateLimiter
limitersMu sync.Mutex
}

func newRateLimiterMap() *rateLimitersMap {
var rlm rateLimitersMap
rlm.limiters = make(map[string]*ringBufferRateLimiter)
return &rlm
}

// getOrInsert returns an existing rate limiter from the map, or inserts a new
// one with the desired settings and returns it.
func (rlm *rateLimitersMap) getOrInsert(key string, maxEvents int, window time.Duration) *ringBufferRateLimiter {
rlm.limitersMu.Lock()
defer rlm.limitersMu.Unlock()

rateLimiter, ok := rlm.limiters[key]
if !ok {
newRateLimiter := newRingBufferRateLimiter(maxEvents, window)
rlm.limiters[key] = newRateLimiter
return newRateLimiter
}
return rateLimiter
}

// updateAll updates existing rate limiters with new settings.
func (rlm *rateLimitersMap) updateAll(maxEvents int, window time.Duration) {
rlm.limitersMu.Lock()
defer rlm.limitersMu.Unlock()

for _, limiter := range rlm.limiters {
limiter.SetMaxEvents(maxEvents)
limiter.SetWindow(time.Duration(window))
}
}

// sweep cleans up expired rate limit states.
func (rlm *rateLimitersMap) sweep() {
rlm.limitersMu.Lock()
defer rlm.limitersMu.Unlock()

for key, rl := range rlm.limiters {
func(rl *ringBufferRateLimiter) {
rl.mu.Lock()
defer rl.mu.Unlock()

// no point in keeping a ring buffer of size 0 around
if len(rl.ring) == 0 {
delete(rlm.limiters, key)
return
}

// get newest event in ring (should come right before oldest)
cursorNewest := rl.cursor - 1
if cursorNewest < 0 {
cursorNewest = len(rl.ring) - 1
}
newest := rl.ring[cursorNewest]
window := rl.window

// if newest event in memory is outside the window,
// the entire ring has expired and can be forgotten
if newest.Add(window).Before(now()) {
delete(rlm.limiters, key)
}
}(rl)
}
}

// rlStateForZone returns the state of all rate limiters in the map.
func (rlm *rateLimitersMap) rlStateForZone(timestamp time.Time) map[string]rlStateValue {
state := make(map[string]rlStateValue)

rlm.limitersMu.Lock()
defer rlm.limitersMu.Unlock()
for key, rl := range rlm.limiters {
count, oldestEvent := rl.Count(timestamp)
state[key] = rlStateValue{
Count: count,
OldestEvent: oldestEvent,
}
}

return state
}
13 changes: 5 additions & 8 deletions ringbuffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,12 @@ type ringBufferRateLimiter struct {
cursor int // always points to the oldest timestamp
}

// initialize sets up the rate limiter if it isn't already, allowing maxEvents
// newRingBufferRateLimiter sets up a new rate limiter, allowing maxEvents
// in a sliding window of size window. If maxEvents is 0, no events are
// allowed. If window is 0, all events are allowed. It panics if maxEvents or
// window are less than zero. This method is idempotent.
func (r *ringBufferRateLimiter) initialize(maxEvents int, window time.Duration) {
r.mu.Lock()
defer r.mu.Unlock()
if r.window != 0 || r.ring != nil {
return
}
// window are less than zero.
func newRingBufferRateLimiter(maxEvents int, window time.Duration) *ringBufferRateLimiter {
r := new(ringBufferRateLimiter)
if maxEvents < 0 {
panic("maxEvents cannot be less than zero")
}
Expand All @@ -48,6 +44,7 @@ func (r *ringBufferRateLimiter) initialize(maxEvents int, window time.Duration)
}
r.window = window
r.ring = make([]time.Time, maxEvents) // TODO: we can probably pool these
return r
}

// When returns the duration before the next allowable event; it does not block.
Expand Down
3 changes: 1 addition & 2 deletions ringbuffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ func TestCount(t *testing.T) {
initTime()

var zeroTime time.Time
var rb ringBufferRateLimiter
bufSize := 10
rb.initialize(bufSize, time.Duration(bufSize)*time.Second)
rb := newRingBufferRateLimiter(bufSize, time.Duration(bufSize)*time.Second)
startTime := now()

count, oldest := rb.Count(now())
Expand Down
Loading