diff --git a/api/apikeys.go b/api/apikeys.go index 063d0518..e0a08d80 100644 --- a/api/apikeys.go +++ b/api/apikeys.go @@ -3,16 +3,107 @@ package api import ( "net/http" "strconv" + "time" "github.com/SkynetLabs/skynet-accounts/database" "github.com/julienschmidt/httprouter" "gitlab.com/NebulousLabs/errors" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" ) +type ( + // APIKeyPOST describes the body of a POST request that creates an API key + APIKeyPOST struct { + Public bool `json:"public,string"` + Skylinks []string `json:"skylinks"` + } + // APIKeyPUT describes the request body for updating an API key + APIKeyPUT struct { + Skylinks []string + } + // APIKeyPATCH describes the request body for updating an API key by + // providing only the requested changes + APIKeyPATCH struct { + Add []string + Remove []string + } + // APIKeyResponse is an API DTO which mirrors database.APIKey. + APIKeyResponse struct { + ID primitive.ObjectID `json:"id"` + UserID primitive.ObjectID `json:"-"` + Public bool `json:"public,string"` + Key database.APIKey `json:"-"` + Skylinks []string `json:"skylinks"` + CreatedAt time.Time `json:"createdAt"` + } + // APIKeyResponseWithKey is an API DTO which mirrors database.APIKey but + // also reveals the value of the Key field. This should only be used on key + // creation. + APIKeyResponseWithKey struct { + APIKeyResponse + Key database.APIKey `json:"key"` + } +) + +// Validate checks if the request and its parts are valid. +func (akp APIKeyPOST) Validate() error { + if !akp.Public && len(akp.Skylinks) > 0 { + return errors.New("public API keys cannot refer to skylinks") + } + var errs []error + for _, s := range akp.Skylinks { + if !database.ValidSkylinkHash(s) { + errs = append(errs, errors.New("invalid skylink: "+s)) + } + } + if len(errs) > 0 { + return errors.Compose(errs...) + } + return nil +} + +// APIKeyResponseFromAPIKey creates a new APIKeyResponse from the given API key. +func APIKeyResponseFromAPIKey(ak database.APIKeyRecord) *APIKeyResponse { + return &APIKeyResponse{ + ID: ak.ID, + UserID: ak.UserID, + Public: ak.Public, + Key: ak.Key, + Skylinks: ak.Skylinks, + CreatedAt: ak.CreatedAt, + } +} + +// APIKeyResponseWithKeyFromAPIKey creates a new APIKeyResponseWithKey from the +// given API key. +func APIKeyResponseWithKeyFromAPIKey(ak database.APIKeyRecord) *APIKeyResponseWithKey { + return &APIKeyResponseWithKey{ + APIKeyResponse: APIKeyResponse{ + ID: ak.ID, + UserID: ak.UserID, + Public: ak.Public, + Key: ak.Key, + Skylinks: ak.Skylinks, + CreatedAt: ak.CreatedAt, + }, + Key: ak.Key, + } +} + // userAPIKeyPOST creates a new API key for the user. func (api *API) userAPIKeyPOST(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { - ak, err := api.staticDB.APIKeyCreate(req.Context(), *u) + var body APIKeyPOST + err := parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + if err := body.Validate(); err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + ak, err := api.staticDB.APIKeyCreate(req.Context(), *u, body.Public, body.Skylinks) if errors.Contains(err, database.ErrMaxNumAPIKeysExceeded) { err = errors.AddContext(err, "the maximum number of API keys a user can create is "+strconv.Itoa(database.MaxNumAPIKeysPerUser)) api.WriteError(w, err, http.StatusBadRequest) @@ -22,36 +113,107 @@ func (api *API) userAPIKeyPOST(u *database.User, w http.ResponseWriter, req *htt api.WriteError(w, err, http.StatusInternalServerError) return } - // Make the Key visible in JSON form. We do that with an anonymous struct - // because we don't envision that being needed anywhere else in the project. - akWithKey := struct { - database.APIKeyRecord - Key database.APIKey `bson:"key" json:"key"` - }{ - *ak, - ak.Key, + api.WriteJSON(w, APIKeyResponseWithKeyFromAPIKey(*ak)) +} + +// userAPIKeyGET returns a single API key. +func (api *API) userAPIKeyGET(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { + akID, err := primitive.ObjectIDFromHex(ps.ByName("id")) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + ak, err := api.staticDB.APIKeyGet(req.Context(), akID) + // If there is no such API key or it doesn't exist, return a 404. + if errors.Contains(err, mongo.ErrNoDocuments) || (err == nil && ak.UserID != u.ID) { + api.WriteError(w, nil, http.StatusNotFound) + return } - api.WriteJSON(w, akWithKey) + if err != nil { + api.WriteError(w, err, http.StatusInternalServerError) + return + } + api.WriteJSON(w, APIKeyResponseFromAPIKey(ak)) } -// userAPIKeyGET lists all API keys associated with the user. -func (api *API) userAPIKeyGET(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { +// userAPIKeyLIST lists all API keys associated with the user. +func (api *API) userAPIKeyLIST(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { aks, err := api.staticDB.APIKeyList(req.Context(), *u) if err != nil { api.WriteError(w, err, http.StatusInternalServerError) return } - api.WriteJSON(w, aks) + resp := make([]*APIKeyResponse, 0, len(aks)) + for _, ak := range aks { + resp = append(resp, APIKeyResponseFromAPIKey(ak)) + } + api.WriteJSON(w, resp) } // userAPIKeyDELETE removes an API key. func (api *API) userAPIKeyDELETE(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { - akID := ps.ByName("id") - err := api.staticDB.APIKeyDelete(req.Context(), *u, akID) + akID, err := primitive.ObjectIDFromHex(ps.ByName("id")) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + err = api.staticDB.APIKeyDelete(req.Context(), *u, akID) if err == mongo.ErrNoDocuments { + api.WriteError(w, err, http.StatusNotFound) + return + } + if err != nil { + api.WriteError(w, err, http.StatusInternalServerError) + return + } + api.WriteSuccess(w) +} + +// userAPIKeyPUT updates an API key. Only possible for public API keys. +func (api *API) userAPIKeyPUT(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { + akID, err := primitive.ObjectIDFromHex(ps.ByName("id")) + if err != nil { api.WriteError(w, err, http.StatusBadRequest) return } + var body APIKeyPUT + err = parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + err = api.staticDB.APIKeyUpdate(req.Context(), *u, akID, body.Skylinks) + if errors.Contains(err, mongo.ErrNoDocuments) { + api.WriteError(w, err, http.StatusNotFound) + return + } + if err != nil { + api.WriteError(w, err, http.StatusInternalServerError) + return + } + api.WriteSuccess(w) +} + +// userAPIKeyPATCH patches an API key. The difference between PUT and PATCH is +// that PATCH only specifies the changes while PUT provides the expected list of +// covered skylinks. Only possible for public API keys. +func (api *API) userAPIKeyPATCH(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { + akID, err := primitive.ObjectIDFromHex(ps.ByName("id")) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + var body APIKeyPATCH + err = parseRequestBodyJSON(req.Body, LimitBodySizeLarge, &body) + if err != nil { + api.WriteError(w, err, http.StatusBadRequest) + return + } + err = api.staticDB.APIKeyPatch(req.Context(), *u, akID, body.Add, body.Remove) + if errors.Contains(err, mongo.ErrNoDocuments) { + api.WriteError(w, err, http.StatusNotFound) + return + } if err != nil { api.WriteError(w, err, http.StatusInternalServerError) return diff --git a/api/auth.go b/api/auth.go new file mode 100644 index 00000000..a5ffc5bc --- /dev/null +++ b/api/auth.go @@ -0,0 +1,109 @@ +package api + +import ( + "net/http" + "strings" + + "github.com/SkynetLabs/skynet-accounts/database" + "github.com/SkynetLabs/skynet-accounts/jwt" + jwt2 "github.com/lestrrat-go/jwx/jwt" + "gitlab.com/NebulousLabs/errors" +) + +// userAndTokenByRequestToken scans the request for an authentication token, +// fetches the corresponding user from the database and returns both user and +// token. +func (api *API) userAndTokenByRequestToken(req *http.Request) (*database.User, jwt2.Token, error) { + token, err := tokenFromRequest(req) + if err != nil { + return nil, nil, errors.AddContext(err, "error fetching token from request") + } + sub, _, _, err := jwt.TokenFields(token) + if err != nil { + return nil, nil, errors.AddContext(err, "error decoding token from request") + } + u, err := api.staticDB.UserBySub(req.Context(), sub) + if err != nil { + return nil, nil, errors.AddContext(err, "error fetching user from database") + } + return u, token, nil +} + +// userAndTokenByAPIKey extracts the APIKey from the request and validates it. +// It then returns the user who owns it and a token for that user. +// It first checks the headers and then the query. +// This method accesses the database. +func (api *API) userAndTokenByAPIKey(req *http.Request) (*database.User, jwt2.Token, error) { + ak, err := apiKeyFromRequest(req) + if err != nil { + return nil, nil, err + } + akr, err := api.staticDB.APIKeyByKey(req.Context(), ak.String()) + if err != nil { + return nil, nil, err + } + // If we're dealing with a public API key, we need to validate that this + // request is a GET for a covered skylink. + if akr.Public { + // Public API keys can only be used with GET. + if req.Method != http.MethodGet { + return nil, nil, database.ErrInvalidAPIKey + } + sl, err := database.ExtractSkylinkHash(req.RequestURI) + if err != nil || !akr.CoversSkylink(sl) { + return nil, nil, database.ErrInvalidAPIKey + } + } + u, err := api.staticDB.UserByID(req.Context(), akr.UserID) + if err != nil { + return nil, nil, err + } + t, err := jwt.TokenForUser(u.Email, u.Sub) + return u, t, err +} + +// apiKeyFromRequest extracts the API key from the request and returns it. +// This function does not differentiate between APIKey and APIKey. +// It first checks the headers and then the query. +func apiKeyFromRequest(r *http.Request) (*database.APIKey, error) { + // Check the headers for an API key. + akStr := r.Header.Get(APIKeyHeader) + // If there is no API key in the headers, try the query. + if akStr == "" { + akStr = r.FormValue("apiKey") + } + if akStr == "" { + return nil, ErrNoAPIKey + } + return database.NewAPIKeyFromString(akStr) +} + +// tokenFromRequest extracts the JWT token from the request and returns it. +// It first checks the authorization header and then the cookies. +// The token is validated before being returned. +func tokenFromRequest(r *http.Request) (jwt2.Token, error) { + var tokenStr string + // Check the headers for a token. + parts := strings.Split(r.Header.Get("Authorization"), "Bearer") + if len(parts) == 2 { + tokenStr = strings.TrimSpace(parts[1]) + } else { + // Check the cookie for a token. + cookie, err := r.Cookie(CookieName) + if errors.Contains(err, http.ErrNoCookie) { + return nil, ErrNoToken + } + if err != nil { + return nil, errors.AddContext(err, "cookie exists but it's not valid") + } + err = secureCookie.Decode(CookieName, cookie.Value, &tokenStr) + if err != nil { + return nil, errors.AddContext(err, "failed to decode token") + } + } + token, err := jwt.ValidateToken(tokenStr) + if err != nil { + return nil, errors.AddContext(err, "failed to validate token") + } + return token, nil +} diff --git a/api/routes_test.go b/api/auth_test.go similarity index 92% rename from api/routes_test.go rename to api/auth_test.go index 4b645249..7615bd49 100644 --- a/api/routes_test.go +++ b/api/auth_test.go @@ -29,28 +29,28 @@ func TestAPIKeyFromRequest(t *testing.T) { } // API key from request form. - token := randomAPIKeyString() - req.Form.Add("apiKey", token) - tk, err := apiKeyFromRequest(req) + akStr := randomAPIKeyString() + req.Form.Add("apiKey", akStr) + ak, err := apiKeyFromRequest(req) if err != nil { t.Fatal(err) } - if string(tk) != token { - t.Fatalf("Expected '%s', got '%s'.", token, tk) + if ak.String() != akStr { + t.Fatalf("Expected '%s', got '%s'.", akStr, ak) } // API key from headers. Expect this to take precedence over request form. token2 := randomAPIKeyString() req.Header.Set(APIKeyHeader, token2) - tk, err = apiKeyFromRequest(req) + ak, err = apiKeyFromRequest(req) if err != nil { t.Fatal(err) } - if string(tk) == token { + if ak.String() == akStr { t.Fatal("Form token took precedence over headers token.") } - if string(tk) != token2 { - t.Fatalf("Expected '%s', got '%s'.", token2, tk) + if ak.String() != token2 { + t.Fatalf("Expected '%s', got '%s'.", token2, ak) } } diff --git a/api/cache.go b/api/cache.go index 7b2717fc..9d49b186 100644 --- a/api/cache.go +++ b/api/cache.go @@ -47,10 +47,10 @@ func (utc *userTierCache) Get(sub string) (int, bool, bool) { return ce.Tier, ce.QuotaExceeded, true } -// Set stores the user's tier in the cache. -func (utc *userTierCache) Set(u *database.User) { +// Set stores the user's tier in the cache under the given key. +func (utc *userTierCache) Set(key string, u *database.User) { utc.mu.Lock() - utc.cache[u.Sub] = userTierCacheEntry{ + utc.cache[key] = userTierCacheEntry{ Tier: u.Tier, QuotaExceeded: u.QuotaExceeded, ExpiresAt: time.Now().UTC().Add(userTierCacheTTL), diff --git a/api/cache_test.go b/api/cache_test.go index ba255533..fa49f5a6 100644 --- a/api/cache_test.go +++ b/api/cache_test.go @@ -21,8 +21,8 @@ func TestUserTierCache(t *testing.T) { if ok || tier != database.TierAnonymous { t.Fatalf("Expected to get tier %d and %t, got %d and %t.", database.TierAnonymous, false, tier, ok) } - // Set the user in the cache. - cache.Set(u) + // Set the use in the cache. + cache.Set(u.Sub, u) // Check again. tier, qe, ok := cache.Get(u.Sub) if !ok || tier != u.Tier { @@ -32,7 +32,7 @@ func TestUserTierCache(t *testing.T) { t.Fatal("Quota exceeded flag doesn't match.") } u.QuotaExceeded = true - cache.Set(u) + cache.Set(u.Sub, u) tier, qe, ok = cache.Get(u.Sub) if !ok || tier != u.Tier { t.Fatalf("Expected to get tier %d and %t, got %d and %t.", u.Tier, true, tier, ok) @@ -53,10 +53,31 @@ func TestUserTierCache(t *testing.T) { timeToMonthRollover := 30 * time.Minute u.SubscribedUntil = time.Now().UTC().Add(timeToMonthRollover) // Update the cache. - cache.Set(u) + cache.Set(u.Sub, u) // Expect the cache entry's ExpiresAt to be after 30 minutes. timeIn30 := time.Now().UTC().Add(time.Hour - timeToMonthRollover) if ce.ExpiresAt.After(timeIn30) && ce.ExpiresAt.Before(timeIn30.Add(time.Second)) { t.Fatalf("Expected ExpiresAt to be within 1 second of %s, but it was %s (off by %d ns)", timeIn30.String(), ce.ExpiresAt.String(), (time.Hour - timeIn30.Sub(ce.ExpiresAt)).Nanoseconds()) } + + // Create a new API key. + ak := database.NewAPIKey() + if !ak.IsValid() { + t.Fatal("Invalid API key.") + } + // Try to get a value from the cache. Expect this to fail. + _, _, ok = cache.Get(string(ak)) + if ok { + t.Fatal("Did not expect to get a cache entry!") + } + // Update the cache with a custom key. + cache.Set(string(ak), u) + // Fetch the data for the custom key. + tier, _, ok = cache.Get(string(ak)) + if !ok { + t.Fatal("Expected the entry to exist.") + } + if tier != u.Tier { + t.Fatalf("Expected tier %+v, got %+v", u.Tier, tier) + } } diff --git a/api/handlers.go b/api/handlers.go index 77002ee1..6cbab5a5 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -28,6 +28,9 @@ const ( // LimitBodySizeSmall defines a size limit for requests that we don't expect // to contain a lot of data. LimitBodySizeSmall = 4 * skynet.KiB + // LimitBodySizeLarge defines a size limit for requests that we expect to + // contain a lot of data. + LimitBodySizeLarge = 4 * skynet.MiB ) type ( @@ -409,29 +412,42 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http // to be presented in bytes per second. The default behaviour is to present // them in bits per second. inBytes := strings.EqualFold(req.FormValue("unit"), "byte") + respAnon := userLimitsGetFromTier(database.TierAnonymous, false, inBytes) // First check for an API key. ak, err := apiKeyFromRequest(req) - respAnon := userLimitsGetFromTier(database.TierAnonymous, inBytes) if err == nil { - u, err := api.staticDB.UserByAPIKey(req.Context(), ak) + // Check the cache before going any further. + tier, qe, ok := api.staticUserTierCache.Get(ak.String()) + if ok { + api.staticLogger.Traceln("Fetching user limits from cache by API key.") + api.WriteJSON(w, userLimitsGetFromTier(tier, qe, inBytes)) + return + } + // Get the API key. + akr, err := api.staticDB.APIKeyByKey(req.Context(), ak.String()) if err != nil { - api.staticLogger.Traceln("Error while fetching user by API key:", err) + api.staticLogger.Trace("API key doesn't exist in the database.") + api.WriteJSON(w, respAnon) + return + } + if akr.Public { + api.staticLogger.Trace("API key is public, cannot be used for general requests") api.WriteJSON(w, respAnon) return } - resp := userLimitsGetFromTier(u.Tier, inBytes) - // If the quota is exceeded we should keep the user's tier but report - // anonymous-level speeds. - if u.QuotaExceeded { - // Report the speeds for tier anonymous. - resp = userLimitsGetFromTier(database.TierAnonymous, inBytes) - // But keep reporting the user's actual tier and it's name. - resp.TierID = u.Tier - resp.TierName = database.UserLimits[u.Tier].TierName + // Get the owner of this API key from the database. + u, err := api.staticDB.UserByID(req.Context(), akr.UserID) + if err != nil { + api.staticLogger.Traceln("Error while fetching user by API key:", err) + api.WriteJSON(w, respAnon) + return } - api.WriteJSON(w, resp) + // Cache the user under the API key they used. + api.staticUserTierCache.Set(ak.String(), u) + api.WriteJSON(w, userLimitsGetFromTier(u.Tier, u.QuotaExceeded, inBytes)) return } + // Next check for a token. token, err := tokenFromRequest(req) if err != nil { api.WriteJSON(w, respAnon) @@ -454,25 +470,77 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http api.WriteJSON(w, respAnon) return } - api.staticUserTierCache.Set(u) + api.staticUserTierCache.Set(u.Sub, u) // Populate the tier and qe values, while simultaneously making sure // that we can read the record from the cache. - tier, qe, ok = api.staticUserTierCache.Get(sub) + tier, qe, ok = api.staticUserTierCache.Get(u.Sub) if !ok { build.Critical("Failed to fetch user from UserTierCache right after setting it.") } } - resp := userLimitsGetFromTier(tier, inBytes) - // If the quota is exceeded we should keep the user's tier but report - // anonymous-level speeds. - if qe { - // Report anonymous speeds. - resp = userLimitsGetFromTier(database.TierAnonymous, inBytes) - // Keep reporting the user's actual tier and tier name. - resp.TierID = tier - resp.TierName = database.UserLimits[tier].TierName + api.WriteJSON(w, userLimitsGetFromTier(tier, qe, inBytes)) +} + +// userLimitsSkylinkGET returns the speed limits which apply to a GET call to +// the given skylink. This method exists to accommodate public API keys. +// +// NOTE: This handler needs to use the noAuth middleware in order to be able to +// optimise its calls to the DB and the use of caching. +func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { + // inBytes is a flag indicating that the caller wants all bandwidth limits + // to be presented in bytes per second. The default behaviour is to present + // them in bits per second. + inBytes := strings.EqualFold(req.FormValue("unit"), "byte") + respAnon := userLimitsGetFromTier(database.TierAnonymous, false, inBytes) + // Validate the skylink. + skylink := ps.ByName("skylink") + if !database.ValidSkylinkHash(skylink) { + api.staticLogger.Tracef("Invalid skylink: '%s'", skylink) + api.WriteJSON(w, respAnon) + return } - api.WriteJSON(w, resp) + // Try to fetch an API attached to the request. + ak, err := apiKeyFromRequest(req) + if errors.Contains(err, ErrNoAPIKey) { + // We failed to fetch an API key from this request but the request might + // be authenticated in another way, so we'll defer to userLimitsGET. + api.userLimitsGET(u, w, req, ps) + return + } + if err != nil { + api.staticLogger.Debugf("Error while processing API key: %s", err) + api.WriteJSON(w, respAnon) + return + } + // Check the cache before hitting the database. + tier, qe, ok := api.staticUserTierCache.Get(ak.String() + skylink) + if ok { + api.staticLogger.Traceln("Fetching user limits from cache by API key.") + api.WriteJSON(w, userLimitsGetFromTier(tier, qe, inBytes)) + return + } + // Get the API key. + akr, err := api.staticDB.APIKeyByKey(req.Context(), ak.String()) + if err != nil { + api.staticLogger.Trace("API key doesn't exist in the database.") + api.WriteJSON(w, respAnon) + return + } + if !akr.CoversSkylink(skylink) { + api.staticLogger.Trace("API key doesn't cover this skylink.") + api.WriteJSON(w, respAnon) + return + } + // Get the owner of this API key from the database. + user, err := api.staticDB.UserByID(req.Context(), akr.UserID) + if err != nil { + api.staticLogger.Tracef("Failed to get user for user ID: %v", err) + api.WriteJSON(w, respAnon) + return + } + // Store the user in the cache with a custom key. + api.staticUserTierCache.Set(ak.String()+skylink, user) + api.WriteJSON(w, userLimitsGetFromTier(user.Tier, user.QuotaExceeded, inBytes)) } // userStatsGET returns statistics about an existing user. @@ -1111,7 +1179,7 @@ func (api *API) trackRegistryWritePOST(u *database.User, w http.ResponseWriter, func (api *API) userUploadsDELETE(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) { sl := ps.ByName("skylink") if !database.ValidSkylinkHash(sl) { - api.WriteError(w, errors.New("invalid skylink"), http.StatusBadRequest) + api.WriteError(w, database.ErrInvalidSkylink, http.StatusBadRequest) return } skylink, err := api.staticDB.Skylink(req.Context(), sl) @@ -1153,7 +1221,7 @@ func (api *API) checkUserQuotas(ctx context.Context, u *database.User) { if err != nil { api.staticLogger.Warnf("Failed to save user. User: %+v, err: %s", u, err.Error()) } - api.staticUserTierCache.Set(u) + api.staticUserTierCache.Set(u.Sub, u) } } @@ -1196,31 +1264,41 @@ func fetchPageSize(form url.Values) (int, error) { } // parseRequestBodyJSON reads a limited portion of the body and decodes it into -// the given obj. The purpose of this is to prevent DoS attacks that rely on -// excessively large request bodies. -func parseRequestBodyJSON(body io.ReadCloser, maxBodySize int64, objRef interface{}) error { - return json.NewDecoder(io.LimitReader(body, maxBodySize)).Decode(&objRef) +// the given struct v. The purpose of this is to prevent DoS attacks that rely +// on excessively large request bodies. +func parseRequestBodyJSON(body io.ReadCloser, maxBodySize int64, v interface{}) error { + return json.NewDecoder(io.LimitReader(body, maxBodySize)).Decode(&v) } // userLimitsGetFromTier is a helper that lets us succinctly translate // from the database DTO to the API DTO. The `inBytes` parameter determines // whether the returned speeds will be in Bps or bps. -func userLimitsGetFromTier(tier int, inBytes bool) *UserLimitsGET { - t := database.UserLimits[tier] - ul := UserLimitsGET{ - TierID: tier, - TierName: t.TierName, - MaxUploadSize: t.MaxUploadSize, - MaxNumberUploads: t.MaxNumberUploads, - RegistryDelay: t.RegistryDelay, - Storage: t.Storage, +func userLimitsGetFromTier(tierID int, quotaExceeded, inBytes bool) *UserLimitsGET { + t, ok := database.UserLimits[tierID] + if !ok { + build.Critical("userLimitsGetFromTier was called with non-existent tierID: " + strconv.Itoa(tierID)) + t = database.UserLimits[database.TierAnonymous] } + limitsTier := t + if quotaExceeded { + limitsTier = database.UserLimits[database.TierAnonymous] + } + // If we need to return the result in bits per second, we multiply by 8, + // otherwise, we multiply by 1. + bpsMul := 8 if inBytes { - ul.UploadBandwidth = t.UploadBandwidth - ul.DownloadBandwidth = t.DownloadBandwidth - } else { - ul.UploadBandwidth = t.UploadBandwidth * 8 - ul.DownloadBandwidth = t.DownloadBandwidth * 8 + bpsMul = 1 + } + return &UserLimitsGET{ + TierID: tierID, + TierName: t.TierName, + Storage: t.Storage, + // If the user exceeds their quota, there will be brought down to + // anonymous levels. + UploadBandwidth: limitsTier.UploadBandwidth * bpsMul, + DownloadBandwidth: limitsTier.DownloadBandwidth * bpsMul, + MaxUploadSize: limitsTier.MaxUploadSize, + MaxNumberUploads: limitsTier.MaxNumberUploads, + RegistryDelay: limitsTier.RegistryDelay, } - return &ul } diff --git a/api/handlers_test.go b/api/handlers_test.go index 3b31e988..29447225 100644 --- a/api/handlers_test.go +++ b/api/handlers_test.go @@ -1,9 +1,13 @@ package api import ( + "fmt" + "math" + "strings" "testing" "github.com/SkynetLabs/skynet-accounts/database" + "gitlab.com/NebulousLabs/errors" ) // TestUserGETFromUser ensures the UserGETFromUser method correctly converts @@ -39,3 +43,88 @@ func TestUserGETFromUser(t *testing.T) { t.Fatal("Expected EmailConfirmed to be false.") } } + +// TestUserLimitsGetFromTier ensures the proper functioning of +// userLimitsGetFromTier. +func TestUserLimitsGetFromTier(t *testing.T) { + tests := []struct { + name string + tier int + quotaExceeded bool + expectedTier int + expectedStorage int64 + expectedUploadBW int + expectedDownloadBW int + expectedRegistryDelay int + }{ + { + name: "anon", + tier: database.TierAnonymous, + quotaExceeded: false, + expectedTier: database.TierAnonymous, + expectedStorage: database.UserLimits[database.TierAnonymous].Storage, + expectedUploadBW: database.UserLimits[database.TierAnonymous].UploadBandwidth, + expectedDownloadBW: database.UserLimits[database.TierAnonymous].DownloadBandwidth, + expectedRegistryDelay: database.UserLimits[database.TierAnonymous].RegistryDelay, + }, + { + name: "plus, quota not exceeded", + tier: database.TierPremium5, + quotaExceeded: false, + expectedTier: database.TierPremium5, + expectedStorage: database.UserLimits[database.TierPremium5].Storage, + expectedUploadBW: database.UserLimits[database.TierPremium5].UploadBandwidth, + expectedDownloadBW: database.UserLimits[database.TierPremium5].DownloadBandwidth, + expectedRegistryDelay: database.UserLimits[database.TierPremium5].RegistryDelay, + }, + { + name: "plus, quota exceeded", + tier: database.TierPremium5, + quotaExceeded: true, + expectedTier: database.TierPremium5, + expectedStorage: database.UserLimits[database.TierPremium5].Storage, + expectedUploadBW: database.UserLimits[database.TierAnonymous].UploadBandwidth, + expectedDownloadBW: database.UserLimits[database.TierAnonymous].DownloadBandwidth, + expectedRegistryDelay: database.UserLimits[database.TierAnonymous].RegistryDelay, + }, + } + + for _, tt := range tests { + ul := userLimitsGetFromTier(tt.tier, tt.quotaExceeded, true) + if ul.TierID != tt.expectedTier { + t.Errorf("Test '%s': expected tier %d, got %d", tt.name, tt.expectedTier, ul.TierID) + } + if ul.Storage != tt.expectedStorage { + t.Errorf("Test '%s': expected storage %d, got %d", tt.name, tt.expectedStorage, ul.Storage) + } + if ul.UploadBandwidth != tt.expectedUploadBW { + t.Errorf("Test '%s': expected upload bandwidth %d, got %d", tt.name, tt.expectedUploadBW, ul.UploadBandwidth) + } + if ul.DownloadBandwidth != tt.expectedDownloadBW { + t.Errorf("Test '%s': expected download bandwidth %d, got %d", tt.name, tt.expectedDownloadBW, ul.DownloadBandwidth) + } + if ul.RegistryDelay != tt.expectedRegistryDelay { + t.Errorf("Test '%s': expected registry delay %d, got %d", tt.name, tt.expectedRegistryDelay, ul.RegistryDelay) + } + } + + // Additionally, let us ensure that userLimitsGetFromTier logs a critical + // when called with an invalid tier ID. + err := func() (err error) { + defer func() { + e := recover() + if e == nil { + err = errors.New("expected to panic") + } + if !strings.Contains(fmt.Sprint(e), "userLimitsGetFromTier was called with non-existent tierID") { + err = fmt.Errorf("expected to panic with a specific error message, got '%s'", fmt.Sprint(e)) + } + }() + // The call that we expect to log a critical. + _ = userLimitsGetFromTier(math.MaxInt, false, true) + return + }() + if err != nil { + t.Fatal(err) + } +} diff --git a/api/routes.go b/api/routes.go index f0a5c572..1a221197 100644 --- a/api/routes.go +++ b/api/routes.go @@ -1,14 +1,11 @@ package api import ( - "context" "net/http" "strings" "github.com/SkynetLabs/skynet-accounts/database" "github.com/SkynetLabs/skynet-accounts/jwt" - jwt2 "github.com/lestrrat-go/jwx/jwt" - "github.com/julienschmidt/httprouter" "gitlab.com/NebulousLabs/errors" ) @@ -54,6 +51,7 @@ func (api *API) buildHTTPRoutes() { api.staticRouter.PUT("/user", api.WithDBSession(api.withAuth(api.userPUT))) api.staticRouter.DELETE("/user", api.withAuth(api.userDELETE)) api.staticRouter.GET("/user/limits", api.noAuth(api.userLimitsGET)) + api.staticRouter.GET("/user/limits/:skylink", api.noAuth(api.userLimitsSkylinkGET)) api.staticRouter.GET("/user/stats", api.withAuth(api.userStatsGET)) api.staticRouter.GET("/user/pubkey/register", api.WithDBSession(api.withAuth(api.userPubKeyRegisterGET))) api.staticRouter.POST("/user/pubkey/register", api.WithDBSession(api.withAuth(api.userPubKeyRegisterPOST))) @@ -63,7 +61,10 @@ func (api *API) buildHTTPRoutes() { // Endpoints for user API keys. api.staticRouter.POST("/user/apikeys", api.WithDBSession(api.withAuth(api.userAPIKeyPOST))) - api.staticRouter.GET("/user/apikeys", api.withAuth(api.userAPIKeyGET)) + api.staticRouter.GET("/user/apikeys", api.withAuth(api.userAPIKeyLIST)) + api.staticRouter.GET("/user/apikeys/:id", api.withAuth(api.userAPIKeyGET)) + api.staticRouter.PUT("/user/apikeys/:id", api.WithDBSession(api.withAuth(api.userAPIKeyPUT))) + api.staticRouter.PATCH("/user/apikeys/:id", api.WithDBSession(api.withAuth(api.userAPIKeyPATCH))) api.staticRouter.DELETE("/user/apikeys/:id", api.withAuth(api.userAPIKeyDELETE)) // Endpoints for email communication with the user. @@ -91,45 +92,24 @@ func (api *API) noAuth(h HandlerWithUser) httprouter.Handle { func (api *API) withAuth(h HandlerWithUser) httprouter.Handle { return func(w http.ResponseWriter, req *http.Request, ps httprouter.Params) { api.logRequest(req) - var u *database.User - var token jwt2.Token - // Check for an API key. We only return an error if an invalid API key - // is provided. - ak, err := apiKeyFromRequest(req) - if err == nil { - // We have an API key. Let's generate a token based on it. - token, err = api.tokenFromAPIKey(req.Context(), ak) - u, err = api.staticDB.UserByAPIKey(req.Context(), ak) - if err != nil { - api.staticLogger.Debugf("Error fetching user for API key %s. Error: %s", ak, err) - api.WriteError(w, errors.AddContext(err, "failed to fetch user by API key"), http.StatusUnauthorized) - return - } - } else { - // No API key. Let's check for a token in the request. - token, err = tokenFromRequest(req) - if err != nil { - api.staticLogger.Debugln("Error fetching token from request:", err) - api.WriteError(w, err, http.StatusUnauthorized) - return - } - sub, _, _, err := jwt.TokenFields(token) + // Check for an API key. + u, token, err := api.userAndTokenByAPIKey(req) + // If there is an unexpected error, that is a 500. + if err != nil && !errors.Contains(err, ErrNoAPIKey) && !errors.Contains(err, database.ErrInvalidAPIKey) && !errors.Contains(err, database.ErrUserNotFound) { + api.WriteError(w, err, http.StatusInternalServerError) + return + } + if err != nil && (errors.Contains(err, database.ErrInvalidAPIKey) || errors.Contains(err, database.ErrUserNotFound)) { + api.WriteError(w, errors.AddContext(err, "failed to fetch user by API key"), http.StatusUnauthorized) + return + } + // If there is no API key check for a token. + if errors.Contains(err, ErrNoAPIKey) { + u, token, err = api.userAndTokenByRequestToken(req) if err != nil { - api.staticLogger.Debugln("Error decoding token from request:", err) - api.WriteError(w, err, http.StatusUnauthorized) - return - } - u, err = api.staticDB.UserBySub(req.Context(), sub) - if errors.Contains(err, database.ErrUserNotFound) { - api.staticLogger.Debugln("User that created this token no longer exists:", err) api.WriteError(w, err, http.StatusUnauthorized) return } - if err != nil { - api.staticLogger.Debugln("Error fetching user by token from request:", err) - api.WriteError(w, err, http.StatusInternalServerError) - return - } } // Embed the verified token in the context of the request. ctx := jwt.ContextWithToken(req.Context(), token) @@ -140,66 +120,9 @@ func (api *API) withAuth(h HandlerWithUser) httprouter.Handle { // logRequest logs information about the current request. func (api *API) logRequest(r *http.Request) { hasAuth := strings.HasPrefix(r.Header.Get("Authorization"), "Bearer") + hasAPIKey := r.Header.Get(APIKeyHeader) != "" || r.FormValue("apiKey") != "" c, err := r.Cookie(CookieName) hasCookie := err == nil && c != nil - api.staticLogger.Tracef("Processing request: %v %v, Auth: %v, Skynet Cookie: %v, Referer: %v, Host: %v, RemoreAddr: %v", r.Method, r.URL, hasAuth, hasCookie, r.Referer(), r.Host, r.RemoteAddr) -} - -// tokenFromAPIKey returns a token, generated for the owner of the given API key. -func (api *API) tokenFromAPIKey(ctx context.Context, ak database.APIKey) (jwt2.Token, error) { - u, err := api.staticDB.UserByAPIKey(ctx, ak) - if err != nil { - return nil, err - } - tk, err := jwt.TokenForUser(u.Email, u.Sub) - return tk, err -} - -// apiKeyFromRequest extracts the API key from the request and returns it. -// It first checks the headers and then the query. -func apiKeyFromRequest(r *http.Request) (database.APIKey, error) { - // Check the headers for an API key. - akStr := r.Header.Get(APIKeyHeader) - // If there is no API key in the headers, try the query. - if akStr == "" { - akStr = r.FormValue("apiKey") - } - if akStr == "" { - return "", ErrNoAPIKey - } - ak, err := database.NewAPIKeyFromString(akStr) - if err != nil { - return "", err - } - return *ak, nil -} - -// tokenFromRequest extracts the JWT token from the request and returns it. -// It first checks the request headers and then the cookies. -// The token is validated before being returned. -func tokenFromRequest(r *http.Request) (jwt2.Token, error) { - var tokenStr string - // Check the headers for a token. - parts := strings.Split(r.Header.Get("Authorization"), "Bearer") - if len(parts) == 2 { - tokenStr = strings.TrimSpace(parts[1]) - } else { - // Check the cookie for a token. - cookie, err := r.Cookie(CookieName) - if errors.Contains(err, http.ErrNoCookie) { - return nil, ErrNoToken - } - if err != nil { - return nil, errors.AddContext(err, "cookie exists but it's not valid") - } - err = secureCookie.Decode(CookieName, cookie.Value, &tokenStr) - if err != nil { - return nil, errors.AddContext(err, "failed to decode token") - } - } - token, err := jwt.ValidateToken(tokenStr) - if err != nil { - return nil, errors.AddContext(err, "failed to validate token") - } - return token, nil + api.staticLogger.Tracef("Processing request: %v %v, Auth: %v, API Key: %v, Cookie: %v, Referer: %v, Host: %v, RemoreAddr: %v", + r.Method, r.URL, hasAuth, hasAPIKey, hasCookie, r.Referer(), r.Host, r.RemoteAddr) } diff --git a/api/stripe.go b/api/stripe.go index 62391c40..3fdd3220 100644 --- a/api/stripe.go +++ b/api/stripe.go @@ -247,7 +247,7 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er api.staticLogger.Tracef("Subscribed user id '%s', tier %d, until %s.", u.ID, u.Tier, u.SubscribedUntil.String()) } // Re-set the tier cache for this user, in case their tier changed. - api.staticUserTierCache.Set(u) + api.staticUserTierCache.Set(u.Sub, u) return err } diff --git a/database/apikeys.go b/database/apikeys.go index f564ea07..7bc41ded 100644 --- a/database/apikeys.go +++ b/database/apikeys.go @@ -12,13 +12,25 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) /** API keys are authentication tokens generated by users. They do not expire, thus allowing users to use them for a long time and to embed them in apps and on machines. API keys can be revoked when they are no longer needed or if they get -compromised. This is done by deleting them from this service. +compromised or are no longer needed. This is done by deleting them from this +service. + +There are two kinds of API keys - public and private. We differentiate between +them by the `public` flag. + +Private API keys give full API access - using them is equivalent to using a JWT +token, either via an authorization header or a cookie. + +Public API keys can only be use for downloading skylinks. The list of skylinks +that can be downloaded by a given public API key is stored under the `skylinks` +array within the API key record. */ var ( @@ -32,17 +44,26 @@ var ( ErrMaxNumAPIKeysExceeded = errors.New("maximum number of api keys exceeded") // ErrInvalidAPIKey is an error returned when the given API key is invalid. ErrInvalidAPIKey = errors.New("invalid api key") + // ErrInvalidAPIKeyOperation covers a range of invalid operations on API + // keys. Some examples include: defining a list of skylinks on a private + // API key, editing a private API key. This error should be used with + // additional context, specifying the exact operation that failed. + ErrInvalidAPIKeyOperation = errors.New("invalid api key operation") ) type ( // APIKey is the hex representation of a base32-encoded random 32-byte slice // length PubKeySize APIKey string - // APIKeyRecord is a non-expiring authentication token generated on user demand. + // APIKeyRecord is a non-expiring authentication token generated on user + // demand. Public API keys allow downloading a given set of skylinks, while + // private API keys give full API access. APIKeyRecord struct { ID primitive.ObjectID `bson:"_id,omitempty" json:"id"` UserID primitive.ObjectID `bson:"user_id" json:"-"` + Public bool `bson:"public,string" json:"public,string"` Key APIKey `bson:"key" json:"-"` + Skylinks []string `bson:"skylinks" json:"skylinks"` CreatedAt time.Time `bson:"created_at" json:"createdAt"` } ) @@ -84,8 +105,37 @@ func (ak *APIKey) LoadBytes(b []byte) error { return nil } +// LoadString loads a string into the API key and validates it. +func (ak *APIKey) LoadString(s string) error { + k := APIKey(strings.ToUpper(s)) + if !k.IsValid() { + return ErrInvalidAPIKey + } + *ak = k + return nil +} + +// LoadBytes encodes a []byte of size PubKeySize into an API key. +func (ak APIKey) String() string { + return string(ak) +} + +// CoversSkylink tells us whether a given API key covers a given skylink. +// Private API keys cover all skylinks while public ones - only a limited set. +func (akr APIKeyRecord) CoversSkylink(sl string) bool { + if !akr.Public { + return true + } + for _, s := range akr.Skylinks { + if s == sl { + return true + } + } + return false +} + // APIKeyCreate creates a new API key. -func (db *DB) APIKeyCreate(ctx context.Context, user User) (*APIKeyRecord, error) { +func (db *DB) APIKeyCreate(ctx context.Context, user User, public bool, skylinks []string) (*APIKeyRecord, error) { if user.ID.IsZero() { return nil, errors.New("invalid user") } @@ -96,9 +146,14 @@ func (db *DB) APIKeyCreate(ctx context.Context, user User) (*APIKeyRecord, error if n > int64(MaxNumAPIKeysPerUser) { return nil, ErrMaxNumAPIKeysExceeded } + if !public && len(skylinks) > 0 { + return nil, errors.AddContext(ErrInvalidAPIKeyOperation, "cannot define skylinks for a private api key") + } akr := APIKeyRecord{ UserID: user.ID, + Public: public, Key: NewAPIKey(), + Skylinks: skylinks, CreatedAt: time.Now().UTC(), } ior, err := db.staticAPIKeys.InsertOne(ctx, akr) @@ -110,16 +165,12 @@ func (db *DB) APIKeyCreate(ctx context.Context, user User) (*APIKeyRecord, error } // APIKeyDelete deletes an API key. -func (db *DB) APIKeyDelete(ctx context.Context, user User, akID string) error { +func (db *DB) APIKeyDelete(ctx context.Context, user User, akID primitive.ObjectID) error { if user.ID.IsZero() { return errors.New("invalid user") } - id, err := primitive.ObjectIDFromHex(akID) - if err != nil { - return errors.AddContext(err, "invalid API key ID") - } filter := bson.M{ - "_id": id, + "_id": akID, "user_id": user.ID, } dr, err := db.staticAPIKeys.DeleteOne(ctx, filter) @@ -132,6 +183,36 @@ func (db *DB) APIKeyDelete(ctx context.Context, user User, akID string) error { return nil } +// APIKeyByKey returns a specific API key. +func (db *DB) APIKeyByKey(ctx context.Context, key string) (APIKeyRecord, error) { + filter := bson.M{"key": key} + sr := db.staticAPIKeys.FindOne(ctx, filter) + if sr.Err() != nil { + return APIKeyRecord{}, sr.Err() + } + var akr APIKeyRecord + err := sr.Decode(&akr) + if err != nil { + return APIKeyRecord{}, err + } + return akr, nil +} + +// APIKeyGet returns a specific API key. +func (db *DB) APIKeyGet(ctx context.Context, akID primitive.ObjectID) (APIKeyRecord, error) { + filter := bson.M{"_id": akID} + sr := db.staticAPIKeys.FindOne(ctx, filter) + if sr.Err() != nil { + return APIKeyRecord{}, sr.Err() + } + var akr APIKeyRecord + err := sr.Decode(&akr) + if err != nil { + return APIKeyRecord{}, err + } + return akr, nil +} + // APIKeyList lists all API keys that belong to the user. func (db *DB) APIKeyList(ctx context.Context, user User) ([]APIKeyRecord, error) { if user.ID.IsZero() { @@ -150,3 +231,86 @@ func (db *DB) APIKeyList(ctx context.Context, user User) ([]APIKeyRecord, error) } return aks, nil } + +// APIKeyUpdate updates an existing API key. This works by replacing the +// list of Skylinks within the API key record. Only valid for public API keys. +func (db *DB) APIKeyUpdate(ctx context.Context, user User, akID primitive.ObjectID, skylinks []string) error { + if user.ID.IsZero() { + return errors.New("invalid user") + } + // Validate all given skylinks. + for _, s := range skylinks { + if !ValidSkylinkHash(s) { + return errors.AddContext(ErrInvalidSkylink, "offending skylink: "+s) + } + } + filter := bson.M{ + "_id": akID, + "public": true, + "user_id": user.ID, + } + update := bson.M{"$set": bson.M{"skylinks": skylinks}} + opts := options.UpdateOptions{ + Upsert: &False, + } + ur, err := db.staticAPIKeys.UpdateOne(ctx, filter, update, &opts) + if err != nil { + return err + } + if ur.ModifiedCount == 0 { + return mongo.ErrNoDocuments + } + return nil +} + +// APIKeyPatch updates an existing API key. This works by adding and removing +// skylinks to its record. Only valid for public API keys. +func (db *DB) APIKeyPatch(ctx context.Context, user User, akID primitive.ObjectID, addSkylinks, removeSkylinks []string) error { + if user.ID.IsZero() { + return errors.New("invalid user") + } + // Validate all given skylinks. + for _, s := range append(addSkylinks, removeSkylinks...) { + if !ValidSkylinkHash(s) { + return errors.AddContext(ErrInvalidSkylink, "offending skylink: "+s) + } + } + filter := bson.M{ + "_id": akID, + "public": &True, // you can only update public API keys + } + var update bson.M + // First, all new skylinks to the record. + if len(addSkylinks) > 0 { + update = bson.M{ + "$addToSet": bson.M{"skylinks": bson.M{"$each": addSkylinks}}, + } + opts := options.UpdateOptions{ + Upsert: &False, + } + ur, err := db.staticAPIKeys.UpdateOne(ctx, filter, update, &opts) + if err != nil { + return err + } + if ur.ModifiedCount == 0 { + return mongo.ErrNoDocuments + } + } + // Then, remove all skylinks that need to be removed. + if len(removeSkylinks) > 0 { + update = bson.M{ + "$pull": bson.M{"skylinks": bson.M{"$in": removeSkylinks}}, + } + opts := options.UpdateOptions{ + Upsert: &False, + } + ur, err := db.staticAPIKeys.UpdateOne(ctx, filter, update, &opts) + if err != nil { + return err + } + if ur.ModifiedCount == 0 { + return errors.New("public API key not found, no keys updated") + } + } + return nil +} diff --git a/database/apikeys_test.go b/database/apikeys_test.go index e2fc9032..6e24675a 100644 --- a/database/apikeys_test.go +++ b/database/apikeys_test.go @@ -1,6 +1,8 @@ package database -import "testing" +import ( + "testing" +) // TestNewAPIKeyFromString validates that NewAPIKeyFromString properly handles // valid API keys, upper case or lower case. @@ -27,3 +29,58 @@ func TestNewAPIKeyFromString(t *testing.T) { } } } + +// TestCoversSkylink ensures that CoversSkylink works as expected. +func TestCoversSkylink(t *testing.T) { + sl1 := "6TAOK0RVVKKK25PIA33FHDBD1G04DLO015DAAD6OM2J33KCD5CL0" + sl2 := "7TAOK0RVVKKK25PIA33FHDBD1G04DLO015DAAD6OM2J33KCD5CL0" + sl3 := "8TAOK0RVVKKK25PIA33FHDBD1G04DLO015DAAD6OM2J33KCD5CL0" + akr1 := APIKeyRecord{ + Public: false, + Key: NewAPIKey(), + } + akr2 := APIKeyRecord{ + Public: true, + Key: NewAPIKey(), + Skylinks: []string{sl1, sl2}, + } + + tests := []struct { + name string + key APIKeyRecord + skylinkToCheck string + expectedCovered bool + }{ + { + name: "general API key", + key: akr1, + skylinkToCheck: sl3, + expectedCovered: true, + }, + { + name: "public API key 1", + key: akr2, + skylinkToCheck: sl1, + expectedCovered: true, + }, + { + name: "public API key 2", + key: akr2, + skylinkToCheck: sl2, + expectedCovered: true, + }, + { + name: "public API key 3", + key: akr2, + skylinkToCheck: sl3, + expectedCovered: false, + }, + } + + for _, tt := range tests { + covered := tt.key.CoversSkylink(tt.skylinkToCheck) + if covered != tt.expectedCovered { + t.Errorf("Unexpected result for test %s", tt.name) + } + } +} diff --git a/database/database.go b/database/database.go index ca71d16a..14344844 100644 --- a/database/database.go +++ b/database/database.go @@ -127,30 +127,29 @@ func NewCustomDB(ctx context.Context, dbName string, creds DBCredentials, logger if err != nil { return nil, errors.AddContext(err, "failed to connect to DB") } - database := c.Database(dbName) + db := c.Database(dbName) if logger == nil { logger = &logrus.Logger{} } - err = ensureDBSchema(ctx, database, logger) + err = ensureDBSchema(ctx, db, Schema, logger) if err != nil { return nil, err } - db := &DB{ - staticDB: database, - staticUsers: database.Collection(collUsers), - staticSkylinks: database.Collection(collSkylinks), - staticUploads: database.Collection(collUploads), - staticDownloads: database.Collection(collDownloads), - staticRegistryReads: database.Collection(collRegistryReads), - staticRegistryWrites: database.Collection(collRegistryWrites), - staticEmails: database.Collection(collEmails), - staticChallenges: database.Collection(collChallenges), - staticUnconfirmedUserUpdates: database.Collection(collUnconfirmedUserUpdates), - staticConfiguration: database.Collection(collConfiguration), - staticAPIKeys: database.Collection(collAPIKeys), + return &DB{ + staticDB: db, + staticUsers: db.Collection(collUsers), + staticSkylinks: db.Collection(collSkylinks), + staticUploads: db.Collection(collUploads), + staticDownloads: db.Collection(collDownloads), + staticRegistryReads: db.Collection(collRegistryReads), + staticRegistryWrites: db.Collection(collRegistryWrites), + staticEmails: db.Collection(collEmails), + staticChallenges: db.Collection(collChallenges), + staticUnconfirmedUserUpdates: db.Collection(collUnconfirmedUserUpdates), + staticConfiguration: db.Collection(collConfiguration), + staticAPIKeys: db.Collection(collAPIKeys), staticLogger: logger, - } - return db, nil + }, nil } // Disconnect closes the connection to the database in an orderly fashion. @@ -188,114 +187,7 @@ func connectionString(creds DBCredentials) string { // creates them if needed. // See https://docs.mongodb.com/manual/indexes/ // See https://docs.mongodb.com/manual/core/index-unique/ -func ensureDBSchema(ctx context.Context, db *mongo.Database, log *logrus.Logger) error { - // schema defines a mapping between a collection name and the indexes that - // must exist for that collection. - schema := map[string][]mongo.IndexModel{ - collUsers: { - { - Keys: bson.D{{"sub", 1}}, - Options: options.Index().SetName("sub_unique").SetUnique(true), - }, - }, - collSkylinks: { - { - Keys: bson.D{{"skylink", 1}}, - Options: options.Index().SetName("skylink_unique").SetUnique(true), - }, - }, - collUploads: { - { - Keys: bson.D{{"user_id", 1}}, - Options: options.Index().SetName("user_id"), - }, - { - Keys: bson.D{{"skylink_id", 1}}, - Options: options.Index().SetName("skylink_id"), - }, - }, - collDownloads: { - { - Keys: bson.D{{"user_id", 1}}, - Options: options.Index().SetName("user_id"), - }, - { - Keys: bson.D{{"skylink_id", 1}}, - Options: options.Index().SetName("skylink_id"), - }, - }, - collRegistryReads: { - { - Keys: bson.D{{"user_id", 1}}, - Options: options.Index().SetName("user_id"), - }, - }, - collRegistryWrites: { - { - Keys: bson.D{{"user_id", 1}}, - Options: options.Index().SetName("user_id"), - }, - }, - collEmails: { - { - Keys: bson.D{{"failed_attempts", 1}}, - Options: options.Index().SetName("failed_attempts"), - }, - { - Keys: bson.D{{"locked_by", 1}}, - Options: options.Index().SetName("locked_by"), - }, - { - Keys: bson.D{{"sent_at", 1}}, - Options: options.Index().SetName("sent_at"), - }, - { - Keys: bson.D{{"sent_by", 1}}, - Options: options.Index().SetName("sent_by"), - }, - }, - collChallenges: { - { - Keys: bson.D{{"challenge", 1}}, - Options: options.Index().SetName("challenge"), - }, - { - Keys: bson.D{{"type", 1}}, - Options: options.Index().SetName("type"), - }, - { - Keys: bson.D{{"expires_at", 1}}, - Options: options.Index().SetName("expires_at"), - }, - }, - collUnconfirmedUserUpdates: { - { - Keys: bson.D{{"challenge_id", 1}}, - Options: options.Index().SetName("challenge_id"), - }, - { - Keys: bson.D{{"expires_at", 1}}, - Options: options.Index().SetName("expires_at"), - }, - }, - collConfiguration: { - { - Keys: bson.D{{"key", 1}}, - Options: options.Index().SetName("key_unique").SetUnique(true), - }, - }, - collAPIKeys: { - { - Keys: bson.D{{"key", 1}}, - Options: options.Index().SetName("key_unique").SetUnique(true), - }, - { - Keys: bson.D{{"user_id", 1}}, - Options: options.Index().SetName("user_id"), - }, - }, - } - +func ensureDBSchema(ctx context.Context, db *mongo.Database, schema map[string][]mongo.IndexModel, log *logrus.Logger) error { for collName, models := range schema { coll, err := ensureCollection(ctx, db, collName) if err != nil { diff --git a/database/download.go b/database/download.go index b3a5546a..08d54d98 100644 --- a/database/download.go +++ b/database/download.go @@ -56,7 +56,7 @@ func (db *DB) DownloadCreate(ctx context.Context, user User, skylink Skylink, by return errors.New("invalid user") } if skylink.ID.IsZero() { - return errors.New("invalid skylink") + return ErrInvalidSkylink } // Check if there exists a download of this skylink by this user, updated @@ -84,7 +84,7 @@ func (db *DB) DownloadCreate(ctx context.Context, user User, skylink Skylink, by // number of such downloads. func (db *DB) DownloadsBySkylink(ctx context.Context, skylink Skylink, offset, pageSize int) ([]DownloadResponse, int, error) { if skylink.ID.IsZero() { - return nil, 0, errors.New("invalid skylink") + return nil, 0, ErrInvalidSkylink } if err := validateOffsetPageSize(offset, pageSize); err != nil { return nil, 0, err diff --git a/database/schema.go b/database/schema.go new file mode 100644 index 00000000..34138e35 --- /dev/null +++ b/database/schema.go @@ -0,0 +1,116 @@ +package database + +import ( + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +var ( + // Schema defines a mapping between a collection name and the indexes that + // must exist for that collection. + Schema = map[string][]mongo.IndexModel{ + collUsers: { + { + Keys: bson.D{{"sub", 1}}, + Options: options.Index().SetName("sub_unique").SetUnique(true), + }, + }, + collSkylinks: { + { + Keys: bson.D{{"skylink", 1}}, + Options: options.Index().SetName("skylink_unique").SetUnique(true), + }, + }, + collUploads: { + { + Keys: bson.D{{"user_id", 1}}, + Options: options.Index().SetName("user_id"), + }, + { + Keys: bson.D{{"skylink_id", 1}}, + Options: options.Index().SetName("skylink_id"), + }, + }, + collDownloads: { + { + Keys: bson.D{{"user_id", 1}}, + Options: options.Index().SetName("user_id"), + }, + { + Keys: bson.D{{"skylink_id", 1}}, + Options: options.Index().SetName("skylink_id"), + }, + }, + collRegistryReads: { + { + Keys: bson.D{{"user_id", 1}}, + Options: options.Index().SetName("user_id"), + }, + }, + collRegistryWrites: { + { + Keys: bson.D{{"user_id", 1}}, + Options: options.Index().SetName("user_id"), + }, + }, + collEmails: { + { + Keys: bson.D{{"failed_attempts", 1}}, + Options: options.Index().SetName("failed_attempts"), + }, + { + Keys: bson.D{{"locked_by", 1}}, + Options: options.Index().SetName("locked_by"), + }, + { + Keys: bson.D{{"sent_at", 1}}, + Options: options.Index().SetName("sent_at"), + }, + { + Keys: bson.D{{"sent_by", 1}}, + Options: options.Index().SetName("sent_by"), + }, + }, + collChallenges: { + { + Keys: bson.D{{"challenge", 1}}, + Options: options.Index().SetName("challenge"), + }, + { + Keys: bson.D{{"type", 1}}, + Options: options.Index().SetName("type"), + }, + { + Keys: bson.D{{"expires_at", 1}}, + Options: options.Index().SetName("expires_at"), + }, + }, + collUnconfirmedUserUpdates: { + { + Keys: bson.D{{"challenge_id", 1}}, + Options: options.Index().SetName("challenge_id"), + }, + { + Keys: bson.D{{"expires_at", 1}}, + Options: options.Index().SetName("expires_at"), + }, + }, + collConfiguration: { + { + Keys: bson.D{{"key", 1}}, + Options: options.Index().SetName("key_unique").SetUnique(true), + }, + }, + collAPIKeys: { + { + Keys: bson.D{{"key", 1}}, + Options: options.Index().SetName("key_unique").SetUnique(true), + }, + { + Keys: bson.D{{"user_id", 1}}, + Options: options.Index().SetName("user_id"), + }, + }, + } +) diff --git a/database/upload.go b/database/upload.go index 2ff5fb88..93c24686 100644 --- a/database/upload.go +++ b/database/upload.go @@ -69,7 +69,7 @@ func (db *DB) UploadCreate(ctx context.Context, user User, skylink Skylink) (*Up // number of such uploads. func (db *DB) UploadsBySkylink(ctx context.Context, skylink Skylink, offset, pageSize int) ([]UploadResponse, int, error) { if skylink.ID.IsZero() { - return nil, 0, errors.New("invalid skylink") + return nil, 0, ErrInvalidSkylink } if err := validateOffsetPageSize(offset, pageSize); err != nil { return nil, 0, err @@ -85,7 +85,7 @@ func (db *DB) UploadsBySkylink(ctx context.Context, skylink Skylink, offset, pag // the number of unpinned uploads. func (db *DB) UnpinUploads(ctx context.Context, skylink Skylink, user User) (int64, error) { if skylink.ID.IsZero() { - return 0, errors.New("invalid skylink") + return 0, ErrInvalidSkylink } if user.ID.IsZero() { return 0, errors.New("invalid user") diff --git a/database/user.go b/database/user.go index 3a43fb5d..de53d08d 100644 --- a/database/user.go +++ b/database/user.go @@ -48,6 +48,8 @@ const ( var ( // True is a helper for when we need to pass a *bool to MongoDB. True = true + // False is a helper for when we need to pass a *bool to MongoDB. + False = false // UserLimits defines the speed limits for each tier. // RegistryDelay delay is in ms. UserLimits = map[int]TierLimits{ @@ -158,20 +160,6 @@ type ( } ) -// UserByAPIKey returns the user who owns the given API key. -func (db *DB) UserByAPIKey(ctx context.Context, ak APIKey) (*User, error) { - sr := db.staticAPIKeys.FindOne(ctx, bson.M{"key": ak}) - if sr.Err() != nil { - return nil, sr.Err() - } - var apiKey APIKeyRecord - err := sr.Decode(&apiKey) - if err != nil { - return nil, err - } - return db.UserByID(ctx, apiKey.UserID) -} - // UserByEmail returns the user with the given username. func (db *DB) UserByEmail(ctx context.Context, email string) (*User, error) { users, err := db.managedUsersByField(ctx, "email", email) diff --git a/env/README.md b/env/README.md new file mode 100644 index 00000000..7373bc76 --- /dev/null +++ b/env/README.md @@ -0,0 +1,4 @@ +## The `env` image + +The purpose of this image is to generate resources needed by the service. It currently only generates +a [JWKS](https://auth0.com/docs/secure/tokens/json-web-tokens/json-web-key-sets), needed for issuing JWTs. diff --git a/main.go b/main.go index aa98e89b..9d696cb3 100644 --- a/main.go +++ b/main.go @@ -140,7 +140,7 @@ func parseConfiguration(logger *logrus.Logger) (ServiceConfig, error) { config.ServerLockID = config.PortalName logger.Warningf(`Environment variable %s is missing! This server's identity`+ ` is set to the default '%s' value. That is OK only if this server is running on its own`+ - ` and it's not sharing its DB with other nodes.\n`, envServerDomain, config.ServerLockID) + ` and it's not sharing its DB with other nodes.`, envServerDomain, config.ServerLockID) } if sk := os.Getenv(envStripeAPIKey); sk != "" { diff --git a/test/api/api_test.go b/test/api/api_test.go index 1a0d17d6..80ec83f5 100644 --- a/test/api/api_test.go +++ b/test/api/api_test.go @@ -1,7 +1,6 @@ package api import ( - "bytes" "context" "encoding/hex" "fmt" @@ -21,32 +20,9 @@ import ( "gitlab.com/NebulousLabs/errors" ) -// TestResponseWriter is a testing ResponseWriter implementation. -type TestResponseWriter struct { - Buffer bytes.Buffer - Status int -} - -// Header implementation. -func (w TestResponseWriter) Header() http.Header { - return http.Header{} -} - -// Write implementation. -func (w TestResponseWriter) Write(b []byte) (int, error) { - return w.Buffer.Write(b) -} - -// WriteHeader implementation. -func (w TestResponseWriter) WriteHeader(statusCode int) { - w.Status = statusCode -} - // TestWithDBSession ensures that database transactions are started, committed, // and aborted properly. func TestWithDBSession(t *testing.T) { - t.Parallel() - ctx := context.Background() dbName := test.DBNameForTest(t.Name()) db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) @@ -125,7 +101,7 @@ func TestWithDBSession(t *testing.T) { testAPI.WriteError(w, errors.New("error"), http.StatusInternalServerError) } - var rw TestResponseWriter + var rw test.ResponseWriter var ps httprouter.Params req := (&http.Request{}).WithContext(ctx) // Call the success handler. @@ -161,8 +137,6 @@ func TestWithDBSession(t *testing.T) { // TestUserTierCache ensures out tier cache works as expected. func TestUserTierCache(t *testing.T) { - t.Parallel() - dbName := test.DBNameForTest(t.Name()) at, err := test.NewAccountsTester(dbName) if err != nil { @@ -201,7 +175,7 @@ func TestUserTierCache(t *testing.T) { } at.SetCookie(test.ExtractCookie(r)) // Get the user's limit. - ul, _, err := at.UserLimits("byte") + ul, _, err := at.UserLimits("byte", nil) if err != nil { t.Fatal(err) } @@ -236,7 +210,7 @@ func TestUserTierCache(t *testing.T) { err = build.Retry(10, 200*time.Millisecond, func() error { // We expect to get tier with name and id matching TierPremium20 but with // speeds matching TierAnonymous. - ul, _, err = at.UserLimits("byte") + ul, _, err = at.UserLimits("byte", nil) if err != nil { t.Fatal(err) } @@ -262,7 +236,7 @@ func TestUserTierCache(t *testing.T) { } err = build.Retry(10, 200*time.Millisecond, func() error { // We expect to get TierPremium20. - ul, _, err = at.UserLimits("byte") + ul, _, err = at.UserLimits("byte", nil) if err != nil { return errors.AddContext(err, "failed to call /user/limits") } diff --git a/test/api/apikeys_test.go b/test/api/apikeys_test.go index 4f5c88fd..2f7f2c8e 100644 --- a/test/api/apikeys_test.go +++ b/test/api/apikeys_test.go @@ -6,6 +6,7 @@ import ( "net/url" "testing" + "github.com/SkynetLabs/skynet-accounts/api" "github.com/SkynetLabs/skynet-accounts/database" "github.com/SkynetLabs/skynet-accounts/skynet" "github.com/SkynetLabs/skynet-accounts/test" @@ -13,8 +14,9 @@ import ( "go.sia.tech/siad/modules" ) -// testAPIKeysFlow validates the creation, listing, and deletion of API keys. -func testAPIKeysFlow(t *testing.T, at *test.AccountsTester) { +// testPrivateAPIKeysFlow validates the creation, listing, and deletion of private +// API keys. +func testPrivateAPIKeysFlow(t *testing.T, at *test.AccountsTester) { name := test.DBNameForTest(t.Name()) r, body, err := at.CreateUserPost(name+"@siasky.net", name+"_pass") if err != nil { @@ -22,92 +24,72 @@ func testAPIKeysFlow(t *testing.T, at *test.AccountsTester) { } at.SetCookie(test.ExtractCookie(r)) - aks := make([]database.APIKeyRecord, 0) - // List all API keys this user has. Expect the list to be empty. - r, body, err = at.Get("/user/apikeys", nil) + aks, _, err := at.UserAPIKeysLIST() if err != nil { t.Fatal(err, string(body)) } - err = json.Unmarshal(body, &aks) - if err != nil { - t.Fatal(err) - } if len(aks) > 0 { t.Fatalf("Expected an empty list of API keys, got %+v.", aks) } // Create a new API key. - r, body, err = at.Post("/user/apikeys", nil, nil) + ak1, _, err := at.UserAPIKeysPOST(api.APIKeyPOST{}) if err != nil { t.Fatal(err, string(body)) } - var ak1 database.APIKeyRecord - err = json.Unmarshal(body, &ak1) - if err != nil { - t.Fatal(err) + // Make sure the API key is private. + if ak1.Public { + t.Fatal("Expected the API key to be private.") } // Create another API key. - r, body, err = at.Post("/user/apikeys", nil, nil) + ak2, _, err := at.UserAPIKeysPOST(api.APIKeyPOST{}) if err != nil { t.Fatal(err, string(body)) } - var ak2 database.APIKeyRecord - err = json.Unmarshal(body, &ak2) - if err != nil { - t.Fatal(err) - } // List all API keys this user has. Expect to find both keys we created. - r, body, err = at.Get("/user/apikeys", nil) + aks, _, err = at.UserAPIKeysLIST() if err != nil { t.Fatal(err, string(body)) } - err = json.Unmarshal(body, &aks) - if err != nil { - t.Fatal(err) - } if len(aks) != 2 { t.Fatalf("Expected two API keys, got %+v.", aks) } - if ak1.Key != aks[0].Key && ak1.Key != aks[1].Key { - t.Fatalf("Missing key '%s'! Set: %+v", ak1.Key, aks) + if ak1.ID.Hex() != aks[0].ID.Hex() && ak1.ID.Hex() != aks[1].ID.Hex() { + t.Fatalf("Missing key '%s'! Set: %+v", ak1.ID.Hex(), aks) } - if ak2.Key != aks[0].Key && ak2.Key != aks[1].Key { - t.Fatalf("Missing key '%s'! Set: %+v", ak2.Key, aks) + if ak2.ID.Hex() != aks[0].ID.Hex() && ak2.ID.Hex() != aks[1].ID.Hex() { + t.Fatalf("Missing key '%s'! Set: %+v", ak2.ID.Hex(), aks) } // Delete an API key. - r, body, err = at.Delete("/user/apikeys/"+ak1.ID.Hex(), nil) - if err != nil { - t.Fatal(err, string(body)) + status, err := at.UserAPIKeysDELETE(ak1.ID) + if err != nil || status != http.StatusNoContent { + t.Fatal(err, status) } // List all API keys this user has. Expect to find only the second one. - r, body, err = at.Get("/user/apikeys", nil) + aks, _, err = at.UserAPIKeysLIST() if err != nil { t.Fatal(err, string(body)) } - err = json.Unmarshal(body, &aks) - if err != nil { - t.Fatal(err) - } if len(aks) != 1 { t.Fatalf("Expected one API key, got %+v.", aks) } - if ak2.Key != aks[0].Key { - t.Fatalf("Missing key '%s'! Set: %+v", ak2.Key, aks) + if ak2.ID.Hex() != aks[0].ID.Hex() { + t.Fatalf("Missing key '%s'! Set: %+v", ak2.ID.Hex(), aks) } - // Try to delete the same key again. Expect a Bad Request. - r, body, err = at.Delete("/user/apikeys/"+ak1.ID.Hex(), nil) - if r.StatusCode != http.StatusBadRequest { - t.Fatalf("Expected status 400, got %d.", r.StatusCode) + // Try to delete the same key again. Expect a 404. + status, _ = at.UserAPIKeysDELETE(ak1.ID) + if status != http.StatusNotFound { + t.Fatalf("Expected status 404, got %d.", status) } } -// testAPIKeysUsage makes sure that we can use API keys to make API calls. -func testAPIKeysUsage(t *testing.T, at *test.AccountsTester) { +// testPrivateAPIKeysUsage makes sure that we can use API keys to make API calls. +func testPrivateAPIKeysUsage(t *testing.T, at *test.AccountsTester) { name := test.DBNameForTest(t.Name()) // Create a test user. email := name + "@siasky.net" @@ -127,28 +109,18 @@ func testAPIKeysUsage(t *testing.T, at *test.AccountsTester) { t.Fatal(err) } // Create a new API key. - _, body, err := at.Post("/user/apikeys", nil, nil) + akWithKey, _, err := at.UserAPIKeysPOST(api.APIKeyPOST{}) if err != nil { t.Fatal(err) } // Stop using the cookie, so we can test the API key. at.ClearCredentials() - // We use a custom struct and not the APIKeyRecord one because that one does - // not render the key in JSON form and therefore it won't unmarshal it, - // either. - var ak struct { - Key database.APIKey - } - err = json.Unmarshal(body, &ak) - if err != nil { - t.Fatal(err) - } // Get user stats without a cookie or headers - pass the API key via a query // variable. The main thing we want to see here is whether we get // an `Unauthorized` error or not but we'll validate the stats as well. params := url.Values{} - params.Set("apiKey", string(ak.Key)) - _, body, err = at.Get("/user/stats", params) + params.Set("apiKey", akWithKey.Key.String()) + _, body, err := at.Get("/user/stats", params) if err != nil { t.Fatal(err, string(body)) } @@ -162,3 +134,155 @@ func testAPIKeysUsage(t *testing.T, at *test.AccountsTester) { uploadSize, us.TotalDownloadsSize, us.NumUploads, skynet.BandwidthUploadCost(uploadSize), us.BandwidthUploads) } } + +// testPublicAPIKeysFlow validates the creation, listing, and deletion of public +// API keys. +func testPublicAPIKeysFlow(t *testing.T, at *test.AccountsTester) { + name := test.DBNameForTest(t.Name()) + r, body, err := at.CreateUserPost(name+"@siasky.net", name+"_pass") + if err != nil { + t.Fatal(err, string(body)) + } + at.Cookie = test.ExtractCookie(r) + + sl1 := "AQAh2vxStoSJ_M9tWcTgqebUWerCAbpMfn9xxa9E29UOuw" + sl2 := "AADDE7_5MJyl1DKyfbuQMY_XBOBC9bR7idiU6isp6LXxEw" + + // List all API keys this user has. Expect the list to be empty. + aks, _, err := at.UserAPIKeysLIST() + if err != nil { + t.Fatal(err) + } + if len(aks) > 0 { + t.Fatalf("Expected an empty list of API keys, got %+v.", aks) + } + // Create a public API key. + akPost := api.APIKeyPOST{ + Public: true, + Skylinks: []string{sl1}, + } + akr, s, err := at.UserAPIKeysPOST(akPost) + if err != nil || s != http.StatusOK { + t.Fatal(err) + } + // List all API keys again. Expect to find a key. + aks, _, err = at.UserAPIKeysLIST() + if err != nil { + t.Fatal(err) + } + if len(aks) != 1 { + t.Fatalf("Expected one API key, got %d.", len(aks)) + } + if aks[0].Skylinks[0] != sl1 { + t.Fatal("Unexpected skylinks list", aks[0].Skylinks) + } + // Update a public API key. Expect to go from sl1 to sl2. + s, err = at.UserAPIKeysPUT(akr.ID, api.APIKeyPUT{Skylinks: []string{sl2}}) + if err != nil { + t.Fatal(err) + } + // Get the key and verify the change. + akr1, _, err := at.UserAPIKeysGET(akr.ID) + if err != nil { + t.Fatal(err, string(body)) + } + if akr1.Skylinks[0] != sl2 { + t.Fatal("Unexpected skylinks list", aks[0].Skylinks) + } + // Patch a public API key. Expect to go from sl2 to sl1. + akPatch := api.APIKeyPATCH{ + Add: []string{sl1}, + Remove: []string{sl2}, + } + s, err = at.UserAPIKeysPATCH(akr.ID, akPatch) + if err != nil { + t.Fatal(err) + } + // List and verify the change. + aks, _, err = at.UserAPIKeysLIST() + if err != nil { + t.Fatal(err, string(body)) + } + if len(aks) != 1 { + t.Fatalf("Expected one API key, got %d.", len(aks)) + } + if aks[0].Skylinks[0] != sl1 { + t.Fatal("Unexpected skylinks list", aks[0].Skylinks) + } + // Delete a public API key. + status, err := at.UserAPIKeysDELETE(akr.ID) + if err != nil || status != http.StatusNoContent { + t.Fatal(err, status) + } + // List and verify the change. + aks, _, err = at.UserAPIKeysLIST() + if err != nil { + t.Fatal(err, string(body)) + } + if len(aks) != 0 { + t.Fatalf("Expected no API keys, got %d.", len(aks)) + } +} + +// testPublicAPIKeysUsage makes sure that we can use public API keys to make +// GET requests to covered skylinks and that we cannot use them for other +// requests. +func testPublicAPIKeysUsage(t *testing.T, at *test.AccountsTester) { + name := test.DBNameForTest(t.Name()) + // Create a test user. + email := name + "@siasky.net" + r, _, err := at.CreateUserPost(email, name+"_pass") + if err != nil { + t.Fatal(err) + } + at.SetCookie(test.ExtractCookie(r)) + // Get the user and create a test upload, so the stats won't be all zeros. + u, err := at.DB.UserByEmail(at.Ctx, email) + if err != nil { + t.Fatal(err) + } + uploadSize := int64(fastrand.Intn(int(modules.SectorSize / 2))) + sl, _, err := test.CreateTestUpload(at.Ctx, at.DB, u, uploadSize) + if err != nil { + t.Fatal(err) + } + sl2, _, err := test.CreateTestUpload(at.Ctx, at.DB, u, uploadSize) + if err != nil { + t.Fatal(err) + } + // Create a new public API key. + apiKeyPOST := api.APIKeyPOST{ + Public: true, + Skylinks: []string{sl.Skylink}, + } + akWithKey, _, err := at.UserAPIKeysPOST(apiKeyPOST) + if err != nil { + t.Fatal(err) + } + // Stop using the cookie, use the public API key instead. + at.SetAPIKey(akWithKey.Key.String()) + // Try to fetch the user's stats with the new public API key. + // Expect this to fail. + _, _, err = at.Get("/user/stats", nil) + if err == nil { + t.Fatal("Managed to get user stats with a public API key.") + } + // Get the user's limits for downloading a skylink covered by the public + // API key. Expect to get TierFree values. + ul, _, err := at.UserLimitsSkylink(sl.Skylink, "byte", nil) + if err != nil { + t.Fatal(err) + } + if ul.DownloadBandwidth != database.UserLimits[database.TierFree].DownloadBandwidth { + t.Fatalf("Expected to get download bandwidth of %d, got %d", database.UserLimits[database.TierFree].DownloadBandwidth, ul.DownloadBandwidth) + } + // Get the user's limits for downloading a skylink that is not covered by + // the public API key. Expect to get TierAnonymous values. + ul, _, err = at.UserLimitsSkylink(sl2.Skylink, "byte", nil) + if err != nil { + t.Fatal(err) + } + if ul.DownloadBandwidth != database.UserLimits[database.TierAnonymous].DownloadBandwidth { + t.Fatalf("Expected to get download bandwidth of %d, got %d", database.UserLimits[database.TierAnonymous].DownloadBandwidth, ul.DownloadBandwidth) + } +} diff --git a/test/api/handlers_test.go b/test/api/handlers_test.go index 2389233d..2a478feb 100644 --- a/test/api/handlers_test.go +++ b/test/api/handlers_test.go @@ -68,8 +68,10 @@ func TestHandlers(t *testing.T) { {name: "StandardUserFlow", test: testUserFlow}, {name: "Challenge-Response/Registration", test: testRegistration}, {name: "Challenge-Response/Login", test: testLogin}, - {name: "APIKeysFlow", test: testAPIKeysFlow}, - {name: "APIKeysUsage", test: testAPIKeysUsage}, + {name: "PrivateAPIKeysFlow", test: testPrivateAPIKeysFlow}, + {name: "PrivateAPIKeysUsage", test: testPrivateAPIKeysUsage}, + {name: "PublicAPIKeysFlow", test: testPublicAPIKeysFlow}, + {name: "PublicAPIKeysUsage", test: testPublicAPIKeysUsage}, } // Run subtests @@ -82,17 +84,10 @@ func TestHandlers(t *testing.T) { // testHandlerHealthGET tests the /health handler. func testHandlerHealthGET(t *testing.T, at *test.AccountsTester) { - _, b, err := at.Get("/health", nil) + status, _, err := at.HealthGet() if err != nil { t.Fatal(err) } - status := struct { - DBAlive bool `json:"dbAlive"` - }{} - err = json.Unmarshal(b, &status) - if err != nil { - t.Fatal("Failed to unmarshal service's response: ", err) - } // DBAlive should never be false because if we couldn't reach the DB, we // wouldn't have made it this far in the test. if !status.DBAlive { @@ -440,8 +435,14 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { at.SetCookie(c) defer at.ClearCredentials() + // Create an API key for this user. + akr, _, err := at.UserAPIKeysPOST(api.APIKeyPOST{}) + if err != nil { + t.Fatal(err) + } + // Call /user/limits with a cookie. Expect FreeTier response. - tl, _, err := at.UserLimits("byte") + tl, _, err := at.UserLimits("byte", nil) if err != nil { t.Fatal(err) } @@ -457,7 +458,7 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { // Call /user/limits without a cookie. Expect FreeAnonymous response. at.ClearCredentials() - tl, _, err = at.UserLimits("byte") + tl, _, err = at.UserLimits("byte", nil) if err != nil { t.Fatal(err) } @@ -471,6 +472,21 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { t.Fatalf("Expected download bandwidth '%d', got '%d'", database.UserLimits[database.TierAnonymous].DownloadBandwidth, tl.DownloadBandwidth) } + // Call /user/limits with an API key. Expect TierFree response. + tl, _, err = at.UserLimits("byte", map[string]string{api.APIKeyHeader: string(akr.Key)}) + if err != nil { + t.Fatal(err) + } + if tl.TierName != database.UserLimits[database.TierFree].TierName { + t.Fatalf("Expected to get the results for %s, got %s", database.UserLimits[database.TierFree].TierName, tl.TierName) + } + if tl.TierName != database.UserLimits[database.TierFree].TierName { + t.Fatalf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierFree].TierName, tl.TierName) + } + if tl.DownloadBandwidth != database.UserLimits[database.TierFree].DownloadBandwidth { + t.Fatalf("Expected download bandwidth '%d', got '%d'", database.UserLimits[database.TierFree].DownloadBandwidth, tl.DownloadBandwidth) + } + // Create a new user which we'll use to test the quota limits. We can't use // the existing one because their status is already cached. u2, c, err := test.CreateUserAndLogin(at, t.Name()+"2") @@ -506,7 +522,7 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { err = build.Retry(10, 200*time.Millisecond, func() error { // Check the user's limits. We expect the tier to be Free but the limits to // match Anonymous. - tl, _, err = at.UserLimits("byte") + tl, _, err = at.UserLimits("byte", nil) if err != nil { return errors.AddContext(err, "failed to call /user/limits") } @@ -527,19 +543,19 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { // Test the `unit` parameter. The only valid value is `byte`, anything else // is ignored and the results are returned in bits per second. - tl, _, err = at.UserLimits("") + tl, _, err = at.UserLimits("", nil) if err != nil { t.Fatal(err) } // Request it with an invalid value. Expect it to be ignored. - tlBits, _, err := at.UserLimits("not-a-byte") + tlBits, _, err := at.UserLimits("not-a-byte", nil) if err != nil { t.Fatal(err) } if tlBits.UploadBandwidth != tl.UploadBandwidth || tlBits.DownloadBandwidth != tl.DownloadBandwidth { t.Fatalf("Expected these to be equal. %+v, %+v", tl, tlBits) } - tlBytes, _, err := at.UserLimits("byte") + tlBytes, _, err := at.UserLimits("byte", nil) if err != nil { t.Fatal(err) } @@ -547,7 +563,7 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { t.Fatalf("Invalid values in bytes. Values in bps: %+v, values in Bps: %+v", tl, tlBytes) } // Ensure we're not case-sensitive. - tlBytes2, _, err := at.UserLimits("ByTe") + tlBytes2, _, err := at.UserLimits("ByTe", nil) if err != nil { t.Fatal(err) } diff --git a/test/database/apikeys_test.go b/test/database/apikeys_test.go new file mode 100644 index 00000000..1eb0327b --- /dev/null +++ b/test/database/apikeys_test.go @@ -0,0 +1,149 @@ +package database + +import ( + "context" + "testing" + + "github.com/SkynetLabs/skynet-accounts/database" + "github.com/SkynetLabs/skynet-accounts/test" +) + +// TestAPIKeys ensures the DB operations with API keys work as expected. +func TestAPIKeys(t *testing.T) { + ctx := context.Background() + dbName := test.DBNameForTest(t.Name()) + db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) + if err != nil { + t.Fatal(err) + } + u, err := db.UserCreate(ctx, "", "", t.Name(), database.TierFree) + if err != nil { + t.Fatal("Unexpected error", err) + } + sl1 := test.RandomSkylink() + sl2 := test.RandomSkylink() + + // Create a private API key. + akr1, err := db.APIKeyCreate(ctx, *u, false, nil) + if err != nil { + t.Fatal(err) + } + // Create a private API key with skylinks. Expect to fail. + _, err = db.APIKeyCreate(ctx, *u, false, []string{sl1}) + if err == nil { + t.Fatal("Managed to create a private API key with skylinks.") + } + // Create a public API key + akr2, err := db.APIKeyCreate(ctx, *u, true, []string{sl1}) + if err != nil { + t.Fatal(err) + } + // Create a public API key without any skylinks. + akr3, err := db.APIKeyCreate(ctx, *u, true, nil) + if err != nil { + t.Fatal(err) + } + // Get an API key. + akr1a, err := db.APIKeyGet(ctx, akr1.ID) + if err != nil { + t.Fatal(err) + } + if akr1a.ID.Hex() != akr1.ID.Hex() { + t.Fatal("Did not get the correct API key!") + } + // Get an API key by key. + akr1a, err = db.APIKeyByKey(ctx, akr1.Key.String()) + if err != nil { + t.Fatal(err) + } + if akr1a.ID.Hex() != akr1.ID.Hex() { + t.Fatal("Did not get the correct API key by key!") + } + // List API keys. + akrs, err := db.APIKeyList(ctx, *u) + if err != nil { + t.Fatal(err) + } + if len(akrs) != 3 { + t.Fatalf("Expected %d API keys, got %d", 3, len(akrs)) + } + // Check if all keys we expect to exist actually exist. + found := 0 + for _, akr := range []*database.APIKeyRecord{akr1, akr2, akr3} { + for _, akrFound := range akrs { + if akrFound.ID.Hex() == akr.ID.Hex() { + found++ + } + } + } + if found != 3 { + t.Fatalf("Expected to find %d API keys we expect, found %d", 3, found) + } + + // Try to update a general API key. Expect to fail. + err = db.APIKeyUpdate(ctx, *u, akr1.ID, []string{sl1}) + if err == nil { + t.Fatal("Expected to be unable to update general API key.") + } + err = db.APIKeyPatch(ctx, *u, akr1.ID, []string{sl1}, nil) + if err == nil { + t.Fatal("Expected to be unable to patch general API key.") + } + // Update a public API key. + err = db.APIKeyUpdate(ctx, *u, akr2.ID, []string{sl1, sl2}) + if err != nil { + t.Fatal(err) + } + // Verify. + akr2a, err := db.APIKeyGet(ctx, akr2.ID) + if err != nil { + t.Fatal(err) + } + if !akr2a.CoversSkylink(sl1) || !akr2a.CoversSkylink(sl2) { + t.Fatal("Expected the API to cover both skylinks.") + } + // Patch a public API key. + err = db.APIKeyPatch(ctx, *u, akr2.ID, nil, []string{sl2}) + if err != nil { + t.Fatal(err) + } + // Verify. + akr2b, err := db.APIKeyGet(ctx, akr2.ID) + if err != nil { + t.Fatal(err) + } + if !akr2b.CoversSkylink(sl1) || akr2b.CoversSkylink(sl2) { + t.Fatal("Expected the API to cover one but not the other skylink.") + } + + // Delete a general API key. + err = db.APIKeyDelete(ctx, *u, akr1.ID) + if err != nil { + t.Fatal(err) + } + // Verify. + akrs, err = db.APIKeyList(ctx, *u) + if err != nil { + t.Fatal(err) + } + for _, akr := range akrs { + if akr.ID.Hex() == akr1.ID.Hex() { + t.Fatal("Expected the API key to be gone but it's not.") + } + } + // Delete a public API key. + err = db.APIKeyDelete(ctx, *u, akr2.ID) + if err != nil { + t.Fatal(err) + } + // Verify. + akrs, err = db.APIKeyList(ctx, *u) + if err != nil { + t.Fatal(err) + } + for _, akr := range akrs { + if akr.ID.Hex() == akr2.ID.Hex() { + t.Fatal("Expected the API key to be gone but it's not.") + } + } +} diff --git a/test/tester.go b/test/tester.go index 218c9449..0ed24e77 100644 --- a/test/tester.go +++ b/test/tester.go @@ -17,6 +17,7 @@ import ( "github.com/SkynetLabs/skynet-accounts/metafetcher" "github.com/sirupsen/logrus" "gitlab.com/NebulousLabs/errors" + "go.mongodb.org/mongo-driver/bson/primitive" "go.sia.tech/siad/build" ) @@ -33,7 +34,7 @@ type ( Ctx context.Context DB *database.DB Logger *logrus.Logger - // If set, this cookie will be attached to all requests. + APIKey string Cookie *http.Cookie Token string @@ -116,8 +117,8 @@ func NewAccountsTester(dbName string) (*AccountsTester, error) { } // Wait for the accounts tester to be fully ready. err = build.Retry(50, time.Millisecond, func() error { - _, _, err = at.Get("/health", nil) - return err + _, _, e := at.HealthGet() + return e }) if err != nil { return nil, errors.AddContext(err, "failed to start accounts tester in the given time") @@ -128,6 +129,7 @@ func NewAccountsTester(dbName string) (*AccountsTester, error) { // ClearCredentials removes any credentials stored by this tester, such as a // cookie, token, etc. func (at *AccountsTester) ClearCredentials() { + at.APIKey = "" at.Cookie = nil at.Token = "" } @@ -144,6 +146,13 @@ func (at *AccountsTester) Close() error { return nil } +// SetAPIKey ensures that all subsequent requests are going to use the given +// API key for authentication. +func (at *AccountsTester) SetAPIKey(ak string) { + at.ClearCredentials() + at.APIKey = ak +} + // SetCookie ensures that all subsequent requests are going to use the given // cookie for authentication. func (at *AccountsTester) SetCookie(c *http.Cookie) { @@ -161,22 +170,22 @@ func (at *AccountsTester) SetToken(t string) { // Get executes a GET request against the test service. // // NOTE: The Body of the returned response is already read and closed. -func (at *AccountsTester) Get(endpoint string, params url.Values) (r *http.Response, body []byte, err error) { - return at.request(http.MethodGet, endpoint, params, nil) +func (at *AccountsTester) Get(endpoint string, params url.Values) (*http.Response, []byte, error) { + return at.request(http.MethodGet, endpoint, params, nil, nil) } // Delete executes a DELETE request against the test service. // // NOTE: The Body of the returned response is already read and closed. -func (at *AccountsTester) Delete(endpoint string, params url.Values) (r *http.Response, body []byte, err error) { - return at.request(http.MethodDelete, endpoint, params, nil) +func (at *AccountsTester) Delete(endpoint string, params url.Values) (*http.Response, []byte, error) { + return at.request(http.MethodDelete, endpoint, params, nil, nil) } // Post executes a POST request against the test service. // // NOTE: The Body of the returned response is already read and closed. // TODO Remove the url.Values in favour of a simple map. -func (at *AccountsTester) Post(endpoint string, params url.Values, bodyParams url.Values) (r *http.Response, body []byte, err error) { +func (at *AccountsTester) Post(endpoint string, params url.Values, bodyParams url.Values) (*http.Response, []byte, error) { if params == nil { params = url.Values{} } @@ -189,12 +198,12 @@ func (at *AccountsTester) Post(endpoint string, params url.Values, bodyParams ur } bodyBytes, err := json.Marshal(bodyMap) if err != nil { - return + return &http.Response{}, nil, err } serviceURL := testPortalAddr + ":" + testPortalPort + endpoint + "?" + params.Encode() req, err := http.NewRequest(http.MethodPost, serviceURL, bytes.NewBuffer(bodyBytes)) if err != nil { - return nil, nil, err + return &http.Response{}, nil, err } req.Header.Set("Content-Type", "application/json") return at.executeRequest(req) @@ -203,14 +212,29 @@ func (at *AccountsTester) Post(endpoint string, params url.Values, bodyParams ur // Put executes a PUT request against the test service. // // NOTE: The Body of the returned response is already read and closed. -func (at *AccountsTester) Put(endpoint string, params url.Values, putParams url.Values) (r *http.Response, body []byte, err error) { - return at.request(http.MethodPut, endpoint, params, putParams) +func (at *AccountsTester) Put(endpoint string, params url.Values, bodyParams url.Values) (*http.Response, []byte, error) { + b, err := json.Marshal(bodyParams) + if err != nil { + return &http.Response{}, nil, errors.AddContext(err, "failed to marshal the body JSON") + } + return at.request(http.MethodPut, endpoint, params, b, nil) +} + +// Patch executes a PATCH request against the test service. +// +// NOTE: The Body of the returned response is already read and closed. +func (at *AccountsTester) Patch(endpoint string, params url.Values, bodyParams url.Values) (*http.Response, []byte, error) { + b, err := json.Marshal(bodyParams) + if err != nil { + return &http.Response{}, nil, errors.AddContext(err, "failed to marshal the body JSON") + } + return at.request(http.MethodPatch, endpoint, params, b, nil) } // CreateUserPost is a helper method that creates a new user. // // NOTE: The Body of the returned response is already read and closed. -func (at *AccountsTester) CreateUserPost(emailAddr, password string) (r *http.Response, body []byte, err error) { +func (at *AccountsTester) CreateUserPost(emailAddr, password string) (*http.Response, []byte, error) { params := url.Values{} params.Set("email", emailAddr) params.Set("password", password) @@ -228,11 +252,11 @@ func (at *AccountsTester) UserPUT(email, password, stipeID string) (*http.Respon "stripeCustomerId": stipeID, }) if err != nil { - return nil, nil, errors.AddContext(err, "failed to marshal the body JSON") + return &http.Response{}, nil, errors.AddContext(err, "failed to marshal the body JSON") } req, err := http.NewRequest(http.MethodPut, serviceURL, bytes.NewBuffer(b)) if err != nil { - return nil, nil, err + return &http.Response{}, nil, err } return at.executeRequest(req) } @@ -241,18 +265,17 @@ func (at *AccountsTester) UserPUT(email, password, stipeID string) (*http.Respon // request. It attaches the current cookie, if one exists. // // NOTE: The Body of the returned response is already read and closed. -func (at *AccountsTester) request(method string, endpoint string, queryParams url.Values, bodyParams url.Values) (*http.Response, []byte, error) { +func (at *AccountsTester) request(method string, endpoint string, queryParams url.Values, body []byte, headers map[string]string) (*http.Response, []byte, error) { if queryParams == nil { queryParams = url.Values{} } serviceURL := testPortalAddr + ":" + testPortalPort + endpoint + "?" + queryParams.Encode() - b, err := json.Marshal(bodyParams) + req, err := http.NewRequest(method, serviceURL, bytes.NewBuffer(body)) if err != nil { - return nil, nil, errors.AddContext(err, "failed to marshal the body JSON") + return &http.Response{}, nil, err } - req, err := http.NewRequest(method, serviceURL, bytes.NewBuffer(b)) - if err != nil { - return nil, nil, err + for name, val := range headers { + req.Header.Set(name, val) } return at.executeRequest(req) } @@ -263,7 +286,10 @@ func (at *AccountsTester) request(method string, endpoint string, queryParams ur // NOTE: The Body of the returned response is already read and closed. func (at *AccountsTester) executeRequest(req *http.Request) (*http.Response, []byte, error) { if req == nil { - return nil, nil, errors.New("invalid request") + return &http.Response{}, nil, errors.New("invalid request") + } + if at.APIKey != "" { + req.Header.Set(api.APIKeyHeader, at.APIKey) } if at.Cookie != nil { req.Header.Set("Cookie", at.Cookie.String()) @@ -274,8 +300,123 @@ func (at *AccountsTester) executeRequest(req *http.Request) (*http.Response, []b client := http.Client{} r, err := client.Do(req) if err != nil { - return nil, nil, err + return &http.Response{}, nil, err + } + return processResponse(r) +} + +// HealthGet executes a GET /health. +func (at *AccountsTester) HealthGet() (api.HealthGET, int, error) { + r, b, err := at.request(http.MethodGet, "/health", nil, nil, nil) + if err != nil { + return api.HealthGET{}, r.StatusCode, err + } + var resp api.HealthGET + err = json.Unmarshal(b, &resp) + if err != nil { + return api.HealthGET{}, 0, errors.AddContext(err, "failed to marshal the body JSON") + } + return resp, r.StatusCode, nil +} + +// UserAPIKeysDELETE performs a `DELETE /user/apikeys/:id` request. +func (at *AccountsTester) UserAPIKeysDELETE(id primitive.ObjectID) (int, error) { + r, _, err := at.request(http.MethodDelete, "/user/apikeys/"+id.Hex(), nil, nil, nil) + return r.StatusCode, err +} + +// UserAPIKeysGET performs a `GET /user/apikeys/:id` request. +func (at *AccountsTester) UserAPIKeysGET(id primitive.ObjectID) (api.APIKeyResponse, int, error) { + r, b, err := at.request(http.MethodGet, "/user/apikeys/"+id.Hex(), nil, nil, nil) + if err != nil { + return api.APIKeyResponse{}, r.StatusCode, err + } + if r.StatusCode != http.StatusOK { + return api.APIKeyResponse{}, r.StatusCode, errors.New(string(b)) + } + var result api.APIKeyResponse + err = json.Unmarshal(b, &result) + if err != nil { + return api.APIKeyResponse{}, 0, errors.AddContext(err, "failed to parse response") + } + return result, r.StatusCode, nil +} + +// UserAPIKeysLIST performs a `GET /user/apikeys` request. +func (at *AccountsTester) UserAPIKeysLIST() ([]api.APIKeyResponse, int, error) { + r, b, err := at.request(http.MethodGet, "/user/apikeys", nil, nil, nil) + if err != nil { + return nil, r.StatusCode, err + } + if r.StatusCode != http.StatusOK { + return nil, r.StatusCode, errors.New(string(b)) + } + result := make([]api.APIKeyResponse, 0) + err = json.Unmarshal(b, &result) + if err != nil { + return nil, 0, errors.AddContext(err, "failed to parse response") + } + return result, r.StatusCode, nil +} + +// UserAPIKeysPOST performs a `POST /user/apikeys` request. +func (at *AccountsTester) UserAPIKeysPOST(body api.APIKeyPOST) (api.APIKeyResponseWithKey, int, error) { + bb, err := json.Marshal(body) + if err != nil { + return api.APIKeyResponseWithKey{}, http.StatusBadRequest, err + } + r, b, err := at.request(http.MethodPost, "/user/apikeys", nil, bb, nil) + if err != nil { + return api.APIKeyResponseWithKey{}, r.StatusCode, err + } + if r.StatusCode != http.StatusOK { + return api.APIKeyResponseWithKey{}, r.StatusCode, errors.New(string(b)) + } + var result api.APIKeyResponseWithKey + err = json.Unmarshal(b, &result) + if err != nil { + return api.APIKeyResponseWithKey{}, 0, errors.AddContext(err, "failed to parse response") + } + return result, r.StatusCode, nil +} + +// UserAPIKeysPUT performs a `PUT /user/apikeys` request. +func (at *AccountsTester) UserAPIKeysPUT(akID primitive.ObjectID, body api.APIKeyPUT) (int, error) { + bb, err := json.Marshal(body) + if err != nil { + return http.StatusBadRequest, err } + r, b, err := at.request(http.MethodPut, "/user/apikeys/"+akID.Hex(), nil, bb, nil) + if err != nil { + return r.StatusCode, err + } + if r.StatusCode != http.StatusNoContent { + return r.StatusCode, errors.New(string(b)) + } + return r.StatusCode, nil +} + +// UserAPIKeysPATCH performs a `PATH /user/apikeys` request. +func (at *AccountsTester) UserAPIKeysPATCH(akID primitive.ObjectID, body api.APIKeyPATCH) (int, error) { + bb, err := json.Marshal(body) + if err != nil { + return http.StatusBadRequest, err + } + r, b, err := at.request(http.MethodPatch, "/user/apikeys/"+akID.Hex(), nil, bb, nil) + if err != nil { + return r.StatusCode, err + } + if r.StatusCode != http.StatusNoContent { + return r.StatusCode, errors.New(string(b)) + } + return r.StatusCode, nil +} + +// processResponse is a helper method which extracts the body from the response +// and handles non-OK status codes. +// +// NOTE: The Body of the returned response is already read and closed. +func processResponse(r *http.Response) (*http.Response, []byte, error) { body, err := ioutil.ReadAll(r.Body) _ = r.Body.Close() // For convenience, whenever we have a non-OK status we'll wrap it in an @@ -290,38 +431,65 @@ func (at *AccountsTester) executeRequest(req *http.Request) (*http.Response, []b func (at *AccountsTester) TrackDownload(skylink string, bytes int64) (int, error) { form := url.Values{} form.Set("bytes", fmt.Sprint(bytes)) - r, _, err := at.request(http.MethodPost, "/track/download/"+skylink, form, nil) + r, _, err := at.request(http.MethodPost, "/track/download/"+skylink, form, nil, nil) return r.StatusCode, err } // TrackUpload performs a `POST /track/upload/:skylink` request. func (at *AccountsTester) TrackUpload(skylink string) (int, error) { - r, _, err := at.request(http.MethodPost, "/track/upload/"+skylink, nil, nil) + r, _, err := at.request(http.MethodPost, "/track/upload/"+skylink, nil, nil, nil) return r.StatusCode, err } // TrackRegistryRead performs a `POST /track/registry/read` request. func (at *AccountsTester) TrackRegistryRead() (int, error) { - r, _, err := at.request(http.MethodPost, "/track/registry/read", nil, nil) + r, _, err := at.request(http.MethodPost, "/track/registry/read", nil, nil, nil) return r.StatusCode, err } // TrackRegistryWrite performs a `POST /track/registry/write` request. func (at *AccountsTester) TrackRegistryWrite() (int, error) { - r, _, err := at.request(http.MethodPost, "/track/registry/write", nil, nil) + r, _, err := at.request(http.MethodPost, "/track/registry/write", nil, nil, nil) return r.StatusCode, err } // UserLimits performs a `GET /user/limits` request. -func (at *AccountsTester) UserLimits(unit string) (api.UserLimitsGET, int, error) { +func (at *AccountsTester) UserLimits(unit string, headers map[string]string) (api.UserLimitsGET, int, error) { queryParams := url.Values{} if unit != "" { queryParams.Set("unit", unit) } - r, b, err := at.request(http.MethodGet, "/user/limits", queryParams, nil) + r, b, err := at.request(http.MethodGet, "/user/limits", queryParams, nil, headers) if err != nil { return api.UserLimitsGET{}, r.StatusCode, err } + if r.StatusCode != http.StatusOK { + return api.UserLimitsGET{}, r.StatusCode, errors.New(string(b)) + } + var resp api.UserLimitsGET + err = json.Unmarshal(b, &resp) + if err != nil { + return api.UserLimitsGET{}, 0, errors.AddContext(err, "failed to marshal the body JSON") + } + return resp, r.StatusCode, nil +} + +// UserLimitsSkylink performs a `GET /user/limits/:skylink` request. +func (at *AccountsTester) UserLimitsSkylink(sl string, unit string, headers map[string]string) (api.UserLimitsGET, int, error) { + queryParams := url.Values{} + if unit != "" { + queryParams.Set("unit", unit) + } + if !database.ValidSkylinkHash(sl) { + return api.UserLimitsGET{}, 0, database.ErrInvalidSkylink + } + r, b, err := at.request(http.MethodGet, "/user/limits/"+sl, queryParams, nil, headers) + if err != nil { + return api.UserLimitsGET{}, r.StatusCode, err + } + if r.StatusCode != http.StatusOK { + return api.UserLimitsGET{}, r.StatusCode, errors.New(string(b)) + } var resp api.UserLimitsGET err = json.Unmarshal(b, &resp) if err != nil { diff --git a/test/utils.go b/test/utils.go index e68850ac..039d31ed 100644 --- a/test/utils.go +++ b/test/utils.go @@ -1,6 +1,7 @@ package test import ( + "bytes" "context" "encoding/hex" "fmt" @@ -34,8 +35,28 @@ type ( *database.User staticDB *database.DB } + // ResponseWriter is a testing ResponseWriter implementation. + ResponseWriter struct { + Buffer bytes.Buffer + Status int + } ) +// Header implementation. +func (w ResponseWriter) Header() http.Header { + return http.Header{} +} + +// Write implementation. +func (w ResponseWriter) Write(b []byte) (int, error) { + return w.Buffer.Write(b) +} + +// WriteHeader implementation. +func (w ResponseWriter) WriteHeader(statusCode int) { + w.Status = statusCode +} + // Delete removes the test user from the DB. func (tu *User) Delete(ctx context.Context) error { return tu.staticDB.UserDelete(ctx, tu.User)