diff --git a/cmdline/workercmd/workercmd.go b/cmdline/workercmd/workercmd.go index 4d3079b..1151d6b 100644 --- a/cmdline/workercmd/workercmd.go +++ b/cmdline/workercmd/workercmd.go @@ -86,6 +86,10 @@ func runWorker(tokenName string) error { log.Logger.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("token", tokenName).Int("pid", os.Getpid()) }) + tconf := tok.Config() + if tconf.RateLimit != 0 { + tok = tokencache.NewLimiter(tok, tconf.RateLimit, tconf.RateBurst) + } expiry := time.Second * time.Duration(cfg.Server.TokenCacheSeconds) handler := &handler{ token: tokencache.New(tok, expiry), diff --git a/config/config.go b/config/config.go index b9666ae..09015ba 100644 --- a/config/config.go +++ b/config/config.go @@ -49,6 +49,8 @@ type TokenConfig struct { Pin *string // PIN to use, otherwise will be prompted. Can be empty. (optional) Timeout int // (server) Terminate command after N seconds (default 60) Retries int // (server) Retry failed commands N times (default 5) + RateLimit float64 // (server) limit token operations per second + RateBurst int // (server) allow burst of operations before limit kicks in User *uint // User argument for PKCS#11 login (optional) UseKeyring bool // Read PIN from system keyring diff --git a/doc/relic.yml b/doc/relic.yml index e71ddf0..c34d4bd 100644 --- a/doc/relic.yml +++ b/doc/relic.yml @@ -26,8 +26,10 @@ tokens: #user: 1 # Optional parameters for server mode - #timeout: 60 # Terminate each attempt after N seconds (default: 60) - #retries: 5 # Retry failed commands N times (default: 5) + #timeout: 60 # Terminate each attempt after N seconds (default: 60) + #retries: 5 # Retry failed commands N times (default: 5) + #ratelimit: 10 # Limit token operations per second + #rateburst: 10 # Allow burst of requests before limit kicks in # Use GnuPG scdaemon as a token myscd: diff --git a/server/server.go b/server/server.go index 7fe07c3..4e5d12d 100644 --- a/server/server.go +++ b/server/server.go @@ -120,6 +120,9 @@ func (s *Server) openTokens() error { if err == nil { // instrument token with metrics and caching tok = tokencache.Metrics{Token: tok} + if tconf.RateLimit != 0 { + tok = tokencache.NewLimiter(tok, tconf.RateLimit, tconf.RateBurst) + } tok = tokencache.New(tok, expiry) } } diff --git a/token/tokencache/ratelimit.go b/token/tokencache/ratelimit.go new file mode 100644 index 0000000..df176b2 --- /dev/null +++ b/token/tokencache/ratelimit.go @@ -0,0 +1,78 @@ +package tokencache + +import ( + "context" + "crypto" + "io" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/sassoftware/relic/v7/token" + "golang.org/x/time/rate" +) + +var metricRateLimited = promauto.NewCounter(prometheus.CounterOpts{ + Name: "token_operation_limited_seconds", + Help: "Cumulative number of seconds waiting for rate limits", +}) + +type RateLimited struct { + token.Token + limit *rate.Limiter +} + +func NewLimiter(base token.Token, limit float64, burst int) *RateLimited { + if burst < 1 { + burst = 1 + } + return &RateLimited{ + Token: base, + limit: rate.NewLimiter(rate.Limit(limit), burst), + } +} + +type rateLimitedKey struct { + token.Key + limit *rate.Limiter +} + +func (r *RateLimited) GetKey(ctx context.Context, keyName string) (token.Key, error) { + start := time.Now() + if err := r.limit.Wait(ctx); err != nil { + return nil, err + } + if waited := time.Since(start); waited > 1*time.Millisecond { + metricRateLimited.Add(time.Since(start).Seconds()) + } + key, err := r.Token.GetKey(ctx, keyName) + if err != nil { + return nil, err + } + return &rateLimitedKey{ + Key: key, + limit: r.limit, + }, nil +} + +func (k *rateLimitedKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (sig []byte, err error) { + start := time.Now() + if err := k.limit.Wait(context.Background()); err != nil { + return nil, err + } + if waited := time.Since(start); waited > 1*time.Millisecond { + metricRateLimited.Add(time.Since(start).Seconds()) + } + return k.Key.Sign(rand, digest, opts) +} + +func (k *rateLimitedKey) SignContext(ctx context.Context, digest []byte, opts crypto.SignerOpts) (sig []byte, err error) { + start := time.Now() + if err := k.limit.Wait(ctx); err != nil { + return nil, err + } + if waited := time.Since(start); waited > 1*time.Millisecond { + metricRateLimited.Add(time.Since(start).Seconds()) + } + return k.Key.SignContext(ctx, digest, opts) +}