Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Public API Keys #140

Merged
merged 31 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
503a8bc
Base implementation of public API keys.
ro-tex Feb 17, 2022
3e77471
Add some clarity what the `env` image does.
ro-tex Feb 25, 2022
3890d4f
Handle public api keys during authentication.
ro-tex Feb 25, 2022
d375f02
Merge branch 'main' into ivo/pub_api_keys
ro-tex Feb 25, 2022
4d1b4e6
Fix a broken test (needs a custom struct to get a hidden field).
ro-tex Feb 25, 2022
1fd0fe9
Merge branch 'main' into ivo/pub_api_keys
ro-tex Feb 28, 2022
a6c3a34
Merge branch 'main' into ivo/pub_api_keys
ro-tex Mar 1, 2022
d4ed081
Move public API keys to their own set of endpoints, as they are suffi…
ro-tex Mar 2, 2022
61fed0f
Add a custom endpoint for checking speed limits based on public API k…
ro-tex Mar 2, 2022
a68a2f2
Refactor the user tier cache to always require a key and pass it befo…
ro-tex Mar 2, 2022
e7a4bee
Merge public and private API keys into one.
ro-tex Mar 3, 2022
f9cbcba
Add integration tests for public API Keys.
ro-tex Mar 4, 2022
e40ad87
Merge branch 'main' into ivo/pub_api_keys
ro-tex Mar 9, 2022
7f9c495
Let `userLimitsGetFromTier` handle quota exceeded.
ro-tex Mar 9, 2022
c0a6904
Let `apiKeyFromRequest` return a validated API key, thus simplifying …
ro-tex Mar 9, 2022
2a4059b
Add HealthGET tester method.
ro-tex Mar 10, 2022
28978d6
Add an integration test for public API keys usage.
ro-tex Mar 10, 2022
f491de4
Tester helpers for API keys.
ro-tex Mar 10, 2022
87b8eef
Add the rest of the API key integration tests.
ro-tex Mar 11, 2022
b32108d
Clean up.
ro-tex Mar 11, 2022
fb443e6
APIKeyPOST.Validate() returns descriptive errors.
ro-tex Mar 15, 2022
c0100d5
Address PR comments.
ro-tex Mar 15, 2022
d79e743
Move DB schema to a separate file.
ro-tex Mar 15, 2022
f0eaaac
Fix a nullpointer in tester.
ro-tex Mar 16, 2022
26a2332
Unparallelise tests that lead to data races.
ro-tex Mar 16, 2022
0323412
Update api/apikeys.go
ro-tex Mar 17, 2022
11e46ae
Update api/apikeys.go
ro-tex Mar 17, 2022
d24e0ff
Update api/apikeys.go
ro-tex Mar 17, 2022
26cdf8a
Merge branch 'main' into ivo/pub_api_keys
ro-tex Mar 17, 2022
f4bc82e
Add 404 errors.
ro-tex Mar 17, 2022
614c1ed
Merge branch 'main' into ivo/pub_api_keys
ro-tex Mar 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions api/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package api

import (
"context"
"net/http"
"strings"

"github.com/SkynetLabs/skynet-accounts/database"
"github.com/SkynetLabs/skynet-accounts/jwt"
jwt2 "github.com/lestrrat-go/jwx/jwt"
"gitlab.com/NebulousLabs/errors"
"go.mongodb.org/mongo-driver/bson/primitive"
)

// TODO Test the methods here which are still untested.
// - add integration tests

// userAndTokenByRequestToken scans the request for an authentication token,
// fetches the corresponding user from the database and returns both user and
// token.
func (api *API) userAndTokenByRequestToken(req *http.Request) (*database.User, jwt2.Token, error) {
token, err := tokenFromRequest(req)
if err != nil {
return nil, nil, errors.AddContext(err, "error fetching token from request")
}
sub, _, _, err := jwt.TokenFields(token)
if err != nil {
return nil, nil, errors.AddContext(err, "error decoding token from request")
}
u, err := api.staticDB.UserBySub(req.Context(), sub)
if err != nil {
return nil, nil, errors.AddContext(err, "error fetching user from database")
}
return u, token, nil
}

