diff --git a/connector/dummy.go b/connector/dummy.go index 23aeff71..e9c7e8e7 100644 --- a/connector/dummy.go +++ b/connector/dummy.go @@ -4,7 +4,10 @@ import ( "bytes" "context" "errors" + "fmt" + "github.com/sirupsen/logrus" "sync" + "sync/atomic" "time" "github.com/ProtonMail/gluon/constants" @@ -57,6 +60,8 @@ type Dummy struct { uidValidity imap.UID allowMessageCreateWithUnknownMailboxID bool + + updatesAllowedToFail int32 } func NewDummy(usernames []string, password []byte, period time.Duration, flags, permFlags, attrs imap.FlagSet) *Dummy { @@ -77,7 +82,16 @@ func NewDummy(usernames []string, password []byte, period time.Duration, flags, go func() { conn.ticker.Tick(func(time.Time) { for _, update := range conn.popUpdates() { - defer update.Wait() + defer func() { + err, ok := update.Wait() + if ok && err != nil { + if atomic.LoadInt32(&conn.updatesAllowedToFail) == 0 { + panic(fmt.Sprintf("Failed to apply update %v: %v", update.String(), err)) + } else { + logrus.Errorf("Failed to apply update %v: %v", update.String(), err) + } + } + }() select { case conn.updateCh <- update: @@ -263,9 +277,13 @@ func (conn *Dummy) SetUIDValidity(newUIDValidity imap.UID) error { func (conn *Dummy) Sync(ctx context.Context) error { for _, mailbox := range conn.state.getMailboxes() { update := imap.NewMailboxCreated(mailbox) - defer update.WaitContext(ctx) conn.updateCh <- update + + err, ok := update.WaitContext(ctx) + if ok && err != nil { + return fmt.Errorf("failed to apply update %v:%w", update.String(), err) + } } var updates []*imap.MessageCreated @@ -280,10 +298,14 @@ func (conn *Dummy) Sync(ctx context.Context) error { } update := imap.NewMessagesCreated(conn.allowMessageCreateWithUnknownMailboxID, updates...) - defer update.WaitContext(ctx) conn.updateCh <- update + err, ok := update.WaitContext(ctx) + if ok && err != nil { + return fmt.Errorf("failed to apply update %v:%w", update.String(), err) + } + return nil } @@ -419,3 +441,14 @@ func (conn *Dummy) validateName(name []string) (bool, error) { return exclusive, nil } + +func (conn *Dummy) SetUpdatesAllowedToFail(value bool) { + var v int32 + if value { + v = 1 + } else { + v = 0 + } + + atomic.StoreInt32(&conn.updatesAllowedToFail, v) +} diff --git a/connector/dummy_test.go b/connector/dummy_test.go index 0e80c0a6..22991ca9 100644 --- a/connector/dummy_test.go +++ b/connector/dummy_test.go @@ -28,7 +28,7 @@ func TestDummyConnector_validateUpdate(t *testing.T) { go func() { for update := range conn.GetUpdates() { - update.Done() + update.Done(nil) } }() diff --git a/imap/update_waiter.go b/imap/update_waiter.go index ef20a8c5..86874ab6 100644 --- a/imap/update_waiter.go +++ b/imap/update_waiter.go @@ -6,36 +6,43 @@ import ( type Waiter interface { // Wait waits until the update has been marked as done. - Wait() + Wait() (error, bool) // WaitContext waits until the update has been marked as done or the context is cancelled. - WaitContext(context.Context) + WaitContext(context.Context) (error, bool) - // Done marks the update as done. - Done() + // Done marks the update as done and report an error (if any). + Done(error) } type updateWaiter struct { - waitCh chan struct{} + waitCh chan error } func newUpdateWaiter() *updateWaiter { return &updateWaiter{ - waitCh: make(chan struct{}), + waitCh: make(chan error, 1), } } -func (w *updateWaiter) Wait() { - <-w.waitCh +func (w *updateWaiter) Wait() (error, bool) { + err, ok := <-w.waitCh + return err, ok } -func (w *updateWaiter) WaitContext(ctx context.Context) { +func (w *updateWaiter) WaitContext(ctx context.Context) (error, bool) { select { case <-ctx.Done(): - case <-w.waitCh: + return nil, false + case err, ok := <-w.waitCh: + return err, ok } } -func (w *updateWaiter) Done() { +func (w *updateWaiter) Done(err error) { + if err != nil { + w.waitCh <- err + } + close(w.waitCh) } diff --git a/internal/backend/connector_updates.go b/internal/backend/connector_updates.go index e58687d2..15f484b4 100644 --- a/internal/backend/connector_updates.go +++ b/internal/backend/connector_updates.go @@ -19,50 +19,54 @@ import ( // apply an incoming update originating from the connector. func (user *user) apply(ctx context.Context, update imap.Update) error { - defer update.Done() - logrus.WithField("update", update).WithField("user-id", user.userID).Debug("Applying update") - switch update := update.(type) { - case *imap.MailboxCreated: - return user.applyMailboxCreated(ctx, update) + err := func() error { + switch update := update.(type) { + case *imap.MailboxCreated: + return user.applyMailboxCreated(ctx, update) - case *imap.MailboxDeleted: - return user.applyMailboxDeleted(ctx, update) + case *imap.MailboxDeleted: + return user.applyMailboxDeleted(ctx, update) - case *imap.MailboxUpdated: - return user.applyMailboxUpdated(ctx, update) + case *imap.MailboxUpdated: + return user.applyMailboxUpdated(ctx, update) - case *imap.MailboxIDChanged: - return user.applyMailboxIDChanged(ctx, update) + case *imap.MailboxIDChanged: + return user.applyMailboxIDChanged(ctx, update) - case *imap.MessagesCreated: - return user.applyMessagesCreated(ctx, update) + case *imap.MessagesCreated: + return user.applyMessagesCreated(ctx, update) - case *imap.MessageMailboxesUpdated: - return user.applyMessageMailboxesUpdated(ctx, update) + case *imap.MessageMailboxesUpdated: + return user.applyMessageMailboxesUpdated(ctx, update) - case *imap.MessageFlagsUpdated: - return user.applyMessageFlagsUpdated(ctx, update) + case *imap.MessageFlagsUpdated: + return user.applyMessageFlagsUpdated(ctx, update) - case *imap.MessageIDChanged: - return user.applyMessageIDChanged(ctx, update) + case *imap.MessageIDChanged: + return user.applyMessageIDChanged(ctx, update) - case *imap.MessageDeleted: - return user.applyMessageDeleted(ctx, update) + case *imap.MessageDeleted: + return user.applyMessageDeleted(ctx, update) - case *imap.MessageUpdated: - return user.applyMessageUpdated(ctx, update) + case *imap.MessageUpdated: + return user.applyMessageUpdated(ctx, update) - case *imap.UIDValidityBumped: - return user.applyUIDValidityBumped(ctx, update) + case *imap.UIDValidityBumped: + return user.applyUIDValidityBumped(ctx, update) - case *imap.Noop: - return nil + case *imap.Noop: + return nil - default: - return fmt.Errorf("bad update") - } + default: + return fmt.Errorf("bad update") + } + }() + + update.Done(err) + + return err } // applyMailboxCreated applies a MailboxCreated update. @@ -71,17 +75,7 @@ func (user *user) applyMailboxCreated(ctx context.Context, update *imap.MailboxC return fmt.Errorf("attempting to create protected mailbox (recovery)") } - if err := user.imapLimits.CheckUIDValidity(user.globalUIDValidity); err != nil { - return err - } - if exists, err := db.ReadResult(ctx, user.db, func(ctx context.Context, client *ent.Client) (bool, error) { - if mailboxCount, err := db.GetMailboxCount(ctx, client); err != nil { - return false, err - } else if err := user.imapLimits.CheckMailBoxCount(mailboxCount); err != nil { - return false, err - } - return db.MailboxExistsWithRemoteID(ctx, client, update.Mailbox.ID) }); err != nil { return err @@ -89,7 +83,17 @@ func (user *user) applyMailboxCreated(ctx context.Context, update *imap.MailboxC return nil } + if err := user.imapLimits.CheckUIDValidity(user.globalUIDValidity); err != nil { + return err + } + return user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if mailboxCount, err := db.GetMailboxCount(ctx, tx.Client()); err != nil { + return err + } else if err := user.imapLimits.CheckMailBoxCount(mailboxCount); err != nil { + return err + } + if _, err := db.CreateMailbox( ctx, tx, diff --git a/internal/backend/update_injector.go b/internal/backend/update_injector.go index bfbc0b4c..d8cb1f81 100644 --- a/internal/backend/update_injector.go +++ b/internal/backend/update_injector.go @@ -62,6 +62,8 @@ func (u *updateInjector) forward(ctx context.Context, updateCh <-chan imap.Updat for { select { + case <-ctx.Done(): + return case update, ok := <-updateCh: if !ok { return @@ -76,14 +78,14 @@ func (u *updateInjector) forward(ctx context.Context, updateCh <-chan imap.Updat } // send the update on the updates channel, optionally blocking until it has been processed. -func (u *updateInjector) send(ctx context.Context, update imap.Update, withBlock ...bool) { +func (u *updateInjector) send(ctx context.Context, update imap.Update) { select { case <-u.forwardQuitCh: return case u.updatesCh <- update: - if len(withBlock) > 0 && withBlock[0] { - update.WaitContext(ctx) - } + + case <-ctx.Done(): + return } } diff --git a/tests/imap_limits_test.go b/tests/imap_limits_test.go index 26cb839c..bd57c35d 100644 --- a/tests/imap_limits_test.go +++ b/tests/imap_limits_test.go @@ -46,6 +46,7 @@ func TestMaxUIDLimitRespected_Append(t *testing.T) { func TestMaxMessageLimitRespected_Copy(t *testing.T) { runOneToOneTestClientWithAuth(t, defaultServerOptions(t, withIMAPLimits(testIMAPLimits())), func(client *client.Client, session *testSession) { + session.setUpdatesAllowedToFail("user", true) require.NoError(t, client.Create("mbox1")) require.NoError(t, doAppendWithClient(client, "mbox1", "To: Foo@bar.com", time.Now())) require.NoError(t, doAppendWithClient(client, "INBOX", "To: Bar@bar.com", time.Now())) @@ -57,6 +58,7 @@ func TestMaxMessageLimitRespected_Copy(t *testing.T) { func TestMaxUIDLimitRespected_Copy(t *testing.T) { runOneToOneTestClientWithAuth(t, defaultServerOptions(t, withIMAPLimits(testIMAPLimits())), func(client *client.Client, session *testSession) { + session.setUpdatesAllowedToFail("user", true) require.NoError(t, client.Create("mbox1")) require.NoError(t, doAppendWithClient(client, "mbox1", "To: Foo@bar.com", time.Now())) require.NoError(t, doAppendWithClient(client, "INBOX", "To: Bar@bar.com", time.Now())) @@ -76,6 +78,7 @@ func TestMaxUIDLimitRespected_Copy(t *testing.T) { func TestMaxMessageLimitRespected_Move(t *testing.T) { runOneToOneTestClientWithAuth(t, defaultServerOptions(t, withIMAPLimits(testIMAPLimits())), func(client *client.Client, session *testSession) { + session.setUpdatesAllowedToFail("user", true) require.NoError(t, client.Create("mbox1")) require.NoError(t, doAppendWithClient(client, "mbox1", "To: Foo@bar.com", time.Now())) require.NoError(t, doAppendWithClient(client, "INBOX", "To: Bar@bar.com", time.Now())) @@ -87,6 +90,7 @@ func TestMaxMessageLimitRespected_Move(t *testing.T) { func TestMaxUIDLimitRespected_Move(t *testing.T) { runOneToOneTestClientWithAuth(t, defaultServerOptions(t, withIMAPLimits(testIMAPLimits())), func(client *client.Client, session *testSession) { + session.setUpdatesAllowedToFail("user", true) require.NoError(t, client.Create("mbox1")) require.NoError(t, doAppendWithClient(client, "mbox1", "To: Foo@bar.com", time.Now())) require.NoError(t, doAppendWithClient(client, "INBOX", "To: Bar@bar.com", time.Now())) @@ -106,6 +110,7 @@ func TestMaxUIDLimitRespected_Move(t *testing.T) { func TestMaxUIDValidityLimitRespected(t *testing.T) { runOneToOneTestClientWithAuth(t, defaultServerOptions(t, withIMAPLimits(testIMAPLimits())), func(client *client.Client, session *testSession) { + session.setUpdatesAllowedToFail("user", true) require.NoError(t, client.Create("mbox1")) require.NoError(t, client.Delete("mbox1")) require.Error(t, client.Create("mbox2")) diff --git a/tests/recent_test.go b/tests/recent_test.go index 373d123f..a951258d 100644 --- a/tests/recent_test.go +++ b/tests/recent_test.go @@ -80,9 +80,12 @@ func TestRecentAppend(t *testing.T) { } func TestRecentStore(t *testing.T) { - runManyToOneTestWithAuth(t, defaultServerOptions(t), []int{1, 2}, func(c map[int]*testConnection, _ *testSession) { + runManyToOneTestWithAuth(t, defaultServerOptions(t), []int{1, 2}, func(c map[int]*testConnection, s *testSession) { mbox, done := c[1].doCreateTempDir() - defer done() + defer func() { + s.flush("user") + done() + }() // Create a message in mbox. c[1].doAppend(mbox, `To: 1@pm.me`).expect(`OK`) diff --git a/tests/session_test.go b/tests/session_test.go index 5443b289..b189e36d 100644 --- a/tests/session_test.go +++ b/tests/session_test.go @@ -51,6 +51,8 @@ type Connector interface { Sync(context.Context) error Flush() + + SetUpdatesAllowedToFail(bool) } type testSession struct { @@ -361,6 +363,10 @@ func (s *testSession) flush(user string) { s.conns[s.userIDs[user]].Flush() } +func (s *testSession) setUpdatesAllowedToFail(user string, value bool) { + s.conns[s.userIDs[user]].SetUpdatesAllowedToFail(value) +} + func forMessageInMBox(rr io.Reader, fn func(messageDelimiter, literal []byte)) error { mr := mbox.NewReader(rr) diff --git a/tests/updates_test.go b/tests/updates_test.go index 36d05846..4bbdd332 100644 --- a/tests/updates_test.go +++ b/tests/updates_test.go @@ -313,6 +313,7 @@ func TestBatchMessageAddedWithMultipleFlags(t *testing.T) { func TestMessageCreatedWithIgnoreMissingMailbox(t *testing.T) { runOneToOneTestClientWithAuth(t, defaultServerOptions(t), func(c *client.Client, s *testSession) { mailboxID := s.mailboxCreated("user", []string{"mbox"}) + s.setUpdatesAllowedToFail("user", true) { // First round fails as a missing mailbox is not allowed. s.messageCreatedWithMailboxes("user", []imap.MailboxID{mailboxID, "THIS MAILBOX DOES NOT EXISTS"}, []byte("To: Test"), time.Now())