Skip to content

Commit

Permalink
Merge pull request #146 from SkynetLabs/ivo/fix_limits
Browse files Browse the repository at this point in the history
Fix /user/limits quotaExceeded
  • Loading branch information
Christopher Schinnerl authored and kwypchlo committed Mar 16, 2022
1 parent e385fc9 commit 6dff5dc
Show file tree
Hide file tree
Showing 10 changed files with 353 additions and 194 deletions.
34 changes: 13 additions & 21 deletions api/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ type (
// userTierCacheEntry allows us to cache some basic information about the
// user, so we don't need to hit the DB to fetch data that rarely changes.
userTierCacheEntry struct {
Tier int
ExpiresAt time.Time
Tier int
QuotaExceeded bool
ExpiresAt time.Time
}
)

Expand All @@ -34,34 +35,25 @@ func newUserTierCache() *userTierCache {
}
}

// Get returns the user's tier and an OK indicator which is true when the cache
// entry exists and hasn't expired, yet.
func (utc *userTierCache) Get(sub string) (int, bool) {
// Get returns the user's tier, a quota exceeded flag, and an OK indicator
// which is true when the cache entry exists and hasn't expired, yet.
func (utc *userTierCache) Get(sub string) (int, bool, bool) {
utc.mu.Lock()
ce, exists := utc.cache[sub]
utc.mu.Unlock()
if !exists || ce.ExpiresAt.Before(time.Now().UTC()) {
return database.TierAnonymous, false
return database.TierAnonymous, false, false
}
return ce.Tier, true
return ce.Tier, ce.QuotaExceeded, true
}

// Set stores the user's tier in the cache.
func (utc *userTierCache) Set(u *database.User) {
var ce userTierCacheEntry
now := time.Now().UTC()
if u.QuotaExceeded {
ce = userTierCacheEntry{
Tier: database.TierAnonymous,
ExpiresAt: now.Add(userTierCacheTTL),
}
} else {
ce = userTierCacheEntry{
Tier: u.Tier,
ExpiresAt: now.Add(userTierCacheTTL),
}
}
utc.mu.Lock()
utc.cache[u.Sub] = ce
utc.cache[u.Sub] = userTierCacheEntry{
Tier: u.Tier,
QuotaExceeded: u.QuotaExceeded,
ExpiresAt: time.Now().UTC().Add(userTierCacheTTL),
}
utc.mu.Unlock()
}
18 changes: 15 additions & 3 deletions api/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,29 @@ func TestUserTierCache(t *testing.T) {
QuotaExceeded: false,
}
// Get the user from the empty cache.
tier, ok := cache.Get(u.Sub)
tier, _, ok := cache.Get(u.Sub)
if ok || tier != database.TierAnonymous {
t.Fatalf("Expected to get tier %d and %t, got %d and %t.", database.TierAnonymous, false, tier, ok)
}
// Set the use in the cache.
// Set the user in the cache.
cache.Set(u)
// Check again.
tier, ok = cache.Get(u.Sub)
tier, qe, ok := cache.Get(u.Sub)
if !ok || tier != u.Tier {
t.Fatalf("Expected to get tier %d and %t, got %d and %t.", u.Tier, true, tier, ok)
}
if qe != u.QuotaExceeded {
t.Fatal("Quota exceeded flag doesn't match.")
}
u.QuotaExceeded = true
cache.Set(u)
tier, qe, ok = cache.Get(u.Sub)
if !ok || tier != u.Tier {
t.Fatalf("Expected to get tier %d and %t, got %d and %t.", u.Tier, true, tier, ok)
}
if qe != u.QuotaExceeded {
t.Fatal("Quota exceeded flag doesn't match.")
}
ce, exists := cache.cache[u.Sub]
if !exists {
t.Fatal("Expected the entry to exist.")
Expand Down
67 changes: 50 additions & 17 deletions api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,14 @@ type (
}
// UserLimitsGET is response of GET /user/limits
UserLimitsGET struct {
TierID int `json:"tierID"`
database.TierLimits
TierID int `json:"tierID"`
TierName string `json:"tierName"`
UploadBandwidth int `json:"upload"` // bytes per second
DownloadBandwidth int `json:"download"` // bytes per second
MaxUploadSize int64 `json:"maxUploadSize"` // the max size of a single upload in bytes
MaxNumberUploads int `json:"-"`
RegistryDelay int `json:"registry"` // ms delay
Storage int64 `json:"-"`
}

// accountRecoveryPOST defines the payload we expect when a user is trying
Expand Down Expand Up @@ -398,20 +404,23 @@ func (api *API) userGET(u *database.User, w http.ResponseWriter, _ *http.Request
func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
// First check for an API key.
ak, err := apiKeyFromRequest(req)
respAnon := UserLimitsGET{
TierID: database.TierAnonymous,
TierLimits: database.UserLimits[database.TierAnonymous],
}
respAnon := userLimitsGetFromTier(database.TierAnonymous)
if err == nil {
u, err := api.staticDB.UserByAPIKey(req.Context(), ak)
if err != nil {
api.staticLogger.Traceln("Error while fetching user by API key:", err)
api.WriteJSON(w, respAnon)
return
}
resp := UserLimitsGET{
TierID: u.Tier,
TierLimits: database.UserLimits[u.Tier],
resp := userLimitsGetFromTier(u.Tier)
// If the quota is exceeded we should keep the user's tier but report
// anonymous-level speeds.
if u.QuotaExceeded {
// Report the speeds for tier anonymous.
resp = userLimitsGetFromTier(database.TierAnonymous)
// But keep reporting the user's actual tier and it's name.
resp.TierID = u.Tier
resp.TierName = database.UserLimits[u.Tier].TierName
}
api.WriteJSON(w, resp)
return
Expand All @@ -430,7 +439,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http
sub := s.(string)
// If the user is not cached, or they were cached too long ago we'll fetch
// their data from the DB.
tier, ok := api.staticUserTierCache.Get(sub)
tier, qe, ok := api.staticUserTierCache.Get(sub)
if !ok {
u, err := api.staticDB.UserBySub(req.Context(), sub)
if err != nil {
Expand All @@ -439,14 +448,22 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http
return
}
api.staticUserTierCache.Set(u)
// Populate the tier and qe values, while simultaneously making sure
// that we can read the record from the cache.
tier, qe, ok = api.staticUserTierCache.Get(sub)
if !ok {
build.Critical("Failed to fetch user from UserTierCache right after setting it.")
}
}
tier, ok = api.staticUserTierCache.Get(sub)
if !ok {
build.Critical("Failed to fetch user from UserTierCache right after setting it.")
}
resp := UserLimitsGET{
TierID: tier,
TierLimits: database.UserLimits[tier],
resp := userLimitsGetFromTier(tier)
// If the quota is exceeded we should keep the user's tier but report
// anonymous-level speeds.
if qe {
// Report anonymous speeds.
resp = userLimitsGetFromTier(database.TierAnonymous)
// Keep reporting the user's actual tier and tier name.
resp.TierID = tier
resp.TierName = database.UserLimits[tier].TierName
}
api.WriteJSON(w, resp)
}
Expand Down Expand Up @@ -1177,3 +1194,19 @@ func fetchPageSize(form url.Values) (int, error) {
func parseRequestBodyJSON(body io.ReadCloser, maxBodySize int64, objRef interface{}) error {
return json.NewDecoder(io.LimitReader(body, maxBodySize)).Decode(&objRef)
}

// userLimitsGetFromTier is a helper that lets us succinctly translate
// from the database DTO to the API DTO.
func userLimitsGetFromTier(tier int) *UserLimitsGET {
t := database.UserLimits[tier]
return &UserLimitsGET{
TierID: tier,
TierName: t.TierName,
UploadBandwidth: t.UploadBandwidth,
DownloadBandwidth: t.DownloadBandwidth,
MaxUploadSize: t.MaxUploadSize,
MaxNumberUploads: t.MaxNumberUploads,
RegistryDelay: t.RegistryDelay,
Storage: t.Storage,
}
}
2 changes: 1 addition & 1 deletion database/challenge.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ type (
// NewChallenge creates a new challenge with the given type and pubKey.
func (db *DB) NewChallenge(ctx context.Context, pubKey PubKey, cType string) (*Challenge, error) {
if cType != ChallengeTypeLogin && cType != ChallengeTypeRegister && cType != ChallengeTypeUpdate {
return nil, errors.New(fmt.Sprintf("invalid challenge type '%s'", cType))
return nil, fmt.Errorf("invalid challenge type '%s'", cType)
}
ch := &Challenge{
Challenge: hex.EncodeToString(fastrand.Bytes(ChallengeSize)),
Expand Down
88 changes: 46 additions & 42 deletions test/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"bytes"
"context"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"net/url"
"testing"
Expand All @@ -14,6 +14,7 @@ import (
"github.com/SkynetLabs/skynet-accounts/database"
"github.com/SkynetLabs/skynet-accounts/test"
"gitlab.com/NebulousLabs/fastrand"
"go.sia.tech/siad/build"

"github.com/julienschmidt/httprouter"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -198,15 +199,9 @@ func TestUserTierCache(t *testing.T) {
if err != nil {
t.Fatal(err)
}
at.Cookie = test.ExtractCookie(r)
// Get the user's limit. Since they are on a Pro account but their
// SubscribedUntil is set in the past, we expect to get TierFree.
_, b, err := at.Get("/user/limits", nil)
if err != nil {
t.Fatal(err)
}
var ul api.UserLimitsGET
err = json.Unmarshal(b, &ul)
at.SetCookie(test.ExtractCookie(r))
// Get the user's limit.
ul, _, err := at.UserLimits()
if err != nil {
t.Fatal(err)
}
Expand All @@ -216,12 +211,11 @@ func TestUserTierCache(t *testing.T) {
if ul.TierID != database.TierPremium20 {
t.Fatalf("Expected tier id '%d', got '%d'", database.TierPremium20, ul.TierID)
}
// Now set their SubscribedUntil in the future, so their subscription tier
// is active.
u.SubscribedUntil = time.Now().UTC().Add(365 * 24 * time.Hour)
err = at.DB.UserSave(at.Ctx, u.User)
if err != nil {
t.Fatal(err)
if ul.TierName != database.UserLimits[database.TierPremium20].TierName {
t.Fatalf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierPremium20].TierName, ul.TierName)
}
if ul.UploadBandwidth != database.UserLimits[database.TierPremium20].UploadBandwidth {
t.Fatalf("Expected upload bandwidth '%d', got '%d'", database.UserLimits[database.TierPremium20].UploadBandwidth, ul.UploadBandwidth)
}
// Register a test upload that exceeds the user's allowed storage, so their
// QuotaExceeded flag will get raised.
Expand All @@ -232,45 +226,55 @@ func TestUserTierCache(t *testing.T) {
// Make a specific call to trackUploadPOST in order to trigger the
// checkUserQuotas method. This wil register the upload a second time but
// that doesn't affect the test.
_, _, err = at.Post("/track/upload/"+sl.Skylink, nil, nil)
if err != nil {
t.Fatal(err)
}
// Sleep for a short time in order to make sure that the background
// goroutine that updates user's quotas has had time to run.
time.Sleep(2 * time.Second)
// We expect to get TierAnonymous.
_, b, err = at.Get("/user/limits", nil)
_, err = at.TrackUpload(sl.Skylink)
if err != nil {
t.Fatal(err)
}
err = json.Unmarshal(b, &ul)
// We need to try this several times because we'll only get the right result
// after the background goroutine that updates user's quotas has had time to
// run.
err = build.Retry(10, 200*time.Millisecond, func() error {
// We expect to get tier with name and id matching TierPremium20 but with
// speeds matching TierAnonymous.
ul, _, err = at.UserLimits()
if err != nil {
t.Fatal(err)
}
if ul.TierID != database.TierPremium20 {
return fmt.Errorf("Expected tier id '%d', got '%d'", database.TierPremium20, ul.TierID)
}
if ul.TierName != database.UserLimits[database.TierPremium20].TierName {
return fmt.Errorf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierPremium20].TierName, ul.TierName)
}
if ul.UploadBandwidth != database.UserLimits[database.TierAnonymous].UploadBandwidth {
return fmt.Errorf("Expected upload bandwidth '%d', got '%d'", database.UserLimits[database.TierAnonymous].UploadBandwidth, ul.UploadBandwidth)
}
return nil
})
if err != nil {
t.Fatal(err)
}
if ul.TierID != database.TierAnonymous {
t.Fatalf("Expected tier id '%d', got '%d'", database.TierAnonymous, ul.TierID)
}
if ul.TierName != database.UserLimits[database.TierAnonymous].TierName {
t.Fatalf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierAnonymous].TierName, ul.TierName)
}
// Delete the uploaded file, so the user's quota recovers.
// This call should invalidate the tier cache.
_, _, err = at.Delete("/user/uploads/"+sl.Skylink, nil)
time.Sleep(2 * time.Second)
// We expect to get TierPremium20.
_, b, err = at.Get("/user/limits", nil)
if err != nil {
t.Fatal(err)
}
err = json.Unmarshal(b, &ul)
err = build.Retry(10, 200*time.Millisecond, func() error {
// We expect to get TierPremium20.
ul, _, err = at.UserLimits()
if err != nil {
return errors.AddContext(err, "failed to call /user/limits")
}
if ul.TierID != database.TierPremium20 {
return fmt.Errorf("Expected tier id '%d', got '%d'", database.TierPremium20, ul.TierID)
}
if ul.TierName != database.UserLimits[database.TierPremium20].TierName {
return fmt.Errorf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierPremium20].TierName, ul.TierName)
}
return nil
})
if err != nil {
t.Fatal(err)
}
if ul.TierID != database.TierPremium20 {
t.Fatalf("Expected tier id '%d', got '%d'", database.TierPremium20, ul.TierID)
}
if ul.TierName != database.UserLimits[database.TierPremium20].TierName {
t.Fatalf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierPremium20].TierName, ul.TierName)
}
}
6 changes: 3 additions & 3 deletions test/api/apikeys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func testAPIKeysFlow(t *testing.T, at *test.AccountsTester) {
if err != nil {
t.Fatal(err, string(body))
}
at.Cookie = test.ExtractCookie(r)
at.SetCookie(test.ExtractCookie(r))

aks := make([]database.APIKeyRecord, 0)

Expand Down Expand Up @@ -115,7 +115,7 @@ func testAPIKeysUsage(t *testing.T, at *test.AccountsTester) {
if err != nil {
t.Fatal(err)
}
at.Cookie = test.ExtractCookie(r)
at.SetCookie(test.ExtractCookie(r))
// Get the user and create a test upload, so the stats won't be all zeros.
u, err := at.DB.UserByEmail(at.Ctx, email)
if err != nil {
Expand All @@ -132,7 +132,7 @@ func testAPIKeysUsage(t *testing.T, at *test.AccountsTester) {
t.Fatal(err)
}
// Stop using the cookie, so we can test the API key.
at.Cookie = nil
at.ClearCredentials()
// We use a custom struct and not the APIKeyRecord one because that one does
// not render the key in JSON form and therefore it won't unmarshal it,
// either.
Expand Down
Loading

0 comments on commit 6dff5dc

Please sign in to comment.