Skip to content

Commit

Permalink
Merge branch 'main' into fix/upsert-identity-keys-dashboard
Browse files Browse the repository at this point in the history
  • Loading branch information
DeepaPrasanna committed Sep 25, 2024
2 parents fd7369c + 0728220 commit fe706b7
Show file tree
Hide file tree
Showing 101 changed files with 1,845 additions and 3,370 deletions.
25 changes: 25 additions & 0 deletions .github/workflows/job_test_agent_local.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: Test Agent Local
on:
workflow_call:



jobs:
test_agent_local:
runs-on: ubuntu-latest
timeout-minutes: 60
steps:
- uses: actions/checkout@v4
- name: Install
uses: ./.github/actions/install
with:
go: true


- name: Build
run: task build
working-directory: apps/agent

- name: Test
run: go test -cover -json -timeout=60m -failfast ./pkg/... ./services/... | tparse -all -progress
working-directory: apps/agent
4 changes: 3 additions & 1 deletion .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ jobs:
name: Test API
uses: ./.github/workflows/job_test_api_local.yaml


test_agent_local:
name: Test Agent Local
uses: ./.github/workflows/job_test_agent_local.yaml
# test_agent_integration:
# name: Test Agent Integration
# runs-on: ubuntu-latest
Expand Down
4 changes: 2 additions & 2 deletions Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ tasks:
- docker compose -f ./deployment/docker-compose.yaml up -d


seed:
migrate:
cmds:
- task: migrate-db
- task: migrate-clickhouse

migrate-clickhouse:
env:
GOOSE_DRIVER: clickhouse
GOOSE_DBSTRING: "tcp://127.0.0.1:9000"
GOOSE_DBSTRING: "tcp://default:password@127.0.0.1:9000"
GOOSE_MIGRATION_DIR: ./apps/agent/pkg/clickhouse/schema
cmds:
- goose up
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestIdentitiesRatelimitAccuracy(t *testing.T) {
unkey.WithSecurity(rootKey),
)

