Skip to content

Commit

Permalink
fix: move all EmailActionTypes to mailer package (#1510)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

Moves the mail related refactors from #1496 into a new PR so the hook
related PR is easier to review. This PR moves all EmailActionTypes to
mailer package to establish an explicit link between mailer and the
packages it is used in.

The changes are cosmetic and should not affect underlying functionality.
  • Loading branch information
J0 authored Apr 2, 2024
1 parent 8243e35 commit 765db08
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 60 deletions.
15 changes: 8 additions & 7 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
mail "github.com/supabase/auth/internal/mailer"
"net/http"
"strings"
"time"
Expand Down Expand Up @@ -64,8 +65,8 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error {
if err != nil {
if models.IsNotFoundError(err) {
switch params.Type {
case magicLinkVerification:
params.Type = signupVerification
case mail.MagicLinkVerification:
params.Type = mail.SignupVerification
params.Password, err = password.Generate(64, 10, 1, false, true)
if err != nil {
// password generation must always succeed
Expand All @@ -91,7 +92,7 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error {
hashedToken := crypto.GenerateTokenHash(params.Email, otp)

var signupUser *models.User
if params.Type == signupVerification && user == nil {
if params.Type == mail.SignupVerification && user == nil {
signupParams := &SignupParams{
Email: params.Email,
Password: params.Password,
Expand All @@ -113,7 +114,7 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error {
err = db.Transaction(func(tx *storage.Connection) error {
var terr error
switch params.Type {
case magicLinkVerification, recoveryVerification:
case mail.MagicLinkVerification, mail.RecoveryVerification:
if terr = models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil {
return terr
}
Expand All @@ -123,7 +124,7 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error {
if terr != nil {
terr = errors.Wrap(terr, "Database error updating user for recovery")
}
case inviteVerification:
case mail.InviteVerification:
if user != nil {
if user.IsConfirmed() {
return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg)
Expand Down Expand Up @@ -170,7 +171,7 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error {
if terr != nil {
terr = errors.Wrap(terr, "Database error updating user for invite")
}
case signupVerification:
case mail.SignupVerification:
if user != nil {
if user.IsConfirmed() {
return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg)
Expand Down Expand Up @@ -202,7 +203,7 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error {
if terr != nil {
terr = errors.Wrap(terr, "Database error updating user for confirmation")
}
case "email_change_current", "email_change_new":
case mail.EmailChangeCurrentVerification, mail.EmailChangeNewVerification:
if !config.Mailer.SecureEmailChangeEnabled && params.Type == "email_change_current" {
return badRequestError(ErrorCodeValidationFailed, "Enable secure email change to generate link for current email")
}
Expand Down
13 changes: 7 additions & 6 deletions internal/api/resend.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/supabase/auth/internal/api/sms_provider"
"github.com/supabase/auth/internal/conf"
mail "github.com/supabase/auth/internal/mailer"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
)
Expand All @@ -20,14 +21,14 @@ type ResendConfirmationParams struct {

func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) error {
switch p.Type {
case signupVerification, emailChangeVerification, smsVerification, phoneChangeVerification:
case mail.SignupVerification, mail.EmailChangeVerification, smsVerification, phoneChangeVerification:
break
default:
// type does not match one of the above
return badRequestError(ErrorCodeValidationFailed, "Missing one of these types: signup, email_change, sms, phone_change")

}
if p.Email == "" && p.Type == signupVerification {
if p.Email == "" && p.Type == mail.SignupVerification {
return badRequestError(ErrorCodeValidationFailed, "Type provided requires an email address")
}
if p.Phone == "" && p.Type == smsVerification {
Expand Down Expand Up @@ -91,7 +92,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
}

switch params.Type {
case signupVerification:
case mail.SignupVerification:
if user.IsConfirmed() {
// if the user's email is confirmed already, we don't need to send a confirmation email again
return sendJSON(w, http.StatusOK, map[string]string{})
Expand All @@ -101,7 +102,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
// if the user's phone is confirmed already, we don't need to send a confirmation sms again
return sendJSON(w, http.StatusOK, map[string]string{})
}
case emailChangeVerification:
case mail.EmailChangeVerification:
// do not resend if user doesn't have a new email address
if user.EmailChange == "" {
return sendJSON(w, http.StatusOK, map[string]string{})
Expand All @@ -116,7 +117,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
messageID := ""
err = db.Transaction(func(tx *storage.Connection) error {
switch params.Type {
case signupVerification:
case mail.SignupVerification:
if terr := models.NewAuditLogEntry(r, tx, user, models.UserConfirmationRequestedAction, "", nil); terr != nil {
return terr
}
Expand All @@ -135,7 +136,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
return terr
}
messageID = mID
case emailChangeVerification:
case mail.EmailChangeVerification:
return a.sendEmailChange(r, tx, user, user.EmailChange, models.ImplicitFlow)
case phoneChangeVerification:
smsProvider, terr := sms_provider.GetSmsProvider(*config)
Expand Down
7 changes: 4 additions & 3 deletions internal/api/resend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/supabase/auth/internal/conf"
mail "github.com/supabase/auth/internal/mailer"
"github.com/supabase/auth/internal/models"
)

Expand Down Expand Up @@ -194,15 +195,15 @@ func (ts *ResendTestSuite) TestResendSuccess() {
require.Equal(ts.T(), http.StatusOK, w.Code)

switch c.params["type"] {
case signupVerification, emailChangeVerification:
case mail.SignupVerification, mail.EmailChangeVerification:
dbUser, err := models.FindUserByID(ts.API.db, c.user.ID)
require.NoError(ts.T(), err)
require.NotEmpty(ts.T(), dbUser)

if c.params["type"] == signupVerification {
if c.params["type"] == mail.SignupVerification {
require.NotEqual(ts.T(), dbUser.ConfirmationToken, c.user.ConfirmationToken)
require.NotEqual(ts.T(), dbUser.ConfirmationSentAt, c.user.ConfirmationSentAt)
} else if c.params["type"] == emailChangeVerification {
} else if c.params["type"] == mail.EmailChangeVerification {
require.NotEqual(ts.T(), dbUser.EmailChangeTokenNew, c.user.EmailChangeTokenNew)
require.NotEqual(ts.T(), dbUser.EmailChangeSentAt, c.user.EmailChangeSentAt)
}
Expand Down
3 changes: 2 additions & 1 deletion internal/api/signup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
mail "github.com/supabase/auth/internal/mailer"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -132,7 +133,7 @@ func (ts *SignupTestSuite) TestVerifySignup() {
require.NoError(ts.T(), err)

// Setup request
reqUrl := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", signupVerification, u.ConfirmationToken)
reqUrl := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.SignupVerification, u.ConfirmationToken)
req := httptest.NewRequest(http.MethodGet, reqUrl, nil)

// Setup response recorder
Expand Down
49 changes: 22 additions & 27 deletions internal/api/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,17 @@ import (
"github.com/supabase/auth/internal/api/provider"
"github.com/supabase/auth/internal/api/sms_provider"
"github.com/supabase/auth/internal/crypto"
mail "github.com/supabase/auth/internal/mailer"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/storage"
"github.com/supabase/auth/internal/utilities"
)

const (
signupVerification = "signup"
recoveryVerification = "recovery"
inviteVerification = "invite"
magicLinkVerification = "magiclink"
emailChangeVerification = "email_change"
smsVerification = "sms"
phoneChangeVerification = "phone_change"
// includes signupVerification and magicLinkVerification
emailOTPVerification = "email"
)

const (
Expand Down Expand Up @@ -150,11 +145,11 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
return terr
}
switch params.Type {
case signupVerification, inviteVerification:
case mail.SignupVerification, mail.InviteVerification:
user, terr = a.signupVerify(r, ctx, tx, user)
case recoveryVerification, magicLinkVerification:
case mail.RecoveryVerification, mail.MagicLinkVerification:
user, terr = a.recoverVerify(r, tx, user)
case emailChangeVerification:
case mail.EmailChangeVerification:
user, terr = a.emailChangeVerify(r, tx, params, user)
if user == nil && terr == nil {
// when double confirmation is required
Expand Down Expand Up @@ -254,11 +249,11 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request, params *VerifyP
}

switch params.Type {
case signupVerification, inviteVerification:
case mail.SignupVerification, mail.InviteVerification:
user, terr = a.signupVerify(r, ctx, tx, user)
case recoveryVerification, magicLinkVerification:
case mail.RecoveryVerification, mail.MagicLinkVerification:
user, terr = a.recoverVerify(r, tx, user)
case emailChangeVerification:
case mail.EmailChangeVerification:
user, terr = a.emailChangeVerify(r, tx, params, user)
if user == nil && terr == nil {
isSingleConfirmationResponse = true
Expand Down Expand Up @@ -555,14 +550,14 @@ func (a *API) verifyTokenHash(conn *storage.Connection, params *VerifyParams) (*
var user *models.User
var err error
switch params.Type {
case emailOTPVerification:
case mail.EmailOTPVerification:
// need to find user by confirmation token or recovery token with the token hash
user, err = models.FindUserByConfirmationOrRecoveryToken(conn, params.TokenHash)
case signupVerification, inviteVerification:
case mail.SignupVerification, mail.InviteVerification:
user, err = models.FindUserByConfirmationToken(conn, params.TokenHash)
case recoveryVerification, magicLinkVerification:
case mail.RecoveryVerification, mail.MagicLinkVerification:
user, err = models.FindUserByRecoveryToken(conn, params.TokenHash)
case emailChangeVerification:
case mail.EmailChangeVerification:
user, err = models.FindUserByEmailChangeToken(conn, params.TokenHash)
default:
return nil, badRequestError(ErrorCodeValidationFailed, "Invalid email verification type")
Expand All @@ -581,19 +576,19 @@ func (a *API) verifyTokenHash(conn *storage.Connection, params *VerifyParams) (*

var isExpired bool
switch params.Type {
case emailOTPVerification:
case mail.EmailOTPVerification:
sentAt := user.ConfirmationSentAt
params.Type = "signup"
if user.RecoveryToken == params.TokenHash {
sentAt = user.RecoverySentAt
params.Type = "magiclink"
}
isExpired = isOtpExpired(sentAt, config.Mailer.OtpExp)
case signupVerification, inviteVerification:
case mail.SignupVerification, mail.InviteVerification:
isExpired = isOtpExpired(user.ConfirmationSentAt, config.Mailer.OtpExp)
case recoveryVerification, magicLinkVerification:
case mail.RecoveryVerification, mail.MagicLinkVerification:
isExpired = isOtpExpired(user.RecoverySentAt, config.Mailer.OtpExp)
case emailChangeVerification:
case mail.EmailChangeVerification:
isExpired = isOtpExpired(user.EmailChangeSentAt, config.Mailer.OtpExp)
}

Expand All @@ -617,7 +612,7 @@ func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams,
user, err = models.FindUserByPhoneChangeAndAudience(conn, params.Phone, aud)
case smsVerification:
user, err = models.FindUserByPhoneAndAudience(conn, params.Phone, aud)
case emailChangeVerification:
case mail.EmailChangeVerification:
// Since the email change could be trigger via the implicit or PKCE flow,
// the query used has to also check if the token saved in the db contains the pkce_ prefix
user, err = models.FindUserForEmailChange(conn, params.Email, tokenHash, aud, config.Mailer.SecureEmailChangeEnabled)
Expand All @@ -640,22 +635,22 @@ func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams,

smsProvider, _ := sms_provider.GetSmsProvider(*config)
switch params.Type {
case emailOTPVerification:
case mail.EmailOTPVerification:
// if the type is emailOTPVerification, we'll check both the confirmation_token and recovery_token columns
if isOtpValid(tokenHash, user.ConfirmationToken, user.ConfirmationSentAt, config.Mailer.OtpExp) {
isValid = true
params.Type = signupVerification
params.Type = mail.SignupVerification
} else if isOtpValid(tokenHash, user.RecoveryToken, user.RecoverySentAt, config.Mailer.OtpExp) {
isValid = true
params.Type = magicLinkVerification
params.Type = mail.MagicLinkVerification
} else {
isValid = false
}
case signupVerification, inviteVerification:
case mail.SignupVerification, mail.InviteVerification:
isValid = isOtpValid(tokenHash, user.ConfirmationToken, user.ConfirmationSentAt, config.Mailer.OtpExp)
case recoveryVerification, magicLinkVerification:
case mail.RecoveryVerification, mail.MagicLinkVerification:
isValid = isOtpValid(tokenHash, user.RecoveryToken, user.RecoverySentAt, config.Mailer.OtpExp)
case emailChangeVerification:
case mail.EmailChangeVerification:
isValid = isOtpValid(tokenHash, user.EmailChangeTokenCurrent, user.EmailChangeSentAt, config.Mailer.OtpExp) ||
isOtpValid(tokenHash, user.EmailChangeTokenNew, user.EmailChangeSentAt, config.Mailer.OtpExp)
case phoneChangeVerification, smsVerification:
Expand Down
Loading

0 comments on commit 765db08

Please sign in to comment.