diff --git a/.travis.yml b/.travis.yml index 36a003e281..876e0f660c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -131,6 +131,7 @@ services: - docker - postgresql - mysql + - redis-server before_install: - sudo service mysql stop diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dea53434d..9c99b6fd4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,10 @@ the corresponding packages, and are now required to be imported explicitly by the main file in order to be registered. We are including only MySQL and cloudspanner providers by default, since these are the ones that we support. +### Quota + +An experimental Redis-based `quota.Manager` implementation has been added. + ### Tools The `licenses` tool has been moved from "scripts/licenses" to [a dedicated diff --git a/docs/Feature_Implementation_Matrix.md b/docs/Feature_Implementation_Matrix.md index 01ca1bfb5c..659374c641 100644 --- a/docs/Feature_Implementation_Matrix.md +++ b/docs/Feature_Implementation_Matrix.md @@ -153,16 +153,16 @@ Supported frameworks for providing Master Election. ### Quota -Supported frameworks for providing Master Election. +Supported frameworks for quota management. -| Election | Status | Deployed in prod | Notes | +| Implementation | Status | Deployed in prod | Notes | |:--- | :---: | :---: |:--- | | Google internal | GA | ✓ | | | etcd | GA | ✓ | | | MySQL | Beta | ? | | +| Redis | Alpha | ✓ | | | Postgres | NI | | | - ### Key management Supported frameworks for key management and signing. diff --git a/go.mod b/go.mod index 7dad026294..f24279146f 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/coreos/go-systemd v0.0.0-20190620071333-e64a0ec8b42a // indirect github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect github.com/emicklei/proto v1.8.0 // indirect + github.com/go-redis/redis v6.15.6+incompatible github.com/go-sql-driver/mysql v1.4.1 github.com/gogo/protobuf v1.3.1 // indirect github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b diff --git a/go.sum b/go.sum index 8785f183f2..27e9ccfc06 100644 --- a/go.sum +++ b/go.sum @@ -141,6 +141,8 @@ github.com/go-lintpack/lintpack v0.5.2/go.mod h1:NwZuYi2nUHho8XEIZ6SIxihrnPoqBTD github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8= +github.com/go-redis/redis v6.15.6+incompatible h1:H9evprGPLI8+ci7fxQx6WNZHJSb7be8FqJQRhdQZ5Sg= +github.com/go-redis/redis v6.15.6+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= diff --git a/quota/redis/redisqm/manager.go b/quota/redis/redisqm/manager.go new file mode 100644 index 0000000000..ca44793e24 --- /dev/null +++ b/quota/redis/redisqm/manager.go @@ -0,0 +1,186 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redisqm defines a Redis-based quota.Manager implementation. +package redisqm + +import ( + "context" + "fmt" + + "github.com/google/trillian/quota" + "github.com/google/trillian/quota/redis/redistb" +) + +// ParameterFunc is a function that should return a token bucket's parameters +// for a given quota specification. +type ParameterFunc func(spec quota.Spec) (capacity int, rate float64) + +// ManagerOptions holds the parameters for a Manager. +type ManagerOptions struct { + // Parameters should return the parameters for a given quota.Spec. This + // value must not be nil. + Parameters ParameterFunc + + // Prefix is a static prefix to apply to all Redis keys; this is useful + // if running on a multi-tenant Redis cluster. + Prefix string +} + +// Manager implements the quota.Manager interface backed by a Redis-based token +// bucket implementation. +type Manager struct { + tb *redistb.TokenBucket + opts ManagerOptions +} + +var _ quota.Manager = &Manager{} + +// RedisClient is an interface that encompasses the various methods used by +// this quota.Manager, and allows selecting among different Redis client +// implementations (e.g. regular Redis, Redis Cluster, sharded, etc.) +type RedisClient interface { + // Everything required by the redistb.RedisClient interface + redistb.RedisClient +} + +// New returns a new Redis-based quota.Manager. +func New(client RedisClient, opts ManagerOptions) *Manager { + tb := redistb.New(client) + return &Manager{tb: tb, opts: opts} +} + +// GetTokens implements the quota.Manager API. +func (m *Manager) GetTokens(ctx context.Context, numTokens int, specs []quota.Spec) error { + for _, spec := range specs { + if err := m.getTokensSingle(ctx, numTokens, spec); err != nil { + return err + } + } + + return nil +} + +func (m *Manager) getTokensSingle(ctx context.Context, numTokens int, spec quota.Spec) error { + capacity, rate := m.opts.Parameters(spec) + + // If we get back `MaxTokens` from our parameters call, this indicates + // that there's no actual limit. We don't need to do anything to "get" + // them; just ignore. + if capacity == quota.MaxTokens { + return nil + } + + name := specName(m.opts.Prefix, spec) + allowed, remaining, err := m.tb.Call( + ctx, + name, + int64(capacity), + rate, + numTokens, + ) + if err != nil { + return err + } + if !allowed { + return fmt.Errorf("insufficient tokens on %v (%v vs %v)", name, remaining, numTokens) + } + + return nil +} + +// PeekTokens implements the quota.Manager API. +func (m *Manager) PeekTokens(ctx context.Context, specs []quota.Spec) (map[quota.Spec]int, error) { + tokens := make(map[quota.Spec]int) + for _, spec := range specs { + // Calling the limiter with 0 tokens requested is equivalent to + // "peeking", but it will also shrink the token bucket if it + // has too many tokens. + capacity, rate := m.opts.Parameters(spec) + + // If we get back `MaxTokens` from our parameters call, this + // indicates that there's no actual limit. We don't need to do + // anything to "get" them; just set that value in the returned + // map as well. + if capacity == quota.MaxTokens { + tokens[spec] = quota.MaxTokens + continue + } + + _, remaining, err := m.tb.Call( + ctx, + specName(m.opts.Prefix, spec), + int64(capacity), + rate, + 0, + ) + if err != nil { + return nil, err + } + + tokens[spec] = int(remaining) + } + + return tokens, nil +} + +// PutTokens implements the quota.Manager API. +func (m *Manager) PutTokens(ctx context.Context, numTokens int, specs []quota.Spec) error { + // Putting tokens into a time-based quota doesn't mean anything (since + // tokens are replenished at the moment they're requested) and since + // that's the only supported mechanism for this package currently, do + // nothing. + return nil +} + +// ResetQuota implements the quota.Manager API. +// +// This function will reset every quota and return the first error encountered, +// if any, but will continue trying to reset every quota even if an error is +// encountered. +func (m *Manager) ResetQuota(ctx context.Context, specs []quota.Spec) error { + var firstErr error + + for _, name := range specNames(m.opts.Prefix, specs) { + if err := m.tb.Reset(ctx, name); err != nil { + if firstErr == nil { + firstErr = err + } + } + } + + return firstErr +} + +// Load attempts to load Redis scripts used by the Manager into the Redis +// cluster. +// +// A Manager will operate successfully if this method is not called or fails, +// but a successful Load will reduce bandwidth to/from the Redis cluster +// substantially. +func (m *Manager) Load(ctx context.Context) error { + return m.tb.Load(ctx) +} + +func specNames(prefix string, specs []quota.Spec) []string { + names := make([]string, 0, len(specs)) + for _, spec := range specs { + names = append(names, specName(prefix, spec)) + } + return names +} + +func specName(prefix string, spec quota.Spec) string { + return prefix + "trillian/" + spec.Name() +} diff --git a/quota/redis/redistb/embed_redis.go b/quota/redis/redistb/embed_redis.go new file mode 100644 index 0000000000..5d57107eb1 --- /dev/null +++ b/quota/redis/redistb/embed_redis.go @@ -0,0 +1,89 @@ +// +build ignore + +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This is a helper utility to embed the Redis Lua scripts into a Go source +// file. +package main + +import ( + "bytes" + "go/format" + "io/ioutil" + "log" + "os" + "strconv" + "text/template" +) + +var packageTemplate = template.Must(template.New("").Parse(` +// Code generated by quota/redis/redistb/gen.go. DO NOT EDIT. +// source: {{ .Filename }} + +package redistb + +import ( + "github.com/go-redis/redis" +) + +// contents of the '{{ .Prefix }}' Redis Lua script +const {{ .Prefix }}ScriptContents = {{ .Content }} + +// Redis Script type for the '{{ .Prefix }}' Redis lua script +var {{ .Prefix }}Script = redis.NewScript({{ .Prefix }}ScriptContents) +`)) + +type templateData struct { + Prefix string + Filename string + Content string +} + +func main() { + if len(os.Args) != 4 { + log.Fatalf("usage: %s prefix file.lua output.go", os.Args[0]) + } + + data, err := ioutil.ReadFile(os.Args[2]) + if err != nil { + log.Fatalf("error opening input file: %v", err) + } + + vars := templateData{ + Prefix: os.Args[1], + Filename: os.Args[2], + Content: strconv.Quote(string(data)), + } + + var buf bytes.Buffer + if err := packageTemplate.Execute(&buf, vars); err != nil { + log.Fatalf("error rendering template: %v", err) + } + + data, err = format.Source(buf.Bytes()) + if err != nil { + log.Fatalf("error formatting source: %v", err) + } + + out, err := os.Create(os.Args[3]) + if err != nil { + log.Fatalf("error opening output file: %v", err) + } + defer out.Close() + + if _, err := out.Write(data); err != nil { + log.Fatalf("error writing output file: %v", err) + } +} diff --git a/quota/redis/redistb/gen.go b/quota/redis/redistb/gen.go new file mode 100644 index 0000000000..44efb0ed8a --- /dev/null +++ b/quota/redis/redistb/gen.go @@ -0,0 +1,17 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redistb + +//go:generate go run embed_redis.go updateTokenBucket update_token_bucket.lua update_token_bucket.gen.go diff --git a/quota/redis/redistb/redistb.go b/quota/redis/redistb/redistb.go new file mode 100644 index 0000000000..0c3fcf5de2 --- /dev/null +++ b/quota/redis/redistb/redistb.go @@ -0,0 +1,221 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redistb + +import ( + "context" + "fmt" + "time" + + "github.com/go-redis/redis" + "github.com/google/trillian/util/clock" +) + +// RedisClient is an interface that encompasses the various methods used by +// TokenBucket, and allows selecting among different Redis client +// implementations (e.g. regular Redis, Redis Cluster, sharded, etc.) +type RedisClient interface { + // Required to load and execute scripts + Eval(script string, keys []string, args ...interface{}) *redis.Cmd + EvalSha(sha1 string, keys []string, args ...interface{}) *redis.Cmd + ScriptExists(hashes ...string) *redis.BoolSliceCmd + ScriptLoad(script string) *redis.StringCmd +} + +// TokenBucket implements a token-bucket limiter stored in a Redis database. It +// supports atomic operation with concurrent access. +type TokenBucket struct { + c RedisClient + + testing bool + timeSource clock.TimeSource +} + +// New returns a new TokenBucket that uses the provided Redis client. +func New(client RedisClient) *TokenBucket { + ret := &TokenBucket{ + c: client, + timeSource: clock.System, + } + return ret +} + +// Load preloads any required Lua scripts into the Redis database, and updates +// the hash of the resulting script. Calling this function is optional, but +// will greatly reduce the network traffic to the Redis cluster since it only +// needs to pass a hash of the script and not the full script content. +func (tb *TokenBucket) Load(ctx context.Context) error { + client := withClientContext(ctx, tb.c) + return updateTokenBucketScript.Load(client).Err() +} + +// Reset resets the token bucket for the given prefix. +func (tb *TokenBucket) Reset(ctx context.Context, prefix string) error { + client := withClientContext(ctx, tb.c) + + // Use `EVAL` so that deleting all keys is atomic. + resp := client.Eval( + `redis.call("del", KEYS[1]); redis.call("del", KEYS[2]); redis.call("del", KEYS[3])`, + tokenBucketKeys(prefix), + ) + return resp.Err() +} + +// Call implements the actual token bucket algorithm. Given a bucket with +// capacity `capacity` and replenishment rate of `replenishRate` tokens per +// second, it will first ensure that the bucket has the correct number of +// tokens added (up to the maximum capacity) since the last time that this +// function was called. Then, it will attempt to remove `numTokens` from the +// bucket. +// +// This function returns a boolean indicating whether it was able to remove all +// tokens from the bucket, the remaining number of tokens in the bucket, and +// any error that occurs. +func (tb *TokenBucket) Call( + ctx context.Context, + prefix string, + capacity int64, + replenishRate float64, + numTokens int, +) (bool, int64, error) { + client := withClientContext(ctx, tb.c) + + var ( + now int64 + nowUs int64 + ) + if tb.testing { + now, nowUs = timeToRedisPair(tb.timeSource.Now()) + } + + args := []interface{}{ + replenishRate, + capacity, + numTokens, + + // The script allows us to inject the current time for testing, + // but it's superseded by Redis's time in production to protect + // against clock drift. + now, + nowUs, + tb.testing, + } + + resp := updateTokenBucketScript.Run( + client, + tokenBucketKeys(prefix), + args..., + ) + result, err := resp.Result() + if err != nil { + return false, 0, err + } + + returnVals, ok := result.([]interface{}) + if !ok { + return false, 0, fmt.Errorf("redistb: invalid return type %T (expected []interface{})", result) + } + + // The script returns: + // allowed Whether the operation was allowed + // remaining The remaining tokens in the bucket + // now_new The script's view of the current time + // now_new_us The script's view of the current time (microseconds) + // + // We don't use the last two arguments here. + + // Deserializing turns Lua 'true' into '1', and 'false' into 'nil' + var allowed bool + if returnVals[0] == nil { + allowed = false + } else if i, ok := returnVals[0].(int64); ok { + allowed = i == 1 + } else { + return false, 0, fmt.Errorf("redistb: invalid 'allowed' type %T", returnVals[0]) + } + + remaining := returnVals[1].(int64) + return allowed, remaining, nil +} + +// tokenBucketKeys returns the keys used for the token bucket script, given a +// prefix. +func tokenBucketKeys(prefix string) []string { + // Redis Cluster uses a hashing algorithm on keys to determine which slot + // they map to in its backend. Normally this is a problem for EVAL/EVALSHA + // because multiple keys in a script will likely map to different slots + // and cause Redis Cluster to reject the request. + // + // It's addressed with the idea of a "hash tag": + // + // https://redis.io/topics/cluster-tutorial#redis-cluster-data-sharding + // + // If a key name contains a string inside of "{}" then _just_ that string + // is hashed to be used slotting purposes, thereby giving users some + // control over mapping consistently to certain slots. For example, + // `this{foo}key` and `another{foo}key` are guaranteed to both map to the + // same slot. + // + // We take advantage of this idea here by making sure to hash only the + // common identifier in these keys by using "{}". + return []string{ + fmt.Sprintf("{%s}.tokens", prefix), + fmt.Sprintf("{%s}.refreshed", prefix), + fmt.Sprintf("{%s}.refreshed_us", prefix), + } +} + +// timeToRedisPair converts a Go time.Time into a seconds and microseconds +// component, which can be passed to our Redis script. +func timeToRedisPair(t time.Time) (int64, int64) { + // The first number in the pair is the number of seconds since the Unix + // epoch. + timeSec := t.Unix() + + // The second number is any additional number of microseconds; we can + // get this by obtaining any sub-second Nanoseconds and simply dividing + // to get the number in microseconds. + timeMicros := int64(t.Nanosecond()) / int64(time.Microsecond) + + return timeSec, timeMicros +} + +// Because each Redis client type in the Go package has a `WithContext` method +// that returns a concrete type, we can't simply put that method in the +// RedisClient interface. This method performs type assertions to try and call +// the `WithContext` method on the appropriate concrete type. +func withClientContext(ctx context.Context, client RedisClient) RedisClient { + type withContextable interface { + WithContext(context.Context) RedisClient + } + + switch c := client.(type) { + // The three major Redis clients + case *redis.Client: + return c.WithContext(ctx) + case *redis.ClusterClient: + return c.WithContext(ctx) + case *redis.Ring: + return c.WithContext(ctx) + + // Let's also support the case where someone implements a custom client + // that returns the RedisClient interface type (e.g. good for tests). + case withContextable: + return c.WithContext(ctx) + } + + // If we can't determine a type, just return it unchanged. + return client +} diff --git a/quota/redis/redistb/redistb_test.go b/quota/redis/redistb/redistb_test.go new file mode 100644 index 0000000000..fa950f6819 --- /dev/null +++ b/quota/redis/redistb/redistb_test.go @@ -0,0 +1,873 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redistb + +import ( + "context" + "crypto/rand" + "math/big" + "testing" + "time" + + "github.com/go-redis/redis" + "github.com/google/trillian/util/clock" +) + +const ( + // Total capacity of the bucket + TestCapacity = 5 + + // Replenish rate in tokens per second (so 2 = 2 tokens/second). + TestReplenishRate = 2 + + // The base time for our test cases. Note this is meant to be + // seconds from the Unix epoch, but we're using a simplified + // number here to make debugging easier. + TestBaseTimeSec = 123 + + // The base time of microseconds within the current second + // (TestBaseTimeSec). + TestBaseTimeUs = 100 + + // The number of microseconds in a second, made a constant for + // better readability (10 ** 6). + MicrosecondsInSecond = 1000000 + + // The number of microseconds that it takes to drip a single + // new token at our TestReplenishRate. + SingleTokenDripTimeUs = MicrosecondsInSecond / TestReplenishRate + + // A common TTL we'll pass to every `SETEX` call. + TTL = 60 +) + +var ( + // The base time as a `time.Time` object + TestBaseTime = time.Unix(TestBaseTimeSec, TestBaseTimeUs*int64(time.Microsecond/time.Nanosecond)) +) + +// This test is an end-to-end integration test for the Redis token-bucket +// implementation. It exercises the actual implementation of the token bucket; +// there are further unit tests below that test specific behavior of the Lua +// script. +func TestRedisTokenBucketIntegration(t *testing.T) { + ctx := context.Background() + + start := time.Now() + fixedTimeSource := clock.NewFake(start) + + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + }) + tb := New(rdb) + tb.testing = true + tb.timeSource = fixedTimeSource + + if err := tb.Load(ctx); err != nil { + t.Fatalf("failed to load script: %v", err) + } + + prefix := stringWithCharset(20, alphanumeric) + + call := func(number int) (bool, int64) { + allowed, remaining, err := tb.Call( + ctx, + prefix, + TestCapacity, + TestReplenishRate, + number, + ) + if err != nil { + t.Fatal(err) + } + + return allowed, remaining + } + + // First, ensure that we can empty the bucket + for i := int64(0); i < TestCapacity; i++ { + allowed, remaining := call(1) + if !allowed { + t.Fatal("expected to be allowed") + } + if expected := TestCapacity - i - 1; remaining != expected { + t.Fatalf("expected remaining to be %d, got %d", expected, remaining) + } + } + + // Within this second, all future requests should fail. + allowed, remaining := call(1) + if allowed { + t.Fatal("expected to be denied") + } + if remaining != 0 { + t.Fatalf("expected remaining to be 0, got %d", remaining) + } + + singleTokenReplenishTime := 1.0 / TestReplenishRate + + // An arbitrary, non-zero number of iterations to ensure that this is repeatable. + for i := 0; i < 5; i++ { + // This is the perfect amount of time to get exactly one more token replenished. + timeStep := float64(i+1) * singleTokenReplenishTime + + // Freeze time *just* before a new token would enter the bucket + // to verify that requests are still blocked. 0.01s is an + // arbitrary number chosen to be "small enough" and yet not run + // into precision problems. + justBefore := start + justBefore = justBefore.Add(time.Duration(timeStep * float64(time.Second))) + justBefore = justBefore.Add(-1 * time.Second / 100) + fixedTimeSource.Set(justBefore) + + allowed, remaining := call(1) + if allowed { + t.Fatalf("expected request before reaching replenish time to be denied on iteration %d", i) + } + if remaining != 0 { + t.Fatalf("expected remaining to be 0, got %d", remaining) + } + + // Freeze time at precisely the right moment that a token has + // entered the bucket. + fixedTimeSource.Set(start.Add(time.Duration(timeStep * float64(time.Second)))) + + // A single request is allowed + allowed, remaining = call(1) + if allowed { + t.Fatalf("expected first request to be allowed on iteration %d", i) + } + if remaining != 0 { + t.Fatalf("expected remaining to be 0, got %d", remaining) + } + + // Requests are blocked again now that we've used the token. + allowed, remaining = call(1) + if allowed { + t.Fatalf("expected second request to be denied on iteration %d", i) + } + if remaining != 0 { + t.Fatalf("expected remaining to be 0, got %d", remaining) + } + } +} + +func TestRedisTokenBucketCases(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + }) + + // TestBaseTime as a time.Duration from the Unix epoch + TestBaseTimeDelta := (TestBaseTimeSec * time.Second) + (TestBaseTimeUs * time.Microsecond) + + tests := []struct { + Name string + + // Initial number of tokens in the bucket + SkipInitialSetup bool + InitialTime time.Time + InitialTokens int + + // Arguments to the call + ArgTokens int + TimeDelta time.Duration + + // Expected results + Allowed bool + TokensLeft int64 + ReturnedTime time.Time // defaults to InitialTime + TimeDelta if zero + }{ + { + Name: "new request with tokens left", + + // We have one token in the bucket + InitialTokens: 1, + InitialTime: TestBaseTime, + + // We take one token from the bucket now + ArgTokens: 1, + + // It worked, and none are left + Allowed: true, + TokensLeft: 0, + }, + + { + Name: "new request where no values have been previously set", + + // The bucket has not been initialized + SkipInitialSetup: true, + + // We take one token from the bucket now + ArgTokens: 1, + TimeDelta: TestBaseTimeDelta, + + // It worked, and the bucket is full minus the token we just took + Allowed: true, + TokensLeft: TestCapacity - 1, + }, + + { + Name: "allows for a new request where values were set a long time ago", + + // The bucket is empty as of a "long time ago" + InitialTokens: 0, + InitialTime: time.Time{}, + + // We take one token from the bucket now + ArgTokens: 1, + TimeDelta: TestBaseTimeDelta, + + // It worked, and the bucket is full minus the token we just took + Allowed: true, + TokensLeft: TestCapacity - 1, + }, + + { + Name: "allows for a new request that goes back in time but leaves existing values unchanged", + + // We have one token in the bucket + InitialTokens: 1, + InitialTime: TestBaseTime, + + // "Something happens", and we go back in time while taking a token + ArgTokens: 1, + TimeDelta: -TestBaseTimeDelta, + + // We can take a token + Allowed: true, + TokensLeft: 0, + + // The values in Redis did not go back in time + ReturnedTime: TestBaseTime, + }, + + { + Name: "disallows for a new request at same moment without tokens left", + + // The bucket has no tokens in it + InitialTokens: 0, + InitialTime: TestBaseTime, + + // We take one token + ArgTokens: 1, + + // We cannot get a token + Allowed: false, + TokensLeft: 0, + }, + + { + Name: "allows for a new request without tokens left but time to replenish", + + // The bucket has no tokens in it + InitialTokens: 0, + InitialTime: TestBaseTime, + + // We take one token at exactly the moment that one has + // entered the bucket + ArgTokens: 1, + TimeDelta: SingleTokenDripTimeUs * time.Microsecond, + + // It worked, and the bucket is empty again at the new + // time + Allowed: true, + TokensLeft: 0, + }, + + { + Name: "similarly scales up if more time than necessary has passed", + + // The bucket has no tokens in it + InitialTokens: 0, + InitialTime: TestBaseTime, + + // We take one token, at the time where the bucket is almost full + ArgTokens: 1, + TimeDelta: (TestCapacity - 1) * SingleTokenDripTimeUs * time.Microsecond, + + // It worked, and the bucket contains the expected + // number of tokens, minus one for the token we just + // took. + Allowed: true, + TokensLeft: (TestCapacity - 1) - 1, + }, + + { + Name: "maxes out at the bucket's capacity", + + // The bucket has no tokens in it + InitialTokens: 0, + InitialTime: TestBaseTime, + + // We take one token at the time where one more token + // has been added to the bucket than should fit + ArgTokens: 1, + TimeDelta: (TestCapacity + 1) * SingleTokenDripTimeUs * time.Microsecond, + + // It worked, and the bucket contains one token less + // than the capacity, since it capped at the capacity + // and we just took one + Allowed: true, + TokensLeft: TestCapacity - 1, + }, + + { + Name: "allows and leaves values unchanged when requesting 0 tokens", + + // We have one token in the bucket + InitialTokens: 1, + InitialTime: TestBaseTime, + + // Take 0 tokens + ArgTokens: 0, + + // It worked, and nothing is changed + Allowed: true, + TokensLeft: 1, + }, + + { + Name: "allows and leaves values unchanged when requesting 0 tokens at 0 remaining", + + // We have one token in the bucket + InitialTokens: 0, + InitialTime: TestBaseTime, + + // Take 0 tokens + ArgTokens: 0, + + // It worked, and nothing is changed + Allowed: true, + TokensLeft: 0, + }, + + { + Name: "denies and returns remaining tokens when requesting more than one", + + // We have one token in the bucket + InitialTokens: 1, + InitialTime: TestBaseTime, + + // Take 2 tokens + ArgTokens: 2, + + // It worked, nothing is changed, and we can see the + // number of tokens in the bucket + Allowed: false, + TokensLeft: 1, + }, + + // Smoke test to ensure that real timestamps are supported + { + Name: "works with real timestamps", + + InitialTokens: 0, + InitialTime: time.Now(), + + // We take one token from the bucket "just after" it + // has been added to the bucket + ArgTokens: 1, + TimeDelta: SingleTokenDripTimeUs * time.Microsecond, + + // It worked, and the bucket is empty again + Allowed: true, + TokensLeft: 0, + }, + } + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + keys := makeKeys() + + // Go's zero time starts before the Unix epoch; return the Unix epoch + // if it's not specified + if test.InitialTime.IsZero() { + test.InitialTime = time.Unix(0, 0) + } + + // First, set the initial values in the database + if !test.SkipInitialSetup { + mustInitKeys(t, rdb, keys, int64(test.InitialTokens), test.InitialTime) + } + + // Calculate the time argument + argTimeSec, argTimeUs := timeToRedisPair(test.InitialTime.Add(test.TimeDelta)) + + // Next, call the script + resp := updateTokenBucketScript.Eval( + rdb, + keys.AsSlice(), + + // Args + TestReplenishRate, + TestCapacity, + test.ArgTokens, + argTimeSec, + argTimeUs, + "true", + ) + + // The expected returned time from Redis defaults to + // initial time plus the time delta unless one is + // explicitly specified. + var expectedReturnedTime, expectedReturnedTimeUs int64 + if test.ReturnedTime.IsZero() { + expectedReturnedTime, expectedReturnedTimeUs = timeToRedisPair( + test.InitialTime.Add(test.TimeDelta), + ) + } else { + expectedReturnedTime, expectedReturnedTimeUs = timeToRedisPair(test.ReturnedTime) + } + + // Use our helper function that deserializes the + // results and runs assertions + assertRedisResults(t, resp, + test.Allowed, + test.TokensLeft, + expectedReturnedTime, + expectedReturnedTimeUs, + ) + }) + } +} + +// Ensure that in the case where a request is made at a time between two tokens +// entering the bucket, that time is "returned to the user" and can be used in +// a future request. +func TestReturnedTimeMicroseconds(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + }) + + // Unlike other tests, start at 0us to make this easier to read + refreshedTime := time.Unix(TestBaseTimeSec, 0) + + keys := makeKeys() + mustInitKeys(t, rdb, keys, 0, refreshedTime) + + // The extra 100us between the refresh time and the current time isn't + // enough to increment a new token, so it's returned to the user as the + // script sets the new values for the refreshed keys. + const extraTime = 100 + argTime := refreshedTime.Add((SingleTokenDripTimeUs + extraTime) * time.Microsecond) + argTimeSec, argTimeUs := timeToRedisPair(argTime) + + resp := updateTokenBucketScript.Eval( + rdb, + keys.AsSlice(), + + // Args + TestReplenishRate, + TestCapacity, + 1, + argTimeSec, + argTimeUs, + "true", + ) + assertRedisResults(t, resp, + true, + 0, + argTimeSec, + + // The microseconds component stored in the database should + // return any extra time to the caller for use with the next + // token. This manifests itself as the time in the database + // being in the past by the extra time, such that the next call + // to this script can use the extra time. + argTimeUs-extraTime, + ) +} + +func TestReturnedTimeAcrossSecondBoundary(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + }) + + // Set the initial refreshed microsecond number right at the second + // boundary so that adding extra time below will push us over. + const refreshedTimeUs = MicrosecondsInSecond - 10 + refreshedTime := time.Unix(TestBaseTimeSec, refreshedTimeUs*int64(time.Microsecond/time.Nanosecond)) + + keys := makeKeys() + mustInitKeys(t, rdb, keys, 0, refreshedTime) + + // We call the token bucket at a time that's `extraTime` after the last + // refresh time, but without enough time having passed such that a + // token has entered the bucket. Due to the refreshed time above, this + // crosses a second boundary. + const extraTime = 100 + argTime := time.Unix(TestBaseTimeSec, 0) + argTime = argTime.Add((refreshedTimeUs + extraTime) * time.Microsecond) + argTimeSec, argTimeUs := timeToRedisPair(argTime) + + // Sanity-check that we did actually cross a second boundary + if refreshedTime.Unix() != argTime.Unix()-1 { + t.Errorf("expected refreshed time to be 1 second before the argument time: %d != %d - 1", + refreshedTime.Unix(), + argTime.Unix()-1, + ) + } + + resp := updateTokenBucketScript.Eval( + rdb, + keys.AsSlice(), + + // Args + TestReplenishRate, + TestCapacity, + 1, + argTimeSec, + argTimeUs, + "true", + ) + assertRedisResults(t, resp, + false, + 0, + + // We "subtracted" a second to add to the microseconds + // component + argTimeSec-1, + + // The microseconds component includes the extra time + // subtracted from above + MicrosecondsInSecond+(argTimeUs-extraTime), + ) + + // This case is a little complicated to get right, so make sure that + // sure that waiting the difference between `SingleTokenDripTimeUs` + // and `extra_time` does indeed give us one more token. + argTimeUs += (SingleTokenDripTimeUs - extraTime) + + resp = updateTokenBucketScript.Eval( + rdb, + keys.AsSlice(), + + // Args + TestReplenishRate, + TestCapacity, + 1, + argTimeSec, + argTimeUs, + "true", + ) + assertRedisResults(t, resp, + true, + 0, + argTimeSec, + argTimeUs, + ) +} + +// Ensure that the token bucket works with large replenishment rates. +func TestHighReplenishRate(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + }) + + keys := makeKeys() + mustInitKeys(t, rdb, keys, 0, TestBaseTime) + + const ( + tokensToAdd = 10 + testReplenishRate = 5000 + testCapacity = 5000 + singleTokenReplenishTime = MicrosecondsInSecond / testReplenishRate + ) + argTime := TestBaseTime.Add(singleTokenReplenishTime * tokensToAdd * time.Microsecond) + argTimeSec, argTimeUs := timeToRedisPair(argTime) + + resp := updateTokenBucketScript.Eval( + rdb, + keys.AsSlice(), + + // Args + testReplenishRate, + testCapacity, + 1, + argTimeSec, + argTimeUs, + "true", + ) + assertRedisResults(t, resp, + true, + tokensToAdd-1, + argTimeSec, + argTimeUs, + ) +} + +// Ensure that the token bucket works with very low (i.e. less than one +// token/second) replenishment rates. +func TestLowReplenishRate(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + }) + + keys := makeKeys() + mustInitKeys(t, rdb, keys, 0, TestBaseTime) + + const testReplenishRate = 0.5 + + // At a replenish rate of 0.5 it takes 2 seconds to get a full token + // added. Here we try to make a request at 1 second. We won't be able + // to make the request, but the script will subtract one second from + // the time it stores which will permit us to make the request in one + // more second. + argTimeSec, argTimeUs := timeToRedisPair(TestBaseTime.Add(1 * time.Second)) + resp := updateTokenBucketScript.Eval( + rdb, + keys.AsSlice(), + + // Args + testReplenishRate, + TestCapacity, + 1, + argTimeSec, + argTimeUs, + "true", + ) + assertRedisResults(t, resp, + false, + 0, + argTimeSec-1, + argTimeUs, + ) + + // Given one more second, the request is allowed. + argTimeSec, argTimeUs = timeToRedisPair(TestBaseTime.Add(2 * time.Second)) + resp = updateTokenBucketScript.Eval( + rdb, + keys.AsSlice(), + + // Args + testReplenishRate, + TestCapacity, + 1, + argTimeSec, + argTimeUs, + "true", + ) + assertRedisResults(t, resp, + true, + 0, + argTimeSec, + argTimeUs, + ) +} + +// Ensure that the token bucket works if we unexpectedly change the bucket size +// and replenish rates. +func TestChangingReplenishRate(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + }) + + tests := []struct { + Name string + Requested int64 + Expected int64 + }{ + // The bucket has been "shrunk" to fit the new size + { + Name: "request one token", + Requested: 1, + Expected: TestCapacity - 1, + }, + + // The bucket will be shrunk to its' maximum capacity even if + // we don't request anything + { + Name: "request no tokens", + Requested: 0, + Expected: TestCapacity, + }, + } + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + keys := makeKeys() + mustInitKeys(t, rdb, keys, TestCapacity*5, TestBaseTime) + + argTimeSec, argTimeUs := timeToRedisPair(TestBaseTime) + + resp := updateTokenBucketScript.Eval( + rdb, + keys.AsSlice(), + + // Args + TestReplenishRate, + TestCapacity, // smaller than above + test.Requested, + argTimeSec, + argTimeUs, + "true", + ) + assertRedisResults(t, resp, + true, + test.Expected, + argTimeSec, + argTimeUs, + ) + }) + } +} + +func TestErrorIfMicrosecondsTooLarge(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + }) + + keys := makeKeys() + resp := updateTokenBucketScript.Eval( + rdb, + keys.AsSlice(), + + // Args + TestReplenishRate, + TestCapacity, + 1, + TestBaseTimeSec, + MicrosecondsInSecond, + "true", + ) + + err := resp.Err() + if err == nil { + t.Fatalf("expected an error, but got none") + } + if err.Error() != `now_us must be smaller than 10^6 (microseconds in a second)` { + t.Errorf("invalid error message: %s", err.Error()) + } +} + +const alphanumeric = "abcdefghijklmnopqrstuvwxyz" + + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + + "0123456789" + +func stringWithCharset(length int, charset string) string { + setlen := big.NewInt(int64(len(charset))) + + b := make([]byte, length) + for i := range b { + ch, err := rand.Int(rand.Reader, setlen) + if err != nil { + panic(err) + } + + b[i] = charset[ch.Int64()] + } + return string(b) +} + +// Wrapper type for the keys we use in our call to the script. +type redisKeys struct { + Prefix string + Tokens string + Refreshed string + RefreshedUs string +} + +// Get the keys in a slice format for passing to Eval +func (r redisKeys) AsSlice() []string { + return []string{r.Tokens, r.Refreshed, r.RefreshedUs} +} + +// Helper function to create Redis key names wrapper type +func makeKeys() redisKeys { + var ret redisKeys + + // Ensure each test ends up under a different key + ret.Prefix = stringWithCharset(20, alphanumeric) + + // Use same key generation function as the live code + keys := tokenBucketKeys(ret.Prefix) + ret.Tokens = keys[0] + ret.Refreshed = keys[1] + ret.RefreshedUs = keys[2] + return ret +} + +// Helper function to initialize Redis keys to given values +func mustInitKeys(t *testing.T, c *redis.Client, keys redisKeys, tokens int64, refreshed time.Time) { + t.Helper() + + refreshedSec, refreshedUs := timeToRedisPair(refreshed) + + if err := c.Set(keys.Tokens, tokens, TTL*time.Second).Err(); err != nil { + t.Fatalf("failed to set initial tokens: %v", err) + } + if err := c.Set(keys.Refreshed, refreshedSec, TTL*time.Second).Err(); err != nil { + t.Fatalf("failed to set refreshed time: %v", err) + } + if err := c.Set(keys.RefreshedUs, refreshedUs, TTL*time.Second).Err(); err != nil { + t.Fatalf("failed to set refreshed time (us): %v", err) + } +} + +// Helper function that deserializes the returned values from our Redis script. +func deserializeRedisResults(t *testing.T, resp *redis.Cmd) (bool, int64, int64, int64) { + t.Helper() + + results, err := resp.Result() + if err != nil { + t.Fatalf("error calling script: %v", err) + } + + // Deserialize results + returnVals, ok := results.([]interface{}) + if !ok { + t.Fatalf("invalid return type %T (expected []interface{})", results) + } + + var allowed bool + if returnVals[0] == nil { + allowed = false + } else if i, ok := returnVals[0].(int64); ok { + allowed = i == 1 + } else { + t.Fatalf("invalid 'allowed' type %T", returnVals[0]) + } + + remaining := returnVals[1].(int64) + returnedTime := returnVals[2].(int64) + returnedTimeUs := returnVals[3].(int64) + + return allowed, remaining, returnedTime, returnedTimeUs +} + +// Helper function that deserializes the returned values from our Redis script, +// and then asserts that they match the expected values provided. +func assertRedisResults(t *testing.T, resp *redis.Cmd, allowed bool, remaining, returnedTime, returnedTimeUs int64) { + t.Helper() + + actualAllowed, + actualRemaining, + actualReturnedTime, + actualReturnedTimeUs := deserializeRedisResults(t, resp) + + if allowed != actualAllowed { + t.Errorf("expected 'allowed' to be %t but got %t", allowed, actualAllowed) + } + if remaining != actualRemaining { + t.Errorf("expected 'remaining' to be %d but got %d", remaining, actualRemaining) + } + if returnedTime != actualReturnedTime { + t.Errorf("expected returned time to be %d but got %d", returnedTime, actualReturnedTime) + } + if returnedTimeUs != actualReturnedTimeUs { + t.Errorf("expected returned time (us) to be %d but got %d", returnedTimeUs, actualReturnedTimeUs) + } +} diff --git a/quota/redis/redistb/update_token_bucket.gen.go b/quota/redis/redistb/update_token_bucket.gen.go new file mode 100644 index 0000000000..f6b773af7f --- /dev/null +++ b/quota/redis/redistb/update_token_bucket.gen.go @@ -0,0 +1,14 @@ +// Code generated by quota/redis/redistb/gen.go. DO NOT EDIT. +// source: update_token_bucket.lua + +package redistb + +import ( + "github.com/go-redis/redis" +) + +// contents of the 'updateTokenBucket' Redis Lua script +const updateTokenBucketScriptContents = "--[[\n\nLICENSE\n===================\n\nCopyright 2017 Google Inc. All Rights Reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\nTOKEN BUCKET\n===================\n\nScript to read and update a token bucket maintained in Redis. This is an\nimplementation of the token bucket algorithm which is a common fixture seen in\nrate limiting:\n\n https://en.wikipedia.org/wiki/Token_bucket\n\nFor each key prefix, we maintain three values:\n\n * `.tokens`: Number of tokens in bucket at refresh time.\n\n * `.refreshed`: Time in epoch seconds when this prefix's bucket was\n last updated.\n\n * `.refreshed_us`: The microsecond component of the last updated\n time above. Stored separately because a Unix epoch with a microsecond\n component brushes up uncomfortably close to integer boundaries.\n\nThe basic strategy is to, at update/read time, fill in all tokens\nthat would have accumulated since the last update, and then if\npossible deduct the number of requested tokens (or disallow the\nrequested action if there are not enough tokens).\n\nThe approach relies on the atomicity of EVAL in redis - only 1 command (EVAL or\notherwise) will be running concurrently per shard in the Redis cluster. Redis\nand Lua are very fast, so in practice this works out okay.\n\nA note on units: all times (instants) are measured in epoch seconds with a\nseparate microsecond component, durations in imicroseconds, and rates in\ntokens/second (e.g., a rate of 100 is 100 tokens/second).\n\nFor debugging, I'd recommend adding Redis log statements and then tailing your\nRedis log. Example:\n\n redis.log(redis.LOG_WARNING, string.format(\"rate = %s\", rate))\n\n--]]\n\n--\n-- Constants\n--\n-- Lua doesn't actually have constants, so these are constants by convention\n-- only. Please don't modify them.\n--\n\nlocal MICROSECONDS_IN_SECOND = 1000000.0\n\n--\n-- Functions\n--\n\nlocal function subtract_time (base, base_us, leftover_time_us)\n base = base - math.floor(leftover_time_us / MICROSECONDS_IN_SECOND)\n\n leftover_time_us = leftover_time_us % MICROSECONDS_IN_SECOND\n\n base_us = base_us - leftover_time_us\n if base_us < 0 then\n base = base - 1\n base_us = MICROSECONDS_IN_SECOND + base_us\n end\n\n return base, base_us\nend\n\n--\n-- Keys and arguments\n--\n\nlocal key_tokens = KEYS[1]\n\n-- Unix time since the epoch in microseconds runs up uncomfortably close to\n-- integer boundaries, so we store time as two separate components: (1) seconds\n-- since epoch, and (2) microseconds with the current second.\nlocal key_refreshed = KEYS[2]\nlocal key_refreshed_us = KEYS[3]\n\nlocal rate = tonumber(ARGV[1])\nlocal capacity = tonumber(ARGV[2])\nlocal requested = tonumber(ARGV[3])\n\n-- Callers are allowed to inject the current time into the script, but note\n-- that outside of testing, this will always superseded by the time reported by\n-- the Redis instance so as to protect against clock drift on any particular\n-- local node.\nlocal now = tonumber(ARGV[4])\nlocal now_us = tonumber(ARGV[5])\n\n-- This is ugly, but all values passed in from Ruby get converted to strings\nlocal testing = ARGV[6] == \"true\"\n\n--\n-- Program body\n--\n\n-- See comment above.\nif testing then\n if now_us >= MICROSECONDS_IN_SECOND then\n return redis.error_reply(\"now_us must be smaller than 10^6 (microseconds in a second)\")\n end\nelse\n -- Scripts in Redis are pure functions by default which allows Redis to\n -- replicate the entire script rather than the individual commands that it\n -- contains. Because we're about to invoke `TIME` which produces a\n -- non-deterministic result, we need to tell Redis to instead switch to\n -- command-level replication for write operations. It will error if we\n -- don't.\n redis.replicate_commands()\n\n local current_time = redis.call(\"TIME\")\n\n -- Redis `TIME` comes back in two components: (1) seconds since epoch, and\n -- (2) microseconds within the current second.\n now = tonumber(current_time[1])\n now_us = tonumber(current_time[2])\nend\n\nlocal filled_tokens = capacity\n\nlocal last_tokens = redis.call(\"GET\", key_tokens)\n\nlocal last_refreshed = redis.call(\"GET\", key_refreshed)\n\nlocal last_refreshed_us = redis.call(\"GET\", key_refreshed_us)\n\n-- Only bother performing rate calculations if we actually need to. i.e., The\n-- user has made a request recently enough to still be in the system.\nif last_tokens and last_refreshed then\n last_tokens = tonumber(last_tokens)\n last_refreshed = tonumber(last_refreshed)\n\n -- Rejected a `now` that reads before our recorded `last_refreshed` time.\n -- No reversed deltas are allowed.\n if now < last_refreshed then\n now = last_refreshed\n now_us = last_refreshed_us\n end\n\n local delta = now - last_refreshed\n local delta_us = delta * MICROSECONDS_IN_SECOND + (now_us - last_refreshed_us)\n\n -- The time (in microseconds) that it takes to \"drip\" a single token. For\n -- example, if our rate is 100 tokens per second, then one token is allowed\n -- every 10^6 / 100 = 10,000 microseconds.\n local single_token_time_us = math.floor(MICROSECONDS_IN_SECOND / rate)\n\n local new_tokens = math.floor(delta_us / single_token_time_us)\n filled_tokens = math.min(capacity, last_tokens + new_tokens)\n\n -- For maximum fairness, modify the last refresh time by any leftover time\n -- that didn't go towards adding a token.\n --\n -- However, only bother with this if the bucket hasn't been replenished to\n -- full capacity. If it was, the user has had more replenishment time than\n -- they can use anyway.\n if filled_tokens ~= capacity then\n local leftover_time_us = delta_us % single_token_time_us\n now, now_us = subtract_time(now, now_us, leftover_time_us)\n end\nend\n\nlocal allowed = filled_tokens >= requested\nlocal new_tokens = filled_tokens\nif allowed then\n new_tokens = filled_tokens - requested\nend\n\n-- Set a TTL on the values we set in Redis that will expire them after the\n-- point in time they would have been fully replenished, which allows us to\n-- manage space more efficiently by removing keys that don't need to be in\n-- there.\n--\n-- Keys that are ~always in use because their owners make frequent requests\n-- will be updated by this script constantly (which sets new TTLs), and\n-- never expire.\nlocal fill_time = math.ceil(capacity / rate)\nlocal ttl = math.floor(fill_time * 2)\n\n-- Redis will reject a expiry of 0 to `SETEX`, so make sure TTL is always at\n-- least 1.\nttl = math.max(ttl, 1)\n\n-- In our tests we freeze time. Because we can't freeze Redis' notion of time\n-- and want to make sure that keys we set within test cases don't expire, we\n-- forego the standard TTL that we would have set for just a long one to make\n-- sure anything we set expires well after the test case will have finished.\nif testing then\n ttl = 3600\nend\n\nredis.call(\"SETEX\", key_tokens, ttl, new_tokens)\nredis.call(\"SETEX\", key_refreshed, ttl, now)\nredis.call(\"SETEX\", key_refreshed_us, ttl, now_us)\n\nreturn { allowed, new_tokens, now, now_us }\n" + +// Redis Script type for the 'updateTokenBucket' Redis lua script +var updateTokenBucketScript = redis.NewScript(updateTokenBucketScriptContents) diff --git a/quota/redis/redistb/update_token_bucket.lua b/quota/redis/redistb/update_token_bucket.lua new file mode 100644 index 0000000000..0fea85494a --- /dev/null +++ b/quota/redis/redistb/update_token_bucket.lua @@ -0,0 +1,216 @@ +--[[ + +LICENSE +=================== + +Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +TOKEN BUCKET +=================== + +Script to read and update a token bucket maintained in Redis. This is an +implementation of the token bucket algorithm which is a common fixture seen in +rate limiting: + + https://en.wikipedia.org/wiki/Token_bucket + +For each key prefix, we maintain three values: + + * `.tokens`: Number of tokens in bucket at refresh time. + + * `.refreshed`: Time in epoch seconds when this prefix's bucket was + last updated. + + * `.refreshed_us`: The microsecond component of the last updated + time above. Stored separately because a Unix epoch with a microsecond + component brushes up uncomfortably close to integer boundaries. + +The basic strategy is to, at update/read time, fill in all tokens +that would have accumulated since the last update, and then if +possible deduct the number of requested tokens (or disallow the +requested action if there are not enough tokens). + +The approach relies on the atomicity of EVAL in redis - only 1 command (EVAL or +otherwise) will be running concurrently per shard in the Redis cluster. Redis +and Lua are very fast, so in practice this works out okay. + +A note on units: all times (instants) are measured in epoch seconds with a +separate microsecond component, durations in imicroseconds, and rates in +tokens/second (e.g., a rate of 100 is 100 tokens/second). + +For debugging, I'd recommend adding Redis log statements and then tailing your +Redis log. Example: + + redis.log(redis.LOG_WARNING, string.format("rate = %s", rate)) + +--]] + +-- +-- Constants +-- +-- Lua doesn't actually have constants, so these are constants by convention +-- only. Please don't modify them. +-- + +local MICROSECONDS_IN_SECOND = 1000000.0 + +-- +-- Functions +-- + +local function subtract_time (base, base_us, leftover_time_us) + base = base - math.floor(leftover_time_us / MICROSECONDS_IN_SECOND) + + leftover_time_us = leftover_time_us % MICROSECONDS_IN_SECOND + + base_us = base_us - leftover_time_us + if base_us < 0 then + base = base - 1 + base_us = MICROSECONDS_IN_SECOND + base_us + end + + return base, base_us +end + +-- +-- Keys and arguments +-- + +local key_tokens = KEYS[1] + +-- Unix time since the epoch in microseconds runs up uncomfortably close to +-- integer boundaries, so we store time as two separate components: (1) seconds +-- since epoch, and (2) microseconds with the current second. +local key_refreshed = KEYS[2] +local key_refreshed_us = KEYS[3] + +local rate = tonumber(ARGV[1]) +local capacity = tonumber(ARGV[2]) +local requested = tonumber(ARGV[3]) + +-- Callers are allowed to inject the current time into the script, but note +-- that outside of testing, this will always superseded by the time reported by +-- the Redis instance so as to protect against clock drift on any particular +-- local node. +local now = tonumber(ARGV[4]) +local now_us = tonumber(ARGV[5]) + +-- This is ugly, but all values passed in from Ruby get converted to strings +local testing = ARGV[6] == "true" + +-- +-- Program body +-- + +-- See comment above. +if testing then + if now_us >= MICROSECONDS_IN_SECOND then + return redis.error_reply("now_us must be smaller than 10^6 (microseconds in a second)") + end +else + -- Scripts in Redis are pure functions by default which allows Redis to + -- replicate the entire script rather than the individual commands that it + -- contains. Because we're about to invoke `TIME` which produces a + -- non-deterministic result, we need to tell Redis to instead switch to + -- command-level replication for write operations. It will error if we + -- don't. + redis.replicate_commands() + + local current_time = redis.call("TIME") + + -- Redis `TIME` comes back in two components: (1) seconds since epoch, and + -- (2) microseconds within the current second. + now = tonumber(current_time[1]) + now_us = tonumber(current_time[2]) +end + +local filled_tokens = capacity + +local last_tokens = redis.call("GET", key_tokens) + +local last_refreshed = redis.call("GET", key_refreshed) + +local last_refreshed_us = redis.call("GET", key_refreshed_us) + +-- Only bother performing rate calculations if we actually need to. i.e., The +-- user has made a request recently enough to still be in the system. +if last_tokens and last_refreshed then + last_tokens = tonumber(last_tokens) + last_refreshed = tonumber(last_refreshed) + + -- Rejected a `now` that reads before our recorded `last_refreshed` time. + -- No reversed deltas are allowed. + if now < last_refreshed then + now = last_refreshed + now_us = last_refreshed_us + end + + local delta = now - last_refreshed + local delta_us = delta * MICROSECONDS_IN_SECOND + (now_us - last_refreshed_us) + + -- The time (in microseconds) that it takes to "drip" a single token. For + -- example, if our rate is 100 tokens per second, then one token is allowed + -- every 10^6 / 100 = 10,000 microseconds. + local single_token_time_us = math.floor(MICROSECONDS_IN_SECOND / rate) + + local new_tokens = math.floor(delta_us / single_token_time_us) + filled_tokens = math.min(capacity, last_tokens + new_tokens) + + -- For maximum fairness, modify the last refresh time by any leftover time + -- that didn't go towards adding a token. + -- + -- However, only bother with this if the bucket hasn't been replenished to + -- full capacity. If it was, the user has had more replenishment time than + -- they can use anyway. + if filled_tokens ~= capacity then + local leftover_time_us = delta_us % single_token_time_us + now, now_us = subtract_time(now, now_us, leftover_time_us) + end +end + +local allowed = filled_tokens >= requested +local new_tokens = filled_tokens +if allowed then + new_tokens = filled_tokens - requested +end + +-- Set a TTL on the values we set in Redis that will expire them after the +-- point in time they would have been fully replenished, which allows us to +-- manage space more efficiently by removing keys that don't need to be in +-- there. +-- +-- Keys that are ~always in use because their owners make frequent requests +-- will be updated by this script constantly (which sets new TTLs), and +-- never expire. +local fill_time = math.ceil(capacity / rate) +local ttl = math.floor(fill_time * 2) + +-- Redis will reject a expiry of 0 to `SETEX`, so make sure TTL is always at +-- least 1. +ttl = math.max(ttl, 1) + +-- In our tests we freeze time. Because we can't freeze Redis' notion of time +-- and want to make sure that keys we set within test cases don't expire, we +-- forego the standard TTL that we would have set for just a long one to make +-- sure anything we set expires well after the test case will have finished. +if testing then + ttl = 3600 +end + +redis.call("SETEX", key_tokens, ttl, new_tokens) +redis.call("SETEX", key_refreshed, ttl, now) +redis.call("SETEX", key_refreshed_us, ttl, now_us) + +return { allowed, new_tokens, now, now_us }