Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anonymous uploads + track upload IPs #167

Merged
merged 9 commits into from
Mar 25, 2022
46 changes: 42 additions & 4 deletions api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"io"
"io/ioutil"
"net"
"net/http"
"net/mail"
"net/url"
Expand All @@ -20,6 +21,7 @@ import (
"github.com/SkynetLabs/skynet-accounts/metafetcher"
"github.com/SkynetLabs/skynet-accounts/skynet"
"github.com/julienschmidt/httprouter"
jwt2 "github.com/lestrrat-go/jwx/jwt"
"gitlab.com/NebulousLabs/errors"
"go.mongodb.org/mongo-driver/mongo"
)
Expand Down Expand Up @@ -1062,7 +1064,7 @@ func (api *API) userRecoverPOST(_ *database.User, w http.ResponseWriter, req *ht
}

// trackUploadPOST registers a new upload in the system.
func (api *API) trackUploadPOST(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
func (api *API) trackUploadPOST(_ *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
sl := ps.ByName("skylink")
if sl == "" {
api.WriteError(w, errors.New("missing parameter 'skylink'"), http.StatusBadRequest)
Expand All @@ -1077,7 +1079,13 @@ func (api *API) trackUploadPOST(u *database.User, w http.ResponseWriter, req *ht
api.WriteError(w, err, http.StatusInternalServerError)
return
}
_, err = api.staticDB.UploadCreate(req.Context(), *u, *skylink)
u, _, _ := api.userFromRequest(req)
if u == nil {
// This will be tracked as an anonymous request.
u = &database.AnonUser
}
ip := validateIP(req.Form.Get("ip"))
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
_, err = api.staticDB.UploadCreate(req.Context(), *u, ip, *skylink)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
Expand All @@ -1096,7 +1104,9 @@ func (api *API) trackUploadPOST(u *database.User, w http.ResponseWriter, req *ht
// administrative details, such as user's quotas check.
// Note that this call is not affected by the request's context, so we use
// a separate one.
go api.checkUserQuotas(context.Background(), u)
if u != nil && !u.ID.IsZero() {
go api.checkUserQuotas(context.Background(), u)
}
}

// trackDownloadPOST registers a new download in the system.
Expand Down Expand Up @@ -1136,7 +1146,7 @@ func (api *API) trackDownloadPOST(u *database.User, w http.ResponseWriter, req *
api.WriteError(w, err, http.StatusInternalServerError)
return
}
err = api.staticDB.DownloadCreate(req.Context(), *u, *skylink, downloadedBytes)
_, err = api.staticDB.DownloadCreate(req.Context(), *u, *skylink, downloadedBytes)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
Expand Down Expand Up @@ -1225,6 +1235,25 @@ func (api *API) checkUserQuotas(ctx context.Context, u *database.User) {
}
}

// userFromRequest checks the requests for various forms of authentication (API
// key, cookie, authorization header) and returns user information based on
// those.
func (api *API) userFromRequest(req *http.Request) (*database.User, jwt2.Token, error) {
// Check for an API key.
u, tk, err := api.userAndTokenByAPIKey(req)
if err != nil && !errors.Contains(err, ErrNoAPIKey) {
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
return nil, nil, err
}
// If there is no API key check for a token.
if errors.Contains(err, ErrNoAPIKey) {
u, tk, err = api.userAndTokenByRequestToken(req)
if err != nil {
return nil, nil, err
}
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
}
return u, tk, err
}

// wellKnownJWKSGET returns our public JWKS, so people can use that to verify
// the authenticity of the JWT tokens we issue.
func (api *API) wellKnownJWKSGET(_ *database.User, w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
Expand Down Expand Up @@ -1302,3 +1331,12 @@ func userLimitsGetFromTier(tierID int, quotaExceeded, inBytes bool) *UserLimitsG
RegistryDelay: limitsTier.RegistryDelay,
}
}

// validateIP is a simple pass-through helper that returns valid IPs as they are
// and returns an empty string for invalid IPs.
func validateIP(ip string) string {
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
if parsedIP := net.ParseIP(ip); parsedIP != nil {
return parsedIP.String()
}
return ""
}
22 changes: 4 additions & 18 deletions api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (api *API) buildHTTPRoutes() {
api.staticRouter.POST("/register", api.WithDBSession(api.noAuth(api.registerPOST)))

// Endpoints at which Nginx reports portal usage.
api.staticRouter.POST("/track/upload/:skylink", api.withAuth(api.trackUploadPOST))
api.staticRouter.POST("/track/upload/:skylink", api.noAuth(api.trackUploadPOST))
api.staticRouter.POST("/track/download/:skylink", api.withAuth(api.trackDownloadPOST))
api.staticRouter.POST("/track/registry/read", api.withAuth(api.trackRegistryReadPOST))
api.staticRouter.POST("/track/registry/write", api.withAuth(api.trackRegistryWritePOST))
Expand Down Expand Up @@ -92,25 +92,11 @@ func (api *API) noAuth(h HandlerWithUser) httprouter.Handle {
func (api *API) withAuth(h HandlerWithUser) httprouter.Handle {
return func(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
api.logRequest(req)
// Check for an API key.
u, token, err := api.userAndTokenByAPIKey(req)
// If there is an unexpected error, that is a 500.
if err != nil && !errors.Contains(err, ErrNoAPIKey) && !errors.Contains(err, database.ErrInvalidAPIKey) && !errors.Contains(err, database.ErrUserNotFound) {
api.WriteError(w, err, http.StatusInternalServerError)
u, token, err := api.userFromRequest(req)
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return
}
if err != nil && (errors.Contains(err, database.ErrInvalidAPIKey) || errors.Contains(err, database.ErrUserNotFound)) {
api.WriteError(w, errors.AddContext(err, "failed to fetch user by API key"), http.StatusUnauthorized)
return
}
// If there is no API key check for a token.
if errors.Contains(err, ErrNoAPIKey) {
u, token, err = api.userAndTokenByRequestToken(req)
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return
}
}
// Embed the verified token in the context of the request.
ctx := jwt.ContextWithToken(req.Context(), token)
h(u, w, req.WithContext(ctx), ps)
Expand Down
22 changes: 12 additions & 10 deletions database/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,17 @@ func (db *DB) DownloadByID(ctx context.Context, id primitive.ObjectID) (*Downloa

// DownloadCreate registers a new download. Marks partial downloads by supplying
// the `bytes` param. If `bytes` is 0 we assume a full download.
func (db *DB) DownloadCreate(ctx context.Context, user User, skylink Skylink, bytes int64) error {
if user.ID.IsZero() {
return errors.New("invalid user")
}
func (db *DB) DownloadCreate(ctx context.Context, user User, skylink Skylink, bytes int64) (*Download, error) {
if skylink.ID.IsZero() {
return ErrInvalidSkylink
return nil, ErrInvalidSkylink
}

// Check if there exists a download of this skylink by this user, updated
// within the DownloadUpdateWindow and keep updating that, if so.
down, err := db.DownloadRecent(ctx, skylink.ID)
down, err := db.DownloadRecent(ctx, user.ID, skylink.ID)
if err == nil {
// We found a recent download of this skylink. Let's update it.
return db.DownloadIncrement(ctx, down, bytes)
return nil, db.DownloadIncrement(ctx, down, bytes)
}

// We couldn't find a recent download of this skylink, updated within
Expand All @@ -76,8 +73,12 @@ func (db *DB) DownloadCreate(ctx context.Context, user User, skylink Skylink, by
CreatedAt: time.Now().UTC().Truncate(time.Millisecond),
UpdatedAt: time.Now().UTC().Truncate(time.Millisecond),
}
_, err = db.staticDownloads.InsertOne(ctx, down)
return err
ior, err := db.staticDownloads.InsertOne(ctx, down)
if err != nil {
return nil, err
}
down.ID = ior.InsertedID.(primitive.ObjectID)
return down, nil
}

// DownloadsBySkylink fetches a page of downloads of this skylink and the total
Expand Down Expand Up @@ -126,9 +127,10 @@ func (db *DB) downloadsBy(ctx context.Context, matchStage bson.D, offset, pageSi
}

// DownloadRecent returns the most recent download of the given skylink.
func (db *DB) DownloadRecent(ctx context.Context, skylinkID primitive.ObjectID) (*Download, error) {
func (db *DB) DownloadRecent(ctx context.Context, uID primitive.ObjectID, skylinkID primitive.ObjectID) (*Download, error) {
updatedAtThreshold := time.Now().UTC().Add(-1 * DownloadUpdateWindow)
filter := bson.D{
{"user_id", uID},
{"skylink_id", skylinkID},
{"updated_at", bson.D{{"$gt", updatedAtThreshold}}},
}
Expand Down
23 changes: 11 additions & 12 deletions database/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ import (

// Upload ...
type Upload struct {
ID primitive.ObjectID `bson:"_id,omitempty" json:"id"`
UserID primitive.ObjectID `bson:"user_id,omitempty" json:"userId"`
SkylinkID primitive.ObjectID `bson:"skylink_id,omitempty" json:"skylinkId"`
Timestamp time.Time `bson:"timestamp" json:"timestamp"`
Unpinned bool `bson:"unpinned" json:"-"`
ID primitive.ObjectID `bson:"_id,omitempty" json:"id"`
UserID primitive.ObjectID `bson:"user_id,omitempty" json:"userId"`
UploaderIP string `bson:"uploader_ip" json:"uploaderIP"`
SkylinkID primitive.ObjectID `bson:"skylink_id,omitempty" json:"skylinkId"`
Timestamp time.Time `bson:"timestamp" json:"timestamp"`
Unpinned bool `bson:"unpinned" json:"-"`
}

// UploadResponse is the representation of an upload we send as response to
Expand Down Expand Up @@ -45,17 +46,15 @@ func (db *DB) UploadByID(ctx context.Context, id primitive.ObjectID) (*Upload, e

// UploadCreate registers a new upload and counts it towards the user's used
// storage.
func (db *DB) UploadCreate(ctx context.Context, user User, skylink Skylink) (*Upload, error) {
if user.ID.IsZero() {
return nil, errors.New("invalid user")
}
func (db *DB) UploadCreate(ctx context.Context, user User, ip string, skylink Skylink) (*Upload, error) {
if skylink.ID.IsZero() {
return nil, errors.New("skylink doesn't exist")
}
up := Upload{
UserID: user.ID,
SkylinkID: skylink.ID,
Timestamp: time.Now().UTC(),
UserID: user.ID,
UploaderIP: ip,
SkylinkID: skylink.ID,
Timestamp: time.Now().UTC(),
}
ior, err := db.staticUploads.InsertOne(ctx, up)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions database/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ const (
)

var (
// AnonUser is a helper struct that we can use when we don't have a relevant
// user, e.g. when an upload is made by an anonymous user.
AnonUser = User{}
// True is a helper for when we need to pass a *bool to MongoDB.
True = true
// False is a helper for when we need to pass a *bool to MongoDB.
Expand Down
4 changes: 2 additions & 2 deletions test/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,14 @@ func TestUserTierCache(t *testing.T) {
}
// Register a test upload that exceeds the user's allowed storage, so their
// QuotaExceeded flag will get raised.
sl, _, err := test.CreateTestUpload(at.Ctx, at.DB, u.User, database.UserLimits[u.Tier].Storage+1)
sl, _, err := test.CreateTestUpload(at.Ctx, at.DB, *u.User, database.UserLimits[u.Tier].Storage+1)
if err != nil {
t.Fatal(err)
}
// Make a specific call to trackUploadPOST in order to trigger the
// checkUserQuotas method. This wil register the upload a second time but
// that doesn't affect the test.
_, err = at.TrackUpload(sl.Skylink)
_, err = at.TrackUpload(sl.Skylink, "")
if err != nil {
t.Fatal(err)
}
Expand Down
6 changes: 3 additions & 3 deletions test/api/apikeys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func testPrivateAPIKeysUsage(t *testing.T, at *test.AccountsTester) {
t.Fatal(err)
}
uploadSize := int64(fastrand.Intn(int(modules.SectorSize / 2)))
_, _, err = test.CreateTestUpload(at.Ctx, at.DB, u, uploadSize)
_, _, err = test.CreateTestUpload(at.Ctx, at.DB, *u, uploadSize)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -242,11 +242,11 @@ func testPublicAPIKeysUsage(t *testing.T, at *test.AccountsTester) {
t.Fatal(err)
}
uploadSize := int64(fastrand.Intn(int(modules.SectorSize / 2)))
sl, _, err := test.CreateTestUpload(at.Ctx, at.DB, u, uploadSize)
sl, _, err := test.CreateTestUpload(at.Ctx, at.DB, *u, uploadSize)
if err != nil {
t.Fatal(err)
}
sl2, _, err := test.CreateTestUpload(at.Ctx, at.DB, u, uploadSize)
sl2, _, err := test.CreateTestUpload(at.Ctx, at.DB, *u, uploadSize)
if err != nil {
t.Fatal(err)
}
Expand Down
38 changes: 21 additions & 17 deletions test/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,11 @@ func testUserDELETE(t *testing.T, at *test.AccountsTester) {
t.Fatal("Failed to create a user and log in:", err)
}
// Create some data for this user.
sl, _, err := test.CreateTestUpload(at.Ctx, at.DB, u.User, 128)
sl, _, err := test.CreateTestUpload(at.Ctx, at.DB, *u.User, 128)
if err != nil {
t.Fatal(err)
}
err = at.DB.DownloadCreate(at.Ctx, *u.User, *sl, 128)
_, err = at.DB.DownloadCreate(at.Ctx, *u.User, *sl, 128)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -505,14 +505,14 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) {
// anonymous levels. Their tier should remain Free.
dbu2 := *u2.User
filesize := database.UserLimits[database.TierFree].Storage + 1
sl, _, err := test.CreateTestUpload(at.Ctx, at.DB, &dbu2, filesize)
sl, _, err := test.CreateTestUpload(at.Ctx, at.DB, dbu2, filesize)
if err != nil {
t.Fatal(err)
}
// Make a specific call to trackUploadPOST in order to trigger the
// checkUserQuotas method. This wil register the upload a second time but
// that doesn't affect the test.
_, err = at.TrackUpload(sl.Skylink)
_, err = at.TrackUpload(sl.Skylink, "")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -588,7 +588,7 @@ func testUserUploadsDELETE(t *testing.T, at *test.AccountsTester) {
defer at.ClearCredentials()

// Create an upload.
skylink, _, err := test.CreateTestUpload(at.Ctx, at.DB, u.User, 128%skynet.KB)
skylink, _, err := test.CreateTestUpload(at.Ctx, at.DB, *u.User, 128%skynet.KB)
// Make sure it shows up for this user.
_, b, err := at.Get("/user/uploads", nil)
if err != nil {
Expand Down Expand Up @@ -915,20 +915,22 @@ func testTrackingAndStats(t *testing.T, at *test.AccountsTester) {
}
expectedStats := database.UserStats{}

// Call trackUpload without a cookie.
// Call trackUpload without a cookie. We expect this to succeed.
// While we expect this to succeed, it won't be counted towards the user's
// quota, so we don't increment the expected stats.
at.ClearCredentials()
_, err = at.TrackUpload(skylink.String())
if err == nil || !strings.Contains(err.Error(), unauthorized) {
t.Fatalf("Expected error '%s', got '%v'", unauthorized, err)
_, err = at.TrackUpload(skylink.String(), "")
if err != nil {
t.Fatal(err)
}
at.SetCookie(c)
// Call trackUpload with an invalid skylink.
_, err = at.TrackUpload("INVALID_SKYLINK")
_, err = at.TrackUpload("INVALID_SKYLINK", "")
if err == nil || !strings.Contains(err.Error(), badRequest) {
t.Fatalf("Expected '%s', got '%v'", badRequest, err)
}
// Call trackUpload with a valid skylink.
_, err = at.TrackUpload(skylink.String())
_, err = at.TrackUpload(skylink.String(), "")
if err != nil {
t.Fatal(err)
}
Expand All @@ -938,11 +940,13 @@ func testTrackingAndStats(t *testing.T, at *test.AccountsTester) {
expectedStats.BandwidthUploads += skynet.BandwidthUploadCost(0)
expectedStats.RawStorageUsed += skynet.RawStorageUsed(0)

// Call trackDownload without a cookie.
// Call trackDownload without a cookie. Expect this to succeed.
// While we expect this to succeed, it won't be counted towards the user's
// quota, so we don't increment the expected stats.
at.ClearCredentials()
_, err = at.TrackDownload(skylink.String(), 100)
if err == nil || !strings.Contains(err.Error(), unauthorized) {
t.Fatalf("Expected error '%s', got '%v'", unauthorized, err)
if err != nil {
t.Fatal(err)
}
at.SetCookie(c)
// Call trackDownload with an invalid skylink.
Expand All @@ -956,14 +960,14 @@ func testTrackingAndStats(t *testing.T, at *test.AccountsTester) {
t.Fatalf("Expected '%s', got '%v'", badRequest, err)
}
// Call trackDownload with a valid skylink.
_, err = at.TrackDownload(skylink.String(), 100)
_, err = at.TrackDownload(skylink.String(), 200)
if err != nil {
t.Fatal(err)
}
// Adjust the expectations.
expectedStats.NumDownloads++
expectedStats.BandwidthDownloads += skynet.BandwidthDownloadCost(100)
expectedStats.TotalDownloadsSize += 100
expectedStats.BandwidthDownloads += skynet.BandwidthDownloadCost(200)
expectedStats.TotalDownloadsSize += 200

// Call trackRegistryRead without a cookie.
at.ClearCredentials()
Expand Down
Loading