Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client-side rate limit interceptors (#520) #545

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions interceptors/ratelimit/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,58 @@ func (*alwaysPassLimiter) Limit(_ context.Context) error {
//
// // Rate limit isn't reached.
// return nil
//}
// }
return nil
}

// Simple example of server initialization code.
func Example() {
// Simple example of a unary server initialization code.
func ExampleUnaryServerInterceptor() {
// Create unary/stream rateLimiters, based on token bucket here.
// You can implement your own ratelimiter for the interface.
// You can implement your own rate-limiter for the interface.
limiter := &alwaysPassLimiter{}
_ = grpc.NewServer(
grpc.ChainUnaryInterceptor(
ratelimit.UnaryServerInterceptor(limiter),
),
)
}

// Simple example of a streaming server initialization code.
func ExampleStreamServerInterceptor() {
// Create unary/stream rateLimiters, based on token bucket here.
// You can implement your own rate-limiter for the interface.
limiter := &alwaysPassLimiter{}
_ = grpc.NewServer(
grpc.ChainStreamInterceptor(
ratelimit.StreamServerInterceptor(limiter),
),
)
}

// Simple example of a unary client initialization code.
func ExampleUnaryClientInterceptor() {
// Create stream rateLimiter, based on token bucket here.
// You can implement your own rate-limiter for the interface.
limiter := &alwaysPassLimiter{}
_, _ = grpc.DialContext(
context.Background(),
":8080",
grpc.WithUnaryInterceptor(
ratelimit.UnaryClientInterceptor(limiter),
),
)
}

// Simple example of a streaming client initialization code.
func ExampleStreamClientInterceptor() {
// Create stream rateLimiter, based on token bucket here.
// You can implement your own rate-limiter for the interface.
limiter := &alwaysPassLimiter{}
_, _ = grpc.DialContext(
context.Background(),
":8080",
grpc.WithChainStreamInterceptor(
ratelimit.StreamClientInterceptor(limiter),
),
)
}
27 changes: 27 additions & 0 deletions interceptors/ratelimit/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,30 @@ func StreamServerInterceptor(limiter Limiter) grpc.StreamServerInterceptor {
return handler(srv, stream)
}
}

// UnaryClientInterceptor returns a new unary client interceptor that performs rate limiting on the request on the
// client side.
// This can be helpful for clients that want to limit the number of requests they send in a given time, potentially
// saving cost.
func UnaryClientInterceptor(limiter Limiter) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if err := limiter.Limit(ctx); err != nil {
return status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later. %s", method, err)
}
return invoker(ctx, method, req, reply, cc, opts...)
}
}

