Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rate limit rule #207

Merged
merged 2 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions gateway/proxy/context.go → gateway/common/context.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package proxy
package common

import (
"context"
"net/http"
)

type Contextable interface {
Expand All @@ -21,10 +20,3 @@ func FromContext(ctx context.Context, contextkey string) (any, bool) {
return nil, false
}
}

func setupContext(req *http.Request) {
// TODO read ctx from request and make any modifications
ctx := req.Context()
lifecyclectx := UpdateContext(ctx, &Lifecycle{})
*req = *req.WithContext(lifecyclectx)
}
69 changes: 25 additions & 44 deletions gateway/db/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/redis/go-redis/v9"
rl "github.com/go-redis/redis_rate/v10"

"porters/common"
)
Expand All @@ -29,12 +30,12 @@ type refreshable interface {
refresh(ctx context.Context) error
}

type incrementable interface {
type Incrementable interface {
Key() string
Field() string
}

type decrementable interface {
type Decrementable interface {
Key() string
Field() string
}
Expand Down Expand Up @@ -166,9 +167,9 @@ func (p *Paymenttx) cache(ctx context.Context) error {
//}

func (t *Tenant) Lookup(ctx context.Context) error {
fromContext, ok := tenantFromContext(ctx)
fromContext, ok := common.FromContext(ctx, TENANT)
if ok {
*t = fromContext
*t = *fromContext.(*Tenant)
} else {
key := t.Key()
result, err := getCache().HGetAll(ctx, key).Result()
Expand All @@ -186,9 +187,9 @@ func (t *Tenant) Lookup(ctx context.Context) error {
}

func (a *App) Lookup(ctx context.Context) error {
fromContext, ok := appFromContext(ctx)
fromContext, ok := common.FromContext(ctx, APP)
if ok {
*a = fromContext
*a = *fromContext.(*App)
} else {
key := a.Key()
log.Println("checking cache for app", key)
Expand All @@ -212,7 +213,7 @@ func (a *App) Lookup(ctx context.Context) error {
return nil
}

func (a *App) Rules(ctx context.Context) ([]Apprule, error) {
func (a *App) Rules(ctx context.Context) (Apprules, error) {
rules := make([]Apprule, 0)
pattern := fmt.Sprintf("%s:%s", APPRULE, a.Id)

Expand All @@ -232,6 +233,7 @@ func (a *App) Rules(ctx context.Context) ([]Apprule, error) {
Id: id,
Active: active,
Value: result["value"],
RuleType: result["ruleType"],
CachedAt: cachedAt,
}
rules = append(rules, ar)
Expand All @@ -241,9 +243,9 @@ func (a *App) Rules(ctx context.Context) ([]Apprule, error) {

// Lookup by name, p should have a valid "Name" set before lookup
func (p *Product) Lookup(ctx context.Context) error {
fromContext, ok := productFromContext(ctx)
fromContext, ok := common.FromContext(ctx, PRODUCT)
if ok {
*p = fromContext
*p = *fromContext.(*Product)
} else {
key := p.Key()
log.Println("finding product from cache:", key)
Expand Down Expand Up @@ -293,6 +295,15 @@ func (a *App) refresh(ctx context.Context) {
a.Tenant.Lookup(ctx)
}
a.cache(ctx)

rules, err := a.fetchRules(ctx)
if err != nil {
log.Println("error accessing rules", err)
return
}
for _, r := range rules {
r.cache(ctx)
}
}

func (p *Product) refresh(ctx context.Context) {
Expand Down Expand Up @@ -351,7 +362,7 @@ func GetIntVal(ctx context.Context, name string) int {
return intval
}

func IncrementField(ctx context.Context, incr incrementable, amount int) int {
func IncrementField(ctx context.Context, incr Incrementable, amount int) int {
incrBy := int64(amount)
newVal, err := getCache().HIncrBy(ctx, incr.Key(), incr.Field(), incrBy).Result()
if err != nil {
Expand All @@ -360,7 +371,7 @@ func IncrementField(ctx context.Context, incr incrementable, amount int) int {
return int(newVal)
}

func DecrementField(ctx context.Context, decr decrementable, amount int) int {
func DecrementField(ctx context.Context, decr Decrementable, amount int) int {
decrBy := -int64(amount)
newVal, err := getCache().HIncrBy(ctx, decr.Key(), decr.Field(), decrBy).Result()
if err != nil {
Expand Down Expand Up @@ -399,37 +410,7 @@ func ScanKeys(ctx context.Context, key string) *redis.ScanIterator {
return iter
}

// TODO write scan for specific types, don't leak redis specifics outside
// package

// use context to prevent duplicate cache hits in same request

func tenantFromContext(ctx context.Context) (Tenant, bool) {
var tenant Tenant
value := ctx.Value(TENANT)
if value != nil {
tenant = value.(Tenant)
return tenant, true
}
return tenant, false
}

func appFromContext(ctx context.Context) (App, bool) {
var app App
value := ctx.Value(APP)
if value != nil {
app = value.(App)
return app, true
}
return app, false
}

func productFromContext(ctx context.Context) (Product, bool) {
var product Product
value := ctx.Value(PRODUCT)
if value != nil {
product = value.(Product)
return product, true
}
return product, false
func Limiter() *rl.Limiter {
rdb := getCache()
return rl.NewLimiter(rdb)
}
2 changes: 1 addition & 1 deletion gateway/db/canonical.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (a *App) fetch(ctx context.Context) error {
return nil
}

func (a *App) fetchRules(ctx context.Context) ([]Apprule, error) {
func (a *App) fetchRules(ctx context.Context) (Apprules, error) {
rules := make([]Apprule, 0)
db := getCanonicalDB()
// TODO join with rule types
Expand Down
2 changes: 1 addition & 1 deletion gateway/db/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const (
PRODUCT = "PRODUCT"
)

type model interface {
type Model interface {
Key() string
}

Expand Down
2 changes: 1 addition & 1 deletion gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func gateway() {
log.Println("starting gateway")
proxy.Start()

done := make(chan os.Signal)
done := make(chan os.Signal, 1)
signal.Notify(done, syscall.SIGINT, syscall.SIGTERM)
<-done
shutdown()
Expand Down
1 change: 1 addition & 0 deletions gateway/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require github.com/redis/go-redis/v9 v9.4.0
require (
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/go-redis/redis_rate/v10 v10.0.1 // indirect
github.com/gorilla/mux v1.8.1 // indirect
github.com/lib/pq v1.10.9 // indirect
)
2 changes: 2 additions & 0 deletions gateway/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/go-redis/redis_rate/v10 v10.0.1 h1:calPxi7tVlxojKunJwQ72kwfozdy25RjA0bCj1h0MUo=
github.com/go-redis/redis_rate/v10 v10.0.1/go.mod h1:EMiuO9+cjRkR7UvdvwMO7vbgqJkltQHtwbdIQvaBKIU=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
Expand Down
3 changes: 2 additions & 1 deletion gateway/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ func main() {
proxy.Register(&plugins.Counter{})
proxy.Register(&plugins.ApiKeyAuth{"X-API"})
proxy.Register(&plugins.BalanceTracker{})
proxy.Register(&plugins.NoopFilter{proxy.LifecycleMask(proxy.AccountLookup|proxy.RateLimit)})
proxy.Register(&plugins.LeakyBucketPlugin{"APP"})
proxy.Register(&plugins.NoopFilter{proxy.LifecycleMask(proxy.AccountLookup|proxy.RateLimit|proxy.BalanceCheck)})

gateway()
}
Expand Down
8 changes: 5 additions & 3 deletions gateway/plugins/apikeyauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"log"
"net/http"

"porters/common"
"porters/proxy"
)

Expand All @@ -28,7 +29,7 @@ func (a ApiKeyAuth) Key() string {
return "API_KEY_AUTH"
}

func (a ApiKeyAuth) HandleRequest(req *http.Request) {
func (a ApiKeyAuth) HandleRequest(req *http.Request) error {
apiKey := req.Header.Get(a.ApiKeyName)
newCtx := context.WithValue(req.Context(), proxy.AUTH_VAL, apiKey)

Expand All @@ -41,11 +42,12 @@ func (a ApiKeyAuth) HandleRequest(req *http.Request) {
//return
//}
} else {
return
return proxy.NewHTTPError(http.StatusBadRequest)
}
lifecycle := proxy.SetStageComplete(newCtx, proxy.Auth)
newCtx = proxy.UpdateContext(newCtx, lifecycle)
newCtx = common.UpdateContext(newCtx, lifecycle)
*req = *req.WithContext(newCtx)
return nil
}

// TODO check api key is in valid format to quickly determine errant requests
Expand Down
21 changes: 10 additions & 11 deletions gateway/plugins/balance.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"log"
"net/http"

"porters/common"
"porters/db"
"porters/proxy"
)
Expand Down Expand Up @@ -39,44 +40,42 @@ func (b *BalanceTracker) Load() {
}

// TODO optim: script this to avoid multi-hops
func (b *BalanceTracker) HandleRequest(req *http.Request) {
func (b *BalanceTracker) HandleRequest(req *http.Request) error {
ctx := req.Context()
appId := proxy.PluckAppId(req)
app := &db.App{Id: appId}
err := app.Lookup(ctx)
log.Println("app:", app)
if err != nil {
// TODO can't find app
return proxy.NewHTTPError(http.StatusNotFound)
}
bal := &balancecache{
tracker: b,
tenant: &app.Tenant,
}
err = bal.Lookup(ctx)
if err != nil {
// TODO can we recover from this?
return proxy.NewHTTPError(http.StatusNotFound)
}
ctx = proxy.UpdateContext(ctx, bal)
ctx = common.UpdateContext(ctx, bal)
// TODO Check that balance is greater than or equal to req weight
if bal.cachedBalance > 0 {
log.Println("balance remaining")
lifecycle := proxy.SetStageComplete(ctx, proxy.BalanceCheck)
ctx = proxy.UpdateContext(ctx, lifecycle)
ctx = common.UpdateContext(ctx, lifecycle)
*req = *req.WithContext(ctx)
} else {
log.Println("none remaining", appId)
var cancel context.CancelCauseFunc
ctx, cancel = context.WithCancelCause(ctx)
err := proxy.BalanceExceededError
cancel(err)
return proxy.BalanceExceededError
}
*req = *req.WithContext(ctx)
return nil
}

func (b *BalanceTracker) HandleResponse(resp *http.Response) error {
// TODO read pokt docs for if there is better way to check response
ctx := resp.Request.Context()
if resp.StatusCode < 400 {
entity, ok := proxy.FromContext(ctx, BALANCE)
entity, ok := common.FromContext(ctx, BALANCE)
if ok {
bal := entity.(*balancecache)
newval := db.DecrementCounter(ctx, bal.Key(), 1)
Expand Down
6 changes: 2 additions & 4 deletions gateway/plugins/blocker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"log"
"net/http"
"porters/proxy"
)

type Blocker struct {}
Expand All @@ -22,10 +21,9 @@ func (b Blocker) Key() string {
return "BLOCKER"
}

func (b Blocker) HandleRequest(req *http.Request) {
cancel := proxy.RequestCanceler(req)
func (b Blocker) HandleRequest(req *http.Request) error {
log.Println("logging block")
cancel(errors.New(fmt.Sprint("blocked by prehandler", b.Name())))
return errors.New(fmt.Sprint("blocked by prehandler", b.Name()))
}

func (b Blocker) HandleResponse(resp *http.Response) error {
Expand Down
Loading