diff --git a/api/cache.go b/api/cache.go index 72013ecc..7b2717fc 100644 --- a/api/cache.go +++ b/api/cache.go @@ -22,8 +22,9 @@ type ( // userTierCacheEntry allows us to cache some basic information about the // user, so we don't need to hit the DB to fetch data that rarely changes. userTierCacheEntry struct { - Tier int - ExpiresAt time.Time + Tier int + QuotaExceeded bool + ExpiresAt time.Time } ) @@ -34,34 +35,25 @@ func newUserTierCache() *userTierCache { } } -// Get returns the user's tier and an OK indicator which is true when the cache -// entry exists and hasn't expired, yet. -func (utc *userTierCache) Get(sub string) (int, bool) { +// Get returns the user's tier, a quota exceeded flag, and an OK indicator +// which is true when the cache entry exists and hasn't expired, yet. +func (utc *userTierCache) Get(sub string) (int, bool, bool) { utc.mu.Lock() ce, exists := utc.cache[sub] utc.mu.Unlock() if !exists || ce.ExpiresAt.Before(time.Now().UTC()) { - return database.TierAnonymous, false + return database.TierAnonymous, false, false } - return ce.Tier, true + return ce.Tier, ce.QuotaExceeded, true } // Set stores the user's tier in the cache. func (utc *userTierCache) Set(u *database.User) { - var ce userTierCacheEntry - now := time.Now().UTC() - if u.QuotaExceeded { - ce = userTierCacheEntry{ - Tier: database.TierAnonymous, - ExpiresAt: now.Add(userTierCacheTTL), - } - } else { - ce = userTierCacheEntry{ - Tier: u.Tier, - ExpiresAt: now.Add(userTierCacheTTL), - } - } utc.mu.Lock() - utc.cache[u.Sub] = ce + utc.cache[u.Sub] = userTierCacheEntry{ + Tier: u.Tier, + QuotaExceeded: u.QuotaExceeded, + ExpiresAt: time.Now().UTC().Add(userTierCacheTTL), + } utc.mu.Unlock() } diff --git a/api/cache_test.go b/api/cache_test.go index 76c332d1..ba255533 100644 --- a/api/cache_test.go +++ b/api/cache_test.go @@ -17,17 +17,29 @@ func TestUserTierCache(t *testing.T) { QuotaExceeded: false, } // Get the user from the empty cache. - tier, ok := cache.Get(u.Sub) + tier, _, ok := cache.Get(u.Sub) if ok || tier != database.TierAnonymous { t.Fatalf("Expected to get tier %d and %t, got %d and %t.", database.TierAnonymous, false, tier, ok) } - // Set the use in the cache. + // Set the user in the cache. cache.Set(u) // Check again. - tier, ok = cache.Get(u.Sub) + tier, qe, ok := cache.Get(u.Sub) if !ok || tier != u.Tier { t.Fatalf("Expected to get tier %d and %t, got %d and %t.", u.Tier, true, tier, ok) } + if qe != u.QuotaExceeded { + t.Fatal("Quota exceeded flag doesn't match.") + } + u.QuotaExceeded = true + cache.Set(u) + tier, qe, ok = cache.Get(u.Sub) + if !ok || tier != u.Tier { + t.Fatalf("Expected to get tier %d and %t, got %d and %t.", u.Tier, true, tier, ok) + } + if qe != u.QuotaExceeded { + t.Fatal("Quota exceeded flag doesn't match.") + } ce, exists := cache.cache[u.Sub] if !exists { t.Fatal("Expected the entry to exist.") diff --git a/api/handlers.go b/api/handlers.go index 79d69b82..5db391da 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -80,8 +80,14 @@ type ( } // UserLimitsGET is response of GET /user/limits UserLimitsGET struct { - TierID int `json:"tierID"` - database.TierLimits + TierID int `json:"tierID"` + TierName string `json:"tierName"` + UploadBandwidth int `json:"upload"` // bytes per second + DownloadBandwidth int `json:"download"` // bytes per second + MaxUploadSize int64 `json:"maxUploadSize"` // the max size of a single upload in bytes + MaxNumberUploads int `json:"-"` + RegistryDelay int `json:"registry"` // ms delay + Storage int64 `json:"-"` } // accountRecoveryPOST defines the payload we expect when a user is trying @@ -398,10 +404,7 @@ func (api *API) userGET(u *database.User, w http.ResponseWriter, _ *http.Request func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { // First check for an API key. ak, err := apiKeyFromRequest(req) - respAnon := UserLimitsGET{ - TierID: database.TierAnonymous, - TierLimits: database.UserLimits[database.TierAnonymous], - } + respAnon := userLimitsGetFromTier(database.TierAnonymous) if err == nil { u, err := api.staticDB.UserByAPIKey(req.Context(), ak) if err != nil { @@ -409,9 +412,15 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http api.WriteJSON(w, respAnon) return } - resp := UserLimitsGET{ - TierID: u.Tier, - TierLimits: database.UserLimits[u.Tier], + resp := userLimitsGetFromTier(u.Tier) + // If the quota is exceeded we should keep the user's tier but report + // anonymous-level speeds. + if u.QuotaExceeded { + // Report the speeds for tier anonymous. + resp = userLimitsGetFromTier(database.TierAnonymous) + // But keep reporting the user's actual tier and it's name. + resp.TierID = u.Tier + resp.TierName = database.UserLimits[u.Tier].TierName } api.WriteJSON(w, resp) return @@ -430,7 +439,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http sub := s.(string) // If the user is not cached, or they were cached too long ago we'll fetch // their data from the DB. - tier, ok := api.staticUserTierCache.Get(sub) + tier, qe, ok := api.staticUserTierCache.Get(sub) if !ok { u, err := api.staticDB.UserBySub(req.Context(), sub) if err != nil { @@ -439,14 +448,22 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http return } api.staticUserTierCache.Set(u) + // Populate the tier and qe values, while simultaneously making sure + // that we can read the record from the cache. + tier, qe, ok = api.staticUserTierCache.Get(sub) + if !ok { + build.Critical("Failed to fetch user from UserTierCache right after setting it.") + } } - tier, ok = api.staticUserTierCache.Get(sub) - if !ok { - build.Critical("Failed to fetch user from UserTierCache right after setting it.") - } - resp := UserLimitsGET{ - TierID: tier, - TierLimits: database.UserLimits[tier], + resp := userLimitsGetFromTier(tier) + // If the quota is exceeded we should keep the user's tier but report + // anonymous-level speeds. + if qe { + // Report anonymous speeds. + resp = userLimitsGetFromTier(database.TierAnonymous) + // Keep reporting the user's actual tier and tier name. + resp.TierID = tier + resp.TierName = database.UserLimits[tier].TierName } api.WriteJSON(w, resp) } @@ -1177,3 +1194,19 @@ func fetchPageSize(form url.Values) (int, error) { func parseRequestBodyJSON(body io.ReadCloser, maxBodySize int64, objRef interface{}) error { return json.NewDecoder(io.LimitReader(body, maxBodySize)).Decode(&objRef) } + +// userLimitsGetFromTier is a helper that lets us succinctly translate +// from the database DTO to the API DTO. +func userLimitsGetFromTier(tier int) *UserLimitsGET { + t := database.UserLimits[tier] + return &UserLimitsGET{ + TierID: tier, + TierName: t.TierName, + UploadBandwidth: t.UploadBandwidth, + DownloadBandwidth: t.DownloadBandwidth, + MaxUploadSize: t.MaxUploadSize, + MaxNumberUploads: t.MaxNumberUploads, + RegistryDelay: t.RegistryDelay, + Storage: t.Storage, + } +} diff --git a/database/challenge.go b/database/challenge.go index b8afd98d..fe184887 100644 --- a/database/challenge.go +++ b/database/challenge.go @@ -94,7 +94,7 @@ type ( // NewChallenge creates a new challenge with the given type and pubKey. func (db *DB) NewChallenge(ctx context.Context, pubKey PubKey, cType string) (*Challenge, error) { if cType != ChallengeTypeLogin && cType != ChallengeTypeRegister && cType != ChallengeTypeUpdate { - return nil, errors.New(fmt.Sprintf("invalid challenge type '%s'", cType)) + return nil, fmt.Errorf("invalid challenge type '%s'", cType) } ch := &Challenge{ Challenge: hex.EncodeToString(fastrand.Bytes(ChallengeSize)), diff --git a/test/api/api_test.go b/test/api/api_test.go index fc964f41..0c6ef578 100644 --- a/test/api/api_test.go +++ b/test/api/api_test.go @@ -4,7 +4,7 @@ import ( "bytes" "context" "encoding/hex" - "encoding/json" + "fmt" "net/http" "net/url" "testing" @@ -14,6 +14,7 @@ import ( "github.com/SkynetLabs/skynet-accounts/database" "github.com/SkynetLabs/skynet-accounts/test" "gitlab.com/NebulousLabs/fastrand" + "go.sia.tech/siad/build" "github.com/julienschmidt/httprouter" "github.com/sirupsen/logrus" @@ -198,15 +199,9 @@ func TestUserTierCache(t *testing.T) { if err != nil { t.Fatal(err) } - at.Cookie = test.ExtractCookie(r) - // Get the user's limit. Since they are on a Pro account but their - // SubscribedUntil is set in the past, we expect to get TierFree. - _, b, err := at.Get("/user/limits", nil) - if err != nil { - t.Fatal(err) - } - var ul api.UserLimitsGET - err = json.Unmarshal(b, &ul) + at.SetCookie(test.ExtractCookie(r)) + // Get the user's limit. + ul, _, err := at.UserLimits() if err != nil { t.Fatal(err) } @@ -216,12 +211,11 @@ func TestUserTierCache(t *testing.T) { if ul.TierID != database.TierPremium20 { t.Fatalf("Expected tier id '%d', got '%d'", database.TierPremium20, ul.TierID) } - // Now set their SubscribedUntil in the future, so their subscription tier - // is active. - u.SubscribedUntil = time.Now().UTC().Add(365 * 24 * time.Hour) - err = at.DB.UserSave(at.Ctx, u.User) - if err != nil { - t.Fatal(err) + if ul.TierName != database.UserLimits[database.TierPremium20].TierName { + t.Fatalf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierPremium20].TierName, ul.TierName) + } + if ul.UploadBandwidth != database.UserLimits[database.TierPremium20].UploadBandwidth { + t.Fatalf("Expected upload bandwidth '%d', got '%d'", database.UserLimits[database.TierPremium20].UploadBandwidth, ul.UploadBandwidth) } // Register a test upload that exceeds the user's allowed storage, so their // QuotaExceeded flag will get raised. @@ -232,45 +226,55 @@ func TestUserTierCache(t *testing.T) { // Make a specific call to trackUploadPOST in order to trigger the // checkUserQuotas method. This wil register the upload a second time but // that doesn't affect the test. - _, _, err = at.Post("/track/upload/"+sl.Skylink, nil, nil) - if err != nil { - t.Fatal(err) - } - // Sleep for a short time in order to make sure that the background - // goroutine that updates user's quotas has had time to run. - time.Sleep(2 * time.Second) - // We expect to get TierAnonymous. - _, b, err = at.Get("/user/limits", nil) + _, err = at.TrackUpload(sl.Skylink) if err != nil { t.Fatal(err) } - err = json.Unmarshal(b, &ul) + // We need to try this several times because we'll only get the right result + // after the background goroutine that updates user's quotas has had time to + // run. + err = build.Retry(10, 200*time.Millisecond, func() error { + // We expect to get tier with name and id matching TierPremium20 but with + // speeds matching TierAnonymous. + ul, _, err = at.UserLimits() + if err != nil { + t.Fatal(err) + } + if ul.TierID != database.TierPremium20 { + return fmt.Errorf("Expected tier id '%d', got '%d'", database.TierPremium20, ul.TierID) + } + if ul.TierName != database.UserLimits[database.TierPremium20].TierName { + return fmt.Errorf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierPremium20].TierName, ul.TierName) + } + if ul.UploadBandwidth != database.UserLimits[database.TierAnonymous].UploadBandwidth { + return fmt.Errorf("Expected upload bandwidth '%d', got '%d'", database.UserLimits[database.TierAnonymous].UploadBandwidth, ul.UploadBandwidth) + } + return nil + }) if err != nil { t.Fatal(err) } - if ul.TierID != database.TierAnonymous { - t.Fatalf("Expected tier id '%d', got '%d'", database.TierAnonymous, ul.TierID) - } - if ul.TierName != database.UserLimits[database.TierAnonymous].TierName { - t.Fatalf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierAnonymous].TierName, ul.TierName) - } // Delete the uploaded file, so the user's quota recovers. // This call should invalidate the tier cache. _, _, err = at.Delete("/user/uploads/"+sl.Skylink, nil) - time.Sleep(2 * time.Second) - // We expect to get TierPremium20. - _, b, err = at.Get("/user/limits", nil) if err != nil { t.Fatal(err) } - err = json.Unmarshal(b, &ul) + err = build.Retry(10, 200*time.Millisecond, func() error { + // We expect to get TierPremium20. + ul, _, err = at.UserLimits() + if err != nil { + return errors.AddContext(err, "failed to call /user/limits") + } + if ul.TierID != database.TierPremium20 { + return fmt.Errorf("Expected tier id '%d', got '%d'", database.TierPremium20, ul.TierID) + } + if ul.TierName != database.UserLimits[database.TierPremium20].TierName { + return fmt.Errorf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierPremium20].TierName, ul.TierName) + } + return nil + }) if err != nil { t.Fatal(err) } - if ul.TierID != database.TierPremium20 { - t.Fatalf("Expected tier id '%d', got '%d'", database.TierPremium20, ul.TierID) - } - if ul.TierName != database.UserLimits[database.TierPremium20].TierName { - t.Fatalf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierPremium20].TierName, ul.TierName) - } } diff --git a/test/api/apikeys_test.go b/test/api/apikeys_test.go index f6443b88..4f5c88fd 100644 --- a/test/api/apikeys_test.go +++ b/test/api/apikeys_test.go @@ -20,7 +20,7 @@ func testAPIKeysFlow(t *testing.T, at *test.AccountsTester) { if err != nil { t.Fatal(err, string(body)) } - at.Cookie = test.ExtractCookie(r) + at.SetCookie(test.ExtractCookie(r)) aks := make([]database.APIKeyRecord, 0) @@ -115,7 +115,7 @@ func testAPIKeysUsage(t *testing.T, at *test.AccountsTester) { if err != nil { t.Fatal(err) } - at.Cookie = test.ExtractCookie(r) + at.SetCookie(test.ExtractCookie(r)) // Get the user and create a test upload, so the stats won't be all zeros. u, err := at.DB.UserByEmail(at.Ctx, email) if err != nil { @@ -132,7 +132,7 @@ func testAPIKeysUsage(t *testing.T, at *test.AccountsTester) { t.Fatal(err) } // Stop using the cookie, so we can test the API key. - at.Cookie = nil + at.ClearCredentials() // We use a custom struct and not the APIKeyRecord one because that one does // not render the key in JSON form and therefore it won't unmarshal it, // either. diff --git a/test/api/challenge_test.go b/test/api/challenge_test.go index 3d2372a8..c3a91ab3 100644 --- a/test/api/challenge_test.go +++ b/test/api/challenge_test.go @@ -165,7 +165,7 @@ func testLogin(t *testing.T, at *test.AccountsTester) { t.Fatalf("Failed to login. Status %d, body '%s', error '%s'", r.StatusCode, string(b), err) } // Make sure we have a valid cookie returned and that it's for the same user. - at.Cookie = test.ExtractCookie(r) + at.SetCookie(test.ExtractCookie(r)) _, b, err = at.Get("/user", nil) if err != nil { t.Fatalf("Failed to fetch user with the given cookie: '%s', error '%s'", string(b), err) @@ -192,8 +192,8 @@ func testUserAddPubKey(t *testing.T, at *test.AccountsTester) { t.Error(errors.AddContext(err, "failed to delete user in defer")) } }() - at.Cookie = c - defer func() { at.Cookie = nil }() + at.SetCookie(c) + defer at.ClearCredentials() // Request a challenge without a pubkey. r, b, _ := at.Get("/user/pubkey/register", nil) @@ -248,7 +248,7 @@ func testUserAddPubKey(t *testing.T, at *test.AccountsTester) { } // Try to solve it without being logged in. - at.Cookie = nil + at.ClearCredentials() response := append(chBytes, append([]byte(database.ChallengeTypeUpdate), []byte(database.PortalName)...)...) bodyParams := url.Values{} bodyParams.Set("response", hex.EncodeToString(response)) @@ -266,7 +266,7 @@ func testUserAddPubKey(t *testing.T, at *test.AccountsTester) { if err != nil || r.StatusCode != http.StatusOK { t.Fatal(r.Status, err, string(b)) } - at.Cookie = test.ExtractCookie(r) + at.SetCookie(test.ExtractCookie(r)) r, b, _ = at.Post("/user/pubkey/register", nil, bodyParams) if r.StatusCode != http.StatusBadRequest || !strings.Contains(string(b), "user's sub doesn't match update sub") { t.Fatalf("Expected %d '%s', got %d '%s'", @@ -274,7 +274,7 @@ func testUserAddPubKey(t *testing.T, at *test.AccountsTester) { } // Request a new challenge with the original test user. - at.Cookie = c + at.SetCookie(c) queryParams = url.Values{} queryParams.Set("pubKey", hex.EncodeToString(pk[:])) r, b, err = at.Get("/user/pubkey/register", queryParams) diff --git a/test/api/handlers_test.go b/test/api/handlers_test.go index bd4c5fc5..3626ae34 100644 --- a/test/api/handlers_test.go +++ b/test/api/handlers_test.go @@ -3,6 +3,7 @@ package api import ( "encoding/hex" "encoding/json" + "fmt" "net/http" "net/url" "reflect" @@ -22,6 +23,7 @@ import ( "gitlab.com/SkynetLabs/skyd/skymodules" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/options" + "go.sia.tech/siad/build" "go.sia.tech/siad/crypto" ) @@ -193,8 +195,8 @@ func testHandlerLoginPOST(t *testing.T, at *test.AccountsTester) { t.Fatal("Expected a cookie.") } // Make sure the returned cookie is usable for making requests. - at.Cookie = c - defer func() { at.Cookie = nil }() + at.SetCookie(c) + defer at.ClearCredentials() // Make sure the response contains a valid JWT. _, err = jwt.ValidateToken(r.Header.Get("Skynet-Token")) if err != nil { @@ -210,7 +212,7 @@ func testHandlerLoginPOST(t *testing.T, at *test.AccountsTester) { t.Fatal(err, string(b)) } // Expect the returned cookie to be already expired. - at.Cookie = test.ExtractCookie(r) + at.SetCookie(test.ExtractCookie(r)) if at.Cookie == nil { t.Fatal("Expected to have a cookie.") } @@ -251,16 +253,16 @@ func testUserPUT(t *testing.T, at *test.AccountsTester) { } }() - at.Cookie = c - defer func() { at.Cookie = nil }() + at.SetCookie(c) + defer at.ClearCredentials() // Call unauthorized. - at.Cookie = nil + at.ClearCredentials() _, _, err = at.Put("/user", nil, nil) if err == nil || !strings.Contains(err.Error(), unauthorized) { t.Fatalf("Expected error '%s', got '%s'", unauthorized, err) } - at.Cookie = c + at.SetCookie(c) // Update the user's Stripe ID. stripeID := name + "_stripe_id" _, b, err := at.UserPUT("", "", stripeID) @@ -357,8 +359,8 @@ func testUserDELETE(t *testing.T, at *test.AccountsTester) { t.Fatal("Failed to create a user and log in:", err) } // Delete the user. - at.Cookie = c - defer func() { at.Cookie = nil }() + at.SetCookie(c) + defer at.ClearCredentials() r, _, err := at.Delete("/user", nil) if err != nil || r.StatusCode != http.StatusNoContent { t.Fatalf("Expected %d success, got %d '%s'", http.StatusNoContent, r.StatusCode, err) @@ -391,14 +393,14 @@ func testUserDELETE(t *testing.T, at *test.AccountsTester) { t.Fatal(err) } // Try to delete the user without a cookie. - at.Cookie = nil + at.ClearCredentials() r, _, _ = at.Delete("/user", nil) if r.StatusCode != http.StatusUnauthorized { t.Fatalf("Expected %d, got %d", http.StatusUnauthorized, r.StatusCode) } // Delete the user. - at.Cookie = c - defer func() { at.Cookie = nil }() + at.SetCookie(c) + defer at.ClearCredentials() r, _, err = at.Delete("/user", nil) if err != nil || r.StatusCode != http.StatusNoContent { t.Fatalf("Expected %d success, got %d '%s'", http.StatusNoContent, r.StatusCode, err) @@ -435,17 +437,11 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { t.Error(errors.AddContext(err, "failed to delete user in defer")) } }() - - at.Cookie = c - defer func() { at.Cookie = nil }() + at.SetCookie(c) + defer at.ClearCredentials() // Call /user/limits with a cookie. Expect FreeTier response. - _, b, err := at.Get("/user/limits", nil) - if err != nil { - t.Fatal(err) - } - var tl api.UserLimitsGET - err = json.Unmarshal(b, &tl) + tl, _, err := at.UserLimits() if err != nil { t.Fatal(err) } @@ -455,14 +451,13 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { if tl.TierName != database.UserLimits[database.TierFree].TierName { t.Fatalf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierFree].TierName, tl.TierName) } + if tl.DownloadBandwidth != database.UserLimits[database.TierFree].DownloadBandwidth { + t.Fatalf("Expected download bandwidth '%d', got '%d'", database.UserLimits[database.TierFree].DownloadBandwidth, tl.DownloadBandwidth) + } // Call /user/limits without a cookie. Expect FreeAnonymous response. - at.Cookie = nil - _, b, err = at.Get("/user/limits", nil) - if err != nil { - t.Fatal(err) - } - err = json.Unmarshal(b, &tl) + at.ClearCredentials() + tl, _, err = at.UserLimits() if err != nil { t.Fatal(err) } @@ -472,6 +467,63 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) { if tl.TierName != database.UserLimits[database.TierAnonymous].TierName { t.Fatalf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierAnonymous].TierName, tl.TierName) } + if tl.DownloadBandwidth != database.UserLimits[database.TierAnonymous].DownloadBandwidth { + t.Fatalf("Expected download bandwidth '%d', got '%d'", database.UserLimits[database.TierAnonymous].DownloadBandwidth, tl.DownloadBandwidth) + } + + // Create a new user which we'll use to test the quota limits. We can't use + // the existing one because their status is already cached. + u2, c, err := test.CreateUserAndLogin(at, t.Name()+"2") + if err != nil { + t.Fatal("Failed to create a user and log in:", err) + } + defer func() { + if err = u2.Delete(at.Ctx); err != nil { + t.Error(errors.AddContext(err, "failed to delete user in defer")) + } + }() + at.SetCookie(c) + defer at.ClearCredentials() + // Upload a very large file, which exceeds the user's storage limit. This + // should cause their QuotaExceed flag to go up and their speeds to drop to + // anonymous levels. Their tier should remain Free. + dbu2 := *u2.User + filesize := database.UserLimits[database.TierFree].Storage + 1 + sl, _, err := test.CreateTestUpload(at.Ctx, at.DB, &dbu2, filesize) + if err != nil { + t.Fatal(err) + } + // Make a specific call to trackUploadPOST in order to trigger the + // checkUserQuotas method. This wil register the upload a second time but + // that doesn't affect the test. + _, err = at.TrackUpload(sl.Skylink) + if err != nil { + t.Fatal(err) + } + // We need to try this several times because we'll only get the right result + // after the background goroutine that updates user's quotas has had time to + // run. + err = build.Retry(10, 200*time.Millisecond, func() error { + // Check the user's limits. We expect the tier to be Free but the limits to + // match Anonymous. + tl, _, err = at.UserLimits() + if err != nil { + return errors.AddContext(err, "failed to call /user/limits") + } + if tl.TierID != database.TierFree { + return fmt.Errorf("Expected to get the results for tier id %d, got %d", database.TierFree, tl.TierID) + } + if tl.TierName != database.UserLimits[database.TierFree].TierName { + return fmt.Errorf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierFree].TierName, tl.TierName) + } + if tl.DownloadBandwidth != database.UserLimits[database.TierAnonymous].DownloadBandwidth { + return fmt.Errorf("Expected download bandwidth '%d', got '%d'", database.UserLimits[database.TierAnonymous].DownloadBandwidth, tl.DownloadBandwidth) + } + return nil + }) + if err != nil { + t.Fatal(err) + } } // testUserUploadsDELETE tests the DELETE /user/uploads/:skylink endpoint. @@ -486,8 +538,8 @@ func testUserUploadsDELETE(t *testing.T, at *test.AccountsTester) { } }() - at.Cookie = c - defer func() { at.Cookie = nil }() + at.SetCookie(c) + defer at.ClearCredentials() // Create an upload. skylink, _, err := test.CreateTestUpload(at.Ctx, at.DB, u.User, 128%skynet.KiB) @@ -506,12 +558,12 @@ func testUserUploadsDELETE(t *testing.T, at *test.AccountsTester) { t.Fatalf("Expected to have a single upload of %s, got %+v", skylink.Skylink, ups) } // Try to delete the upload without passing a JWT cookie. - at.Cookie = nil + at.ClearCredentials() _, b, err = at.Delete("/user/uploads/"+skylink.Skylink, nil) if err == nil || !strings.Contains(err.Error(), unauthorized) { t.Fatalf("Expected error %s, got %s. Body: %s", unauthorized, err, string(b)) } - at.Cookie = c + at.SetCookie(c) // Delete it. _, b, err = at.Delete("/user/uploads/"+skylink.Skylink, nil) if err != nil { @@ -546,7 +598,7 @@ func testUserConfirmReconfirmEmailGET(t *testing.T, at *test.AccountsTester) { } }() - defer func() { at.Cookie = nil }() + defer at.ClearCredentials() // Confirm the user params := url.Values{} @@ -565,14 +617,14 @@ func testUserConfirmReconfirmEmailGET(t *testing.T, at *test.AccountsTester) { } // Make sure `POST /user/reconfirm` requires a cookie. - at.Cookie = nil + at.ClearCredentials() _, b, err = at.Post("/user/reconfirm", nil, nil) if err == nil || !strings.Contains(err.Error(), unauthorized) { t.Fatalf("Expected '%s', got '%s'. Body: '%s'", unauthorized, err, string(b)) } // Reset the confirmation field, so we can continue testing with the same // user. - at.Cookie = c + at.SetCookie(c) _, b, err = at.Post("/user/reconfirm", nil, nil) if err != nil { t.Fatal(err, string(b)) @@ -624,7 +676,7 @@ func testUserAccountRecovery(t *testing.T, at *test.AccountsTester) { } }() - defer func() { at.Cookie = nil }() + defer at.ClearCredentials() // // TEST REQUESTING RECOVERY // // @@ -807,8 +859,8 @@ func testTrackingAndStats(t *testing.T, at *test.AccountsTester) { } }() - at.Cookie = c - defer func() { at.Cookie = nil }() + at.SetCookie(c) + defer at.ClearCredentials() // Generate a random skylink. skylink, err := skymodules.NewSkylinkV1(crypto.HashBytes(fastrand.Bytes(32)), 0, 32) @@ -818,21 +870,21 @@ func testTrackingAndStats(t *testing.T, at *test.AccountsTester) { expectedStats := database.UserStats{} // Call trackUpload without a cookie. - at.Cookie = nil - _, b, err := at.Post("/track/upload/"+skylink.String(), nil, nil) + at.ClearCredentials() + _, err = at.TrackUpload(skylink.String()) if err == nil || !strings.Contains(err.Error(), unauthorized) { - t.Fatalf("Expected error '%s', got '%s'. Body: '%s'", unauthorized, err, string(b)) + t.Fatalf("Expected error '%s', got '%v'", unauthorized, err) } - at.Cookie = c + at.SetCookie(c) // Call trackUpload with an invalid skylink. - _, b, err = at.Post("/track/upload/INVALID_SKYLINK", nil, nil) + _, err = at.TrackUpload("INVALID_SKYLINK") if err == nil || !strings.Contains(err.Error(), badRequest) { - t.Fatalf("Expected '%s', got '%s'. Body: '%s'", badRequest, err, string(b)) + t.Fatalf("Expected '%s', got '%v'", badRequest, err) } // Call trackUpload with a valid skylink. - _, b, err = at.Post("/track/upload/"+skylink.String(), nil, nil) + _, err = at.TrackUpload(skylink.String()) if err != nil { - t.Fatal(err, string(b)) + t.Fatal(err) } // Adjust the expectations. We won't adjust anything based on size because // the metafetcher won't be running during testing. @@ -841,32 +893,26 @@ func testTrackingAndStats(t *testing.T, at *test.AccountsTester) { expectedStats.RawStorageUsed += skynet.RawStorageUsed(0) // Call trackDownload without a cookie. - at.Cookie = nil - params := url.Values{} - params.Set("bytes", "100") - _, b, err = at.Post("/track/download/"+skylink.String(), params, nil) + at.ClearCredentials() + _, err = at.TrackDownload(skylink.String(), 100) if err == nil || !strings.Contains(err.Error(), unauthorized) { - t.Fatalf("Expected error '%s', got '%s'. Body: '%s", unauthorized, err, string(b)) + t.Fatalf("Expected error '%s', got '%v'", unauthorized, err) } - at.Cookie = c + at.SetCookie(c) // Call trackDownload with an invalid skylink. - _, b, err = at.Post("/track/download/INVALID_SKYLINK", params, nil) + _, err = at.TrackDownload("INVALID_SKYLINK", 100) if err == nil || !strings.Contains(err.Error(), badRequest) { - t.Fatalf("Expected '%s', got '%s'. Body: '%s'", badRequest, err, string(b)) + t.Fatalf("Expected '%s', got '%v'", badRequest, err) } // Call trackDownload with a valid skylink and a negative size download - params = url.Values{} - params.Set("bytes", "-100") - _, b, err = at.Post("/track/download/"+skylink.String(), params, nil) + _, err = at.TrackDownload(skylink.String(), -100) if err == nil || !strings.Contains(err.Error(), badRequest) { - t.Fatalf("Expected '%s', got '%s'. Body: '%s'", badRequest, err, string(b)) + t.Fatalf("Expected '%s', got '%v'", badRequest, err) } // Call trackDownload with a valid skylink. - params = url.Values{} - params.Set("bytes", "100") - _, b, err = at.Post("/track/download/"+skylink.String(), params, nil) + _, err = at.TrackDownload(skylink.String(), 100) if err != nil { - t.Fatal(err, string(b)) + t.Fatal(err) } // Adjust the expectations. expectedStats.NumDownloads++ @@ -874,44 +920,44 @@ func testTrackingAndStats(t *testing.T, at *test.AccountsTester) { expectedStats.TotalDownloadsSize += 100 // Call trackRegistryRead without a cookie. - at.Cookie = nil - _, b, err = at.Post("/track/registry/read", nil, nil) + at.ClearCredentials() + _, err = at.TrackRegistryRead() if err == nil || !strings.Contains(err.Error(), unauthorized) { - t.Fatalf("Expected error '%s', got '%s'. Body: '%s'", unauthorized, err, string(b)) + t.Fatalf("Expected error '%s', got '%v'", unauthorized, err) } - at.Cookie = c + at.SetCookie(c) // Call trackRegistryRead. - _, b, err = at.Post("/track/registry/read", nil, nil) + _, err = at.TrackRegistryRead() if err != nil { - t.Fatal(err, string(b)) + t.Fatal(err) } // Adjust the expectations. expectedStats.NumRegReads++ expectedStats.BandwidthRegReads += skynet.CostBandwidthRegistryRead // Call trackRegistryWrite without a cookie. - at.Cookie = nil - _, b, err = at.Post("/track/registry/write", nil, nil) + at.ClearCredentials() + _, err = at.TrackRegistryWrite() if err == nil || !strings.Contains(err.Error(), unauthorized) { - t.Fatalf("Expected error '%s', got '%s'. Body: '%s'", unauthorized, err, string(b)) + t.Fatalf("Expected error '%s', got '%v'", unauthorized, err) } - at.Cookie = c + at.SetCookie(c) // Call trackRegistryWrite. - _, b, err = at.Post("/track/registry/write", nil, nil) + _, err = at.TrackRegistryWrite() if err != nil { - t.Fatal(err, string(b)) + t.Fatal(err) } // Adjust the expectations. expectedStats.NumRegWrites++ expectedStats.BandwidthRegWrites += skynet.CostBandwidthRegistryWrite // Call userStats without a cookie. - at.Cookie = nil - _, b, err = at.Get("/user/stats", nil) + at.ClearCredentials() + _, b, err := at.Get("/user/stats", nil) if err == nil || !strings.Contains(err.Error(), unauthorized) { - t.Fatalf("Expected error '%s', got '%s'. Body: '%s'", unauthorized, err, string(b)) + t.Fatalf("Expected error '%s', got '%v'", unauthorized, err) } - at.Cookie = c + at.SetCookie(c) // Call userStats. _, b, err = at.Get("/user/stats", nil) if err != nil { @@ -958,8 +1004,8 @@ func testUserFlow(t *testing.T, at *test.AccountsTester) { t.Fatal("Login failed. Error ", err.Error()) } // Grab the Skynet cookie, so we can make authenticated calls. - at.Cookie = test.ExtractCookie(r) - defer func() { at.Cookie = nil }() + at.SetCookie(test.ExtractCookie(r)) + defer at.ClearCredentials() if at.Cookie == nil { t.Fatalf("Failed to extract cookie from request. Cookies found: %+v", r.Cookies()) } @@ -969,15 +1015,13 @@ func testUserFlow(t *testing.T, at *test.AccountsTester) { t.Fatal("Missing or invalid token. Error:", err) } // Make sure we can make calls with this token. - at.Token = tk c := at.Cookie - at.Cookie = nil + at.SetToken(tk) _, _, err = at.Get("/user", nil) if err != nil { t.Fatal("Failed to fetch user data with token:", err.Error()) } - at.Token = "" - at.Cookie = c + at.SetCookie(c) // Change the user's email. newEmail := name + "_new@siasky.net" r, b, err := at.UserPUT(newEmail, "", "") @@ -985,7 +1029,7 @@ func testUserFlow(t *testing.T, at *test.AccountsTester) { t.Fatalf("Failed to update user. Error: %s. Body: %s", err.Error(), string(b)) } // Grab the new cookie. It has changed because of the user edit. - at.Cookie = test.ExtractCookie(r) + at.SetCookie(test.ExtractCookie(r)) if at.Cookie == nil { t.Fatalf("Failed to extract cookie from request. Cookies found: %+v", r.Cookies()) } @@ -1007,7 +1051,7 @@ func testUserFlow(t *testing.T, at *test.AccountsTester) { t.Fatal("Failed to logout:", err.Error()) } // Grab the new cookie. - at.Cookie = test.ExtractCookie(r) + at.SetCookie(test.ExtractCookie(r)) // Try to get the user, expect a 401. _, b, err = at.Get("/user", nil) if err == nil || !strings.Contains(err.Error(), unauthorized) { diff --git a/test/email/sender_test.go b/test/email/sender_test.go index c3eda237..11a33b97 100644 --- a/test/email/sender_test.go +++ b/test/email/sender_test.go @@ -2,6 +2,7 @@ package email import ( "context" + "fmt" "strconv" "sync" "sync/atomic" @@ -14,6 +15,7 @@ import ( "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/options" + "go.sia.tech/siad/build" ) // TestSender goes through the standard Sender workflow and ensures that it @@ -62,19 +64,24 @@ func TestSender(t *testing.T) { } // Start the sender and wait for a second. sender.Start() - time.Sleep(2 * time.Second) - // Check that the email has been sent. - _, emails, err = db.FindEmails(ctx, filterTo, &options.FindOptions{}) + err = build.Retry(10, 200*time.Millisecond, func() error { + // Check that the email has been sent. + _, emails, err = db.FindEmails(ctx, filterTo, &options.FindOptions{}) + if err != nil { + return err + } + if len(emails) != 1 { + return fmt.Errorf("expected 1 email in the DB, got %d", len(emails)) + } + if emails[0].SentAt.IsZero() { + emails[0].Body = "<<>>" + return fmt.Errorf("email not sent. Email: %+v", emails[0]) + } + return nil + }) if err != nil { t.Fatal(err) } - if len(emails) != 1 { - t.Fatalf("Expected 1 email in the DB, got %d\n", len(emails)) - } - if emails[0].SentAt.IsZero() { - emails[0].Body = "<<>>" - t.Fatalf("Email not sent. Email: %+v\n", emails[0]) - } } // TestContendingSenders ensures that each email generated by a cluster of diff --git a/test/tester.go b/test/tester.go index 7d69bfdc..64dcb035 100644 --- a/test/tester.go +++ b/test/tester.go @@ -124,6 +124,39 @@ func NewAccountsTester(dbName string) (*AccountsTester, error) { return at, nil } +// ClearCredentials removes any credentials stored by this tester, such as a +// cookie, token, etc. +func (at *AccountsTester) ClearCredentials() { + at.Cookie = nil + at.Token = "" +} + +// Close performs a graceful shutdown of the AccountsTester service. +func (at *AccountsTester) Close() error { + at.cancel() + if at.DB != nil { + err := at.DB.Disconnect(at.Ctx) + if err != nil { + return err + } + } + return nil +} + +// SetCookie ensures that all subsequent requests are going to use the given +// cookie for authentication. +func (at *AccountsTester) SetCookie(c *http.Cookie) { + at.ClearCredentials() + at.Cookie = c +} + +// SetToken ensures that all subsequent requests are going to use the given +// token for authentication. +func (at *AccountsTester) SetToken(t string) { + at.ClearCredentials() + at.Token = t +} + // Get executes a GET request against the test service. // // NOTE: The Body of the returned response is already read and closed. @@ -173,12 +206,6 @@ func (at *AccountsTester) Put(endpoint string, params url.Values, putParams url. return at.request(http.MethodPut, endpoint, params, putParams) } -// Close performs a graceful shutdown of the AccountsTester service. -func (at *AccountsTester) Close() error { - at.cancel() - return nil -} - // CreateUserPost is a helper method that creates a new user. // // NOTE: The Body of the returned response is already read and closed. @@ -257,3 +284,43 @@ func (at *AccountsTester) executeRequest(req *http.Request) (*http.Response, []b } return r, body, err } + +// TrackDownload performs a `POST /track/download/:skylink` request. +func (at *AccountsTester) TrackDownload(skylink string, bytes int64) (int, error) { + form := url.Values{} + form.Set("bytes", fmt.Sprint(bytes)) + r, _, err := at.request(http.MethodPost, "/track/download/"+skylink, form, nil) + return r.StatusCode, err +} + +// TrackUpload performs a `POST /track/upload/:skylink` request. +func (at *AccountsTester) TrackUpload(skylink string) (int, error) { + r, _, err := at.request(http.MethodPost, "/track/upload/"+skylink, nil, nil) + return r.StatusCode, err +} + +// TrackRegistryRead performs a `POST /track/registry/read` request. +func (at *AccountsTester) TrackRegistryRead() (int, error) { + r, _, err := at.request(http.MethodPost, "/track/registry/read", nil, nil) + return r.StatusCode, err +} + +// TrackRegistryWrite performs a `POST /track/registry/write` request. +func (at *AccountsTester) TrackRegistryWrite() (int, error) { + r, _, err := at.request(http.MethodPost, "/track/registry/write", nil, nil) + return r.StatusCode, err +} + +// UserLimits performs a `GET /user/limits` request. +func (at *AccountsTester) UserLimits() (api.UserLimitsGET, int, error) { + r, b, err := at.request(http.MethodGet, "/user/limits", nil, nil) + if err != nil { + return api.UserLimitsGET{}, r.StatusCode, err + } + var resp api.UserLimitsGET + err = json.Unmarshal(b, &resp) + if err != nil { + return api.UserLimitsGET{}, 0, errors.AddContext(err, "failed to marshal the body JSON") + } + return resp, r.StatusCode, nil +}