Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: return SMS ID when possible #1145

Merged
merged 1 commit into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions internal/api/otp.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error {
return otpError("unsupported_otp_type", "")
}

type SmsOtpResponse struct {
MessageID string `json:"message_id,omitempty"`
}

// SmsOtp sends the user an otp via sms
func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
Expand Down Expand Up @@ -188,6 +192,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error {
return sendJSON(w, http.StatusOK, make(map[string]string))
}

messageID := ""
err = db.Transaction(func(tx *storage.Connection) error {
if err := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", map[string]interface{}{
"channel": params.Channel,
Expand All @@ -198,17 +203,21 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error {
if terr != nil {
return badRequestError("Error sending sms: %v", terr)
}
if err := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); err != nil {
return badRequestError("Error sending sms otp: %v", err)
mID, serr := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel)
if serr != nil {
return badRequestError("Error sending sms OTP: %v", err)
}
messageID = mID
return nil
})

if err != nil {
return err
}

return sendJSON(w, http.StatusOK, make(map[string]string))
return sendJSON(w, http.StatusOK, SmsOtpResponse{
MessageID: messageID,
})
}

func (a *API) shouldCreateUser(r *http.Request, params *OtpParams) (bool, error) {
Expand Down
15 changes: 8 additions & 7 deletions internal/api/phone.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func formatPhoneNumber(phone string) string {
}

// sendPhoneConfirmation sends an otp to the user's phone number
func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider, channel string) error {
func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider, channel string) (string, error) {
config := a.config

var token *string
Expand All @@ -65,17 +65,17 @@ func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection,
sentAt = user.ReauthenticationSentAt
includeFields = append(includeFields, "reauthentication_token", "reauthentication_sent_at")
default:
return internalServerError("invalid otp type")
return "", internalServerError("invalid otp type")
}

if sentAt != nil && !sentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) {
return MaxFrequencyLimitError
return "", MaxFrequencyLimitError
}

oldToken := *token
otp, err := crypto.GenerateOtp(config.Sms.OtpLength)
if err != nil {
return internalServerError("error generating otp").WithInternalError(err)
return "", internalServerError("error generating otp").WithInternalError(err)
}
*token = fmt.Sprintf("%x", sha256.Sum224([]byte(phone+otp)))

Expand All @@ -86,9 +86,10 @@ func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection,
message = strings.Replace(config.Sms.Template, "{{ .Code }}", otp, -1)
}

if serr := smsProvider.SendMessage(phone, message, channel); serr != nil {
messageID, serr := smsProvider.SendMessage(phone, message, channel)
if serr != nil {
*token = oldToken
return serr
return messageID, serr
}

now := time.Now()
Expand All @@ -102,5 +103,5 @@ func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection,
user.ReauthenticationSentAt = &now
}

return errors.Wrap(tx.UpdateOnly(user, includeFields...), "Database error updating user for confirmation")
return messageID, errors.Wrap(tx.UpdateOnly(user, includeFields...), "Database error updating user for confirmation")
}
6 changes: 3 additions & 3 deletions internal/api/phone_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ type TestSmsProvider struct {
mock.Mock
}