// userAndTokenByAPIKey extracts the APIKey or PubAPIKey from the requests and
// validates it. It then returns the user who owns it and a token for that user.
// It first checks the headers and then the query.
// This method accesses the database.
func (api *API) userAndTokenByAPIKey(req *http.Request) (*database.User, jwt2.Token, error) {
akStr, err := apiKeyFromRequest(req)
if err != nil {
return nil, nil, err
}
// We should only check for a PubAPIKey if this is a GET request for a valid
// skylink. We ignore the errors here because the API key might not be a
// public one.
if req.Method == http.MethodGet {
pak := database.PubAPIKey(akStr)
sl, err := database.ExtractSkylinkHash(req.RequestURI)
if err == nil && sl != "" && pak.IsValid() {
uID, err := api.userIDForPubAPIKey(req.Context(), pak, sl)
if err == nil {
return api.userAndTokenByUserID(req.Context(), uID)
}
}
}
// Check if this is a valid APIKey.
ak := database.APIKey(akStr)
if !ak.IsValid() {
return nil, nil, ErrInvalidAPIKey
}
uID, err := api.userIDForAPIKey(req.Context(), ak)
if err != nil {
return nil, nil, ErrInvalidAPIKey
}
return api.userAndTokenByUserID(req.Context(), uID)
}

// userAndTokenByUserID is a helper method that fetches a given user from the
// database based on their Key, issues a JWT token for them, and returns both
// of those.
func (api *API) userAndTokenByUserID(ctx context.Context, uid primitive.ObjectID) (*database.User, jwt2.Token, error) {
u, err := api.staticDB.UserByID(ctx, uid)
if err != nil {
return nil, nil, err
}
t, err := jwt.TokenForUser(u.Email, u.Sub)
return u, t, err
}

// userIDForAPIKey looks up the given APIKey and returns the Key of the user that
// issued it.
func (api *API) userIDForAPIKey(ctx context.Context, ak database.APIKey) (primitive.ObjectID, error) {
akRec, err := api.staticDB.APIKeyGetRecord(ctx, ak)
if err != nil {
return primitive.ObjectID{}, err
}
return akRec.UserID, nil
}

// userIDForPubAPIKey looks up the given PubAPIKey, validates that the target
// skylink is covered by it, and returns the Key of the user that issued the
// PubAPIKey.
func (api *API) userIDForPubAPIKey(ctx context.Context, pak database.PubAPIKey, sl string) (primitive.ObjectID, error) {
pakRec, err := api.staticDB.PubAPIKeyGetRecord(ctx, pak)
if err != nil {
return primitive.ObjectID{}, err
}
for _, s := range pakRec.Skylinks {
if sl == s {
return pakRec.UserID, nil
}
}
return primitive.ObjectID{}, database.ErrUserNotFound
}

// apiKeyFromRequest extracts the API key from the request and returns it.
// It first checks the headers and then the query.
func apiKeyFromRequest(r *http.Request) (string, error) {
// Check the headers for an API key.
akStr := r.Header.Get(APIKeyHeader)
// If there is no API key in the headers, try the query.
if akStr == "" {
akStr = r.FormValue("apiKey")
}
if akStr == "" {
return "", ErrNoAPIKey
}
return akStr, nil
}

