diff --git a/internal/api/phone.go b/internal/api/phone.go index a9ffe1ff4..28fcb0e49 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -67,30 +67,41 @@ func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection, return "", internalServerError("invalid otp type") } + // intentionally keeping this before the test OTP, so that the behavior + // of regular and test OTPs is similar if sentAt != nil && !sentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) { return "", MaxFrequencyLimitError } - oldToken := *token - otp, err := crypto.GenerateOtp(config.Sms.OtpLength) - if err != nil { - return "", internalServerError("error generating otp").WithInternalError(err) - } - *token = crypto.GenerateTokenHash(phone, otp) - var message string - if config.Sms.Template == "" { - message = fmt.Sprintf(defaultSmsMessage, otp) - } else { - message = strings.Replace(config.Sms.Template, "{{ .Code }}", otp, -1) + now := time.Now() + + var otp, messageID string + + if testOTP, ok := config.Sms.GetTestOTP(phone, now); ok { + otp = testOTP + messageID = "test-otp" } - messageID, serr := smsProvider.SendMessage(phone, message, channel) - if serr != nil { - *token = oldToken - return messageID, serr + if otp == "" { // not using test OTPs + otp, err := crypto.GenerateOtp(config.Sms.OtpLength) + if err != nil { + return "", internalServerError("error generating otp").WithInternalError(err) + } + + var message string + if config.Sms.Template == "" { + message = fmt.Sprintf(defaultSmsMessage, otp) + } else { + message = strings.Replace(config.Sms.Template, "{{ .Code }}", otp, -1) + } + + messageID, err = smsProvider.SendMessage(phone, message, channel) + if err != nil { + return messageID, err + } } - now := time.Now() + *token = crypto.GenerateTokenHash(phone, otp) switch otpType { case phoneConfirmationOtp: diff --git a/internal/api/phone_test.go b/internal/api/phone_test.go index 93a7e7d65..7047de896 100644 --- a/internal/api/phone_test.go +++ b/internal/api/phone_test.go @@ -28,9 +28,12 @@ type PhoneTestSuite struct { type TestSmsProvider struct { mock.Mock + + SentMessages int } func (t *TestSmsProvider) SendMessage(phone string, message string, channel string) (string, error) { + t.SentMessages += 1 return "", nil } @@ -66,7 +69,7 @@ func (ts *PhoneTestSuite) TestFormatPhoneNumber() { assert.Equal(ts.T(), "123456789", actual) } -func (ts *PhoneTestSuite) TestSendPhoneConfirmation() { +func doTestSendPhoneConfirmation(ts *PhoneTestSuite, useTestOTP bool) { u, err := models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) require.NoError(ts.T(), err) ctx := context.Background() @@ -97,13 +100,31 @@ func (ts *PhoneTestSuite) TestSendPhoneConfirmation() { }, } + if useTestOTP { + ts.API.config.Sms.TestOTP = map[string]string{ + "123456789": "123456", + } + } else { + ts.API.config.Sms.TestOTP = nil + } + for _, c := range cases { ts.Run(c.desc, func() { - _, err = ts.API.sendPhoneConfirmation(ctx, ts.API.db, u, "123456789", c.otpType, &TestSmsProvider{}, sms_provider.SMSProvider) + provider := &TestSmsProvider{} + + _, err = ts.API.sendPhoneConfirmation(ctx, ts.API.db, u, "123456789", c.otpType, provider, sms_provider.SMSProvider) require.Equal(ts.T(), c.expected, err) u, err = models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) require.NoError(ts.T(), err) + if c.expected == nil { + if useTestOTP { + require.Equal(ts.T(), provider.SentMessages, 0) + } else { + require.Equal(ts.T(), provider.SentMessages, 1) + } + } + switch c.otpType { case phoneConfirmationOtp: require.NotEmpty(ts.T(), u.ConfirmationToken) @@ -120,6 +141,14 @@ func (ts *PhoneTestSuite) TestSendPhoneConfirmation() { } } +func (ts *PhoneTestSuite) TestSendPhoneConfirmation() { + doTestSendPhoneConfirmation(ts, false) +} + +func (ts *PhoneTestSuite) TestSendPhoneConfirmationWithTestOTP() { + doTestSendPhoneConfirmation(ts, true) +} + func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { u, err := models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) require.NoError(ts.T(), err) @@ -198,6 +227,7 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { ts.Config.Sms.Messagebird.AccessKey = "" ts.Config.Sms.Textlocal.ApiKey = "" ts.Config.Sms.Vonage.ApiKey = "" + for _, c := range cases { for _, provider := range smsProviders { ts.Config.Sms.Provider = provider diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 3832e65a3..c7086dd60 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -221,12 +221,15 @@ type PhoneProviderConfiguration struct { } type SmsProviderConfiguration struct { - Autoconfirm bool `json:"autoconfirm"` - MaxFrequency time.Duration `json:"max_frequency" split_words:"true"` - OtpExp uint `json:"otp_exp" split_words:"true"` - OtpLength int `json:"otp_length" split_words:"true"` - Provider string `json:"provider"` - Template string `json:"template"` + Autoconfirm bool `json:"autoconfirm"` + MaxFrequency time.Duration `json:"max_frequency" split_words:"true"` + OtpExp uint `json:"otp_exp" split_words:"true"` + OtpLength int `json:"otp_length" split_words:"true"` + Provider string `json:"provider"` + Template string `json:"template"` + TestOTP map[string]string `json:"test_otp" split_words:"true"` + TestOTPValidUntil time.Time `json:"test_otp_valid_until" split_words:"true"` + Twilio TwilioProviderConfiguration `json:"twilio"` TwilioVerify TwilioVerifyProviderConfiguration `json:"twilio_verify" split_words:"true"` Messagebird MessagebirdProviderConfiguration `json:"messagebird"` @@ -234,6 +237,15 @@ type SmsProviderConfiguration struct { Vonage VonageProviderConfiguration `json:"vonage"` } +func (c *SmsProviderConfiguration) GetTestOTP(phone string, now time.Time) (string, bool) { + if c.TestOTP != nil && (c.TestOTPValidUntil.IsZero() || now.Before(c.TestOTPValidUntil)) { + testOTP, ok := c.TestOTP[phone] + return testOTP, ok + } + + return "", false +} + type TwilioProviderConfiguration struct { AccountSid string `json:"account_sid" split_words:"true"` AuthToken string `json:"auth_token" split_words:"true"`