Skip to content

Commit

Permalink
Merge branch 'main' into ivo/pub_api_keys
Browse files Browse the repository at this point in the history
# Conflicts:
#	api/handlers.go
#	test/api/api_test.go
#	test/api/handlers_test.go
#	test/tester.go
  • Loading branch information
ro-tex committed Mar 18, 2022
2 parents f4bc82e + 9533fae commit 614c1ed
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 29 deletions.
44 changes: 31 additions & 13 deletions api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/mail"
"net/url"
"strconv"
"strings"
"time"

"github.com/SkynetLabs/skynet-accounts/build"
Expand Down Expand Up @@ -82,11 +83,13 @@ type (
EmailConfirmed bool `json:"emailConfirmed"`
}
// UserLimitsGET is response of GET /user/limits
// The returned speeds might be in bits or bytes per second, depending on
// the client's request.
UserLimitsGET struct {
TierID int `json:"tierID"`
TierName string `json:"tierName"`
UploadBandwidth int `json:"upload"` // bytes per second
DownloadBandwidth int `json:"download"` // bytes per second
UploadBandwidth int `json:"upload"` // bits or bytes per second
DownloadBandwidth int `json:"download"` // bits or bytes per second
MaxUploadSize int64 `json:"maxUploadSize"` // the max size of a single upload in bytes
MaxNumberUploads int `json:"-"`
RegistryDelay int `json:"registry"` // ms delay
Expand Down Expand Up @@ -405,15 +408,19 @@ func (api *API) userGET(u *database.User, w http.ResponseWriter, _ *http.Request
// NOTE: This handler needs to use the noAuth middleware in order to be able to
// optimise its calls to the DB and the use of caching.
func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
respAnon := userLimitsGetFromTier(database.TierAnonymous, false)
// inBytes is a flag indicating that the caller wants all bandwidth limits
// to be presented in bytes per second. The default behaviour is to present
// them in bits per second.
inBytes := strings.EqualFold(req.FormValue("unit"), "byte")
respAnon := userLimitsGetFromTier(database.TierAnonymous, false, inBytes)
// First check for an API key.
ak, err := apiKeyFromRequest(req)
if err == nil {
// Check the cache before going any further.
tier, qe, ok := api.staticUserTierCache.Get(ak.String())
if ok {
api.staticLogger.Traceln("Fetching user limits from cache by API key.")
api.WriteJSON(w, userLimitsGetFromTier(tier, qe))
api.WriteJSON(w, userLimitsGetFromTier(tier, qe, inBytes))
return
}
// Get the API key.
Expand All @@ -437,7 +444,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http
}
// Cache the user under the API key they used.
api.staticUserTierCache.Set(ak.String(), u)
api.WriteJSON(w, userLimitsGetFromTier(u.Tier, u.QuotaExceeded))
api.WriteJSON(w, userLimitsGetFromTier(u.Tier, u.QuotaExceeded, inBytes))
return
}
// Next check for a token.
Expand Down Expand Up @@ -471,7 +478,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http
build.Critical("Failed to fetch user from UserTierCache right after setting it.")
}
}
api.WriteJSON(w, userLimitsGetFromTier(tier, qe))
api.WriteJSON(w, userLimitsGetFromTier(tier, qe, inBytes))
}

// userLimitsSkylinkGET returns the speed limits which apply to a GET call to
Expand All @@ -480,7 +487,11 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http
// NOTE: This handler needs to use the noAuth middleware in order to be able to
// optimise its calls to the DB and the use of caching.
func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
respAnon := userLimitsGetFromTier(database.TierAnonymous, false)
// inBytes is a flag indicating that the caller wants all bandwidth limits
// to be presented in bytes per second. The default behaviour is to present
// them in bits per second.
inBytes := strings.EqualFold(req.FormValue("unit"), "byte")
respAnon := userLimitsGetFromTier(database.TierAnonymous, false, inBytes)
// Validate the skylink.
skylink := ps.ByName("skylink")
if !database.ValidSkylinkHash(skylink) {
Expand All @@ -505,7 +516,7 @@ func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, re
tier, qe, ok := api.staticUserTierCache.Get(ak.String() + skylink)
if ok {
api.staticLogger.Traceln("Fetching user limits from cache by API key.")
api.WriteJSON(w, userLimitsGetFromTier(tier, qe))
api.WriteJSON(w, userLimitsGetFromTier(tier, qe, inBytes))
return
}
// Get the API key.
Expand All @@ -529,7 +540,7 @@ func (api *API) userLimitsSkylinkGET(u *database.User, w http.ResponseWriter, re
}
// Store the user in the cache with a custom key.
api.staticUserTierCache.Set(ak.String()+skylink, user)
api.WriteJSON(w, userLimitsGetFromTier(user.Tier, user.QuotaExceeded))
api.WriteJSON(w, userLimitsGetFromTier(user.Tier, user.QuotaExceeded, inBytes))
}

