diff --git a/.github/workflows/job_test_agent_local.yaml b/.github/workflows/job_test_agent_local.yaml new file mode 100644 index 000000000..4c5276ff9 --- /dev/null +++ b/.github/workflows/job_test_agent_local.yaml @@ -0,0 +1,25 @@ +name: Test Agent Local +on: + workflow_call: + + + +jobs: + test_agent_local: + runs-on: ubuntu-latest + timeout-minutes: 60 + steps: + - uses: actions/checkout@v4 + - name: Install + uses: ./.github/actions/install + with: + go: true + + + - name: Build + run: task build + working-directory: apps/agent + + - name: Test + run: go test -cover -json -timeout=60m -failfast ./pkg/... ./services/... | tparse -all -progress + working-directory: apps/agent diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index cdc796edc..d3302a778 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -20,7 +20,9 @@ jobs: name: Test API uses: ./.github/workflows/job_test_api_local.yaml - + test_agent_local: + name: Test Agent Local + uses: ./.github/workflows/job_test_agent_local.yaml # test_agent_integration: # name: Test Agent Integration # runs-on: ubuntu-latest diff --git a/Taskfile.yml b/Taskfile.yml index fb3bbc389..4e44a5160 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -14,7 +14,7 @@ tasks: - docker compose -f ./deployment/docker-compose.yaml up -d - seed: + migrate: cmds: - task: migrate-db - task: migrate-clickhouse @@ -22,7 +22,7 @@ tasks: migrate-clickhouse: env: GOOSE_DRIVER: clickhouse - GOOSE_DBSTRING: "tcp://127.0.0.1:9000" + GOOSE_DBSTRING: "tcp://default:password@127.0.0.1:9000" GOOSE_MIGRATION_DIR: ./apps/agent/pkg/clickhouse/schema cmds: - goose up diff --git a/apps/agent/integration/identities/identities_ratelimits_accuracy_test.go b/apps/agent/integration/identities/identities_ratelimits_accuracy_test.go index bdd0b6d9c..de5f8e001 100644 --- a/apps/agent/integration/identities/identities_ratelimits_accuracy_test.go +++ b/apps/agent/integration/identities/identities_ratelimits_accuracy_test.go @@ -32,7 +32,7 @@ func TestIdentitiesRatelimitAccuracy(t *testing.T) { unkey.WithSecurity(rootKey), ) - for _, nKeys := range []int{1} { //, 3, 10, 1000} { + for _, nKeys := range []int{1, 3, 10, 1000} { t.Run(fmt.Sprintf("with %d keys", nKeys), func(t *testing.T) { for _, tc := range []struct { @@ -133,7 +133,7 @@ func TestIdentitiesRatelimitAccuracy(t *testing.T) { // --------------------------------------------------------------------------- exactLimit := int(inferenceLimit.Limit) * int(tc.testDuration/(time.Duration(inferenceLimit.Duration)*time.Millisecond)) - upperLimit := int(1.2 * float64(exactLimit)) + upperLimit := int(2.5 * float64(exactLimit)) lowerLimit := exactLimit if total < lowerLimit { lowerLimit = total diff --git a/apps/agent/integration/keys/ratelimits_test.go b/apps/agent/integration/keys/ratelimits_test.go index 93ebdaeba..efac5eed6 100644 --- a/apps/agent/integration/keys/ratelimits_test.go +++ b/apps/agent/integration/keys/ratelimits_test.go @@ -109,7 +109,7 @@ func TestDefaultRatelimitAccuracy(t *testing.T) { // --------------------------------------------------------------------------- exactLimit := int(ratelimit.Limit) * int(tc.testDuration/(time.Duration(*ratelimit.Duration)*time.Millisecond)) - upperLimit := int(1.2 * float64(exactLimit)) + upperLimit := int(2.5 * float64(exactLimit)) lowerLimit := exactLimit if total < lowerLimit { lowerLimit = total diff --git a/apps/agent/pkg/circuitbreaker/lib.go b/apps/agent/pkg/circuitbreaker/lib.go index 04ff5fc0a..327ebb2b0 100644 --- a/apps/agent/pkg/circuitbreaker/lib.go +++ b/apps/agent/pkg/circuitbreaker/lib.go @@ -169,7 +169,6 @@ func (cb *CB[Res]) preflight(ctx context.Context) error { now := cb.config.clock.Now() if now.After(cb.resetCountersAt) { - cb.logger.Info().Msg("resetting circuit breaker") cb.requests = 0 cb.successes = 0 cb.failures = 0 diff --git a/apps/agent/pkg/clickhouse/schema/001_create_requests_table.sql b/apps/agent/pkg/clickhouse/schema/001_create_requests_table.sql index c0f22f121..e8db0d3b0 100644 --- a/apps/agent/pkg/clickhouse/schema/001_create_requests_table.sql +++ b/apps/agent/pkg/clickhouse/schema/001_create_requests_table.sql @@ -21,7 +21,13 @@ CREATE TABLE default.raw_api_requests_v1( response_headers Array(String), response_body String, -- internal err.Error() string, empty if no error - error String + error String, + + -- milliseconds + service_latency Int64, + + user_agent String, + ip_address String ) ENGINE = MergeTree() diff --git a/apps/agent/pkg/clickhouse/schema/002_create_key_verifications_table.sql b/apps/agent/pkg/clickhouse/schema/002_create_key_verifications_table.sql index add2e9351..58ead77dc 100644 --- a/apps/agent/pkg/clickhouse/schema/002_create_key_verifications_table.sql +++ b/apps/agent/pkg/clickhouse/schema/002_create_key_verifications_table.sql @@ -12,7 +12,7 @@ CREATE TABLE default.raw_key_verifications_v1( -- Right now this is a 3 character airport code, but when we move to aws, -- this will be the region code such as `us-east-1` - region String, + region LowCardinality(String), -- Examples: -- - "VALID" @@ -24,6 +24,8 @@ CREATE TABLE default.raw_key_verifications_v1( -- Empty string if the key has no identity identity_id String, + + ) ENGINE = MergeTree() ORDER BY (workspace_id, key_space_id, key_id, time) diff --git a/apps/agent/pkg/clock/real_clock.go b/apps/agent/pkg/clock/real_clock.go index 50e33a6a1..580be114e 100644 --- a/apps/agent/pkg/clock/real_clock.go +++ b/apps/agent/pkg/clock/real_clock.go @@ -2,31 +2,15 @@ package clock import "time" -type TestClock struct { - now time.Time +type RealClock struct { } -func NewTestClock(now ...time.Time) *TestClock { - if len(now) == 0 { - now = append(now, time.Now()) - } - return &TestClock{now: now[0]} +func New() *RealClock { + return &RealClock{} } -var _ Clock = &TestClock{} +var _ Clock = &RealClock{} -func (c *TestClock) Now() time.Time { - return c.now -} - -// Tick advances the clock by the given duration and returns the new time. -func (c *TestClock) Tick(d time.Duration) time.Time { - c.now = c.now.Add(d) - return c.now -} - -// Set sets the clock to the given time and returns the new time. -func (c *TestClock) Set(t time.Time) time.Time { - c.now = t - return c.now +func (c *RealClock) Now() time.Time { + return time.Now() } diff --git a/apps/agent/pkg/clock/test_clock.go b/apps/agent/pkg/clock/test_clock.go index 580be114e..50e33a6a1 100644 --- a/apps/agent/pkg/clock/test_clock.go +++ b/apps/agent/pkg/clock/test_clock.go @@ -2,15 +2,31 @@ package clock import "time" -type RealClock struct { +type TestClock struct { + now time.Time } -func New() *RealClock { - return &RealClock{} +func NewTestClock(now ...time.Time) *TestClock { + if len(now) == 0 { + now = append(now, time.Now()) + } + return &TestClock{now: now[0]} } -var _ Clock = &RealClock{} +var _ Clock = &TestClock{} -func (c *RealClock) Now() time.Time { - return time.Now() +func (c *TestClock) Now() time.Time { + return c.now +} + +// Tick advances the clock by the given duration and returns the new time. +func (c *TestClock) Tick(d time.Duration) time.Time { + c.now = c.now.Add(d) + return c.now +} + +// Set sets the clock to the given time and returns the new time. +func (c *TestClock) Set(t time.Time) time.Time { + c.now = t + return c.now } diff --git a/apps/agent/services/ratelimit/metrics.go b/apps/agent/services/ratelimit/metrics.go index bee92c052..595dc0685 100644 --- a/apps/agent/services/ratelimit/metrics.go +++ b/apps/agent/services/ratelimit/metrics.go @@ -23,4 +23,12 @@ var ( Subsystem: "ratelimit", Name: "ratelimits_total", }, []string{"passed"}) + + // forceSync is a counter that increments every time the agent is forced to + // sync with the origin ratelimit service because it doesn't have enough data + forceSync = promauto.NewCounter(prometheus.CounterOpts{ + Namespace: "agent", + Subsystem: "ratelimit", + Name: "force_sync", + }) ) diff --git a/apps/agent/services/ratelimit/mitigate.go b/apps/agent/services/ratelimit/mitigate.go index 2d75b280f..4bc7dd93c 100644 --- a/apps/agent/services/ratelimit/mitigate.go +++ b/apps/agent/services/ratelimit/mitigate.go @@ -21,7 +21,6 @@ func (s *service) Mitigate(ctx context.Context, req *ratelimitv1.MitigateRequest bucket, _ := s.getBucket(bucketKey{req.Identifier, req.Limit, duration}) bucket.Lock() defer bucket.Unlock() - bucket.windows[req.Window.GetSequence()] = req.Window return &ratelimitv1.MitigateResponse{}, nil @@ -38,7 +37,7 @@ func (s *service) broadcastMitigation(req mitigateWindowRequest) { ctx := context.Background() node, err := s.cluster.FindNode(bucketKey{req.identifier, req.limit, req.duration}.toString()) if err != nil { - s.logger.Err(err).Msg("failed to find node") + s.logger.Warn().Err(err).Msg("failed to find node") return } if node.Id != s.cluster.NodeId() { @@ -51,16 +50,20 @@ func (s *service) broadcastMitigation(req mitigateWindowRequest) { return } for _, peer := range peers { - _, err := peer.client.Mitigate(ctx, connect.NewRequest(&ratelimitv1.MitigateRequest{ - Identifier: req.identifier, - Limit: req.limit, - Duration: req.duration.Milliseconds(), - Window: req.window, - })) + _, err := s.mitigateCircuitBreaker.Do(ctx, func(innerCtx context.Context) (*connect.Response[ratelimitv1.MitigateResponse], error) { + innerCtx, cancel := context.WithTimeout(innerCtx, 10*time.Second) + defer cancel() + return peer.client.Mitigate(innerCtx, connect.NewRequest(&ratelimitv1.MitigateRequest{ + Identifier: req.identifier, + Limit: req.limit, + Duration: req.duration.Milliseconds(), + Window: req.window, + })) + }) if err != nil { s.logger.Err(err).Msg("failed to call mitigate") } else { - s.logger.Info().Str("peerId", peer.id).Msg("broadcasted mitigation") + s.logger.Debug().Str("peerId", peer.id).Msg("broadcasted mitigation") } } } diff --git a/apps/agent/services/ratelimit/ratelimit.go b/apps/agent/services/ratelimit/ratelimit.go index da8b89c74..1d6511a2b 100644 --- a/apps/agent/services/ratelimit/ratelimit.go +++ b/apps/agent/services/ratelimit/ratelimit.go @@ -38,23 +38,23 @@ func (s *service) Ratelimit(ctx context.Context, req *ratelimitv1.RatelimitReque } } - // prevExists, currExists := s.CheckWindows(ctx, ratelimitReq) - // // If neither window existed before, we should do an origin ratelimit check - // // because we likely don't have enough data to make an accurate decision' - // if !prevExists && !currExists { - - // originRes, err := s.ratelimitOrigin(ctx, req) - // // The control flow is a bit unusual here because we want to return early on - // // success, rather than on error - // if err == nil && originRes != nil { - // return originRes, nil - // } - // if err != nil { - // // We want to know about the error, but if there is one, we just fall back - // // to local state, so we don't return early - // s.logger.Err(err).Msg("failed to sync with origin, falling back to local state") - // } - // } + prevExists, currExists := s.CheckWindows(ctx, ratelimitReq) + // If neither window existed before, we should do an origin ratelimit check + // because we likely don't have enough data to make an accurate decision' + if !prevExists && !currExists { + + originRes, err := s.ratelimitOrigin(ctx, req) + // The control flow is a bit unusual here because we want to return early on + // success, rather than on error + if err == nil && originRes != nil { + return originRes, nil + } + if err != nil { + // We want to know about the error, but if there is one, we just fall back + // to local state, so we don't return early + s.logger.Err(err).Msg("failed to sync with origin, falling back to local state") + } + } taken := s.Take(ctx, ratelimitReq) @@ -93,6 +93,8 @@ func (s *service) ratelimitOrigin(ctx context.Context, req *ratelimitv1.Ratelimi ctx, span := tracing.Start(ctx, "ratelimit.RatelimitOrigin") defer span.End() + forceSync.Inc() + now := time.Now() if req.Time != nil { now = time.UnixMilli(req.GetTime()) @@ -116,7 +118,11 @@ func (s *service) ratelimitOrigin(ctx context.Context, req *ratelimitv1.Ratelimi Time: now.UnixMilli(), }) - res, err := client.PushPull(ctx, connectReq) + res, err := s.syncCircuitBreaker.Do(ctx, func(innerCtx context.Context) (*connect.Response[ratelimitv1.PushPullResponse], error) { + innerCtx, cancel := context.WithTimeout(innerCtx, 10*time.Second) + defer cancel() + return client.PushPull(innerCtx, connectReq) + }) if err != nil { tracing.RecordError(span, err) s.logger.Err(err).Msg("failed to call ratelimit") diff --git a/apps/agent/services/ratelimit/ratelimit_mitigation_test.go b/apps/agent/services/ratelimit/ratelimit_mitigation_test.go index b517591d5..a76899511 100644 --- a/apps/agent/services/ratelimit/ratelimit_mitigation_test.go +++ b/apps/agent/services/ratelimit/ratelimit_mitigation_test.go @@ -23,7 +23,7 @@ import ( ) func TestExceedingTheLimitShouldNotifyAllNodes(t *testing.T) { - t.Skip() + for _, clusterSize := range []int{1, 3, 5} { t.Run(fmt.Sprintf("Cluster Size %d", clusterSize), func(t *testing.T) { logger := logging.New(nil) @@ -94,12 +94,13 @@ func TestExceedingTheLimitShouldNotifyAllNodes(t *testing.T) { ctx := context.Background() // Saturate the window - for i := int64(0); i <= limit; i++ { + for i := int64(0); i < limit; i++ { rl := util.RandomElement(ratelimiters) res, err := rl.Ratelimit(ctx, req) require.NoError(t, err) t.Logf("saturate res: %+v", res) require.True(t, res.Success) + } time.Sleep(time.Second * 5) @@ -107,10 +108,11 @@ func TestExceedingTheLimitShouldNotifyAllNodes(t *testing.T) { // Let's hit everry node again // They should all be mitigated for i, rl := range ratelimiters { + res, err := rl.Ratelimit(ctx, req) require.NoError(t, err) t.Logf("res from %d: %+v", i, res) - // require.False(t, res.Success) + require.False(t, res.Success) } }) diff --git a/apps/agent/services/ratelimit/ratelimit_replication_test.go b/apps/agent/services/ratelimit/ratelimit_replication_test.go index cae53d6dd..8e93fc19e 100644 --- a/apps/agent/services/ratelimit/ratelimit_replication_test.go +++ b/apps/agent/services/ratelimit/ratelimit_replication_test.go @@ -24,8 +24,7 @@ import ( "github.com/unkeyed/unkey/apps/agent/pkg/util" ) -func TestReplication(t *testing.T) { - t.Skip() +func TestSync(t *testing.T) { type Node struct { srv *service cluster cluster.Cluster @@ -106,7 +105,7 @@ func TestReplication(t *testing.T) { } // Figure out who is the origin - _, err := nodes[1].srv.Ratelimit(ctx, req) + _, err := nodes[0].srv.Ratelimit(ctx, req) require.NoError(t, err) time.Sleep(5 * time.Second) @@ -138,7 +137,6 @@ func TestReplication(t *testing.T) { require.True(t, ok) bucket.RLock() window := bucket.getCurrentWindow(now) - t.Logf("window on origin: %+v", window) counter := window.Counter bucket.RUnlock() diff --git a/apps/agent/services/ratelimit/ratelimit_test.go b/apps/agent/services/ratelimit/ratelimit_test.go index 8d8edc01f..9a764b458 100644 --- a/apps/agent/services/ratelimit/ratelimit_test.go +++ b/apps/agent/services/ratelimit/ratelimit_test.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "net/url" - "sync" "testing" "time" @@ -28,134 +27,142 @@ import ( func TestAccuracy_fixed_time(t *testing.T) { - for _, clusterSize := range []int{5} { - t.Run(fmt.Sprintf("Cluster Size %d", clusterSize), func(t *testing.T) { - logger := logging.New(nil) - clusters := []cluster.Cluster{} - ratelimiters := []ratelimit.Service{} - serfAddrs := []string{} - - for i := range clusterSize { - c, serfAddr, rpcAddr := createCluster(t, fmt.Sprintf("node-%d", i), serfAddrs) - serfAddrs = append(serfAddrs, serfAddr) - clusters = append(clusters, c) - - rl, err := ratelimit.New(ratelimit.Config{ - Logger: logger, - Metrics: metrics.NewNoop(), - Cluster: c, - }) - require.NoError(t, err) - ratelimiters = append(ratelimiters, rl) - - srv, err := connectSrv.New(connectSrv.Config{ - Logger: logger, - Metrics: metrics.NewNoop(), - Image: "does not matter", - }) - require.NoError(t, err) - err = srv.AddService(connectSrv.NewRatelimitServer(rl, logger, "test-auth-token")) - require.NoError(t, err) - - require.NoError(t, err) - u, err := url.Parse(rpcAddr) - require.NoError(t, err) - go srv.Listen(u.Host) - - require.Eventually(t, func() bool { - client := ratelimitv1connect.NewRatelimitServiceClient(http.DefaultClient, rpcAddr) - res, livenessErr := client.Liveness(context.Background(), connect.NewRequest(&ratelimitv1.LivenessRequest{})) - require.NoError(t, livenessErr) - return res.Msg.Status == "ok" - - }, - time.Minute, 100*time.Millisecond) - - } - require.Len(t, ratelimiters, clusterSize) - require.Len(t, serfAddrs, clusterSize) - - for _, c := range clusters { - require.Eventually(t, func() bool { - return c.Size() == clusterSize - }, time.Minute, 100*time.Millisecond) - } - + for _, clusterSize := range []int{1, 3, 5} { + t.Run(fmt.Sprintf("ClusterSize:%d", clusterSize), func(t *testing.T) { for _, limit := range []int64{ 5, 10, 100, } { - for _, duration := range []time.Duration{ - 1 * time.Second, - 10 * time.Second, - 1 * time.Minute, - 5 * time.Minute, - 1 * time.Hour, - } { - for _, windows := range []int64{1, 2, 5, 10, 50} { - // Attack the ratelimit with 100x as much as it should let pass - requests := limit * windows * 100 - - for _, nIngressNodes := range []int{1, 3, clusterSize} { - if nIngressNodes > clusterSize { - nIngressNodes = clusterSize - } - t.Run(fmt.Sprintf("%d/%d ingress nodes: rate %d/%s %d requests across %d windows", - nIngressNodes, - clusterSize, - limit, - duration, - requests, - windows, - ), func(t *testing.T) { - - identifier := uid.New("test") - ingressNodes := ratelimiters[:nIngressNodes] - - now := time.Now() - end := now.Add(duration * time.Duration(windows)) - passed := int64(0) - - dt := duration * time.Duration(windows) / time.Duration(requests) - - for i := now; i.Before(end); i = i.Add(dt) { - rl := util.RandomElement(ingressNodes) - - res, err := rl.Ratelimit(context.Background(), &ratelimitv1.RatelimitRequest{ - // random time within one of the windows - Time: util.Pointer(i.UnixMilli()), - Identifier: identifier, - Limit: limit, - Duration: duration.Milliseconds(), - Cost: 1, - }) - require.NoError(t, err) - if res.Success { - passed++ - } - } + type Node struct { + srv ratelimit.Service + cluster cluster.Cluster + } - // At least 95% of the requests should pass - // lower := 0.95 - // At most 150% + 75% per additional ingress node should pass - upper := 1.50 + 1.0*float64(len(ingressNodes)-1) + nodes := []Node{} + logger := logging.New(nil) + serfAddrs := []string{} + + for i := 0; i < clusterSize; i++ { + node := Node{} + c, serfAddr, rpcAddr := createCluster(t, fmt.Sprintf("node-%d", i), serfAddrs) + serfAddrs = append(serfAddrs, serfAddr) + node.cluster = c + + srv, err := ratelimit.New(ratelimit.Config{ + Logger: logger, + Metrics: metrics.NewNoop(), + Cluster: c, + }) + require.NoError(t, err) + node.srv = srv + nodes = append(nodes, node) + + cSrv, err := connectSrv.New(connectSrv.Config{ + Logger: logger, + Metrics: metrics.NewNoop(), + Image: "does not matter", + }) + require.NoError(t, err) + err = cSrv.AddService(connectSrv.NewRatelimitServer(srv, logger, "test-auth-token")) + require.NoError(t, err) + + require.NoError(t, err) + u, err := url.Parse(rpcAddr) + require.NoError(t, err) + + go cSrv.Listen(u.Host) + + require.Eventually(t, func() bool { + client := ratelimitv1connect.NewRatelimitServiceClient(http.DefaultClient, rpcAddr) + res, livenessErr := client.Liveness(context.Background(), connect.NewRequest(&ratelimitv1.LivenessRequest{})) + require.NoError(t, livenessErr) + return res.Msg.Status == "ok" + + }, + time.Minute, 100*time.Millisecond) + } + require.Len(t, nodes, clusterSize) + require.Len(t, serfAddrs, clusterSize) - exactLimit := limit * (windows + 1) - // require.GreaterOrEqual(t, passed, int64(float64(exactLimit)*lower)) - require.LessOrEqual(t, passed, int64(float64(exactLimit)*upper)) + for _, n := range nodes { + require.Eventually(t, func() bool { + return n.cluster.Size() == clusterSize + }, time.Minute, 100*time.Millisecond) + } - }) - } + t.Run(fmt.Sprintf("limit:%d", limit), func(t *testing.T) { + + for _, duration := range []time.Duration{ + 10 * time.Second, + 1 * time.Minute, + 5 * time.Minute, + 1 * time.Hour, + } { + t.Run(fmt.Sprintf("duration:%s", duration), func(t *testing.T) { + + for _, windows := range []int64{1, 2, 5, 10, 50} { + t.Run(fmt.Sprintf("windows:%d", windows), func(t *testing.T) { + + // Attack the ratelimit with 100x as much as it should let pass + requests := limit * windows * 100 + + for _, nIngressNodes := range []int{1, 3, clusterSize} { + if nIngressNodes > clusterSize { + nIngressNodes = clusterSize + } + t.Run(fmt.Sprintf("%d/%d ingress nodes", + nIngressNodes, + clusterSize, + ), func(t *testing.T) { + + identifier := uid.New("test") + ingressNodes := nodes[:nIngressNodes] + + now := time.Now() + end := now.Add(duration * time.Duration(windows)) + passed := int64(0) + + dt := duration * time.Duration(windows) / time.Duration(requests) + + for i := now; i.Before(end); i = i.Add(dt) { + rl := util.RandomElement(ingressNodes) + + res, err := rl.srv.Ratelimit(context.Background(), &ratelimitv1.RatelimitRequest{ + // random time within one of the windows + Time: util.Pointer(i.UnixMilli()), + Identifier: identifier, + Limit: limit, + Duration: duration.Milliseconds(), + Cost: 1, + }) + require.NoError(t, err) + if res.Success { + passed++ + } + } + + lower := limit * windows + // At most 150% + 75% per additional ingress node should pass + upper := 1.50 + 1.0*float64(len(ingressNodes)-1) + + require.GreaterOrEqual(t, passed, lower) + require.LessOrEqual(t, passed, int64(float64(limit*(windows+1))*upper)) + }) + } + }) + } + }) } + }) + for _, n := range nodes { + require.NoError(t, n.cluster.Shutdown()) } - } - for _, c := range clusters { - require.NoError(t, c.Shutdown()) } + }) } } @@ -205,35 +212,3 @@ func createCluster( return c, serfAddr, rpcAddr } - -func loadTest[T any](t *testing.T, rps int64, seconds int64, fn func() T) []T { - t.Helper() - - resultsC := make(chan T) - - var wg sync.WaitGroup - - for range seconds { - for range rps { - time.Sleep(time.Second / time.Duration(rps)) - - wg.Add(1) - go func() { - resultsC <- fn() - }() - } - } - - results := []T{} - go func() { - for res := range resultsC { - results = append(results, res) - wg.Done() - - } - }() - wg.Wait() - - return results - -} diff --git a/apps/agent/services/ratelimit/service.go b/apps/agent/services/ratelimit/service.go index d6e1bff85..c40458a31 100644 --- a/apps/agent/services/ratelimit/service.go +++ b/apps/agent/services/ratelimit/service.go @@ -38,7 +38,8 @@ type service struct { // Store a reference leaseId -> window key leaseIdToKeyMap map[string]string - syncCircuitBreaker circuitbreaker.CircuitBreaker[*connect.Response[ratelimitv1.PushPullResponse]] + syncCircuitBreaker circuitbreaker.CircuitBreaker[*connect.Response[ratelimitv1.PushPullResponse]] + mitigateCircuitBreaker circuitbreaker.CircuitBreaker[*connect.Response[ratelimitv1.MitigateResponse]] } type Config struct { @@ -64,6 +65,15 @@ func New(cfg Config) (*service, error) { buckets: make(map[string]*bucket), leaseIdToKeyMapLock: sync.RWMutex{}, leaseIdToKeyMap: make(map[string]string), + + mitigateCircuitBreaker: circuitbreaker.New[*connect.Response[ratelimitv1.MitigateResponse]]( + "ratelimit.broadcastMitigation", + circuitbreaker.WithLogger(cfg.Logger), + circuitbreaker.WithCyclicPeriod(10*time.Second), + circuitbreaker.WithTimeout(time.Minute), + circuitbreaker.WithMaxRequests(100), + circuitbreaker.WithTripThreshold(50), + ), syncCircuitBreaker: circuitbreaker.New[*connect.Response[ratelimitv1.PushPullResponse]]( "ratelimit.syncWithOrigin", circuitbreaker.WithLogger(cfg.Logger), diff --git a/apps/agent/services/ratelimit/sliding_window.go b/apps/agent/services/ratelimit/sliding_window.go index 78bbb4337..f8cd41b1e 100644 --- a/apps/agent/services/ratelimit/sliding_window.go +++ b/apps/agent/services/ratelimit/sliding_window.go @@ -2,7 +2,6 @@ package ratelimit import ( "context" - "math" "time" ratelimitv1 "github.com/unkeyed/unkey/apps/agent/gen/proto/ratelimit/v1" @@ -109,6 +108,15 @@ func (r *service) CheckWindows(ctx context.Context, req ratelimitRequest) (prev return prev, curr } +// :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: +// :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: +// Experimentally, we are reverting this to fixed-window until we can get rid +// of the cloudflare cachelayer. +// +// Throughout this function there is commented out and annotated code that we +// need to reenable later. Such code is also marked with the comment "FIXED-WINDOW" +// :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: +// :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: func (r *service) Take(ctx context.Context, req ratelimitRequest) ratelimitResponse { ctx, span := tracing.Start(ctx, "slidingWindow.Take") defer span.End() @@ -127,13 +135,21 @@ func (r *service) Take(ctx context.Context, req ratelimitRequest) ratelimitRespo currentWindow := bucket.getCurrentWindow(req.Time) previousWindow := bucket.getPreviousWindow(req.Time) - currentWindowPercentage := float64(req.Time.UnixMilli()-currentWindow.Start) / float64(req.Duration.Milliseconds()) - previousWindowPercentage := 1.0 - currentWindowPercentage + // FIXED-WINDOW + // uncomment + // currentWindowPercentage := float64(req.Time.UnixMilli()-currentWindow.Start) / float64(req.Duration.Milliseconds()) + // previousWindowPercentage := 1.0 - currentWindowPercentage // Calculate the current count including all leases - fromPreviousWindow := float64(previousWindow.Counter) * previousWindowPercentage - fromCurrentWindow := float64(currentWindow.Counter) - current := int64(math.Ceil(fromCurrentWindow + fromPreviousWindow)) + // FIXED-WINDOW + // uncomment + // fromPreviousWindow := float64(previousWindow.Counter) * previousWindowPercentage + // fromCurrentWindow := float64(currentWindow.Counter) + + // FIXED-WINDOW + // replace this with the following line + // current := int64(math.Ceil(fromCurrentWindow + fromPreviousWindow)) + current := currentWindow.Counter // r.logger.Info().Int64("fromCurrentWindow", fromCurrentWindow).Int64("fromPreviousWindow", fromPreviousWindow).Time("now", req.Time).Time("currentWindow.start", currentWindow.start).Int64("msSinceStart", msSinceStart).Float64("currentWindowPercentage", currentWindowPercentage).Float64("previousWindowPercentage", previousWindowPercentage).Bool("currentWindowExists", currentWindowExists).Bool("previousWindowExists", previousWindowExists).Int64("current", current).Interface("buckets", r.buckets).Send() // currentWithLeases := id.current @@ -180,12 +196,12 @@ func (r *service) Take(ctx context.Context, req ratelimitRequest) ratelimitRespo currentWindow.Counter += req.Cost if currentWindow.Counter >= req.Limit && !currentWindow.MitigateBroadcasted && r.mitigateBuffer != nil { currentWindow.MitigateBroadcasted = true - // r.mitigateBuffer <- mitigateWindowRequest{ - // identifier: req.Identifier, - // limit: req.Limit, - // duration: req.Duration, - // window: currentWindow, - // } + r.mitigateBuffer <- mitigateWindowRequest{ + identifier: req.Identifier, + limit: req.Limit, + duration: req.Duration, + window: currentWindow, + } } current += req.Cost @@ -264,6 +280,7 @@ func (r *service) SetCounter(ctx context.Context, requests ...setCounterRequest) func newWindow(sequence int64, t time.Time, duration time.Duration) *ratelimitv1.Window { return &ratelimitv1.Window{ + Sequence: sequence, MitigateBroadcasted: false, Start: t.Truncate(duration).UnixMilli(), Duration: duration.Milliseconds(), diff --git a/apps/agent/services/ratelimit/sync_with_origin.go b/apps/agent/services/ratelimit/sync_with_origin.go index fe1edd9ac..332743754 100644 --- a/apps/agent/services/ratelimit/sync_with_origin.go +++ b/apps/agent/services/ratelimit/sync_with_origin.go @@ -43,7 +43,7 @@ func (s *service) syncWithOrigin(req syncWithOriginRequest) { }) if err != nil { s.peersMu.Lock() - s.logger.Warn().Err(err).Msg("resetting peer client due to error") + s.logger.Warn().Str("peerId", peer.Id).Err(err).Msg("resetting peer client due to error") delete(s.peers, peer.Id) s.peersMu.Unlock() tracing.RecordError(span, err) diff --git a/apps/api/package.json b/apps/api/package.json index 040eae0b3..2022dc6a3 100644 --- a/apps/api/package.json +++ b/apps/api/package.json @@ -36,6 +36,7 @@ "@unkey/logs": "workspace:^", "@unkey/metrics": "workspace:^", "@unkey/rbac": "workspace:^", + "@unkey/clickhouse-zod": "workspace:^", "@unkey/schema": "workspace:^", "@unkey/worker-logging": "workspace:^", "hono": "^4.5.8", diff --git a/apps/api/src/pkg/analytics.ts b/apps/api/src/pkg/analytics.ts index bc834fa88..66907d5b0 100644 --- a/apps/api/src/pkg/analytics.ts +++ b/apps/api/src/pkg/analytics.ts @@ -1,4 +1,5 @@ import { NoopTinybird, Tinybird } from "@chronark/zod-bird"; +import * as ch from "@unkey/clickhouse-zod"; import { newId } from "@unkey/id"; import { auditLogSchemaV1, unkeyAuditLogEvents } from "@unkey/schema/src/auditlog"; import { ratelimitSchemaV1 } from "@unkey/schema/src/ratelimit-tinybird"; @@ -17,6 +18,7 @@ const dateToUnixMilli = z.string().transform((t) => new Date(t.split(" ").at(0) export class Analytics { public readonly readClient: Tinybird | NoopTinybird; public readonly writeClient: Tinybird | NoopTinybird; + private clickhouse: ch.Clickhouse; constructor(opts: { tinybirdToken?: string; @@ -24,6 +26,9 @@ export class Analytics { url: string; token: string; }; + clickhouse?: { + url: string; + }; }) { this.readClient = opts.tinybirdToken ? new Tinybird({ token: opts.tinybirdToken }) @@ -32,6 +37,8 @@ export class Analytics { this.writeClient = opts.tinybirdProxy ? new Tinybird({ token: opts.tinybirdProxy.token, baseUrl: opts.tinybirdProxy.url }) : this.readClient; + + this.clickhouse = opts.clickhouse ? new ch.Client({ url: opts.clickhouse.url }) : new ch.Noop(); } public get ingestSdkTelemetry() { @@ -93,6 +100,53 @@ export class Analytics { }); } + public get insertKeyVerification() { + return this.clickhouse.insert({ + table: "default.raw_key_verifications_v1", + schema: z.object({ + request_id: z.string(), + time: z.number().int(), + workspace_id: z.string(), + key_space_id: z.string(), + key_id: z.string(), + region: z.string(), + outcome: z.enum([ + "VALID", + "RATE_LIMITED", + "EXPIRED", + "DISABLED", + "FORBIDDEN", + "USAGE_EXCEEDED", + "INSUFFICIENT_PERMISSIONS", + ]), + identity_id: z.string().optional().default(""), + }), + }); + } + + public get insertApiRequest() { + return this.clickhouse.insert({ + table: "default.raw_api_requests_v1", + schema: z.object({ + request_id: z.string(), + time: z.number().int(), + workspace_id: z.string(), + host: z.string(), + method: z.string(), + path: z.string(), + request_headers: z.array(z.string()), + request_body: z.string(), + response_status: z.number().int(), + response_headers: z.array(z.string()), + response_body: z.string(), + error: z.string().optional().default(""), + service_latency: z.number().int(), + user_agent: z.string(), + ip_address: z.string(), + }), + }); + } + public get ingestKeyVerification() { return this.writeClient.buildIngestEndpoint({ datasource: "key_verifications__v2", diff --git a/apps/api/src/pkg/env.ts b/apps/api/src/pkg/env.ts index 2e0a33bbd..6ce33d7a8 100644 --- a/apps/api/src/pkg/env.ts +++ b/apps/api/src/pkg/env.ts @@ -28,6 +28,9 @@ export const zEnv = z.object({ }), AGENT_URL: z.string().url(), AGENT_TOKEN: z.string(), + + CLICKHOUSE_URL: z.string().optional(), + SYNC_RATELIMIT_ON_NO_DATA: z .string() .optional() diff --git a/apps/api/src/pkg/hono/env.ts b/apps/api/src/pkg/hono/env.ts index 991193f7b..34ccb8dd8 100644 --- a/apps/api/src/pkg/hono/env.ts +++ b/apps/api/src/pkg/hono/env.ts @@ -29,6 +29,8 @@ export type HonoEnv = { isolateId: string; isolateCreatedAt: number; requestId: string; + requestStartedAt: number; + workspaceId?: string; metricsContext: { keyId?: string; [key: string]: unknown; diff --git a/apps/api/src/pkg/keys/service.ts b/apps/api/src/pkg/keys/service.ts index 407c34965..93a8c7e55 100644 --- a/apps/api/src/pkg/keys/service.ts +++ b/apps/api/src/pkg/keys/service.ts @@ -26,11 +26,11 @@ export class DisabledWorkspaceError extends BaseError<{ workspaceId: string }> { export class MissingRatelimitError extends BaseError<{ name: string }> { public readonly retry = false; public readonly name = MissingRatelimitError.name; - constructor(name: string) { + constructor(ratelimitName: string, message: string) { super({ - message: `ratelimit "${name}" does not exist`, + message, context: { - name, + name: ratelimitName, }, }); } @@ -151,6 +151,7 @@ export class KeyService { }); return res; } + c.set("workspaceId", res.val.key?.forWorkspaceId ?? res.val.key?.workspaceId); this.metrics.emit({ metric: "metric.key.verification", @@ -503,7 +504,14 @@ export class KeyService { continue; } - return Err(new MissingRatelimitError(r.name)); + let errorMessage = `ratelimit "${r.name}" was requested but does not exist for key "${data.key.id}"`; + if (data.identity) { + errorMessage += ` nor identity { id: ${data.identity.id}, externalId: ${data.identity.externalId}}`; + } else { + errorMessage += " and there is no identity connected"; + } + + return Err(new MissingRatelimitError(r.name, errorMessage)); } const [pass, ratelimit] = await this.ratelimit(c, data.key, ratelimits); diff --git a/apps/api/src/pkg/middleware/init.ts b/apps/api/src/pkg/middleware/init.ts index d78a0dea7..23289f787 100644 --- a/apps/api/src/pkg/middleware/init.ts +++ b/apps/api/src/pkg/middleware/init.ts @@ -1,5 +1,6 @@ import { Analytics } from "@/pkg/analytics"; import { createConnection } from "@/pkg/db"; + import { KeyService } from "@/pkg/keys/service"; import { AgentRatelimiter } from "@/pkg/ratelimit"; import { DurableUsageLimiter, NoopUsageLimiter } from "@/pkg/usagelimit"; @@ -45,6 +46,8 @@ export function init(): MiddlewareHandler { const requestId = newId("request"); c.set("requestId", requestId); + c.set("requestStartedAt", Date.now()); + c.res.headers.set("Unkey-Request-Id", requestId); const logger = new ConsoleLogger({ @@ -104,6 +107,11 @@ export function init(): MiddlewareHandler { const analytics = new Analytics({ tinybirdProxy, tinybirdToken: c.env.TINYBIRD_TOKEN, + clickhouse: c.env.CLICKHOUSE_URL + ? { + url: c.env.CLICKHOUSE_URL, + } + : undefined, }); const rateLimiter = new AgentRatelimiter({ agent: { url: c.env.AGENT_URL, token: c.env.AGENT_TOKEN }, diff --git a/apps/api/src/pkg/middleware/metrics.ts b/apps/api/src/pkg/middleware/metrics.ts index 72d37be3c..f612c1d72 100644 --- a/apps/api/src/pkg/middleware/metrics.ts +++ b/apps/api/src/pkg/middleware/metrics.ts @@ -8,6 +8,8 @@ export function metrics(): MiddlewareHandler { return async (c, next) => { const { metrics, analytics, logger } = c.get("services"); + let requestBody = await c.req.raw.clone().text(); + requestBody = requestBody.replaceAll(/"key":\s*"[a-zA-Z0-9_]+"/g, '"key": ""'); const start = performance.now(); const m = { isolateId: c.get("isolateId"), @@ -79,6 +81,36 @@ export function metrics(): MiddlewareHandler { c.res.headers.append("Unkey-Version", c.env.VERSION); metrics.emit(m); c.executionCtx.waitUntil(metrics.flush()); + + const responseHeaders: Array = []; + c.res.headers.forEach((v, k) => { + responseHeaders.push(`${k}: ${v}`); + }); + + c.executionCtx.waitUntil( + analytics.insertApiRequest({ + request_id: c.get("requestId"), + time: c.get("requestStartedAt"), + workspace_id: c.get("workspaceId") ?? "", + host: new URL(c.req.url).host, + method: c.req.method, + path: c.req.path, + request_headers: Object.entries(c.req.header()).map(([k, v]) => { + if (k.toLowerCase() === "authorization") { + return `${k}: `; + } + return `${k}: ${v}`; + }), + request_body: requestBody, + response_status: c.res.status, + response_headers: responseHeaders, + response_body: await c.res.clone().text(), + error: m.error ?? "", + service_latency: Date.now() - c.get("requestStartedAt"), + ip_address: c.req.header("True-Client-IP") ?? c.req.header("CF-Connecting-IP") ?? "", + user_agent: c.req.header("User-Agent") ?? "", + }), + ); } }; } diff --git a/apps/api/src/pkg/ratelimit/client.ts b/apps/api/src/pkg/ratelimit/client.ts index 39317fc2e..fc231d9d0 100644 --- a/apps/api/src/pkg/ratelimit/client.ts +++ b/apps/api/src/pkg/ratelimit/client.ts @@ -3,6 +3,7 @@ import type { Logger } from "@unkey/worker-logging"; import type { Metrics } from "../metrics"; import type { Context } from "../hono/app"; +import { retry } from "../util/retry"; import { Agent } from "./agent"; import { type RateLimiter, @@ -14,13 +15,13 @@ import { export class AgentRatelimiter implements RateLimiter { private readonly logger: Logger; private readonly metrics: Metrics; - private readonly cache: Map; + private readonly cache: Map; private readonly agent: Agent; constructor(opts: { agent: { url: string; token: string }; logger: Logger; metrics: Metrics; - cache: Map; + cache: Map; }) { this.logger = opts.logger; this.metrics = opts.metrics; @@ -35,7 +36,7 @@ export class AgentRatelimiter implements RateLimiter { return [req.identifier, window, req.shard].join("::"); } - private setCache(id: string, current: number, reset: number, blocked: boolean) { + private setCacheMax(id: string, current: number, reset: number) { const maxEntries = 10_000; this.metrics.emit({ metric: "metric.cache.size", @@ -54,7 +55,11 @@ export class AgentRatelimiter implements RateLimiter { } } } - this.cache.set(id, { reset, current, blocked }); + const cached = this.cache.get(id) ?? { reset: 0, current: 0 }; + if (current > cached.current) { + this.cache.set(id, { reset, current }); + return current; + } } public async limit( @@ -122,8 +127,8 @@ export class AgentRatelimiter implements RateLimiter { * This might not happen too often, but in extreme cases the cache should hit and we can skip * the request to the durable object entirely, which speeds everything up and is cheaper for us */ - const cached = this.cache.get(id) ?? { current: 0, reset: 0, blocked: false }; - if (cached.blocked) { + const cached = this.cache.get(id) ?? { current: 0, reset: 0 }; + if (cached.current >= req.limit) { return Ok({ pass: false, current: cached.current, @@ -133,31 +138,22 @@ export class AgentRatelimiter implements RateLimiter { }); } - const p = (async () => { - const a = await this.callAgent(c, { + const p = retry(3, async () => + this.callAgent(c, { requestId: c.get("requestId"), identifier: req.identifier, cost, duration: req.interval, limit: req.limit, name: req.name, - }); - if (a.err) { + }).catch((err) => { this.logger.error("error calling agent", { - error: a.err.message, - json: JSON.stringify(a.err), - }); - return await this.callAgent(c, { - requestId: c.get("requestId"), - identifier: req.identifier, - cost, - duration: req.interval, - limit: req.limit, - name: req.name, + error: err.message, + json: JSON.stringify(err), }); - } - return a; - })(); + throw err; + }), + ); // A rollout of the sync rate limiting // Isolates younger than 60s must not sync. It would cause a stampede of requests as the cache is entirely empty @@ -169,7 +165,7 @@ export class AgentRatelimiter implements RateLimiter { if (sync) { const res = await p; if (res.val) { - this.setCache(id, res.val.current, res.val.reset, !res.val.pass); + this.setCacheMax(id, res.val.current, res.val.reset); } return res; } @@ -180,7 +176,7 @@ export class AgentRatelimiter implements RateLimiter { this.logger.error(res.err.message); return; } - this.setCache(id, res.val.current, res.val.reset, !res.val.pass); + this.setCacheMax(id, res.val.current, res.val.reset); this.metrics.emit({ workspaceId: req.workspaceId, @@ -203,7 +199,7 @@ export class AgentRatelimiter implements RateLimiter { }); } cached.current += cost; - this.setCache(id, cached.current, reset, false); + this.setCacheMax(id, cached.current, reset); return Ok({ pass: true, diff --git a/apps/api/src/routes/v1_identities_updateIdentity.happy.test.ts b/apps/api/src/routes/v1_identities_updateIdentity.happy.test.ts index 88260819a..2530e8c0f 100644 --- a/apps/api/src/routes/v1_identities_updateIdentity.happy.test.ts +++ b/apps/api/src/routes/v1_identities_updateIdentity.happy.test.ts @@ -120,6 +120,8 @@ test("sets new ratelimits", async (t) => { expect(res.status, `expected 200, received: ${JSON.stringify(res, null, 2)}`).toBe(200); + await new Promise((r) => setTimeout(r, 2000)); + const found = await h.db.primary.query.ratelimits.findMany({ where: (table, { eq }) => eq(table.identityId, identity.id), }); diff --git a/apps/api/src/routes/v1_identities_updateIdentity.ts b/apps/api/src/routes/v1_identities_updateIdentity.ts index 6e2ff74b3..03a44eafa 100644 --- a/apps/api/src/routes/v1_identities_updateIdentity.ts +++ b/apps/api/src/routes/v1_identities_updateIdentity.ts @@ -253,7 +253,6 @@ export const registerV1IdentitiesUpdateIdentity = (app: App) => /** * Delete undesired ratelimits */ - for (const rl of deleteRatelimits) { await tx.delete(schema.ratelimits).where(eq(schema.ratelimits.id, rl.id)); auditLogs.push({ diff --git a/apps/api/src/routes/v1_keys_createKey.happy.test.ts b/apps/api/src/routes/v1_keys_createKey.happy.test.ts index e067320f0..8b02cec8e 100644 --- a/apps/api/src/routes/v1_keys_createKey.happy.test.ts +++ b/apps/api/src/routes/v1_keys_createKey.happy.test.ts @@ -389,3 +389,82 @@ describe("with ownerId", () => { }); }); }); + +describe("with externalId", () => { + describe("when externalId does not exist yet", () => { + test("should create identity", async (t) => { + const h = await IntegrationHarness.init(t); + + const root = await h.createRootKey([`api.${h.resources.userApi.id}.create_key`]); + + const externalId = newId("test"); + const res = await h.post({ + url: "/v1/keys.createKey", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${root.key}`, + }, + body: { + apiId: h.resources.userApi.id, + externalId, + }, + }); + + expect(res.status, `expected 200, received: ${JSON.stringify(res, null, 2)}`).toBe(200); + + const identity = await h.db.primary.query.identities.findFirst({ + where: (table, { eq }) => eq(table.externalId, externalId), + with: { + keys: true, + }, + }); + expect(identity).toBeDefined(); + + const key = identity!.keys.at(0); + expect(key).toBeDefined(); + expect(key!.id).toEqual(res.body.keyId); + }); + }); + + describe("when the identity exists already", () => { + test("should link to the identity", async (t) => { + const h = await IntegrationHarness.init(t); + + const externalId = newId("test"); + + const identity = { + id: newId("test"), + externalId, + workspaceId: h.resources.userWorkspace.id, + }; + + await h.db.primary.insert(schema.identities).values(identity); + + const root = await h.createRootKey([`api.${h.resources.userApi.id}.create_key`]); + + const res = await h.post({ + url: "/v1/keys.createKey", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${root.key}`, + }, + body: { + apiId: h.resources.userApi.id, + ownerId: externalId, + }, + }); + + expect(res.status, `expected 200, received: ${JSON.stringify(res, null, 2)}`).toBe(200); + + const key = await h.db.primary.query.keys.findFirst({ + where: (table, { eq }) => eq(table.id, res.body.keyId), + with: { + identity: true, + }, + }); + expect(key).toBeDefined(); + expect(key!.identity).toBeDefined(); + expect(key!.identity!.id).toEqual(identity.id); + }); + }); +}); diff --git a/apps/api/src/routes/v1_keys_createKey.ts b/apps/api/src/routes/v1_keys_createKey.ts index 006feb7ca..8780846e9 100644 --- a/apps/api/src/routes/v1_keys_createKey.ts +++ b/apps/api/src/routes/v1_keys_createKey.ts @@ -569,7 +569,7 @@ async function getRoleIds( return roles.map((r) => r.id); } -async function upsertIdentity( +export async function upsertIdentity( db: Database, workspaceId: string, externalId: string, diff --git a/apps/api/src/routes/v1_keys_updateKey.happy.test.ts b/apps/api/src/routes/v1_keys_updateKey.happy.test.ts index 52e664995..30c1d555e 100644 --- a/apps/api/src/routes/v1_keys_updateKey.happy.test.ts +++ b/apps/api/src/routes/v1_keys_updateKey.happy.test.ts @@ -585,6 +585,207 @@ test("delete expires", async (t) => { expect(found?.expires).toBeNull(); }); +describe("externalId", () => { + test("set externalId connects the identity", async (t) => { + const h = await IntegrationHarness.init(t); + + const root = await h.createRootKey([`api.${h.resources.userApi.id}.update_key`]); + + const key = await h.createKey(); + const externalId = newId("test"); + + const res = await h.post({ + url: "/v1/keys.updateKey", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${root.key}`, + }, + body: { + keyId: key.keyId, + externalId, + }, + }); + + expect(res.status, `expected 200, received: ${JSON.stringify(res, null, 2)}`).toBe(200); + + const found = await h.db.primary.query.keys.findFirst({ + where: (table, { eq }) => eq(table.id, key.keyId), + with: { + identity: true, + }, + }); + expect(found).toBeDefined(); + expect(found!.identity).toBeDefined(); + expect(found!.identity!.externalId).toBe(externalId); + }); + + test("omitting the field does not disconnect the identity", async (t) => { + const h = await IntegrationHarness.init(t); + + const root = await h.createRootKey([`api.${h.resources.userApi.id}.update_key`]); + + const identityId = newId("test"); + const externalId = newId("test"); + await h.db.primary.insert(schema.identities).values({ + id: identityId, + workspaceId: h.resources.userWorkspace.id, + externalId, + }); + const key = await h.createKey({ identityId }); + const before = await h.db.primary.query.keys.findFirst({ + where: (table, { eq }) => eq(table.id, key.keyId), + with: { + identity: true, + }, + }); + expect(before?.identity).toBeDefined(); + + const res = await h.post({ + url: "/v1/keys.updateKey", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${root.key}`, + }, + body: { + keyId: key.keyId, + externalId: undefined, + }, + }); + + expect(res.status, `expected 200, received: ${JSON.stringify(res, null, 2)}`).toBe(200); + + const found = await h.db.primary.query.keys.findFirst({ + where: (table, { eq }) => eq(table.id, key.keyId), + with: { + identity: true, + }, + }); + expect(found).toBeDefined(); + expect(found!.identity).toBeDefined(); + expect(found!.identity!.externalId).toBe(externalId); + }); + + test("set ownerId connects the identity", async (t) => { + const h = await IntegrationHarness.init(t); + + const root = await h.createRootKey([`api.${h.resources.userApi.id}.update_key`]); + + const key = await h.createKey(); + const ownerId = newId("test"); + + const res = await h.post({ + url: "/v1/keys.updateKey", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${root.key}`, + }, + body: { + keyId: key.keyId, + ownerId, + }, + }); + + expect(res.status, `expected 200, received: ${JSON.stringify(res, null, 2)}`).toBe(200); + + const found = await h.db.primary.query.keys.findFirst({ + where: (table, { eq }) => eq(table.id, key.keyId), + with: { + identity: true, + }, + }); + expect(found).toBeDefined(); + expect(found!.identity).toBeDefined(); + expect(found!.identity!.externalId).toBe(ownerId); + }); + + test("set externalId=null disconnects the identity", async (t) => { + const h = await IntegrationHarness.init(t); + + const root = await h.createRootKey([`api.${h.resources.userApi.id}.update_key`]); + + const identityId = newId("test"); + await h.db.primary.insert(schema.identities).values({ + id: identityId, + workspaceId: h.resources.userWorkspace.id, + externalId: newId("test"), + }); + const key = await h.createKey({ identityId }); + const before = await h.db.primary.query.keys.findFirst({ + where: (table, { eq }) => eq(table.id, key.keyId), + with: { + identity: true, + }, + }); + expect(before?.identity).toBeDefined(); + + const res = await h.post({ + url: "/v1/keys.updateKey", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${root.key}`, + }, + body: { + keyId: key.keyId, + externalId: null, + }, + }); + + expect(res.status, `expected 200, received: ${JSON.stringify(res, null, 2)}`).toBe(200); + + const found = await h.db.primary.query.keys.findFirst({ + where: (table, { eq }) => eq(table.id, key.keyId), + with: { + identity: true, + }, + }); + expect(found).toBeDefined(); + expect(found!.identity).toBeNull(); + }); + + test("set ownerId=null disconnects the identity", async (t) => { + const h = await IntegrationHarness.init(t); + + const root = await h.createRootKey([`api.${h.resources.userApi.id}.update_key`]); + + const identityId = newId("test"); + await h.db.primary.insert(schema.identities).values({ + id: identityId, + workspaceId: h.resources.userWorkspace.id, + externalId: newId("test"), + }); + const key = await h.createKey({ identityId }); + const before = await h.db.primary.query.keys.findFirst({ + where: (table, { eq }) => eq(table.id, key.keyId), + with: { + identity: true, + }, + }); + expect(before?.identity).toBeDefined(); + + const res = await h.post({ + url: "/v1/keys.updateKey", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${root.key}`, + }, + body: { + keyId: key.keyId, + ownerId: null, + }, + }); + + expect(res.status, `expected 200, received: ${JSON.stringify(res, null, 2)}`).toBe(200); + + const found = await h.db.primary.query.keys.findFirst({ + where: (table, { eq }) => eq(table.id, key.keyId), + with: { + identity: true, + }, + }); + expect(found).toBeDefined(); + expect(found!.identity).toBeNull(); + }); +}); test("update should not affect undefined fields", async (t) => { const h = await IntegrationHarness.init(t); diff --git a/apps/api/src/routes/v1_keys_updateKey.ts b/apps/api/src/routes/v1_keys_updateKey.ts index 566176da5..df68785fc 100644 --- a/apps/api/src/routes/v1_keys_updateKey.ts +++ b/apps/api/src/routes/v1_keys_updateKey.ts @@ -6,6 +6,7 @@ import { UnkeyApiError, openApiErrorResponses } from "@/pkg/errors"; import { schema } from "@unkey/db"; import { eq } from "@unkey/db"; import { buildUnkeyQuery } from "@unkey/rbac"; +import { upsertIdentity } from "./v1_keys_createKey"; import { setPermissions } from "./v1_keys_setPermissions"; import { setRoles } from "./v1_keys_setRoles"; @@ -30,11 +31,24 @@ const route = createRoute({ description: "The name of the key", example: "Customer X", }), - ownerId: z.string().nullish().openapi({ - description: - "The id of the tenant associated with this key. Use whatever reference you have in your system to identify the tenant. When verifying the key, we will send this field back to you, so you know who is accessing your API.", - example: "user_123", - }), + ownerId: z + .string() + .nullish() + .openapi({ + deprecated: true, + description: `Deprecated, use \`externalId\` + The id of the tenant associated with this key. Use whatever reference you have in your system to identify the tenant. When verifying the key, we will send this field back to you, so you know who is accessing your API.`, + example: "user_123", + }), + externalId: z + .string() + .nullish() + .openapi({ + description: `The id of the tenant associated with this key. Use whatever reference you have in your system to identify the tenant. When verifying the key, we will send this back to you, so you know who is accessing your API. + Under the hood this upserts and connects an \`ìdentity\` for you. + To disconnect the key from an identity, set \`externalId: null\`.`, + example: "user_123", + }), meta: z .record(z.unknown()) .nullish() @@ -324,12 +338,21 @@ export const registerV1KeysUpdate = (app: App) => const authorizedWorkspaceId = auth.authorizedWorkspaceId; const rootKeyId = auth.key.id; + const externalId = typeof req.externalId !== "undefined" ? req.externalId : req.ownerId; + const identityId = + typeof externalId === "undefined" + ? undefined + : externalId === null + ? null + : (await upsertIdentity(db.primary, authorizedWorkspaceId, externalId)).id; + await db.primary .update(schema.keys) .set({ name: req.name, ownerId: req.ownerId, meta: typeof req.meta === "undefined" ? undefined : JSON.stringify(req.meta ?? {}), + identityId, expires: typeof req.expires === "undefined" ? undefined diff --git a/apps/api/src/routes/v1_keys_verifyKey.error.test.ts b/apps/api/src/routes/v1_keys_verifyKey.error.test.ts new file mode 100644 index 000000000..397e86561 --- /dev/null +++ b/apps/api/src/routes/v1_keys_verifyKey.error.test.ts @@ -0,0 +1,94 @@ +import { describe, expect, test } from "vitest"; + +import type { ErrorResponse } from "@/pkg/errors"; +import { schema } from "@unkey/db"; +import { newId } from "@unkey/id"; +import { IntegrationHarness } from "src/pkg/testutil/integration-harness"; + +describe("with identity", () => { + describe("with ratelimits", () => { + describe("missing ratelimit", () => { + test("returns 400 and a useful error message", async (t) => { + const h = await IntegrationHarness.init(t); + + const identity = { + id: newId("test"), + workspaceId: h.resources.userWorkspace.id, + externalId: newId("test"), + }; + await h.db.primary.insert(schema.identities).values(identity); + await h.db.primary.insert(schema.ratelimits).values({ + id: newId("test"), + workspaceId: h.resources.userWorkspace.id, + name: "existing-ratelimit", + identityId: identity.id, + limit: 100, + duration: 60_000, + }); + + const key = await h.createKey({ identityId: identity.id }); + + const res = await h.post({ + url: "/v1/keys.verifyKey", + headers: { + "Content-Type": "application/json", + }, + body: { + key: key.key, + ratelimits: [ + { + name: "does-not-exist", + }, + ], + }, + }); + + expect(res.status).toEqual(400); + expect(res.body.error.message).toMatchInlineSnapshot( + `"ratelimit "does-not-exist" was requested but does not exist for key "${key.keyId}" nor identity { id: ${identity.id}, externalId: ${identity.externalId}}"`, + ); + }); + }); + }); +}); + +describe("without identity", () => { + describe("with ratelimits", () => { + describe("missing ratelimit", () => { + test("returns 400 and a useful error message", async (t) => { + const h = await IntegrationHarness.init(t); + + const key = await h.createKey(); + + await h.db.primary.insert(schema.ratelimits).values({ + id: newId("test"), + workspaceId: h.resources.userWorkspace.id, + name: "existing-ratelimit", + keyId: key.keyId, + limit: 100, + duration: 60_000, + }); + + const res = await h.post({ + url: "/v1/keys.verifyKey", + headers: { + "Content-Type": "application/json", + }, + body: { + key: key.key, + ratelimits: [ + { + name: "does-not-exist", + }, + ], + }, + }); + + expect(res.status).toEqual(400); + expect(res.body.error.message).toMatchInlineSnapshot( + `"ratelimit "does-not-exist" was requested but does not exist for key "${key.keyId}" and there is no identity connected"`, + ); + }); + }); + }); +}); diff --git a/apps/api/src/routes/v1_keys_verifyKey.ratelimit_accuracy.test.ts b/apps/api/src/routes/v1_keys_verifyKey.ratelimit_accuracy.test.ts index bed2efaf9..8cdf91c43 100644 --- a/apps/api/src/routes/v1_keys_verifyKey.ratelimit_accuracy.test.ts +++ b/apps/api/src/routes/v1_keys_verifyKey.ratelimit_accuracy.test.ts @@ -31,7 +31,7 @@ const testCases: { }, { limit: 20, - duration: 5000, + duration: 30000, rps: 50, seconds: 60, }, diff --git a/apps/api/src/routes/v1_keys_verifyKey.ts b/apps/api/src/routes/v1_keys_verifyKey.ts index 5cd232b3d..da2d90456 100644 --- a/apps/api/src/routes/v1_keys_verifyKey.ts +++ b/apps/api/src/routes/v1_keys_verifyKey.ts @@ -349,6 +349,22 @@ export const registerV1KeysVerifyKey = (app: App) => : undefined, }; c.executionCtx.waitUntil( + // new clickhouse + analytics.insertKeyVerification({ + request_id: c.get("requestId"), + time: Date.now(), + workspace_id: val.key.workspaceId, + key_space_id: val.key.keyAuthId, + key_id: val.key.id, + // @ts-expect-error + region: c.req.raw.cf.colo ?? "", + outcome: val.code ?? "VALID", + identity_id: val.identity?.id, + }), + ); + + c.executionCtx.waitUntil( + // old tinybird analytics.ingestKeyVerification({ workspaceId: val.key.workspaceId, apiId: val.api.id, diff --git a/apps/api/src/routes/v1_ratelimit_limit.accuracy.test.ts b/apps/api/src/routes/v1_ratelimit_limit.accuracy.test.ts index 762450975..b4a5a9a73 100644 --- a/apps/api/src/routes/v1_ratelimit_limit.accuracy.test.ts +++ b/apps/api/src/routes/v1_ratelimit_limit.accuracy.test.ts @@ -96,7 +96,7 @@ for (const { limit, duration, rps, seconds } of testCases) { }, 0); const exactLimit = Math.min(results.length, (limit / (duration / 1000)) * seconds); - const upperLimit = Math.round(exactLimit * 1.5); + const upperLimit = Math.round(exactLimit * 2.5); const lowerLimit = Math.round(exactLimit * 0.95); console.info({ name, passed, exactLimit, upperLimit, lowerLimit }); t.expect(passed).toBeGreaterThanOrEqual(lowerLimit); diff --git a/apps/bounce/.editorconfig b/apps/bounce/.editorconfig deleted file mode 100644 index a727df347..000000000 --- a/apps/bounce/.editorconfig +++ /dev/null @@ -1,12 +0,0 @@ -# http://editorconfig.org -root = true - -[*] -indent_style = tab -end_of_line = lf -charset = utf-8 -trim_trailing_whitespace = true -insert_final_newline = true - -[*.yml] -indent_style = space diff --git a/apps/bounce/.gitignore b/apps/bounce/.gitignore deleted file mode 100644 index 3b0fe33c4..000000000 --- a/apps/bounce/.gitignore +++ /dev/null @@ -1,172 +0,0 @@ -# Logs - -logs -_.log -npm-debug.log_ -yarn-debug.log* -yarn-error.log* -lerna-debug.log* -.pnpm-debug.log* - -# Diagnostic reports (https://nodejs.org/api/report.html) - -report.[0-9]_.[0-9]_.[0-9]_.[0-9]_.json - -# Runtime data - -pids -_.pid -_.seed -\*.pid.lock - -# Directory for instrumented libs generated by jscoverage/JSCover - -lib-cov - -# Coverage directory used by tools like istanbul - -coverage -\*.lcov - -# nyc test coverage - -.nyc_output - -# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) - -.grunt - -# Bower dependency directory (https://bower.io/) - -bower_components - -# node-waf configuration - -.lock-wscript - -# Compiled binary addons (https://nodejs.org/api/addons.html) - -build/Release - -# Dependency directories - -node_modules/ -jspm_packages/ - -# Snowpack dependency directory (https://snowpack.dev/) - -web_modules/ - -# TypeScript cache - -\*.tsbuildinfo - -# Optional npm cache directory - -.npm - -# Optional eslint cache - -.eslintcache - -# Optional stylelint cache - -.stylelintcache - -# Microbundle cache - -.rpt2_cache/ -.rts2_cache_cjs/ -.rts2_cache_es/ -.rts2_cache_umd/ - -# Optional REPL history - -.node_repl_history - -# Output of 'npm pack' - -\*.tgz - -# Yarn Integrity file - -.yarn-integrity - -# dotenv environment variable files - -.env -.env.development.local -.env.test.local -.env.production.local -.env.local - -# parcel-bundler cache (https://parceljs.org/) - -.cache -.parcel-cache - -# Next.js build output - -.next -out - -# Nuxt.js build / generate output - -.nuxt -dist - -# Gatsby files - -.cache/ - -# Comment in the public line in if your project uses Gatsby and not Next.js - -# https://nextjs.org/blog/next-9-1#public-directory-support - -# public - -# vuepress build output - -.vuepress/dist - -# vuepress v2.x temp and cache directory - -.temp -.cache - -# Docusaurus cache and generated files - -.docusaurus - -# Serverless directories - -.serverless/ - -# FuseBox cache - -.fusebox/ - -# DynamoDB Local files - -.dynamodb/ - -# TernJS port file - -.tern-port - -# Stores VSCode versions used for testing VSCode extensions - -.vscode-test - -# yarn v2 - -.yarn/cache -.yarn/unplugged -.yarn/build-state.yml -.yarn/install-state.gz -.pnp.\* - -# wrangler project - -.dev.vars -.wrangler/ diff --git a/apps/bounce/.prettierrc b/apps/bounce/.prettierrc deleted file mode 100644 index 5c7b5d3c7..000000000 --- a/apps/bounce/.prettierrc +++ /dev/null @@ -1,6 +0,0 @@ -{ - "printWidth": 140, - "singleQuote": true, - "semi": true, - "useTabs": true -} diff --git a/apps/bounce/package.json b/apps/bounce/package.json deleted file mode 100644 index 42fc183f4..000000000 --- a/apps/bounce/package.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "name": "bounce", - "version": "0.0.0", - "private": true, - "scripts": { - "deploy": "wrangler deploy", - "dev": "wrangler dev", - "start": "wrangler dev", - "test": "vitest" - }, - "devDependencies": { - "@cloudflare/vitest-pool-workers": "^0.4.5", - "wrangler": "^3.60.3", - "vitest": "1.5.0" - } -} diff --git a/apps/bounce/src/index.js b/apps/bounce/src/index.js deleted file mode 100644 index 2ffd913b1..000000000 --- a/apps/bounce/src/index.js +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Welcome to Cloudflare Workers! This is your first worker. - * - * - Run `npm run dev` in your terminal to start a development server - * - Open a browser tab at http://localhost:8787/ to see your worker in action - * - Run `npm run deploy` to publish your worker - * - * Learn more at https://developers.cloudflare.com/workers/ - */ - -const ids = new Set(); - -function getIdentifier() { - if (ids.size === 0 || Math.random() > 0.9) { - ids.add(Math.random().toString().slice(0, 4)); - } - - return Array.from(ids.values())[Math.floor(Math.random() * ids.size)]; -} - -export default { - async fetch(request) { - const pathname = new URL(request.url).pathname; - - const identifier = getIdentifier(); - - const url = `https://api.unkey.cloud${pathname}`; - // const res = await fetch(url, { - // method: request.method, - // headers: request.headers, - // body: JSON.stringify({ - // identifier, - // limit: 10000, - // duration: 60000, - // }), - // }); - - const res = await fetch(url, { - method: request.method, - headers: request.headers, - body: JSON.stringify({ - keyring: identifier, - data: Math.random().toString(), - }), - }); - - const body = await res.text(); - - console.info("response", res.status, body); - - return new Response(body, { - headers: res.headers, - status: res.status, - }); - }, -}; diff --git a/apps/bounce/test/index.spec.js b/apps/bounce/test/index.spec.js deleted file mode 100644 index e56d4ae96..000000000 --- a/apps/bounce/test/index.spec.js +++ /dev/null @@ -1,20 +0,0 @@ -import { SELF, createExecutionContext, env, waitOnExecutionContext } from "cloudflare:test"; -import { describe, expect, it } from "vitest"; -import worker from "../src"; - -describe("Hello World worker", () => { - it("responds with Hello World! (unit style)", async () => { - const request = new Request("http://example.com"); - // Create an empty context to pass to `worker.fetch()`. - const ctx = createExecutionContext(); - const response = await worker.fetch(request, env, ctx); - // Wait for all `Promise`s passed to `ctx.waitUntil()` to settle before running test assertions - await waitOnExecutionContext(ctx); - expect(await response.text()).toMatchInlineSnapshot(`"Hello World!"`); - }); - - it("responds with Hello World! (integration style)", async () => { - const response = await SELF.fetch(request, env, ctx); - expect(await response.text()).toMatchInlineSnapshot(`"Hello World!"`); - }); -}); diff --git a/apps/bounce/vitest.config.js b/apps/bounce/vitest.config.js deleted file mode 100644 index 973627c26..000000000 --- a/apps/bounce/vitest.config.js +++ /dev/null @@ -1,11 +0,0 @@ -import { defineWorkersConfig } from "@cloudflare/vitest-pool-workers/config"; - -export default defineWorkersConfig({ - test: { - poolOptions: { - workers: { - wrangler: { configPath: "./wrangler.toml" }, - }, - }, - }, -}); diff --git a/apps/bounce/wrangler.toml b/apps/bounce/wrangler.toml deleted file mode 100644 index cfba83e0c..000000000 --- a/apps/bounce/wrangler.toml +++ /dev/null @@ -1,108 +0,0 @@ -#:schema node_modules/wrangler/config-schema.json -name = "bounce" -main = "src/index.js" -compatibility_date = "2024-06-01" -compatibility_flags = ["nodejs_compat"] - -# Automatically place your workloads in an optimal location to minimize latency. -# If you are running back-end logic in a Worker, running it closer to your back-end infrastructure -# rather than the end user may result in better performance. -# Docs: https://developers.cloudflare.com/workers/configuration/smart-placement/#smart-placement -# [placement] -# mode = "smart" - -# Variable bindings. These are arbitrary, plaintext strings (similar to environment variables) -# Docs: -# - https://developers.cloudflare.com/workers/wrangler/configuration/#environment-variables -# Note: Use secrets to store sensitive data. -# - https://developers.cloudflare.com/workers/configuration/secrets/ -# [vars] -# MY_VARIABLE = "production_value" - -# Bind the Workers AI model catalog. Run machine learning models, powered by serverless GPUs, on Cloudflare’s global network -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#workers-ai -# [ai] -# binding = "AI" - -# Bind an Analytics Engine dataset. Use Analytics Engine to write analytics within your Pages Function. -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#analytics-engine-datasets -# [[analytics_engine_datasets]] -# binding = "MY_DATASET" - -# Bind a headless browser instance running on Cloudflare's global network. -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#browser-rendering -# [browser] -# binding = "MY_BROWSER" - -# Bind a D1 database. D1 is Cloudflare’s native serverless SQL database. -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#d1-databases -# [[d1_databases]] -# binding = "MY_DB" -# database_name = "my-database" -# database_id = "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" - -# Bind a dispatch namespace. Use Workers for Platforms to deploy serverless functions programmatically on behalf of your customers. -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#dispatch-namespace-bindings-workers-for-platforms -# [[dispatch_namespaces]] -# binding = "MY_DISPATCHER" -# namespace = "my-namespace" - -# Bind a Durable Object. Durable objects are a scale-to-zero compute primitive based on the actor model. -# Durable Objects can live for as long as needed. Use these when you need a long-running "server", such as in realtime apps. -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#durable-objects -# [[durable_objects.bindings]] -# name = "MY_DURABLE_OBJECT" -# class_name = "MyDurableObject" - -# Durable Object migrations. -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#migrations -# [[migrations]] -# tag = "v1" -# new_classes = ["MyDurableObject"] - -# Bind a Hyperdrive configuration. Use to accelerate access to your existing databases from Cloudflare Workers. -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#hyperdrive -# [[hyperdrive]] -# binding = "MY_HYPERDRIVE" -# id = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" - -# Bind a KV Namespace. Use KV as persistent storage for small key-value pairs. -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#kv-namespaces -# [[kv_namespaces]] -# binding = "MY_KV_NAMESPACE" -# id = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" - -# Bind an mTLS certificate. Use to present a client certificate when communicating with another service. -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#mtls-certificates -# [[mtls_certificates]] -# binding = "MY_CERTIFICATE" -# certificate_id = "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" - -# Bind a Queue producer. Use this binding to schedule an arbitrary task that may be processed later by a Queue consumer. -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#queues -# [[queues.producers]] -# binding = "MY_QUEUE" -# queue = "my-queue" - -# Bind a Queue consumer. Queue Consumers can retrieve tasks scheduled by Producers to act on them. -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#queues -# [[queues.consumers]] -# queue = "my-queue" - -# Bind an R2 Bucket. Use R2 to store arbitrarily large blobs of data, such as files. -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#r2-buckets -# [[r2_buckets]] -# binding = "MY_BUCKET" -# bucket_name = "my-bucket" - -# Bind another Worker service. Use this binding to call another Worker without network overhead. -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#service-bindings -# [[services]] -# binding = "MY_SERVICE" -# service = "my-service" - -# Bind a Vectorize index. Use to store and query vector embeddings for semantic search, classification and other vector search use-cases. -# Docs: https://developers.cloudflare.com/workers/wrangler/configuration/#vectorize-indexes -# [[vectorize]] -# binding = "MY_INDEX" -# index_name = "my-index" diff --git a/apps/dashboard/app/(app)/logs/page.tsx b/apps/dashboard/app/(app)/logs/page.tsx new file mode 100644 index 000000000..a8a40fbe3 --- /dev/null +++ b/apps/dashboard/app/(app)/logs/page.tsx @@ -0,0 +1,28 @@ +import { PageHeader } from "@/components/dashboard/page-header"; +import { getTenantId } from "@/lib/auth"; +import { getLogs } from "@/lib/clickhouse"; +import { db } from "@/lib/db"; + +export const revalidate = 0; + +export default async function Page() { + const tenantId = getTenantId(); + + const workspace = await db.query.workspaces.findFirst({ + where: (table, { and, eq, isNull }) => + and(eq(table.tenantId, tenantId), isNull(table.deletedAt)), + }); + if (!workspace) { + return
Workspace with tenantId: {tenantId} not found
; + } + + const logs = await getLogs({ workspaceId: workspace.id, limit: 10 }); + + return ( +
+ + +
{JSON.stringify(logs, null, 2)}
+
+ ); +} diff --git a/apps/dashboard/app/(app)/settings/billing/plans/button.tsx b/apps/dashboard/app/(app)/settings/billing/plans/button.tsx index fa15a8a15..8b10427ec 100644 --- a/apps/dashboard/app/(app)/settings/billing/plans/button.tsx +++ b/apps/dashboard/app/(app)/settings/billing/plans/button.tsx @@ -45,6 +45,18 @@ export const ChangePlanButton: React.FC = ({ workspace, newPlan, label }) }, }); + const handleClick = () => { + const hasPaymentMethod = !!workspace.stripeCustomerId; + if (!hasPaymentMethod && newPlan === "pro") { + return router.push(`/settings/billing/stripe?new_plan=${newPlan}`); + } + + changePlan.mutateAsync({ + workspaceId: workspace.id, + plan: newPlan === "free" ? "free" : "pro", + }); + }; + const isSamePlan = workspace.plan === newPlan; return ( @@ -90,16 +102,7 @@ export const ChangePlanButton: React.FC = ({ workspace, newPlan, label }) - diff --git a/apps/dashboard/app/(app)/settings/billing/stripe/page.tsx b/apps/dashboard/app/(app)/settings/billing/stripe/page.tsx index 28faa4cb7..cfc9a343a 100644 --- a/apps/dashboard/app/(app)/settings/billing/stripe/page.tsx +++ b/apps/dashboard/app/(app)/settings/billing/stripe/page.tsx @@ -7,7 +7,14 @@ import { headers } from "next/headers"; import { redirect } from "next/navigation"; import Stripe from "stripe"; -export default async function StripeRedirect() { +type Props = { + searchParams: { + new_plan: "free" | "pro" | undefined; + }; +}; + +export default async function StripeRedirect(props: Props) { + const { new_plan } = props.searchParams; const tenantId = getTenantId(); if (!tenantId) { return redirect("/auth/sign-in"); @@ -53,7 +60,12 @@ export default async function StripeRedirect() { const baseUrl = process.env.VERCEL_URL ? "https://app.unkey.com" : "http://localhost:3000"; // do not use `new URL(...).searchParams` here, because it will escape the curly braces and stripe will not replace them with the session id - const successUrl = `${baseUrl}/settings/billing/stripe/success?session_id={CHECKOUT_SESSION_ID}`; + let successUrl = `${baseUrl}/settings/billing/stripe/success?session_id={CHECKOUT_SESSION_ID}`; + + // if they're coming from the change plan flow, pass along the new plan param + if (new_plan && new_plan !== ws.plan) { + successUrl += `&new_plan=${new_plan}`; + } const cancelUrl = headers().get("referer") ?? baseUrl; const session = await stripe.checkout.sessions.create({ diff --git a/apps/dashboard/app/(app)/settings/billing/stripe/success/page.tsx b/apps/dashboard/app/(app)/settings/billing/stripe/success/page.tsx index 9e2492c10..bda76f02a 100644 --- a/apps/dashboard/app/(app)/settings/billing/stripe/success/page.tsx +++ b/apps/dashboard/app/(app)/settings/billing/stripe/success/page.tsx @@ -3,6 +3,7 @@ import { Code } from "@/components/ui/code"; import { getTenantId } from "@/lib/auth"; import { db, eq, schema } from "@/lib/db"; import { stripeEnv } from "@/lib/env"; +import { PostHogClient } from "@/lib/posthog"; import { currentUser } from "@clerk/nextjs"; import { redirect } from "next/navigation"; import Stripe from "stripe"; @@ -10,10 +11,12 @@ import Stripe from "stripe"; type Props = { searchParams: { session_id: string; + new_plan: "free" | "pro" | undefined; }; }; export default async function StripeSuccess(props: Props) { + const { session_id, new_plan } = props.searchParams; const tenantId = getTenantId(); if (!tenantId) { return redirect("/auth/sign-in"); @@ -44,14 +47,14 @@ export default async function StripeSuccess(props: Props) { typescript: true, }); - const session = await stripe.checkout.sessions.retrieve(props.searchParams.session_id); + const session = await stripe.checkout.sessions.retrieve(session_id); if (!session) { return ( Stripe session not found - The Stripe session {props.searchParams.session_id} you are trying to access - does not exist. Please contact support@unkey.dev. + The Stripe session {session_id} you are trying to access does not exist. + Please contact support@unkey.dev. ); @@ -69,14 +72,25 @@ export default async function StripeSuccess(props: Props) { ); } + const isChangingPlan = new_plan && new_plan !== ws.plan; + await db .update(schema.workspaces) .set({ stripeCustomerId: customer.id, stripeSubscriptionId: session.subscription as string, trialEnds: null, + ...(isChangingPlan ? { plan: new_plan } : {}), }) .where(eq(schema.workspaces.id, ws.id)); + if (isChangingPlan) { + PostHogClient.capture({ + distinctId: tenantId, + event: "plan_changed", + properties: { plan: new_plan, workspace: ws.id }, + }); + } + return redirect("/settings/billing"); } diff --git a/apps/dashboard/app/(app)/settings/root-keys/new/client.tsx b/apps/dashboard/app/(app)/settings/root-keys/new/client.tsx index 80ebc511b..d045a59e2 100644 --- a/apps/dashboard/app/(app)/settings/root-keys/new/client.tsx +++ b/apps/dashboard/app/(app)/settings/root-keys/new/client.tsx @@ -8,6 +8,7 @@ import { Code } from "@/components/ui/code"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Checkbox } from "@/components/ui/checkbox"; +import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; import { Dialog, DialogClose, @@ -22,6 +23,7 @@ import { toast } from "@/components/ui/toaster"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { trpc } from "@/lib/trpc/client"; import type { UnkeyPermission } from "@unkey/rbac"; +import { ChevronRight } from "lucide-react"; import { useRouter } from "next/navigation"; import { useState } from "react"; import { apiPermissions, workspacePermissions } from "../[keyId]/permissions/permissions"; @@ -38,7 +40,7 @@ export const Client: React.FC = ({ apis }) => { const [selectedPermissions, setSelectedPermissions] = useState([]); const key = trpc.rootKey.create.useMutation({ - onError(err) { + onError(err: { message: string }) { console.error(err); toast.error(err.message); }, @@ -65,6 +67,23 @@ export const Client: React.FC = ({ apis }) => { }); }; + type CardStates = { + [key: string]: boolean; + }; + + const initialCardStates: CardStates = {}; + apis.forEach((api) => { + initialCardStates[api.id] = false; + }); + const [cardStatesMap, setCardStatesMap] = useState(initialCardStates); + + const toggleCard = (apiId: string) => { + setCardStatesMap((prevStates) => ({ + ...prevStates, + [apiId]: !prevStates[apiId], + })); + }; + return (
@@ -132,56 +151,80 @@ export const Client: React.FC = ({ apis }) => { {apis.map((api) => ( - - - {api.name} - - Permissions scoped to this API. Enabling these roles only grants access to this - specific API. - - - -
- {Object.entries(apiPermissions(api.id)).map(([category, roles]) => { - const allPermissionNames = Object.values(roles).map(({ permission }) => permission); - const isAllSelected = allPermissionNames.every((permission) => - selectedPermissions.includes(permission), - ); + { + toggleCard(api.id); + }} + > + + + + {api.name} + + + + Permissions scoped to this API. Enabling these roles only grants access to this + specific API. + + + + + +
+ {Object.entries(apiPermissions(api.id)).map(([category, roles]) => { + const allPermissionNames = Object.values(roles).map( + ({ permission }) => permission, + ); + const isAllSelected = allPermissionNames.every((permission) => + selectedPermissions.includes(permission), + ); - return ( -
-
- {category}} - description={`Select all for ${category} permissions for this API`} - checked={isAllSelected} - setChecked={(isChecked) => { - allPermissionNames.forEach((permission) => { - handleSetChecked(permission, isChecked); - }); - }} - /> -
+ return ( +
+
+ {category}} + description={`Select all for ${category} permissions for this API`} + checked={isAllSelected} + setChecked={(isChecked) => { + allPermissionNames.forEach((permission) => { + handleSetChecked(permission, isChecked); + }); + }} + /> +
-
- {Object.entries(roles).map(([action, { description, permission }]) => ( - handleSetChecked(permission, isChecked)} - /> - ))} -
-
- ); - })} -
- - +
+ {Object.entries(roles).map(([action, { description, permission }]) => ( + handleSetChecked(permission, isChecked)} + /> + ))} +
+
+ ); + })} +
+
+ +
+ ))}