// tokenFromRequest extracts the JWT token from the request and returns it.
// It first checks the authorization header and then the cookies.
// The token is validated before being returned.
func tokenFromRequest(r *http.Request) (jwt2.Token, error) {
var tokenStr string
// Check the headers for a token.
parts := strings.Split(r.Header.Get("Authorization"), "Bearer")
if len(parts) == 2 {
tokenStr = strings.TrimSpace(parts[1])
} else {
// Check the cookie for a token.
cookie, err := r.Cookie(CookieName)
if errors.Contains(err, http.ErrNoCookie) {
return nil, ErrNoToken
}
if err != nil {
return nil, errors.AddContext(err, "cookie exists but it's not valid")
}
err = secureCookie.Decode(CookieName, cookie.Value, &tokenStr)
if err != nil {
return nil, errors.AddContext(err, "failed to decode token")
}
}
token, err := jwt.ValidateToken(tokenStr)
if err != nil {
return nil, errors.AddContext(err, "failed to validate token")
}
return token, nil
}
6 changes: 3 additions & 3 deletions api/routes_test.go → api/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestAPIKeyFromRequest(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if string(tk) != token {
if tk != token {
t.Fatalf("Expected '%s', got '%s'.", token, tk)
}

Expand All @@ -45,10 +45,10 @@ func TestAPIKeyFromRequest(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if string(tk) == token {
if tk == token {
t.Fatal("Form token took precedence over headers token.")
}
if string(tk) != token2 {
if tk != token2 {
t.Fatalf("Expected '%s', got '%s'.", token2, tk)
}
}
Expand Down
12 changes: 9 additions & 3 deletions api/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ func (utc *userTierCache) Get(sub string) (int, bool) {
return ce.Tier, true
}

// Set stores the user's tier in the cache.
func (utc *userTierCache) Set(u *database.User) {
// Set stores the user's tier in the cache. If the customCacheKey is not empty,
// it will be used to store the user in the cache, otherwise the user's sub will
// be used.
func (utc *userTierCache) Set(u *database.User, customCacheKey string) {
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
var ce userTierCacheEntry
now := time.Now().UTC()
if u.QuotaExceeded {
Expand All @@ -62,6 +64,10 @@ func (utc *userTierCache) Set(u *database.User) {
}
}
utc.mu.Lock()
utc.cache[u.Sub] = ce
if customCacheKey == "" {
utc.cache[u.Sub] = ce
} else {
utc.cache[customCacheKey] = ce
}
utc.mu.Unlock()
}
27 changes: 25 additions & 2 deletions api/cache_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package api

import (
"encoding/base64"
"testing"
"time"

"github.com/SkynetLabs/skynet-accounts/database"
"gitlab.com/NebulousLabs/fastrand"
)

// TestUserTierCache tests that working with userTierCache works as expected.
Expand All @@ -22,7 +24,7 @@ func TestUserTierCache(t *testing.T) {
t.Fatalf("Expected to get tier %d and %t, got %d and %t.", database.TierAnonymous, false, tier, ok)
}
// Set the use in the cache.
cache.Set(u)
cache.Set(u, "")
// Check again.
tier, ok = cache.Get(u.Sub)
if !ok || tier != u.Tier {
Expand All @@ -41,10 +43,31 @@ func TestUserTierCache(t *testing.T) {
timeToMonthRollover := 30 * time.Minute
u.SubscribedUntil = time.Now().UTC().Add(timeToMonthRollover)
// Update the cache.
cache.Set(u)
cache.Set(u, "")
// Expect the cache entry's ExpiresAt to be after 30 minutes.
timeIn30 := time.Now().UTC().Add(time.Hour - timeToMonthRollover)
if ce.ExpiresAt.After(timeIn30) && ce.ExpiresAt.Before(timeIn30.Add(time.Second)) {
t.Fatalf("Expected ExpiresAt to be within 1 second of %s, but it was %s (off by %d ns)", timeIn30.String(), ce.ExpiresAt.String(), (time.Hour - timeIn30.Sub(ce.ExpiresAt)).Nanoseconds())
}

// Create a new API key.
ak := database.APIKey(base64.URLEncoding.EncodeToString(fastrand.Bytes(database.PubKeySize)))
if !ak.IsValid() {
t.Fatal("Invalid API key.")
}
// Try to get a value from the cache. Expect this to fail.
_, ok = cache.Get(string(ak))
if ok {
t.Fatal("Did not expect to get a cache entry!")
}
// Update the cache with a custom key.
cache.Set(u, string(ak))
// Fetch the data for the custom key.
tier, ok = cache.Get(string(ak))
if !ok {
t.Fatal("Expected the entry to exist.")
}
if tier != u.Tier {
t.Fatalf("Expected tier %+v, got %+v", u.Tier, tier)
}
}
42 changes: 36 additions & 6 deletions api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ const (
// LimitBodySizeSmall defines a size limit for requests that we don't expect
// to contain a lot of data.
LimitBodySizeSmall = 4 * skynet.KiB
// LimitBodySizeLarge defines a size limit for requests that we expect to
// contain a lot of data.
LimitBodySizeLarge = 4 * skynet.MiB
peterjan marked this conversation as resolved.
Show resolved Hide resolved
)

type (
Expand Down Expand Up @@ -396,26 +399,53 @@ func (api *API) userGET(u *database.User, w http.ResponseWriter, _ *http.Request
// NOTE: This handler needs to use the noAuth middleware in order to be able to
// optimise its calls to the DB and the use of caching.
func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
// First check for an API key.
ak, err := apiKeyFromRequest(req)
respAnon := UserLimitsGET{
TierID: database.TierAnonymous,
TierLimits: database.UserLimits[database.TierAnonymous],
}
// First check for an API key.
akStr, err := apiKeyFromRequest(req)
if err == nil {
u, err := api.staticDB.UserByAPIKey(req.Context(), ak)
// Check the cache before going any further.
tier, ok := api.staticUserTierCache.Get(akStr)
if ok {
api.staticLogger.Traceln("Fetching user limits from cache by API key.")
resp := UserLimitsGET{
TierID: tier,
TierLimits: database.UserLimits[tier],
}
api.WriteJSON(w, resp)
return
}
// Cache is missed, fetch the data from the DB.
ak := database.APIKey(akStr)
if !ak.IsValid() {
api.staticLogger.Traceln("Invalid API key.")
api.WriteJSON(w, respAnon)
return
}
uID, err := api.userIDForAPIKey(req.Context(), ak)
if err != nil {
api.staticLogger.Traceln("Error while fetching user by API key:", err)
api.WriteJSON(w, respAnon)
return
}
u, err := api.staticDB.UserByID(req.Context(), uID)
if err != nil {
api.staticLogger.Traceln("Error while fetching user by API key:", err)
api.WriteJSON(w, respAnon)
return
}
// Cache the user under the API key they used.
api.staticUserTierCache.Set(u, akStr)
resp := UserLimitsGET{
TierID: u.Tier,
TierLimits: database.UserLimits[u.Tier],
}
api.WriteJSON(w, resp)
return
}
// Next check for a token.
token, err := tokenFromRequest(req)
if err != nil {
api.WriteJSON(w, respAnon)
Expand All @@ -438,7 +468,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http
api.WriteJSON(w, respAnon)
return
}
api.staticUserTierCache.Set(u)
api.staticUserTierCache.Set(u, "")
}
tier, ok = api.staticUserTierCache.Get(sub)
if !ok {
Expand Down Expand Up @@ -1087,7 +1117,7 @@ func (api *API) trackRegistryWritePOST(u *database.User, w http.ResponseWriter,
func (api *API) userUploadsDELETE(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
sl := ps.ByName("skylink")
if !database.ValidSkylinkHash(sl) {
api.WriteError(w, errors.New("invalid skylink"), http.StatusBadRequest)
api.WriteError(w, database.ErrInvalidSkylink, http.StatusBadRequest)
return
}
skylink, err := api.staticDB.Skylink(req.Context(), sl)
Expand Down Expand Up @@ -1129,7 +1159,7 @@ func (api *API) checkUserQuotas(ctx context.Context, u *database.User) {
if err != nil {
api.staticLogger.Warnf("Failed to save user. User: %+v, err: %s", u, err.Error())
}
api.staticUserTierCache.Set(u)
api.staticUserTierCache.Set(u, "")
}
}

Expand Down
Loading