// userStatsGET returns statistics about an existing user.
Expand Down Expand Up @@ -1260,8 +1271,9 @@ func parseRequestBodyJSON(body io.ReadCloser, maxBodySize int64, v interface{})
}

// userLimitsGetFromTier is a helper that lets us succinctly translate
// from the database DTO to the API DTO.
func userLimitsGetFromTier(tierID int, quotaExceeded bool) *UserLimitsGET {
// from the database DTO to the API DTO. The `inBytes` parameter determines
// whether the returned speeds will be in Bps or bps.
func userLimitsGetFromTier(tierID int, quotaExceeded, inBytes bool) *UserLimitsGET {
t, ok := database.UserLimits[tierID]
if !ok {
build.Critical("userLimitsGetFromTier was called with non-existent tierID: " + strconv.Itoa(tierID))
Expand All @@ -1271,14 +1283,20 @@ func userLimitsGetFromTier(tierID int, quotaExceeded bool) *UserLimitsGET {
if quotaExceeded {
limitsTier = database.UserLimits[database.TierAnonymous]
}
// If we need to return the result in bits per second, we multiply by 8,
// otherwise, we multiply by 1.
bpsMul := 8
if inBytes {
bpsMul = 1
}
return &UserLimitsGET{
TierID: tierID,
TierName: t.TierName,
Storage: t.Storage,
// If the user exceeds their quota, there will be brought down to
// anonymous levels.
UploadBandwidth: limitsTier.UploadBandwidth,
DownloadBandwidth: limitsTier.DownloadBandwidth,
UploadBandwidth: limitsTier.UploadBandwidth * bpsMul,
DownloadBandwidth: limitsTier.DownloadBandwidth * bpsMul,
MaxUploadSize: limitsTier.MaxUploadSize,
MaxNumberUploads: limitsTier.MaxNumberUploads,
RegistryDelay: limitsTier.RegistryDelay,
Expand Down
4 changes: 2 additions & 2 deletions api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestUserLimitsGetFromTier(t *testing.T) {
}

for _, tt := range tests {
ul := userLimitsGetFromTier(tt.tier, tt.quotaExceeded)
ul := userLimitsGetFromTier(tt.tier, tt.quotaExceeded, true)
if ul.TierID != tt.expectedTier {
t.Errorf("Test '%s': expected tier %d, got %d", tt.name, tt.expectedTier, ul.TierID)
}
Expand Down Expand Up @@ -121,7 +121,7 @@ func TestUserLimitsGetFromTier(t *testing.T) {
}
}()
// The call that we expect to log a critical.
_ = userLimitsGetFromTier(math.MaxInt, false)
_ = userLimitsGetFromTier(math.MaxInt, false, true)
return
}()
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions test/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func TestUserTierCache(t *testing.T) {
}
at.SetCookie(test.ExtractCookie(r))
// Get the user's limit.
ul, _, err := at.UserLimits(nil, nil)
ul, _, err := at.UserLimits("byte", nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -210,7 +210,7 @@ func TestUserTierCache(t *testing.T) {
err = build.Retry(10, 200*time.Millisecond, func() error {
// We expect to get tier with name and id matching TierPremium20 but with
// speeds matching TierAnonymous.
ul, _, err = at.UserLimits(nil, nil)
ul, _, err = at.UserLimits("byte", nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -236,7 +236,7 @@ func TestUserTierCache(t *testing.T) {
}
err = build.Retry(10, 200*time.Millisecond, func() error {
// We expect to get TierPremium20.
ul, _, err = at.UserLimits(nil, nil)
ul, _, err = at.UserLimits("byte", nil)
if err != nil {
return errors.AddContext(err, "failed to call /user/limits")
}
Expand Down
4 changes: 2 additions & 2 deletions test/api/apikeys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func testPublicAPIKeysUsage(t *testing.T, at *test.AccountsTester) {
}
// Get the user's limits for downloading a skylink covered by the public
// API key. Expect to get TierFree values.
ul, _, err := at.UserLimitsSkylink(sl.Skylink, nil, nil)
ul, _, err := at.UserLimitsSkylink(sl.Skylink, "byte", nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -278,7 +278,7 @@ func testPublicAPIKeysUsage(t *testing.T, at *test.AccountsTester) {
}
// Get the user's limits for downloading a skylink that is not covered by
// the public API key. Expect to get TierAnonymous values.
ul, _, err = at.UserLimitsSkylink(sl2.Skylink, nil, nil)
ul, _, err = at.UserLimitsSkylink(sl2.Skylink, "byte", nil)
if err != nil {
t.Fatal(err)
}
Expand Down
40 changes: 35 additions & 5 deletions test/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,8 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) {
t.Fatal(err)
}

// Call /user/limits with a cookie. Expect TierFree response.
tl, _, err := at.UserLimits(nil, nil)
// Call /user/limits with a cookie. Expect FreeTier response.
tl, _, err := at.UserLimits("byte", nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -458,7 +458,7 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) {

// Call /user/limits without a cookie. Expect FreeAnonymous response.
at.ClearCredentials()
tl, _, err = at.UserLimits(nil, nil)
tl, _, err = at.UserLimits("byte", nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -473,7 +473,7 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) {
}

// Call /user/limits with an API key. Expect TierFree response.
tl, _, err = at.UserLimits(nil, map[string]string{api.APIKeyHeader: string(akr.Key)})
tl, _, err = at.UserLimits("byte", map[string]string{api.APIKeyHeader: string(akr.Key)})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -522,7 +522,7 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) {
err = build.Retry(10, 200*time.Millisecond, func() error {
// Check the user's limits. We expect the tier to be Free but the limits to
// match Anonymous.
tl, _, err = at.UserLimits(nil, nil)
tl, _, err = at.UserLimits("byte", nil)
if err != nil {
return errors.AddContext(err, "failed to call /user/limits")
}
Expand All @@ -540,6 +540,36 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) {
if err != nil {
t.Fatal(err)
}

// Test the `unit` parameter. The only valid value is `byte`, anything else
// is ignored and the results are returned in bits per second.
tl, _, err = at.UserLimits("", nil)
if err != nil {
t.Fatal(err)
}
// Request it with an invalid value. Expect it to be ignored.
tlBits, _, err := at.UserLimits("not-a-byte", nil)
if err != nil {
t.Fatal(err)
}
if tlBits.UploadBandwidth != tl.UploadBandwidth || tlBits.DownloadBandwidth != tl.DownloadBandwidth {
t.Fatalf("Expected these to be equal. %+v, %+v", tl, tlBits)
}
tlBytes, _, err := at.UserLimits("byte", nil)
if err != nil {
t.Fatal(err)
}
if tlBytes.UploadBandwidth*8 != tl.UploadBandwidth || tlBytes.DownloadBandwidth*8 != tl.DownloadBandwidth {
t.Fatalf("Invalid values in bytes. Values in bps: %+v, values in Bps: %+v", tl, tlBytes)
}
// Ensure we're not case-sensitive.
tlBytes2, _, err := at.UserLimits("ByTe", nil)
if err != nil {
t.Fatal(err)
}
if tlBytes2.UploadBandwidth != tlBytes.UploadBandwidth || tlBytes2.DownloadBandwidth != tlBytes.DownloadBandwidth {
t.Fatalf("Got different values for different capitalizations of 'byte'.\nValues for 'byte': %+v, values for 'ByTe': %+v", tlBytes, tlBytes2)
}
}

// testUserUploadsDELETE tests the DELETE /user/uploads/:skylink endpoint.
Expand Down
16 changes: 12 additions & 4 deletions test/tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,12 @@ func (at *AccountsTester) TrackRegistryWrite() (int, error) {
}

// UserLimits performs a `GET /user/limits` request.
func (at *AccountsTester) UserLimits(params url.Values, headers map[string]string) (api.UserLimitsGET, int, error) {
r, b, err := at.request(http.MethodGet, "/user/limits", params, nil, headers)
func (at *AccountsTester) UserLimits(unit string, headers map[string]string) (api.UserLimitsGET, int, error) {
queryParams := url.Values{}
if unit != "" {
queryParams.Set("unit", unit)
}
r, b, err := at.request(http.MethodGet, "/user/limits", queryParams, nil, headers)
if err != nil {
return api.UserLimitsGET{}, r.StatusCode, err
}
Expand All @@ -471,11 +475,15 @@ func (at *AccountsTester) UserLimits(params url.Values, headers map[string]strin
}

// UserLimitsSkylink performs a `GET /user/limits/:skylink` request.
func (at *AccountsTester) UserLimitsSkylink(sl string, params url.Values, headers map[string]string) (api.UserLimitsGET, int, error) {
func (at *AccountsTester) UserLimitsSkylink(sl string, unit string, headers map[string]string) (api.UserLimitsGET, int, error) {
queryParams := url.Values{}
if unit != "" {
queryParams.Set("unit", unit)
}
if !database.ValidSkylinkHash(sl) {
return api.UserLimitsGET{}, 0, database.ErrInvalidSkylink
}
r, b, err := at.request(http.MethodGet, "/user/limits/"+sl, params, nil, headers)
r, b, err := at.request(http.MethodGet, "/user/limits/"+sl, queryParams, nil, headers)
if err != nil {
return api.UserLimitsGET{}, r.StatusCode, err
}
Expand Down

0 comments on commit 614c1ed

Please sign in to comment.