for _, nKeys := range []int{1} { //, 3, 10, 1000} {
for _, nKeys := range []int{1, 3, 10, 1000} {
t.Run(fmt.Sprintf("with %d keys", nKeys), func(t *testing.T) {

for _, tc := range []struct {
Expand Down Expand Up @@ -133,7 +133,7 @@ func TestIdentitiesRatelimitAccuracy(t *testing.T) {
// ---------------------------------------------------------------------------

exactLimit := int(inferenceLimit.Limit) * int(tc.testDuration/(time.Duration(inferenceLimit.Duration)*time.Millisecond))
upperLimit := int(1.2 * float64(exactLimit))
upperLimit := int(2.5 * float64(exactLimit))
lowerLimit := exactLimit
if total < lowerLimit {
lowerLimit = total
Expand Down
2 changes: 1 addition & 1 deletion apps/agent/integration/keys/ratelimits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func TestDefaultRatelimitAccuracy(t *testing.T) {
// ---------------------------------------------------------------------------

exactLimit := int(ratelimit.Limit) * int(tc.testDuration/(time.Duration(*ratelimit.Duration)*time.Millisecond))
upperLimit := int(1.2 * float64(exactLimit))
upperLimit := int(2.5 * float64(exactLimit))
lowerLimit := exactLimit
if total < lowerLimit {
lowerLimit = total
Expand Down
1 change: 0 additions & 1 deletion apps/agent/pkg/circuitbreaker/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ func (cb *CB[Res]) preflight(ctx context.Context) error {
now := cb.config.clock.Now()

if now.After(cb.resetCountersAt) {
cb.logger.Info().Msg("resetting circuit breaker")
cb.requests = 0
cb.successes = 0
cb.failures = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ CREATE TABLE default.raw_api_requests_v1(
response_headers Array(String),
response_body String,
-- internal err.Error() string, empty if no error
error String
error String,

-- milliseconds
service_latency Int64,

user_agent String,
ip_address String

)
ENGINE = MergeTree()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ CREATE TABLE default.raw_key_verifications_v1(

-- Right now this is a 3 character airport code, but when we move to aws,
-- this will be the region code such as `us-east-1`
region String,
region LowCardinality(String),

-- Examples:
-- - "VALID"
Expand All @@ -24,6 +24,8 @@ CREATE TABLE default.raw_key_verifications_v1(
-- Empty string if the key has no identity
identity_id String,



)
ENGINE = MergeTree()
ORDER BY (workspace_id, key_space_id, key_id, time)
Expand Down
28 changes: 6 additions & 22 deletions apps/agent/pkg/clock/real_clock.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,15 @@ package clock

import "time"

type TestClock struct {
now time.Time
type RealClock struct {
}

func NewTestClock(now ...time.Time) *TestClock {
if len(now) == 0 {
now = append(now, time.Now())
}
return &TestClock{now: now[0]}
func New() *RealClock {
return &RealClock{}
}

var _ Clock = &TestClock{}
var _ Clock = &RealClock{}

func (c *TestClock) Now() time.Time {
return c.now
}

// Tick advances the clock by the given duration and returns the new time.
func (c *TestClock) Tick(d time.Duration) time.Time {
c.now = c.now.Add(d)
return c.now
}

// Set sets the clock to the given time and returns the new time.
func (c *TestClock) Set(t time.Time) time.Time {
c.now = t
return c.now
func (c *RealClock) Now() time.Time {
return time.Now()
}
28 changes: 22 additions & 6 deletions apps/agent/pkg/clock/test_clock.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,31 @@ package clock

import "time"

type RealClock struct {
type TestClock struct {
now time.Time
}

func New() *RealClock {
return &RealClock{}
func NewTestClock(now ...time.Time) *TestClock {
if len(now) == 0 {
now = append(now, time.Now())
}
return &TestClock{now: now[0]}
}

var _ Clock = &RealClock{}
var _ Clock = &TestClock{}

func (c *RealClock) Now() time.Time {
return time.Now()
func (c *TestClock) Now() time.Time {
return c.now
}

// Tick advances the clock by the given duration and returns the new time.
func (c *TestClock) Tick(d time.Duration) time.Time {
c.now = c.now.Add(d)
return c.now
}

// Set sets the clock to the given time and returns the new time.
func (c *TestClock) Set(t time.Time) time.Time {
c.now = t
return c.now
}
8 changes: 8 additions & 0 deletions apps/agent/services/ratelimit/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,12 @@ var (
Subsystem: "ratelimit",
Name: "ratelimits_total",
}, []string{"passed"})

// forceSync is a counter that increments every time the agent is forced to
// sync with the origin ratelimit service because it doesn't have enough data
forceSync = promauto.NewCounter(prometheus.CounterOpts{
Namespace: "agent",
Subsystem: "ratelimit",
Name: "force_sync",
})
)
21 changes: 12 additions & 9 deletions apps/agent/services/ratelimit/mitigate.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ func (s *service) Mitigate(ctx context.Context, req *ratelimitv1.MitigateRequest
bucket, _ := s.getBucket(bucketKey{req.Identifier, req.Limit, duration})
bucket.Lock()
defer bucket.Unlock()

bucket.windows[req.Window.GetSequence()] = req.Window

return &ratelimitv1.MitigateResponse{}, nil
Expand All @@ -38,7 +37,7 @@ func (s *service) broadcastMitigation(req mitigateWindowRequest) {
ctx := context.Background()
node, err := s.cluster.FindNode(bucketKey{req.identifier, req.limit, req.duration}.toString())
if err != nil {
s.logger.Err(err).Msg("failed to find node")
s.logger.Warn().Err(err).Msg("failed to find node")
return
}
if node.Id != s.cluster.NodeId() {
Expand All @@ -51,16 +50,20 @@ func (s *service) broadcastMitigation(req mitigateWindowRequest) {
return
}
for _, peer := range peers {
_, err := peer.client.Mitigate(ctx, connect.NewRequest(&ratelimitv1.MitigateRequest{
Identifier: req.identifier,
Limit: req.limit,
Duration: req.duration.Milliseconds(),
Window: req.window,
}))
_, err := s.mitigateCircuitBreaker.Do(ctx, func(innerCtx context.Context) (*connect.Response[ratelimitv1.MitigateResponse], error) {
innerCtx, cancel := context.WithTimeout(innerCtx, 10*time.Second)
defer cancel()
return peer.client.Mitigate(innerCtx, connect.NewRequest(&ratelimitv1.MitigateRequest{
Identifier: req.identifier,
Limit: req.limit,
Duration: req.duration.Milliseconds(),
Window: req.window,
}))
})
if err != nil {
s.logger.Err(err).Msg("failed to call mitigate")
} else {
s.logger.Info().Str("peerId", peer.id).Msg("broadcasted mitigation")
s.logger.Debug().Str("peerId", peer.id).Msg("broadcasted mitigation")
}
}
}
42 changes: 24 additions & 18 deletions apps/agent/services/ratelimit/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,23 @@ func (s *service) Ratelimit(ctx context.Context, req *ratelimitv1.RatelimitReque
}
}

// prevExists, currExists := s.CheckWindows(ctx, ratelimitReq)
// // If neither window existed before, we should do an origin ratelimit check
// // because we likely don't have enough data to make an accurate decision'
// if !prevExists && !currExists {

// originRes, err := s.ratelimitOrigin(ctx, req)
// // The control flow is a bit unusual here because we want to return early on
// // success, rather than on error
// if err == nil && originRes != nil {
// return originRes, nil
// }
// if err != nil {
// // We want to know about the error, but if there is one, we just fall back
// // to local state, so we don't return early
// s.logger.Err(err).Msg("failed to sync with origin, falling back to local state")
// }
// }
prevExists, currExists := s.CheckWindows(ctx, ratelimitReq)
// If neither window existed before, we should do an origin ratelimit check
// because we likely don't have enough data to make an accurate decision'
if !prevExists && !currExists {

originRes, err := s.ratelimitOrigin(ctx, req)
// The control flow is a bit unusual here because we want to return early on
// success, rather than on error
if err == nil && originRes != nil {
return originRes, nil
}
if err != nil {
// We want to know about the error, but if there is one, we just fall back
// to local state, so we don't return early
s.logger.Err(err).Msg("failed to sync with origin, falling back to local state")
}
}

taken := s.Take(ctx, ratelimitReq)

Expand Down Expand Up @@ -93,6 +93,8 @@ func (s *service) ratelimitOrigin(ctx context.Context, req *ratelimitv1.Ratelimi
ctx, span := tracing.Start(ctx, "ratelimit.RatelimitOrigin")
defer span.End()

forceSync.Inc()

now := time.Now()
if req.Time != nil {
now = time.UnixMilli(req.GetTime())
Expand All @@ -116,7 +118,11 @@ func (s *service) ratelimitOrigin(ctx context.Context, req *ratelimitv1.Ratelimi
Time: now.UnixMilli(),
})

res, err := client.PushPull(ctx, connectReq)
res, err := s.syncCircuitBreaker.Do(ctx, func(innerCtx context.Context) (*connect.Response[ratelimitv1.PushPullResponse], error) {
innerCtx, cancel := context.WithTimeout(innerCtx, 10*time.Second)
defer cancel()
return client.PushPull(innerCtx, connectReq)
})
if err != nil {
tracing.RecordError(span, err)
s.logger.Err(err).Msg("failed to call ratelimit")
Expand Down
8 changes: 5 additions & 3 deletions apps/agent/services/ratelimit/ratelimit_mitigation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
)

func TestExceedingTheLimitShouldNotifyAllNodes(t *testing.T) {
t.Skip()

for _, clusterSize := range []int{1, 3, 5} {
t.Run(fmt.Sprintf("Cluster Size %d", clusterSize), func(t *testing.T) {
logger := logging.New(nil)
Expand Down Expand Up @@ -94,23 +94,25 @@ func TestExceedingTheLimitShouldNotifyAllNodes(t *testing.T) {
ctx := context.Background()

// Saturate the window
for i := int64(0); i <= limit; i++ {
for i := int64(0); i < limit; i++ {
rl := util.RandomElement(ratelimiters)
res, err := rl.Ratelimit(ctx, req)
require.NoError(t, err)
t.Logf("saturate res: %+v", res)
require.True(t, res.Success)

}

time.Sleep(time.Second * 5)

// Let's hit everry node again
// They should all be mitigated
for i, rl := range ratelimiters {

res, err := rl.Ratelimit(ctx, req)
require.NoError(t, err)
t.Logf("res from %d: %+v", i, res)
// require.False(t, res.Success)
require.False(t, res.Success)
}

})
Expand Down
6 changes: 2 additions & 4 deletions apps/agent/services/ratelimit/ratelimit_replication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import (
"github.com/unkeyed/unkey/apps/agent/pkg/util"
)

func TestReplication(t *testing.T) {
t.Skip()
func TestSync(t *testing.T) {
type Node struct {
srv *service
cluster cluster.Cluster
Expand Down Expand Up @@ -106,7 +105,7 @@ func TestReplication(t *testing.T) {
}

// Figure out who is the origin
_, err := nodes[1].srv.Ratelimit(ctx, req)
_, err := nodes[0].srv.Ratelimit(ctx, req)
require.NoError(t, err)

time.Sleep(5 * time.Second)
Expand Down Expand Up @@ -138,7 +137,6 @@ func TestReplication(t *testing.T) {
require.True(t, ok)
bucket.RLock()
window := bucket.getCurrentWindow(now)
t.Logf("window on origin: %+v", window)
counter := window.Counter
bucket.RUnlock()

Expand Down
Loading

0 comments on commit fe706b7

Please sign in to comment.