Skip to content

Commit

Permalink
Use dedicated DB methods for adding and removing pubkeys from the use…
Browse files Browse the repository at this point in the history
…r record.
  • Loading branch information
ro-tex committed Mar 21, 2022
1 parent fd1d965 commit f1c6a02
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 31 deletions.
2 changes: 1 addition & 1 deletion api/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestUserTierCache(t *testing.T) {
if ok || tier != database.TierAnonymous {
t.Fatalf("Expected to get tier %d and %t, got %d and %t.", database.TierAnonymous, false, tier, ok)
}
// Set the use in the cache.
// Set the user in the cache.
cache.Set(u.Sub, u)
// Check again.
tier, qe, ok := cache.Get(u.Sub)
Expand Down
26 changes: 8 additions & 18 deletions api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"context"
"crypto/subtle"
"encoding/json"
"io"
"io/ioutil"
Expand Down Expand Up @@ -748,20 +747,7 @@ func (api *API) userPubKeyDELETE(u *database.User, w http.ResponseWriter, req *h
api.WriteError(w, errors.New("the given pubkey is not associated with this user"), http.StatusBadRequest)
return
}
// Find the position of the pubkey in the list.
keyIdx := -1
for i, k := range u.PubKeys {
if subtle.ConstantTimeCompare(pk[:], k[:]) == 1 {
keyIdx = i
break
}
}
if keyIdx == -1 {
build.Critical("Reaching this should be impossible. It would indicate a concurrent change of the user struct.")
}
// Remove the pubkey.
u.PubKeys = append(u.PubKeys[:keyIdx], u.PubKeys[keyIdx+1:]...)
err = api.staticDB.UserSave(ctx, u)
err = api.staticDB.UserPubKeyRemove(ctx, *u, pk)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
Expand Down Expand Up @@ -852,8 +838,12 @@ func (api *API) userPubKeyRegisterPOST(u *database.User, w http.ResponseWriter,
api.WriteError(w, errors.New("user's sub doesn't match update sub"), http.StatusBadRequest)
return
}
u.PubKeys = append(u.PubKeys, pk)
err = api.staticDB.UserSave(ctx, u)
err = api.staticDB.UserPubKeyAdd(ctx, *u, pk)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
}
updatedUser, err := api.staticDB.UserByID(ctx, u.ID)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
Expand All @@ -863,7 +853,7 @@ func (api *API) userPubKeyRegisterPOST(u *database.User, w http.ResponseWriter,
api.WriteError(w, err, http.StatusInternalServerError)
return
}
api.loginUser(w, u, true)
api.loginUser(w, updatedUser, true)
}

