Skip to content

Commit

Permalink
feat: allow creating limited use keys by setting 'remaining'
Browse files Browse the repository at this point in the history
  • Loading branch information
chronark committed Jul 11, 2023
1 parent 0e2cdb8 commit f4eafe6
Show file tree
Hide file tree
Showing 25 changed files with 511 additions and 58 deletions.
7 changes: 4 additions & 3 deletions apps/api/pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package cache_test

import (
"context"
"github.com/stretchr/testify/require"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/unkeyed/unkey/apps/api/pkg/cache"
"github.com/unkeyed/unkey/apps/api/pkg/logging"
)
Expand All @@ -18,7 +19,7 @@ func TestWriteRead(t *testing.T) {
RefreshFromOrigin: func(ctx context.Context, id string) (string, error) {
return "hello", nil
},
Logger: logging.New(),
Logger: logging.NewNoopLogger(),
})
c.Set(context.Background(), "key", "value")
value, cacheHit := c.Get(context.Background(), "key")
Expand Down Expand Up @@ -57,7 +58,7 @@ func TestRefresh(t *testing.T) {
spy.counter.Add(1)
return "hello", nil
},
Logger: logging.New(),
Logger: logging.NewNoopLogger(),
})
c.Set(context.Background(), "key", "value")
time.Sleep(time.Second * 2)
Expand Down
9 changes: 9 additions & 0 deletions apps/api/pkg/database/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ func keyModelToEntity(model *models.Key) (entities.Key, error) {
RefillInterval: model.RatelimitRefillInterval.Int64,
}
}

if model.RemainingRequests.Valid {
key.Remaining.Enabled = true
key.Remaining.Remaining = model.RemainingRequests.Int64
}
return key, nil
}

Expand Down Expand Up @@ -91,6 +96,10 @@ func keyEntityToModel(e entities.Key) (*models.Key, error) {
key.RatelimitRefillInterval = sql.NullInt64{Int64: e.Ratelimit.RefillInterval, Valid: e.Ratelimit.RefillRate > 0}
}

if e.Remaining.Enabled {
key.RemainingRequests = sql.NullInt64{Int64: e.Remaining.Remaining, Valid: true}
}

return key, nil

}
Expand Down
55 changes: 55 additions & 0 deletions apps/api/pkg/database/conversion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package database
import (
"database/sql"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/unkeyed/unkey/apps/api/pkg/database/models"
Expand Down Expand Up @@ -80,3 +81,57 @@ func Test_apiModelToEntity_WithIpWithlist(t *testing.T) {
require.Equal(t, m.WorkspaceID, e.WorkspaceId)
require.Equal(t, []string{"1.1.1.1", "2.2.2.2"}, e.IpWhitelist)
}

func Test_keyModelToEntity(t *testing.T) {

m := &models.Key{
ID: uid.Key(),
APIID: uid.Api(),
WorkspaceID: uid.Workspace(),
Hash: "abc",
Start: "abc",
CreatedAt: time.Now(),
}

e, err := keyModelToEntity(m)
require.NoError(t, err)
require.Equal(t, m.ID, e.Id)
require.Equal(t, m.APIID, e.ApiId)
require.Equal(t, m.WorkspaceID, e.WorkspaceId)
require.Equal(t, m.Hash, e.Hash)
require.Equal(t, m.Start, e.Start)
require.Equal(t, m.CreatedAt, e.CreatedAt)
require.Nil(t, e.Ratelimit)
}

func Test_keyModelToEntity_WithNullFields(t *testing.T) {

m := &models.Key{
ID: uid.Key(),
APIID: uid.Api(),
WorkspaceID: uid.Workspace(),
Hash: "abc",
Start: "abc",
CreatedAt: time.Now(),
RemainingRequests: sql.NullInt64{Int64: 99, Valid: true},
RatelimitType: sql.NullString{String: "fast", Valid: true},
RatelimitLimit: sql.NullInt64{Int64: 10, Valid: true},
RatelimitRefillRate: sql.NullInt64{Int64: 1, Valid: true},
RatelimitRefillInterval: sql.NullInt64{Int64: 1000, Valid: true},
}

e, err := keyModelToEntity(m)
require.NoError(t, err)
require.Equal(t, m.ID, e.Id)
require.Equal(t, m.APIID, e.ApiId)
require.Equal(t, m.WorkspaceID, e.WorkspaceId)
require.Equal(t, m.Hash, e.Hash)
require.Equal(t, m.Start, e.Start)
require.Equal(t, m.CreatedAt, e.CreatedAt)
require.Equal(t, true, e.Remaining.Enabled)
require.Equal(t, int64(99), e.Remaining.Remaining)
require.Equal(t, "fast", e.Ratelimit.Type)
require.Equal(t, int64(10), e.Ratelimit.Limit)
require.Equal(t, int64(1), e.Ratelimit.RefillRate)
require.Equal(t, int64(1000), e.Ratelimit.RefillInterval)
}
1 change: 1 addition & 0 deletions apps/api/pkg/database/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ type Database interface {
CreateWorkspace(ctx context.Context, newWorkspace entities.Workspace) error

GetWorkspace(ctx context.Context, workspaceId string) (entities.Workspace, error)
DecrementRemainingKeyUsage(ctx context.Context, keyId string) (int64, error)
}
52 changes: 52 additions & 0 deletions apps/api/pkg/database/key_decrement_remaining.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package database

import (
"context"
"database/sql"
"fmt"
)

// Decrement the `remaining` field and return the new value
// The returned value is the number of remaining verifications after the current one.
// This means the returned value can be negative, for example when the remaining is 0 and we call this function.
func (db *database) DecrementRemainingKeyUsage(ctx context.Context, keyId string) (int64, error) {
tx, err := db.write().BeginTx(ctx, nil)
if err != nil {
return 0, fmt.Errorf("unable to start transaction: %w", err)
}

_, err = tx.Exec(`UPDATE unkey.keys SET remaining_requests = remaining_requests - 1 WHERE id = ?`, keyId)
if err != nil {
rollbackErr := tx.Rollback()
if rollbackErr != nil {
return 0, fmt.Errorf("unable to roll back: %w", rollbackErr)
}
return 0, fmt.Errorf("unable to decrement: %w", err)
}

row := tx.QueryRow(`SELECT remaining_requests FROM unkey.keys WHERE id = ?`, keyId)
if err != nil {
rollbackErr := tx.Rollback()
if rollbackErr != nil {
return 0, fmt.Errorf("unable to roll back: %w", rollbackErr)
}
return 0, fmt.Errorf("unable to query: %w", err)
}
var remainingAfter sql.NullInt64
err = row.Scan(&remainingAfter)

if err != nil {
return 0, fmt.Errorf("unable to scan result: %w", err)
}
if !remainingAfter.Valid {
return 0, fmt.Errorf("this key did not have a remaining config")
}

err = tx.Commit()
if err != nil {
return 0, fmt.Errorf("unable to commit transaction: %w", err)
}

return remainingAfter.Int64, err

}
13 changes: 13 additions & 0 deletions apps/api/pkg/database/middleware/tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,16 @@ func (mw *tracingMiddleware) GetWorkspace(ctx context.Context, workspaceId strin
}
return keys, err
}

