Skip to content

Commit

Permalink
add rate limit tripperware and middleware (#33)
Browse files Browse the repository at this point in the history
* add rate limit tripperware

* use atomic uint64

* delete useless options

* delete useless options

* add rate limiter from golang.org/x/time/rate

* move config

* code style

* update readme

* refacto leaky_bucket rate limiter

* remove useless test file

* fix example and code style

* remove vendor

* remove bench

* fix typo

* fix port

* fix example addr

* improve and add middleware

* improve

* fix comments

* Update README.md

Co-authored-by: instabledesign <instabledesign@gmail.com>

* update version

* update version

Co-authored-by: Anthony Moutte <instabledesign@gmail.com>
  • Loading branch information
qneyrat and instabledesign authored Sep 3, 2020
1 parent 14daf88 commit 071422a
Show file tree
Hide file tree
Showing 12 changed files with 496 additions and 26 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Package httpware is a collection of middleware (net/http.Handler wrapper) and tr
|**Interceptor**|X|X|
|**Skip**|X|X|
|**Enable**|X|X|
|**RateLimiter**|X|X|

## Installation

Expand Down
26 changes: 13 additions & 13 deletions middleware/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
)

// Interceptor middleware allow multiple req.Body read and allow to set callback before and after roundtrip
func Interceptor(options ...Option) httpware.Middleware {
config := NewConfig(options...)
func Interceptor(options ...InterceptorOption) httpware.Middleware {
config := NewInterceptorConfig(options...)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
writerInterceptor := NewResponseWriterInterceptor(writer)
Expand All @@ -25,40 +25,40 @@ func Interceptor(options ...Option) httpware.Middleware {
}
}

type Config struct {
type InterceptorConfig struct {
CallbackBefore func(*ResponseWriterInterceptor, *http.Request)
CallbackAfter func(*ResponseWriterInterceptor, *http.Request)
}

func (c *Config) apply(options ...Option) *Config {
func (c *InterceptorConfig) apply(options ...InterceptorOption) *InterceptorConfig {
for _, option := range options {
option(c)
}
return c
}

// NewConfig returns a new interceptor middleware configuration with all options applied
func NewConfig(options ...Option) *Config {
config := &Config{
// NewInterceptorConfig returns a new interceptor middleware configuration with all options applied
func NewInterceptorConfig(options ...InterceptorOption) *InterceptorConfig {
config := &InterceptorConfig{
CallbackBefore: func(_ *ResponseWriterInterceptor, _ *http.Request) {},
CallbackAfter: func(_ *ResponseWriterInterceptor, _ *http.Request) {},
}
return config.apply(options...)
}

// Option defines a interceptor middleware configuration option
type Option func(*Config)
// InterceptorOption defines a interceptor middleware configuration option
type InterceptorOption func(*InterceptorConfig)

// WithBefore will configure CallbackBefore interceptor option
func WithBefore(callbackBefore func(*ResponseWriterInterceptor, *http.Request)) Option {
return func(config *Config) {
func WithBefore(callbackBefore func(*ResponseWriterInterceptor, *http.Request)) InterceptorOption {
return func(config *InterceptorConfig) {
config.CallbackBefore = callbackBefore
}
}

// WithAfter will configure CallbackAfter interceptor option
func WithAfter(callbackAfter func(*ResponseWriterInterceptor, *http.Request)) Option {
return func(config *Config) {
func WithAfter(callbackAfter func(*ResponseWriterInterceptor, *http.Request)) InterceptorOption {
return func(config *InterceptorConfig) {
config.CallbackAfter = callbackAfter
}
}
59 changes: 59 additions & 0 deletions middleware/rate_limit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package middleware

import (
"net/http"

"github.com/gol4ng/httpware/v3"
"github.com/gol4ng/httpware/v3/rate_limit"
)

func RateLimit(limiter rate_limit.RateLimiter, options ...RateLimitOption) httpware.Middleware {
config := NewRateLimitConfig(options...)

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
if err := limiter.Allow(req); err != nil {
if !config.ErrorCallback(err, writer, req) {
return
}
}

limiter.Inc(req)
defer limiter.Dec(req)
next.ServeHTTP(writer, req)
})
}
}

type RateLimitOption func(*RateLimitConfig)

type RateLimitErrorCallback func(err error, writer http.ResponseWriter, req *http.Request) (next bool)

type RateLimitConfig struct {
ErrorCallback RateLimitErrorCallback
}

func (c *RateLimitConfig) apply(options ...RateLimitOption) *RateLimitConfig {
for _, option := range options {
option(c)
}
return c
}

func NewRateLimitConfig(options ...RateLimitOption) *RateLimitConfig {
config := &RateLimitConfig{
ErrorCallback: DefaultRateLimitErrorCallback,
}
return config.apply(options...)
}

func DefaultRateLimitErrorCallback(err error, writer http.ResponseWriter, _ *http.Request) bool {
http.Error(writer, err.Error(), http.StatusTooManyRequests)
return false
}

func WithRateLimitErrorCallback(callback RateLimitErrorCallback) RateLimitOption {
return func(config *RateLimitConfig) {
config.ErrorCallback = callback
}
}
80 changes: 80 additions & 0 deletions middleware/rate_limit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package middleware_test

import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/gol4ng/httpware/v3"
"github.com/gol4ng/httpware/v3/middleware"
"github.com/gol4ng/httpware/v3/mocks"
"github.com/gol4ng/httpware/v3/rate_limit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

func TestRateLimit(t *testing.T) {
rateLimiterMock := &mocks.RateLimiter{}
rateLimiterMock.On("Allow", mock.AnythingOfType("*http.Request")).Return(errors.New("failed"))

req := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil)
responseWriter := httptest.NewRecorder()

executed := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
executed = true
})

middleware.RateLimit(rateLimiterMock)(handler).ServeHTTP(responseWriter, req)

assert.False(t, executed)
assert.Equal(t, http.StatusTooManyRequests, responseWriter.Result().StatusCode)

content, err := ioutil.ReadAll(responseWriter.Result().Body)
assert.NoError(t, err)
assert.Equal(t, "failed\n", string(content))

rateLimiterMock.AssertExpectations(t)
}

// =====================================================================================================================
// ========================================= EXAMPLES ==================================================================
// =====================================================================================================================

func ExampleRateLimit() {
limiter := rate_limit.NewTokenBucket(1*time.Second, 1)
defer limiter.Stop()

port := ":9105"
// we recommend to use MiddlewareStack to simplify managing all wanted middlewares
// caution middleware order matters
stack := httpware.MiddlewareStack(
middleware.RateLimit(limiter),
)

srv := http.NewServeMux()
srv.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {})
go func() {
if err := http.ListenAndServe(port, stack.DecorateHandler(srv)); err != nil {
panic(err)
}
}()

resp, _ := http.Get("http://localhost" + port)
fmt.Println(resp.StatusCode)

resp, _ = http.Get("http://localhost" + port)
fmt.Println(resp.StatusCode)

time.Sleep(2 * time.Second)
resp, _ = http.Get("http://localhost" + port)
fmt.Println(resp.StatusCode)
// Output:
//200
//429
//200
}
35 changes: 35 additions & 0 deletions mocks/RateLimiter.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions rate_limit/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package rate_limit

const (
RequestLimitReachedErr = "request limit reached"
)
11 changes: 11 additions & 0 deletions rate_limit/rate_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package rate_limit

import (
"net/http"
)

type RateLimiter interface {
Allow(req *http.Request) error
Inc(req *http.Request)
Dec(req *http.Request)
}
67 changes: 67 additions & 0 deletions rate_limit/token_bucket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package rate_limit

import (
"errors"
"net/http"
"sync"
"time"
)

type TokenBucket struct {
mutex sync.Mutex
ticker *time.Ticker
done chan struct{}
callLimit uint32
count uint32
}

func (t *TokenBucket) Allow(_ *http.Request) error {
t.mutex.Lock()
res := t.count >= t.callLimit
t.mutex.Unlock()
if res {
return errors.New(RequestLimitReachedErr)
}

return nil
}

func (t *TokenBucket) Inc(_ *http.Request) {
t.mutex.Lock()
t.count++
t.mutex.Unlock()
}

func (t *TokenBucket) Dec(_ *http.Request) {}

func (t *TokenBucket) Stop() {
t.done <- struct{}{}
t.ticker.Stop()
}

func (t *TokenBucket) start() {
go func() {
for {
select {
case <-t.done:
return
case <-t.ticker.C:
t.mutex.Lock()
t.count = 0
t.mutex.Unlock()
}
}
}()
}

func NewTokenBucket(timeBucket time.Duration, callLimit int) *TokenBucket {
t := &TokenBucket{
ticker: time.NewTicker(timeBucket),
done: make(chan struct{}),
callLimit: uint32(callLimit),
}

t.start()

return t
}
23 changes: 23 additions & 0 deletions rate_limit/token_bucket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package rate_limit_test

import (
"testing"
"time"

"github.com/gol4ng/httpware/v3/rate_limit"
"github.com/stretchr/testify/assert"
)

func TestTokenBucket_Allow(t *testing.T) {
limiter := rate_limit.NewTokenBucket(1 * time.Millisecond, 1)
defer limiter.Stop()

assert.NoError(t, limiter.Allow(nil))
limiter.Inc(nil)

assert.EqualError(t, limiter.Allow(nil), "request limit reached")
limiter.Inc(nil)

time.Sleep(2 * time.Millisecond)
assert.NoError(t, limiter.Allow(nil))
}
Loading

0 comments on commit 071422a

Please sign in to comment.