Skip to content

Commit

Permalink
fix: do not crash process on invalid smtp url (#2890)
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Nov 14, 2022
1 parent 40e2258 commit c5d3ebc
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 22 deletions.
7 changes: 6 additions & 1 deletion cmd/courier/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,12 @@ func Watch(ctx cx.Context, r driver.Registry) error {

r.Logger().Println("Courier worker started.")
if err := graceful.Graceful(func() error {
return r.Courier(ctx).Work(ctx)
c, err := r.Courier(ctx)
if err != nil {
return err
}

return c.Work(ctx)
}, func(_ cx.Context) error {
cancel()
return nil
Expand Down
12 changes: 8 additions & 4 deletions courier/courier.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type (
}

Provider interface {
Courier(ctx context.Context) Courier
Courier(ctx context.Context) (Courier, error)
}

ConfigProvider interface {
Expand All @@ -57,13 +57,17 @@ type (
}
)

func NewCourier(ctx context.Context, deps Dependencies) Courier {
func NewCourier(ctx context.Context, deps Dependencies) (Courier, error) {
smtp, err := newSMTP(ctx, deps)
if err != nil {
return nil, err
}
return &courier{
smsClient: newSMS(ctx, deps),
smtpClient: newSMTP(ctx, deps),
smtpClient: smtp,
deps: deps,
backoff: backoff.NewExponentialBackOff(),
}
}, nil
}

func (c *courier) FailOnDispatchError() {
Expand Down
6 changes: 4 additions & 2 deletions courier/courier_dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ func TestDispatchMessageWithInvalidSMTP(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

c := reg.Courier(ctx)
c, err := reg.Courier(ctx)
require.NoError(t, err)

t.Run("case=failed sending", func(t *testing.T) {
id := queueNewMessage(t, ctx, c, reg)
Expand Down Expand Up @@ -76,7 +77,8 @@ func TestDispatchMessage2(t *testing.T) {
conf, reg := internal.NewRegistryDefaultWithDSN(t, "")
conf.MustSet(ctx, config.ViperKeyCourierMessageRetries, 1)

c := reg.Courier(ctx)
c, err := reg.Courier(ctx)
require.NoError(t, err)

ctx, cancel := context.WithCancel(ctx)
defer cancel()
Expand Down
8 changes: 5 additions & 3 deletions courier/sms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ func TestQueueSMS(t *testing.T) {
conf.MustSet(ctx, config.ViperKeyCourierSMTPURL, "http://foo.url")
reg.Logger().Level = logrus.TraceLevel

c := reg.Courier(ctx)
c, err := reg.Courier(ctx)
require.NoError(t, err)

ctx, cancel := context.WithCancel(ctx)
defer t.Cleanup(cancel)
Expand Down Expand Up @@ -132,11 +133,12 @@ func TestDisallowedInternalNetwork(t *testing.T) {
conf.MustSet(ctx, config.ViperKeyClientHTTPNoPrivateIPRanges, true)
reg.Logger().Level = logrus.TraceLevel

c := reg.Courier(ctx)
c, err := reg.Courier(ctx)
require.NoError(t, err)
c.(interface {
FailOnDispatchError()
}).FailOnDispatchError()
_, err := c.QueueSMS(ctx, sms.NewTestStub(reg, &sms.TestStubModel{
_, err = c.QueueSMS(ctx, sms.NewTestStub(reg, &sms.TestStubModel{
To: "+12065550101",
Body: "test-sms-body-1",
}))
Expand Down
10 changes: 7 additions & 3 deletions courier/smtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ type smtpClient struct {
NewTemplateFromMessage func(d template.Dependencies, msg Message) (EmailTemplate, error)
}

func newSMTP(ctx context.Context, deps Dependencies) *smtpClient {
uri := deps.CourierConfig().CourierSMTPURL(ctx)
func newSMTP(ctx context.Context, deps Dependencies) (*smtpClient, error) {
uri, err := deps.CourierConfig().CourierSMTPURL(ctx)
if err != nil {
return nil, err
}

var tlsCertificates []tls.Certificate
clientCertPath := deps.CourierConfig().CourierSMTPClientCertPath(ctx)
clientKeyPath := deps.CourierConfig().CourierSMTPClientKeyPath(ctx)
Expand Down Expand Up @@ -94,7 +98,7 @@ func newSMTP(ctx context.Context, deps Dependencies) *smtpClient {

GetTemplateType: GetEmailTemplateType,
NewTemplateFromMessage: NewEmailTemplateFromMessage,
}
}, nil
}

func (c *courier) SetGetEmailTemplateType(f func(t EmailTemplate) (TemplateType, error)) {
Expand Down
11 changes: 8 additions & 3 deletions courier/smtp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,13 @@ func TestNewSMTP(t *testing.T) {

setupCourier := func(stringURL string) courier.Courier {
conf.MustSet(ctx, config.ViperKeyCourierSMTPURL, stringURL)
t.Logf("SMTP URL: %s", conf.CourierSMTPURL(ctx).String())
u, err := conf.CourierSMTPURL(ctx)
require.NoError(t, err)
t.Logf("SMTP URL: %s", u.String())

return courier.NewCourier(ctx, reg)
c, err := courier.NewCourier(ctx, reg)
require.NoError(t, err)
return c
}

if testing.Short() {
Expand Down Expand Up @@ -107,7 +111,8 @@ func TestQueueEmail(t *testing.T) {
conf.MustSet(ctx, config.ViperKeyCourierSMTPFrom, "test-stub@ory.sh")
reg.Logger().Level = logrus.TraceLevel

c := reg.Courier(ctx)
c, err := reg.Courier(ctx)
require.NoError(t, err)

ctx, cancel := context.WithCancel(ctx)
defer cancel()
Expand Down
13 changes: 10 additions & 3 deletions driver/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"testing"
"time"

"github.com/ory/herodot"

"github.com/ory/x/contextx"

"github.com/ory/jsonschema/v3/httploader"
Expand Down Expand Up @@ -258,7 +260,7 @@ type (
Config() *Config
}
CourierConfigs interface {
CourierSMTPURL(ctx context.Context) *url.URL
CourierSMTPURL(ctx context.Context) (*url.URL, error)
CourierSMTPClientCertPath(ctx context.Context) string
CourierSMTPClientKeyPath(ctx context.Context) string
CourierSMTPFrom(ctx context.Context) string
Expand Down Expand Up @@ -849,8 +851,13 @@ func (p *Config) SelfAdminURL(ctx context.Context) *url.URL {
return p.baseURL(ctx, ViperKeyAdminBaseURL, ViperKeyAdminHost, ViperKeyAdminPort, 4434)
}

func (p *Config) CourierSMTPURL(ctx context.Context) *url.URL {
return p.ParseURIOrFail(ctx, ViperKeyCourierSMTPURL)
func (p *Config) CourierSMTPURL(ctx context.Context) (*url.URL, error) {
source := p.GetProvider(ctx).String(ViperKeyCourierSMTPURL)
parsed, err := url.Parse(source)
if err != nil {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to parse the project's SMTP URL. Please ensure that it is properly escaped: https://www.ory.sh/dr/3").WithDebugf("%s", err))
}
return parsed, nil
}

func (p *Config) OAuth2ProviderHeader(ctx context.Context) http.Header {
Expand Down
28 changes: 28 additions & 0 deletions driver/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,34 @@ func TestCourierSMS(t *testing.T) {
})
}

func TestCourierSMTPUrl(t *testing.T) {
ctx := context.Background()

for _, tc := range []string{
"smtp://a:basdasdasda%2Fc@email-smtp.eu-west-3.amazonaws.com:587/",
"smtp://a:b$c@email-smtp.eu-west-3.amazonaws.com:587/",
"smtp://a/a:bc@email-smtp.eu-west-3.amazonaws.com:587",
"smtp://aa:b+c@email-smtp.eu-west-3.amazonaws.com:587/",
"smtp://user?name:password@email-smtp.eu-west-3.amazonaws.com:587/",
"smtp://username:pass%2Fword@email-smtp.eu-west-3.amazonaws.com:587/",
} {
t.Run("case="+tc, func(t *testing.T) {
conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.WithValue(config.ViperKeyCourierSMTPURL, tc), configx.SkipValidation())
require.NoError(t, err)
parsed, err := conf.CourierSMTPURL(ctx)
require.NoError(t, err)
assert.Equal(t, tc, parsed.String())
})
}

t.Run("invalid", func(t *testing.T) {
conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.WithValue(config.ViperKeyCourierSMTPURL, "smtp://a:b/c@email-smtp.eu-west-3.amazonaws.com:587/"), configx.SkipValidation())
require.NoError(t, err)
_, err = conf.CourierSMTPURL(ctx)
require.Error(t, err)
})
}

func TestCourierMessageTTL(t *testing.T) {
ctx := context.Background()

Expand Down
2 changes: 1 addition & 1 deletion driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ func (m *RegistryDefault) SetPersister(p persistence.Persister) {
m.persister = p
}

func (m *RegistryDefault) Courier(ctx context.Context) courier.Courier {
func (m *RegistryDefault) Courier(ctx context.Context) (courier.Courier, error) {
return courier.NewCourier(ctx, m)
}

Expand Down
7 changes: 6 additions & 1 deletion selfservice/strategy/code/sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ func (s *RecoveryCodeSender) SendRecoveryCodeTo(ctx context.Context, i *identity
func (s *RecoveryCodeSender) send(ctx context.Context, via string, t courier.EmailTemplate) error {
switch f := stringsx.SwitchExact(via); {
case f.AddCase(identity.AddressTypeEmail):
_, err := s.deps.Courier(ctx).QueueEmail(ctx, t)
c, err := s.deps.Courier(ctx)
if err != nil {
return err
}

_, err = c.QueueEmail(ctx, t)
return err
default:
return f.ToUnknownCaseErr()
Expand Down
6 changes: 5 additions & 1 deletion selfservice/strategy/link/sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,11 @@ func (s *Sender) SendVerificationTokenTo(ctx context.Context, f *verification.Fl
func (s *Sender) send(ctx context.Context, via string, t courier.EmailTemplate) error {
switch via {
case identity.AddressTypeEmail:
_, err := s.r.Courier(ctx).QueueEmail(ctx, t)
c, err := s.r.Courier(ctx)
if err != nil {
return err
}
_, err = c.QueueEmail(ctx, t)
return err
default:
return errors.Errorf("received unexpected via type: %s", via)
Expand Down

0 comments on commit c5d3ebc

Please sign in to comment.