From 97150340a7037dd11c4ee9045323cd811761e6c9 Mon Sep 17 00:00:00 2001 From: Nelz Date: Thu, 20 Jul 2017 18:32:41 -0700 Subject: [PATCH 1/2] More generic API to work with x/time/rate --- ratelimit/token_bucket.go | 89 +++++++++++++++++++++++++++++++--- ratelimit/token_bucket_test.go | 28 +++++++++++ 2 files changed, 110 insertions(+), 7 deletions(-) diff --git a/ratelimit/token_bucket.go b/ratelimit/token_bucket.go index f6df9d00e..7bc06b2bd 100644 --- a/ratelimit/token_bucket.go +++ b/ratelimit/token_bucket.go @@ -18,9 +18,37 @@ var ErrLimited = errors.New("rate limit exceeded") // limiter based on a token-bucket algorithm. Requests that would exceed the // maximum request rate are simply rejected with an error. func NewTokenBucketLimiter(tb *ratelimit.Bucket) endpoint.Middleware { + return NewErroringLimiter(NewAllower(tb)) +} + +// NewTokenBucketThrottler returns an endpoint.Middleware that acts as a +// request throttler based on a token-bucket algorithm. Requests that would +// exceed the maximum request rate are delayed via the parameterized sleep +// function. By default you may pass time.Sleep. +func NewTokenBucketThrottler(tb *ratelimit.Bucket, sleep func(time.Duration)) endpoint.Middleware { + // return NewDelayingLimiter(NewWaiter(tb)) return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - if tb.TakeAvailable(1) == 0 { + sleep(tb.Take(1)) + return next(ctx, request) + } + } +} + +// Allower dictates whether or not a request is acceptable to run. +// The Limiter from "golang.org/x/time/rate" already implements this interface, +// one is able to use that in NewErroringLimiter without any modifications. +type Allower interface { + Allow() bool +} + +// NewErroringLimiter returns an endpoint.Middleware that acts as a rate +// limiter. Requests that would exceed the +// maximum request rate are simply rejected with an error. +func NewErroringLimiter(limit Allower) endpoint.Middleware { + return func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + if !limit.Allow() { return nil, ErrLimited } return next(ctx, request) @@ -28,15 +56,62 @@ func NewTokenBucketLimiter(tb *ratelimit.Bucket) endpoint.Middleware { } } -// NewTokenBucketThrottler returns an endpoint.Middleware that acts as a -// request throttler based on a token-bucket algorithm. Requests that would -// exceed the maximum request rate are delayed via the parameterized sleep -// function. By default you may pass time.Sleep. -func NewTokenBucketThrottler(tb *ratelimit.Bucket, sleep func(time.Duration)) endpoint.Middleware { +// Waiter dictates how long a request must be delayed. +// The Limiter from "golang.org/x/time/rate" already implements this interface, +// one is able to use that in NewDelayingLimiter without any modifications. +type Waiter interface { + Wait(ctx context.Context) error +} + +// NewDelayingLimiter returns an endpoint.Middleware that acts as a +// request throttler. Requests that would +// exceed the maximum request rate are delayed via the Waiter function +func NewDelayingLimiter(limit Waiter) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - sleep(tb.Take(1)) + if err := limit.Wait(ctx); err != nil { + return nil, err + } return next(ctx, request) } } } + +// AllowerFunc is an adapter that lets a function operate as if +// it implements Allower +type AllowerFunc func() bool + +// Allow makes the adapter implement Allower +func (f AllowerFunc) Allow() bool { + return f() +} + +// NewAllower turns an existing ratelimit.Bucket into an API-compatible form +func NewAllower(tb *ratelimit.Bucket) Allower { + return AllowerFunc(func() bool { + return (tb.TakeAvailable(1) != 0) + }) +} + +// WaiterFunc is an adapter that lets a function operate as if +// it implements Waiter +type WaiterFunc func(ctx context.Context) error + +// Wait makes the adapter implement Waiter +func (f WaiterFunc) Wait(ctx context.Context) error { + return f(ctx) +} + +// NewWaiter turns an existing ratelimit.Bucket into an API-compatible form +func NewWaiter(tb *ratelimit.Bucket) Waiter { + return WaiterFunc(func(ctx context.Context) error { + dur := tb.Take(1) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(dur): + // happy path + } + return nil + }) +} diff --git a/ratelimit/token_bucket_test.go b/ratelimit/token_bucket_test.go index 54225b80e..319f845df 100644 --- a/ratelimit/token_bucket_test.go +++ b/ratelimit/token_bucket_test.go @@ -3,6 +3,7 @@ package ratelimit_test import ( "context" "math" + "strings" "testing" "time" @@ -10,6 +11,7 @@ import ( "github.com/go-kit/kit/endpoint" "github.com/go-kit/kit/ratelimit" + "golang.org/x/time/rate" ) func TestTokenBucketLimiter(t *testing.T) { @@ -53,3 +55,29 @@ func testLimiter(t *testing.T, e endpoint.Endpoint, rate int) { t.Errorf("rate=%d: want %v, have %v", rate, ratelimit.ErrLimited, err) } } + +func TestXRateErroring(t *testing.T) { + limit := rate.NewLimiter(rate.Every(time.Minute), 1) + e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } + testLimiter(t, ratelimit.NewErroringLimiter(limit)(e), 1) +} + +func TestXRateDelaying(t *testing.T) { + limit := rate.NewLimiter(rate.Every(time.Minute), 1) + e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } + e = ratelimit.NewDelayingLimiter(limit)(e) + + _, err := e(context.Background(), struct{}{}) + if err != nil { + t.Errorf("unexpected: %v\n", err) + } + + dur := 500 * time.Millisecond + ctx, cxl := context.WithTimeout(context.Background(), dur) + defer cxl() + + _, err = e(ctx, struct{}{}) + if !strings.Contains(err.Error(), "exceed context deadline") { + t.Errorf("expected timeout: %v\n", err) + } +} From 5ea37e7a6b609d6be4342faff91f9fde8f1f540f Mon Sep 17 00:00:00 2001 From: Nelz Date: Fri, 21 Jul 2017 16:07:58 -0700 Subject: [PATCH 2/2] All old expressed in terms of new --- ratelimit/token_bucket.go | 15 +++---- ratelimit/token_bucket_test.go | 82 +++++++++++++--------------------- 2 files changed, 36 insertions(+), 61 deletions(-) diff --git a/ratelimit/token_bucket.go b/ratelimit/token_bucket.go index 7bc06b2bd..b71e50bb1 100644 --- a/ratelimit/token_bucket.go +++ b/ratelimit/token_bucket.go @@ -23,16 +23,11 @@ func NewTokenBucketLimiter(tb *ratelimit.Bucket) endpoint.Middleware { // NewTokenBucketThrottler returns an endpoint.Middleware that acts as a // request throttler based on a token-bucket algorithm. Requests that would -// exceed the maximum request rate are delayed via the parameterized sleep -// function. By default you may pass time.Sleep. -func NewTokenBucketThrottler(tb *ratelimit.Bucket, sleep func(time.Duration)) endpoint.Middleware { - // return NewDelayingLimiter(NewWaiter(tb)) - return func(next endpoint.Endpoint) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - sleep(tb.Take(1)) - return next(ctx, request) - } - } +// exceed the maximum request rate are delayed. +// The parameterized function "_" is kept for backwards-compatiblity of +// the API, but it is no longer used for anything. You may pass it nil. +func NewTokenBucketThrottler(tb *ratelimit.Bucket, _ func(time.Duration)) endpoint.Middleware { + return NewDelayingLimiter(NewWaiter(tb)) } // Allower dictates whether or not a request is acceptable to run. diff --git a/ratelimit/token_bucket_test.go b/ratelimit/token_bucket_test.go index 319f845df..d444fe992 100644 --- a/ratelimit/token_bucket_test.go +++ b/ratelimit/token_bucket_test.go @@ -2,82 +2,62 @@ package ratelimit_test import ( "context" - "math" "strings" "testing" "time" jujuratelimit "github.com/juju/ratelimit" + "golang.org/x/time/rate" "github.com/go-kit/kit/endpoint" "github.com/go-kit/kit/ratelimit" - "golang.org/x/time/rate" ) +var nopEndpoint = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } + func TestTokenBucketLimiter(t *testing.T) { - e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - for _, n := range []int{1, 2, 100} { - tb := jujuratelimit.NewBucketWithRate(float64(n), int64(n)) - testLimiter(t, ratelimit.NewTokenBucketLimiter(tb)(e), n) - } + tb := jujuratelimit.NewBucket(time.Minute, 1) + testSuccessThenFailure( + t, + ratelimit.NewTokenBucketLimiter(tb)(nopEndpoint), + ratelimit.ErrLimited.Error()) } func TestTokenBucketThrottler(t *testing.T) { - d := time.Duration(0) - s := func(d0 time.Duration) { d = d0 } - - e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - e = ratelimit.NewTokenBucketThrottler(jujuratelimit.NewBucketWithRate(1, 1), s)(e) - - // First request should go through with no delay. - e(context.Background(), struct{}{}) - if want, have := time.Duration(0), d; want != have { - t.Errorf("want %s, have %s", want, have) - } - - // Next request should request a ~1s sleep. - e(context.Background(), struct{}{}) - if want, have, tol := time.Second, d, time.Millisecond; math.Abs(float64(want-have)) > float64(tol) { - t.Errorf("want %s, have %s", want, have) - } -} - -func testLimiter(t *testing.T, e endpoint.Endpoint, rate int) { - // First requests should succeed. - for i := 0; i < rate; i++ { - if _, err := e(context.Background(), struct{}{}); err != nil { - t.Fatalf("rate=%d: request %d/%d failed: %v", rate, i+1, rate, err) - } - } - - // Next request should fail. - if _, err := e(context.Background(), struct{}{}); err != ratelimit.ErrLimited { - t.Errorf("rate=%d: want %v, have %v", rate, ratelimit.ErrLimited, err) - } + tb := jujuratelimit.NewBucket(time.Minute, 1) + testSuccessThenFailure( + t, + ratelimit.NewTokenBucketThrottler(tb, nil)(nopEndpoint), + "context deadline exceeded") } func TestXRateErroring(t *testing.T) { limit := rate.NewLimiter(rate.Every(time.Minute), 1) - e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - testLimiter(t, ratelimit.NewErroringLimiter(limit)(e), 1) + testSuccessThenFailure( + t, + ratelimit.NewErroringLimiter(limit)(nopEndpoint), + ratelimit.ErrLimited.Error()) } func TestXRateDelaying(t *testing.T) { limit := rate.NewLimiter(rate.Every(time.Minute), 1) - e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil } - e = ratelimit.NewDelayingLimiter(limit)(e) + testSuccessThenFailure( + t, + ratelimit.NewDelayingLimiter(limit)(nopEndpoint), + "exceed context deadline") +} - _, err := e(context.Background(), struct{}{}) - if err != nil { +func testSuccessThenFailure(t *testing.T, e endpoint.Endpoint, failContains string) { + ctx, cxl := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cxl() + + // First request should succeed. + if _, err := e(ctx, struct{}{}); err != nil { t.Errorf("unexpected: %v\n", err) } - dur := 500 * time.Millisecond - ctx, cxl := context.WithTimeout(context.Background(), dur) - defer cxl() - - _, err = e(ctx, struct{}{}) - if !strings.Contains(err.Error(), "exceed context deadline") { - t.Errorf("expected timeout: %v\n", err) + // Next request should fail. + if _, err := e(ctx, struct{}{}); !strings.Contains(err.Error(), failContains) { + t.Errorf("expected `%s`: %v\n", failContains, err) } }