diff --git a/api/cache_test.go b/api/cache_test.go index fa49f5a6..161099e8 100644 --- a/api/cache_test.go +++ b/api/cache_test.go @@ -21,7 +21,7 @@ func TestUserTierCache(t *testing.T) { if ok || tier != database.TierAnonymous { t.Fatalf("Expected to get tier %d and %t, got %d and %t.", database.TierAnonymous, false, tier, ok) } - // Set the use in the cache. + // Set the user in the cache. cache.Set(u.Sub, u) // Check again. tier, qe, ok := cache.Get(u.Sub) diff --git a/api/handlers.go b/api/handlers.go index 7aa6eb16..18d57f0f 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -2,7 +2,6 @@ package api import ( "context" - "crypto/subtle" "encoding/json" "io" "io/ioutil" @@ -748,20 +747,7 @@ func (api *API) userPubKeyDELETE(u *database.User, w http.ResponseWriter, req *h api.WriteError(w, errors.New("the given pubkey is not associated with this user"), http.StatusBadRequest) return } - // Find the position of the pubkey in the list. - keyIdx := -1 - for i, k := range u.PubKeys { - if subtle.ConstantTimeCompare(pk[:], k[:]) == 1 { - keyIdx = i - break - } - } - if keyIdx == -1 { - build.Critical("Reaching this should be impossible. It would indicate a concurrent change of the user struct.") - } - // Remove the pubkey. - u.PubKeys = append(u.PubKeys[:keyIdx], u.PubKeys[keyIdx+1:]...) - err = api.staticDB.UserSave(ctx, u) + err = api.staticDB.UserPubKeyRemove(ctx, *u, pk) if err != nil { api.WriteError(w, err, http.StatusInternalServerError) return @@ -852,8 +838,12 @@ func (api *API) userPubKeyRegisterPOST(u *database.User, w http.ResponseWriter, api.WriteError(w, errors.New("user's sub doesn't match update sub"), http.StatusBadRequest) return } - u.PubKeys = append(u.PubKeys, pk) - err = api.staticDB.UserSave(ctx, u) + err = api.staticDB.UserPubKeyAdd(ctx, *u, pk) + if err != nil { + api.WriteError(w, err, http.StatusInternalServerError) + return + } + updatedUser, err := api.staticDB.UserByID(ctx, u.ID) if err != nil { api.WriteError(w, err, http.StatusInternalServerError) return @@ -863,7 +853,7 @@ func (api *API) userPubKeyRegisterPOST(u *database.User, w http.ResponseWriter, api.WriteError(w, err, http.StatusInternalServerError) return } - api.loginUser(w, u, true) + api.loginUser(w, updatedUser, true) } // userUploadsGET returns all uploads made by the current user. diff --git a/database/user.go b/database/user.go index de53d08d..64136872 100644 --- a/database/user.go +++ b/database/user.go @@ -128,10 +128,7 @@ type ( SubscriptionCancelAtPeriodEnd bool `bson:"subscription_cancel_at_period_end" json:"subscriptionCancelAtPeriodEnd"` StripeID string `bson:"stripe_id" json:"stripeCustomerId"` QuotaExceeded bool `bson:"quota_exceeded" json:"quotaExceeded"` - // The currently active (or default) key is going to be the first one in - // the list. If we want to activate a new pubkey, we'll just move it to - // the first position in the list. - PubKeys []PubKey `bson:"pub_keys" json:"-"` + PubKeys []PubKey `bson:"pub_keys" json:"-"` } // UserStats contains statistical information about the user. UserStats struct { @@ -477,6 +474,39 @@ func (db *DB) UserSave(ctx context.Context, u *User) error { return nil } +// UserPubKeyAdd adds a new PubKey to the given user's set. +func (db *DB) UserPubKeyAdd(ctx context.Context, u User, pk PubKey) (err error) { + // If the set of pubkeys is not initialised we cannot use mongo's mutation + // operations, such as $addToSet, so we'll save the entire user record. + if u.PubKeys == nil { + u.PubKeys = make([]PubKey, 1) + u.PubKeys[0] = pk + return db.UserSave(ctx, &u) + } + filter := bson.M{"_id": u.ID} + update := bson.M{ + "$addToSet": bson.M{"pub_keys": pk}, + } + opts := options.UpdateOptions{ + Upsert: &False, + } + _, err = db.staticUsers.UpdateOne(ctx, filter, update, &opts) + return err +} + +// UserPubKeyRemove removes a PubKey from the given user's set. +func (db *DB) UserPubKeyRemove(ctx context.Context, u User, pk PubKey) error { + filter := bson.M{"_id": u.ID} + update := bson.M{ + "$pull": bson.M{"pub_keys": pk}, + } + opts := options.UpdateOptions{ + Upsert: &False, + } + _, err := db.staticUsers.UpdateOne(ctx, filter, update, &opts) + return err +} + // UserSetStripeID changes the user's stripe id in the DB. func (db *DB) UserSetStripeID(ctx context.Context, u *User, stripeID string) error { filter := bson.M{"_id": u.ID} diff --git a/test/api/apikeys_test.go b/test/api/apikeys_test.go index 2f7f2c8e..a08dc71f 100644 --- a/test/api/apikeys_test.go +++ b/test/api/apikeys_test.go @@ -255,12 +255,12 @@ func testPublicAPIKeysUsage(t *testing.T, at *test.AccountsTester) { Public: true, Skylinks: []string{sl.Skylink}, } - akWithKey, _, err := at.UserAPIKeysPOST(apiKeyPOST) + pakWithKey, _, err := at.UserAPIKeysPOST(apiKeyPOST) if err != nil { t.Fatal(err) } // Stop using the cookie, use the public API key instead. - at.SetAPIKey(akWithKey.Key.String()) + at.SetAPIKey(pakWithKey.Key.String()) // Try to fetch the user's stats with the new public API key. // Expect this to fail. _, _, err = at.Get("/user/stats", nil) @@ -269,7 +269,7 @@ func testPublicAPIKeysUsage(t *testing.T, at *test.AccountsTester) { } // Get the user's limits for downloading a skylink covered by the public // API key. Expect to get TierFree values. - ul, _, err := at.UserLimitsSkylink(sl.Skylink, "byte", nil) + ul, _, err := at.UserLimitsSkylink(sl.Skylink, "byte", "", nil) if err != nil { t.Fatal(err) } @@ -278,11 +278,22 @@ func testPublicAPIKeysUsage(t *testing.T, at *test.AccountsTester) { } // Get the user's limits for downloading a skylink that is not covered by // the public API key. Expect to get TierAnonymous values. - ul, _, err = at.UserLimitsSkylink(sl2.Skylink, "byte", nil) + ul, _, err = at.UserLimitsSkylink(sl2.Skylink, "byte", "", nil) if err != nil { t.Fatal(err) } if ul.DownloadBandwidth != database.UserLimits[database.TierAnonymous].DownloadBandwidth { t.Fatalf("Expected to get download bandwidth of %d, got %d", database.UserLimits[database.TierAnonymous].DownloadBandwidth, ul.DownloadBandwidth) } + // Stop using the header, pass the skylink as a query parameter. + at.ClearCredentials() + // Get the user's limits for downloading a skylink covered by the public + // API key. Expect to get TierFree values. + ul, _, err = at.UserLimitsSkylink(sl.Skylink, "byte", pakWithKey.Key.String(), nil) + if err != nil { + t.Fatal(err) + } + if ul.DownloadBandwidth != database.UserLimits[database.TierFree].DownloadBandwidth { + t.Fatalf("Expected to get download bandwidth of %d, got %d", database.UserLimits[database.TierFree].DownloadBandwidth, ul.DownloadBandwidth) + } } diff --git a/test/database/user_test.go b/test/database/user_test.go index f8caa4d6..ac3315c5 100644 --- a/test/database/user_test.go +++ b/test/database/user_test.go @@ -2,6 +2,7 @@ package database import ( "context" + "crypto/subtle" "reflect" "testing" "time" @@ -438,6 +439,85 @@ func TestUserSetStripeID(t *testing.T) { } } +// TestUserPubKey tests UserPubKeyAdd and UserPubKeyRemove. +func TestUserPubKey(t *testing.T) { + ctx := context.Background() + dbName := test.DBNameForTest(t.Name()) + db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil) + if err != nil { + t.Fatal(err) + } + // Create a test user. + u, err := db.UserCreate(ctx, t.Name()+"@siasky.net", t.Name()+"pass", t.Name()+"sub", database.TierFree) + if err != nil { + t.Fatal(err) + } + defer func(user *database.User) { + _ = db.UserDelete(ctx, user) + }(u) + // Add a pubkey. + pk := database.PubKey(make([]byte, database.PubKeySize)) + copy(pk[:], fastrand.Bytes(database.PubKeySize)) + err = db.UserPubKeyAdd(ctx, *u, pk) + if err != nil { + t.Fatal(err) + } + u1, err := db.UserByID(ctx, u.ID) + if err != nil { + t.Fatal(err) + } + if len(u1.PubKeys) == 1 && subtle.ConstantTimeCompare(u1.PubKeys[0][:], pk[:]) != 1 { + t.Fatalf("Expected the user to have a single pubkey which matches ours. Got %+v, pubkey %+v", u1.PubKeys, pk) + } + // Add another. + var pk1 database.PubKey + copy(pk1[:], fastrand.Bytes(database.PubKeySize)) + err = db.UserPubKeyAdd(ctx, *u, pk) + if err != nil { + t.Fatal(err) + } + u2, err := db.UserByID(ctx, u.ID) + if err != nil { + t.Fatal(err) + } + if len(u2.PubKeys) == 2 && subtle.ConstantTimeCompare(u1.PubKeys[1][:], pk1[:]) == 1 { + t.Fatalf("Expected the user to have a single pubkey which matches ours. Got %+v, pubkey %+v", u2.PubKeys, pk1) + } + // Delete a pubkey. + err = db.UserPubKeyRemove(ctx, *u, pk) + if err != nil { + t.Fatal(err) + } + u3, err := db.UserByID(ctx, u.ID) + if err != nil { + t.Fatal(err) + } + if len(u3.PubKeys) == 1 && subtle.ConstantTimeCompare(u3.PubKeys[0][:], pk1[:]) == 1 { + t.Fatalf("Expected the user to have a single pubkey which matches ours. Got %+v, pubkey %+v", u3.PubKeys, pk1) + } + // Make sure UserPubKeyRemove removes all copies of the pubkey from the set. + // We don't expect there to be multiple but we still want to make sure. + u.PubKeys = make([]database.PubKey, 0) + u.PubKeys = append(u.PubKeys, pk) + u.PubKeys = append(u.PubKeys, pk) + u.PubKeys = append(u.PubKeys, pk) + err = db.UserSave(ctx, u) + if err != nil { + t.Fatal(err) + } + err = db.UserPubKeyRemove(ctx, *u, pk) + if err != nil { + t.Fatal(err) + } + u4, err := db.UserByID(ctx, u.ID) + if err != nil { + t.Fatal(err) + } + if len(u4.PubKeys) > 0 { + t.Fatal("Expected zero pubkeys.") + } +} + // TestUserSetTier ensures that UserSetTier works as expected. func TestUserSetTier(t *testing.T) { ctx := context.Background() diff --git a/test/tester.go b/test/tester.go index 0ed24e77..ab4ff637 100644 --- a/test/tester.go +++ b/test/tester.go @@ -475,11 +475,10 @@ func (at *AccountsTester) UserLimits(unit string, headers map[string]string) (ap } // UserLimitsSkylink performs a `GET /user/limits/:skylink` request. -func (at *AccountsTester) UserLimitsSkylink(sl string, unit string, headers map[string]string) (api.UserLimitsGET, int, error) { +func (at *AccountsTester) UserLimitsSkylink(sl string, unit, apikey string, headers map[string]string) (api.UserLimitsGET, int, error) { queryParams := url.Values{} - if unit != "" { - queryParams.Set("unit", unit) - } + queryParams.Set("unit", unit) + queryParams.Set("apiKey", apikey) if !database.ValidSkylinkHash(sl) { return api.UserLimitsGET{}, 0, database.ErrInvalidSkylink }