Skip to content

Commit

Permalink
Merge pull request #207 from porters-xyz/rate-limit-rule
Browse files Browse the repository at this point in the history
  • Loading branch information
wtfsayo authored Apr 29, 2024
2 parents 43d9b22 + ee2a1b5 commit 29f08a8
Show file tree
Hide file tree
Showing 17 changed files with 117 additions and 119 deletions.
10 changes: 1 addition & 9 deletions gateway/proxy/context.go → gateway/common/context.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package proxy
package common

import (
"context"
"net/http"
)

type Contextable interface {
Expand All @@ -21,10 +20,3 @@ func FromContext(ctx context.Context, contextkey string) (any, bool) {
return nil, false
}
}

func setupContext(req *http.Request) {
// TODO read ctx from request and make any modifications
ctx := req.Context()
lifecyclectx := UpdateContext(ctx, &Lifecycle{})
*req = *req.WithContext(lifecyclectx)
}
69 changes: 25 additions & 44 deletions gateway/db/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/redis/go-redis/v9"
rl "github.com/go-redis/redis_rate/v10"

"porters/common"
)
Expand All @@ -29,12 +30,12 @@ type refreshable interface {
refresh(ctx context.Context) error
}

type incrementable interface {
type Incrementable interface {
Key() string
Field() string
}

type decrementable interface {
type Decrementable interface {
Key() string
Field() string
}
Expand Down Expand Up @@ -166,9 +167,9 @@ func (p *Paymenttx) cache(ctx context.Context) error {
//}

func (t *Tenant) Lookup(ctx context.Context) error {
fromContext, ok := tenantFromContext(ctx)
fromContext, ok := common.FromContext(ctx, TENANT)
if ok {
*t = fromContext
*t = *fromContext.(*Tenant)
} else {
key := t.Key()
result, err := getCache().HGetAll(ctx, key).Result()
Expand All @@ -186,9 +187,9 @@ func (t *Tenant) Lookup(ctx context.Context) error {
}

func (a *App) Lookup(ctx context.Context) error {
fromContext, ok := appFromContext(ctx)
fromContext, ok := common.FromContext(ctx, APP)
if ok {
*a = fromContext
*a = *fromContext.(*App)
} else {
key := a.Key()
log.Println("checking cache for app", key)
Expand All @@ -212,7 +213,7 @@ func (a *App) Lookup(ctx context.Context) error {
return nil
}

func (a *App) Rules(ctx context.Context) ([]Apprule, error) {
func (a *App) Rules(ctx context.Context) (Apprules, error) {
rules := make([]Apprule, 0)
pattern := fmt.Sprintf("%s:%s", APPRULE, a.Id)

Expand All @@ -232,6 +233,7 @@ func (a *App) Rules(ctx context.Context) ([]Apprule, error) {
Id: id,
Active: active,
Value: result["value"],
RuleType: result["ruleType"],
CachedAt: cachedAt,
}
rules = append(rules, ar)
Expand All @@ -241,9 +243,9 @@ func (a *App) Rules(ctx context.Context) ([]Apprule, error) {

// Lookup by name, p should have a valid "Name" set before lookup
func (p *Product) Lookup(ctx context.Context) error {
fromContext, ok := productFromContext(ctx)
fromContext, ok := common.FromContext(ctx, PRODUCT)
if ok {
*p = fromContext
*p = *fromContext.(*Product)
} else {
key := p.Key()
log.Println("finding product from cache:", key)
Expand Down Expand Up @@ -293,6 +295,15 @@ func (a *App) refresh(ctx context.Context) {
a.Tenant.Lookup(ctx)
}
a.cache(ctx)

rules, err := a.fetchRules(ctx)
if err != nil {
log.Println("error accessing rules", err)
return
}
for _, r := range rules {
r.cache(ctx)
}
}

func (p *Product) refresh(ctx context.Context) {
Expand Down Expand Up @@ -351,7 +362,7 @@ func GetIntVal(ctx context.Context, name string) int {
return intval
}

func IncrementField(ctx context.Context, incr incrementable, amount int) int {
func IncrementField(ctx context.Context, incr Incrementable, amount int) int {
incrBy := int64(amount)
newVal, err := getCache().HIncrBy(ctx, incr.Key(), incr.Field(), incrBy).Result()
if err != nil {
Expand All @@ -360,7 +371,7 @@ func IncrementField(ctx context.Context, incr incrementable, amount int) int {
return int(newVal)
}

func DecrementField(ctx context.Context, decr decrementable, amount int) int {
func DecrementField(ctx context.Context, decr Decrementable, amount int) int {
decrBy := -int64(amount)
newVal, err := getCache().HIncrBy(ctx, decr.Key(), decr.Field(), decrBy).Result()
if err != nil {
Expand Down Expand Up @@ -399,37 +410,7 @@ func ScanKeys(ctx context.Context, key string) *redis.ScanIterator {
return iter
}

// TODO write scan for specific types, don't leak redis specifics outside
// package

// use context to prevent duplicate cache hits in same request

func tenantFromContext(ctx context.Context) (Tenant, bool) {
var tenant Tenant
value := ctx.Value(TENANT)
if value != nil {
tenant = value.(Tenant)
return tenant, true
}
return tenant, false
}

func appFromContext(ctx context.Context) (App, bool) {
var app App
value := ctx.Value(APP)
if value != nil {
app = value.(App)
return app, true
}
return app, false
}

func productFromContext(ctx context.Context) (Product, bool) {
var product Product
value := ctx.Value(PRODUCT)
if value != nil {
product = value.(Product)
return product, true
}
return product, false
func Limiter() *rl.Limiter {
rdb := getCache()
return rl.NewLimiter(rdb)
}
2 changes: 1 addition & 1 deletion gateway/db/canonical.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (a *App) fetch(ctx context.Context) error {
return nil
}

func (a *App) fetchRules(ctx context.Context) ([]Apprule, error) {
func (a *App) fetchRules(ctx context.Context) (Apprules, error) {
rules := make([]Apprule, 0)
db := getCanonicalDB()
// TODO join with rule types
Expand Down
2 changes: 1 addition & 1 deletion gateway/db/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const (
PRODUCT = "PRODUCT"
)

type model interface {
type Model interface {
Key() string
}

Expand Down
2 changes: 1 addition & 1 deletion gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func gateway() {
log.Println("starting gateway")
proxy.Start()

done := make(chan os.Signal)
done := make(chan os.Signal, 1)
signal.Notify(done, syscall.SIGINT, syscall.SIGTERM)
<-done
shutdown()
Expand Down
1 change: 1 addition & 0 deletions gateway/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require github.com/redis/go-redis/v9 v9.4.0
require (
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/go-redis/redis_rate/v10 v10.0.1 // indirect
github.com/gorilla/mux v1.8.1 // indirect
github.com/lib/pq v1.10.9 // indirect
)
2 changes: 2 additions & 0 deletions gateway/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/go-redis/redis_rate/v10 v10.0.1 h1:calPxi7tVlxojKunJwQ72kwfozdy25RjA0bCj1h0MUo=
github.com/go-redis/redis_rate/v10 v10.0.1/go.mod h1:EMiuO9+cjRkR7UvdvwMO7vbgqJkltQHtwbdIQvaBKIU=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
Expand Down
3 changes: 2 additions & 1 deletion gateway/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ func main() {
proxy.Register(&plugins.Counter{})
proxy.Register(&plugins.ApiKeyAuth{"X-API"})
proxy.Register(&plugins.BalanceTracker{})
proxy.Register(&plugins.NoopFilter{proxy.LifecycleMask(proxy.AccountLookup|proxy.RateLimit)})
proxy.Register(&plugins.LeakyBucketPlugin{"APP"})
proxy.Register(&plugins.NoopFilter{proxy.LifecycleMask(proxy.AccountLookup|proxy.RateLimit|proxy.BalanceCheck)})

gateway()
}
Expand Down
8 changes: 5 additions & 3 deletions gateway/plugins/apikeyauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"log"
"net/http"

"porters/common"
"porters/proxy"
)

Expand All @@ -28,7 +29,7 @@ func (a ApiKeyAuth) Key() string {
return "API_KEY_AUTH"
}

func (a ApiKeyAuth) HandleRequest(req *http.Request) {
func (a ApiKeyAuth) HandleRequest(req *http.Request) error {
apiKey := req.Header.Get(a.ApiKeyName)
newCtx := context.WithValue(req.Context(), proxy.AUTH_VAL, apiKey)

Expand All @@ -41,11 +42,12 @@ func (a ApiKeyAuth) HandleRequest(req *http.Request) {
//return
//}
} else {
return
return proxy.NewHTTPError(http.StatusBadRequest)
}
lifecycle := proxy.SetStageComplete(newCtx, proxy.Auth)
newCtx = proxy.UpdateContext(newCtx, lifecycle)
newCtx = common.UpdateContext(newCtx, lifecycle)
*req = *req.WithContext(newCtx)
return nil
}

// TODO check api key is in valid format to quickly determine errant requests
Expand Down
21 changes: 10 additions & 11 deletions gateway/plugins/balance.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"log"
"net/http"

"porters/common"
"porters/db"
"porters/proxy"
)
Expand Down Expand Up @@ -39,44 +40,42 @@ func (b *BalanceTracker) Load() {
}

// TODO optim: script this to avoid multi-hops
func (b *BalanceTracker) HandleRequest(req *http.Request) {
func (b *BalanceTracker) HandleRequest(req *http.Request) error {
ctx := req.Context()
appId := proxy.PluckAppId(req)
app := &db.App{Id: appId}
err := app.Lookup(ctx)
log.Println("app:", app)
if err != nil {
// TODO can't find app
return proxy.NewHTTPError(http.StatusNotFound)
}
bal := &balancecache{
tracker: b,
tenant: &app.Tenant,
}
err = bal.Lookup(ctx)
if err != nil {
// TODO can we recover from this?
return proxy.NewHTTPError(http.StatusNotFound)
}
ctx = proxy.UpdateContext(ctx, bal)
ctx = common.UpdateContext(ctx, bal)
// TODO Check that balance is greater than or equal to req weight
if bal.cachedBalance > 0 {
log.Println("balance remaining")
lifecycle := proxy.SetStageComplete(ctx, proxy.BalanceCheck)
ctx = proxy.UpdateContext(ctx, lifecycle)
ctx = common.UpdateContext(ctx, lifecycle)
*req = *req.WithContext(ctx)
} else {
log.Println("none remaining", appId)
var cancel context.CancelCauseFunc
ctx, cancel = context.WithCancelCause(ctx)
err := proxy.BalanceExceededError
cancel(err)
return proxy.BalanceExceededError
}
*req = *req.WithContext(ctx)
return nil
}

func (b *BalanceTracker) HandleResponse(resp *http.Response) error {
// TODO read pokt docs for if there is better way to check response
ctx := resp.Request.Context()
if resp.StatusCode < 400 {
entity, ok := proxy.FromContext(ctx, BALANCE)
entity, ok := common.FromContext(ctx, BALANCE)
if ok {
bal := entity.(*balancecache)
newval := db.DecrementCounter(ctx, bal.Key(), 1)
Expand Down
6 changes: 2 additions & 4 deletions gateway/plugins/blocker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"log"
"net/http"
"porters/proxy"
)

type Blocker struct {}
Expand All @@ -22,10 +21,9 @@ func (b Blocker) Key() string {
return "BLOCKER"
}

func (b Blocker) HandleRequest(req *http.Request) {
cancel := proxy.RequestCanceler(req)
func (b Blocker) HandleRequest(req *http.Request) error {
log.Println("logging block")
cancel(errors.New(fmt.Sprint("blocked by prehandler", b.Name())))
return errors.New(fmt.Sprint("blocked by prehandler", b.Name()))
}

func (b Blocker) HandleResponse(resp *http.Response) error {
Expand Down
Loading

0 comments on commit 29f08a8

Please sign in to comment.