From 503a8bc7fa5135e762ab21f5d3d3107fde9dcf65 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Thu, 17 Feb 2022 11:11:14 +0100 Subject: [PATCH 01/25] Base implementation of public API keys. --- api/handlers.go | 2 +- database/apikeys.go | 3 +- database/database.go | 15 +++++ database/download.go | 4 +- database/publicapikeys.go | 130 ++++++++++++++++++++++++++++++++++++++ database/upload.go | 6 +- database/user.go | 2 + 7 files changed, 155 insertions(+), 7 deletions(-) create mode 100644 database/publicapikeys.go diff --git a/api/handlers.go b/api/handlers.go index 684bfecd..096f2111 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -1048,7 +1048,7 @@ func (api *API) trackRegistryWritePOST(u *database.User, w http.ResponseWriter, func (api *API) userUploadsDELETE(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { sl := ps.ByName("skylink") if !database.ValidSkylinkHash(sl) { - api.WriteError(w, errors.New("invalid skylink"), http.StatusBadRequest) + api.WriteError(w, database.ErrInvalidSkylink, http.StatusBadRequest) return } skylink, err := api.staticDB.Skylink(req.Context(), sl) diff --git a/database/apikeys.go b/database/apikeys.go index 1fa3a31b..48d8f29e 100644 --- a/database/apikeys.go +++ b/database/apikeys.go @@ -33,7 +33,8 @@ var ( type ( // APIKey is a base64URL-encoded representation of []byte with length PubKeySize APIKey string - // APIKeyRecord is a non-expiring authentication token generated on user demand. + // APIKeyRecord is a non-expiring authentication token generated on user + // demand. APIKeyRecord struct { ID primitive.ObjectID `bson:"_id,omitempty" json:"id"` UserID primitive.ObjectID `bson:"user_id" json:"-"` diff --git a/database/database.go b/database/database.go index ca71d16a..30b6d232 100644 --- a/database/database.go +++ b/database/database.go @@ -50,6 +50,9 @@ var ( collConfiguration = "configuration" // collAPIKeys defines the name of the db table with API keys for users. collAPIKeys = "api_keys" + // collPubAPIKeys defines the name of the db table with public API keys for + // users. + collPubAPIKeys = "pub_api_keys" // DefaultPageSize defines the default number of records to return. DefaultPageSize = 10 @@ -97,6 +100,7 @@ type ( staticUnconfirmedUserUpdates *mongo.Collection staticConfiguration *mongo.Collection staticAPIKeys *mongo.Collection + staticPubAPIKeys *mongo.Collection staticDeps lib.Dependencies staticLogger *logrus.Logger } @@ -148,6 +152,7 @@ func NewCustomDB(ctx context.Context, dbName string, creds DBCredentials, logger staticUnconfirmedUserUpdates: database.Collection(collUnconfirmedUserUpdates), staticConfiguration: database.Collection(collConfiguration), staticAPIKeys: database.Collection(collAPIKeys), + staticPubAPIKeys: database.Collection(collPubAPIKeys), staticLogger: logger, } return db, nil @@ -294,6 +299,16 @@ func ensureDBSchema(ctx context.Context, db *mongo.Database, log *logrus.Logger) Options: options.Index().SetName("user_id"), }, }, + collPubAPIKeys: { + { + Keys: bson.D{{"key", 1}}, + Options: options.Index().SetName("key_unique").SetUnique(true), + }, + { + Keys: bson.D{{"user_id", 1}}, + Options: options.Index().SetName("user_id"), + }, + }, } for collName, models := range schema { diff --git a/database/download.go b/database/download.go index 207b8c24..d0bd11dc 100644 --- a/database/download.go +++ b/database/download.go @@ -64,7 +64,7 @@ func (db *DB) DownloadCreate(ctx context.Context, user User, skylink Skylink, by return errors.New("invalid user") } if skylink.ID.IsZero() { - return errors.New("invalid skylink") + return ErrInvalidSkylink } // Check if there exists a download of this skylink by this user, updated @@ -92,7 +92,7 @@ func (db *DB) DownloadCreate(ctx context.Context, user User, skylink Skylink, by // number of such downloads. func (db *DB) DownloadsBySkylink(ctx context.Context, skylink Skylink, offset, pageSize int) ([]DownloadResponse, int, error) { if skylink.ID.IsZero() { - return nil, 0, errors.New("invalid skylink") + return nil, 0, ErrInvalidSkylink } if err := validateOffsetPageSize(offset, pageSize); err != nil { return nil, 0, err diff --git a/database/publicapikeys.go b/database/publicapikeys.go new file mode 100644 index 00000000..d13a4961 --- /dev/null +++ b/database/publicapikeys.go @@ -0,0 +1,130 @@ +package database + +import ( + "context" + "encoding/base64" + "time" + + "gitlab.com/NebulousLabs/errors" + "gitlab.com/NebulousLabs/fastrand" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +type ( + // TODO: I am still not sure whether we should use separate collections or + // keep all API keys in the same one. + + // PubAPIKey is a base64URL-encoded representation of []byte with length + // PubKeySize + PubAPIKey string + // PubAPIKeyRecord is a non-expiring authentication token generated on user + // demand. This token allows anyone to access a set of pre-determined + // skylinks. The traffic generated by this access is counted towards the + // issuing user's balance. + PubAPIKeyRecord struct { + ID primitive.ObjectID `bson:"_id,omitempty" json:"id"` + UserID primitive.ObjectID `bson:"user_id" json:"userID"` + Key PubAPIKey `bson:"key" json:"key"` + Skylinks []string `bson:"skylinks" json:"skylinks"` + CreatedAt time.Time `bson:"created_at" json:"createdAt"` + } +) + +// PubAPIKeyCreate creates a new public API key. +func (db *DB) PubAPIKeyCreate(ctx context.Context, user User, skylinks []string) (*PubAPIKeyRecord, error) { + if user.ID.IsZero() { + return nil, errors.New("invalid user") + } + n, err := db.staticPubAPIKeys.CountDocuments(ctx, bson.M{"user_id": user.ID}) + if err != nil { + return nil, errors.AddContext(err, "failed to ensure user can create a new API key") + } + if n > int64(MaxNumAPIKeysPerUser) { + return nil, ErrMaxNumAPIKeysExceeded + } + // Validate all given skylinks. + for _, s := range skylinks { + if !ValidSkylinkHash(s) { + return nil, ErrInvalidSkylink + } + } + ak := PubAPIKeyRecord{ + UserID: user.ID, + Key: PubAPIKey(base64.URLEncoding.EncodeToString(fastrand.Bytes(PubKeySize))), + Skylinks: skylinks, + CreatedAt: time.Now().UTC(), + } + ior, err := db.staticAPIKeys.InsertOne(ctx, ak) + if err != nil { + return nil, err + } + ak.ID = ior.InsertedID.(primitive.ObjectID) + return &ak, nil +} + +// PubAPIKeyUpdate updates an existing PubAPIKey. This works by replacing the +// list of Skylinks within the PubAPIKey record. +func (db *DB) PubAPIKeyUpdate(ctx context.Context, user User, keyID primitive.ObjectID, skylinks []string) error { + if user.ID.IsZero() { + return errors.New("invalid user") + } + // Validate all given skylinks. + for _, s := range skylinks { + if !ValidSkylinkHash(s) { + return ErrInvalidSkylink + } + } + filter := bson.M{ + "_id": keyID, + "user_id": user.ID, + } + update := bson.M{"skylinks": skylinks} + opts := options.UpdateOptions{ + Upsert: &False, + } + _, err := db.staticPubAPIKeys.UpdateOne(ctx, filter, update, &opts) + return err +} + +// PubAPIKeyDelete deletes a public API key. +func (db *DB) PubAPIKeyDelete(ctx context.Context, user User, akID string) error { + if user.ID.IsZero() { + return errors.New("invalid user") + } + id, err := primitive.ObjectIDFromHex(akID) + if err != nil { + return errors.AddContext(err, "invalid API key ID") + } + filter := bson.M{ + "_id": id, + "user_id": user.ID, + } + dr, err := db.staticPubAPIKeys.DeleteOne(ctx, filter) + if err != nil { + return err + } + if dr.DeletedCount == 0 { + return mongo.ErrNoDocuments + } + return nil +} + +// PubAPIKeyList lists all public API keys that belong to the user. +func (db *DB) PubAPIKeyList(ctx context.Context, user User) ([]*PubAPIKeyRecord, error) { + if user.ID.IsZero() { + return nil, errors.New("invalid user") + } + c, err := db.staticPubAPIKeys.Find(ctx, bson.M{"user_id": user.ID}) + if err != nil { + return nil, err + } + var aks []*PubAPIKeyRecord + err = c.All(ctx, &aks) + if err != nil { + return nil, err + } + return aks, nil +} diff --git a/database/upload.go b/database/upload.go index 22de084f..6415b7d8 100644 --- a/database/upload.go +++ b/database/upload.go @@ -58,7 +58,7 @@ func (db *DB) UploadCreate(ctx context.Context, user User, skylink Skylink) (*Up return nil, errors.New("invalid user") } if skylink.ID.IsZero() { - return nil, errors.New("invalid skylink") + return nil, ErrInvalidSkylink } up := Upload{ UserID: user.ID, @@ -77,7 +77,7 @@ func (db *DB) UploadCreate(ctx context.Context, user User, skylink Skylink) (*Up // number of such uploads. func (db *DB) UploadsBySkylink(ctx context.Context, skylink Skylink, offset, pageSize int) ([]UploadResponse, int, error) { if skylink.ID.IsZero() { - return nil, 0, errors.New("invalid skylink") + return nil, 0, ErrInvalidSkylink } if err := validateOffsetPageSize(offset, pageSize); err != nil { return nil, 0, err @@ -93,7 +93,7 @@ func (db *DB) UploadsBySkylink(ctx context.Context, skylink Skylink, offset, pag // the number of unpinned uploads. func (db *DB) UnpinUploads(ctx context.Context, skylink Skylink, user User) (int64, error) { if skylink.ID.IsZero() { - return 0, errors.New("invalid skylink") + return 0, ErrInvalidSkylink } if user.ID.IsZero() { return 0, errors.New("invalid user") diff --git a/database/user.go b/database/user.go index efb2c841..6e4e820f 100644 --- a/database/user.go +++ b/database/user.go @@ -48,6 +48,8 @@ const ( var ( // True is a helper for when we need to pass a *bool to MongoDB. True = true + // False is a helper for when we need to pass a *bool to MongoDB. + False = false // UserLimits defines the speed limits for each tier. // RegistryDelay delay is in ms. UserLimits = map[int]TierLimits{ From 3e7747139f190c3e4981d468126c6278c27e8ae0 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Fri, 25 Feb 2022 11:30:51 +0100 Subject: [PATCH 02/25] Add some clarity what the `env` image does. --- env/README.md | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 env/README.md diff --git a/env/README.md b/env/README.md new file mode 100644 index 00000000..7373bc76 --- /dev/null +++ b/env/README.md @@ -0,0 +1,4 @@ +## The `env` image + +The purpose of this image is to generate resources needed by the service. It currently only generates +a [JWKS](https://auth0.com/docs/secure/tokens/json-web-tokens/json-web-key-sets), needed for issuing JWTs. From 3890d4f6d5998d676d3e21a25c5d3ed023f0c621 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Fri, 25 Feb 2022 18:40:25 +0100 Subject: [PATCH 03/25] Handle public api keys during authentication. More tests are needed. --- api/auth.go | 149 +++++++++++++++++++++++++++ api/{routes_test.go => auth_test.go} | 6 +- api/cache.go | 12 ++- api/cache_test.go | 27 ++++- api/handlers.go | 31 +++++- api/routes.go | 111 +++----------------- api/stripe.go | 2 +- database/apikeys.go | 16 ++- database/publicapikeys.go | 23 +++++ database/user.go | 14 --- test/api/handlers_test.go | 27 ++++- test/tester.go | 28 ++++- 12 files changed, 314 insertions(+), 132 deletions(-) create mode 100644 api/auth.go rename api/{routes_test.go => auth_test.go} (97%) diff --git a/api/auth.go b/api/auth.go new file mode 100644 index 00000000..9551880b --- /dev/null +++ b/api/auth.go @@ -0,0 +1,149 @@ +package api + +import ( + "context" + "net/http" + "strings" + + "github.com/SkynetLabs/skynet-accounts/database" + "github.com/SkynetLabs/skynet-accounts/jwt" + jwt2 "github.com/lestrrat-go/jwx/jwt" + "gitlab.com/NebulousLabs/errors" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +// userAndTokenByRequestToken scans the request for an authentication token, +// fetches the corresponding user from the database and returns both user and +// token. +func (api *API) userAndTokenByRequestToken(req *http.Request) (*database.User, jwt2.Token, error) { + token, err := tokenFromRequest(req) + if err != nil { + return nil, nil, errors.AddContext(err, "error fetching token from request") + } + sub, _, _, err := jwt.TokenFields(token) + if err != nil { + return nil, nil, errors.AddContext(err, "error decoding token from request") + } + u, err := api.staticDB.UserBySub(req.Context(), sub) + if err != nil { + return nil, nil, errors.AddContext(err, "error fetching user from database") + } + return u, token, nil +} + +// userAndTokenByAPIKey extracts the APIKey or PubAPIKey from the requests and +// validates it. It then returns the user who owns it and a token for that user. +// It first checks the headers and then the query. +// This method accesses the database. +func (api *API) userAndTokenByAPIKey(req *http.Request) (*database.User, jwt2.Token, error) { + akStr, err := apiKeyFromRequest(req) + if err != nil { + return nil, nil, err + } + // We should only check for a PubAPIKey if this is a GET request for a valid + // skylink. We ignore the errors here because the API key might not be a + // public one. + if req.Method == http.MethodGet { + pak := database.PubAPIKey(akStr) + sl, err := database.ExtractSkylinkHash(req.RequestURI) + if err == nil && sl != "" && pak.IsValid() { + uID, err := api.userIDForPubAPIKey(req.Context(), pak, sl) + if err == nil { + return api.userAndTokenByUserID(req.Context(), uID) + } + } + } + // Check if this is a valid APIKey. + ak := database.APIKey(akStr) + if !ak.IsValid() { + return nil, nil, ErrInvalidAPIKey + } + uID, err := api.userIDForAPIKey(req.Context(), ak) + if err != nil { + return nil, nil, ErrInvalidAPIKey + } + return api.userAndTokenByUserID(req.Context(), uID) +} + +// userAndTokenByUserID is a helper method that fetches a given user from the +// database based on their ID, issues a JWT token for them, and returns both +// of those. +func (api *API) userAndTokenByUserID(ctx context.Context, uid primitive.ObjectID) (*database.User, jwt2.Token, error) { + u, err := api.staticDB.UserByID(ctx, uid) + if err != nil { + return nil, nil, err + } + t, err := jwt.TokenForUser(u.Email, u.Sub) + return u, t, err +} + +// userIDForAPIKey looks up the given APIKey and returns the ID of the user that +// issued it. +func (api *API) userIDForAPIKey(ctx context.Context, ak database.APIKey) (primitive.ObjectID, error) { + akRec, err := api.staticDB.APIKeyGetRecord(ctx, ak) + if err != nil { + return primitive.ObjectID{}, err + } + return akRec.UserID, nil +} + +// userIDForPubAPIKey looks up the given PubAPIKey, validates that the target +// skylink is covered by it, and returns the ID of the user that issued the +// PubAPIKey. +func (api *API) userIDForPubAPIKey(ctx context.Context, pak database.PubAPIKey, sl string) (primitive.ObjectID, error) { + pakRec, err := api.staticDB.PubAPIKeyGetRecord(ctx, pak) + if err != nil { + return primitive.ObjectID{}, err + } + for _, s := range pakRec.Skylinks { + if sl == s { + return pakRec.UserID, nil + } + } + return primitive.ObjectID{}, database.ErrUserNotFound +} + +// apiKeyFromRequest extracts the API key from the request and returns it. +// It first checks the headers and then the query. +func apiKeyFromRequest(r *http.Request) (string, error) { + // Check the headers for an API key. + akStr := r.Header.Get(APIKeyHeader) + // If there is no API key in the headers, try the query. + if akStr == "" { + akStr = r.FormValue("apiKey") + } + if akStr == "" { + return "", ErrNoAPIKey + } + return akStr, nil +} + +// tokenFromRequest extracts the JWT token from the request and returns it. +// It first checks the authorization header and then the cookies. +// The token is validated before being returned. +func tokenFromRequest(r *http.Request) (jwt2.Token, error) { + var tokenStr string + // Check the headers for a token. + parts := strings.Split(r.Header.Get("Authorization"), "Bearer") + if len(parts) == 2 { + tokenStr = strings.TrimSpace(parts[1]) + } else { + // Check the cookie for a token. + cookie, err := r.Cookie(CookieName) + if errors.Contains(err, http.ErrNoCookie) { + return nil, ErrNoToken + } + if err != nil { + return nil, errors.AddContext(err, "cookie exists but it's not valid") + } + err = secureCookie.Decode(CookieName, cookie.Value, &tokenStr) + if err != nil { + return nil, errors.AddContext(err, "failed to decode token") + } + } + token, err := jwt.ValidateToken(tokenStr) + if err != nil { + return nil, errors.AddContext(err, "failed to validate token") + } + return token, nil +} diff --git a/api/routes_test.go b/api/auth_test.go similarity index 97% rename from api/routes_test.go rename to api/auth_test.go index 6eac462f..6fa05d41 100644 --- a/api/routes_test.go +++ b/api/auth_test.go @@ -34,7 +34,7 @@ func TestAPIKeyFromRequest(t *testing.T) { if err != nil { t.Fatal(err) } - if string(tk) != token { + if tk != token { t.Fatalf("Expected '%s', got '%s'.", token, tk) } @@ -45,10 +45,10 @@ func TestAPIKeyFromRequest(t *testing.T) { if err != nil { t.Fatal(err) } - if string(tk) == token { + if tk == token { t.Fatal("Form token took precedence over headers token.") } - if string(tk) != token2 { + if tk != token2 { t.Fatalf("Expected '%s', got '%s'.", token2, tk) } } diff --git a/api/cache.go b/api/cache.go index 7885c638..53358fd0 100644 --- a/api/cache.go +++ b/api/cache.go @@ -46,8 +46,10 @@ func (utc *userTierCache) Get(sub string) (int, bool) { return ce.Tier, true } -// Set stores the user's tier in the cache. -func (utc *userTierCache) Set(u *database.User) { +// Set stores the user's tier in the cache. If the customCacheKey is not empty, +// it will be used to store the user in the cache, otherwise the user's sub will +// be used. +func (utc *userTierCache) Set(u *database.User, customCacheKey string) { var ce userTierCacheEntry now := time.Now().UTC() if u.SubscribedUntil.Before(now) { @@ -67,6 +69,10 @@ func (utc *userTierCache) Set(u *database.User) { } } utc.mu.Lock() - utc.cache[u.Sub] = ce + if customCacheKey == "" { + utc.cache[u.Sub] = ce + } else { + utc.cache[customCacheKey] = ce + } utc.mu.Unlock() } diff --git a/api/cache_test.go b/api/cache_test.go index 76c332d1..835651ca 100644 --- a/api/cache_test.go +++ b/api/cache_test.go @@ -1,10 +1,12 @@ package api import ( + "encoding/base64" "testing" "time" "github.com/SkynetLabs/skynet-accounts/database" + "gitlab.com/NebulousLabs/fastrand" ) // TestUserTierCache tests that working with userTierCache works as expected. @@ -22,7 +24,7 @@ func TestUserTierCache(t *testing.T) { t.Fatalf("Expected to get tier %d and %t, got %d and %t.", database.TierAnonymous, false, tier, ok) } // Set the use in the cache. - cache.Set(u) + cache.Set(u, "") // Check again. tier, ok = cache.Get(u.Sub) if !ok || tier != u.Tier { @@ -41,10 +43,31 @@ func TestUserTierCache(t *testing.T) { timeToMonthRollover := 30 * time.Minute u.SubscribedUntil = time.Now().UTC().Add(timeToMonthRollover) // Update the cache. - cache.Set(u) + cache.Set(u, "") // Expect the cache entry's ExpiresAt to be after 30 minutes. timeIn30 := time.Now().UTC().Add(time.Hour - timeToMonthRollover) if ce.ExpiresAt.After(timeIn30) && ce.ExpiresAt.Before(timeIn30.Add(time.Second)) { t.Fatalf("Expected ExpiresAt to be within 1 second of %s, but it was %s (off by %d ns)", timeIn30.String(), ce.ExpiresAt.String(), (time.Hour - timeIn30.Sub(ce.ExpiresAt)).Nanoseconds()) } + + // Create a new API key. + ak := database.APIKey(base64.URLEncoding.EncodeToString(fastrand.Bytes(database.PubKeySize))) + if !ak.IsValid() { + t.Fatal("Invalid API key.") + } + // Try to get a value from the cache. Expect this to fail. + _, ok = cache.Get(string(ak)) + if ok { + t.Fatal("Did not expect to get a cache entry!") + } + // Update the cache with a custom key. + cache.Set(u, string(ak)) + // Fetch the data for the custom key. + tier, ok = cache.Get(string(ak)) + if !ok { + t.Fatal("Expected the entry to exist.") + } + if tier != u.Tier { + t.Fatalf("Expected tier %+v, got %+v", u.Tier, tier) + } } diff --git a/api/handlers.go b/api/handlers.go index 096f2111..1582683b 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -370,17 +370,40 @@ func (api *API) userGET(u *database.User, w http.ResponseWriter, _ *http.Request // optimise its calls to the DB and the use of caching. func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { // First check for an API key. - ak, err := apiKeyFromRequest(req) + akStr, err := apiKeyFromRequest(req) if err == nil { - u, err := api.staticDB.UserByAPIKey(req.Context(), ak) + // Check the cache before going any further. + tier, ok := api.staticUserTierCache.Get(akStr) + if ok { + api.staticLogger.Traceln("Fetching user limits from cache by API key.") + api.WriteJSON(w, database.UserLimits[tier]) + return + } + // Cache is missed, fetch the data from the DB. + ak := database.APIKey(akStr) + if !ak.IsValid() { + api.staticLogger.Traceln("Invalid API key.") + api.WriteJSON(w, database.UserLimits[database.TierAnonymous]) + return + } + uID, err := api.userIDForAPIKey(req.Context(), ak) + if err != nil { + api.staticLogger.Traceln("Error while fetching user by API key:", err) + api.WriteJSON(w, database.UserLimits[database.TierAnonymous]) + return + } + u, err := api.staticDB.UserByID(req.Context(), uID) if err != nil { api.staticLogger.Traceln("Error while fetching user by API key:", err) api.WriteJSON(w, database.UserLimits[database.TierAnonymous]) return } + // Cache the user under the API key they used. + api.staticUserTierCache.Set(u, akStr) api.WriteJSON(w, database.UserLimits[u.Tier]) return } + // Next check for a token. token, err := tokenFromRequest(req) if err != nil { api.WriteJSON(w, database.UserLimits[database.TierAnonymous]) @@ -403,7 +426,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http api.WriteJSON(w, database.UserLimits[database.TierAnonymous]) return } - api.staticUserTierCache.Set(u) + api.staticUserTierCache.Set(u, "") } tier, ok = api.staticUserTierCache.Get(sub) if !ok { @@ -1090,7 +1113,7 @@ func (api *API) checkUserQuotas(ctx context.Context, u *database.User) { if err != nil { api.staticLogger.Warnf("Failed to save user. User: %+v, err: %s", u, err.Error()) } - api.staticUserTierCache.Set(u) + api.staticUserTierCache.Set(u, "") } } diff --git a/api/routes.go b/api/routes.go index 07b65ff7..35d61511 100644 --- a/api/routes.go +++ b/api/routes.go @@ -1,14 +1,11 @@ package api import ( - "context" "net/http" "strings" "github.com/SkynetLabs/skynet-accounts/database" "github.com/SkynetLabs/skynet-accounts/jwt" - jwt2 "github.com/lestrrat-go/jwx/jwt" - "github.com/julienschmidt/httprouter" "gitlab.com/NebulousLabs/errors" ) @@ -93,45 +90,24 @@ func (api *API) noAuth(h HandlerWithUser) httprouter.Handle { func (api *API) withAuth(h HandlerWithUser) httprouter.Handle { return func(w http.ResponseWriter, req *http.Request, ps httprouter.Params) { api.logRequest(req) - var u *database.User - var token jwt2.Token - // Check for an API key. We only return an error if an invalid API key - // is provided. - ak, err := apiKeyFromRequest(req) - if err == nil { - // We have an API key. Let's generate a token based on it. - token, err = api.tokenFromAPIKey(req.Context(), ak) - u, err = api.staticDB.UserByAPIKey(req.Context(), ak) - if err != nil { - api.staticLogger.Debugf("Error fetching user for API key %s. Error: %s", ak, err) - api.WriteError(w, errors.AddContext(err, "failed to fetch user by API key"), http.StatusUnauthorized) - return - } - } else { - // No API key. Let's check for a token in the request. - token, err = tokenFromRequest(req) - if err != nil { - api.staticLogger.Debugln("Error fetching token from request:", err) - api.WriteError(w, err, http.StatusUnauthorized) - return - } - sub, _, _, err := jwt.TokenFields(token) + // Check for an API key. + u, token, err := api.userAndTokenByAPIKey(req) + // If there is an unexpected error, that is a 500. + if err != nil && !errors.Contains(err, ErrNoAPIKey) && !errors.Contains(err, ErrInvalidAPIKey) && !errors.Contains(err, database.ErrUserNotFound) { + api.WriteError(w, err, http.StatusInternalServerError) + return + } + if err != nil && (errors.Contains(err, ErrInvalidAPIKey) || errors.Contains(err, database.ErrUserNotFound)) { + api.WriteError(w, errors.AddContext(err, "failed to fetch user by API key"), http.StatusUnauthorized) + return + } + // If there is no API key check for a token. + if errors.Contains(err, ErrNoAPIKey) { + u, token, err = api.userAndTokenByRequestToken(req) if err != nil { - api.staticLogger.Debugln("Error decoding token from request:", err) - api.WriteError(w, err, http.StatusUnauthorized) - return - } - u, err = api.staticDB.UserBySub(req.Context(), sub) - if errors.Contains(err, database.ErrUserNotFound) { - api.staticLogger.Debugln("User that created this token no longer exists:", err) api.WriteError(w, err, http.StatusUnauthorized) return } - if err != nil { - api.staticLogger.Debugln("Error fetching user by token from request:", err) - api.WriteError(w, err, http.StatusInternalServerError) - return - } } // Embed the verified token in the context of the request. ctx := jwt.ContextWithToken(req.Context(), token) @@ -146,62 +122,3 @@ func (api *API) logRequest(r *http.Request) { hasCookie := err == nil && c != nil api.staticLogger.Tracef("Processing request: %v %v, Auth: %v, Skynet Cookie: %v, Referer: %v, Host: %v, RemoreAddr: %v", r.Method, r.URL, hasAuth, hasCookie, r.Referer(), r.Host, r.RemoteAddr) } - -// tokenFromAPIKey returns a token, generated for the owner of the given API key. -func (api *API) tokenFromAPIKey(ctx context.Context, ak database.APIKey) (jwt2.Token, error) { - u, err := api.staticDB.UserByAPIKey(ctx, ak) - if err != nil { - return nil, err - } - tk, err := jwt.TokenForUser(u.Email, u.Sub) - return tk, err -} - -// apiKeyFromRequest extracts the API key from the request and returns it. -// It first checks the headers and then the query. -func apiKeyFromRequest(r *http.Request) (database.APIKey, error) { - // Check the headers for an API key. - akStr := r.Header.Get(APIKeyHeader) - // If there is no API key in the headers, try the query. - if akStr == "" { - akStr = r.FormValue("apiKey") - } - if akStr == "" { - return "", ErrNoAPIKey - } - ak := database.APIKey(akStr) - if !ak.IsValid() { - return "", ErrInvalidAPIKey - } - return ak, nil -} - -// tokenFromRequest extracts the JWT token from the request and returns it. -// It first checks the request headers and then the cookies. -// The token is validated before being returned. -func tokenFromRequest(r *http.Request) (jwt2.Token, error) { - var tokenStr string - // Check the headers for a token. - parts := strings.Split(r.Header.Get("Authorization"), "Bearer") - if len(parts) == 2 { - tokenStr = strings.TrimSpace(parts[1]) - } else { - // Check the cookie for a token. - cookie, err := r.Cookie(CookieName) - if errors.Contains(err, http.ErrNoCookie) { - return nil, ErrNoToken - } - if err != nil { - return nil, errors.AddContext(err, "cookie exists but it's not valid") - } - err = secureCookie.Decode(CookieName, cookie.Value, &tokenStr) - if err != nil { - return nil, errors.AddContext(err, "failed to decode token") - } - } - token, err := jwt.ValidateToken(tokenStr) - if err != nil { - return nil, errors.AddContext(err, "failed to validate token") - } - return token, nil -} diff --git a/api/stripe.go b/api/stripe.go index ba580d12..45e80c34 100644 --- a/api/stripe.go +++ b/api/stripe.go @@ -241,7 +241,7 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er api.staticLogger.Tracef("Subscribed user id %s, tier %d, until %s.", u.ID, u.Tier, u.SubscribedUntil.String()) } // Re-set the tier cache for this user, in case their tier changed. - api.staticUserTierCache.Set(u) + api.staticUserTierCache.Set(u, "") return err } diff --git a/database/apikeys.go b/database/apikeys.go index 48d8f29e..02e3dc78 100644 --- a/database/apikeys.go +++ b/database/apikeys.go @@ -45,7 +45,7 @@ type ( // IsValid checks whether the underlying string satisfies the type's requirement // to represent a []byte with length PubKeySize which is encoded as base64URL. -// This method does NOT check whether the API exists in the database. +// This method does NOT check whether the API key exists in the database. func (ak APIKey) IsValid() bool { b := make([]byte, PubKeySize) n, err := base64.URLEncoding.Decode(b, []byte(ak)) @@ -100,6 +100,20 @@ func (db *DB) APIKeyDelete(ctx context.Context, user User, akID string) error { return nil } +// APIKeyGetRecord returns a specific API key. +func (db *DB) APIKeyGetRecord(ctx context.Context, ak APIKey) (APIKeyRecord, error) { + sr := db.staticAPIKeys.FindOne(ctx, bson.M{"key": ak}) + if sr.Err() != nil { + return APIKeyRecord{}, sr.Err() + } + var akRecord APIKeyRecord + err := sr.Decode(&akRecord) + if err != nil { + return APIKeyRecord{}, err + } + return akRecord, nil +} + // APIKeyList lists all API keys that belong to the user. func (db *DB) APIKeyList(ctx context.Context, user User) ([]*APIKeyRecord, error) { if user.ID.IsZero() { diff --git a/database/publicapikeys.go b/database/publicapikeys.go index d13a4961..17449288 100644 --- a/database/publicapikeys.go +++ b/database/publicapikeys.go @@ -33,6 +33,15 @@ type ( } ) +// IsValid checks whether the underlying string satisfies the type's requirement +// to represent a []byte with length PubKeySize which is encoded as base64URL. +// This method does NOT check whether the public API key exists in the database. +func (pak PubAPIKey) IsValid() bool { + b := make([]byte, PubKeySize) + n, err := base64.URLEncoding.Decode(b, []byte(pak)) + return err == nil && n == PubKeySize +} + // PubAPIKeyCreate creates a new public API key. func (db *DB) PubAPIKeyCreate(ctx context.Context, user User, skylinks []string) (*PubAPIKeyRecord, error) { if user.ID.IsZero() { @@ -112,6 +121,20 @@ func (db *DB) PubAPIKeyDelete(ctx context.Context, user User, akID string) error return nil } +// PubAPIKeyGetRecord returns a specific public API key. +func (db *DB) PubAPIKeyGetRecord(ctx context.Context, pak PubAPIKey) (PubAPIKeyRecord, error) { + sr := db.staticPubAPIKeys.FindOne(ctx, bson.M{"key": pak}) + if sr.Err() != nil { + return PubAPIKeyRecord{}, sr.Err() + } + var pakRec PubAPIKeyRecord + err := sr.Decode(&pakRec) + if err != nil { + return PubAPIKeyRecord{}, err + } + return pakRec, nil +} + // PubAPIKeyList lists all public API keys that belong to the user. func (db *DB) PubAPIKeyList(ctx context.Context, user User) ([]*PubAPIKeyRecord, error) { if user.ID.IsZero() { diff --git a/database/user.go b/database/user.go index 6e4e820f..35bd5aa7 100644 --- a/database/user.go +++ b/database/user.go @@ -160,20 +160,6 @@ type ( } ) -// UserByAPIKey returns the user who owns the given API key. -func (db *DB) UserByAPIKey(ctx context.Context, ak APIKey) (*User, error) { - sr := db.staticAPIKeys.FindOne(ctx, bson.M{"key": ak}) - if sr.Err() != nil { - return nil, sr.Err() - } - var apiKey APIKeyRecord - err := sr.Decode(&apiKey) - if err != nil { - return nil, err - } - return db.UserByID(ctx, apiKey.UserID) -} - // UserByEmail returns the user with the given username. func (db *DB) UserByEmail(ctx context.Context, email string) (*User, error) { users, err := db.managedUsersByField(ctx, "email", email) diff --git a/test/api/handlers_test.go b/test/api/handlers_test.go index 3cfbf0a8..7314508a 100644 --- a/test/api/handlers_test.go +++ b/test/api/handlers_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/SkynetLabs/skynet-accounts/api" "github.com/SkynetLabs/skynet-accounts/database" "github.com/SkynetLabs/skynet-accounts/email" "github.com/SkynetLabs/skynet-accounts/skynet" @@ -432,8 +433,19 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { at.Cookie = c defer func() { at.Cookie = nil }() - // Call /user/limits with a cookie. Expect FreeTier response. - _, b, err := at.Get("/user/limits", nil) + // Create an API key for this user. + _, b, err := at.Post("/user/apikeys", nil, nil) + if err != nil { + t.Fatal(err) + } + var akRec database.APIKeyRecord + err = json.Unmarshal(b, &akRec) + if err != nil { + t.Fatal(err) + } + + // Call /user/limits with a cookie. Expect TierFree response. + _, b, err = at.Get("/user/limits", nil) if err != nil { t.Fatal(err) } @@ -446,7 +458,7 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { t.Fatalf("Expected to get the results for %s, got %s", database.UserLimits[database.TierFree].TierName, tl.TierName) } - // Call /user/limits without a cookie. Expect FreeAnonymous response. + // Call /user/limits without a cookie. Expect TierAnonymous response. at.Cookie = nil _, b, err = at.Get("/user/limits", nil) if err != nil { @@ -459,6 +471,15 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { if tl.TierName != database.UserLimits[database.TierAnonymous].TierName { t.Fatalf("Expected to get the results for %s, got %s", database.UserLimits[database.TierAnonymous].TierName, tl.TierName) } + + // Call /user/limits with an API key. Expect TierFree response. + tl, _, err = at.UserLimitsGET(nil, map[string]string{api.APIKeyHeader: string(akRec.Key)}) + if err != nil { + t.Fatal(err) + } + if tl.TierName != database.UserLimits[database.TierFree].TierName { + t.Fatalf("Expected to get the results for %s, got %s", database.UserLimits[database.TierFree].TierName, tl.TierName) + } } // testUserUploadsDELETE tests the DELETE /user/uploads/:skylink endpoint. diff --git a/test/tester.go b/test/tester.go index 09f4a278..ef226f49 100644 --- a/test/tester.go +++ b/test/tester.go @@ -127,14 +127,14 @@ func NewAccountsTester(dbName string) (*AccountsTester, error) { // // NOTE: The Body of the returned response is already read and closed. func (at *AccountsTester) Get(endpoint string, params url.Values) (r *http.Response, body []byte, err error) { - return at.request(http.MethodGet, endpoint, params, nil) + return at.request(http.MethodGet, endpoint, params, nil, nil) } // Delete executes a DELETE request against the test service. // // NOTE: The Body of the returned response is already read and closed. func (at *AccountsTester) Delete(endpoint string, params url.Values) (r *http.Response, body []byte, err error) { - return at.request(http.MethodDelete, endpoint, params, nil) + return at.request(http.MethodDelete, endpoint, params, nil, nil) } // Post executes a POST request against the test service. @@ -177,7 +177,7 @@ func (at *AccountsTester) Post(endpoint string, params url.Values, bodyParams ur // // NOTE: The Body of the returned response is already read and closed. func (at *AccountsTester) Put(endpoint string, params url.Values, putParams url.Values) (r *http.Response, body []byte, err error) { - return at.request(http.MethodPut, endpoint, params, putParams) + return at.request(http.MethodPut, endpoint, params, putParams, nil) } // Close performs a graceful shutdown of the AccountsTester service. @@ -228,7 +228,7 @@ func (at *AccountsTester) UserPUT(email, password, stipeID string) (*http.Respon // request. It attaches the current cookie, if one exists. // // NOTE: The Body of the returned response is already read and closed. -func (at *AccountsTester) request(method string, endpoint string, queryParams url.Values, bodyParams url.Values) (*http.Response, []byte, error) { +func (at *AccountsTester) request(method string, endpoint string, queryParams url.Values, bodyParams url.Values, headers map[string]string) (*http.Response, []byte, error) { if queryParams == nil { queryParams = url.Values{} } @@ -241,6 +241,9 @@ func (at *AccountsTester) request(method string, endpoint string, queryParams ur if err != nil { return nil, nil, err } + for name, val := range headers { + req.Header.Set(name, val) + } if at.Cookie != nil { req.Header.Set("Cookie", at.Cookie.String()) } @@ -252,6 +255,23 @@ func (at *AccountsTester) request(method string, endpoint string, queryParams ur return processResponse(r) } +// UserLimitsGET performs a `GET /user/limits` request. +func (at *AccountsTester) UserLimitsGET(params url.Values, headers map[string]string) (database.TierLimits, int, error) { + r, b, err := at.request(http.MethodGet, "/user/limits", params, nil, headers) + if err != nil { + return database.TierLimits{}, r.StatusCode, err + } + if r.StatusCode != http.StatusOK { + return database.TierLimits{}, r.StatusCode, errors.New(string(b)) + } + var result database.TierLimits + err = json.Unmarshal(b, &result) + if err != nil { + return database.TierLimits{}, 0, errors.AddContext(err, "failed to parse response") + } + return result, r.StatusCode, nil +} + // processResponse is a helper method which extracts the body from the response // and handles non-OK status codes. // From 4d1b4e67f9df2ca7924b895669c2e665f92bee91 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Fri, 25 Feb 2022 19:19:54 +0100 Subject: [PATCH 04/25] Fix a broken test (needs a custom struct to get a hidden field). --- test/api/handlers_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/api/handlers_test.go b/test/api/handlers_test.go index c66e3e26..38c90117 100644 --- a/test/api/handlers_test.go +++ b/test/api/handlers_test.go @@ -444,7 +444,9 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { if err != nil { t.Fatal(err) } - var akRec database.APIKeyRecord + var akRec struct { + Key database.APIKey `json:"key"` + } err = json.Unmarshal(b, &akRec) if err != nil { t.Fatal(err) From d4ed0812a956ec49e2ae4159202801ecbccd1f75 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Wed, 2 Mar 2022 13:09:08 +0100 Subject: [PATCH 05/25] Move public API keys to their own set of endpoints, as they are sufficiently different from the general (secret) API keys. --- api/auth.go | 9 ++-- api/handlers.go | 3 ++ api/pubapikeys.go | 108 ++++++++++++++++++++++++++++++++++++++ api/routes.go | 7 +++ database/publicapikeys.go | 66 +++++++++++++++++++---- main.go | 2 +- 6 files changed, 180 insertions(+), 15 deletions(-) create mode 100644 api/pubapikeys.go diff --git a/api/auth.go b/api/auth.go index 9551880b..1541af93 100644 --- a/api/auth.go +++ b/api/auth.go @@ -12,6 +12,9 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" ) +// TODO Test the methods here which are still untested. +// - add integration tests + // userAndTokenByRequestToken scans the request for an authentication token, // fetches the corresponding user from the database and returns both user and // token. @@ -66,7 +69,7 @@ func (api *API) userAndTokenByAPIKey(req *http.Request) (*database.User, jwt2.To } // userAndTokenByUserID is a helper method that fetches a given user from the -// database based on their ID, issues a JWT token for them, and returns both +// database based on their Key, issues a JWT token for them, and returns both // of those. func (api *API) userAndTokenByUserID(ctx context.Context, uid primitive.ObjectID) (*database.User, jwt2.Token, error) { u, err := api.staticDB.UserByID(ctx, uid) @@ -77,7 +80,7 @@ func (api *API) userAndTokenByUserID(ctx context.Context, uid primitive.ObjectID return u, t, err } -// userIDForAPIKey looks up the given APIKey and returns the ID of the user that +// userIDForAPIKey looks up the given APIKey and returns the Key of the user that // issued it. func (api *API) userIDForAPIKey(ctx context.Context, ak database.APIKey) (primitive.ObjectID, error) { akRec, err := api.staticDB.APIKeyGetRecord(ctx, ak) @@ -88,7 +91,7 @@ func (api *API) userIDForAPIKey(ctx context.Context, ak database.APIKey) (primit } // userIDForPubAPIKey looks up the given PubAPIKey, validates that the target -// skylink is covered by it, and returns the ID of the user that issued the +// skylink is covered by it, and returns the Key of the user that issued the // PubAPIKey. func (api *API) userIDForPubAPIKey(ctx context.Context, pak database.PubAPIKey, sl string) (primitive.ObjectID, error) { pakRec, err := api.staticDB.PubAPIKeyGetRecord(ctx, pak) diff --git a/api/handlers.go b/api/handlers.go index b63eb2ae..5e1e6c0d 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -27,6 +27,9 @@ const ( // LimitBodySizeSmall defines a size limit for requests that we don't expect // to contain a lot of data. LimitBodySizeSmall = 4 * skynet.KiB + // LimitBodySizeLarge defines a size limit for requests that we expect to + // contain a lot of data. + LimitBodySizeLarge = 4 * skynet.MiB ) type ( diff --git a/api/pubapikeys.go b/api/pubapikeys.go new file mode 100644 index 00000000..01cf3393 --- /dev/null +++ b/api/pubapikeys.go @@ -0,0 +1,108 @@ +package api + +import ( + "net/http" + "strconv" + + "github.com/SkynetLabs/skynet-accounts/database" + "github.com/julienschmidt/httprouter" + "gitlab.com/NebulousLabs/errors" + "go.mongodb.org/mongo-driver/mongo" +) + +type ( + // PubAPIKeyPOST describes the request body for creating a new PubAPIKey + PubAPIKeyPOST struct { + Skylinks []string + } + // PubAPIKeyPUT describes the request body for updating a PubAPIKey + PubAPIKeyPUT struct { + Key database.PubAPIKey + Skylinks []string + } + // PubAPIKeyPATCH describes the request body for updating a PubAPIKey by + // providing only the requested changes + PubAPIKeyPATCH struct { + Key database.PubAPIKey + Add []string + Remove []string + } +) + +// userAPIKeyGET lists all PubAPI keys associated with the user. +func (api *API) userPubAPIKeyGET(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { + paks, err := api.staticDB.PubAPIKeyList(req.Context(), *u) + if err != nil { + api.WriteError(w, err, http.StatusInternalServerError) + return + } + api.WriteJSON(w, paks) +} + +// userAPIKeyDELETE removes a PubAPI key. +func (api *API) userPubAPIKeyDELETE(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { + pakID := ps.ByName("id") + err := api.staticDB.PubAPIKeyDelete(req.Context(), *u, pakID) + if err == mongo.ErrNoDocuments { + api.WriteError(w, err, http.StatusBadRequest) + return + } + if err != nil { + api.WriteError(w, err, http.StatusInternalServerError) + return + } + api.WriteSuccess(w) +} + +// userPubAPIKeyPOST creates a new PubAPI key for the user. +func (api *API) userPubAPIKeyPOST(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { + var body PubAPIKeyPOST + err := parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + pakRec, err := api.staticDB.PubAPIKeyCreate(req.Context(), *u, body.Skylinks) + if errors.Contains(err, database.ErrMaxNumAPIKeysExceeded) { + err = errors.AddContext(err, "the maximum number of API keys a user can create is "+strconv.Itoa(database.MaxNumAPIKeysPerUser)) + api.WriteError(w, err, http.StatusBadRequest) + return + } + if err != nil { + api.WriteError(w, err, http.StatusInternalServerError) + return + } + api.WriteJSON(w, pakRec) +} + +// userPubAPIKeyPUT updates a PubAPI key. +func (api *API) userPubAPIKeyPUT(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { + var body PubAPIKeyPUT + err := parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + err = api.staticDB.PubAPIKeyUpdate(req.Context(), *u, body.Key, body.Skylinks) + if err != nil { + api.WriteError(w, err, http.StatusInternalServerError) + return + } + api.WriteSuccess(w) +} + +// userPubAPIKeyPATCH patches a PubAPI key. The difference between PUT and PATCH is +func (api *API) userPubAPIKeyPATCH(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { + var body PubAPIKeyPATCH + err := parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + err = api.staticDB.PubAPIKeyPatch(req.Context(), *u, body.Key, body.Add, body.Remove) + if err != nil { + api.WriteError(w, err, http.StatusInternalServerError) + return + } + api.WriteSuccess(w) +} diff --git a/api/routes.go b/api/routes.go index 34fd3917..9fecdeca 100644 --- a/api/routes.go +++ b/api/routes.go @@ -65,6 +65,13 @@ func (api *API) buildHTTPRoutes() { api.staticRouter.GET("/user/apikeys", api.withAuth(api.userAPIKeyGET)) api.staticRouter.DELETE("/user/apikeys/:id", api.withAuth(api.userAPIKeyDELETE)) + // Endpoints for user public API keys. + api.staticRouter.POST("/user/pubapikeys", api.WithDBSession(api.withAuth(api.userPubAPIKeyPOST))) + api.staticRouter.GET("/user/pubapikeys", api.withAuth(api.userPubAPIKeyGET)) + api.staticRouter.PUT("/user/pubapikeys", api.WithDBSession(api.withAuth(api.userPubAPIKeyPUT))) + api.staticRouter.PATCH("/user/pubapikeys", api.WithDBSession(api.withAuth(api.userPubAPIKeyPATCH))) + api.staticRouter.DELETE("/user/pubapikeys/:id", api.withAuth(api.userPubAPIKeyDELETE)) + // Endpoints for email communication with the user. api.staticRouter.GET("/user/confirm", api.WithDBSession(api.noAuth(api.userConfirmGET))) // TODO POST api.staticRouter.POST("/user/reconfirm", api.WithDBSession(api.withAuth(api.userReconfirmPOST))) diff --git a/database/publicapikeys.go b/database/publicapikeys.go index 17449288..3d686a1c 100644 --- a/database/publicapikeys.go +++ b/database/publicapikeys.go @@ -60,23 +60,23 @@ func (db *DB) PubAPIKeyCreate(ctx context.Context, user User, skylinks []string) return nil, ErrInvalidSkylink } } - ak := PubAPIKeyRecord{ + pakRec := PubAPIKeyRecord{ UserID: user.ID, Key: PubAPIKey(base64.URLEncoding.EncodeToString(fastrand.Bytes(PubKeySize))), Skylinks: skylinks, CreatedAt: time.Now().UTC(), } - ior, err := db.staticAPIKeys.InsertOne(ctx, ak) + ior, err := db.staticAPIKeys.InsertOne(ctx, pakRec) if err != nil { return nil, err } - ak.ID = ior.InsertedID.(primitive.ObjectID) - return &ak, nil + pakRec.ID = ior.InsertedID.(primitive.ObjectID) + return &pakRec, nil } // PubAPIKeyUpdate updates an existing PubAPIKey. This works by replacing the // list of Skylinks within the PubAPIKey record. -func (db *DB) PubAPIKeyUpdate(ctx context.Context, user User, keyID primitive.ObjectID, skylinks []string) error { +func (db *DB) PubAPIKeyUpdate(ctx context.Context, user User, pak PubAPIKey, skylinks []string) error { if user.ID.IsZero() { return errors.New("invalid user") } @@ -87,7 +87,7 @@ func (db *DB) PubAPIKeyUpdate(ctx context.Context, user User, keyID primitive.Ob } } filter := bson.M{ - "_id": keyID, + "key": pak, "user_id": user.ID, } update := bson.M{"skylinks": skylinks} @@ -98,12 +98,56 @@ func (db *DB) PubAPIKeyUpdate(ctx context.Context, user User, keyID primitive.Ob return err } +// PubAPIKeyPatch updates an existing PubAPIKey. This works by adding and +// removing specific elements directly in Mongo. +func (db *DB) PubAPIKeyPatch(ctx context.Context, user User, pak PubAPIKey, addSkylinks, removeSkylinks []string) error { + if user.ID.IsZero() { + return errors.New("invalid user") + } + // Validate all given skylinks. + for _, s := range append(addSkylinks, removeSkylinks...) { + if !ValidSkylinkHash(s) { + return ErrInvalidSkylink + } + } + var filter, update bson.M + // First, all new skylinks to the record. + if len(addSkylinks) > 0 { + filter = bson.M{"key": pak} + update = bson.M{ + "$push": bson.M{"skylinks": bson.M{"$each": addSkylinks}}, + } + opts := options.UpdateOptions{ + Upsert: &False, + } + _, err := db.staticPubAPIKeys.UpdateOne(ctx, filter, update, &opts) + if err != nil { + return err + } + } + // Then, remove all skylinks that need to be removed. + if len(removeSkylinks) > 0 { + filter = bson.M{"key": pak} + update = bson.M{ + "pull": bson.M{"skylinks": bson.M{"$in": addSkylinks}}, + } + opts := options.UpdateOptions{ + Upsert: &False, + } + _, err := db.staticPubAPIKeys.UpdateOne(ctx, filter, update, &opts) + if err != nil { + return err + } + } + return nil +} + // PubAPIKeyDelete deletes a public API key. -func (db *DB) PubAPIKeyDelete(ctx context.Context, user User, akID string) error { +func (db *DB) PubAPIKeyDelete(ctx context.Context, user User, pakID string) error { if user.ID.IsZero() { return errors.New("invalid user") } - id, err := primitive.ObjectIDFromHex(akID) + id, err := primitive.ObjectIDFromHex(pakID) if err != nil { return errors.AddContext(err, "invalid API key ID") } @@ -144,10 +188,10 @@ func (db *DB) PubAPIKeyList(ctx context.Context, user User) ([]*PubAPIKeyRecord, if err != nil { return nil, err } - var aks []*PubAPIKeyRecord - err = c.All(ctx, &aks) + var paks []*PubAPIKeyRecord + err = c.All(ctx, &paks) if err != nil { return nil, err } - return aks, nil + return paks, nil } diff --git a/main.go b/main.go index 4132dc9f..6d8d95e1 100644 --- a/main.go +++ b/main.go @@ -140,7 +140,7 @@ func parseConfiguration(logger *logrus.Logger) (ServiceConfig, error) { config.ServerLockID = config.PortalName logger.Warningf(`Environment variable %s is missing! This server's identity`+ ` is set to the default '%s' value. That is OK only if this server is running on its own`+ - ` and it's not sharing its DB with other nodes.\n`, envServerDomain, config.ServerLockID) + ` and it's not sharing its DB with other nodes.`, envServerDomain, config.ServerLockID) } if sk := os.Getenv(envStripeAPIKey); sk != "" { From 61fed0fdd65417af13758706842cf701ae3ed94a Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Wed, 2 Mar 2022 15:28:59 +0100 Subject: [PATCH 06/25] Add a custom endpoint for checking speed limits based on public API keys - `GET /user/limits/:skylink`. --- api/auth.go | 1 + api/handlers.go | 67 ++++++++++++++++++++++++++++++++++++++- api/routes.go | 5 ++- database/publicapikeys.go | 3 -- 4 files changed, 71 insertions(+), 5 deletions(-) diff --git a/api/auth.go b/api/auth.go index 1541af93..bf8989af 100644 --- a/api/auth.go +++ b/api/auth.go @@ -107,6 +107,7 @@ func (api *API) userIDForPubAPIKey(ctx context.Context, pak database.PubAPIKey, } // apiKeyFromRequest extracts the API key from the request and returns it. +// This function does not differentiate between APIKey and PubAPIKey. // It first checks the headers and then the query. func apiKeyFromRequest(r *http.Request) (string, error) { // Check the headers for an API key. diff --git a/api/handlers.go b/api/handlers.go index 5e1e6c0d..b5f74f3c 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -417,7 +417,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http api.WriteJSON(w, resp) return } - // Cache is missed, fetch the data from the DB. + // Cache is missed, fetch the owner of this APIKey from the DB. ak := database.APIKey(akStr) if !ak.IsValid() { api.staticLogger.Traceln("Invalid API key.") @@ -481,6 +481,71 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http api.WriteJSON(w, resp) } +// userLimitsSkylinkGET returns the speed limits which apply to a GET call to +// the given skylink. This method exists to accommodate public API keys. +// +// NOTE: This handler needs to use the noAuth middleware in order to be able to +// optimise its calls to the DB and the use of caching. +func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { + respAnon := UserLimitsGET{ + TierID: database.TierAnonymous, + TierLimits: database.UserLimits[database.TierAnonymous], + } + // Validate the skylink. + skylink := ps.ByName("skylink") + if !database.ValidSkylinkHash(skylink) { + api.staticLogger.Tracef("Invalid skylink: %s", skylink) + api.WriteJSON(w, respAnon) + return + } + // Try to fetch an API attached to the request. + pakStr, err := apiKeyFromRequest(req) + if err != nil { + // We failed to fetch an API key from this request but the request might + // be authenticated in another way, so we'll defer to userLimitsGET. + api.userLimitsGET(u, w, req, ps) + return + } + // Check if that is a valid PubAPIKey. + pak := database.PubAPIKey(pakStr) + if !pak.IsValid() { + // This is not a valid PubAPIKey. Defer to userLimitsGET. + api.userLimitsGET(u, w, req, ps) + return + } + // Check the cache before hitting the database. + tier, ok := api.staticUserTierCache.Get(pakStr + skylink) + if ok { + api.staticLogger.Traceln("Fetching user limits from cache by API key.") + resp := UserLimitsGET{ + TierID: tier, + TierLimits: database.UserLimits[tier], + } + api.WriteJSON(w, resp) + return + } + // Get the owner of this PubAPIKey from the database. + uID, err := api.userIDForPubAPIKey(req.Context(), pak, skylink) + if err != nil { + api.staticLogger.Tracef("Failed to get user ID for this PubAPIKey: %v", err) + api.WriteJSON(w, respAnon) + return + } + user, err := api.staticDB.UserByID(req.Context(), uID) + if err != nil { + api.staticLogger.Tracef("Failed to get user for user ID: %v", err) + api.WriteJSON(w, respAnon) + return + } + // Store the user in the cache with a custom key. + api.staticUserTierCache.Set(user, pakStr+skylink) + resp := UserLimitsGET{ + TierID: user.Tier, + TierLimits: database.UserLimits[user.Tier], + } + api.WriteJSON(w, resp) +} + // userStatsGET returns statistics about an existing user. func (api *API) userStatsGET(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { us, err := api.staticDB.UserStats(req.Context(), *u) diff --git a/api/routes.go b/api/routes.go index 9fecdeca..bdf8d65b 100644 --- a/api/routes.go +++ b/api/routes.go @@ -53,6 +53,7 @@ func (api *API) buildHTTPRoutes() { api.staticRouter.PUT("/user", api.WithDBSession(api.withAuth(api.userPUT))) api.staticRouter.DELETE("/user", api.withAuth(api.userDELETE)) api.staticRouter.GET("/user/limits", api.noAuth(api.userLimitsGET)) + api.staticRouter.GET("/user/limits/:skylink", api.noAuth(api.userLimitsSkylinkGET)) api.staticRouter.GET("/user/stats", api.withAuth(api.userStatsGET)) api.staticRouter.GET("/user/pubkey/register", api.WithDBSession(api.withAuth(api.userPubKeyRegisterGET))) api.staticRouter.POST("/user/pubkey/register", api.WithDBSession(api.withAuth(api.userPubKeyRegisterPOST))) @@ -125,7 +126,9 @@ func (api *API) withAuth(h HandlerWithUser) httprouter.Handle { // logRequest logs information about the current request. func (api *API) logRequest(r *http.Request) { hasAuth := strings.HasPrefix(r.Header.Get("Authorization"), "Bearer") + hasAPIKey := r.Header.Get(APIKeyHeader) != "" || r.FormValue("apiKey") != "" c, err := r.Cookie(CookieName) hasCookie := err == nil && c != nil - api.staticLogger.Tracef("Processing request: %v %v, Auth: %v, Skynet Cookie: %v, Referer: %v, Host: %v, RemoreAddr: %v", r.Method, r.URL, hasAuth, hasCookie, r.Referer(), r.Host, r.RemoteAddr) + api.staticLogger.Tracef("Processing request: %v %v, Auth: %v, API Key: %v, Cookie: %v, Referer: %v, Host: %v, RemoreAddr: %v", + r.Method, r.URL, hasAuth, hasAPIKey, hasCookie, r.Referer(), r.Host, r.RemoteAddr) } diff --git a/database/publicapikeys.go b/database/publicapikeys.go index 3d686a1c..92934c9e 100644 --- a/database/publicapikeys.go +++ b/database/publicapikeys.go @@ -14,9 +14,6 @@ import ( ) type ( - // TODO: I am still not sure whether we should use separate collections or - // keep all API keys in the same one. - // PubAPIKey is a base64URL-encoded representation of []byte with length // PubKeySize PubAPIKey string From a68a2f2a96c338a0c5a47c11281b4ef553873206 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Wed, 2 Mar 2022 15:39:39 +0100 Subject: [PATCH 07/25] Refactor the user tier cache to always require a key and pass it before the cached value. --- api/cache.go | 12 +++--------- api/cache_test.go | 6 +++--- api/handlers.go | 8 ++++---- api/stripe.go | 2 +- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/api/cache.go b/api/cache.go index 8a0606b5..5aa79bfa 100644 --- a/api/cache.go +++ b/api/cache.go @@ -46,10 +46,8 @@ func (utc *userTierCache) Get(sub string) (int, bool) { return ce.Tier, true } -// Set stores the user's tier in the cache. If the customCacheKey is not empty, -// it will be used to store the user in the cache, otherwise the user's sub will -// be used. -func (utc *userTierCache) Set(u *database.User, customCacheKey string) { +// Set stores the user's tier in the cache under the given key. +func (utc *userTierCache) Set(key string, u *database.User) { var ce userTierCacheEntry now := time.Now().UTC() if u.QuotaExceeded { @@ -64,10 +62,6 @@ func (utc *userTierCache) Set(u *database.User, customCacheKey string) { } } utc.mu.Lock() - if customCacheKey == "" { - utc.cache[u.Sub] = ce - } else { - utc.cache[customCacheKey] = ce - } + utc.cache[key] = ce utc.mu.Unlock() } diff --git a/api/cache_test.go b/api/cache_test.go index 835651ca..665c8804 100644 --- a/api/cache_test.go +++ b/api/cache_test.go @@ -24,7 +24,7 @@ func TestUserTierCache(t *testing.T) { t.Fatalf("Expected to get tier %d and %t, got %d and %t.", database.TierAnonymous, false, tier, ok) } // Set the use in the cache. - cache.Set(u, "") + cache.Set(u.Sub, u) // Check again. tier, ok = cache.Get(u.Sub) if !ok || tier != u.Tier { @@ -43,7 +43,7 @@ func TestUserTierCache(t *testing.T) { timeToMonthRollover := 30 * time.Minute u.SubscribedUntil = time.Now().UTC().Add(timeToMonthRollover) // Update the cache. - cache.Set(u, "") + cache.Set(u.Sub, u) // Expect the cache entry's ExpiresAt to be after 30 minutes. timeIn30 := time.Now().UTC().Add(time.Hour - timeToMonthRollover) if ce.ExpiresAt.After(timeIn30) && ce.ExpiresAt.Before(timeIn30.Add(time.Second)) { @@ -61,7 +61,7 @@ func TestUserTierCache(t *testing.T) { t.Fatal("Did not expect to get a cache entry!") } // Update the cache with a custom key. - cache.Set(u, string(ak)) + cache.Set(string(ak), u) // Fetch the data for the custom key. tier, ok = cache.Get(string(ak)) if !ok { diff --git a/api/handlers.go b/api/handlers.go index b5f74f3c..8abfd595 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -437,7 +437,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http return } // Cache the user under the API key they used. - api.staticUserTierCache.Set(u, akStr) + api.staticUserTierCache.Set(akStr, u) resp := UserLimitsGET{ TierID: u.Tier, TierLimits: database.UserLimits[u.Tier], @@ -468,7 +468,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http api.WriteJSON(w, respAnon) return } - api.staticUserTierCache.Set(u, "") + api.staticUserTierCache.Set(u.Sub, u) } tier, ok = api.staticUserTierCache.Get(sub) if !ok { @@ -538,7 +538,7 @@ func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, re return } // Store the user in the cache with a custom key. - api.staticUserTierCache.Set(user, pakStr+skylink) + api.staticUserTierCache.Set(pakStr+skylink, user) resp := UserLimitsGET{ TierID: user.Tier, TierLimits: database.UserLimits[user.Tier], @@ -1224,7 +1224,7 @@ func (api *API) checkUserQuotas(ctx context.Context, u *database.User) { if err != nil { api.staticLogger.Warnf("Failed to save user. User: %+v, err: %s", u, err.Error()) } - api.staticUserTierCache.Set(u, "") + api.staticUserTierCache.Set(u.Sub, u) } } diff --git a/api/stripe.go b/api/stripe.go index be549076..3fdd3220 100644 --- a/api/stripe.go +++ b/api/stripe.go @@ -247,7 +247,7 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er api.staticLogger.Tracef("Subscribed user id '%s', tier %d, until %s.", u.ID, u.Tier, u.SubscribedUntil.String()) } // Re-set the tier cache for this user, in case their tier changed. - api.staticUserTierCache.Set(u, "") + api.staticUserTierCache.Set(u.Sub, u) return err } From e7a4beefcff35a9f57960f03d198d545cdbe6f5c Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Thu, 3 Mar 2022 18:16:16 +0100 Subject: [PATCH 08/25] Merge public and private API keys into one. --- api/apikeys.go | 178 +++++++++++++++++++++++++++++++--- api/auth.go | 71 +++----------- api/handlers.go | 49 ++++++---- api/pubapikeys.go | 108 --------------------- api/routes.go | 12 +-- database/apikeys.go | 152 ++++++++++++++++++++++++++--- database/database.go | 15 --- database/publicapikeys.go | 194 -------------------------------------- test/api/handlers_test.go | 6 +- 9 files changed, 352 insertions(+), 433 deletions(-) delete mode 100644 api/pubapikeys.go delete mode 100644 database/publicapikeys.go diff --git a/api/apikeys.go b/api/apikeys.go index 063d0518..2ec63690 100644 --- a/api/apikeys.go +++ b/api/apikeys.go @@ -3,16 +3,95 @@ package api import ( "net/http" "strconv" + "time" "github.com/SkynetLabs/skynet-accounts/database" "github.com/julienschmidt/httprouter" "gitlab.com/NebulousLabs/errors" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" ) +type ( + // apiKeyPOST describes the body of a POST request that creates an API key + apiKeyPOST struct { + Public bool + Skylinks []string + } + // apiKeyPUT describes the request body for updating an API key + apiKeyPUT struct { + Skylinks []string + } + // apiKeyPATCH describes the request body for updating an API key by + // providing only the requested changes + apiKeyPATCH struct { + Add []string + Remove []string + } + // apiKeyResponse is an API DTO which mirrors database.APIKey. + // TODO Should we reveal the Key each time for public keys? + apiKeyResponse struct { + ID primitive.ObjectID `json:"id"` + UserID primitive.ObjectID `json:"-"` + Public bool `json:"public"` + Key database.APIKey `json:"-"` + Skylinks []string `json:"skylinks"` + CreatedAt time.Time `json:"createdAt"` + } + // apiKeyResponseWithKey is an API DTO which mirrors database.APIKey but + // also reveals the value of the Key field. This should only be used on key + // creation. + // TODO Should we reveal the Key each time for public keys? + apiKeyResponseWithKey struct { + apiKeyResponse + Key database.APIKey `json:"key"` + } +) + +// Valid checks if the request and its parts are valid. +func (akp apiKeyPOST) Valid() bool { + if !akp.Public && len(akp.Skylinks) > 0 { + return false + } + for _, s := range akp.Skylinks { + if !database.ValidSkylinkHash(s) { + return false + } + } + return true +} + +// FromAPIKey populates the struct's fields from the given API key. +// TODO This might be more convenient as a constructor. +func (rwk *apiKeyResponse) FromAPIKey(ak database.APIKeyRecord) { + rwk.ID = ak.ID + rwk.UserID = ak.UserID + rwk.Public = ak.Public + rwk.Key = ak.Key + rwk.Skylinks = ak.Skylinks + rwk.CreatedAt = ak.CreatedAt +} + +// FromAPIKey populates the struct's fields from the given API key. +// TODO This might be more convenient as a constructor. +func (rwk *apiKeyResponseWithKey) FromAPIKey(ak database.APIKeyRecord) { + rwk.ID = ak.ID + rwk.UserID = ak.UserID + rwk.Public = ak.Public + rwk.Key = ak.Key + rwk.Skylinks = ak.Skylinks + rwk.CreatedAt = ak.CreatedAt +} + // userAPIKeyPOST creates a new API key for the user. func (api *API) userAPIKeyPOST(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { - ak, err := api.staticDB.APIKeyCreate(req.Context(), *u) + var body apiKeyPOST + err := parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + ak, err := api.staticDB.APIKeyCreate(req.Context(), *u, body.Public, body.Skylinks) if errors.Contains(err, database.ErrMaxNumAPIKeysExceeded) { err = errors.AddContext(err, "the maximum number of API keys a user can create is "+strconv.Itoa(database.MaxNumAPIKeysPerUser)) api.WriteError(w, err, http.StatusBadRequest) @@ -22,32 +101,57 @@ func (api *API) userAPIKeyPOST(u *database.User, w http.ResponseWriter, req *htt api.WriteError(w, err, http.StatusInternalServerError) return } - // Make the Key visible in JSON form. We do that with an anonymous struct - // because we don't envision that being needed anywhere else in the project. - akWithKey := struct { - database.APIKeyRecord - Key database.APIKey `bson:"key" json:"key"` - }{ - *ak, - ak.Key, + var resp apiKeyResponseWithKey + resp.FromAPIKey(*ak) + api.WriteJSON(w, resp) +} + +// userAPIKeyGET returns a single API key. +func (api *API) userAPIKeyGET(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { + akID, err := primitive.ObjectIDFromHex(ps.ByName("id")) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + ak, err := api.staticDB.APIKeyGet(req.Context(), akID) + // If there is no such API key or it doesn't exist, return a 404. + if errors.Contains(err, mongo.ErrNoDocuments) || (err == nil && ak.UserID != u.ID) { + api.WriteError(w, nil, http.StatusNotFound) + return + } + if err != nil { + api.WriteError(w, err, http.StatusInternalServerError) + return } - api.WriteJSON(w, akWithKey) + var resp apiKeyResponse + resp.FromAPIKey(ak) + api.WriteJSON(w, resp) } -// userAPIKeyGET lists all API keys associated with the user. -func (api *API) userAPIKeyGET(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { +// userAPIKeyLIST lists all API keys associated with the user. +func (api *API) userAPIKeyLIST(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { aks, err := api.staticDB.APIKeyList(req.Context(), *u) if err != nil { api.WriteError(w, err, http.StatusInternalServerError) return } - api.WriteJSON(w, aks) + resp := make([]apiKeyResponse, 0, len(aks)) + for _, ak := range aks { + var r apiKeyResponse + r.FromAPIKey(ak) + resp = append(resp, r) + } + api.WriteJSON(w, resp) } // userAPIKeyDELETE removes an API key. func (api *API) userAPIKeyDELETE(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { - akID := ps.ByName("id") - err := api.staticDB.APIKeyDelete(req.Context(), *u, akID) + akID, err := primitive.ObjectIDFromHex(ps.ByName("id")) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + err = api.staticDB.APIKeyDelete(req.Context(), *u, akID) if err == mongo.ErrNoDocuments { api.WriteError(w, err, http.StatusBadRequest) return @@ -58,3 +162,47 @@ func (api *API) userAPIKeyDELETE(u *database.User, w http.ResponseWriter, req *h } api.WriteSuccess(w) } + +// userAPIKeyPUT updates an API key. Only possible for public API keys. +func (api *API) userAPIKeyPUT(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { + akID, err := primitive.ObjectIDFromHex(ps.ByName("id")) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + var body apiKeyPUT + err = parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + err = api.staticDB.APIKeyUpdate(req.Context(), *u, akID, body.Skylinks) + if err != nil { + api.WriteError(w, err, http.StatusInternalServerError) + return + } + api.WriteSuccess(w) +} + +// userAPIKeyPATCH patches an API key. The difference between PUT and PATCH is +// that PATCH only specifies the changes while PUT provides the expected list of +// covered skylinks. Only possible for public API keys. +func (api *API) userAPIKeyPATCH(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { + akID, err := primitive.ObjectIDFromHex(ps.ByName("id")) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + var body apiKeyPATCH + err = parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + err = api.staticDB.APIKeyPatch(req.Context(), *u, akID, body.Add, body.Remove) + if err != nil { + api.WriteError(w, err, http.StatusInternalServerError) + return + } + api.WriteSuccess(w) +} diff --git a/api/auth.go b/api/auth.go index bf8989af..08824aaa 100644 --- a/api/auth.go +++ b/api/auth.go @@ -1,7 +1,6 @@ package api import ( - "context" "net/http" "strings" @@ -9,12 +8,8 @@ import ( "github.com/SkynetLabs/skynet-accounts/jwt" jwt2 "github.com/lestrrat-go/jwx/jwt" "gitlab.com/NebulousLabs/errors" - "go.mongodb.org/mongo-driver/bson/primitive" ) -// TODO Test the methods here which are still untested. -// - add integration tests - // userAndTokenByRequestToken scans the request for an authentication token, // fetches the corresponding user from the database and returns both user and // token. @@ -34,8 +29,8 @@ func (api *API) userAndTokenByRequestToken(req *http.Request) (*database.User, j return u, token, nil } -// userAndTokenByAPIKey extracts the APIKey or PubAPIKey from the requests and -// validates it. It then returns the user who owns it and a token for that user. +// userAndTokenByAPIKey extracts the APIKey from the request and validates it. +// It then returns the user who owns it and a token for that user. // It first checks the headers and then the query. // This method accesses the database. func (api *API) userAndTokenByAPIKey(req *http.Request) (*database.User, jwt2.Token, error) { @@ -43,36 +38,24 @@ func (api *API) userAndTokenByAPIKey(req *http.Request) (*database.User, jwt2.To if err != nil { return nil, nil, err } - // We should only check for a PubAPIKey if this is a GET request for a valid - // skylink. We ignore the errors here because the API key might not be a - // public one. - if req.Method == http.MethodGet { - pak := database.PubAPIKey(akStr) - sl, err := database.ExtractSkylinkHash(req.RequestURI) - if err == nil && sl != "" && pak.IsValid() { - uID, err := api.userIDForPubAPIKey(req.Context(), pak, sl) - if err == nil { - return api.userAndTokenByUserID(req.Context(), uID) - } - } - } // Check if this is a valid APIKey. ak := database.APIKey(akStr) if !ak.IsValid() { return nil, nil, ErrInvalidAPIKey } - uID, err := api.userIDForAPIKey(req.Context(), ak) + akr, err := api.staticDB.APIKeyByKey(req.Context(), akStr) if err != nil { - return nil, nil, ErrInvalidAPIKey + return nil, nil, err } - return api.userAndTokenByUserID(req.Context(), uID) -} - -// userAndTokenByUserID is a helper method that fetches a given user from the -// database based on their Key, issues a JWT token for them, and returns both -// of those. -func (api *API) userAndTokenByUserID(ctx context.Context, uid primitive.ObjectID) (*database.User, jwt2.Token, error) { - u, err := api.staticDB.UserByID(ctx, uid) + // If we're dealing with a public API key, we need to validate that this + // request is a GET for a covered skylink. + if akr.Public { + sl, err := database.ExtractSkylinkHash(req.RequestURI) + if err != nil || !akr.CoversSkylink(sl) { + return nil, nil, ErrInvalidAPIKey + } + } + u, err := api.staticDB.UserByID(req.Context(), akr.UserID) if err != nil { return nil, nil, err } @@ -80,34 +63,8 @@ func (api *API) userAndTokenByUserID(ctx context.Context, uid primitive.ObjectID return u, t, err } -// userIDForAPIKey looks up the given APIKey and returns the Key of the user that -// issued it. -func (api *API) userIDForAPIKey(ctx context.Context, ak database.APIKey) (primitive.ObjectID, error) { - akRec, err := api.staticDB.APIKeyGetRecord(ctx, ak) - if err != nil { - return primitive.ObjectID{}, err - } - return akRec.UserID, nil -} - -// userIDForPubAPIKey looks up the given PubAPIKey, validates that the target -// skylink is covered by it, and returns the Key of the user that issued the -// PubAPIKey. -func (api *API) userIDForPubAPIKey(ctx context.Context, pak database.PubAPIKey, sl string) (primitive.ObjectID, error) { - pakRec, err := api.staticDB.PubAPIKeyGetRecord(ctx, pak) - if err != nil { - return primitive.ObjectID{}, err - } - for _, s := range pakRec.Skylinks { - if sl == s { - return pakRec.UserID, nil - } - } - return primitive.ObjectID{}, database.ErrUserNotFound -} - // apiKeyFromRequest extracts the API key from the request and returns it. -// This function does not differentiate between APIKey and PubAPIKey. +// This function does not differentiate between APIKey and APIKey. // It first checks the headers and then the query. func apiKeyFromRequest(r *http.Request) (string, error) { // Check the headers for an API key. diff --git a/api/handlers.go b/api/handlers.go index 8abfd595..c81c312f 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -424,13 +424,20 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http api.WriteJSON(w, respAnon) return } - uID, err := api.userIDForAPIKey(req.Context(), ak) + // Get the API key. + akr, err := api.staticDB.APIKeyByKey(req.Context(), akStr) if err != nil { - api.staticLogger.Traceln("Error while fetching user by API key:", err) + api.staticLogger.Trace("API key doesn't exist in the database.") + api.WriteJSON(w, respAnon) + return + } + if akr.Public { + api.staticLogger.Trace("API key is public, cannot be used for general requests") api.WriteJSON(w, respAnon) return } - u, err := api.staticDB.UserByID(req.Context(), uID) + // Get the owner of this API key from the database. + u, err := api.staticDB.UserByID(req.Context(), akr.UserID) if err != nil { api.staticLogger.Traceln("Error while fetching user by API key:", err) api.WriteJSON(w, respAnon) @@ -499,22 +506,22 @@ func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, re return } // Try to fetch an API attached to the request. - pakStr, err := apiKeyFromRequest(req) + akStr, err := apiKeyFromRequest(req) if err != nil { // We failed to fetch an API key from this request but the request might // be authenticated in another way, so we'll defer to userLimitsGET. api.userLimitsGET(u, w, req, ps) return } - // Check if that is a valid PubAPIKey. - pak := database.PubAPIKey(pakStr) - if !pak.IsValid() { - // This is not a valid PubAPIKey. Defer to userLimitsGET. + // Check if that is a valid API key. + ak := database.APIKey(akStr) + if !ak.IsValid() { + // This is not a valid APIKey. Defer to userLimitsGET. api.userLimitsGET(u, w, req, ps) return } // Check the cache before hitting the database. - tier, ok := api.staticUserTierCache.Get(pakStr + skylink) + tier, ok := api.staticUserTierCache.Get(akStr + skylink) if ok { api.staticLogger.Traceln("Fetching user limits from cache by API key.") resp := UserLimitsGET{ @@ -524,21 +531,27 @@ func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, re api.WriteJSON(w, resp) return } - // Get the owner of this PubAPIKey from the database. - uID, err := api.userIDForPubAPIKey(req.Context(), pak, skylink) + // Get the API key. + akr, err := api.staticDB.APIKeyByKey(req.Context(), akStr) if err != nil { - api.staticLogger.Tracef("Failed to get user ID for this PubAPIKey: %v", err) + api.staticLogger.Trace("API key doesn't exist in the database.") + api.WriteJSON(w, respAnon) + return + } + if !akr.CoversSkylink(skylink) { + api.staticLogger.Trace("API key doesn't cover this skylink.") api.WriteJSON(w, respAnon) return } - user, err := api.staticDB.UserByID(req.Context(), uID) + // Get the owner of this API key from the database. + user, err := api.staticDB.UserByID(req.Context(), akr.UserID) if err != nil { api.staticLogger.Tracef("Failed to get user for user ID: %v", err) api.WriteJSON(w, respAnon) return } // Store the user in the cache with a custom key. - api.staticUserTierCache.Set(pakStr+skylink, user) + api.staticUserTierCache.Set(akStr+skylink, user) resp := UserLimitsGET{ TierID: user.Tier, TierLimits: database.UserLimits[user.Tier], @@ -1267,8 +1280,8 @@ func fetchPageSize(form url.Values) (int, error) { } // parseRequestBodyJSON reads a limited portion of the body and decodes it into -// the given obj. The purpose of this is to prevent DoS attacks that rely on -// excessively large request bodies. -func parseRequestBodyJSON(body io.ReadCloser, maxBodySize int64, objRef interface{}) error { - return json.NewDecoder(io.LimitReader(body, maxBodySize)).Decode(&objRef) +// the given struct v. The purpose of this is to prevent DoS attacks that rely +// on excessively large request bodies. +func parseRequestBodyJSON(body io.ReadCloser, maxBodySize int64, v interface{}) error { + return json.NewDecoder(io.LimitReader(body, maxBodySize)).Decode(&v) } diff --git a/api/pubapikeys.go b/api/pubapikeys.go deleted file mode 100644 index 01cf3393..00000000 --- a/api/pubapikeys.go +++ /dev/null @@ -1,108 +0,0 @@ -package api - -import ( - "net/http" - "strconv" - - "github.com/SkynetLabs/skynet-accounts/database" - "github.com/julienschmidt/httprouter" - "gitlab.com/NebulousLabs/errors" - "go.mongodb.org/mongo-driver/mongo" -) - -type ( - // PubAPIKeyPOST describes the request body for creating a new PubAPIKey - PubAPIKeyPOST struct { - Skylinks []string - } - // PubAPIKeyPUT describes the request body for updating a PubAPIKey - PubAPIKeyPUT struct { - Key database.PubAPIKey - Skylinks []string - } - // PubAPIKeyPATCH describes the request body for updating a PubAPIKey by - // providing only the requested changes - PubAPIKeyPATCH struct { - Key database.PubAPIKey - Add []string - Remove []string - } -) - -// userAPIKeyGET lists all PubAPI keys associated with the user. -func (api *API) userPubAPIKeyGET(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { - paks, err := api.staticDB.PubAPIKeyList(req.Context(), *u) - if err != nil { - api.WriteError(w, err, http.StatusInternalServerError) - return - } - api.WriteJSON(w, paks) -} - -// userAPIKeyDELETE removes a PubAPI key. -func (api *API) userPubAPIKeyDELETE(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { - pakID := ps.ByName("id") - err := api.staticDB.PubAPIKeyDelete(req.Context(), *u, pakID) - if err == mongo.ErrNoDocuments { - api.WriteError(w, err, http.StatusBadRequest) - return - } - if err != nil { - api.WriteError(w, err, http.StatusInternalServerError) - return - } - api.WriteSuccess(w) -} - -// userPubAPIKeyPOST creates a new PubAPI key for the user. -func (api *API) userPubAPIKeyPOST(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { - var body PubAPIKeyPOST - err := parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) - if err != nil { - api.WriteError(w, err, http.StatusBadRequest) - return - } - pakRec, err := api.staticDB.PubAPIKeyCreate(req.Context(), *u, body.Skylinks) - if errors.Contains(err, database.ErrMaxNumAPIKeysExceeded) { - err = errors.AddContext(err, "the maximum number of API keys a user can create is "+strconv.Itoa(database.MaxNumAPIKeysPerUser)) - api.WriteError(w, err, http.StatusBadRequest) - return - } - if err != nil { - api.WriteError(w, err, http.StatusInternalServerError) - return - } - api.WriteJSON(w, pakRec) -} - -// userPubAPIKeyPUT updates a PubAPI key. -func (api *API) userPubAPIKeyPUT(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { - var body PubAPIKeyPUT - err := parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) - if err != nil { - api.WriteError(w, err, http.StatusBadRequest) - return - } - err = api.staticDB.PubAPIKeyUpdate(req.Context(), *u, body.Key, body.Skylinks) - if err != nil { - api.WriteError(w, err, http.StatusInternalServerError) - return - } - api.WriteSuccess(w) -} - -// userPubAPIKeyPATCH patches a PubAPI key. The difference between PUT and PATCH is -func (api *API) userPubAPIKeyPATCH(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { - var body PubAPIKeyPATCH - err := parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) - if err != nil { - api.WriteError(w, err, http.StatusBadRequest) - return - } - err = api.staticDB.PubAPIKeyPatch(req.Context(), *u, body.Key, body.Add, body.Remove) - if err != nil { - api.WriteError(w, err, http.StatusInternalServerError) - return - } - api.WriteSuccess(w) -} diff --git a/api/routes.go b/api/routes.go index bdf8d65b..d9194f62 100644 --- a/api/routes.go +++ b/api/routes.go @@ -63,16 +63,12 @@ func (api *API) buildHTTPRoutes() { // Endpoints for user API keys. api.staticRouter.POST("/user/apikeys", api.WithDBSession(api.withAuth(api.userAPIKeyPOST))) - api.staticRouter.GET("/user/apikeys", api.withAuth(api.userAPIKeyGET)) + api.staticRouter.GET("/user/apikeys", api.withAuth(api.userAPIKeyLIST)) + api.staticRouter.GET("/user/apikeys/:id", api.withAuth(api.userAPIKeyGET)) + api.staticRouter.PUT("/user/apikeys/:id", api.WithDBSession(api.withAuth(api.userAPIKeyPUT))) + api.staticRouter.PATCH("/user/apikeys/:id", api.WithDBSession(api.withAuth(api.userAPIKeyPATCH))) api.staticRouter.DELETE("/user/apikeys/:id", api.withAuth(api.userAPIKeyDELETE)) - // Endpoints for user public API keys. - api.staticRouter.POST("/user/pubapikeys", api.WithDBSession(api.withAuth(api.userPubAPIKeyPOST))) - api.staticRouter.GET("/user/pubapikeys", api.withAuth(api.userPubAPIKeyGET)) - api.staticRouter.PUT("/user/pubapikeys", api.WithDBSession(api.withAuth(api.userPubAPIKeyPUT))) - api.staticRouter.PATCH("/user/pubapikeys", api.WithDBSession(api.withAuth(api.userPubAPIKeyPATCH))) - api.staticRouter.DELETE("/user/pubapikeys/:id", api.withAuth(api.userPubAPIKeyDELETE)) - // Endpoints for email communication with the user. api.staticRouter.GET("/user/confirm", api.WithDBSession(api.noAuth(api.userConfirmGET))) // TODO POST api.staticRouter.POST("/user/reconfirm", api.WithDBSession(api.withAuth(api.userReconfirmPOST))) diff --git a/database/apikeys.go b/database/apikeys.go index 2496a653..3b733a4e 100644 --- a/database/apikeys.go +++ b/database/apikeys.go @@ -10,13 +10,25 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) /** API keys are authentication tokens generated by users. They do not expire, thus allowing users to use them for a long time and to embed them in apps and on machines. API keys can be revoked when they are no longer needed or if they get -compromised. This is done by deleting them from this service. +compromised or are no longer needed. This is done by deleting them from this +service. + +There are two kinds of API keys - public and private. We differentiate between +them by the `public` flag. + +Private API keys give full API access - using them is equivalent to using a JWT +token, either via an authorization header or a cookie. + +Public API keys can only be use for downloading skylinks. The list of skylinks +that can be downloaded by a given public API key is stored under the `skylinks` +array within the API key record. */ var ( @@ -28,17 +40,25 @@ var ( // ErrMaxNumAPIKeysExceeded is returned when a user tries to create a new // API key after already having the maximum allowed number. ErrMaxNumAPIKeysExceeded = errors.New("maximum number of api keys exceeded") + // ErrInvalidAPIKeyOperation covers a range of invalid operations on API + // keys. Some examples include: defining a list of skylinks on a private + // API key, editing a private API key. This error should be used with + // additional context, specifying the exact operation that failed. + ErrInvalidAPIKeyOperation = errors.New("invalid api key operation") ) type ( // APIKey is a base64URL-encoded representation of []byte with length PubKeySize APIKey string // APIKeyRecord is a non-expiring authentication token generated on user - // demand. + // demand. Public API keys allow downloading a given set of skylinks, while + // private API keys give full API access. APIKeyRecord struct { ID primitive.ObjectID `bson:"_id,omitempty" json:"id"` UserID primitive.ObjectID `bson:"user_id" json:"-"` + Public bool `bson:"public" json:"public"` Key APIKey `bson:"key" json:"-"` + Skylinks []string `bson:"skylinks" json:"skylinks"` CreatedAt time.Time `bson:"created_at" json:"createdAt"` } ) @@ -52,8 +72,22 @@ func (ak APIKey) IsValid() bool { return err == nil && n == PubKeySize } +// CoversSkylink tells us whether a given API key covers a given skylink. +// Private API keys cover all skylinks while public ones - only a limited set. +func (akr APIKeyRecord) CoversSkylink(sl string) bool { + if akr.Public { + return true + } + for _, s := range akr.Skylinks { + if s == sl { + return true + } + } + return false +} + // APIKeyCreate creates a new API key. -func (db *DB) APIKeyCreate(ctx context.Context, user User) (*APIKeyRecord, error) { +func (db *DB) APIKeyCreate(ctx context.Context, user User, public bool, skylinks []string) (*APIKeyRecord, error) { if user.ID.IsZero() { return nil, errors.New("invalid user") } @@ -64,9 +98,14 @@ func (db *DB) APIKeyCreate(ctx context.Context, user User) (*APIKeyRecord, error if n > int64(MaxNumAPIKeysPerUser) { return nil, ErrMaxNumAPIKeysExceeded } + if !public && len(skylinks) > 0 { + return nil, errors.AddContext(ErrInvalidAPIKeyOperation, "cannot define skylinks for a private api key") + } ak := APIKeyRecord{ UserID: user.ID, + Public: public, Key: APIKey(base64.URLEncoding.EncodeToString(fastrand.Bytes(PubKeySize))), + Skylinks: skylinks, CreatedAt: time.Now().UTC(), } ior, err := db.staticAPIKeys.InsertOne(ctx, ak) @@ -78,16 +117,12 @@ func (db *DB) APIKeyCreate(ctx context.Context, user User) (*APIKeyRecord, error } // APIKeyDelete deletes an API key. -func (db *DB) APIKeyDelete(ctx context.Context, user User, akID string) error { +func (db *DB) APIKeyDelete(ctx context.Context, user User, akID primitive.ObjectID) error { if user.ID.IsZero() { return errors.New("invalid user") } - id, err := primitive.ObjectIDFromHex(akID) - if err != nil { - return errors.AddContext(err, "invalid API key ID") - } filter := bson.M{ - "_id": id, + "_id": akID, "user_id": user.ID, } dr, err := db.staticAPIKeys.DeleteOne(ctx, filter) @@ -100,18 +135,34 @@ func (db *DB) APIKeyDelete(ctx context.Context, user User, akID string) error { return nil } -// APIKeyGetRecord returns a specific API key. -func (db *DB) APIKeyGetRecord(ctx context.Context, ak APIKey) (APIKeyRecord, error) { - sr := db.staticAPIKeys.FindOne(ctx, bson.M{"key": ak}) +// APIKeyByKey returns a specific API key. +func (db *DB) APIKeyByKey(ctx context.Context, key string) (APIKeyRecord, error) { + filter := bson.M{"key": key} + sr := db.staticAPIKeys.FindOne(ctx, filter) if sr.Err() != nil { return APIKeyRecord{}, sr.Err() } - var akRecord APIKeyRecord - err := sr.Decode(&akRecord) + var akr APIKeyRecord + err := sr.Decode(&akr) if err != nil { return APIKeyRecord{}, err } - return akRecord, nil + return akr, nil +} + +// APIKeyGet returns a specific API key. +func (db *DB) APIKeyGet(ctx context.Context, akID primitive.ObjectID) (APIKeyRecord, error) { + filter := bson.M{"_id": akID} + sr := db.staticAPIKeys.FindOne(ctx, filter) + if sr.Err() != nil { + return APIKeyRecord{}, sr.Err() + } + var akr APIKeyRecord + err := sr.Decode(&akr) + if err != nil { + return APIKeyRecord{}, err + } + return akr, nil } // APIKeyList lists all API keys that belong to the user. @@ -132,3 +183,74 @@ func (db *DB) APIKeyList(ctx context.Context, user User) ([]APIKeyRecord, error) } return aks, nil } + +// APIKeyUpdate updates an existing API key. This works by replacing the +// list of Skylinks within the API key record. Only valid for public API keys. +func (db *DB) APIKeyUpdate(ctx context.Context, user User, akID primitive.ObjectID, skylinks []string) error { + if user.ID.IsZero() { + return errors.New("invalid user") + } + // Validate all given skylinks. + for _, s := range skylinks { + if !ValidSkylinkHash(s) { + return ErrInvalidSkylink + } + } + filter := bson.M{ + "_id": akID, + "public": &True, // you can only update public API keys + "user_id": user.ID, + } + update := bson.M{"skylinks": skylinks} + opts := options.UpdateOptions{ + Upsert: &False, + } + _, err := db.staticAPIKeys.UpdateOne(ctx, filter, update, &opts) + return err +} + +// APIKeyPatch updates an existing API key. This works by adding and removing +// skylinks to its record. Only valid for public API keys. +func (db *DB) APIKeyPatch(ctx context.Context, user User, akID primitive.ObjectID, addSkylinks, removeSkylinks []string) error { + if user.ID.IsZero() { + return errors.New("invalid user") + } + // Validate all given skylinks. + for _, s := range append(addSkylinks, removeSkylinks...) { + if !ValidSkylinkHash(s) { + return ErrInvalidSkylink + } + } + filter := bson.M{ + "_id": akID, + "public": &True, // you can only update public API keys + } + var update bson.M + // First, all new skylinks to the record. + if len(addSkylinks) > 0 { + update = bson.M{ + "$push": bson.M{"skylinks": bson.M{"$each": addSkylinks}}, + } + opts := options.UpdateOptions{ + Upsert: &False, + } + _, err := db.staticAPIKeys.UpdateOne(ctx, filter, update, &opts) + if err != nil { + return err + } + } + // Then, remove all skylinks that need to be removed. + if len(removeSkylinks) > 0 { + update = bson.M{ + "pull": bson.M{"skylinks": bson.M{"$in": addSkylinks}}, + } + opts := options.UpdateOptions{ + Upsert: &False, + } + _, err := db.staticAPIKeys.UpdateOne(ctx, filter, update, &opts) + if err != nil { + return err + } + } + return nil +} diff --git a/database/database.go b/database/database.go index 30b6d232..ca71d16a 100644 --- a/database/database.go +++ b/database/database.go @@ -50,9 +50,6 @@ var ( collConfiguration = "configuration" // collAPIKeys defines the name of the db table with API keys for users. collAPIKeys = "api_keys" - // collPubAPIKeys defines the name of the db table with public API keys for - // users. - collPubAPIKeys = "pub_api_keys" // DefaultPageSize defines the default number of records to return. DefaultPageSize = 10 @@ -100,7 +97,6 @@ type ( staticUnconfirmedUserUpdates *mongo.Collection staticConfiguration *mongo.Collection staticAPIKeys *mongo.Collection - staticPubAPIKeys *mongo.Collection staticDeps lib.Dependencies staticLogger *logrus.Logger } @@ -152,7 +148,6 @@ func NewCustomDB(ctx context.Context, dbName string, creds DBCredentials, logger staticUnconfirmedUserUpdates: database.Collection(collUnconfirmedUserUpdates), staticConfiguration: database.Collection(collConfiguration), staticAPIKeys: database.Collection(collAPIKeys), - staticPubAPIKeys: database.Collection(collPubAPIKeys), staticLogger: logger, } return db, nil @@ -299,16 +294,6 @@ func ensureDBSchema(ctx context.Context, db *mongo.Database, log *logrus.Logger) Options: options.Index().SetName("user_id"), }, }, - collPubAPIKeys: { - { - Keys: bson.D{{"key", 1}}, - Options: options.Index().SetName("key_unique").SetUnique(true), - }, - { - Keys: bson.D{{"user_id", 1}}, - Options: options.Index().SetName("user_id"), - }, - }, } for collName, models := range schema { diff --git a/database/publicapikeys.go b/database/publicapikeys.go deleted file mode 100644 index 92934c9e..00000000 --- a/database/publicapikeys.go +++ /dev/null @@ -1,194 +0,0 @@ -package database - -import ( - "context" - "encoding/base64" - "time" - - "gitlab.com/NebulousLabs/errors" - "gitlab.com/NebulousLabs/fastrand" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" -) - -type ( - // PubAPIKey is a base64URL-encoded representation of []byte with length - // PubKeySize - PubAPIKey string - // PubAPIKeyRecord is a non-expiring authentication token generated on user - // demand. This token allows anyone to access a set of pre-determined - // skylinks. The traffic generated by this access is counted towards the - // issuing user's balance. - PubAPIKeyRecord struct { - ID primitive.ObjectID `bson:"_id,omitempty" json:"id"` - UserID primitive.ObjectID `bson:"user_id" json:"userID"` - Key PubAPIKey `bson:"key" json:"key"` - Skylinks []string `bson:"skylinks" json:"skylinks"` - CreatedAt time.Time `bson:"created_at" json:"createdAt"` - } -) - -// IsValid checks whether the underlying string satisfies the type's requirement -// to represent a []byte with length PubKeySize which is encoded as base64URL. -// This method does NOT check whether the public API key exists in the database. -func (pak PubAPIKey) IsValid() bool { - b := make([]byte, PubKeySize) - n, err := base64.URLEncoding.Decode(b, []byte(pak)) - return err == nil && n == PubKeySize -} - -// PubAPIKeyCreate creates a new public API key. -func (db *DB) PubAPIKeyCreate(ctx context.Context, user User, skylinks []string) (*PubAPIKeyRecord, error) { - if user.ID.IsZero() { - return nil, errors.New("invalid user") - } - n, err := db.staticPubAPIKeys.CountDocuments(ctx, bson.M{"user_id": user.ID}) - if err != nil { - return nil, errors.AddContext(err, "failed to ensure user can create a new API key") - } - if n > int64(MaxNumAPIKeysPerUser) { - return nil, ErrMaxNumAPIKeysExceeded - } - // Validate all given skylinks. - for _, s := range skylinks { - if !ValidSkylinkHash(s) { - return nil, ErrInvalidSkylink - } - } - pakRec := PubAPIKeyRecord{ - UserID: user.ID, - Key: PubAPIKey(base64.URLEncoding.EncodeToString(fastrand.Bytes(PubKeySize))), - Skylinks: skylinks, - CreatedAt: time.Now().UTC(), - } - ior, err := db.staticAPIKeys.InsertOne(ctx, pakRec) - if err != nil { - return nil, err - } - pakRec.ID = ior.InsertedID.(primitive.ObjectID) - return &pakRec, nil -} - -// PubAPIKeyUpdate updates an existing PubAPIKey. This works by replacing the -// list of Skylinks within the PubAPIKey record. -func (db *DB) PubAPIKeyUpdate(ctx context.Context, user User, pak PubAPIKey, skylinks []string) error { - if user.ID.IsZero() { - return errors.New("invalid user") - } - // Validate all given skylinks. - for _, s := range skylinks { - if !ValidSkylinkHash(s) { - return ErrInvalidSkylink - } - } - filter := bson.M{ - "key": pak, - "user_id": user.ID, - } - update := bson.M{"skylinks": skylinks} - opts := options.UpdateOptions{ - Upsert: &False, - } - _, err := db.staticPubAPIKeys.UpdateOne(ctx, filter, update, &opts) - return err -} - -// PubAPIKeyPatch updates an existing PubAPIKey. This works by adding and -// removing specific elements directly in Mongo. -func (db *DB) PubAPIKeyPatch(ctx context.Context, user User, pak PubAPIKey, addSkylinks, removeSkylinks []string) error { - if user.ID.IsZero() { - return errors.New("invalid user") - } - // Validate all given skylinks. - for _, s := range append(addSkylinks, removeSkylinks...) { - if !ValidSkylinkHash(s) { - return ErrInvalidSkylink - } - } - var filter, update bson.M - // First, all new skylinks to the record. - if len(addSkylinks) > 0 { - filter = bson.M{"key": pak} - update = bson.M{ - "$push": bson.M{"skylinks": bson.M{"$each": addSkylinks}}, - } - opts := options.UpdateOptions{ - Upsert: &False, - } - _, err := db.staticPubAPIKeys.UpdateOne(ctx, filter, update, &opts) - if err != nil { - return err - } - } - // Then, remove all skylinks that need to be removed. - if len(removeSkylinks) > 0 { - filter = bson.M{"key": pak} - update = bson.M{ - "pull": bson.M{"skylinks": bson.M{"$in": addSkylinks}}, - } - opts := options.UpdateOptions{ - Upsert: &False, - } - _, err := db.staticPubAPIKeys.UpdateOne(ctx, filter, update, &opts) - if err != nil { - return err - } - } - return nil -} - -// PubAPIKeyDelete deletes a public API key. -func (db *DB) PubAPIKeyDelete(ctx context.Context, user User, pakID string) error { - if user.ID.IsZero() { - return errors.New("invalid user") - } - id, err := primitive.ObjectIDFromHex(pakID) - if err != nil { - return errors.AddContext(err, "invalid API key ID") - } - filter := bson.M{ - "_id": id, - "user_id": user.ID, - } - dr, err := db.staticPubAPIKeys.DeleteOne(ctx, filter) - if err != nil { - return err - } - if dr.DeletedCount == 0 { - return mongo.ErrNoDocuments - } - return nil -} - -// PubAPIKeyGetRecord returns a specific public API key. -func (db *DB) PubAPIKeyGetRecord(ctx context.Context, pak PubAPIKey) (PubAPIKeyRecord, error) { - sr := db.staticPubAPIKeys.FindOne(ctx, bson.M{"key": pak}) - if sr.Err() != nil { - return PubAPIKeyRecord{}, sr.Err() - } - var pakRec PubAPIKeyRecord - err := sr.Decode(&pakRec) - if err != nil { - return PubAPIKeyRecord{}, err - } - return pakRec, nil -} - -// PubAPIKeyList lists all public API keys that belong to the user. -func (db *DB) PubAPIKeyList(ctx context.Context, user User) ([]*PubAPIKeyRecord, error) { - if user.ID.IsZero() { - return nil, errors.New("invalid user") - } - c, err := db.staticPubAPIKeys.Find(ctx, bson.M{"user_id": user.ID}) - if err != nil { - return nil, err - } - var paks []*PubAPIKeyRecord - err = c.All(ctx, &paks) - if err != nil { - return nil, err - } - return paks, nil -} diff --git a/test/api/handlers_test.go b/test/api/handlers_test.go index e27675ce..77c1737f 100644 --- a/test/api/handlers_test.go +++ b/test/api/handlers_test.go @@ -444,10 +444,10 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { if err != nil { t.Fatal(err) } - var akRec struct { + var akr struct { Key database.APIKey `json:"key"` } - err = json.Unmarshal(b, &akRec) + err = json.Unmarshal(b, &akr) if err != nil { t.Fatal(err) } @@ -487,7 +487,7 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { } // Call /user/limits with an API key. Expect TierFree response. - tl, _, err = at.UserLimitsGET(nil, map[string]string{api.APIKeyHeader: string(akRec.Key)}) + tl, _, err = at.UserLimitsGET(nil, map[string]string{api.APIKeyHeader: string(akr.Key)}) if err != nil { t.Fatal(err) } From f9cbcba9d72ab4d30199a17f133d22b13275c70d Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Fri, 4 Mar 2022 11:05:24 +0100 Subject: [PATCH 09/25] Add integration tests for public API Keys. --- api/apikeys.go | 48 ++++++------- database/apikeys.go | 6 +- test/api/api_test.go | 24 +------ test/api/apikeys_test.go | 126 ++++++++++++++++++++++++++++++++-- test/api/auth_test.go | 1 + test/api/handlers_test.go | 5 +- test/database/apikeys_test.go | 1 + test/tester.go | 81 +++++++++++++++++++--- test/utils.go | 21 ++++++ 9 files changed, 249 insertions(+), 64 deletions(-) create mode 100644 test/api/auth_test.go create mode 100644 test/database/apikeys_test.go diff --git a/api/apikeys.go b/api/apikeys.go index 2ec63690..f14ef010 100644 --- a/api/apikeys.go +++ b/api/apikeys.go @@ -13,43 +13,43 @@ import ( ) type ( - // apiKeyPOST describes the body of a POST request that creates an API key - apiKeyPOST struct { - Public bool - Skylinks []string + // APIKeyPOST describes the body of a POST request that creates an API key + APIKeyPOST struct { + Public bool `json:"public,string"` + Skylinks []string `json:"skylinks"` } - // apiKeyPUT describes the request body for updating an API key - apiKeyPUT struct { + // APIKeyPUT describes the request body for updating an API key + APIKeyPUT struct { Skylinks []string } - // apiKeyPATCH describes the request body for updating an API key by + // APIKeyPATCH describes the request body for updating an API key by // providing only the requested changes - apiKeyPATCH struct { + APIKeyPATCH struct { Add []string Remove []string } - // apiKeyResponse is an API DTO which mirrors database.APIKey. + // APIKeyResponse is an API DTO which mirrors database.APIKey. // TODO Should we reveal the Key each time for public keys? - apiKeyResponse struct { + APIKeyResponse struct { ID primitive.ObjectID `json:"id"` UserID primitive.ObjectID `json:"-"` - Public bool `json:"public"` + Public bool `json:"public,string"` Key database.APIKey `json:"-"` Skylinks []string `json:"skylinks"` CreatedAt time.Time `json:"createdAt"` } - // apiKeyResponseWithKey is an API DTO which mirrors database.APIKey but + // APIKeyResponseWithKey is an API DTO which mirrors database.APIKey but // also reveals the value of the Key field. This should only be used on key // creation. // TODO Should we reveal the Key each time for public keys? - apiKeyResponseWithKey struct { - apiKeyResponse + APIKeyResponseWithKey struct { + APIKeyResponse Key database.APIKey `json:"key"` } ) // Valid checks if the request and its parts are valid. -func (akp apiKeyPOST) Valid() bool { +func (akp APIKeyPOST) Valid() bool { if !akp.Public && len(akp.Skylinks) > 0 { return false } @@ -63,7 +63,7 @@ func (akp apiKeyPOST) Valid() bool { // FromAPIKey populates the struct's fields from the given API key. // TODO This might be more convenient as a constructor. -func (rwk *apiKeyResponse) FromAPIKey(ak database.APIKeyRecord) { +func (rwk *APIKeyResponse) FromAPIKey(ak database.APIKeyRecord) { rwk.ID = ak.ID rwk.UserID = ak.UserID rwk.Public = ak.Public @@ -74,7 +74,7 @@ func (rwk *apiKeyResponse) FromAPIKey(ak database.APIKeyRecord) { // FromAPIKey populates the struct's fields from the given API key. // TODO This might be more convenient as a constructor. -func (rwk *apiKeyResponseWithKey) FromAPIKey(ak database.APIKeyRecord) { +func (rwk *APIKeyResponseWithKey) FromAPIKey(ak database.APIKeyRecord) { rwk.ID = ak.ID rwk.UserID = ak.UserID rwk.Public = ak.Public @@ -85,7 +85,7 @@ func (rwk *apiKeyResponseWithKey) FromAPIKey(ak database.APIKeyRecord) { // userAPIKeyPOST creates a new API key for the user. func (api *API) userAPIKeyPOST(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { - var body apiKeyPOST + var body APIKeyPOST err := parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) if err != nil { api.WriteError(w, err, http.StatusBadRequest) @@ -101,7 +101,7 @@ func (api *API) userAPIKeyPOST(u *database.User, w http.ResponseWriter, req *htt api.WriteError(w, err, http.StatusInternalServerError) return } - var resp apiKeyResponseWithKey + var resp APIKeyResponseWithKey resp.FromAPIKey(*ak) api.WriteJSON(w, resp) } @@ -123,7 +123,7 @@ func (api *API) userAPIKeyGET(u *database.User, w http.ResponseWriter, req *http api.WriteError(w, err, http.StatusInternalServerError) return } - var resp apiKeyResponse + var resp APIKeyResponse resp.FromAPIKey(ak) api.WriteJSON(w, resp) } @@ -135,9 +135,9 @@ func (api *API) userAPIKeyLIST(u *database.User, w http.ResponseWriter, req *htt api.WriteError(w, err, http.StatusInternalServerError) return } - resp := make([]apiKeyResponse, 0, len(aks)) + resp := make([]APIKeyResponse, 0, len(aks)) for _, ak := range aks { - var r apiKeyResponse + var r APIKeyResponse r.FromAPIKey(ak) resp = append(resp, r) } @@ -170,7 +170,7 @@ func (api *API) userAPIKeyPUT(u *database.User, w http.ResponseWriter, req *http api.WriteError(w, err, http.StatusBadRequest) return } - var body apiKeyPUT + var body APIKeyPUT err = parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) if err != nil { api.WriteError(w, err, http.StatusBadRequest) @@ -193,7 +193,7 @@ func (api *API) userAPIKeyPATCH(u *database.User, w http.ResponseWriter, req *ht api.WriteError(w, err, http.StatusBadRequest) return } - var body apiKeyPATCH + var body APIKeyPATCH err = parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) if err != nil { api.WriteError(w, err, http.StatusBadRequest) diff --git a/database/apikeys.go b/database/apikeys.go index 3b733a4e..26343424 100644 --- a/database/apikeys.go +++ b/database/apikeys.go @@ -56,7 +56,7 @@ type ( APIKeyRecord struct { ID primitive.ObjectID `bson:"_id,omitempty" json:"id"` UserID primitive.ObjectID `bson:"user_id" json:"-"` - Public bool `bson:"public" json:"public"` + Public bool `bson:"public,string" json:"public,string"` Key APIKey `bson:"key" json:"-"` Skylinks []string `bson:"skylinks" json:"skylinks"` CreatedAt time.Time `bson:"created_at" json:"createdAt"` @@ -201,7 +201,7 @@ func (db *DB) APIKeyUpdate(ctx context.Context, user User, akID primitive.Object "public": &True, // you can only update public API keys "user_id": user.ID, } - update := bson.M{"skylinks": skylinks} + update := bson.M{"$set": bson.M{"skylinks": skylinks}} opts := options.UpdateOptions{ Upsert: &False, } @@ -242,7 +242,7 @@ func (db *DB) APIKeyPatch(ctx context.Context, user User, akID primitive.ObjectI // Then, remove all skylinks that need to be removed. if len(removeSkylinks) > 0 { update = bson.M{ - "pull": bson.M{"skylinks": bson.M{"$in": addSkylinks}}, + "$pull": bson.M{"skylinks": bson.M{"$in": removeSkylinks}}, } opts := options.UpdateOptions{ Upsert: &False, diff --git a/test/api/api_test.go b/test/api/api_test.go index fc964f41..b366a51b 100644 --- a/test/api/api_test.go +++ b/test/api/api_test.go @@ -1,7 +1,6 @@ package api import ( - "bytes" "context" "encoding/hex" "encoding/json" @@ -20,27 +19,6 @@ import ( "gitlab.com/NebulousLabs/errors" ) -// TestResponseWriter is a testing ResponseWriter implementation. -type TestResponseWriter struct { - Buffer bytes.Buffer - Status int -} - -// Header implementation. -func (w TestResponseWriter) Header() http.Header { - return http.Header{} -} - -// Write implementation. -func (w TestResponseWriter) Write(b []byte) (int, error) { - return w.Buffer.Write(b) -} - -// WriteHeader implementation. -func (w TestResponseWriter) WriteHeader(statusCode int) { - w.Status = statusCode -} - // TestWithDBSession ensures that database transactions are started, committed, // and aborted properly. func TestWithDBSession(t *testing.T) { @@ -124,7 +102,7 @@ func TestWithDBSession(t *testing.T) { testAPI.WriteError(w, errors.New("error"), http.StatusInternalServerError) } - var rw TestResponseWriter + var rw test.ResponseWriter var ps httprouter.Params req := (&http.Request{}).WithContext(ctx) // Call the success handler. diff --git a/test/api/apikeys_test.go b/test/api/apikeys_test.go index f6443b88..00387eb3 100644 --- a/test/api/apikeys_test.go +++ b/test/api/apikeys_test.go @@ -6,6 +6,7 @@ import ( "net/url" "testing" + "github.com/SkynetLabs/skynet-accounts/api" "github.com/SkynetLabs/skynet-accounts/database" "github.com/SkynetLabs/skynet-accounts/skynet" "github.com/SkynetLabs/skynet-accounts/test" @@ -13,8 +14,9 @@ import ( "go.sia.tech/siad/modules" ) -// testAPIKeysFlow validates the creation, listing, and deletion of API keys. -func testAPIKeysFlow(t *testing.T, at *test.AccountsTester) { +// testPrivateAPIKeysFlow validates the creation, listing, and deletion of private +// API keys. +func testPrivateAPIKeysFlow(t *testing.T, at *test.AccountsTester) { name := test.DBNameForTest(t.Name()) r, body, err := at.CreateUserPost(name+"@siasky.net", name+"_pass") if err != nil { @@ -47,6 +49,10 @@ func testAPIKeysFlow(t *testing.T, at *test.AccountsTester) { if err != nil { t.Fatal(err) } + // Make sure the API key is private. + if ak1.Public { + t.Fatal("Expected the API key to be private.") + } // Create another API key. r, body, err = at.Post("/user/apikeys", nil, nil) @@ -106,8 +112,8 @@ func testAPIKeysFlow(t *testing.T, at *test.AccountsTester) { } } -// testAPIKeysUsage makes sure that we can use API keys to make API calls. -func testAPIKeysUsage(t *testing.T, at *test.AccountsTester) { +// testPrivateAPIKeysUsage makes sure that we can use API keys to make API calls. +func testPrivateAPIKeysUsage(t *testing.T, at *test.AccountsTester) { name := test.DBNameForTest(t.Name()) // Create a test user. email := name + "@siasky.net" @@ -162,3 +168,115 @@ func testAPIKeysUsage(t *testing.T, at *test.AccountsTester) { uploadSize, us.TotalDownloadsSize, us.NumUploads, skynet.BandwidthUploadCost(uploadSize), us.BandwidthUploads) } } + +// TestPublicAPIKeyFlow validates the creation, listing, and deletion of public +// API keys. +func testPublicAPIKeysFlow(t *testing.T, at *test.AccountsTester) { + name := test.DBNameForTest(t.Name()) + r, body, err := at.CreateUserPost(name+"@siasky.net", name+"_pass") + if err != nil { + t.Fatal(err, string(body)) + } + at.Cookie = test.ExtractCookie(r) + + sl1 := "AQAh2vxStoSJ_M9tWcTgqebUWerCAbpMfn9xxa9E29UOuw" + sl2 := "AADDE7_5MJyl1DKyfbuQMY_XBOBC9bR7idiU6isp6LXxEw" + + aks := make([]database.APIKeyRecord, 0) + + // List all API keys this user has. Expect the list to be empty. + r, body, err = at.Get("/user/apikeys", nil) + if err != nil { + t.Fatal(err, string(body)) + } + err = json.Unmarshal(body, &aks) + if err != nil { + t.Fatal(err) + } + if len(aks) > 0 { + t.Fatalf("Expected an empty list of API keys, got %+v.", aks) + } + // Create a public API key. + akPost := api.APIKeyPOST{ + Public: true, + Skylinks: []string{sl1}, + } + akr, s, err := at.UserAPIKeysPOST(akPost) + if err != nil || s != http.StatusOK { + t.Fatal(err) + } + // List all API keys again. Expect to find a key. + r, body, err = at.Get("/user/apikeys", nil) + if err != nil { + t.Fatal(err, string(body)) + } + err = json.Unmarshal(body, &aks) + if err != nil { + t.Fatal(err) + } + if len(aks) != 1 { + t.Fatalf("Expected one API key, got %d.", len(aks)) + } + if aks[0].Skylinks[0] != sl1 { + t.Fatal("Unexpected skylinks list", aks[0].Skylinks) + } + // Update a public API key. Expect to go from sl1 to sl2. + s, err = at.UserAPIKeysPUT(akr.ID, api.APIKeyPUT{Skylinks: []string{sl2}}) + if err != nil { + t.Fatal(err) + } + // Get the key and verify the change. + r, body, err = at.Get("/user/apikeys/"+akr.ID.Hex(), nil) + if err != nil { + t.Fatal(err, string(body)) + } + var akr1 database.APIKeyRecord + err = json.Unmarshal(body, &akr1) + if err != nil { + t.Fatal(err) + } + if akr1.Skylinks[0] != sl2 { + t.Fatal("Unexpected skylinks list", aks[0].Skylinks) + } + // Patch a public API key. Expect to go from sl2 to sl1. + akPatch := api.APIKeyPATCH{ + Add: []string{sl1}, + Remove: []string{sl2}, + } + s, err = at.UserAPIKeysPATCH(akr.ID, akPatch) + if err != nil { + t.Fatal(err) + } + // List and verify the change. + r, body, err = at.Get("/user/apikeys", nil) + if err != nil { + t.Fatal(err, string(body)) + } + err = json.Unmarshal(body, &aks) + if err != nil { + t.Fatal(err) + } + if len(aks) != 1 { + t.Fatalf("Expected one API key, got %d.", len(aks)) + } + if aks[0].Skylinks[0] != sl1 { + t.Fatal("Unexpected skylinks list", aks[0].Skylinks) + } + // Delete a public API key. + r, body, err = at.Delete("/user/apikeys/"+akr.ID.Hex(), nil) + if err != nil { + t.Fatal(err, string(body)) + } + // List and verify the change. + r, body, err = at.Get("/user/apikeys", nil) + if err != nil { + t.Fatal(err, string(body)) + } + err = json.Unmarshal(body, &aks) + if err != nil { + t.Fatal(err) + } + if len(aks) != 0 { + t.Fatalf("Expected no API keys, got %d.", len(aks)) + } +} diff --git a/test/api/auth_test.go b/test/api/auth_test.go new file mode 100644 index 00000000..778f64ec --- /dev/null +++ b/test/api/auth_test.go @@ -0,0 +1 @@ +package api diff --git a/test/api/handlers_test.go b/test/api/handlers_test.go index 77c1737f..660704fb 100644 --- a/test/api/handlers_test.go +++ b/test/api/handlers_test.go @@ -66,8 +66,9 @@ func TestHandlers(t *testing.T) { {name: "StandardUserFlow", test: testUserFlow}, {name: "Challenge-Response/Registration", test: testRegistration}, {name: "Challenge-Response/Login", test: testLogin}, - {name: "APIKeysFlow", test: testAPIKeysFlow}, - {name: "APIKeysUsage", test: testAPIKeysUsage}, + {name: "PrivateAPIKeysFlow", test: testPrivateAPIKeysFlow}, + {name: "PrivateAPIKeysUsage", test: testPrivateAPIKeysUsage}, + {name: "PublicAPIKeysFlow", test: testPublicAPIKeysFlow}, } // Run subtests diff --git a/test/database/apikeys_test.go b/test/database/apikeys_test.go new file mode 100644 index 00000000..636bab89 --- /dev/null +++ b/test/database/apikeys_test.go @@ -0,0 +1 @@ +package database diff --git a/test/tester.go b/test/tester.go index 045fb3d2..86f88780 100644 --- a/test/tester.go +++ b/test/tester.go @@ -17,6 +17,7 @@ import ( "github.com/SkynetLabs/skynet-accounts/metafetcher" "github.com/sirupsen/logrus" "gitlab.com/NebulousLabs/errors" + "go.mongodb.org/mongo-driver/bson/primitive" "go.sia.tech/siad/build" ) @@ -169,8 +170,23 @@ func (at *AccountsTester) Post(endpoint string, params url.Values, bodyParams ur // Put executes a PUT request against the test service. // // NOTE: The Body of the returned response is already read and closed. -func (at *AccountsTester) Put(endpoint string, params url.Values, putParams url.Values) (r *http.Response, body []byte, err error) { - return at.request(http.MethodPut, endpoint, params, putParams, nil) +func (at *AccountsTester) Put(endpoint string, params url.Values, bodyParams url.Values) (r *http.Response, body []byte, err error) { + b, err := json.Marshal(bodyParams) + if err != nil { + return nil, nil, errors.AddContext(err, "failed to marshal the body JSON") + } + return at.request(http.MethodPut, endpoint, params, b, nil) +} + +// Patch executes a PATCH request against the test service. +// +// NOTE: The Body of the returned response is already read and closed. +func (at *AccountsTester) Patch(endpoint string, params url.Values, bodyParams url.Values) (r *http.Response, body []byte, err error) { + b, err := json.Marshal(bodyParams) + if err != nil { + return nil, nil, errors.AddContext(err, "failed to marshal the body JSON") + } + return at.request(http.MethodPatch, endpoint, params, b, nil) } // Close performs a graceful shutdown of the AccountsTester service. @@ -213,16 +229,12 @@ func (at *AccountsTester) UserPUT(email, password, stipeID string) (*http.Respon // request. It attaches the current cookie, if one exists. // // NOTE: The Body of the returned response is already read and closed. -func (at *AccountsTester) request(method string, endpoint string, queryParams url.Values, bodyParams url.Values, headers map[string]string) (*http.Response, []byte, error) { +func (at *AccountsTester) request(method string, endpoint string, queryParams url.Values, body []byte, headers map[string]string) (*http.Response, []byte, error) { if queryParams == nil { queryParams = url.Values{} } serviceURL := testPortalAddr + ":" + testPortalPort + endpoint + "?" + queryParams.Encode() - b, err := json.Marshal(bodyParams) - if err != nil { - return nil, nil, errors.AddContext(err, "failed to marshal the body JSON") - } - req, err := http.NewRequest(method, serviceURL, bytes.NewBuffer(b)) + req, err := http.NewRequest(method, serviceURL, bytes.NewBuffer(body)) if err != nil { return nil, nil, err } @@ -254,6 +266,59 @@ func (at *AccountsTester) executeRequest(req *http.Request) (*http.Response, []b return processResponse(r) } +// UserAPIKeysPOST performs a `POST /user/apikeys` request. +func (at *AccountsTester) UserAPIKeysPOST(body api.APIKeyPOST) (api.APIKeyResponseWithKey, int, error) { + bb, err := json.Marshal(body) + if err != nil { + return api.APIKeyResponseWithKey{}, http.StatusBadRequest, err + } + r, b, err := at.request(http.MethodPost, "/user/apikeys", nil, bb, nil) + if err != nil { + return api.APIKeyResponseWithKey{}, r.StatusCode, err + } + if r.StatusCode != http.StatusOK { + return api.APIKeyResponseWithKey{}, r.StatusCode, errors.New(string(b)) + } + var result api.APIKeyResponseWithKey + err = json.Unmarshal(b, &result) + if err != nil { + return api.APIKeyResponseWithKey{}, 0, errors.AddContext(err, "failed to parse response") + } + return result, r.StatusCode, nil +} + +// UserAPIKeysPUT performs a `PUT /user/apikeys` request. +func (at *AccountsTester) UserAPIKeysPUT(akID primitive.ObjectID, body api.APIKeyPUT) (int, error) { + bb, err := json.Marshal(body) + if err != nil { + return http.StatusBadRequest, err + } + r, b, err := at.request(http.MethodPut, "/user/apikeys/"+akID.Hex(), nil, bb, nil) + if err != nil { + return r.StatusCode, err + } + if r.StatusCode != http.StatusNoContent { + return r.StatusCode, errors.New(string(b)) + } + return r.StatusCode, nil +} + +// UserAPIKeysPATCH performs a `PATH /user/apikeys` request. +func (at *AccountsTester) UserAPIKeysPATCH(akID primitive.ObjectID, body api.APIKeyPATCH) (int, error) { + bb, err := json.Marshal(body) + if err != nil { + return http.StatusBadRequest, err + } + r, b, err := at.request(http.MethodPatch, "/user/apikeys/"+akID.Hex(), nil, bb, nil) + if err != nil { + return r.StatusCode, err + } + if r.StatusCode != http.StatusNoContent { + return r.StatusCode, errors.New(string(b)) + } + return r.StatusCode, nil +} + // UserLimitsGET performs a `GET /user/limits` request. func (at *AccountsTester) UserLimitsGET(params url.Values, headers map[string]string) (api.UserLimitsGET, int, error) { r, b, err := at.request(http.MethodGet, "/user/limits", params, nil, headers) diff --git a/test/utils.go b/test/utils.go index e68850ac..039d31ed 100644 --- a/test/utils.go +++ b/test/utils.go @@ -1,6 +1,7 @@ package test import ( + "bytes" "context" "encoding/hex" "fmt" @@ -34,8 +35,28 @@ type ( *database.User staticDB *database.DB } + // ResponseWriter is a testing ResponseWriter implementation. + ResponseWriter struct { + Buffer bytes.Buffer + Status int + } ) +// Header implementation. +func (w ResponseWriter) Header() http.Header { + return http.Header{} +} + +// Write implementation. +func (w ResponseWriter) Write(b []byte) (int, error) { + return w.Buffer.Write(b) +} + +// WriteHeader implementation. +func (w ResponseWriter) WriteHeader(statusCode int) { + w.Status = statusCode +} + // Delete removes the test user from the DB. func (tu *User) Delete(ctx context.Context) error { return tu.staticDB.UserDelete(ctx, tu.User) From 7f9c4952212157003c909bb779108d9291092cb1 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Wed, 9 Mar 2022 17:45:29 +0100 Subject: [PATCH 10/25] Let `userLimitsGetFromTier` handle quota exceeded. --- api/handlers.go | 94 +++++++++++++------------------------------- api/handlers_test.go | 89 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 67 deletions(-) diff --git a/api/handlers.go b/api/handlers.go index 0faece88..c46486c8 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -405,7 +405,7 @@ func (api *API) userGET(u *database.User, w http.ResponseWriter, _ *http.Request // NOTE: This handler needs to use the noAuth middleware in order to be able to // optimise its calls to the DB and the use of caching. func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { - respAnon := userLimitsGetFromTier(database.TierAnonymous) + respAnon := userLimitsGetFromTier(database.TierAnonymous, false) // First check for an API key. akStr, err := apiKeyFromRequest(req) if err == nil { @@ -413,17 +413,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http tier, qe, ok := api.staticUserTierCache.Get(akStr) if ok { api.staticLogger.Traceln("Fetching user limits from cache by API key.") - resp := userLimitsGetFromTier(tier) - // If the quota is exceeded we should keep the user's tier but report - // anonymous-level speeds. - if qe { - // Report the speeds for tier anonymous. - resp = userLimitsGetFromTier(database.TierAnonymous) - // But keep reporting the user's actual tier and it's name. - resp.TierID = tier - resp.TierName = database.UserLimits[tier].TierName - } - api.WriteJSON(w, resp) + api.WriteJSON(w, userLimitsGetFromTier(tier, qe)) return } // Cache is missed, fetch the owner of this APIKey from the DB. @@ -454,17 +444,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http } // Cache the user under the API key they used. api.staticUserTierCache.Set(akStr, u) - resp := userLimitsGetFromTier(u.Tier) - // If the quota is exceeded we should keep the user's tier but report - // anonymous-level speeds. - if u.QuotaExceeded { - // Report the speeds for tier anonymous. - resp = userLimitsGetFromTier(database.TierAnonymous) - // But keep reporting the user's actual tier and it's name. - resp.TierID = u.Tier - resp.TierName = database.UserLimits[u.Tier].TierName - } - api.WriteJSON(w, resp) + api.WriteJSON(w, userLimitsGetFromTier(u.Tier, u.QuotaExceeded)) return } // Next check for a token. @@ -498,17 +478,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http build.Critical("Failed to fetch user from UserTierCache right after setting it.") } } - resp := userLimitsGetFromTier(tier) - // If the quota is exceeded we should keep the user's tier but report - // anonymous-level speeds. - if qe { - // Report anonymous speeds. - resp = userLimitsGetFromTier(database.TierAnonymous) - // Keep reporting the user's actual tier and tier name. - resp.TierID = tier - resp.TierName = database.UserLimits[tier].TierName - } - api.WriteJSON(w, resp) + api.WriteJSON(w, userLimitsGetFromTier(tier, qe)) } // userLimitsSkylinkGET returns the speed limits which apply to a GET call to @@ -517,7 +487,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http // NOTE: This handler needs to use the noAuth middleware in order to be able to // optimise its calls to the DB and the use of caching. func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { - respAnon := userLimitsGetFromTier(database.TierAnonymous) + respAnon := userLimitsGetFromTier(database.TierAnonymous, false) // Validate the skylink. skylink := ps.ByName("skylink") if !database.ValidSkylinkHash(skylink) { @@ -544,17 +514,7 @@ func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, re tier, qe, ok := api.staticUserTierCache.Get(akStr + skylink) if ok { api.staticLogger.Traceln("Fetching user limits from cache by API key.") - resp := userLimitsGetFromTier(tier) - // If the quota is exceeded we should keep the user's tier but report - // anonymous-level speeds. - if qe { - // Report the speeds for tier anonymous. - resp = userLimitsGetFromTier(database.TierAnonymous) - // But keep reporting the user's actual tier and it's name. - resp.TierID = u.Tier - resp.TierName = database.UserLimits[u.Tier].TierName - } - api.WriteJSON(w, resp) + api.WriteJSON(w, userLimitsGetFromTier(tier, qe)) return } // Get the API key. @@ -578,17 +538,7 @@ func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, re } // Store the user in the cache with a custom key. api.staticUserTierCache.Set(akStr+skylink, user) - resp := userLimitsGetFromTier(user.Tier) - // If the quota is exceeded we should keep the user's tier but report - // anonymous-level speeds. - if user.QuotaExceeded { - // Report the speeds for tier anonymous. - resp = userLimitsGetFromTier(database.TierAnonymous) - // But keep reporting the user's actual tier and it's name. - resp.TierID = user.Tier - resp.TierName = database.UserLimits[user.Tier].TierName - } - api.WriteJSON(w, resp) + api.WriteJSON(w, userLimitsGetFromTier(user.Tier, u.QuotaExceeded)) } // userStatsGET returns statistics about an existing user. @@ -1320,16 +1270,26 @@ func parseRequestBodyJSON(body io.ReadCloser, maxBodySize int64, v interface{}) // userLimitsGetFromTier is a helper that lets us succinctly translate // from the database DTO to the API DTO. -func userLimitsGetFromTier(tier int) *UserLimitsGET { - t := database.UserLimits[tier] +func userLimitsGetFromTier(tierID int, quotaExceeded bool) *UserLimitsGET { + t, ok := database.UserLimits[tierID] + if !ok { + build.Critical("userLimitsGetFromTier was called with non-existent tierID: " + strconv.Itoa(tierID)) + t = database.UserLimits[database.TierAnonymous] + } + limitsTier := t + if quotaExceeded { + limitsTier = database.UserLimits[database.TierAnonymous] + } return &UserLimitsGET{ - TierID: tier, - TierName: t.TierName, - UploadBandwidth: t.UploadBandwidth, - DownloadBandwidth: t.DownloadBandwidth, - MaxUploadSize: t.MaxUploadSize, - MaxNumberUploads: t.MaxNumberUploads, - RegistryDelay: t.RegistryDelay, - Storage: t.Storage, + TierID: tierID, + TierName: t.TierName, + Storage: t.Storage, + // If the user exceeds their quota, there will be brought down to + // anonymous levels. + UploadBandwidth: limitsTier.UploadBandwidth, + DownloadBandwidth: limitsTier.DownloadBandwidth, + MaxUploadSize: limitsTier.MaxUploadSize, + MaxNumberUploads: limitsTier.MaxNumberUploads, + RegistryDelay: limitsTier.RegistryDelay, } } diff --git a/api/handlers_test.go b/api/handlers_test.go index 3b31e988..41e17d3b 100644 --- a/api/handlers_test.go +++ b/api/handlers_test.go @@ -1,9 +1,13 @@ package api import ( + "fmt" + "math" + "strings" "testing" "github.com/SkynetLabs/skynet-accounts/database" + "gitlab.com/NebulousLabs/errors" ) // TestUserGETFromUser ensures the UserGETFromUser method correctly converts @@ -39,3 +43,88 @@ func TestUserGETFromUser(t *testing.T) { t.Fatal("Expected EmailConfirmed to be false.") } } + +// TestUserLimitsGetFromTier ensures the proper functioning of +// userLimitsGetFromTier. +func TestUserLimitsGetFromTier(t *testing.T) { + tests := []struct { + name string + tier int + quotaExceeded bool + expectedTier int + expectedStorage int64 + expectedUploadBW int + expectedDownloadBW int + expectedRegistryDelay int + }{ + { + name: "anon", + tier: database.TierAnonymous, + quotaExceeded: false, + expectedTier: database.TierAnonymous, + expectedStorage: database.UserLimits[database.TierAnonymous].Storage, + expectedUploadBW: database.UserLimits[database.TierAnonymous].UploadBandwidth, + expectedDownloadBW: database.UserLimits[database.TierAnonymous].DownloadBandwidth, + expectedRegistryDelay: database.UserLimits[database.TierAnonymous].RegistryDelay, + }, + { + name: "plus, quota not exceeded", + tier: database.TierPremium5, + quotaExceeded: false, + expectedTier: database.TierPremium5, + expectedStorage: database.UserLimits[database.TierPremium5].Storage, + expectedUploadBW: database.UserLimits[database.TierPremium5].UploadBandwidth, + expectedDownloadBW: database.UserLimits[database.TierPremium5].DownloadBandwidth, + expectedRegistryDelay: database.UserLimits[database.TierPremium5].RegistryDelay, + }, + { + name: "plus, quota exceeded", + tier: database.TierPremium5, + quotaExceeded: true, + expectedTier: database.TierPremium5, + expectedStorage: database.UserLimits[database.TierPremium5].Storage, + expectedUploadBW: database.UserLimits[database.TierAnonymous].UploadBandwidth, + expectedDownloadBW: database.UserLimits[database.TierAnonymous].DownloadBandwidth, + expectedRegistryDelay: database.UserLimits[database.TierAnonymous].RegistryDelay, + }, + } + + for _, tt := range tests { + ul := userLimitsGetFromTier(tt.tier, tt.quotaExceeded) + if ul.TierID != tt.expectedTier { + t.Errorf("Test '%s': expected tier %d, got %d", tt.name, tt.expectedTier, ul.TierID) + } + if ul.Storage != tt.expectedStorage { + t.Errorf("Test '%s': expected storage %d, got %d", tt.name, tt.expectedStorage, ul.Storage) + } + if ul.UploadBandwidth != tt.expectedUploadBW { + t.Errorf("Test '%s': expected upload bandwidth %d, got %d", tt.name, tt.expectedUploadBW, ul.UploadBandwidth) + } + if ul.DownloadBandwidth != tt.expectedDownloadBW { + t.Errorf("Test '%s': expected download bandwidth %d, got %d", tt.name, tt.expectedDownloadBW, ul.DownloadBandwidth) + } + if ul.RegistryDelay != tt.expectedRegistryDelay { + t.Errorf("Test '%s': expected registry delay %d, got %d", tt.name, tt.expectedRegistryDelay, ul.RegistryDelay) + } + } + + // Additionally, let us ensure that userLimitsGetFromTier logs a critical + // when called with an invalid tier ID. + err := func() (err error) { + defer func() { + e := recover() + if e == nil { + err = errors.New("expected to panic") + } + if !strings.Contains(fmt.Sprint(e), "userLimitsGetFromTier was called with non-existent tierID") { + err = fmt.Errorf("expected to panic with a specific error message, got '%s'", fmt.Sprint(e)) + } + }() + // The call that we expect to log a critical. + _ = userLimitsGetFromTier(math.MaxInt, false) + return + }() + if err != nil { + t.Fatal(err) + } +} From c0a69049c7cc03bfb4f3d047596b0c813872d494 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Wed, 9 Mar 2022 18:05:33 +0100 Subject: [PATCH 11/25] Let `apiKeyFromRequest` return a validated API key, thus simplifying the flow. --- api/auth.go | 15 +++++---------- api/auth_test.go | 18 +++++++++--------- api/handlers.go | 33 ++++++++++++--------------------- database/apikeys.go | 15 +++++++++++++++ 4 files changed, 41 insertions(+), 40 deletions(-) diff --git a/api/auth.go b/api/auth.go index 2fc8a3fa..a378db16 100644 --- a/api/auth.go +++ b/api/auth.go @@ -34,16 +34,11 @@ func (api *API) userAndTokenByRequestToken(req *http.Request) (*database.User, j // It first checks the headers and then the query. // This method accesses the database. func (api *API) userAndTokenByAPIKey(req *http.Request) (*database.User, jwt2.Token, error) { - akStr, err := apiKeyFromRequest(req) + ak, err := apiKeyFromRequest(req) if err != nil { return nil, nil, err } - // Check if this is a valid APIKey. - ak := database.APIKey(akStr) - if !ak.IsValid() { - return nil, nil, database.ErrInvalidAPIKey - } - akr, err := api.staticDB.APIKeyByKey(req.Context(), akStr) + akr, err := api.staticDB.APIKeyByKey(req.Context(), ak.String()) if err != nil { return nil, nil, err } @@ -66,7 +61,7 @@ func (api *API) userAndTokenByAPIKey(req *http.Request) (*database.User, jwt2.To // apiKeyFromRequest extracts the API key from the request and returns it. // This function does not differentiate between APIKey and APIKey. // It first checks the headers and then the query. -func apiKeyFromRequest(r *http.Request) (string, error) { +func apiKeyFromRequest(r *http.Request) (*database.APIKey, error) { // Check the headers for an API key. akStr := r.Header.Get(APIKeyHeader) // If there is no API key in the headers, try the query. @@ -74,9 +69,9 @@ func apiKeyFromRequest(r *http.Request) (string, error) { akStr = r.FormValue("apiKey") } if akStr == "" { - return "", ErrNoAPIKey + return nil, ErrNoAPIKey } - return akStr, nil + return database.NewAPIKeyFromString(akStr) } // tokenFromRequest extracts the JWT token from the request and returns it. diff --git a/api/auth_test.go b/api/auth_test.go index 2fb78071..7615bd49 100644 --- a/api/auth_test.go +++ b/api/auth_test.go @@ -29,28 +29,28 @@ func TestAPIKeyFromRequest(t *testing.T) { } // API key from request form. - token := randomAPIKeyString() - req.Form.Add("apiKey", token) - tk, err := apiKeyFromRequest(req) + akStr := randomAPIKeyString() + req.Form.Add("apiKey", akStr) + ak, err := apiKeyFromRequest(req) if err != nil { t.Fatal(err) } - if tk != token { - t.Fatalf("Expected '%s', got '%s'.", token, tk) + if ak.String() != akStr { + t.Fatalf("Expected '%s', got '%s'.", akStr, ak) } // API key from headers. Expect this to take precedence over request form. token2 := randomAPIKeyString() req.Header.Set(APIKeyHeader, token2) - tk, err = apiKeyFromRequest(req) + ak, err = apiKeyFromRequest(req) if err != nil { t.Fatal(err) } - if tk == token { + if ak.String() == akStr { t.Fatal("Form token took precedence over headers token.") } - if tk != token2 { - t.Fatalf("Expected '%s', got '%s'.", token2, tk) + if ak.String() != token2 { + t.Fatalf("Expected '%s', got '%s'.", token2, ak) } } diff --git a/api/handlers.go b/api/handlers.go index c46486c8..b4eddcee 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -407,24 +407,17 @@ func (api *API) userGET(u *database.User, w http.ResponseWriter, _ *http.Request func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { respAnon := userLimitsGetFromTier(database.TierAnonymous, false) // First check for an API key. - akStr, err := apiKeyFromRequest(req) + ak, err := apiKeyFromRequest(req) if err == nil { // Check the cache before going any further. - tier, qe, ok := api.staticUserTierCache.Get(akStr) + tier, qe, ok := api.staticUserTierCache.Get(ak.String()) if ok { api.staticLogger.Traceln("Fetching user limits from cache by API key.") api.WriteJSON(w, userLimitsGetFromTier(tier, qe)) return } - // Cache is missed, fetch the owner of this APIKey from the DB. - ak := database.APIKey(akStr) - if !ak.IsValid() { - api.staticLogger.Traceln("Invalid API key.") - api.WriteJSON(w, respAnon) - return - } // Get the API key. - akr, err := api.staticDB.APIKeyByKey(req.Context(), akStr) + akr, err := api.staticDB.APIKeyByKey(req.Context(), ak.String()) if err != nil { api.staticLogger.Trace("API key doesn't exist in the database.") api.WriteJSON(w, respAnon) @@ -443,7 +436,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http return } // Cache the user under the API key they used. - api.staticUserTierCache.Set(akStr, u) + api.staticUserTierCache.Set(ak.String(), u) api.WriteJSON(w, userLimitsGetFromTier(u.Tier, u.QuotaExceeded)) return } @@ -496,29 +489,27 @@ func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, re return } // Try to fetch an API attached to the request. - akStr, err := apiKeyFromRequest(req) - if err != nil { + ak, err := apiKeyFromRequest(req) + if errors.Contains(err, ErrNoAPIKey) { // We failed to fetch an API key from this request but the request might // be authenticated in another way, so we'll defer to userLimitsGET. api.userLimitsGET(u, w, req, ps) return } - // Check if that is a valid API key. - ak := database.APIKey(akStr) - if !ak.IsValid() { - // This is not a valid APIKey. Defer to userLimitsGET. - api.userLimitsGET(u, w, req, ps) + if err != nil { + api.staticLogger.Debugf("Error while processing API key: %s", err) + api.WriteJSON(w, respAnon) return } // Check the cache before hitting the database. - tier, qe, ok := api.staticUserTierCache.Get(akStr + skylink) + tier, qe, ok := api.staticUserTierCache.Get(ak.String() + skylink) if ok { api.staticLogger.Traceln("Fetching user limits from cache by API key.") api.WriteJSON(w, userLimitsGetFromTier(tier, qe)) return } // Get the API key. - akr, err := api.staticDB.APIKeyByKey(req.Context(), akStr) + akr, err := api.staticDB.APIKeyByKey(req.Context(), ak.String()) if err != nil { api.staticLogger.Trace("API key doesn't exist in the database.") api.WriteJSON(w, respAnon) @@ -537,7 +528,7 @@ func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, re return } // Store the user in the cache with a custom key. - api.staticUserTierCache.Set(akStr+skylink, user) + api.staticUserTierCache.Set(ak.String()+skylink, user) api.WriteJSON(w, userLimitsGetFromTier(user.Tier, u.QuotaExceeded)) } diff --git a/database/apikeys.go b/database/apikeys.go index 415bfd85..7cd6fa15 100644 --- a/database/apikeys.go +++ b/database/apikeys.go @@ -105,6 +105,21 @@ func (ak *APIKey) LoadBytes(b []byte) error { return nil } +// LoadString loads a string into the API key and validates it. +func (ak *APIKey) LoadString(s string) error { + k := APIKey(strings.ToUpper(s)) + if !k.IsValid() { + return ErrInvalidAPIKey + } + *ak = k + return nil +} + +// LoadBytes encodes a []byte of size PubKeySize into an API key. +func (ak APIKey) String() string { + return string(ak) +} + // CoversSkylink tells us whether a given API key covers a given skylink. // Private API keys cover all skylinks while public ones - only a limited set. func (akr APIKeyRecord) CoversSkylink(sl string) bool { From 2a4059b9bd2446bf92dd8b57fda65dcb2b0592e6 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Thu, 10 Mar 2022 15:18:15 +0100 Subject: [PATCH 12/25] Add HealthGET tester method. --- test/api/handlers_test.go | 9 +-------- test/tester.go | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/test/api/handlers_test.go b/test/api/handlers_test.go index 3219bca7..1060ebdd 100644 --- a/test/api/handlers_test.go +++ b/test/api/handlers_test.go @@ -83,17 +83,10 @@ func TestHandlers(t *testing.T) { // testHandlerHealthGET tests the /health handler. func testHandlerHealthGET(t *testing.T, at *test.AccountsTester) { - _, b, err := at.Get("/health", nil) + status, _, err := at.HealthGet() if err != nil { t.Fatal(err) } - status := struct { - DBAlive bool `json:"dbAlive"` - }{} - err = json.Unmarshal(b, &status) - if err != nil { - t.Fatal("Failed to unmarshal service's response: ", err) - } // DBAlive should never be false because if we couldn't reach the DB, we // wouldn't have made it this far in the test. if !status.DBAlive { diff --git a/test/tester.go b/test/tester.go index 55239db6..5297d680 100644 --- a/test/tester.go +++ b/test/tester.go @@ -116,7 +116,7 @@ func NewAccountsTester(dbName string) (*AccountsTester, error) { } // Wait for the accounts tester to be fully ready. err = build.Retry(50, time.Millisecond, func() error { - _, _, err = at.Get("/health", nil) + _, _, err = at.HealthGet() return err }) if err != nil { @@ -293,6 +293,20 @@ func (at *AccountsTester) executeRequest(req *http.Request) (*http.Response, []b return processResponse(r) } +// HealthGet executes a GET /health. +func (at *AccountsTester) HealthGet() (api.HealthGET, int, error) { + r, b, err := at.request(http.MethodGet, "/health", nil, nil, nil) + if err != nil { + return api.HealthGET{}, r.StatusCode, err + } + var resp api.HealthGET + err = json.Unmarshal(b, &resp) + if err != nil { + return api.HealthGET{}, 0, errors.AddContext(err, "failed to marshal the body JSON") + } + return resp, r.StatusCode, nil +} + // UserAPIKeysPOST performs a `POST /user/apikeys` request. func (at *AccountsTester) UserAPIKeysPOST(body api.APIKeyPOST) (api.APIKeyResponseWithKey, int, error) { bb, err := json.Marshal(body) From 28978d654541865c9ccd6469200363647443a9ae Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Thu, 10 Mar 2022 16:22:44 +0100 Subject: [PATCH 13/25] Add an integration test for public API keys usage. --- api/handlers.go | 2 +- database/apikeys.go | 2 +- test/api/api_test.go | 6 +-- test/api/apikeys_test.go | 79 ++++++++++++++++++++++++++++++++------- test/api/handlers_test.go | 9 +++-- test/tester.go | 57 ++++++++++++++++++---------- 6 files changed, 113 insertions(+), 42 deletions(-) diff --git a/api/handlers.go b/api/handlers.go index b4eddcee..fff1547e 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -529,7 +529,7 @@ func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, re } // Store the user in the cache with a custom key. api.staticUserTierCache.Set(ak.String()+skylink, user) - api.WriteJSON(w, userLimitsGetFromTier(user.Tier, u.QuotaExceeded)) + api.WriteJSON(w, userLimitsGetFromTier(user.Tier, user.QuotaExceeded)) } // userStatsGET returns statistics about an existing user. diff --git a/database/apikeys.go b/database/apikeys.go index 7cd6fa15..f31263ef 100644 --- a/database/apikeys.go +++ b/database/apikeys.go @@ -123,7 +123,7 @@ func (ak APIKey) String() string { // CoversSkylink tells us whether a given API key covers a given skylink. // Private API keys cover all skylinks while public ones - only a limited set. func (akr APIKeyRecord) CoversSkylink(sl string) bool { - if akr.Public { + if !akr.Public { return true } for _, s := range akr.Skylinks { diff --git a/test/api/api_test.go b/test/api/api_test.go index 100a68c0..51a08dcf 100644 --- a/test/api/api_test.go +++ b/test/api/api_test.go @@ -179,7 +179,7 @@ func TestUserTierCache(t *testing.T) { } at.SetCookie(test.ExtractCookie(r)) // Get the user's limit. - ul, _, err := at.UserLimits() + ul, _, err := at.UserLimits(nil, nil) if err != nil { t.Fatal(err) } @@ -214,7 +214,7 @@ func TestUserTierCache(t *testing.T) { err = build.Retry(10, 200*time.Millisecond, func() error { // We expect to get tier with name and id matching TierPremium20 but with // speeds matching TierAnonymous. - ul, _, err = at.UserLimits() + ul, _, err = at.UserLimits(nil, nil) if err != nil { t.Fatal(err) } @@ -240,7 +240,7 @@ func TestUserTierCache(t *testing.T) { } err = build.Retry(10, 200*time.Millisecond, func() error { // We expect to get TierPremium20. - ul, _, err = at.UserLimits() + ul, _, err = at.UserLimits(nil, nil) if err != nil { return errors.AddContext(err, "failed to call /user/limits") } diff --git a/test/api/apikeys_test.go b/test/api/apikeys_test.go index b1ed9b41..49803f7e 100644 --- a/test/api/apikeys_test.go +++ b/test/api/apikeys_test.go @@ -133,28 +133,18 @@ func testPrivateAPIKeysUsage(t *testing.T, at *test.AccountsTester) { t.Fatal(err) } // Create a new API key. - _, body, err := at.Post("/user/apikeys", nil, nil) + akWithKey, _, err := at.UserAPIKeysPOST(api.APIKeyPOST{}) if err != nil { t.Fatal(err) } // Stop using the cookie, so we can test the API key. at.ClearCredentials() - // We use a custom struct and not the APIKeyRecord one because that one does - // not render the key in JSON form and therefore it won't unmarshal it, - // either. - var ak struct { - Key database.APIKey - } - err = json.Unmarshal(body, &ak) - if err != nil { - t.Fatal(err) - } // Get user stats without a cookie or headers - pass the API key via a query // variable. The main thing we want to see here is whether we get // an `Unauthorized` error or not but we'll validate the stats as well. params := url.Values{} - params.Set("apiKey", string(ak.Key)) - _, body, err = at.Get("/user/stats", params) + params.Set("apiKey", akWithKey.Key.String()) + _, body, err := at.Get("/user/stats", params) if err != nil { t.Fatal(err, string(body)) } @@ -280,3 +270,66 @@ func testPublicAPIKeysFlow(t *testing.T, at *test.AccountsTester) { t.Fatalf("Expected no API keys, got %d.", len(aks)) } } + +// testPublicAPIKeysUsage makes sure that we can use public API keys to make +// GET requests to covered skylinks and that we cannot use them for other +// requests. +func testPublicAPIKeysUsage(t *testing.T, at *test.AccountsTester) { + name := test.DBNameForTest(t.Name()) + // Create a test user. + email := name + "@siasky.net" + r, _, err := at.CreateUserPost(email, name+"_pass") + if err != nil { + t.Fatal(err) + } + at.SetCookie(test.ExtractCookie(r)) + // Get the user and create a test upload, so the stats won't be all zeros. + u, err := at.DB.UserByEmail(at.Ctx, email) + if err != nil { + t.Fatal(err) + } + uploadSize := int64(fastrand.Intn(int(modules.SectorSize / 2))) + sl, _, err := test.CreateTestUpload(at.Ctx, at.DB, u, uploadSize) + if err != nil { + t.Fatal(err) + } + sl2, _, err := test.CreateTestUpload(at.Ctx, at.DB, u, uploadSize) + if err != nil { + t.Fatal(err) + } + // Create a new public API key. + apiKeyPOST := api.APIKeyPOST{ + Public: true, + Skylinks: []string{sl.Skylink}, + } + akWithKey, _, err := at.UserAPIKeysPOST(apiKeyPOST) + if err != nil { + t.Fatal(err) + } + // Stop using the cookie, use the public API key instead. + at.SetAPIKey(akWithKey.Key.String()) + // Try to fetch the user's stats with the new public API key. + // Expect this to fail. + _, _, err = at.Get("/user/stats", nil) + if err == nil { + t.Fatal("Managed to get user stats with a public API key.") + } + // Get the user's limits for downloading a skylink covered by the public + // API key. Expect to get TierFree values. + ul, _, err := at.UserLimitsSkylink(sl.Skylink, nil, nil) + if err != nil { + t.Fatal(err) + } + if ul.DownloadBandwidth != database.UserLimits[database.TierFree].DownloadBandwidth { + t.Fatalf("Expected to get download bandwidth of %d, got %d", database.UserLimits[database.TierFree].DownloadBandwidth, ul.DownloadBandwidth) + } + // Get the user's limits for downloading a skylink that is not covered by + // the public API key. Expect to get TierAnonymous values. + ul, _, err = at.UserLimitsSkylink(sl2.Skylink, nil, nil) + if err != nil { + t.Fatal(err) + } + if ul.DownloadBandwidth != database.UserLimits[database.TierAnonymous].DownloadBandwidth { + t.Fatalf("Expected to get download bandwidth of %d, got %d", database.UserLimits[database.TierAnonymous].DownloadBandwidth, ul.DownloadBandwidth) + } +} diff --git a/test/api/handlers_test.go b/test/api/handlers_test.go index 1060ebdd..9febfe50 100644 --- a/test/api/handlers_test.go +++ b/test/api/handlers_test.go @@ -71,6 +71,7 @@ func TestHandlers(t *testing.T) { {name: "PrivateAPIKeysFlow", test: testPrivateAPIKeysFlow}, {name: "PrivateAPIKeysUsage", test: testPrivateAPIKeysUsage}, {name: "PublicAPIKeysFlow", test: testPublicAPIKeysFlow}, + {name: "PublicAPIKeysUsage", test: testPublicAPIKeysUsage}, } // Run subtests @@ -448,7 +449,7 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { } // Call /user/limits with a cookie. Expect TierFree response. - tl, _, err := at.UserLimits() + tl, _, err := at.UserLimits(nil, nil) if err != nil { t.Fatal(err) } @@ -464,7 +465,7 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { // Call /user/limits without a cookie. Expect FreeAnonymous response. at.ClearCredentials() - tl, _, err = at.UserLimits() + tl, _, err = at.UserLimits(nil, nil) if err != nil { t.Fatal(err) } @@ -479,7 +480,7 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { } // Call /user/limits with an API key. Expect TierFree response. - tl, _, err = at.UserLimitsGET(nil, map[string]string{api.APIKeyHeader: string(akr.Key)}) + tl, _, err = at.UserLimits(nil, map[string]string{api.APIKeyHeader: string(akr.Key)}) if err != nil { t.Fatal(err) } @@ -528,7 +529,7 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { err = build.Retry(10, 200*time.Millisecond, func() error { // Check the user's limits. We expect the tier to be Free but the limits to // match Anonymous. - tl, _, err = at.UserLimits() + tl, _, err = at.UserLimits(nil, nil) if err != nil { return errors.AddContext(err, "failed to call /user/limits") } diff --git a/test/tester.go b/test/tester.go index 5297d680..b96c5568 100644 --- a/test/tester.go +++ b/test/tester.go @@ -34,7 +34,7 @@ type ( Ctx context.Context DB *database.DB Logger *logrus.Logger - // If set, this cookie will be attached to all requests. + APIKey string Cookie *http.Cookie Token string @@ -128,6 +128,7 @@ func NewAccountsTester(dbName string) (*AccountsTester, error) { // ClearCredentials removes any credentials stored by this tester, such as a // cookie, token, etc. func (at *AccountsTester) ClearCredentials() { + at.APIKey = "" at.Cookie = nil at.Token = "" } @@ -144,6 +145,13 @@ func (at *AccountsTester) Close() error { return nil } +// SetAPIKey ensures that all subsequent requests are going to use the given +// API key for authentication. +func (at *AccountsTester) SetAPIKey(ak string) { + at.ClearCredentials() + at.APIKey = ak +} + // SetCookie ensures that all subsequent requests are going to use the given // cookie for authentication. func (at *AccountsTester) SetCookie(c *http.Cookie) { @@ -279,6 +287,9 @@ func (at *AccountsTester) executeRequest(req *http.Request) (*http.Response, []b if req == nil { return nil, nil, errors.New("invalid request") } + if at.APIKey != "" { + req.Header.Set(api.APIKeyHeader, at.APIKey) + } if at.Cookie != nil { req.Header.Set("Cookie", at.Cookie.String()) } @@ -360,23 +371,6 @@ func (at *AccountsTester) UserAPIKeysPATCH(akID primitive.ObjectID, body api.API return r.StatusCode, nil } -// UserLimitsGET performs a `GET /user/limits` request. -func (at *AccountsTester) UserLimitsGET(params url.Values, headers map[string]string) (api.UserLimitsGET, int, error) { - r, b, err := at.request(http.MethodGet, "/user/limits", params, nil, headers) - if err != nil { - return api.UserLimitsGET{}, r.StatusCode, err - } - if r.StatusCode != http.StatusOK { - return api.UserLimitsGET{}, r.StatusCode, errors.New(string(b)) - } - var result api.UserLimitsGET - err = json.Unmarshal(b, &result) - if err != nil { - return api.UserLimitsGET{}, 0, errors.AddContext(err, "failed to parse response") - } - return result, r.StatusCode, nil -} - // processResponse is a helper method which extracts the body from the response // and handles non-OK status codes. // @@ -419,11 +413,34 @@ func (at *AccountsTester) TrackRegistryWrite() (int, error) { } // UserLimits performs a `GET /user/limits` request. -func (at *AccountsTester) UserLimits() (api.UserLimitsGET, int, error) { - r, b, err := at.request(http.MethodGet, "/user/limits", nil, nil, nil) +func (at *AccountsTester) UserLimits(params url.Values, headers map[string]string) (api.UserLimitsGET, int, error) { + r, b, err := at.request(http.MethodGet, "/user/limits", params, nil, headers) if err != nil { return api.UserLimitsGET{}, r.StatusCode, err } + if r.StatusCode != http.StatusOK { + return api.UserLimitsGET{}, r.StatusCode, errors.New(string(b)) + } + var resp api.UserLimitsGET + err = json.Unmarshal(b, &resp) + if err != nil { + return api.UserLimitsGET{}, 0, errors.AddContext(err, "failed to marshal the body JSON") + } + return resp, r.StatusCode, nil +} + +// UserLimitsSkylink performs a `GET /user/limits/:skylink` request. +func (at *AccountsTester) UserLimitsSkylink(sl string, params url.Values, headers map[string]string) (api.UserLimitsGET, int, error) { + if !database.ValidSkylinkHash(sl) { + return api.UserLimitsGET{}, 0, database.ErrInvalidSkylink + } + r, b, err := at.request(http.MethodGet, "/user/limits/"+sl, params, nil, headers) + if err != nil { + return api.UserLimitsGET{}, r.StatusCode, err + } + if r.StatusCode != http.StatusOK { + return api.UserLimitsGET{}, r.StatusCode, errors.New(string(b)) + } var resp api.UserLimitsGET err = json.Unmarshal(b, &resp) if err != nil { From f491de47361a0e4b027eb7406b0f8ffbdb8490d3 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Thu, 10 Mar 2022 16:46:07 +0100 Subject: [PATCH 14/25] Tester helpers for API keys. --- test/api/apikeys_test.go | 97 ++++++++++----------------------------- test/api/handlers_test.go | 9 +--- test/tester.go | 40 ++++++++++++++++ 3 files changed, 66 insertions(+), 80 deletions(-) diff --git a/test/api/apikeys_test.go b/test/api/apikeys_test.go index 49803f7e..f2d5ba6e 100644 --- a/test/api/apikeys_test.go +++ b/test/api/apikeys_test.go @@ -24,91 +24,67 @@ func testPrivateAPIKeysFlow(t *testing.T, at *test.AccountsTester) { } at.SetCookie(test.ExtractCookie(r)) - aks := make([]database.APIKeyRecord, 0) - // List all API keys this user has. Expect the list to be empty. - r, body, err = at.Get("/user/apikeys", nil) + aks, _, err := at.UserAPIKeysLIST() if err != nil { t.Fatal(err, string(body)) } - err = json.Unmarshal(body, &aks) - if err != nil { - t.Fatal(err) - } if len(aks) > 0 { t.Fatalf("Expected an empty list of API keys, got %+v.", aks) } // Create a new API key. - r, body, err = at.Post("/user/apikeys", nil, nil) + ak1, _, err := at.UserAPIKeysPOST(api.APIKeyPOST{}) if err != nil { t.Fatal(err, string(body)) } - var ak1 database.APIKeyRecord - err = json.Unmarshal(body, &ak1) - if err != nil { - t.Fatal(err) - } // Make sure the API key is private. if ak1.Public { t.Fatal("Expected the API key to be private.") } // Create another API key. - r, body, err = at.Post("/user/apikeys", nil, nil) + ak2, _, err := at.UserAPIKeysPOST(api.APIKeyPOST{}) if err != nil { t.Fatal(err, string(body)) } - var ak2 database.APIKeyRecord - err = json.Unmarshal(body, &ak2) - if err != nil { - t.Fatal(err) - } // List all API keys this user has. Expect to find both keys we created. - r, body, err = at.Get("/user/apikeys", nil) + aks, _, err = at.UserAPIKeysLIST() if err != nil { t.Fatal(err, string(body)) } - err = json.Unmarshal(body, &aks) - if err != nil { - t.Fatal(err) - } if len(aks) != 2 { t.Fatalf("Expected two API keys, got %+v.", aks) } - if ak1.Key != aks[0].Key && ak1.Key != aks[1].Key { - t.Fatalf("Missing key '%s'! Set: %+v", ak1.Key, aks) + if ak1.ID.Hex() != aks[0].ID.Hex() && ak1.ID.Hex() != aks[1].ID.Hex() { + t.Fatalf("Missing key '%s'! Set: %+v", ak1.ID.Hex(), aks) } - if ak2.Key != aks[0].Key && ak2.Key != aks[1].Key { - t.Fatalf("Missing key '%s'! Set: %+v", ak2.Key, aks) + if ak2.ID.Hex() != aks[0].ID.Hex() && ak2.ID.Hex() != aks[1].ID.Hex() { + t.Fatalf("Missing key '%s'! Set: %+v", ak2.ID.Hex(), aks) } // Delete an API key. - r, body, err = at.Delete("/user/apikeys/"+ak1.ID.Hex(), nil) - if err != nil { - t.Fatal(err, string(body)) + status, err := at.UserAPIKeysDELETE(ak1.ID) + if err != nil || status != http.StatusNoContent { + t.Fatal(err, status) } // List all API keys this user has. Expect to find only the second one. - r, body, err = at.Get("/user/apikeys", nil) + aks, _, err = at.UserAPIKeysLIST() if err != nil { t.Fatal(err, string(body)) } - err = json.Unmarshal(body, &aks) - if err != nil { - t.Fatal(err) - } if len(aks) != 1 { t.Fatalf("Expected one API key, got %+v.", aks) } - if ak2.Key != aks[0].Key { - t.Fatalf("Missing key '%s'! Set: %+v", ak2.Key, aks) + if ak2.ID.Hex() != aks[0].ID.Hex() { + t.Fatalf("Missing key '%s'! Set: %+v", ak2.ID.Hex(), aks) } // Try to delete the same key again. Expect a Bad Request. - r, body, err = at.Delete("/user/apikeys/"+ak1.ID.Hex(), nil) - if r.StatusCode != http.StatusBadRequest { - t.Fatalf("Expected status 400, got %d.", r.StatusCode) + status, _ = at.UserAPIKeysDELETE(ak1.ID) + if status != http.StatusBadRequest { + t.Fatalf("Expected status 400, got %d.", status) } } @@ -172,14 +148,8 @@ func testPublicAPIKeysFlow(t *testing.T, at *test.AccountsTester) { sl1 := "AQAh2vxStoSJ_M9tWcTgqebUWerCAbpMfn9xxa9E29UOuw" sl2 := "AADDE7_5MJyl1DKyfbuQMY_XBOBC9bR7idiU6isp6LXxEw" - aks := make([]database.APIKeyRecord, 0) - // List all API keys this user has. Expect the list to be empty. - r, body, err = at.Get("/user/apikeys", nil) - if err != nil { - t.Fatal(err, string(body)) - } - err = json.Unmarshal(body, &aks) + aks, _, err := at.UserAPIKeysLIST() if err != nil { t.Fatal(err) } @@ -196,11 +166,7 @@ func testPublicAPIKeysFlow(t *testing.T, at *test.AccountsTester) { t.Fatal(err) } // List all API keys again. Expect to find a key. - r, body, err = at.Get("/user/apikeys", nil) - if err != nil { - t.Fatal(err, string(body)) - } - err = json.Unmarshal(body, &aks) + aks, _, err = at.UserAPIKeysLIST() if err != nil { t.Fatal(err) } @@ -216,15 +182,10 @@ func testPublicAPIKeysFlow(t *testing.T, at *test.AccountsTester) { t.Fatal(err) } // Get the key and verify the change. - r, body, err = at.Get("/user/apikeys/"+akr.ID.Hex(), nil) + akr1, _, err := at.UserAPIKeysGET(akr.ID) if err != nil { t.Fatal(err, string(body)) } - var akr1 database.APIKeyRecord - err = json.Unmarshal(body, &akr1) - if err != nil { - t.Fatal(err) - } if akr1.Skylinks[0] != sl2 { t.Fatal("Unexpected skylinks list", aks[0].Skylinks) } @@ -238,14 +199,10 @@ func testPublicAPIKeysFlow(t *testing.T, at *test.AccountsTester) { t.Fatal(err) } // List and verify the change. - r, body, err = at.Get("/user/apikeys", nil) + aks, _, err = at.UserAPIKeysLIST() if err != nil { t.Fatal(err, string(body)) } - err = json.Unmarshal(body, &aks) - if err != nil { - t.Fatal(err) - } if len(aks) != 1 { t.Fatalf("Expected one API key, got %d.", len(aks)) } @@ -253,19 +210,15 @@ func testPublicAPIKeysFlow(t *testing.T, at *test.AccountsTester) { t.Fatal("Unexpected skylinks list", aks[0].Skylinks) } // Delete a public API key. - r, body, err = at.Delete("/user/apikeys/"+akr.ID.Hex(), nil) - if err != nil { - t.Fatal(err, string(body)) + status, err := at.UserAPIKeysDELETE(akr.ID) + if err != nil || status != http.StatusNoContent { + t.Fatal(err, status) } // List and verify the change. - r, body, err = at.Get("/user/apikeys", nil) + aks, _, err = at.UserAPIKeysLIST() if err != nil { t.Fatal(err, string(body)) } - err = json.Unmarshal(body, &aks) - if err != nil { - t.Fatal(err) - } if len(aks) != 0 { t.Fatalf("Expected no API keys, got %d.", len(aks)) } diff --git a/test/api/handlers_test.go b/test/api/handlers_test.go index 9febfe50..4a4a7b44 100644 --- a/test/api/handlers_test.go +++ b/test/api/handlers_test.go @@ -436,14 +436,7 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { defer at.ClearCredentials() // Create an API key for this user. - _, b, err := at.Post("/user/apikeys", nil, nil) - if err != nil { - t.Fatal(err) - } - var akr struct { - Key database.APIKey `json:"key"` - } - err = json.Unmarshal(b, &akr) + akr, _, err := at.UserAPIKeysPOST(api.APIKeyPOST{}) if err != nil { t.Fatal(err) } diff --git a/test/tester.go b/test/tester.go index b96c5568..4b843e7f 100644 --- a/test/tester.go +++ b/test/tester.go @@ -318,6 +318,46 @@ func (at *AccountsTester) HealthGet() (api.HealthGET, int, error) { return resp, r.StatusCode, nil } +// UserAPIKeysDELETE performs a `DELETE /user/apikeys/:id` request. +func (at *AccountsTester) UserAPIKeysDELETE(id primitive.ObjectID) (int, error) { + r, _, err := at.request(http.MethodDelete, "/user/apikeys/"+id.Hex(), nil, nil, nil) + return r.StatusCode, err +} + +// UserAPIKeysGET performs a `GET /user/apikeys/:id` request. +func (at *AccountsTester) UserAPIKeysGET(id primitive.ObjectID) (api.APIKeyResponse, int, error) { + r, b, err := at.request(http.MethodGet, "/user/apikeys/"+id.Hex(), nil, nil, nil) + if err != nil { + return api.APIKeyResponse{}, r.StatusCode, err + } + if r.StatusCode != http.StatusOK { + return api.APIKeyResponse{}, r.StatusCode, errors.New(string(b)) + } + var result api.APIKeyResponse + err = json.Unmarshal(b, &result) + if err != nil { + return api.APIKeyResponse{}, 0, errors.AddContext(err, "failed to parse response") + } + return result, r.StatusCode, nil +} + +// UserAPIKeysLIST performs a `GET /user/apikeys` request. +func (at *AccountsTester) UserAPIKeysLIST() ([]api.APIKeyResponse, int, error) { + r, b, err := at.request(http.MethodGet, "/user/apikeys", nil, nil, nil) + if err != nil { + return nil, r.StatusCode, err + } + if r.StatusCode != http.StatusOK { + return nil, r.StatusCode, errors.New(string(b)) + } + result := make([]api.APIKeyResponse, 0) + err = json.Unmarshal(b, &result) + if err != nil { + return nil, 0, errors.AddContext(err, "failed to parse response") + } + return result, r.StatusCode, nil +} + // UserAPIKeysPOST performs a `POST /user/apikeys` request. func (at *AccountsTester) UserAPIKeysPOST(body api.APIKeyPOST) (api.APIKeyResponseWithKey, int, error) { bb, err := json.Marshal(body) From 87b8eefbdc99b6117105c8182de39fe47a5b2989 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Fri, 11 Mar 2022 13:39:17 +0100 Subject: [PATCH 15/25] Add the rest of the API key integration tests. --- database/apikeys.go | 20 +++- database/apikeys_test.go | 59 ++++++++++- test/api/auth_test.go | 1 - test/database/apikeys_test.go | 150 ++++++++++++++++++++++++++++ test/database/challenge_test.go | 4 + test/database/configuration_test.go | 2 + test/database/upload_test.go | 2 + test/database/user_test.go | 24 +++++ 8 files changed, 256 insertions(+), 6 deletions(-) delete mode 100644 test/api/auth_test.go diff --git a/database/apikeys.go b/database/apikeys.go index f31263ef..0dc2b0af 100644 --- a/database/apikeys.go +++ b/database/apikeys.go @@ -253,8 +253,14 @@ func (db *DB) APIKeyUpdate(ctx context.Context, user User, akID primitive.Object opts := options.UpdateOptions{ Upsert: &False, } - _, err := db.staticAPIKeys.UpdateOne(ctx, filter, update, &opts) - return err + ur, err := db.staticAPIKeys.UpdateOne(ctx, filter, update, &opts) + if err != nil { + return err + } + if ur.ModifiedCount == 0 { + return errors.New("public API key not found, no keys updated") + } + return nil } // APIKeyPatch updates an existing API key. This works by adding and removing @@ -282,10 +288,13 @@ func (db *DB) APIKeyPatch(ctx context.Context, user User, akID primitive.ObjectI opts := options.UpdateOptions{ Upsert: &False, } - _, err := db.staticAPIKeys.UpdateOne(ctx, filter, update, &opts) + ur, err := db.staticAPIKeys.UpdateOne(ctx, filter, update, &opts) if err != nil { return err } + if ur.ModifiedCount == 0 { + return errors.New("public API key not found, no keys updated") + } } // Then, remove all skylinks that need to be removed. if len(removeSkylinks) > 0 { @@ -295,10 +304,13 @@ func (db *DB) APIKeyPatch(ctx context.Context, user User, akID primitive.ObjectI opts := options.UpdateOptions{ Upsert: &False, } - _, err := db.staticAPIKeys.UpdateOne(ctx, filter, update, &opts) + ur, err := db.staticAPIKeys.UpdateOne(ctx, filter, update, &opts) if err != nil { return err } + if ur.ModifiedCount == 0 { + return errors.New("public API key not found, no keys updated") + } } return nil } diff --git a/database/apikeys_test.go b/database/apikeys_test.go index e2fc9032..6e24675a 100644 --- a/database/apikeys_test.go +++ b/database/apikeys_test.go @@ -1,6 +1,8 @@ package database -import "testing" +import ( + "testing" +) // TestNewAPIKeyFromString validates that NewAPIKeyFromString properly handles // valid API keys, upper case or lower case. @@ -27,3 +29,58 @@ func TestNewAPIKeyFromString(t *testing.T) { } } } + +// TestCoversSkylink ensures that CoversSkylink works as expected. +func TestCoversSkylink(t *testing.T) { + sl1 := "6TAOK0RVVKKK25PIA33FHDBD1G04DLO015DAAD6OM2J33KCD5CL0" + sl2 := "7TAOK0RVVKKK25PIA33FHDBD1G04DLO015DAAD6OM2J33KCD5CL0" + sl3 := "8TAOK0RVVKKK25PIA33FHDBD1G04DLO015DAAD6OM2J33KCD5CL0" + akr1 := APIKeyRecord{ + Public: false, + Key: NewAPIKey(), + } + akr2 := APIKeyRecord{ + Public: true, + Key: NewAPIKey(), + Skylinks: []string{sl1, sl2}, + } + + tests := []struct { + name string + key APIKeyRecord + skylinkToCheck string + expectedCovered bool + }{ + { + name: "general API key", + key: akr1, + skylinkToCheck: sl3, + expectedCovered: true, + }, + { + name: "public API key 1", + key: akr2, + skylinkToCheck: sl1, + expectedCovered: true, + }, + { + name: "public API key 2", + key: akr2, + skylinkToCheck: sl2, + expectedCovered: true, + }, + { + name: "public API key 3", + key: akr2, + skylinkToCheck: sl3, + expectedCovered: false, + }, + } + + for _, tt := range tests { + covered := tt.key.CoversSkylink(tt.skylinkToCheck) + if covered != tt.expectedCovered { + t.Errorf("Unexpected result for test %s", tt.name) + } + } +} diff --git a/test/api/auth_test.go b/test/api/auth_test.go deleted file mode 100644 index 778f64ec..00000000 --- a/test/api/auth_test.go +++ /dev/null @@ -1 +0,0 @@ -package api diff --git a/test/database/apikeys_test.go b/test/database/apikeys_test.go index 636bab89..5b10151f 100644 --- a/test/database/apikeys_test.go +++ b/test/database/apikeys_test.go @@ -1 +1,151 @@ package database + +import ( + "context" + "testing" + + "github.com/SkynetLabs/skynet-accounts/database" + "github.com/SkynetLabs/skynet-accounts/test" +) + +// TestAPIKeys ensures the DB operations with API keys work as expected. +func TestAPIKeys(t *testing.T) { + t.Parallel() + + ctx := context.Background() + dbName := test.DBNameForTest(t.Name()) + db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) + if err != nil { + t.Fatal(err) + } + u, err := db.UserCreate(ctx, "", "", t.Name(), database.TierFree) + if err != nil { + t.Fatal("Unexpected error", err) + } + sl1 := test.RandomSkylink() + sl2 := test.RandomSkylink() + + // Create a private API key. + akr1, err := db.APIKeyCreate(ctx, *u, false, nil) + if err != nil { + t.Fatal(err) + } + // Create a private API key with skylinks. Expect to fail. + _, err = db.APIKeyCreate(ctx, *u, false, []string{sl1}) + if err == nil { + t.Fatal("Managed to create a private API key with skylinks.") + } + // Create a public API key + akr2, err := db.APIKeyCreate(ctx, *u, true, []string{sl1}) + if err != nil { + t.Fatal(err) + } + // Create a public API key without any skylinks. + akr3, err := db.APIKeyCreate(ctx, *u, true, nil) + if err != nil { + t.Fatal(err) + } + // Get an API key. + akr1a, err := db.APIKeyGet(ctx, akr1.ID) + if err != nil { + t.Fatal(err) + } + if akr1a.ID.Hex() != akr1.ID.Hex() { + t.Fatal("Did not get the correct API key!") + } + // Get an API key by key. + akr1a, err = db.APIKeyByKey(ctx, akr1.Key.String()) + if err != nil { + t.Fatal(err) + } + if akr1a.ID.Hex() != akr1.ID.Hex() { + t.Fatal("Did not get the correct API key by key!") + } + // List API keys. + akrs, err := db.APIKeyList(ctx, *u) + if err != nil { + t.Fatal(err) + } + if len(akrs) != 3 { + t.Fatalf("Expected %d API keys, got %d", 3, len(akrs)) + } + // Check if all keys we expect to exist actually exist. + found := 0 + for _, akr := range []*database.APIKeyRecord{akr1, akr2, akr3} { + for _, akrFound := range akrs { + if akrFound.ID.Hex() == akr.ID.Hex() { + found++ + } + } + } + if found != 3 { + t.Fatalf("Expected to find %d API keys we expect, found %d", 3, found) + } + + // Try to update a general API key. Expect to fail. + err = db.APIKeyUpdate(ctx, *u, akr1.ID, []string{sl1}) + if err == nil { + t.Fatal("Expected to be unable to update general API key.") + } + err = db.APIKeyPatch(ctx, *u, akr1.ID, []string{sl1}, nil) + if err == nil { + t.Fatal("Expected to be unable to patch general API key.") + } + // Update a public API key. + err = db.APIKeyUpdate(ctx, *u, akr2.ID, []string{sl1, sl2}) + if err != nil { + t.Fatal(err) + } + // Verify. + akr2a, err := db.APIKeyGet(ctx, akr2.ID) + if err != nil { + t.Fatal(err) + } + if !akr2a.CoversSkylink(sl1) || !akr2a.CoversSkylink(sl2) { + t.Fatal("Expected the API to cover both skylinks.") + } + // Patch a public API key. + err = db.APIKeyPatch(ctx, *u, akr2.ID, nil, []string{sl2}) + if err != nil { + t.Fatal(err) + } + // Verify. + akr2b, err := db.APIKeyGet(ctx, akr2.ID) + if err != nil { + t.Fatal(err) + } + if !akr2b.CoversSkylink(sl1) || akr2b.CoversSkylink(sl2) { + t.Fatal("Expected the API to cover one but not the other skylink.") + } + + // Delete a general API key. + err = db.APIKeyDelete(ctx, *u, akr1.ID) + if err != nil { + t.Fatal(err) + } + // Verify. + akrs, err = db.APIKeyList(ctx, *u) + if err != nil { + t.Fatal(err) + } + for _, akr := range akrs { + if akr.ID.Hex() == akr1.ID.Hex() { + t.Fatal("Expected the API key to be gone but it's not.") + } + } + // Delete a public API key. + err = db.APIKeyDelete(ctx, *u, akr2.ID) + if err != nil { + t.Fatal(err) + } + // Verify. + akrs, err = db.APIKeyList(ctx, *u) + if err != nil { + t.Fatal(err) + } + for _, akr := range akrs { + if akr.ID.Hex() == akr2.ID.Hex() { + t.Fatal("Expected the API key to be gone but it's not.") + } + } +} diff --git a/test/database/challenge_test.go b/test/database/challenge_test.go index b1dd3b70..c6ac508a 100644 --- a/test/database/challenge_test.go +++ b/test/database/challenge_test.go @@ -18,6 +18,8 @@ import ( // TestValidateChallengeResponse is a unit test using a database. func TestValidateChallengeResponse(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -126,6 +128,8 @@ func TestValidateChallengeResponse(t *testing.T) { // TestUnconfirmedUserUpdate ensures the entire flow for unconfirmed user // updates works as expected. func TestUnconfirmedUserUpdate(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) diff --git a/test/database/configuration_test.go b/test/database/configuration_test.go index 14fd3c42..b9de09bd 100644 --- a/test/database/configuration_test.go +++ b/test/database/configuration_test.go @@ -13,6 +13,8 @@ import ( // TestConfiguration ensures we can correctly read and write from/to the // configuration DB table. func TestConfiguration(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) diff --git a/test/database/upload_test.go b/test/database/upload_test.go index 4a6eb26d..8d839846 100644 --- a/test/database/upload_test.go +++ b/test/database/upload_test.go @@ -14,6 +14,8 @@ import ( // TestUploadsByUser ensures UploadsByUser returns the correct uploads, // in the correct order, with the correct sized and so on. func TestUploadsByUser(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) diff --git a/test/database/user_test.go b/test/database/user_test.go index f8caa4d6..b6143894 100644 --- a/test/database/user_test.go +++ b/test/database/user_test.go @@ -19,6 +19,8 @@ import ( // TestUserByEmail ensures UserByEmail works as expected. // This method also tests UserCreate. func TestUserByEmail(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -62,6 +64,8 @@ func TestUserByEmail(t *testing.T) { // TestUserByID ensures UserByID works as expected. func TestUserByID(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -101,6 +105,8 @@ func TestUserByID(t *testing.T) { // TestUserByPubKey makes sure UserByPubKey functions correctly, both with a // single and multiple pubkeys attached to a user. func TestUserByPubKey(t *testing.T) { + t.Parallel() + ctx := context.Background() name := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, name, test.DBTestCredentials(), nil) @@ -152,6 +158,8 @@ func TestUserByPubKey(t *testing.T) { // TestUserByStripeID ensures UserByStripeID works as expected. // This method also tests UserCreate and UserSetStripeID. func TestUserByStripeID(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -196,6 +204,8 @@ func TestUserByStripeID(t *testing.T) { // TestUserBySub ensures UserBySub works as expected. // This method also tests UserCreate. func TestUserBySub(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -238,6 +248,8 @@ func TestUserBySub(t *testing.T) { // TestUserConfirmEmail ensures that email confirmation works as expected, // including resecting the expiration of tokens. func TestUserConfirmEmail(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -276,6 +288,8 @@ func TestUserConfirmEmail(t *testing.T) { // TestUserCreate ensures UserCreate works as expected. func TestUserCreate(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -325,6 +339,8 @@ func TestUserCreate(t *testing.T) { // TestUserDelete ensures UserDelete works as expected. func TestUserDelete(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -363,6 +379,8 @@ func TestUserDelete(t *testing.T) { // TestUserSave ensures that UserSave works as expected. func TestUserSave(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -409,6 +427,8 @@ func TestUserSave(t *testing.T) { // TestUserSetStripeID ensures that UserSetStripeID works as expected. func TestUserSetStripeID(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -440,6 +460,8 @@ func TestUserSetStripeID(t *testing.T) { // TestUserSetTier ensures that UserSetTier works as expected. func TestUserSetTier(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -469,6 +491,8 @@ func TestUserSetTier(t *testing.T) { // TestUserStats ensures we report accurate statistics for users. func TestUserStats(t *testing.T) { + t.Parallel() + ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) From b32108d47e183018b0dcb5e4e440b22cb86bd917 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Fri, 11 Mar 2022 13:44:57 +0100 Subject: [PATCH 16/25] Clean up. --- api/apikeys.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/api/apikeys.go b/api/apikeys.go index f14ef010..fc986ca3 100644 --- a/api/apikeys.go +++ b/api/apikeys.go @@ -29,7 +29,6 @@ type ( Remove []string } // APIKeyResponse is an API DTO which mirrors database.APIKey. - // TODO Should we reveal the Key each time for public keys? APIKeyResponse struct { ID primitive.ObjectID `json:"id"` UserID primitive.ObjectID `json:"-"` @@ -41,7 +40,6 @@ type ( // APIKeyResponseWithKey is an API DTO which mirrors database.APIKey but // also reveals the value of the Key field. This should only be used on key // creation. - // TODO Should we reveal the Key each time for public keys? APIKeyResponseWithKey struct { APIKeyResponse Key database.APIKey `json:"key"` @@ -62,7 +60,6 @@ func (akp APIKeyPOST) Valid() bool { } // FromAPIKey populates the struct's fields from the given API key. -// TODO This might be more convenient as a constructor. func (rwk *APIKeyResponse) FromAPIKey(ak database.APIKeyRecord) { rwk.ID = ak.ID rwk.UserID = ak.UserID @@ -73,7 +70,6 @@ func (rwk *APIKeyResponse) FromAPIKey(ak database.APIKeyRecord) { } // FromAPIKey populates the struct's fields from the given API key. -// TODO This might be more convenient as a constructor. func (rwk *APIKeyResponseWithKey) FromAPIKey(ak database.APIKeyRecord) { rwk.ID = ak.ID rwk.UserID = ak.UserID From fb443e6da7a2e43782f6181478277137bc69a9db Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Tue, 15 Mar 2022 14:01:54 +0100 Subject: [PATCH 17/25] APIKeyPOST.Validate() returns descriptive errors. --- api/apikeys.go | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/api/apikeys.go b/api/apikeys.go index fc986ca3..35b7d14b 100644 --- a/api/apikeys.go +++ b/api/apikeys.go @@ -46,17 +46,21 @@ type ( } ) -// Valid checks if the request and its parts are valid. -func (akp APIKeyPOST) Valid() bool { +// Validate checks if the request and its parts are valid. +func (akp APIKeyPOST) Validate() error { if !akp.Public && len(akp.Skylinks) > 0 { - return false + return errors.New("public API keys cannot refer to skylinlks") } + errs := make([]error, 0) for _, s := range akp.Skylinks { if !database.ValidSkylinkHash(s) { - return false + errs = append(errs, errors.New("invalid skylink:"+s)) } } - return true + if len(errs) > 0 { + return errors.Compose(errs...) + } + return nil } // FromAPIKey populates the struct's fields from the given API key. @@ -87,6 +91,10 @@ func (api *API) userAPIKeyPOST(u *database.User, w http.ResponseWriter, req *htt api.WriteError(w, err, http.StatusBadRequest) return } + if err := body.Validate(); err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } ak, err := api.staticDB.APIKeyCreate(req.Context(), *u, body.Public, body.Skylinks) if errors.Contains(err, database.ErrMaxNumAPIKeysExceeded) { err = errors.AddContext(err, "the maximum number of API keys a user can create is "+strconv.Itoa(database.MaxNumAPIKeysPerUser)) From c0100d53c69971fb95327d1fe32417716dbc65e5 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Tue, 15 Mar 2022 16:43:12 +0100 Subject: [PATCH 18/25] Address PR comments. --- api/apikeys.go | 54 +++++++++++++++++++++------------------- api/handlers.go | 2 +- database/apikeys.go | 6 ++--- test/api/apikeys_test.go | 2 +- 4 files changed, 33 insertions(+), 31 deletions(-) diff --git a/api/apikeys.go b/api/apikeys.go index 35b7d14b..317cb4e6 100644 --- a/api/apikeys.go +++ b/api/apikeys.go @@ -63,24 +63,32 @@ func (akp APIKeyPOST) Validate() error { return nil } -// FromAPIKey populates the struct's fields from the given API key. -func (rwk *APIKeyResponse) FromAPIKey(ak database.APIKeyRecord) { - rwk.ID = ak.ID - rwk.UserID = ak.UserID - rwk.Public = ak.Public - rwk.Key = ak.Key - rwk.Skylinks = ak.Skylinks - rwk.CreatedAt = ak.CreatedAt +// APIKeyResponseFromAPIKey creates a new APIKeyResponse from the given API key. +func APIKeyResponseFromAPIKey(ak database.APIKeyRecord) *APIKeyResponse { + return &APIKeyResponse{ + ID: ak.ID, + UserID: ak.UserID, + Public: ak.Public, + Key: ak.Key, + Skylinks: ak.Skylinks, + CreatedAt: ak.CreatedAt, + } } -// FromAPIKey populates the struct's fields from the given API key. -func (rwk *APIKeyResponseWithKey) FromAPIKey(ak database.APIKeyRecord) { - rwk.ID = ak.ID - rwk.UserID = ak.UserID - rwk.Public = ak.Public - rwk.Key = ak.Key - rwk.Skylinks = ak.Skylinks - rwk.CreatedAt = ak.CreatedAt +// APIKeyResponseWithKeyFromAPIKey creates a new APIKeyResponseWithKey from the +// given API key. +func APIKeyResponseWithKeyFromAPIKey(ak database.APIKeyRecord) *APIKeyResponseWithKey { + return &APIKeyResponseWithKey{ + APIKeyResponse: APIKeyResponse{ + ID: ak.ID, + UserID: ak.UserID, + Public: ak.Public, + Key: ak.Key, + Skylinks: ak.Skylinks, + CreatedAt: ak.CreatedAt, + }, + Key: ak.Key, + } } // userAPIKeyPOST creates a new API key for the user. @@ -105,9 +113,7 @@ func (api *API) userAPIKeyPOST(u *database.User, w http.ResponseWriter, req *htt api.WriteError(w, err, http.StatusInternalServerError) return } - var resp APIKeyResponseWithKey - resp.FromAPIKey(*ak) - api.WriteJSON(w, resp) + api.WriteJSON(w, APIKeyResponseWithKeyFromAPIKey(*ak)) } // userAPIKeyGET returns a single API key. @@ -127,9 +133,7 @@ func (api *API) userAPIKeyGET(u *database.User, w http.ResponseWriter, req *http api.WriteError(w, err, http.StatusInternalServerError) return } - var resp APIKeyResponse - resp.FromAPIKey(ak) - api.WriteJSON(w, resp) + api.WriteJSON(w, APIKeyResponseFromAPIKey(ak)) } // userAPIKeyLIST lists all API keys associated with the user. @@ -139,11 +143,9 @@ func (api *API) userAPIKeyLIST(u *database.User, w http.ResponseWriter, req *htt api.WriteError(w, err, http.StatusInternalServerError) return } - resp := make([]APIKeyResponse, 0, len(aks)) + resp := make([]*APIKeyResponse, 0, len(aks)) for _, ak := range aks { - var r APIKeyResponse - r.FromAPIKey(ak) - resp = append(resp, r) + resp = append(resp, APIKeyResponseFromAPIKey(ak)) } api.WriteJSON(w, resp) } diff --git a/api/handlers.go b/api/handlers.go index fff1547e..f5af12b5 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -484,7 +484,7 @@ func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, re // Validate the skylink. skylink := ps.ByName("skylink") if !database.ValidSkylinkHash(skylink) { - api.staticLogger.Tracef("Invalid skylink: %s", skylink) + api.staticLogger.Tracef("Invalid skylink: '%s'", skylink) api.WriteJSON(w, respAnon) return } diff --git a/database/apikeys.go b/database/apikeys.go index 0dc2b0af..03a94730 100644 --- a/database/apikeys.go +++ b/database/apikeys.go @@ -241,7 +241,7 @@ func (db *DB) APIKeyUpdate(ctx context.Context, user User, akID primitive.Object // Validate all given skylinks. for _, s := range skylinks { if !ValidSkylinkHash(s) { - return ErrInvalidSkylink + return errors.AddContext(ErrInvalidSkylink, "offending skylink: "+s) } } filter := bson.M{ @@ -272,7 +272,7 @@ func (db *DB) APIKeyPatch(ctx context.Context, user User, akID primitive.ObjectI // Validate all given skylinks. for _, s := range append(addSkylinks, removeSkylinks...) { if !ValidSkylinkHash(s) { - return ErrInvalidSkylink + return errors.AddContext(ErrInvalidSkylink, "offending skylink: "+s) } } filter := bson.M{ @@ -283,7 +283,7 @@ func (db *DB) APIKeyPatch(ctx context.Context, user User, akID primitive.ObjectI // First, all new skylinks to the record. if len(addSkylinks) > 0 { update = bson.M{ - "$push": bson.M{"skylinks": bson.M{"$each": addSkylinks}}, + "$addToSet": bson.M{"skylinks": bson.M{"$each": addSkylinks}}, } opts := options.UpdateOptions{ Upsert: &False, diff --git a/test/api/apikeys_test.go b/test/api/apikeys_test.go index f2d5ba6e..b323587c 100644 --- a/test/api/apikeys_test.go +++ b/test/api/apikeys_test.go @@ -135,7 +135,7 @@ func testPrivateAPIKeysUsage(t *testing.T, at *test.AccountsTester) { } } -// TestPublicAPIKeyFlow validates the creation, listing, and deletion of public +// testPublicAPIKeysFlow validates the creation, listing, and deletion of public // API keys. func testPublicAPIKeysFlow(t *testing.T, at *test.AccountsTester) { name := test.DBNameForTest(t.Name()) From d79e7430d571577a6795aa93f943e3adf43ded3f Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Tue, 15 Mar 2022 17:26:00 +0100 Subject: [PATCH 19/25] Move DB schema to a separate file. --- database/database.go | 111 +---------------------------------------- database/schema.go | 116 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 109 deletions(-) create mode 100644 database/schema.go diff --git a/database/database.go b/database/database.go index ca71d16a..184d0d30 100644 --- a/database/database.go +++ b/database/database.go @@ -131,7 +131,7 @@ func NewCustomDB(ctx context.Context, dbName string, creds DBCredentials, logger if logger == nil { logger = &logrus.Logger{} } - err = ensureDBSchema(ctx, database, logger) + err = ensureDBSchema(ctx, database, Schema, logger) if err != nil { return nil, err } @@ -188,114 +188,7 @@ func connectionString(creds DBCredentials) string { // creates them if needed. // See https://docs.mongodb.com/manual/indexes/ // See https://docs.mongodb.com/manual/core/index-unique/ -func ensureDBSchema(ctx context.Context, db *mongo.Database, log *logrus.Logger) error { - // schema defines a mapping between a collection name and the indexes that - // must exist for that collection. - schema := map[string][]mongo.IndexModel{ - collUsers: { - { - Keys: bson.D{{"sub", 1}}, - Options: options.Index().SetName("sub_unique").SetUnique(true), - }, - }, - collSkylinks: { - { - Keys: bson.D{{"skylink", 1}}, - Options: options.Index().SetName("skylink_unique").SetUnique(true), - }, - }, - collUploads: { - { - Keys: bson.D{{"user_id", 1}}, - Options: options.Index().SetName("user_id"), - }, - { - Keys: bson.D{{"skylink_id", 1}}, - Options: options.Index().SetName("skylink_id"), - }, - }, - collDownloads: { - { - Keys: bson.D{{"user_id", 1}}, - Options: options.Index().SetName("user_id"), - }, - { - Keys: bson.D{{"skylink_id", 1}}, - Options: options.Index().SetName("skylink_id"), - }, - }, - collRegistryReads: { - { - Keys: bson.D{{"user_id", 1}}, - Options: options.Index().SetName("user_id"), - }, - }, - collRegistryWrites: { - { - Keys: bson.D{{"user_id", 1}}, - Options: options.Index().SetName("user_id"), - }, - }, - collEmails: { - { - Keys: bson.D{{"failed_attempts", 1}}, - Options: options.Index().SetName("failed_attempts"), - }, - { - Keys: bson.D{{"locked_by", 1}}, - Options: options.Index().SetName("locked_by"), - }, - { - Keys: bson.D{{"sent_at", 1}}, - Options: options.Index().SetName("sent_at"), - }, - { - Keys: bson.D{{"sent_by", 1}}, - Options: options.Index().SetName("sent_by"), - }, - }, - collChallenges: { - { - Keys: bson.D{{"challenge", 1}}, - Options: options.Index().SetName("challenge"), - }, - { - Keys: bson.D{{"type", 1}}, - Options: options.Index().SetName("type"), - }, - { - Keys: bson.D{{"expires_at", 1}}, - Options: options.Index().SetName("expires_at"), - }, - }, - collUnconfirmedUserUpdates: { - { - Keys: bson.D{{"challenge_id", 1}}, - Options: options.Index().SetName("challenge_id"), - }, - { - Keys: bson.D{{"expires_at", 1}}, - Options: options.Index().SetName("expires_at"), - }, - }, - collConfiguration: { - { - Keys: bson.D{{"key", 1}}, - Options: options.Index().SetName("key_unique").SetUnique(true), - }, - }, - collAPIKeys: { - { - Keys: bson.D{{"key", 1}}, - Options: options.Index().SetName("key_unique").SetUnique(true), - }, - { - Keys: bson.D{{"user_id", 1}}, - Options: options.Index().SetName("user_id"), - }, - }, - } - +func ensureDBSchema(ctx context.Context, db *mongo.Database, schema map[string][]mongo.IndexModel, log *logrus.Logger) error { for collName, models := range schema { coll, err := ensureCollection(ctx, db, collName) if err != nil { diff --git a/database/schema.go b/database/schema.go new file mode 100644 index 00000000..34138e35 --- /dev/null +++ b/database/schema.go @@ -0,0 +1,116 @@ +package database + +import ( + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +var ( + // Schema defines a mapping between a collection name and the indexes that + // must exist for that collection. + Schema = map[string][]mongo.IndexModel{ + collUsers: { + { + Keys: bson.D{{"sub", 1}}, + Options: options.Index().SetName("sub_unique").SetUnique(true), + }, + }, + collSkylinks: { + { + Keys: bson.D{{"skylink", 1}}, + Options: options.Index().SetName("skylink_unique").SetUnique(true), + }, + }, + collUploads: { + { + Keys: bson.D{{"user_id", 1}}, + Options: options.Index().SetName("user_id"), + }, + { + Keys: bson.D{{"skylink_id", 1}}, + Options: options.Index().SetName("skylink_id"), + }, + }, + collDownloads: { + { + Keys: bson.D{{"user_id", 1}}, + Options: options.Index().SetName("user_id"), + }, + { + Keys: bson.D{{"skylink_id", 1}}, + Options: options.Index().SetName("skylink_id"), + }, + }, + collRegistryReads: { + { + Keys: bson.D{{"user_id", 1}}, + Options: options.Index().SetName("user_id"), + }, + }, + collRegistryWrites: { + { + Keys: bson.D{{"user_id", 1}}, + Options: options.Index().SetName("user_id"), + }, + }, + collEmails: { + { + Keys: bson.D{{"failed_attempts", 1}}, + Options: options.Index().SetName("failed_attempts"), + }, + { + Keys: bson.D{{"locked_by", 1}}, + Options: options.Index().SetName("locked_by"), + }, + { + Keys: bson.D{{"sent_at", 1}}, + Options: options.Index().SetName("sent_at"), + }, + { + Keys: bson.D{{"sent_by", 1}}, + Options: options.Index().SetName("sent_by"), + }, + }, + collChallenges: { + { + Keys: bson.D{{"challenge", 1}}, + Options: options.Index().SetName("challenge"), + }, + { + Keys: bson.D{{"type", 1}}, + Options: options.Index().SetName("type"), + }, + { + Keys: bson.D{{"expires_at", 1}}, + Options: options.Index().SetName("expires_at"), + }, + }, + collUnconfirmedUserUpdates: { + { + Keys: bson.D{{"challenge_id", 1}}, + Options: options.Index().SetName("challenge_id"), + }, + { + Keys: bson.D{{"expires_at", 1}}, + Options: options.Index().SetName("expires_at"), + }, + }, + collConfiguration: { + { + Keys: bson.D{{"key", 1}}, + Options: options.Index().SetName("key_unique").SetUnique(true), + }, + }, + collAPIKeys: { + { + Keys: bson.D{{"key", 1}}, + Options: options.Index().SetName("key_unique").SetUnique(true), + }, + { + Keys: bson.D{{"user_id", 1}}, + Options: options.Index().SetName("user_id"), + }, + }, + } +) From f0eaaace2d7ef6eb8918e83f18798869a227ff7d Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Wed, 16 Mar 2022 11:34:30 +0100 Subject: [PATCH 20/25] Fix a nullpointer in tester. --- test/tester.go | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/test/tester.go b/test/tester.go index 4b843e7f..e2007c3d 100644 --- a/test/tester.go +++ b/test/tester.go @@ -58,6 +58,7 @@ func ExtractCookie(r *http.Response) *http.Cookie { func NewAccountsTester(dbName string) (*AccountsTester, error) { ctx := context.Background() logger := logrus.New() + logger.Out = ioutil.Discard // Initialise the environment. jwt.PortalName = testPortalAddr @@ -116,8 +117,8 @@ func NewAccountsTester(dbName string) (*AccountsTester, error) { } // Wait for the accounts tester to be fully ready. err = build.Retry(50, time.Millisecond, func() error { - _, _, err = at.HealthGet() - return err + _, _, e := at.HealthGet() + return e }) if err != nil { return nil, errors.AddContext(err, "failed to start accounts tester in the given time") @@ -169,14 +170,14 @@ func (at *AccountsTester) SetToken(t string) { // Get executes a GET request against the test service. // // NOTE: The Body of the returned response is already read and closed. -func (at *AccountsTester) Get(endpoint string, params url.Values) (r *http.Response, body []byte, err error) { +func (at *AccountsTester) Get(endpoint string, params url.Values) (*http.Response, []byte, error) { return at.request(http.MethodGet, endpoint, params, nil, nil) } // Delete executes a DELETE request against the test service. // // NOTE: The Body of the returned response is already read and closed. -func (at *AccountsTester) Delete(endpoint string, params url.Values) (r *http.Response, body []byte, err error) { +func (at *AccountsTester) Delete(endpoint string, params url.Values) (*http.Response, []byte, error) { return at.request(http.MethodDelete, endpoint, params, nil, nil) } @@ -184,7 +185,7 @@ func (at *AccountsTester) Delete(endpoint string, params url.Values) (r *http.Re // // NOTE: The Body of the returned response is already read and closed. // TODO Remove the url.Values in favour of a simple map. -func (at *AccountsTester) Post(endpoint string, params url.Values, bodyParams url.Values) (r *http.Response, body []byte, err error) { +func (at *AccountsTester) Post(endpoint string, params url.Values, bodyParams url.Values) (*http.Response, []byte, error) { if params == nil { params = url.Values{} } @@ -197,12 +198,12 @@ func (at *AccountsTester) Post(endpoint string, params url.Values, bodyParams ur } bodyBytes, err := json.Marshal(bodyMap) if err != nil { - return + return &http.Response{}, nil, err } serviceURL := testPortalAddr + ":" + testPortalPort + endpoint + "?" + params.Encode() req, err := http.NewRequest(http.MethodPost, serviceURL, bytes.NewBuffer(bodyBytes)) if err != nil { - return nil, nil, err + return &http.Response{}, nil, err } req.Header.Set("Content-Type", "application/json") return at.executeRequest(req) @@ -211,10 +212,10 @@ func (at *AccountsTester) Post(endpoint string, params url.Values, bodyParams ur // Put executes a PUT request against the test service. // // NOTE: The Body of the returned response is already read and closed. -func (at *AccountsTester) Put(endpoint string, params url.Values, bodyParams url.Values) (r *http.Response, body []byte, err error) { +func (at *AccountsTester) Put(endpoint string, params url.Values, bodyParams url.Values) (*http.Response, []byte, error) { b, err := json.Marshal(bodyParams) if err != nil { - return nil, nil, errors.AddContext(err, "failed to marshal the body JSON") + return &http.Response{}, nil, errors.AddContext(err, "failed to marshal the body JSON") } return at.request(http.MethodPut, endpoint, params, b, nil) } @@ -222,10 +223,10 @@ func (at *AccountsTester) Put(endpoint string, params url.Values, bodyParams url // Patch executes a PATCH request against the test service. // // NOTE: The Body of the returned response is already read and closed. -func (at *AccountsTester) Patch(endpoint string, params url.Values, bodyParams url.Values) (r *http.Response, body []byte, err error) { +func (at *AccountsTester) Patch(endpoint string, params url.Values, bodyParams url.Values) (*http.Response, []byte, error) { b, err := json.Marshal(bodyParams) if err != nil { - return nil, nil, errors.AddContext(err, "failed to marshal the body JSON") + return &http.Response{}, nil, errors.AddContext(err, "failed to marshal the body JSON") } return at.request(http.MethodPatch, endpoint, params, b, nil) } @@ -233,7 +234,7 @@ func (at *AccountsTester) Patch(endpoint string, params url.Values, bodyParams u // CreateUserPost is a helper method that creates a new user. // // NOTE: The Body of the returned response is already read and closed. -func (at *AccountsTester) CreateUserPost(emailAddr, password string) (r *http.Response, body []byte, err error) { +func (at *AccountsTester) CreateUserPost(emailAddr, password string) (*http.Response, []byte, error) { params := url.Values{} params.Set("email", emailAddr) params.Set("password", password) @@ -251,11 +252,11 @@ func (at *AccountsTester) UserPUT(email, password, stipeID string) (*http.Respon "stripeCustomerId": stipeID, }) if err != nil { - return nil, nil, errors.AddContext(err, "failed to marshal the body JSON") + return &http.Response{}, nil, errors.AddContext(err, "failed to marshal the body JSON") } req, err := http.NewRequest(http.MethodPut, serviceURL, bytes.NewBuffer(b)) if err != nil { - return nil, nil, err + return &http.Response{}, nil, err } return at.executeRequest(req) } @@ -271,7 +272,7 @@ func (at *AccountsTester) request(method string, endpoint string, queryParams ur serviceURL := testPortalAddr + ":" + testPortalPort + endpoint + "?" + queryParams.Encode() req, err := http.NewRequest(method, serviceURL, bytes.NewBuffer(body)) if err != nil { - return nil, nil, err + return &http.Response{}, nil, err } for name, val := range headers { req.Header.Set(name, val) @@ -285,7 +286,7 @@ func (at *AccountsTester) request(method string, endpoint string, queryParams ur // NOTE: The Body of the returned response is already read and closed. func (at *AccountsTester) executeRequest(req *http.Request) (*http.Response, []byte, error) { if req == nil { - return nil, nil, errors.New("invalid request") + return &http.Response{}, nil, errors.New("invalid request") } if at.APIKey != "" { req.Header.Set(api.APIKeyHeader, at.APIKey) @@ -299,7 +300,7 @@ func (at *AccountsTester) executeRequest(req *http.Request) (*http.Response, []b client := http.Client{} r, err := client.Do(req) if err != nil { - return nil, nil, err + return &http.Response{}, nil, err } return processResponse(r) } From 26a2332dddb639ca1d7a5e7617eaac93ce4f7797 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Wed, 16 Mar 2022 14:25:16 +0100 Subject: [PATCH 21/25] Unparallelise tests that lead to data races. --- database/database.go | 33 ++++++++++++++--------------- test/database/apikeys_test.go | 2 -- test/database/challenge_test.go | 4 ---- test/database/configuration_test.go | 2 -- test/database/upload_test.go | 2 -- test/database/user_test.go | 24 --------------------- 6 files changed, 16 insertions(+), 51 deletions(-) diff --git a/database/database.go b/database/database.go index 184d0d30..14344844 100644 --- a/database/database.go +++ b/database/database.go @@ -127,30 +127,29 @@ func NewCustomDB(ctx context.Context, dbName string, creds DBCredentials, logger if err != nil { return nil, errors.AddContext(err, "failed to connect to DB") } - database := c.Database(dbName) + db := c.Database(dbName) if logger == nil { logger = &logrus.Logger{} } - err = ensureDBSchema(ctx, database, Schema, logger) + err = ensureDBSchema(ctx, db, Schema, logger) if err != nil { return nil, err } - db := &DB{ - staticDB: database, - staticUsers: database.Collection(collUsers), - staticSkylinks: database.Collection(collSkylinks), - staticUploads: database.Collection(collUploads), - staticDownloads: database.Collection(collDownloads), - staticRegistryReads: database.Collection(collRegistryReads), - staticRegistryWrites: database.Collection(collRegistryWrites), - staticEmails: database.Collection(collEmails), - staticChallenges: database.Collection(collChallenges), - staticUnconfirmedUserUpdates: database.Collection(collUnconfirmedUserUpdates), - staticConfiguration: database.Collection(collConfiguration), - staticAPIKeys: database.Collection(collAPIKeys), + return &DB{ + staticDB: db, + staticUsers: db.Collection(collUsers), + staticSkylinks: db.Collection(collSkylinks), + staticUploads: db.Collection(collUploads), + staticDownloads: db.Collection(collDownloads), + staticRegistryReads: db.Collection(collRegistryReads), + staticRegistryWrites: db.Collection(collRegistryWrites), + staticEmails: db.Collection(collEmails), + staticChallenges: db.Collection(collChallenges), + staticUnconfirmedUserUpdates: db.Collection(collUnconfirmedUserUpdates), + staticConfiguration: db.Collection(collConfiguration), + staticAPIKeys: db.Collection(collAPIKeys), staticLogger: logger, - } - return db, nil + }, nil } // Disconnect closes the connection to the database in an orderly fashion. diff --git a/test/database/apikeys_test.go b/test/database/apikeys_test.go index 5b10151f..1eb0327b 100644 --- a/test/database/apikeys_test.go +++ b/test/database/apikeys_test.go @@ -10,8 +10,6 @@ import ( // TestAPIKeys ensures the DB operations with API keys work as expected. func TestAPIKeys(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) diff --git a/test/database/challenge_test.go b/test/database/challenge_test.go index c6ac508a..b1dd3b70 100644 --- a/test/database/challenge_test.go +++ b/test/database/challenge_test.go @@ -18,8 +18,6 @@ import ( // TestValidateChallengeResponse is a unit test using a database. func TestValidateChallengeResponse(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -128,8 +126,6 @@ func TestValidateChallengeResponse(t *testing.T) { // TestUnconfirmedUserUpdate ensures the entire flow for unconfirmed user // updates works as expected. func TestUnconfirmedUserUpdate(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) diff --git a/test/database/configuration_test.go b/test/database/configuration_test.go index b9de09bd..14fd3c42 100644 --- a/test/database/configuration_test.go +++ b/test/database/configuration_test.go @@ -13,8 +13,6 @@ import ( // TestConfiguration ensures we can correctly read and write from/to the // configuration DB table. func TestConfiguration(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) diff --git a/test/database/upload_test.go b/test/database/upload_test.go index 8d839846..4a6eb26d 100644 --- a/test/database/upload_test.go +++ b/test/database/upload_test.go @@ -14,8 +14,6 @@ import ( // TestUploadsByUser ensures UploadsByUser returns the correct uploads, // in the correct order, with the correct sized and so on. func TestUploadsByUser(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) diff --git a/test/database/user_test.go b/test/database/user_test.go index b6143894..f8caa4d6 100644 --- a/test/database/user_test.go +++ b/test/database/user_test.go @@ -19,8 +19,6 @@ import ( // TestUserByEmail ensures UserByEmail works as expected. // This method also tests UserCreate. func TestUserByEmail(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -64,8 +62,6 @@ func TestUserByEmail(t *testing.T) { // TestUserByID ensures UserByID works as expected. func TestUserByID(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -105,8 +101,6 @@ func TestUserByID(t *testing.T) { // TestUserByPubKey makes sure UserByPubKey functions correctly, both with a // single and multiple pubkeys attached to a user. func TestUserByPubKey(t *testing.T) { - t.Parallel() - ctx := context.Background() name := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, name, test.DBTestCredentials(), nil) @@ -158,8 +152,6 @@ func TestUserByPubKey(t *testing.T) { // TestUserByStripeID ensures UserByStripeID works as expected. // This method also tests UserCreate and UserSetStripeID. func TestUserByStripeID(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -204,8 +196,6 @@ func TestUserByStripeID(t *testing.T) { // TestUserBySub ensures UserBySub works as expected. // This method also tests UserCreate. func TestUserBySub(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -248,8 +238,6 @@ func TestUserBySub(t *testing.T) { // TestUserConfirmEmail ensures that email confirmation works as expected, // including resecting the expiration of tokens. func TestUserConfirmEmail(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -288,8 +276,6 @@ func TestUserConfirmEmail(t *testing.T) { // TestUserCreate ensures UserCreate works as expected. func TestUserCreate(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -339,8 +325,6 @@ func TestUserCreate(t *testing.T) { // TestUserDelete ensures UserDelete works as expected. func TestUserDelete(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -379,8 +363,6 @@ func TestUserDelete(t *testing.T) { // TestUserSave ensures that UserSave works as expected. func TestUserSave(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -427,8 +409,6 @@ func TestUserSave(t *testing.T) { // TestUserSetStripeID ensures that UserSetStripeID works as expected. func TestUserSetStripeID(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -460,8 +440,6 @@ func TestUserSetStripeID(t *testing.T) { // TestUserSetTier ensures that UserSetTier works as expected. func TestUserSetTier(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -491,8 +469,6 @@ func TestUserSetTier(t *testing.T) { // TestUserStats ensures we report accurate statistics for users. func TestUserStats(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) From 032341206018d436d4a9bfa20b1170ad17f2aa8f Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Thu, 17 Mar 2022 14:10:53 +0100 Subject: [PATCH 22/25] Update api/apikeys.go Co-authored-by: Peter-Jan Brone --- api/apikeys.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/apikeys.go b/api/apikeys.go index 317cb4e6..3b29b56e 100644 --- a/api/apikeys.go +++ b/api/apikeys.go @@ -51,7 +51,7 @@ func (akp APIKeyPOST) Validate() error { if !akp.Public && len(akp.Skylinks) > 0 { return errors.New("public API keys cannot refer to skylinlks") } - errs := make([]error, 0) + var errs []error for _, s := range akp.Skylinks { if !database.ValidSkylinkHash(s) { errs = append(errs, errors.New("invalid skylink:"+s)) From 11e46ae775d04dadfa7d945d421a6376c41c7e69 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Thu, 17 Mar 2022 14:10:59 +0100 Subject: [PATCH 23/25] Update api/apikeys.go Co-authored-by: Peter-Jan Brone --- api/apikeys.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/apikeys.go b/api/apikeys.go index 3b29b56e..529041f8 100644 --- a/api/apikeys.go +++ b/api/apikeys.go @@ -49,7 +49,7 @@ type ( // Validate checks if the request and its parts are valid. func (akp APIKeyPOST) Validate() error { if !akp.Public && len(akp.Skylinks) > 0 { - return errors.New("public API keys cannot refer to skylinlks") + return errors.New("public API keys cannot refer to skylinks") } var errs []error for _, s := range akp.Skylinks { From d24e0ff228aae97d16de1a8249cdd59b71853925 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Thu, 17 Mar 2022 14:11:07 +0100 Subject: [PATCH 24/25] Update api/apikeys.go Co-authored-by: Peter-Jan Brone --- api/apikeys.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/apikeys.go b/api/apikeys.go index 529041f8..bbb3ffba 100644 --- a/api/apikeys.go +++ b/api/apikeys.go @@ -54,7 +54,7 @@ func (akp APIKeyPOST) Validate() error { var errs []error for _, s := range akp.Skylinks { if !database.ValidSkylinkHash(s) { - errs = append(errs, errors.New("invalid skylink:"+s)) + errs = append(errs, errors.New("invalid skylink: "+s)) } } if len(errs) > 0 { From f4bc82e0f71c6bb8101fa3767dfb81ffdf4fe8f8 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Thu, 17 Mar 2022 14:51:17 +0100 Subject: [PATCH 25/25] Add 404 errors. Cover other PR comments. --- api/apikeys.go | 10 +++++++++- api/auth.go | 4 ++++ database/apikeys.go | 6 +++--- test/api/api_test.go | 4 ---- test/api/apikeys_test.go | 6 +++--- 5 files changed, 19 insertions(+), 11 deletions(-) diff --git a/api/apikeys.go b/api/apikeys.go index bbb3ffba..e0a08d80 100644 --- a/api/apikeys.go +++ b/api/apikeys.go @@ -159,7 +159,7 @@ func (api *API) userAPIKeyDELETE(u *database.User, w http.ResponseWriter, req *h } err = api.staticDB.APIKeyDelete(req.Context(), *u, akID) if err == mongo.ErrNoDocuments { - api.WriteError(w, err, http.StatusBadRequest) + api.WriteError(w, err, http.StatusNotFound) return } if err != nil { @@ -183,6 +183,10 @@ func (api *API) userAPIKeyPUT(u *database.User, w http.ResponseWriter, req *http return } err = api.staticDB.APIKeyUpdate(req.Context(), *u, akID, body.Skylinks) + if errors.Contains(err, mongo.ErrNoDocuments) { + api.WriteError(w, err, http.StatusNotFound) + return + } if err != nil { api.WriteError(w, err, http.StatusInternalServerError) return @@ -206,6 +210,10 @@ func (api *API) userAPIKeyPATCH(u *database.User, w http.ResponseWriter, req *ht return } err = api.staticDB.APIKeyPatch(req.Context(), *u, akID, body.Add, body.Remove) + if errors.Contains(err, mongo.ErrNoDocuments) { + api.WriteError(w, err, http.StatusNotFound) + return + } if err != nil { api.WriteError(w, err, http.StatusInternalServerError) return diff --git a/api/auth.go b/api/auth.go index a378db16..a5ffc5bc 100644 --- a/api/auth.go +++ b/api/auth.go @@ -45,6 +45,10 @@ func (api *API) userAndTokenByAPIKey(req *http.Request) (*database.User, jwt2.To // If we're dealing with a public API key, we need to validate that this // request is a GET for a covered skylink. if akr.Public { + // Public API keys can only be used with GET. + if req.Method != http.MethodGet { + return nil, nil, database.ErrInvalidAPIKey + } sl, err := database.ExtractSkylinkHash(req.RequestURI) if err != nil || !akr.CoversSkylink(sl) { return nil, nil, database.ErrInvalidAPIKey diff --git a/database/apikeys.go b/database/apikeys.go index 03a94730..7bc41ded 100644 --- a/database/apikeys.go +++ b/database/apikeys.go @@ -246,7 +246,7 @@ func (db *DB) APIKeyUpdate(ctx context.Context, user User, akID primitive.Object } filter := bson.M{ "_id": akID, - "public": &True, // you can only update public API keys + "public": true, "user_id": user.ID, } update := bson.M{"$set": bson.M{"skylinks": skylinks}} @@ -258,7 +258,7 @@ func (db *DB) APIKeyUpdate(ctx context.Context, user User, akID primitive.Object return err } if ur.ModifiedCount == 0 { - return errors.New("public API key not found, no keys updated") + return mongo.ErrNoDocuments } return nil } @@ -293,7 +293,7 @@ func (db *DB) APIKeyPatch(ctx context.Context, user User, akID primitive.ObjectI return err } if ur.ModifiedCount == 0 { - return errors.New("public API key not found, no keys updated") + return mongo.ErrNoDocuments } } // Then, remove all skylinks that need to be removed. diff --git a/test/api/api_test.go b/test/api/api_test.go index 51a08dcf..a3f698b3 100644 --- a/test/api/api_test.go +++ b/test/api/api_test.go @@ -23,8 +23,6 @@ import ( // TestWithDBSession ensures that database transactions are started, committed, // and aborted properly. func TestWithDBSession(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -139,8 +137,6 @@ func TestWithDBSession(t *testing.T) { // TestUserTierCache ensures out tier cache works as expected. func TestUserTierCache(t *testing.T) { - t.Parallel() - dbName := test.DBNameForTest(t.Name()) at, err := test.NewAccountsTester(dbName) if err != nil { diff --git a/test/api/apikeys_test.go b/test/api/apikeys_test.go index b323587c..94a19f0c 100644 --- a/test/api/apikeys_test.go +++ b/test/api/apikeys_test.go @@ -81,10 +81,10 @@ func testPrivateAPIKeysFlow(t *testing.T, at *test.AccountsTester) { t.Fatalf("Missing key '%s'! Set: %+v", ak2.ID.Hex(), aks) } - // Try to delete the same key again. Expect a Bad Request. + // Try to delete the same key again. Expect a 404. status, _ = at.UserAPIKeysDELETE(ak1.ID) - if status != http.StatusBadRequest { - t.Fatalf("Expected status 400, got %d.", status) + if status != http.StatusNotFound { + t.Fatalf("Expected status 404, got %d.", status) } }