func (t *TestSmsProvider) SendMessage(phone string, message string, channel string) error {
return nil
func (t *TestSmsProvider) SendMessage(phone string, message string, channel string) (string, error) {
return "", nil
}

func TestPhone(t *testing.T) {
Expand Down Expand Up @@ -99,7 +99,7 @@ func (ts *PhoneTestSuite) TestSendPhoneConfirmation() {

for _, c := range cases {
ts.Run(c.desc, func() {
err = ts.API.sendPhoneConfirmation(ctx, ts.API.db, u, "123456789", c.otpType, &TestSmsProvider{}, sms_provider.SMSProvider)
_, err = ts.API.sendPhoneConfirmation(ctx, ts.API.db, u, "123456789", c.otpType, &TestSmsProvider{}, sms_provider.SMSProvider)
require.Equal(ts.T(), c.expected, err)
u, err = models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
Expand Down
16 changes: 14 additions & 2 deletions internal/api/reauthenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error {
}
}

messageID := ""
err := db.Transaction(func(tx *storage.Connection) error {
if terr := models.NewAuditLogEntry(r, tx, user, models.UserReauthenticateAction, "", nil); terr != nil {
return terr
Expand All @@ -49,7 +50,12 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error {
if terr != nil {
return badRequestError("Error sending sms: %v", terr)
}
return a.sendPhoneConfirmation(ctx, tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider)
mID, err := a.sendPhoneConfirmation(ctx, tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider)
if err != nil {
return err
}

messageID = mID
}
return nil
})
Expand All @@ -60,7 +66,13 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error {
return err
}

return sendJSON(w, http.StatusOK, make(map[string]string))
ret := map[string]any{}
if messageID != "" {
ret["message_id"] = messageID

}

return sendJSON(w, http.StatusOK, ret)
}

// verifyReauthentication checks if the nonce provided is valid
Expand Down
20 changes: 17 additions & 3 deletions internal/api/resend.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
}
}

messageID := ""
err = db.Transaction(func(tx *storage.Connection) error {
mailer := a.Mailer(ctx)
referrer := a.getReferrer(r)
Expand All @@ -131,15 +132,23 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
if terr != nil {
return terr
}
return a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, sms_provider.SMSProvider)
mID, terr := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, sms_provider.SMSProvider)
if terr != nil {
return terr
}
messageID = mID
case emailChangeVerification:
return a.sendEmailChange(tx, config, user, mailer, params.Email, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow)
case phoneChangeVerification:
smsProvider, terr := sms_provider.GetSmsProvider(*config)
if terr != nil {
return terr
}
return a.sendPhoneConfirmation(ctx, tx, user, user.PhoneChange, phoneChangeVerification, smsProvider, sms_provider.SMSProvider)
mID, terr := a.sendPhoneConfirmation(ctx, tx, user, user.PhoneChange, phoneChangeVerification, smsProvider, sms_provider.SMSProvider)
if terr != nil {
return terr
}
messageID = mID
}
return nil
})
Expand All @@ -151,5 +160,10 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
return internalServerError("Unable to process request").WithInternalError(err)
}

return sendJSON(w, http.StatusOK, map[string]string{})
ret := map[string]any{}
if messageID != "" {
ret["message_id"] = messageID
}

return sendJSON(w, http.StatusOK, ret)
}
2 changes: 1 addition & 1 deletion internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
if terr != nil {
return badRequestError("Error sending confirmation sms: %v", terr)
}
if terr = a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil {
if _, terr := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil {
return badRequestError("Error sending confirmation sms: %v", terr)
}
}
Expand Down
21 changes: 11 additions & 10 deletions internal/api/sms_provider/messagebird.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type MessagebirdResponseRecipients struct {
}

type MessagebirdResponse struct {
ID string `json:"id"`
Recipients MessagebirdResponseRecipients `json:"recipients"`
}

Expand Down Expand Up @@ -55,17 +56,17 @@ func NewMessagebirdProvider(config conf.MessagebirdProviderConfiguration) (SmsPr
}, nil
}

