Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More generic API to work with x/time/rate #582

Merged
merged 2 commits into from
Jul 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 77 additions & 7 deletions ratelimit/token_bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,95 @@ var ErrLimited = errors.New("rate limit exceeded")
// limiter based on a token-bucket algorithm. Requests that would exceed the
// maximum request rate are simply rejected with an error.
func NewTokenBucketLimiter(tb *ratelimit.Bucket) endpoint.Middleware {
return NewErroringLimiter(NewAllower(tb))
}

// NewTokenBucketThrottler returns an endpoint.Middleware that acts as a
// request throttler based on a token-bucket algorithm. Requests that would
// exceed the maximum request rate are delayed.
// The parameterized function "_" is kept for backwards-compatiblity of
// the API, but it is no longer used for anything. You may pass it nil.
func NewTokenBucketThrottler(tb *ratelimit.Bucket, _ func(time.Duration)) endpoint.Middleware {
return NewDelayingLimiter(NewWaiter(tb))
}

// Allower dictates whether or not a request is acceptable to run.
// The Limiter from "golang.org/x/time/rate" already implements this interface,
// one is able to use that in NewErroringLimiter without any modifications.
type Allower interface {
Allow() bool
}

// NewErroringLimiter returns an endpoint.Middleware that acts as a rate
// limiter. Requests that would exceed the
// maximum request rate are simply rejected with an error.
func NewErroringLimiter(limit Allower) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
if tb.TakeAvailable(1) == 0 {
if !limit.Allow() {
return nil, ErrLimited
}
return next(ctx, request)
}
}
}

// NewTokenBucketThrottler returns an endpoint.Middleware that acts as a
// request throttler based on a token-bucket algorithm. Requests that would
// exceed the maximum request rate are delayed via the parameterized sleep
// function. By default you may pass time.Sleep.
func NewTokenBucketThrottler(tb *ratelimit.Bucket, sleep func(time.Duration)) endpoint.Middleware {
// Waiter dictates how long a request must be delayed.
// The Limiter from "golang.org/x/time/rate" already implements this interface,
// one is able to use that in NewDelayingLimiter without any modifications.
type Waiter interface {
Wait(ctx context.Context) error
}

// NewDelayingLimiter returns an endpoint.Middleware that acts as a
// request throttler. Requests that would
// exceed the maximum request rate are delayed via the Waiter function
func NewDelayingLimiter(limit Waiter) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
sleep(tb.Take(1))
if err := limit.Wait(ctx); err != nil {
return nil, err
}
return next(ctx, request)
}
}
}

// AllowerFunc is an adapter that lets a function operate as if
// it implements Allower
type AllowerFunc func() bool

// Allow makes the adapter implement Allower
func (f AllowerFunc) Allow() bool {
return f()
}

// NewAllower turns an existing ratelimit.Bucket into an API-compatible form
func NewAllower(tb *ratelimit.Bucket) Allower {
return AllowerFunc(func() bool {
return (tb.TakeAvailable(1) != 0)
})
}

// WaiterFunc is an adapter that lets a function operate as if
// it implements Waiter
type WaiterFunc func(ctx context.Context) error

// Wait makes the adapter implement Waiter
func (f WaiterFunc) Wait(ctx context.Context) error {
return f(ctx)
}

// NewWaiter turns an existing ratelimit.Bucket into an API-compatible form
func NewWaiter(tb *ratelimit.Bucket) Waiter {
return WaiterFunc(func(ctx context.Context) error {
dur := tb.Take(1)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(dur):
// happy path
}
return nil
})
}
66 changes: 37 additions & 29 deletions ratelimit/token_bucket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,62 @@ package ratelimit_test

import (
"context"
"math"
"strings"
"testing"
"time"

jujuratelimit "github.com/juju/ratelimit"
"golang.org/x/time/rate"

"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/ratelimit"
)

var nopEndpoint = func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }

func TestTokenBucketLimiter(t *testing.T) {
e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
for _, n := range []int{1, 2, 100} {
tb := jujuratelimit.NewBucketWithRate(float64(n), int64(n))
testLimiter(t, ratelimit.NewTokenBucketLimiter(tb)(e), n)
}
tb := jujuratelimit.NewBucket(time.Minute, 1)
testSuccessThenFailure(
t,
ratelimit.NewTokenBucketLimiter(tb)(nopEndpoint),
ratelimit.ErrLimited.Error())
}

func TestTokenBucketThrottler(t *testing.T) {
d := time.Duration(0)
s := func(d0 time.Duration) { d = d0 }

e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
e = ratelimit.NewTokenBucketThrottler(jujuratelimit.NewBucketWithRate(1, 1), s)(e)
tb := jujuratelimit.NewBucket(time.Minute, 1)
testSuccessThenFailure(
t,
ratelimit.NewTokenBucketThrottler(tb, nil)(nopEndpoint),
"context deadline exceeded")
}

// First request should go through with no delay.
e(context.Background(), struct{}{})
if want, have := time.Duration(0), d; want != have {
t.Errorf("want %s, have %s", want, have)
}
func TestXRateErroring(t *testing.T) {
limit := rate.NewLimiter(rate.Every(time.Minute), 1)
testSuccessThenFailure(
t,
ratelimit.NewErroringLimiter(limit)(nopEndpoint),
ratelimit.ErrLimited.Error())
}

// Next request should request a ~1s sleep.
e(context.Background(), struct{}{})
if want, have, tol := time.Second, d, time.Millisecond; math.Abs(float64(want-have)) > float64(tol) {
t.Errorf("want %s, have %s", want, have)
}
func TestXRateDelaying(t *testing.T) {
limit := rate.NewLimiter(rate.Every(time.Minute), 1)
testSuccessThenFailure(
t,
ratelimit.NewDelayingLimiter(limit)(nopEndpoint),
"exceed context deadline")
}

func testLimiter(t *testing.T, e endpoint.Endpoint, rate int) {
// First <rate> requests should succeed.
for i := 0; i < rate; i++ {
if _, err := e(context.Background(), struct{}{}); err != nil {
t.Fatalf("rate=%d: request %d/%d failed: %v", rate, i+1, rate, err)
}
func testSuccessThenFailure(t *testing.T, e endpoint.Endpoint, failContains string) {
ctx, cxl := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cxl()

// First request should succeed.
if _, err := e(ctx, struct{}{}); err != nil {
t.Errorf("unexpected: %v\n", err)
}

// Next request should fail.
if _, err := e(context.Background(), struct{}{}); err != ratelimit.ErrLimited {
t.Errorf("rate=%d: want %v, have %v", rate, ratelimit.ErrLimited, err)
if _, err := e(ctx, struct{}{}); !strings.Contains(err.Error(), failContains) {
t.Errorf("expected `%s`: %v\n", failContains, err)
}
}