// userUploadsGET returns all uploads made by the current user.
Expand Down
38 changes: 34 additions & 4 deletions database/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,7 @@ type (
SubscriptionCancelAtPeriodEnd bool `bson:"subscription_cancel_at_period_end" json:"subscriptionCancelAtPeriodEnd"`
StripeID string `bson:"stripe_id" json:"stripeCustomerId"`
QuotaExceeded bool `bson:"quota_exceeded" json:"quotaExceeded"`
// The currently active (or default) key is going to be the first one in
// the list. If we want to activate a new pubkey, we'll just move it to
// the first position in the list.
PubKeys []PubKey `bson:"pub_keys" json:"-"`
PubKeys []PubKey `bson:"pub_keys" json:"-"`
}
// UserStats contains statistical information about the user.
UserStats struct {
Expand Down Expand Up @@ -477,6 +474,39 @@ func (db *DB) UserSave(ctx context.Context, u *User) error {
return nil
}

// UserPubKeyAdd adds a new PubKey to the given user's set.
func (db *DB) UserPubKeyAdd(ctx context.Context, u User, pk PubKey) (err error) {
// If the set of pubkeys is not initialised we cannot use mongo's mutation
// operations, such as $addToSet, so we'll save the entire user record.
if u.PubKeys == nil {
u.PubKeys = make([]PubKey, 1)
u.PubKeys[0] = pk
return db.UserSave(ctx, &u)
}
filter := bson.M{"_id": u.ID}
update := bson.M{
"$addToSet": bson.M{"pub_keys": pk},
}
opts := options.UpdateOptions{
Upsert: &False,
}
_, err = db.staticUsers.UpdateOne(ctx, filter, update, &opts)
return err
}

// UserPubKeyRemove removes a PubKey from the given user's set.
func (db *DB) UserPubKeyRemove(ctx context.Context, u User, pk PubKey) error {
filter := bson.M{"_id": u.ID}
update := bson.M{
"$pull": bson.M{"pub_keys": pk},
}
opts := options.UpdateOptions{
Upsert: &False,
}
_, err := db.staticUsers.UpdateOne(ctx, filter, update, &opts)
return err
}

// UserSetStripeID changes the user's stripe id in the DB.
func (db *DB) UserSetStripeID(ctx context.Context, u *User, stripeID string) error {
filter := bson.M{"_id": u.ID}
Expand Down
19 changes: 15 additions & 4 deletions test/api/apikeys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,12 @@ func testPublicAPIKeysUsage(t *testing.T, at *test.AccountsTester) {
Public: true,
Skylinks: []string{sl.Skylink},
}
akWithKey, _, err := at.UserAPIKeysPOST(apiKeyPOST)
pakWithKey, _, err := at.UserAPIKeysPOST(apiKeyPOST)
if err != nil {
t.Fatal(err)
}
// Stop using the cookie, use the public API key instead.
at.SetAPIKey(akWithKey.Key.String())
at.SetAPIKey(pakWithKey.Key.String())
// Try to fetch the user's stats with the new public API key.
// Expect this to fail.
_, _, err = at.Get("/user/stats", nil)
Expand All @@ -269,7 +269,7 @@ func testPublicAPIKeysUsage(t *testing.T, at *test.AccountsTester) {
}
// Get the user's limits for downloading a skylink covered by the public
// API key. Expect to get TierFree values.
ul, _, err := at.UserLimitsSkylink(sl.Skylink, "byte", nil)
ul, _, err := at.UserLimitsSkylink(sl.Skylink, "byte", "", nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -278,11 +278,22 @@ func testPublicAPIKeysUsage(t *testing.T, at *test.AccountsTester) {
}
// Get the user's limits for downloading a skylink that is not covered by
// the public API key. Expect to get TierAnonymous values.
ul, _, err = at.UserLimitsSkylink(sl2.Skylink, "byte", nil)
ul, _, err = at.UserLimitsSkylink(sl2.Skylink, "byte", "", nil)
if err != nil {
t.Fatal(err)
}
if ul.DownloadBandwidth != database.UserLimits[database.TierAnonymous].DownloadBandwidth {
t.Fatalf("Expected to get download bandwidth of %d, got %d", database.UserLimits[database.TierAnonymous].DownloadBandwidth, ul.DownloadBandwidth)
}
// Stop using the header, pass the skylink as a query parameter.
at.ClearCredentials()
// Get the user's limits for downloading a skylink covered by the public
// API key. Expect to get TierFree values.
ul, _, err = at.UserLimitsSkylink(sl.Skylink, "byte", pakWithKey.Key.String(), nil)
if err != nil {
t.Fatal(err)
}
if ul.DownloadBandwidth != database.UserLimits[database.TierFree].DownloadBandwidth {
t.Fatalf("Expected to get download bandwidth of %d, got %d", database.UserLimits[database.TierFree].DownloadBandwidth, ul.DownloadBandwidth)
}
}
80 changes: 80 additions & 0 deletions test/database/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package database

import (
"context"
"crypto/subtle"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -438,6 +439,85 @@ func TestUserSetStripeID(t *testing.T) {
}
}

// TestUserPubKey tests UserPubKeyAdd and UserPubKeyRemove.
func TestUserPubKey(t *testing.T) {
ctx := context.Background()
dbName := test.DBNameForTest(t.Name())
db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil)
if err != nil {
t.Fatal(err)
}
// Create a test user.
u, err := db.UserCreate(ctx, t.Name()+"@siasky.net", t.Name()+"pass", t.Name()+"sub", database.TierFree)
if err != nil {
t.Fatal(err)
}
defer func(user *database.User) {
_ = db.UserDelete(ctx, user)
}(u)
// Add a pubkey.
pk := database.PubKey(make([]byte, database.PubKeySize))
copy(pk[:], fastrand.Bytes(database.PubKeySize))
err = db.UserPubKeyAdd(ctx, *u, pk)
if err != nil {
t.Fatal(err)
}
u1, err := db.UserByID(ctx, u.ID)
if err != nil {
t.Fatal(err)
}
if len(u1.PubKeys) == 1 && subtle.ConstantTimeCompare(u1.PubKeys[0][:], pk[:]) != 1 {
t.Fatalf("Expected the user to have a single pubkey which matches ours. Got %+v, pubkey %+v", u1.PubKeys, pk)
}
// Add another.
var pk1 database.PubKey
copy(pk1[:], fastrand.Bytes(database.PubKeySize))
err = db.UserPubKeyAdd(ctx, *u, pk)
if err != nil {
t.Fatal(err)
}
u2, err := db.UserByID(ctx, u.ID)
if err != nil {
t.Fatal(err)
}
if len(u2.PubKeys) == 2 && subtle.ConstantTimeCompare(u1.PubKeys[1][:], pk1[:]) == 1 {
t.Fatalf("Expected the user to have a single pubkey which matches ours. Got %+v, pubkey %+v", u2.PubKeys, pk1)
}
// Delete a pubkey.
err = db.UserPubKeyRemove(ctx, *u, pk)
if err != nil {
t.Fatal(err)
}
u3, err := db.UserByID(ctx, u.ID)
if err != nil {
t.Fatal(err)
}
if len(u3.PubKeys) == 1 && subtle.ConstantTimeCompare(u3.PubKeys[0][:], pk1[:]) == 1 {
t.Fatalf("Expected the user to have a single pubkey which matches ours. Got %+v, pubkey %+v", u3.PubKeys, pk1)
}
// Make sure UserPubKeyRemove removes all copies of the pubkey from the set.
// We don't expect there to be multiple but we still want to make sure.
u.PubKeys = make([]database.PubKey, 0)
u.PubKeys = append(u.PubKeys, pk)
u.PubKeys = append(u.PubKeys, pk)
u.PubKeys = append(u.PubKeys, pk)
err = db.UserSave(ctx, u)
if err != nil {
t.Fatal(err)
}
err = db.UserPubKeyRemove(ctx, *u, pk)
if err != nil {
t.Fatal(err)
}
u4, err := db.UserByID(ctx, u.ID)
if err != nil {
t.Fatal(err)
}
if len(u4.PubKeys) > 0 {
t.Fatal("Expected zero pubkeys.")
}
}

// TestUserSetTier ensures that UserSetTier works as expected.
func TestUserSetTier(t *testing.T) {
ctx := context.Background()
Expand Down
7 changes: 3 additions & 4 deletions test/tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,11 +475,10 @@ func (at *AccountsTester) UserLimits(unit string, headers map[string]string) (ap
}

// UserLimitsSkylink performs a `GET /user/limits/:skylink` request.
func (at *AccountsTester) UserLimitsSkylink(sl string, unit string, headers map[string]string) (api.UserLimitsGET, int, error) {
func (at *AccountsTester) UserLimitsSkylink(sl string, unit, apikey string, headers map[string]string) (api.UserLimitsGET, int, error) {
queryParams := url.Values{}
if unit != "" {
queryParams.Set("unit", unit)
}
queryParams.Set("unit", unit)
queryParams.Set("apiKey", apikey)
if !database.ValidSkylinkHash(sl) {
return api.UserLimitsGET{}, 0, database.ErrInvalidSkylink
}
Expand Down

0 comments on commit f1c6a02

Please sign in to comment.