Skip to content

Commit

Permalink
Merge pull request #1 from ksysoev/small_improvments
Browse files Browse the repository at this point in the history
Improving closing stor
  • Loading branch information
ksysoev authored Apr 10, 2024
2 parents 5a6656e + e717ff0 commit 4f23875
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 17 deletions.
61 changes: 45 additions & 16 deletions ratestor.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,31 @@ import (
)

const (
gcBatchSize = 100
gcDefaultBatchSize = 100
)

// ErrRateLimitExceeded is an error that is returned when the rate limit is exceeded.
// This error indicates that the maximum number of requests allowed within a certain time period has been reached.
var ErrRateLimitExceeded = fmt.Errorf("rate limit exceeded")

// ErrRateStorClosed is an error that indicates the rate stor is closed.
var ErrRateStorClosed = fmt.Errorf("rate stor is closed")

type RateValue struct {
ExpiresAt time.Time
Value uint64
Limit uint64
}

type RateStor struct {
rates map[string]RateValue
lock *sync.Mutex
wg *sync.WaitGroup
stop context.CancelFunc
index expIndex
gcInterval time.Duration
rates map[string]RateValue
lock *sync.Mutex
wg *sync.WaitGroup
stop context.CancelFunc
index expIndex
gcInterval time.Duration
gcBatchSize int
isClosed bool
}

type indexValue struct {
Expand Down Expand Up @@ -61,11 +66,13 @@ func NewRateStor(opts ...Optition) *RateStor {
ctx, cancel := context.WithCancel(context.Background())

stor := &RateStor{
lock: &sync.Mutex{},
rates: make(map[string]RateValue),
gcInterval: 1 * time.Second,
stop: cancel,
wg: &sync.WaitGroup{},
lock: &sync.Mutex{},
rates: make(map[string]RateValue),
gcInterval: 1 * time.Second,
stop: cancel,
wg: &sync.WaitGroup{},
gcBatchSize: gcDefaultBatchSize,
isClosed: false,
}

for _, opt := range opts {
Expand All @@ -88,8 +95,13 @@ func (rs *RateStor) Allow(key string, period time.Duration, limit uint64) error
rs.lock.Lock()
defer rs.lock.Unlock()

if rs.isClosed {
return ErrRateStorClosed
}

now := time.Now()
if rate, ok := rs.rates[key]; ok {
if rate.ExpiresAt.After(time.Now()) {
if rate.ExpiresAt.After(now) {
if rate.Value < rate.Limit {
rate.Value++
rs.rates[key] = rate
Expand All @@ -101,7 +113,7 @@ func (rs *RateStor) Allow(key string, period time.Duration, limit uint64) error
}
}

ExpiresAt := time.Now().Add(period)
ExpiresAt := now.Add(period)
rs.rates[key] = RateValue{
Value: 1,
ExpiresAt: ExpiresAt,
Expand Down Expand Up @@ -134,7 +146,9 @@ func (rs *RateStor) cleaner(ctx context.Context) {

for isRunning {
rs.lock.Lock()
for i := 0; i < gcBatchSize; i++ {
now := time.Now()

for i := 0; i < rs.gcBatchSize; i++ {
if rs.index.Len() == 0 {
isRunning = false

Expand All @@ -146,7 +160,7 @@ func (rs *RateStor) cleaner(ctx context.Context) {
panic("unexpected type" + fmt.Sprintf("%T", item))
}

if item.ExpiresAt.After(time.Now()) {
if item.ExpiresAt.After(now) {
heap.Push(&rs.index, item)

isRunning = false
Expand All @@ -169,6 +183,11 @@ func (rs *RateStor) cleaner(ctx context.Context) {

// Close stops the RateStor instance and waits for all goroutines to complete.
func (rs *RateStor) Close() {
rs.lock.Lock()
defer rs.lock.Unlock()

rs.isClosed = true

rs.stop()
rs.wg.Wait()
}
Expand All @@ -182,3 +201,13 @@ func WithGCInterval(interval time.Duration) Optition {
rs.gcInterval = interval
}
}

// WithGCBatchSize sets the garbage collection batch size for the RateStor instance.
// The garbage collection batch size determines how many expired rate limit entries
// will be removed in each garbage collection cycle.
// The default garbage collection batch size is 100.
func WithGCBatchSize(size int) Optition {
return func(rs *RateStor) {
rs.gcBatchSize = size
}
}
11 changes: 10 additions & 1 deletion ratestor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestAllow(t *testing.T) {
}

func TestGCRun(t *testing.T) {
rs := NewRateStor(WithGCInterval(2 * time.Millisecond))
rs := NewRateStor(WithGCInterval(2*time.Millisecond), WithGCBatchSize(10))
defer rs.Close()

_ = rs.Allow("key1", 3*time.Millisecond, 1)
Expand Down Expand Up @@ -61,3 +61,12 @@ func TestGCRun(t *testing.T) {

rs.lock.Unlock()
}

func TestClose(t *testing.T) {
rs := NewRateStor()
rs.Close()

if err := rs.Allow("key", time.Millisecond, 1); err != ErrRateStorClosed {
t.Errorf("Expected error %v, but got %v", ErrRateStorClosed, err)
}
}

0 comments on commit 4f23875

Please sign in to comment.