Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anonymous uploads + track upload IPs #167

Merged
merged 9 commits into from
Mar 25, 2022
70 changes: 65 additions & 5 deletions api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"net/mail"
"net/url"
"regexp"
"strconv"
"strings"
"time"
Expand All @@ -20,6 +21,7 @@ import (
"github.com/SkynetLabs/skynet-accounts/metafetcher"
"github.com/SkynetLabs/skynet-accounts/skynet"
"github.com/julienschmidt/httprouter"
jwt2 "github.com/lestrrat-go/jwx/jwt"
"gitlab.com/NebulousLabs/errors"
"go.mongodb.org/mongo-driver/mongo"
)
Expand Down Expand Up @@ -1062,7 +1064,7 @@ func (api *API) userRecoverPOST(_ *database.User, w http.ResponseWriter, req *ht
}

// trackUploadPOST registers a new upload in the system.
func (api *API) trackUploadPOST(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
func (api *API) trackUploadPOST(_ *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
sl := ps.ByName("skylink")
if sl == "" {
api.WriteError(w, errors.New("missing parameter 'skylink'"), http.StatusBadRequest)
Expand All @@ -1077,7 +1079,14 @@ func (api *API) trackUploadPOST(u *database.User, w http.ResponseWriter, req *ht
api.WriteError(w, err, http.StatusInternalServerError)
return
}
_, err = api.staticDB.UploadCreate(req.Context(), *u, *skylink)
u, _, err := api.userFromRequest(req)
if err != nil {
// This will be tracked as an anonymous request.
// The assignment below is redundant but adds clarity.
u = nil
}
ip := validateIP(req.Form.Get("ip"))
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
_, err = api.staticDB.UploadCreate(req.Context(), u, ip, *skylink)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
Expand All @@ -1096,11 +1105,13 @@ func (api *API) trackUploadPOST(u *database.User, w http.ResponseWriter, req *ht
// administrative details, such as user's quotas check.
// Note that this call is not affected by the request's context, so we use
// a separate one.
go api.checkUserQuotas(context.Background(), u)
if u != nil && !u.ID.IsZero() {
go api.checkUserQuotas(context.Background(), u)
}
}

// trackDownloadPOST registers a new download in the system.
func (api *API) trackDownloadPOST(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
func (api *API) trackDownloadPOST(_ *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
err := req.ParseForm()
if err != nil {
api.WriteError(w, err, http.StatusBadRequest)
Expand Down Expand Up @@ -1136,7 +1147,13 @@ func (api *API) trackDownloadPOST(u *database.User, w http.ResponseWriter, req *
api.WriteError(w, err, http.StatusInternalServerError)
return
}
err = api.staticDB.DownloadCreate(req.Context(), *u, *skylink, downloadedBytes)
u, _, err := api.userFromRequest(req)
if err != nil {
// This will be tracked as an anonymous request.
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
// The assignment below is redundant but adds clarity.
u = nil
}
err = api.staticDB.DownloadCreate(req.Context(), u, *skylink, downloadedBytes)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
Expand Down Expand Up @@ -1225,6 +1242,25 @@ func (api *API) checkUserQuotas(ctx context.Context, u *database.User) {
}
}

// userFromRequest checks the requests for various forms of authentication (API
// key, cookie, authorization header) and returns user information based on
// those.
func (api *API) userFromRequest(req *http.Request) (*database.User, jwt2.Token, error) {
// Check for an API key.
u, tk, err := api.userAndTokenByAPIKey(req)
if err != nil && !errors.Contains(err, ErrNoAPIKey) {
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
return nil, nil, err
}
// If there is no API key check for a token.
if errors.Contains(err, ErrNoAPIKey) {
u, tk, err = api.userAndTokenByRequestToken(req)
if err != nil {
return nil, nil, err
}
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
}
return u, tk, err
}

// wellKnownJWKSGET returns our public JWKS, so people can use that to verify
// the authenticity of the JWT tokens we issue.
func (api *API) wellKnownJWKSGET(_ *database.User, w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
Expand Down Expand Up @@ -1302,3 +1338,27 @@ func userLimitsGetFromTier(tierID int, quotaExceeded, inBytes bool) *UserLimitsG
RegistryDelay: limitsTier.RegistryDelay,
}
}

// validateIP is a simple pass-through helper that returns valid IPs as they are
// and returns an empty string for invalid IPs.
func validateIP(ip string) string {
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
ip = strings.ToLower(ip)
reV4 := regexp.MustCompile("^(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})$")
if reV4.MatchString(ip) {
submatches := reV4.FindAllStringSubmatch(ip, -1)
for i := 1; i < len(submatches[0]); i++ {
n, err := strconv.Atoi(submatches[0][i])
if err != nil || n < 0 || n > 255 {
return ""
}
}
return ip
}

// see https://stackoverflow.com/questions/53497/regular-expression-that-matches-valid-ipv6-addresses
reV6 := regexp.MustCompile("^(([0-9a-f]{1,4}:){7,7}[0-9a-f]{1,4}|([0-9a-f]{1,4}:){1,7}:|([0-9a-f]{1,4}:){1,6}:[0-9a-f]{1,4}|([0-9a-f]{1,4}:){1,5}(:[0-9a-f]{1,4}){1,2}|([0-9a-f]{1,4}:){1,4}(:[0-9a-f]{1,4}){1,3}|([0-9a-f]{1,4}:){1,3}(:[0-9a-f]{1,4}){1,4}|([0-9a-f]{1,4}:){1,2}(:[0-9a-f]{1,4}){1,5}|[0-9a-f]{1,4}:((:[0-9a-f]{1,4}){1,6})|:((:[0-9a-f]{1,4}){1,7}|:)|fe80:(:[0-9a-f]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(ffff(:0{1,4}){0,1}:){0,1}((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])|([0-9a-f]{1,4}:){1,4}:((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9]))$")
if reV6.MatchString(ip) {
return ip
}
return ""
}
27 changes: 27 additions & 0 deletions api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,30 @@ func TestUserLimitsGetFromTier(t *testing.T) {
t.Fatal(err)
}
}

// TestValidateIP ensures that validateIP works as expected for both IPv4 and
// IPv6 IP addresses.
func TestValidateIP(t *testing.T) {
tests := []struct {
in string
expected string
}{
{in: "", expected: ""},
{in: "12.12.12", expected: ""},
{in: "1.2.3.256", expected: ""},
{in: "1.2.3.4", expected: "1.2.3.4"},
{in: "0.0.0.0", expected: "0.0.0.0"},
{in: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", expected: "2001:0db8:85a3:0000:0000:8a2e:0370:7334"},
{in: "FE80:0000:0000:0000:0202:B3FF:FE1E:8329", expected: strings.ToLower("FE80:0000:0000:0000:0202:B3FF:FE1E:8329")},
{in: "2001:db8:0:0:0:ff00:42:8329", expected: "2001:db8:0:0:0:ff00:42:8329"},
{in: "2001:db8::ff00:42:8329", expected: "2001:db8::ff00:42:8329"},
{in: "::1", expected: "::1"},
}

for _, tt := range tests {
out := validateIP(tt.in)
if out != tt.expected {
t.Errorf("Expected '%s' to get me '%s' but got '%s'", tt.in, tt.expected, out)
}
}
}
24 changes: 5 additions & 19 deletions api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ func (api *API) buildHTTPRoutes() {
api.staticRouter.POST("/register", api.WithDBSession(api.noAuth(api.registerPOST)))

// Endpoints at which Nginx reports portal usage.
api.staticRouter.POST("/track/upload/:skylink", api.withAuth(api.trackUploadPOST))
api.staticRouter.POST("/track/download/:skylink", api.withAuth(api.trackDownloadPOST))
api.staticRouter.POST("/track/upload/:skylink", api.noAuth(api.trackUploadPOST))
api.staticRouter.POST("/track/download/:skylink", api.noAuth(api.trackDownloadPOST))
api.staticRouter.POST("/track/registry/read", api.withAuth(api.trackRegistryReadPOST))
api.staticRouter.POST("/track/registry/write", api.withAuth(api.trackRegistryWritePOST))

Expand Down Expand Up @@ -92,25 +92,11 @@ func (api *API) noAuth(h HandlerWithUser) httprouter.Handle {
func (api *API) withAuth(h HandlerWithUser) httprouter.Handle {
return func(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
api.logRequest(req)
// Check for an API key.
u, token, err := api.userAndTokenByAPIKey(req)
// If there is an unexpected error, that is a 500.
if err != nil && !errors.Contains(err, ErrNoAPIKey) && !errors.Contains(err, database.ErrInvalidAPIKey) && !errors.Contains(err, database.ErrUserNotFound) {
api.WriteError(w, err, http.StatusInternalServerError)
u, token, err := api.userFromRequest(req)
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return
}
if err != nil && (errors.Contains(err, database.ErrInvalidAPIKey) || errors.Contains(err, database.ErrUserNotFound)) {
api.WriteError(w, errors.AddContext(err, "failed to fetch user by API key"), http.StatusUnauthorized)
return
}
// If there is no API key check for a token.
if errors.Contains(err, ErrNoAPIKey) {
u, token, err = api.userAndTokenByRequestToken(req)
if err != nil {
api.WriteError(w, err, http.StatusUnauthorized)
return
}
}
// Embed the verified token in the context of the request.
ctx := jwt.ContextWithToken(req.Context(), token)
h(u, w, req.WithContext(ctx), ps)
Expand Down
13 changes: 8 additions & 5 deletions database/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,19 @@ func (db *DB) DownloadByID(ctx context.Context, id primitive.ObjectID) (*Downloa

// DownloadCreate registers a new download. Marks partial downloads by supplying
// the `bytes` param. If `bytes` is 0 we assume a full download.
func (db *DB) DownloadCreate(ctx context.Context, user User, skylink Skylink, bytes int64) error {
if user.ID.IsZero() {
return errors.New("invalid user")
func (db *DB) DownloadCreate(ctx context.Context, user *User, skylink Skylink, bytes int64) error {
if user == nil {
// If there is no user passed, we initialise it with the zero ID in
// order to denote an anonymous download.
user = &User{ID: primitive.ObjectID{}}
}
if skylink.ID.IsZero() {
return ErrInvalidSkylink
}

// Check if there exists a download of this skylink by this user, updated
// within the DownloadUpdateWindow and keep updating that, if so.
down, err := db.DownloadRecent(ctx, skylink.ID)
down, err := db.DownloadRecent(ctx, user.ID, skylink.ID)
if err == nil {
// We found a recent download of this skylink. Let's update it.
return db.DownloadIncrement(ctx, down, bytes)
Expand Down Expand Up @@ -126,9 +128,10 @@ func (db *DB) downloadsBy(ctx context.Context, matchStage bson.D, offset, pageSi
}

// DownloadRecent returns the most recent download of the given skylink.
func (db *DB) DownloadRecent(ctx context.Context, skylinkID primitive.ObjectID) (*Download, error) {
func (db *DB) DownloadRecent(ctx context.Context, uID primitive.ObjectID, skylinkID primitive.ObjectID) (*Download, error) {
updatedAtThreshold := time.Now().UTC().Add(-1 * DownloadUpdateWindow)
filter := bson.D{
{"user_id", uID},
{"skylink_id", skylinkID},
{"updated_at", bson.D{{"$gt", updatedAtThreshold}}},
}
Expand Down
10 changes: 7 additions & 3 deletions database/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
type Upload struct {
ID primitive.ObjectID `bson:"_id,omitempty" json:"id"`
UserID primitive.ObjectID `bson:"user_id,omitempty" json:"userId"`
IP string `bson:"ip" json:"IP"`
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
SkylinkID primitive.ObjectID `bson:"skylink_id,omitempty" json:"skylinkId"`
Timestamp time.Time `bson:"timestamp" json:"timestamp"`
Unpinned bool `bson:"unpinned" json:"-"`
Expand Down Expand Up @@ -45,15 +46,18 @@ func (db *DB) UploadByID(ctx context.Context, id primitive.ObjectID) (*Upload, e

// UploadCreate registers a new upload and counts it towards the user's used
// storage.
func (db *DB) UploadCreate(ctx context.Context, user User, skylink Skylink) (*Upload, error) {
if user.ID.IsZero() {
return nil, errors.New("invalid user")
func (db *DB) UploadCreate(ctx context.Context, user *User, ip string, skylink Skylink) (*Upload, error) {
if user == nil {
// If there is no user passed, we initialise it with the zero ID in
// order to denote an anonymous upload.
user = &User{ID: primitive.ObjectID{}}
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
}
if skylink.ID.IsZero() {
return nil, errors.New("skylink doesn't exist")
}
up := Upload{
UserID: user.ID,
IP: ip,
SkylinkID: skylink.ID,
Timestamp: time.Now().UTC(),
}
Expand Down
2 changes: 1 addition & 1 deletion test/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ func TestUserTierCache(t *testing.T) {
// Make a specific call to trackUploadPOST in order to trigger the
// checkUserQuotas method. This wil register the upload a second time but
// that doesn't affect the test.
_, err = at.TrackUpload(sl.Skylink)
_, err = at.TrackUpload(sl.Skylink, "")
if err != nil {
t.Fatal(err)
}
Expand Down
32 changes: 18 additions & 14 deletions test/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ func testUserDELETE(t *testing.T, at *test.AccountsTester) {
if err != nil {
t.Fatal(err)
}
err = at.DB.DownloadCreate(at.Ctx, *u.User, *sl, 128)
err = at.DB.DownloadCreate(at.Ctx, u.User, *sl, 128)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -512,7 +512,7 @@ func testUserLimits(t *testing.T, at *test.AccountsTester) {
// Make a specific call to trackUploadPOST in order to trigger the
// checkUserQuotas method. This wil register the upload a second time but
// that doesn't affect the test.
_, err = at.TrackUpload(sl.Skylink)
_, err = at.TrackUpload(sl.Skylink, "")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -915,20 +915,22 @@ func testTrackingAndStats(t *testing.T, at *test.AccountsTester) {
}
expectedStats := database.UserStats{}

// Call trackUpload without a cookie.
// Call trackUpload without a cookie. We expect this to succeed.
// While we expect this to succeed, it won't be counted towards the user's
// quota, so we don't increment the expected stats.
at.ClearCredentials()
_, err = at.TrackUpload(skylink.String())
if err == nil || !strings.Contains(err.Error(), unauthorized) {
t.Fatalf("Expected error '%s', got '%v'", unauthorized, err)
_, err = at.TrackUpload(skylink.String(), "")
if err != nil {
t.Fatal(err)
}
at.SetCookie(c)
// Call trackUpload with an invalid skylink.
_, err = at.TrackUpload("INVALID_SKYLINK")
_, err = at.TrackUpload("INVALID_SKYLINK", "")
if err == nil || !strings.Contains(err.Error(), badRequest) {
t.Fatalf("Expected '%s', got '%v'", badRequest, err)
}
// Call trackUpload with a valid skylink.
_, err = at.TrackUpload(skylink.String())
_, err = at.TrackUpload(skylink.String(), "")
if err != nil {
t.Fatal(err)
}
Expand All @@ -938,11 +940,13 @@ func testTrackingAndStats(t *testing.T, at *test.AccountsTester) {
expectedStats.BandwidthUploads += skynet.BandwidthUploadCost(0)
expectedStats.RawStorageUsed += skynet.RawStorageUsed(0)

// Call trackDownload without a cookie.
// Call trackDownload without a cookie. Expect this to succeed.
// While we expect this to succeed, it won't be counted towards the user's
// quota, so we don't increment the expected stats.
at.ClearCredentials()
_, err = at.TrackDownload(skylink.String(), 100)
if err == nil || !strings.Contains(err.Error(), unauthorized) {
t.Fatalf("Expected error '%s', got '%v'", unauthorized, err)
if err != nil {
t.Fatal(err)
}
at.SetCookie(c)
// Call trackDownload with an invalid skylink.
Expand All @@ -956,14 +960,14 @@ func testTrackingAndStats(t *testing.T, at *test.AccountsTester) {
t.Fatalf("Expected '%s', got '%v'", badRequest, err)
}
// Call trackDownload with a valid skylink.
_, err = at.TrackDownload(skylink.String(), 100)
_, err = at.TrackDownload(skylink.String(), 200)
if err != nil {
t.Fatal(err)
}
// Adjust the expectations.
expectedStats.NumDownloads++
expectedStats.BandwidthDownloads += skynet.BandwidthDownloadCost(100)
expectedStats.TotalDownloadsSize += 100
expectedStats.BandwidthDownloads += skynet.BandwidthDownloadCost(200)
expectedStats.TotalDownloadsSize += 200

// Call trackRegistryRead without a cookie.
at.ClearCredentials()
Expand Down
29 changes: 29 additions & 0 deletions test/database/download_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package database

import (
"context"
"testing"

"github.com/SkynetLabs/skynet-accounts/database"
"github.com/SkynetLabs/skynet-accounts/test"
)

// TestDownloadCreateAnon ensures that UploadCreate can create anonymous downloads.
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
func TestDownloadCreateAnon(t *testing.T) {
ctx := context.Background()
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
dbName := test.DBNameForTest(t.Name())
db, err := database.NewCustomDB(ctx, dbName, test.DBTestCredentials(), nil)
if err != nil {
t.Fatal(err)
}
sl := test.RandomSkylink()
skylink, err := db.Skylink(ctx, sl)
if err != nil {
t.Fatal(err)
}
// Register an anonymous download.
err = db.DownloadCreate(ctx, nil, *skylink, 123)
ro-tex marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
t.Fatal(err)
}
}
Loading