Skip to content

Commit

Permalink
Fix race when initializing ringBufferRateLimiter
Browse files Browse the repository at this point in the history
  • Loading branch information
inahga committed Mar 4, 2024
1 parent a2b3187 commit 65ad951
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 28 deletions.
12 changes: 9 additions & 3 deletions distributed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ func TestDistributed(t *testing.T) {
if err != nil {
t.Fatal("failed to parse duration")
}
var simulatedPeer ringBufferRateLimiter
simulatedPeer.initialize(maxEvents, parsedDuration)
simulatedPeer := NewRingBufferRateLimiter(maxEvents, parsedDuration)

for i := 0; i < testCase.peerRequests; i++ {
if when := simulatedPeer.When(); when != 0 {
Expand All @@ -106,7 +105,7 @@ func TestDistributed(t *testing.T) {
}

zoneLimiters := new(sync.Map)
zoneLimiters.Store("static", &simulatedPeer)
zoneLimiters.Store("static", simulatedPeer)

rlState := rlState{
Timestamp: testCase.peerStateTimeStamp,
Expand Down Expand Up @@ -134,6 +133,13 @@ func TestDistributed(t *testing.T) {
"module": "file_system",
"root": "%s"
},
"logging": {
"logs": {
"default": {
"level": "DEBUG"
}
}
},
"apps": {
"http": {
"servers": {
Expand Down
39 changes: 24 additions & 15 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,19 +166,28 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhtt
// make key for the individual rate limiter in this zone
key := repl.ReplaceAll(rl.Key, "")

// the API for sync.Pool is unfortunate: there is no LoadOrNew() method
// which allocates/constructs a value only if needed, so we always need
// to pre-allocate the value even if we never use it; we should be able
// to relieve some memory pressure by putting unused values back into a
// pool...
limiter := ringBufPool.Get().(*ringBufferRateLimiter)
if val, loaded := rl.limiters.LoadOrStore(key, limiter); loaded {
ringBufPool.Put(limiter) // didn't use; save for next time
limiter = val.(*ringBufferRateLimiter)
var limiter *ringBufferRateLimiter
loadedLimiter, ok := rl.limiters.Load(key)
if ok {
limiter = loadedLimiter.(*ringBufferRateLimiter)
} else {
// as another side-effect of sync.Map's bad API, avoid all the
// work of initializing the ring buffer unless we have to
limiter.initialize(rl.MaxEvents, time.Duration(rl.Window))
var newLimiter *ringBufferRateLimiter
poolLimiter := ringBufPool.Get()
if poolLimiter == nil {
// Nothing is in the pool, create a new limiter.
newLimiter = NewRingBufferRateLimiter(rl.MaxEvents, time.Duration(rl.Window))
} else {
newLimiter = poolLimiter.(*ringBufferRateLimiter)
}

loadedLimiter, loaded := rl.limiters.LoadOrStore(key, newLimiter)
if loaded {
// We didn't end up needing the pool's limiter, since a concurrent request
// has loaded its own limiter. Store it for later use.
ringBufPool.Put(newLimiter)
}

limiter = loadedLimiter.(*ringBufferRateLimiter)
}

if h.Distributed == nil {
Expand Down Expand Up @@ -286,9 +295,9 @@ var rateLimits = caddy.NewUsagePool()

// ringBufPool reduces allocations from unneeded rate limiters.
var ringBufPool = sync.Pool{
New: func() interface{} {
return new(ringBufferRateLimiter)
},
// New: func() interface{} {
// return new(ringBufferRateLimiter)
// },
}

// Interface guards
Expand Down
13 changes: 5 additions & 8 deletions ringbuffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,12 @@ type ringBufferRateLimiter struct {
cursor int // always points to the oldest timestamp
}

// initialize sets up the rate limiter if it isn't already, allowing maxEvents
// NewRingBufferRateLimiter sets up a new rate limiter, allowing maxEvents
// in a sliding window of size window. If maxEvents is 0, no events are
// allowed. If window is 0, all events are allowed. It panics if maxEvents or
// window are less than zero. This method is idempotent.
func (r *ringBufferRateLimiter) initialize(maxEvents int, window time.Duration) {
r.mu.Lock()
defer r.mu.Unlock()
if r.window != 0 || r.ring != nil {
return
}
// window are less than zero.
func NewRingBufferRateLimiter(maxEvents int, window time.Duration) *ringBufferRateLimiter {
r := new(ringBufferRateLimiter)
if maxEvents < 0 {
panic("maxEvents cannot be less than zero")
}
Expand All @@ -48,6 +44,7 @@ func (r *ringBufferRateLimiter) initialize(maxEvents int, window time.Duration)
}
r.window = window
r.ring = make([]time.Time, maxEvents) // TODO: we can probably pool these
return r
}

// When returns the duration before the next allowable event; it does not block.
Expand Down
3 changes: 1 addition & 2 deletions ringbuffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ func TestCount(t *testing.T) {
initTime()

var zeroTime time.Time
var rb ringBufferRateLimiter
bufSize := 10
rb.initialize(bufSize, time.Duration(bufSize)*time.Second)
rb := NewRingBufferRateLimiter(bufSize, time.Duration(bufSize)*time.Second)
startTime := now()

count, oldest := rb.Count(now())
Expand Down

0 comments on commit 65ad951

Please sign in to comment.