From 0d69b97ef8cf5cfe1658b2eab0c6aa617e63661e Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Mon, 5 Aug 2024 11:50:14 +0200 Subject: [PATCH] Fix passing custom redis client --- _example/main.go | 1 + config.go | 4 ++-- httprateredis.go | 46 ++++++++++++++++++++++++---------------------- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/_example/main.go b/_example/main.go index bafb886..6d1250e 100644 --- a/_example/main.go +++ b/_example/main.go @@ -37,6 +37,7 @@ func main() { }) }) + // Rate-limit at 50 req/s per IP address. r.Use(httprate.Limit( 50, time.Second, httprate.WithKeyByIP(), diff --git a/config.go b/config.go index 068f99b..7ce25e9 100644 --- a/config.go +++ b/config.go @@ -22,9 +22,9 @@ type Config struct { // the system will use the local counter unless it is explicitly disabled. FallbackTimeout time.Duration `toml:"fallback_timeout"` // default: 50ms - // Client if supplied will be used and below fields will be ignored. + // Client if supplied will be used and the below fields will be ignored. // - // NOTE: It's recommended to set short Dial/Read/Write timeouts and disable + // NOTE: It's recommended to set short dial/read/write timeouts and disable // retries on the client, so the local in-memory fallback can activate quickly. Client *redis.Client `toml:"-"` Host string `toml:"host"` diff --git a/httprateredis.go b/httprateredis.go index 1a3fd50..b3d9183 100644 --- a/httprateredis.go +++ b/httprateredis.go @@ -40,6 +40,7 @@ func NewRedisLimitCounter(cfg *Config) (*redisCounter, error) { cfg.PrefixKey = "httprate" } if cfg.FallbackTimeout == 0 { + // Activate local in-memory fallback fairly quickly, as this would slow down all requests. cfg.FallbackTimeout = 50 * time.Millisecond } @@ -50,29 +51,30 @@ func NewRedisLimitCounter(cfg *Config) (*redisCounter, error) { rc.fallbackCounter = httprate.NewLocalLimitCounter(cfg.WindowLength) } - var maxIdle, maxActive = cfg.MaxIdle, cfg.MaxActive - if maxIdle <= 0 { - maxIdle = 20 - } - if maxActive <= 0 { - maxActive = 50 - } + if cfg.Client == nil { + maxIdle, maxActive := cfg.MaxIdle, cfg.MaxActive + if maxIdle < 1 { + maxIdle = 20 + } + if maxActive < 1 { + maxActive = 50 + } - address := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) - rc.client = redis.NewClient(&redis.Options{ - Addr: address, - Password: cfg.Password, - DB: cfg.DBIndex, - PoolSize: maxActive, - MaxIdleConns: maxIdle, - ClientName: cfg.ClientName, + rc.client = redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), + Password: cfg.Password, + DB: cfg.DBIndex, + ClientName: cfg.ClientName, - DialTimeout: cfg.FallbackTimeout, - ReadTimeout: cfg.FallbackTimeout, - WriteTimeout: cfg.FallbackTimeout, - MinIdleConns: 1, - MaxRetries: -1, - }) + DialTimeout: cfg.FallbackTimeout, + ReadTimeout: cfg.FallbackTimeout, + WriteTimeout: cfg.FallbackTimeout, + PoolSize: maxActive, + MinIdleConns: 1, + MaxIdleConns: maxIdle, + MaxRetries: -1, // -1 disables retries + }) + } return rc, nil } @@ -109,7 +111,7 @@ func (c *redisCounter) IncrementBy(key string, currentWindow time.Time, amount i var netErr net.Error if errors.As(err, &netErr) || errors.Is(err, redis.ErrClosed) { go c.fallback() - err = c.fallbackCounter.IncrementBy(key, currentWindow, amount) // = nil + err = c.fallbackCounter.IncrementBy(key, currentWindow, amount) } } }()