// StreamClientInterceptor returns a new stream client interceptor that performs rate limiting on the request on the
// client side.
// This can be helpful for clients that want to limit the number of requests they send in a given time, potentially
// saving cost.
func StreamClientInterceptor(limiter Limiter) grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
if err := limiter.Limit(ctx); err != nil {
return nil, status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later. %s", method, err)
}
return streamer(ctx, desc, cc, method, opts...)
}
}
112 changes: 103 additions & 9 deletions interceptors/ratelimit/ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ import (
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

const errMsgFake = "fake error"

var ctxLimitKey = struct{}{}
type ctxKey string

var ctxKeyShouldLimit = ctxKey("shouldLimit")

type mockGRPCServerStream struct {
grpc.ServerStream
Expand All @@ -29,13 +33,18 @@ func (m *mockGRPCServerStream) Context() context.Context {
type mockContextBasedLimiter struct{}

func (*mockContextBasedLimiter) Limit(ctx context.Context) error {
l, _ := ctx.Value(ctxLimitKey).(error)
return l
shouldLimit, _ := ctx.Value(ctxKeyShouldLimit).(bool)

if shouldLimit {
return errors.New("rate limit exceeded")
}

return nil
}

func TestUnaryServerInterceptor_RateLimitPass(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxLimitKey, false)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, false)

interceptor := UnaryServerInterceptor(limiter)
handler := func(ctx context.Context, req any) (any, error) {
Expand All @@ -51,7 +60,7 @@ func TestUnaryServerInterceptor_RateLimitPass(t *testing.T) {

func TestStreamServerInterceptor_RateLimitPass(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxLimitKey, false)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, false)

interceptor := StreamServerInterceptor(limiter)
handler := func(srv any, stream grpc.ServerStream) error {
Expand All @@ -66,31 +75,116 @@ func TestStreamServerInterceptor_RateLimitPass(t *testing.T) {

func TestUnaryServerInterceptor_RateLimitFail(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxLimitKey, true)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, true)

interceptor := UnaryServerInterceptor(limiter)
called := false
handler := func(ctx context.Context, req any) (any, error) {
called = true
return nil, errors.New(errMsgFake)
}
info := &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}
resp, err := interceptor(ctx, nil, info, handler)
expErr := status.Errorf(
codes.ResourceExhausted,
"%s is rejected by grpc_ratelimit middleware, please retry later. %s",
info.FullMethod,
"rate limit exceeded",
)
assert.Nil(t, resp)
assert.EqualError(t, err, errMsgFake)
assert.EqualError(t, err, expErr.Error())
assert.False(t, called)
}

func TestStreamServerInterceptor_RateLimitFail(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxLimitKey, true)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, true)

interceptor := StreamServerInterceptor(limiter)
handler := func(srv interface{}, stream grpc.ServerStream) error {
called := false
handler := func(srv any, stream grpc.ServerStream) error {
called = true
return errors.New(errMsgFake)
}
info := &grpc.StreamServerInfo{
FullMethod: "FakeMethod",
}
err := interceptor(nil, &mockGRPCServerStream{ctx: ctx}, info, handler)
expErr := status.Errorf(
codes.ResourceExhausted,
"%s is rejected by grpc_ratelimit middleware, please retry later. %s",
info.FullMethod,
"rate limit exceeded",
)

assert.EqualError(t, err, expErr.Error())
assert.False(t, called)
}

func TestUnaryClientInterceptor_RateLimitPass(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, false)

interceptor := UnaryClientInterceptor(limiter)
invoker := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return errors.New(errMsgFake)
}
err := interceptor(ctx, "FakeMethod", nil, nil, nil, invoker)
assert.EqualError(t, err, errMsgFake)
}

func TestStreamClientInterceptor_RateLimitPass(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, false)

interceptor := StreamClientInterceptor(limiter)
invoker := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return nil, errors.New(errMsgFake)
}
_, err := interceptor(ctx, nil, nil, "FakeMethod", invoker)
assert.EqualError(t, err, errMsgFake)
}

func TestUnaryClientInterceptor_RateLimitFail(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, true)

interceptor := UnaryClientInterceptor(limiter)
called := false
invoker := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
called = true
return errors.New(errMsgFake)
}
err := interceptor(ctx, "FakeMethod", nil, nil, nil, invoker)
expErr := status.Errorf(
codes.ResourceExhausted,
"%s is rejected by grpc_ratelimit middleware, please retry later. %s",
"FakeMethod",
"rate limit exceeded",
)
assert.EqualError(t, err, expErr.Error())
assert.False(t, called)
}

func TestStreamClientInterceptor_RateLimitFail(t *testing.T) {
limiter := new(mockContextBasedLimiter)
ctx := context.WithValue(context.Background(), ctxKeyShouldLimit, true)

interceptor := StreamClientInterceptor(limiter)
called := false
invoker := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
called = true
return nil, errors.New(errMsgFake)
}
_, err := interceptor(ctx, nil, nil, "FakeMethod", invoker)
expErr := status.Errorf(
codes.ResourceExhausted,
"%s is rejected by grpc_ratelimit middleware, please retry later. %s",
"FakeMethod",
"rate limit exceeded",
)
assert.EqualError(t, err, expErr.Error())
assert.False(t, called)
}