func (t *MessagebirdProvider) SendMessage(phone string, message string, channel string) error {
func (t *MessagebirdProvider) SendMessage(phone string, message string, channel string) (string, error) {
switch channel {
case SMSProvider:
return t.SendSms(phone, message)
default:
return fmt.Errorf("channel type %q is not supported for Messagebird", channel)
return "", fmt.Errorf("channel type %q is not supported for Messagebird", channel)
}
}

// Send an SMS containing the OTP with Messagebird's API
func (t *MessagebirdProvider) SendSms(phone string, message string) error {
func (t *MessagebirdProvider) SendSms(phone string, message string) (string, error) {
body := url.Values{
"originator": {t.Config.Originator},
"body": {message},
Expand All @@ -77,34 +78,34 @@ func (t *MessagebirdProvider) SendSms(phone string, message string) error {
client := &http.Client{Timeout: defaultTimeout}
r, err := http.NewRequest("POST", t.APIPath, strings.NewReader(body.Encode()))
if err != nil {
return err
return "", err
}
r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
r.Header.Add("Authorization", "AccessKey "+t.Config.AccessKey)
res, err := client.Do(r)
if err != nil {
return err
return "", err
}

if res.StatusCode == http.StatusBadRequest || res.StatusCode == http.StatusForbidden || res.StatusCode == http.StatusUnauthorized || res.StatusCode == http.StatusUnprocessableEntity {
resp := &MessagebirdErrResponse{}
if err := json.NewDecoder(res.Body).Decode(resp); err != nil {
return err
return "", err
}
return resp
return "", resp
}
defer utilities.SafeClose(res.Body)

// validate sms status
resp := &MessagebirdResponse{}
derr := json.NewDecoder(res.Body).Decode(resp)
if derr != nil {
return derr
return "", derr
}

if resp.Recipients.TotalSentCount == 0 {
return fmt.Errorf("messagebird error: total sent count is 0")
return "", fmt.Errorf("messagebird error: total sent count is 0")
}

return nil
return resp.ID, nil
}
2 changes: 1 addition & 1 deletion internal/api/sms_provider/sms_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func init() {
}

type SmsProvider interface {
SendMessage(phone, message, channel string) error
SendMessage(phone, message, channel string) (string, error)
}

func GetSmsProvider(config conf.GlobalConfiguration) (SmsProvider, error) {
Expand Down
20 changes: 11 additions & 9 deletions internal/api/sms_provider/sms_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ func (ts *SmsProviderTestSuite) TestTwilioSendSms() {
MatchHeader("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(twilioProvider.Config.AccountSid+":"+twilioProvider.Config.AuthToken))).
MatchType("url").BodyString(body.Encode()).
Reply(200).JSON(SmsStatus{
To: "+" + phone,
From: twilioProvider.Config.MessageServiceSid,
Status: "sent",
Body: message,
To: "+" + phone,
From: twilioProvider.Config.MessageServiceSid,
Status: "sent",
Body: message,
MessageSID: "abcdef",
}),
ExpectedError: nil,
},
Expand All @@ -102,8 +103,9 @@ func (ts *SmsProviderTestSuite) TestTwilioSendSms() {
ErrorMessage: "failed to send sms",
ErrorCode: "401",
Status: "failed",
MessageSID: "abcdef",
}),
ExpectedError: fmt.Errorf("twilio error: %v %v", "failed to send sms", "401"),
ExpectedError: fmt.Errorf("twilio error: %v %v for message %v", "failed to send sms", "401", "abcdef"),
},
{
Desc: "Non-2xx status code returned",
Expand All @@ -127,7 +129,7 @@ func (ts *SmsProviderTestSuite) TestTwilioSendSms() {

for _, c := range cases {
ts.Run(c.Desc, func() {
err = twilioProvider.SendSms(phone, message, SMSProvider)
_, err = twilioProvider.SendSms(phone, message, SMSProvider)
require.Equal(ts.T(), c.ExpectedError, err)
})
}
Expand Down Expand Up @@ -156,7 +158,7 @@ func (ts *SmsProviderTestSuite) TestMessagebirdSendSms() {
},
})

err = messagebirdProvider.SendSms(phone, message)
_, err = messagebirdProvider.SendSms(phone, message)
require.NoError(ts.T(), err)
}

Expand Down Expand Up @@ -185,7 +187,7 @@ func (ts *SmsProviderTestSuite) TestVonageSendSms() {
},
})

err = vonageProvider.SendSms(phone, message)
_, err = vonageProvider.SendSms(phone, message)
require.NoError(ts.T(), err)
}

Expand All @@ -211,6 +213,6 @@ func (ts *SmsProviderTestSuite) TestTextLocalSendSms() {
Errors: []TextlocalError{},
})

err = textlocalProvider.SendSms(phone, message)
_, err = textlocalProvider.SendSms(phone, message)
require.NoError(ts.T(), err)
}
Loading