func (mw *tracingMiddleware) DecrementRemainingKeyUsage(ctx context.Context, keyId string) (int64, error) {
ctx, span := mw.t.Start(ctx, fmt.Sprintf("%s.decrementRemainingKeyUsage", mw.pkg), trace.WithAttributes(
attribute.String("keyId", keyId),
))
defer span.End()

remaining, err := mw.next.DecrementRemainingKeyUsage(ctx, keyId)
if err != nil {
span.RecordError(err)
}
return remaining, err
}
33 changes: 17 additions & 16 deletions apps/api/pkg/database/models/key.xo.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions apps/api/pkg/entities/entities.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ type Key struct {
Expires time.Time
Ratelimit *Ratelimit
ForWorkspaceId string
Remaining struct {
// Whether or not the value in `Remaining` makes any sense or is just a default
Enabled bool
Remaining int64
}
}

type Ratelimit struct {
Expand Down
1 change: 1 addition & 0 deletions apps/api/pkg/server/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ const (
INTERNAL_SERVER_ERROR ErrorCode = "INTERNAL_SERVER_ERROR"
RATELIMITED ErrorCode = "RATELIMITED"
FORBIDDEN ErrorCode = "FORBIDDEN"
USAGE_EXCEEDED ErrorCode = "USAGE_EXCEEDED"
)

type ErrorResponse struct {
Expand Down
8 changes: 8 additions & 0 deletions apps/api/pkg/server/key_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ type CreateKeyRequest struct {
// ForWorkspaceId is used internally when the frontend wants to create a new root key.
// Therefore we might not want to add this field to our docs.
ForWorkspaceId string `json:"forWorkspaceId"`

// How often this key may be used
// `undefined`, `0` or negative to disable
Remaining int64 `json:"remaining,omitempty"`
}

type CreateKeyResponse struct {
Expand Down Expand Up @@ -143,6 +147,10 @@ func (s *Server) createKey(c *fiber.Ctx) error {
if req.Expires > 0 {
newKey.Expires = time.UnixMilli(req.Expires)
}
if req.Remaining > 0 {
newKey.Remaining.Enabled = true
newKey.Remaining.Remaining = req.Remaining
}
if req.Ratelimit != nil {
newKey.Ratelimit = &entities.Ratelimit{
Type: req.Ratelimit.Type,
Expand Down
Loading

0 comments on commit f4eafe6

Please sign in to comment.