Skip to content

Commit

Permalink
Merge pull request #216 from SkynetLabs/ivo/email_type
Browse files Browse the repository at this point in the history
Introduce an Email type that handles capitalization.
  • Loading branch information
ro-tex committed Jun 8, 2022
2 parents 4cd88bb + 0dcf2b3 commit ad51e76
Show file tree
Hide file tree
Showing 17 changed files with 252 additions and 108 deletions.
5 changes: 3 additions & 2 deletions api/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/SkynetLabs/skynet-accounts/database"
"github.com/SkynetLabs/skynet-accounts/jwt"
"github.com/SkynetLabs/skynet-accounts/types"
"github.com/sirupsen/logrus"
"gitlab.com/NebulousLabs/errors"
"gitlab.com/NebulousLabs/fastrand"
Expand Down Expand Up @@ -47,7 +48,7 @@ func TestTokenFromRequest(t *testing.T) {
if err != nil {
t.Fatal(err)
}
tk, err := jwt.TokenForUser(t.Name()+"@siasky.net", t.Name()+"_sub")
tk, err := jwt.TokenForUser(types.NewEmail(t.Name()+"@siasky.net"), t.Name()+"_sub")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -97,7 +98,7 @@ func TestTokenFromRequest(t *testing.T) {

// Token from request with a header and a cookie. Expect the header to take
// precedence.
tk2, err := jwt.TokenForUser(t.Name()+"2@siasky.net", t.Name()+"2_sub")
tk2, err := jwt.TokenForUser(types.NewEmail(t.Name()+"2@siasky.net"), t.Name()+"2_sub")
if err != nil {
t.Fatal(err)
}
Expand Down
33 changes: 17 additions & 16 deletions api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/SkynetLabs/skynet-accounts/lib"
"github.com/SkynetLabs/skynet-accounts/metafetcher"
"github.com/SkynetLabs/skynet-accounts/skynet"
"github.com/SkynetLabs/skynet-accounts/types"
"github.com/julienschmidt/httprouter"
jwt2 "github.com/lestrrat-go/jwx/jwt"
"gitlab.com/NebulousLabs/errors"
Expand Down Expand Up @@ -125,16 +126,16 @@ type (

// credentialsPOST defines the standard credentials package we expect.
credentialsPOST struct {
Email string `json:"email"`
Password string `json:"password"`
Email types.Email `json:"email"`
Password string `json:"password"`
}

// userUpdatePUT defines the fields of the User record that can be changed
// externally, e.g. by calling `PUT /user`.
userUpdatePUT struct {
Email string `json:"email,omitempty"`
Password string `json:"password,omitempty"`
StripeID string `json:"stripeCustomerId,omitempty"`
Email types.Email `json:"email,omitempty"`
Password string `json:"password,omitempty"`
StripeID string `json:"stripeCustomerId,omitempty"`
}
)

Expand Down Expand Up @@ -231,7 +232,7 @@ func (api *API) loginPOSTChallengeResponse(w http.ResponseWriter, req *http.Requ
}

// loginPOSTCredentials is a helper that handles logins with credentials.
func (api *API) loginPOSTCredentials(w http.ResponseWriter, req *http.Request, email, password string) {
func (api *API) loginPOSTCredentials(w http.ResponseWriter, req *http.Request, email types.Email, password string) {
// Fetch the user with that email, if they exist.
u, err := api.staticDB.UserByEmail(req.Context(), email)
if err != nil {
Expand Down Expand Up @@ -388,8 +389,8 @@ func (api *API) registerPOST(_ *database.User, w http.ResponseWriter, req *http.
api.WriteError(w, errors.AddContext(err, "failed to parse request body"), http.StatusBadRequest)
return
}
parsed, err := mail.ParseAddress(payload.Email)
if err != nil || payload.Email != parsed.Address {
parsed, err := mail.ParseAddress(payload.Email.String())
if err != nil || payload.Email.String() != parsed.Address {
api.WriteError(w, errors.New("invalid email provided"), http.StatusBadRequest)
return
}
Expand Down Expand Up @@ -616,8 +617,8 @@ func (api *API) userPOST(_ *database.User, w http.ResponseWriter, req *http.Requ
api.WriteError(w, errors.New("email is required"), http.StatusBadRequest)
return
}
parsed, err := mail.ParseAddress(payload.Email)
if err != nil || payload.Email != parsed.Address {
parsed, err := mail.ParseAddress(payload.Email.String())
if err != nil || payload.Email.String() != parsed.Address {
api.WriteError(w, errors.New("invalid email provided"), http.StatusBadRequest)
return
}
Expand Down Expand Up @@ -714,8 +715,8 @@ func (api *API) userPUT(u *database.User, w http.ResponseWriter, req *http.Reque

var changedEmail bool
if payload.Email != "" {
parsed, err := mail.ParseAddress(payload.Email)
if err != nil || payload.Email != parsed.Address {
parsed, err := mail.ParseAddress(payload.Email.String())
if err != nil || payload.Email.String() != parsed.Address {
api.WriteError(w, errors.New("invalid email provided"), http.StatusBadRequest)
return
}
Expand Down Expand Up @@ -995,10 +996,10 @@ func (api *API) userRecoverRequestPOST(_ *database.User, w http.ResponseWriter,
return
}

// Read and parse the request body.
var payload struct {
Email string `json:"email"`
}
// Read and parse the request body. We do not expect a password but we want
// to use the same email parsing approach in all cases where we get an email
// address from the user.
var payload credentialsPOST
err = parseRequestBodyJSON(req.Body, LimitBodySizeSmall, &payload)
if err != nil {
err = errors.AddContext(err, "failed to parse request body")
Expand Down
3 changes: 2 additions & 1 deletion api/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"github.com/SkynetLabs/skynet-accounts/database"
"github.com/SkynetLabs/skynet-accounts/types"
"github.com/julienschmidt/httprouter"
"gitlab.com/NebulousLabs/errors"
"go.mongodb.org/mongo-driver/bson/primitive"
Expand All @@ -14,7 +15,7 @@ type (
// UploaderInfo gives information about a user who created an upload.
UploaderInfo struct {
UserID primitive.ObjectID
Email string
Email types.Email
Sub string
StripeID string
}
Expand Down
21 changes: 11 additions & 10 deletions database/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/SkynetLabs/skynet-accounts/hash"
"github.com/SkynetLabs/skynet-accounts/lib"
"github.com/SkynetLabs/skynet-accounts/skynet"
"github.com/SkynetLabs/skynet-accounts/types"
"gitlab.com/NebulousLabs/errors"
"gitlab.com/SkynetLabs/skyd/build"
"go.mongodb.org/mongo-driver/bson"
Expand Down Expand Up @@ -109,7 +110,7 @@ type (
// ID is auto-generated by Mongo on insert. We will usually use it in
// its ID.Hex() form.
ID primitive.ObjectID `bson:"_id,omitempty" json:"-"`
Email string `bson:"email" json:"email"`
Email types.Email `bson:"email" json:"email"`
EmailConfirmationToken string `bson:"email_confirmation_token,omitempty" json:"-"`
EmailConfirmationTokenExpiration time.Time `bson:"email_confirmation_token_expiration,omitempty" json:"-"`
PasswordHash string `bson:"password_hash" json:"-"`
Expand Down Expand Up @@ -140,8 +141,8 @@ type (
)

// UserByEmail returns the user with the given username.
func (db *DB) UserByEmail(ctx context.Context, email string) (*User, error) {
users, err := db.managedUsersByField(ctx, "email", email)
func (db *DB) UserByEmail(ctx context.Context, email types.Email) (*User, error) {
users, err := db.managedUsersByField(ctx, "email", email.String())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -263,14 +264,14 @@ func (db *DB) UserConfirmEmail(ctx context.Context, token string) (*User, error)
//
// The new user is created as "unconfirmed" and a confirmation email is sent to
// the address they provided.
func (db *DB) UserCreate(ctx context.Context, emailAddr, pass, sub string, tier int) (*User, error) {
func (db *DB) UserCreate(ctx context.Context, emailAddr types.Email, pass, sub string, tier int) (*User, error) {
// Ensure the email is valid if it's passed. We allow empty emails.
if emailAddr != "" {
addr, err := mail.ParseAddress(emailAddr)
addr, err := mail.ParseAddress(emailAddr.String())
if err != nil {
return nil, errors.AddContext(err, "invalid email address")
}
emailAddr = addr.Address
emailAddr = types.NewEmail(addr.Address)
}
if sub == "" {
return nil, errors.New("empty sub is not allowed")
Expand Down Expand Up @@ -367,14 +368,14 @@ func (db *DB) UserCreateEmailConfirmation(ctx context.Context, uID primitive.Obj
//
// The new user is created as "unconfirmed" and a confirmation email is sent to
// the address they provided.
func (db *DB) UserCreatePK(ctx context.Context, emailAddr, pass, sub string, pk PubKey, tier int) (*User, error) {
func (db *DB) UserCreatePK(ctx context.Context, emailAddr types.Email, pass, sub string, pk PubKey, tier int) (*User, error) {
// Validate the email.
parsed, err := mail.ParseAddress(emailAddr)
if err != nil || parsed.Address != emailAddr {
parsed, err := mail.ParseAddress(emailAddr.String())
if err != nil || parsed.Address != emailAddr.String() {
return nil, errors.AddContext(err, "invalid email address")
}
// Check for an existing user with this email.
users, err := db.managedUsersByField(ctx, "email", emailAddr)
users, err := db.managedUsersByField(ctx, "email", emailAddr.String())
if err != nil && !errors.Contains(err, ErrUserNotFound) {
return nil, errors.AddContext(err, "failed to query DB")
}
Expand Down
13 changes: 7 additions & 6 deletions email/mailer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"

"github.com/SkynetLabs/skynet-accounts/database"
"github.com/SkynetLabs/skynet-accounts/types"
)

/**
Expand Down Expand Up @@ -35,15 +36,15 @@ func (em Mailer) Send(ctx context.Context, m database.EmailMessage) error {

// SendAddressConfirmationEmail sends a new email to the given email address
// with a link to confirm the ownership of the address.
func (em Mailer) SendAddressConfirmationEmail(ctx context.Context, email, token string) error {
m := confirmEmailEmail(email, token)
func (em Mailer) SendAddressConfirmationEmail(ctx context.Context, email types.Email, token string) error {
m := confirmEmailEmail(email.String(), token)
return em.Send(ctx, *m)
}

// SendRecoverAccountEmail sends a new email to the given email address
// with a link to recover the account.
func (em Mailer) SendRecoverAccountEmail(ctx context.Context, email, token string) error {
m := recoverAccountEmail(email, token)
func (em Mailer) SendRecoverAccountEmail(ctx context.Context, email types.Email, token string) error {
m := recoverAccountEmail(email.String(), token)
return em.Send(ctx, *m)
}

Expand All @@ -52,7 +53,7 @@ func (em Mailer) SendRecoverAccountEmail(ctx context.Context, email, token strin
// recover a Skynet account but their email is not in our system. The main
// reason to do that is because the user might have forgotten which email they
// used for signing up.
func (em Mailer) SendAccountAccessAttemptedEmail(ctx context.Context, email string) error {
m := accountAccessAttemptedEmail(email)
func (em Mailer) SendAccountAccessAttemptedEmail(ctx context.Context, email types.Email) error {
m := accountAccessAttemptedEmail(email.String())
return em.Send(ctx, *m)
}
7 changes: 4 additions & 3 deletions jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io/ioutil"
"time"

"github.com/SkynetLabs/skynet-accounts/types"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
Expand Down Expand Up @@ -74,7 +75,7 @@ func ContextWithToken(ctx context.Context, token jwt.Token) context.Context {
//
// The tokens generated by this function are a slimmed down version of the ones
// described in ValidateToken's docstring.
func TokenForUser(email, sub string) (jwt.Token, error) {
func TokenForUser(email types.Email, sub string) (jwt.Token, error) {
sigAlgo, key, err := signatureAlgoAndKey()
if err != nil {
return nil, err
Expand Down Expand Up @@ -252,15 +253,15 @@ func signatureAlgoAndKey() (jwa.SignatureAlgorithm, jwk.Key, error) {

// tokenForUser is a helper method that puts together an unsigned token based
// on the provided values.
func tokenForUser(emailAddr, sub string) (jwt.Token, error) {
func tokenForUser(emailAddr types.Email, sub string) (jwt.Token, error) {
if emailAddr == "" || sub == "" {
return nil, errors.New("email and sub cannot be empty")
}
session := tokenSession{
Active: true,
Identity: tokenIdentity{
Traits: tokenTraits{
Email: emailAddr,
Email: emailAddr.String(),
},
},
}
Expand Down
7 changes: 4 additions & 3 deletions jwt/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"
"time"

"github.com/SkynetLabs/skynet-accounts/types"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
"github.com/sirupsen/logrus"
Expand All @@ -19,7 +20,7 @@ func TestJWT(t *testing.T) {
if err != nil {
t.Fatal(err)
}
email := t.Name() + "@siasky.net"
email := types.NewEmail(t.Name() + "@siasky.net")
sub := "this is a sub"
fakeSub := "fake sub"
tk, err := TokenForUser(email, sub)
Expand Down Expand Up @@ -59,7 +60,7 @@ func TestValidateToken_Expired(t *testing.T) {
if err != nil {
t.Fatal(err)
}
email := t.Name() + "@siasky.net"
email := types.NewEmail(t.Name() + "@siasky.net")
sub := "this is a sub"
// Fetch the tools we need in order to craft a custom token.
key, found := AccountsJWKS.Get(0)
Expand All @@ -81,7 +82,7 @@ func TestValidateToken_Expired(t *testing.T) {
Active: true,
Identity: tokenIdentity{
Traits: tokenTraits{
Email: email,
Email: email.String(),
},
},
}
Expand Down
13 changes: 7 additions & 6 deletions test/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/SkynetLabs/skynet-accounts/api"
"github.com/SkynetLabs/skynet-accounts/database"
"github.com/SkynetLabs/skynet-accounts/test"
"github.com/SkynetLabs/skynet-accounts/types"
"gitlab.com/NebulousLabs/fastrand"
"go.sia.tech/siad/build"

Expand All @@ -33,9 +34,9 @@ func TestWithDBSession(t *testing.T) {
t.Fatal("Failed to instantiate API.", err)
}

emailSuccess := t.Name() + "success@siasky.net"
emailSuccessJSON := t.Name() + "success_json@siasky.net"
emailFailure := t.Name() + "failure@siasky.net"
emailSuccess := types.NewEmail(t.Name() + "success@siasky.net")
emailSuccessJSON := types.NewEmail(t.Name() + "success_json@siasky.net")
emailFailure := types.NewEmail(t.Name() + "failure@siasky.net")

// This handler successfully creates a user in the DB and exits with
// a success status code. We expect the user to exist in the DB after
Expand All @@ -52,7 +53,7 @@ func TestWithDBSession(t *testing.T) {
t.Fatal("Failed to fetch user from DB.", err)
}
if u.Email != emailSuccess {
t.Fatalf("Expected email %s, got %s.", emailSuccess, u.Email)
t.Fatalf("Expected email '%v', got '%v'.", emailSuccess, u.Email)
}
testAPI.WriteSuccess(w)
}
Expand Down Expand Up @@ -147,7 +148,7 @@ func TestUserTierCache(t *testing.T) {
}
}()

emailAddr := test.DBNameForTest(t.Name()) + "@siasky.net"
emailAddr := types.NewEmail(test.DBNameForTest(t.Name()) + "@siasky.net")
password := hex.EncodeToString(fastrand.Bytes(16))
u, err := test.CreateUser(at, emailAddr, password)
if err != nil {
Expand All @@ -165,7 +166,7 @@ func TestUserTierCache(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r, _, err := at.LoginCredentialsPOST(emailAddr, password)
r, _, err := at.LoginCredentialsPOST(emailAddr.String(), password)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading

0 comments on commit ad51e76

Please sign in to comment.