From 5222638a1050870037eae1f7e455ee45d1aae642 Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Mon, 22 May 2023 16:41:49 +0200 Subject: [PATCH] fix(GODT-2454): Only apply state updates if db transactions succeed Ensure that all changes to the DB are made before publishing state updates for connected sessions. This improves stability of the connected clients. The states still require another transaction to be applied. However, this is currently only related to the handling of the recent flag. This will be handled in GODT-2522 instead. --- internal/state/actions.go | 178 +++++++++++++++---------------- internal/state/mailbox.go | 44 ++++---- internal/state/mailbox_fetch.go | 4 +- internal/state/mailbox_search.go | 2 +- internal/state/state.go | 176 ++++++++++++++++++++---------- internal/state/updates.go | 68 ++++++------ internal/state/updates_remote.go | 19 ---- 7 files changed, 254 insertions(+), 237 deletions(-) diff --git a/internal/state/actions.go b/internal/state/actions.go index 9b11c8f7..a9aad84d 100644 --- a/internal/state/actions.go +++ b/internal/state/actions.go @@ -59,16 +59,16 @@ func (state *State) actionCreateMailbox(ctx context.Context, tx *ent.Tx, name st return db.CreateMailboxIfNotExists(ctx, tx, res, state.delimiter, uidValidity) } -func (state *State) actionDeleteMailbox(ctx context.Context, tx *ent.Tx, mboxID ids.MailboxIDPair) error { +func (state *State) actionDeleteMailbox(ctx context.Context, tx *ent.Tx, mboxID ids.MailboxIDPair) ([]Update, error) { if err := state.user.GetRemote().DeleteMailbox(ctx, mboxID.RemoteID); err != nil { - return err + return nil, err } if err := db.DeleteMailboxWithRemoteID(ctx, tx, mboxID.RemoteID); err != nil { - return err + return nil, err } - return state.user.QueueOrApplyStateUpdate(ctx, tx, NewMailboxDeletedStateUpdate(mboxID.InternalID)) + return []Update{NewMailboxDeletedStateUpdate(mboxID.InternalID)}, nil } func (state *State) actionUpdateMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.MailboxID, newName string) error { @@ -92,17 +92,17 @@ func (state *State) actionCreateMessage( date time.Time, isSelectedMailbox bool, cameFromDrafts bool, -) (imap.UID, error) { +) ([]Update, imap.UID, error) { internalID, res, newLiteral, err := state.user.GetRemote().CreateMessage(ctx, mboxID.RemoteID, literal, flags, date) if err != nil { - return 0, err + return nil, 0, err } { // Handle the case where duplicate messages can return the same remote ID. knownInternalID, knownErr := db.GetMessageIDFromRemoteID(ctx, tx.Client(), res.ID) if knownErr != nil && !ent.IsNotFound(knownErr) { - return 0, knownErr + return nil, 0, knownErr } if knownErr == nil { // Try to collect the original message date. @@ -121,37 +121,37 @@ func (state *State) actionCreateMessage( logrus.Errorf("Append to drafts must not return an existing RemoteID (Remote=%v, Internal=%v)", res.ID, knownInternalID) - return 0, fmt.Errorf("append to drafts returned an existing remote ID") + return nil, 0, fmt.Errorf("append to drafts returned an existing remote ID") } logrus.Debugf("Deduped message detected, adding existing %v message to mailbox instead.", knownInternalID.ShortID()) - result, err := state.actionAddMessagesToMailbox(ctx, + updates, result, err := state.actionAddMessagesToMailbox(ctx, tx, []ids.MessageIDPair{{InternalID: knownInternalID, RemoteID: res.ID}}, mboxID, isSelectedMailbox, ) if err != nil { - return 0, err + return nil, 0, err } - return result[0].UID, nil + return updates, result[0].UID, nil } } parsedMessage, err := imap.NewParsedMessage(newLiteral) if err != nil { - return 0, err + return nil, 0, err } literalWithHeader, literalSize, err := rfc822.SetHeaderValueNoMemCopy(newLiteral, ids.InternalIDKey, internalID.String()) if err != nil { - return 0, fmt.Errorf("failed to set internal ID: %w", err) + return nil, 0, fmt.Errorf("failed to set internal ID: %w", err) } if err := state.user.GetStore().SetUnchecked(internalID, literalWithHeader); err != nil { - return 0, fmt.Errorf("failed to store message literal: %w", err) + return nil, 0, fmt.Errorf("failed to store message literal: %w", err) } req := db.CreateMessageReq{ @@ -165,7 +165,7 @@ func (state *State) actionCreateMessage( messageUID, flagSet, err := db.CreateAndAddMessageToMailbox(ctx, tx, mboxID.InternalID, &req) if err != nil { - return 0, err + return nil, 0, err } // We can append to non-selected mailboxes. @@ -175,15 +175,14 @@ func (state *State) actionCreateMessage( st = state } - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, newExistsStateUpdateWithExists( + updates := []Update{newExistsStateUpdateWithExists( mboxID.InternalID, []*exists{newExists(ids.MessageIDPair{InternalID: internalID, RemoteID: res.ID}, messageUID, flagSet)}, st, - )); err != nil { - return 0, err + ), } - return messageUID, nil + return updates, messageUID, nil } func (state *State) actionCreateRecoveredMessage( @@ -192,23 +191,23 @@ func (state *State) actionCreateRecoveredMessage( literal []byte, flags imap.FlagSet, date time.Time, -) (bool, error) { +) ([]Update, bool, error) { internalID := imap.NewInternalMessageID() remoteID := ids.NewRecoveredRemoteMessageID(internalID) parsedMessage, err := imap.NewParsedMessage(literal) if err != nil { - return false, err + return nil, false, err } alreadyKnown, err := state.user.GetRecoveredMessageHashesMap().Insert(internalID, literal) if err == nil && alreadyKnown { // Message is already known to us, so we ignore it. - return true, nil + return nil, true, nil } if err := state.user.GetStore().SetUnchecked(internalID, bytes.NewReader(literal)); err != nil { - return false, fmt.Errorf("failed to store message literal: %w", err) + return nil, false, fmt.Errorf("failed to store message literal: %w", err) } req := db.CreateMessageReq{ @@ -228,18 +227,17 @@ func (state *State) actionCreateRecoveredMessage( messageUID, flagSet, err := db.CreateAndAddMessageToMailbox(ctx, tx, recoveryMBoxID.InternalID, &req) if err != nil { - return false, err + return nil, false, err } - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, newExistsStateUpdateWithExists( + var updates = []Update{newExistsStateUpdateWithExists( recoveryMBoxID.InternalID, []*exists{newExists(ids.MessageIDPair{InternalID: internalID, RemoteID: remoteID}, messageUID, flagSet)}, nil, - )); err != nil { - return false, err + ), } - return false, nil + return updates, false, nil } func (state *State) actionAddMessagesToMailbox( @@ -248,26 +246,31 @@ func (state *State) actionAddMessagesToMailbox( messageIDs []ids.MessageIDPair, mboxID ids.MailboxIDPair, isMailboxSelected bool, -) ([]db.UIDWithFlags, error) { +) ([]Update, []db.UIDWithFlags, error) { + var allUpdates []Update + { haveMessageIDs, err := db.FilterMailboxContains(ctx, tx.Client(), mboxID.InternalID, messageIDs) if err != nil { - return nil, err + return nil, nil, err } if remMessageIDs := xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { return slices.Contains(haveMessageIDs, messageID.InternalID) }); len(remMessageIDs) > 0 { - if err := state.actionRemoveMessagesFromMailboxUnchecked(ctx, tx, remMessageIDs, mboxID); err != nil { - return nil, err + updates, err := state.actionRemoveMessagesFromMailboxUnchecked(ctx, tx, remMessageIDs, mboxID) + if err != nil { + return nil, nil, err } + + allUpdates = append(allUpdates, updates...) } } internalIDs, remoteIDs := ids.SplitMessageIDPairSlice(messageIDs) if err := state.user.GetRemote().AddMessagesToMailbox(ctx, remoteIDs, mboxID.RemoteID); err != nil { - return nil, err + return nil, nil, err } // Messages can be added to a mailbox that is not selected. @@ -278,14 +281,12 @@ func (state *State) actionAddMessagesToMailbox( messageUIDs, update, err := AddMessagesToMailbox(ctx, tx, mboxID.InternalID, internalIDs, st, state.imapLimits) if err != nil { - return nil, err + return nil, nil, err } - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, update); err != nil { - return nil, err - } + allUpdates = append(allUpdates, update) - return messageUIDs, nil + return allUpdates, messageUIDs, nil } func (state *State) actionAddRecoveredMessagesToMailbox( @@ -382,14 +383,14 @@ func (state *State) actionCopyMessagesOutOfRecoveryMailbox( tx *ent.Tx, messageIDs []ids.MessageIDPair, mboxID ids.MailboxIDPair, -) ([]db.UIDWithFlags, error) { +) ([]Update, []db.UIDWithFlags, error) { ids := make([]ids.MessageIDPair, 0, len(messageIDs)) // Import messages to remote. for _, id := range messageIDs { id, _, err := state.actionImportRecoveredMessage(ctx, tx, id.InternalID, mboxID.RemoteID) if err != nil { - return nil, err + return nil, nil, err } ids = append(ids, id) @@ -398,14 +399,10 @@ func (state *State) actionCopyMessagesOutOfRecoveryMailbox( // Label messages in destination. uidWithFlags, update, err := state.actionAddRecoveredMessagesToMailbox(ctx, tx, ids, mboxID) if err != nil { - return nil, err - } - - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, update); err != nil { - return nil, err + return nil, nil, err } - return uidWithFlags, nil + return []Update{update}, uidWithFlags, nil } func (state *State) actionMoveMessagesOutOfRecoveryMailbox( @@ -413,7 +410,7 @@ func (state *State) actionMoveMessagesOutOfRecoveryMailbox( tx *ent.Tx, messageIDs []ids.MessageIDPair, mboxID ids.MailboxIDPair, -) ([]db.UIDWithFlags, error) { +) ([]Update, []db.UIDWithFlags, error) { ids := make([]ids.MessageIDPair, 0, len(messageIDs)) oldInternalIDs := make([]imap.InternalMessageID, 0, len(messageIDs)) @@ -421,12 +418,12 @@ func (state *State) actionMoveMessagesOutOfRecoveryMailbox( for _, id := range messageIDs { newID, deduped, err := state.actionImportRecoveredMessage(ctx, tx, id.InternalID, mboxID.RemoteID) if err != nil { - return nil, err + return nil, nil, err } if !deduped { if err := db.MarkMessageAsDeleted(ctx, tx, id.InternalID); err != nil { - return nil, err + return nil, nil, err } } @@ -439,7 +436,7 @@ func (state *State) actionMoveMessagesOutOfRecoveryMailbox( { removeUpdates, err := RemoveMessagesFromMailbox(ctx, tx, state.user.GetRecoveryMailboxID().InternalID, oldInternalIDs) if err != nil { - return nil, err + return nil, nil, err } state.user.GetRecoveredMessageHashesMap().Erase(oldInternalIDs...) @@ -450,17 +447,12 @@ func (state *State) actionMoveMessagesOutOfRecoveryMailbox( // Label messages in destination. uidWithFlags, update, err := state.actionAddRecoveredMessagesToMailbox(ctx, tx, ids, mboxID) if err != nil { - return nil, err + return nil, nil, err } - // Publish all updates in unison. updates = append(updates, update) - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, updates...); err != nil { - return nil, err - } - - return uidWithFlags, nil + return updates, uidWithFlags, nil } // actionRemoveMessagesFromMailboxUnchecked is similar to actionRemoveMessagesFromMailbox, but it does not validate @@ -471,23 +463,18 @@ func (state *State) actionRemoveMessagesFromMailboxUnchecked( tx *ent.Tx, messageIDs []ids.MessageIDPair, mboxID ids.MailboxIDPair, -) error { +) ([]Update, error) { internalIDs, remoteIDs := ids.SplitMessageIDPairSlice(messageIDs) if mboxID.InternalID != state.user.GetRecoveryMailboxID().InternalID { if err := state.user.GetRemote().RemoveMessagesFromMailbox(ctx, remoteIDs, mboxID.RemoteID); err != nil { - return err + return nil, err } } else { state.user.GetRecoveredMessageHashesMap().Erase(internalIDs...) } - updates, err := RemoveMessagesFromMailbox(ctx, tx, mboxID.InternalID, internalIDs) - if err != nil { - return err - } - - return state.user.QueueOrApplyStateUpdate(ctx, tx, updates...) + return RemoveMessagesFromMailbox(ctx, tx, mboxID.InternalID, internalIDs) } func (state *State) actionRemoveMessagesFromMailbox( @@ -495,10 +482,10 @@ func (state *State) actionRemoveMessagesFromMailbox( tx *ent.Tx, messageIDs []ids.MessageIDPair, mboxID ids.MailboxIDPair, -) error { +) ([]Update, error) { haveMessageIDs, err := db.FilterMailboxContains(ctx, tx.Client(), mboxID.InternalID, messageIDs) if err != nil { - return err + return nil, err } messageIDs = xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { @@ -506,7 +493,7 @@ func (state *State) actionRemoveMessagesFromMailbox( }) if len(messageIDs) == 0 { - return nil + return nil, nil } return state.actionRemoveMessagesFromMailboxUnchecked(ctx, tx, messageIDs, mboxID) @@ -517,31 +504,41 @@ func (state *State) actionMoveMessages( tx *ent.Tx, messageIDs []ids.MessageIDPair, mboxFromID, mboxToID ids.MailboxIDPair, -) ([]db.UIDWithFlags, error) { +) ([]Update, []db.UIDWithFlags, error) { + var allUpdates []Update + if mboxFromID.InternalID == mboxToID.InternalID { internalIDs, _ := ids.SplitMessageIDPairSlice(messageIDs) - return db.BumpMailboxUIDsForMessage(ctx, tx, internalIDs, mboxToID.InternalID) + uid, err := db.BumpMailboxUIDsForMessage(ctx, tx, internalIDs, mboxToID.InternalID) + if err != nil { + return nil, nil, err + } + + return nil, uid, nil } { messageIDsToAdd, err := db.FilterMailboxContains(ctx, tx.Client(), mboxToID.InternalID, messageIDs) if err != nil { - return nil, err + return nil, nil, err } if remMessageIDs := xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { return slices.Contains(messageIDsToAdd, messageID.InternalID) }); len(remMessageIDs) > 0 { - if err := state.actionRemoveMessagesFromMailboxUnchecked(ctx, tx, remMessageIDs, mboxToID); err != nil { - return nil, err + updates, err := state.actionRemoveMessagesFromMailboxUnchecked(ctx, tx, remMessageIDs, mboxToID) + if err != nil { + return nil, nil, err } + + allUpdates = append(allUpdates, updates...) } } messageInFromMBox, err := db.FilterMailboxContains(ctx, tx.Client(), mboxFromID.InternalID, messageIDs) if err != nil { - return nil, err + return nil, nil, err } messagesIDsToMove := xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { @@ -552,19 +549,17 @@ func (state *State) actionMoveMessages( shouldRemoveOldMessages, err := state.user.GetRemote().MoveMessagesFromMailbox(ctx, remoteIDs, mboxFromID.RemoteID, mboxToID.RemoteID) if err != nil { - return nil, err + return nil, nil, err } messageUIDs, updates, err := MoveMessagesFromMailbox(ctx, tx, mboxFromID.InternalID, mboxToID.InternalID, internalIDs, state, state.imapLimits, shouldRemoveOldMessages) if err != nil { - return nil, err + return nil, nil, err } - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, updates...); err != nil { - return nil, err - } + allUpdates = append(allUpdates, updates...) - return messageUIDs, nil + return allUpdates, messageUIDs, nil } func (state *State) actionAddMessageFlags( @@ -572,16 +567,12 @@ func (state *State) actionAddMessageFlags( tx *ent.Tx, messages []snapMsgWithSeq, addFlags imap.FlagSet, -) error { +) ([]Update, error) { internalMessageIDs := xslices.Map(messages, func(sm snapMsgWithSeq) imap.InternalMessageID { return sm.ID.InternalID }) - if err := state.applyMessageFlagsAdded(ctx, tx, internalMessageIDs, addFlags); err != nil { - return err - } - - return nil + return state.applyMessageFlagsAdded(ctx, tx, internalMessageIDs, addFlags) } func (state *State) actionRemoveMessageFlags( @@ -589,21 +580,20 @@ func (state *State) actionRemoveMessageFlags( tx *ent.Tx, messages []snapMsgWithSeq, remFlags imap.FlagSet, -) error { +) ([]Update, error) { internalMessageIDs := xslices.Map(messages, func(sm snapMsgWithSeq) imap.InternalMessageID { return sm.ID.InternalID }) - if err := state.applyMessageFlagsRemoved(ctx, tx, internalMessageIDs, remFlags); err != nil { - return err - } - - return nil + return state.applyMessageFlagsRemoved(ctx, tx, internalMessageIDs, remFlags) } -func (state *State) actionSetMessageFlags(ctx context.Context, tx *ent.Tx, messages []snapMsgWithSeq, setFlags imap.FlagSet) error { +func (state *State) actionSetMessageFlags(ctx context.Context, + tx *ent.Tx, + messages []snapMsgWithSeq, + setFlags imap.FlagSet) ([]Update, error) { if setFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { - return fmt.Errorf("recent flag is read-only") + return nil, fmt.Errorf("recent flag is read-only") } internalMessageIDs := xslices.Map(messages, func(sm snapMsgWithSeq) imap.InternalMessageID { diff --git a/internal/state/mailbox.go b/internal/state/mailbox.go index f6135b52..f3ce0cc8 100644 --- a/internal/state/mailbox.go +++ b/internal/state/mailbox.go @@ -85,19 +85,19 @@ func (m *Mailbox) Count() int { } func (m *Mailbox) Flags(ctx context.Context) (imap.FlagSet, error) { - return db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { + return stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { return db.GetMailboxFlags(ctx, client, m.id.InternalID) }) } func (m *Mailbox) PermanentFlags(ctx context.Context) (imap.FlagSet, error) { - return db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { + return stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { return db.GetMailboxPermanentFlags(ctx, client, m.id.InternalID) }) } func (m *Mailbox) Attributes(ctx context.Context) (imap.FlagSet, error) { - return db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { + return stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { return db.GetMailboxAttributes(ctx, client, m.id.InternalID) }) } @@ -147,7 +147,7 @@ func (m *Mailbox) GetMessagesWithoutFlagCount(flag string) int { } func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap.FlagSet, date time.Time) (imap.UID, error) { - if err := m.state.db().Read(ctx, func(ctx context.Context, client *ent.Client) error { + if err := stateDBRead(ctx, m.state, func(ctx context.Context, client *ent.Client) error { if messageCount, uid, err := db.GetMailboxMessageCountAndUID(ctx, client, m.snap.mboxID.InternalID); err != nil { return err } else { @@ -185,7 +185,7 @@ func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap. return 0, err } - if message, err := db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Message, error) { + if message, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Message, error) { message, err := db.GetMessageWithIDWithDeletedFlag(ctx, client, msgID) if err != nil { if ent.IsNotFound(err) { @@ -201,10 +201,10 @@ func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap. } else if !message.Deleted { logrus.Debugf("Appending duplicate message with Internal ID:%v", msgID.ShortID()) // Only shuffle around messages that haven't been marked for deletion. - if res, err := db.WriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) ([]db.UIDWithFlags, error) { + if res, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, []db.UIDWithFlags, error) { remoteID, err := db.GetMessageRemoteIDFromID(ctx, tx.Client(), msgID) if err != nil { - return nil, err + return nil, nil, err } return m.state.actionAddMessagesToMailbox(ctx, tx, @@ -229,7 +229,7 @@ func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap. } } - return db.WriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) (imap.UID, error) { + return stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, imap.UID, error) { return m.state.actionCreateMessage(ctx, tx, m.snap.mboxID, literal, flags, date, m.snap == m.state.snap, appendIntoDrafts) }) } @@ -245,7 +245,7 @@ func (m *Mailbox) Append(ctx context.Context, literal []byte, flags imap.FlagSet } // Failed to append to mailbox attempt to insert into recovery mailbox. - knownMessage, recoverErr := db.WriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) (bool, error) { + knownMessage, recoverErr := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, bool, error) { return m.state.actionCreateRecoveredMessage(ctx, tx, literal, flags, date) }) if recoverErr != nil && !knownMessage { @@ -269,7 +269,7 @@ func (m *Mailbox) Copy(ctx context.Context, seq []command.SeqRange, name string) return nil, ErrOperationNotAllowed } - mbox, err := db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { + mbox, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { return db.GetMailboxByName(ctx, client, name) }) if err != nil { @@ -290,7 +290,7 @@ func (m *Mailbox) Copy(ctx context.Context, seq []command.SeqRange, name string) msgIDs[i] = snapMsg.ID } - destUIDs, err := db.WriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) ([]db.UIDWithFlags, error) { + destUIDs, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, []db.UIDWithFlags, error) { if m.state.user.GetRecoveryMailboxID().InternalID == m.snap.mboxID.InternalID { return m.state.actionCopyMessagesOutOfRecoveryMailbox(ctx, tx, msgIDs, ids.NewMailboxIDPair(mbox)) } else { @@ -320,7 +320,7 @@ func (m *Mailbox) Move(ctx context.Context, seq []command.SeqRange, name string) return nil, ErrOperationNotAllowed } - mbox, err := db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { + mbox, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { return db.GetMailboxByName(ctx, client, name) }) if err != nil { @@ -341,7 +341,7 @@ func (m *Mailbox) Move(ctx context.Context, seq []command.SeqRange, name string) msgIDs[i] = snapMsg.ID } - destUIDs, err := db.WriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) ([]db.UIDWithFlags, error) { + destUIDs, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, []db.UIDWithFlags, error) { if m.state.user.GetRecoveryMailboxID().InternalID == m.snap.mboxID.InternalID { return m.state.actionMoveMessagesOutOfRecoveryMailbox(ctx, tx, msgIDs, ids.NewMailboxIDPair(mbox)) } else { @@ -369,25 +369,19 @@ func (m *Mailbox) Store(ctx context.Context, seqSet []command.SeqRange, action c return err } - return m.state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return stateDBWrite(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { switch action { case command.StoreActionAddFlags: - if err := m.state.actionAddMessageFlags(ctx, tx, messages, flags); err != nil { - return err - } + return m.state.actionAddMessageFlags(ctx, tx, messages, flags) case command.StoreActionRemFlags: - if err := m.state.actionRemoveMessageFlags(ctx, tx, messages, flags); err != nil { - return err - } + return m.state.actionRemoveMessageFlags(ctx, tx, messages, flags) case command.StoreActionSetFlags: - if err := m.state.actionSetMessageFlags(ctx, tx, messages, flags); err != nil { - return err - } + return m.state.actionSetMessageFlags(ctx, tx, messages, flags) } - return nil + return nil, fmt.Errorf("unknown flag action") }) } @@ -411,7 +405,7 @@ func (m *Mailbox) Expunge(ctx context.Context, seq []command.SeqRange) error { msgIDs = m.snap.getAllMessagesIDsMarkedDelete() } - return m.state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return stateDBWrite(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { return m.state.actionRemoveMessagesFromMailbox(ctx, tx, msgIDs, m.snap.mboxID) }) } diff --git a/internal/state/mailbox_fetch.go b/internal/state/mailbox_fetch.go index 0fa379e0..fc2e8134 100644 --- a/internal/state/mailbox_fetch.go +++ b/internal/state/mailbox_fetch.go @@ -118,7 +118,7 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons defer async.HandlePanic(m.state.panicHandler) msg := snapMessages[i] - message, err := db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Message, error) { + message, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Message, error) { return db.GetMessage(ctx, client, msg.ID.InternalID) }) if err != nil { @@ -175,7 +175,7 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons }) if len(msgsToBeMarkedSeen) != 0 { - if err := m.state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := stateDBWrite(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { return m.state.actionAddMessageFlags(ctx, tx, msgsToBeMarkedSeen, imap.NewFlagSet(imap.FlagSeen)) }); err != nil { return err diff --git a/internal/state/mailbox_search.go b/internal/state/mailbox_search.go index 6a4cce41..038a18d6 100644 --- a/internal/state/mailbox_search.go +++ b/internal/state/mailbox_search.go @@ -88,7 +88,7 @@ func buildSearchData(ctx context.Context, m *Mailbox, op *buildSearchOpResult, m data := searchData{message: message} if op.needsMessage { - dbm, err := db.ReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Message, error) { + dbm, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Message, error) { return db.GetMessageDateAndSize(ctx, client, message.ID.InternalID) }) if err != nil { diff --git a/internal/state/state.go b/internal/state/state.go index d113fde0..3b559bab 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -76,12 +76,8 @@ func (state *State) UserID() string { return state.user.GetUserID() } -func (state *State) db() *db.DB { - return state.user.GetDB() -} - func (state *State) List(ctx context.Context, ref, pattern string, lsub bool, fn func(map[string]Match) error) error { - return state.db().Read(ctx, func(ctx context.Context, client *ent.Client) error { + return stateDBRead(ctx, state, func(ctx context.Context, client *ent.Client) error { mailboxes, err := db.GetAllMailboxes(ctx, client) if err != nil { return err @@ -168,7 +164,7 @@ func (state *State) List(ctx context.Context, ref, pattern string, lsub bool, fn } func (state *State) Select(ctx context.Context, name string, fn func(*Mailbox) error) error { - mbox, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { return db.GetMailboxByName(ctx, client, name) }) if err != nil { @@ -181,15 +177,15 @@ func (state *State) Select(ctx context.Context, name string, fn func(*Mailbox) e } } - snap, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*snapshot, error) { + snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*snapshot, error) { return newSnapshot(ctx, state, client, mbox) }) if err != nil { return err } - if err := state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - return db.ClearRecentFlags(ctx, tx, mbox.ID) + if err := stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { + return nil, db.ClearRecentFlags(ctx, tx, mbox.ID) }); err != nil { return err } @@ -201,7 +197,7 @@ func (state *State) Select(ctx context.Context, name string, fn func(*Mailbox) e } func (state *State) Examine(ctx context.Context, name string, fn func(*Mailbox) error) error { - mbox, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { return db.GetMailboxByName(ctx, client, name) }) if err != nil { @@ -214,7 +210,7 @@ func (state *State) Examine(ctx context.Context, name string, fn func(*Mailbox) } } - snap, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*snapshot, error) { + snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*snapshot, error) { return newSnapshot(ctx, state, client, mbox) }) if err != nil { @@ -251,13 +247,13 @@ func (state *State) Create(ctx context.Context, name string) error { } } - return state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { client := tx.Client() if mailboxCount, err := db.GetMailboxCount(ctx, client); err != nil { - return err + return nil, err } else if err := state.imapLimits.CheckMailBoxCount(mailboxCount); err != nil { - return err + return nil, err } var mboxesToCreate []string @@ -268,14 +264,14 @@ func (state *State) Create(ctx context.Context, name string) error { } if exists, err := db.MailboxExistsWithName(ctx, client, name); err != nil { - return err + return nil, err } else if exists { - return ErrExistingMailbox + return nil, ErrExistingMailbox } for _, superior := range listSuperiors(name, state.delimiter) { if exists, err := db.MailboxExistsWithName(ctx, client, superior); err != nil { - return err + return nil, err } else if exists { continue } @@ -287,11 +283,11 @@ func (state *State) Create(ctx context.Context, name string) error { for _, mboxName := range mboxesToCreate { if err := state.actionCreateMailbox(ctx, tx, mboxName, uidValidity); err != nil { - return err + return nil, err } } - return nil + return nil, nil }) } @@ -301,13 +297,18 @@ func (state *State) Delete(ctx context.Context, name string) (bool, error) { return false, ErrOperationNotAllowed } - mboxID, err := db.WriteResult(ctx, state.db(), func(ctx context.Context, tx *ent.Tx) (imap.InternalMailboxID, error) { + mboxID, err := stateDBWriteResult(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, imap.InternalMailboxID, error) { mbox, err := db.GetMailboxByName(ctx, tx.Client(), name) if err != nil { - return 0, ErrNoSuchMailbox + return nil, 0, ErrNoSuchMailbox + } + + update, err := state.actionDeleteMailbox(ctx, tx, ids.NewMailboxIDPair(mbox)) + if err != nil { + return nil, 0, err } - return mbox.ID, state.actionDeleteMailbox(ctx, tx, ids.NewMailboxIDPair(mbox)) + return update, mbox.ID, nil }) if err != nil { return false, err @@ -326,27 +327,27 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { return ErrOperationNotAllowed } - return state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { client := tx.Client() mbox, err := db.GetMailboxByName(ctx, client, oldName) if err != nil { - return ErrNoSuchMailbox + return nil, ErrNoSuchMailbox } if exists, err := db.MailboxExistsWithName(ctx, client, newName); err != nil { - return err + return nil, err } else if exists { - return ErrExistingMailbox + return nil, ErrExistingMailbox } var mboxesToCreate []string for _, superior := range listSuperiors(newName, state.delimiter) { if exists, err := db.MailboxExistsWithName(ctx, client, superior); err != nil { - return err + return nil, err } else if exists { if superior == oldName { - return ErrExistingMailbox + return nil, ErrExistingMailbox } continue } @@ -357,16 +358,16 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { for _, m := range mboxesToCreate { uidValidity, err := state.user.GenerateUIDValidity() if err != nil { - return err + return nil, err } res, err := state.user.GetRemote().CreateMailbox(ctx, strings.Split(m, state.delimiter)) if err != nil { - return err + return nil, err } if err := db.CreateMailboxIfNotExists(ctx, tx, res, state.delimiter, uidValidity); err != nil { - return err + return nil, err } } @@ -375,13 +376,13 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { } if err := state.actionUpdateMailbox(ctx, tx, mbox.RemoteID, newName); err != nil { - return err + return nil, err } // Locally update all inferiors so we don't wait for update mailboxes, err := db.GetAllMailboxes(ctx, tx.Client()) if err != nil { - return err + return nil, err } inferiors := listInferiors(oldName, state.delimiter, xslices.Map(mailboxes, func(mailbox *ent.Mailbox) string { @@ -391,54 +392,54 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { for _, inferior := range inferiors { mbox, err := db.GetMailboxByName(ctx, tx.Client(), inferior) if err != nil { - return ErrNoSuchMailbox + return nil, ErrNoSuchMailbox } newInferior := newName + strings.TrimPrefix(inferior, oldName) if err := db.RenameMailboxWithRemoteID(ctx, tx, mbox.RemoteID, newInferior); err != nil { - return err + return nil, err } } - return nil + return nil, nil }) } func (state *State) Subscribe(ctx context.Context, name string) error { - return state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { mbox, err := db.GetMailboxByName(ctx, tx.Client(), name) if err != nil { - return ErrNoSuchMailbox + return nil, ErrNoSuchMailbox } if mbox.Subscribed { - return ErrAlreadySubscribed + return nil, ErrAlreadySubscribed } - return mbox.Update().SetSubscribed(true).Exec(ctx) + return nil, mbox.Update().SetSubscribed(true).Exec(ctx) }) } func (state *State) Unsubscribe(ctx context.Context, name string) error { - return state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { mbox, err := db.GetMailboxByName(ctx, tx.Client(), name) if err != nil { // If mailbox does not exist, check that if it is present in the deleted subscription table if count, err := db.RemoveDeletedSubscriptionWithName(ctx, tx, name); err != nil { - return err + return nil, err } else if count == 0 { - return ErrNoSuchMailbox + return nil, ErrNoSuchMailbox } else { - return nil + return nil, nil } } if !mbox.Subscribed { - return ErrAlreadyUnsubscribed + return nil, ErrAlreadyUnsubscribed } - return mbox.Update().SetSubscribed(false).Exec(ctx) + return nil, mbox.Update().SetSubscribed(false).Exec(ctx) }) } @@ -454,7 +455,7 @@ func (state *State) Idle(ctx context.Context, fn func([]response.Response, chan } func (state *State) Mailbox(ctx context.Context, name string, fn func(*Mailbox) error) error { - mbox, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { return db.GetMailboxByName(ctx, client, name) }) if err != nil { @@ -465,7 +466,7 @@ func (state *State) Mailbox(ctx context.Context, name string, fn func(*Mailbox) return fn(newMailbox(mbox, state, state.snap)) } - snap, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*snapshot, error) { + snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*snapshot, error) { return newSnapshot(ctx, state, client, mbox) }) if err != nil { @@ -483,7 +484,7 @@ func (state *State) AppendOnlyMailbox(ctx context.Context, name string, fn func( return ErrOperationNotAllowed } - mbox, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { return db.GetMailboxByName(ctx, client, name) }) if err != nil { @@ -504,7 +505,7 @@ func (state *State) Selected(ctx context.Context, fn func(*Mailbox) error) error return ErrSessionNotSelected } - mbox, err := db.ReadResult(ctx, state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { return db.GetMailboxByID(ctx, client, state.snap.mboxID.InternalID) }) if err != nil { @@ -571,7 +572,7 @@ func (state *State) ApplyUpdate(ctx context.Context, update Update) error { return nil } - if err := state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { return update.Apply(ctx, tx, state) }); err != nil { reporter.MessageWithContext(ctx, @@ -598,29 +599,30 @@ func (state *State) markInvalid() { } // renameInbox creates a new mailbox and moves everything there. -func (state *State) renameInbox(ctx context.Context, tx *ent.Tx, inbox *ent.Mailbox, newName string) error { +func (state *State) renameInbox(ctx context.Context, tx *ent.Tx, inbox *ent.Mailbox, newName string) ([]Update, error) { uidValidity, err := state.user.GenerateUIDValidity() if err != nil { - return err + return nil, err } mbox, err := state.actionCreateAndGetMailbox(ctx, tx, newName, uidValidity) if err != nil { - return err + return nil, err } messageIDs, err := db.GetMailboxMessageIDPairs(ctx, tx.Client(), inbox.ID) if err != nil { - return err + return nil, err } mboxIDPair := ids.NewMailboxIDPair(mbox) - if _, err := state.actionMoveMessages(ctx, tx, messageIDs, ids.NewMailboxIDPair(inbox), mboxIDPair); err != nil { - return err + updates, _, err := state.actionMoveMessages(ctx, tx, messageIDs, ids.NewMailboxIDPair(inbox), mboxIDPair) + if err != nil { + return nil, err } - return nil + return updates, nil } func (state *State) beginIdle(ctx context.Context) ([]response.Response, error) { @@ -708,7 +710,7 @@ func (state *State) flushResponses(ctx context.Context, permitExpunge bool) ([]r } } - if err := state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { for _, update := range dbUpdates { if err := update.apply(ctx, tx); err != nil { return err @@ -830,3 +832,59 @@ func (state *State) close() error { return nil } + +func stateDBRead(ctx context.Context, state *State, fn func(context.Context, *ent.Client) error) error { + return state.user.GetDB().Read(ctx, fn) +} + +func stateDBReadResult[T any](ctx context.Context, state *State, fn func(context.Context, *ent.Client) (T, error)) (T, error) { + return db.ReadResult(ctx, state.user.GetDB(), fn) +} + +func stateDBWrite(ctx context.Context, state *State, fn func(context.Context, *ent.Tx) ([]Update, error)) error { + var updates []Update + + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + up, err := fn(ctx, tx) + updates = up + return err + }); err != nil { + return err + } + + // need to create a separate transaction for the state updates so that import changes get written first. + if len(updates) != 0 { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return state.user.QueueOrApplyStateUpdate(ctx, tx, updates...) + }); err != nil { + return err + } + } + + return nil +} + +func stateDBWriteResult[T any](ctx context.Context, state *State, fn func(context.Context, *ent.Tx) ([]Update, T, error)) (T, error) { + var updates []Update + + result, err := db.WriteResult(ctx, state.user.GetDB(), func(ctx context.Context, tx *ent.Tx) (T, error) { + up, val, err := fn(ctx, tx) + updates = up + return val, err + }) + if err != nil { + var t T + return t, err + } + + // need to create a separate transaction for the state updates so that import changes get written first. + if len(updates) != 0 { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return state.user.QueueOrApplyStateUpdate(ctx, tx, updates...) + }); err != nil { + return result, err + } + } + + return result, nil +} diff --git a/internal/state/updates.go b/internal/state/updates.go index 8dcb89db..d8e0b783 100644 --- a/internal/state/updates.go +++ b/internal/state/updates.go @@ -98,9 +98,9 @@ func (u *messageFlagsAddedStateUpdate) String() string { func (state *State) applyMessageFlagsAdded(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, - addFlags imap.FlagSet) error { + addFlags imap.FlagSet) ([]Update, error) { if addFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { - return fmt.Errorf("the recent flag is read-only") + return nil, fmt.Errorf("the recent flag is read-only") } // Since DB state can be more up to date then the flag state we should only emit add flag updates for values @@ -110,7 +110,7 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, curFlags, err := db.GetMessageFlags(ctx, client, messageIDs) if err != nil { - return err + return nil, err } // If setting messages as seen, only set those messages that aren't currently seen. @@ -125,7 +125,7 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, if len(messagesToApply) != 0 { if err := state.user.GetRemote().SetMessagesSeen(ctx, messagesToApply, true); err != nil { - return err + return nil, err } } } @@ -142,7 +142,7 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, if len(messagesToApply) != 0 { if err := state.user.GetRemote().SetMessagesFlagged(ctx, messagesToApply, true); err != nil { - return err + return nil, err } } } @@ -151,7 +151,7 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, if addFlags.ContainsUnchecked(imap.FlagDeletedLowerCase) { if err := db.SetDeletedFlag(ctx, tx, state.snap.mboxID.InternalID, messageIDs, true); err != nil { - return err + return nil, err } flagStateUpdate.addUpdate(newMessageFlagsAddedStateUpdate(imap.NewFlagSet(imap.FlagDeleted), state.snap.mboxID, messageIDs, state.StateID)) @@ -170,17 +170,13 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, } if err := db.AddMessageFlag(ctx, tx, messagesToFlag, flag); err != nil { - return err + return nil, err } flagStateUpdate.addUpdate(newMessageFlagsAddedStateUpdate(remainingFlags, state.snap.mboxID, messagesToFlag, state.StateID)) } - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, flagStateUpdate); err != nil { - return err - } - - return nil + return []Update{flagStateUpdate}, nil } type messageFlagsRemovedStateUpdate struct { @@ -229,16 +225,19 @@ func (u *messageFlagsRemovedStateUpdate) String() string { } // applyMessageFlagsRemoved removes the flags from the given messages. -func (state *State) applyMessageFlagsRemoved(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, remFlags imap.FlagSet) error { +func (state *State) applyMessageFlagsRemoved(ctx context.Context, + tx *ent.Tx, + messageIDs []imap.InternalMessageID, + remFlags imap.FlagSet) ([]Update, error) { if remFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { - return fmt.Errorf("the recent flag is read-only") + return nil, fmt.Errorf("the recent flag is read-only") } client := tx.Client() curFlags, err := db.GetMessageFlags(ctx, client, messageIDs) if err != nil { - return err + return nil, err } // If setting messages as unseen, only set those messages that are currently seen. if remFlags.ContainsUnchecked(imap.FlagSeenLowerCase) { @@ -252,7 +251,7 @@ func (state *State) applyMessageFlagsRemoved(ctx context.Context, tx *ent.Tx, me if len(messagesToApply) != 0 { if err := state.user.GetRemote().SetMessagesSeen(ctx, messagesToApply, false); err != nil { - return err + return nil, err } } } @@ -269,7 +268,7 @@ func (state *State) applyMessageFlagsRemoved(ctx context.Context, tx *ent.Tx, me if len(messagesToApply) != 0 { if err := state.user.GetRemote().SetMessagesFlagged(ctx, messagesToApply, false); err != nil { - return err + return nil, err } } } @@ -278,7 +277,7 @@ func (state *State) applyMessageFlagsRemoved(ctx context.Context, tx *ent.Tx, me if remFlags.ContainsUnchecked(imap.FlagDeletedLowerCase) { if err := db.SetDeletedFlag(ctx, tx, state.snap.mboxID.InternalID, messageIDs, false); err != nil { - return err + return nil, err } flagStateUpdate.addUpdate(NewMessageFlagsRemovedStateUpdate(imap.NewFlagSet(imap.FlagDeleted), state.snap.mboxID, messageIDs, state.StateID)) @@ -297,17 +296,13 @@ func (state *State) applyMessageFlagsRemoved(ctx context.Context, tx *ent.Tx, me } if err := db.RemoveMessageFlag(ctx, tx, messagesToFlag, flag); err != nil { - return err + return nil, err } flagStateUpdate.addUpdate(NewMessageFlagsRemovedStateUpdate(remainingFlags, state.snap.mboxID, messagesToFlag, state.StateID)) } - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, flagStateUpdate); err != nil { - return err - } - - return nil + return []Update{flagStateUpdate}, nil } type messageFlagsSetStateUpdate struct { @@ -356,18 +351,21 @@ func (u *messageFlagsSetStateUpdate) String() string { } // applyMessageFlagsSet sets the flags of the given messages. -func (state *State) applyMessageFlagsSet(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, setFlags imap.FlagSet) error { +func (state *State) applyMessageFlagsSet(ctx context.Context, + tx *ent.Tx, + messageIDs []imap.InternalMessageID, + setFlags imap.FlagSet) ([]Update, error) { if setFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { - return fmt.Errorf("the recent flag is read-only") + return nil, fmt.Errorf("the recent flag is read-only") } if state.snap == nil { - return nil + return nil, nil } curFlags, err := db.GetMessageFlags(ctx, tx.Client(), messageIDs) if err != nil { - return err + return nil, err } // If setting messages as seen, only set those messages that aren't currently seen, and vice versa. @@ -381,7 +379,7 @@ func (state *State) applyMessageFlagsSet(ctx context.Context, tx *ent.Tx, messag for seen, messageIDs := range setSeen { if err := state.user.GetRemote().SetMessagesSeen(ctx, messageIDs, seen); err != nil { - return err + return nil, err } } @@ -396,23 +394,19 @@ func (state *State) applyMessageFlagsSet(ctx context.Context, tx *ent.Tx, messag for flagged, messageIDs := range setFlagged { if err := state.user.GetRemote().SetMessagesFlagged(ctx, messageIDs, flagged); err != nil { - return err + return nil, err } } if err := db.SetDeletedFlag(ctx, tx, state.snap.mboxID.InternalID, messageIDs, setFlags.Contains(imap.FlagDeleted)); err != nil { - return err + return nil, err } if err := db.SetMessageFlags(ctx, tx, messageIDs, setFlags.Remove(imap.FlagDeleted)); err != nil { - return err - } - - if err := state.user.QueueOrApplyStateUpdate(ctx, tx, NewMessageFlagsSetStateUpdate(setFlags, state.snap.mboxID, messageIDs, state.StateID)); err != nil { - return err + return nil, err } - return nil + return []Update{NewMessageFlagsSetStateUpdate(setFlags, state.snap.mboxID, messageIDs, state.StateID)}, nil } type mailboxRemoteIDUpdateStateUpdate struct { diff --git a/internal/state/updates_remote.go b/internal/state/updates_remote.go index 8aa4aebb..2f4e3d5f 100644 --- a/internal/state/updates_remote.go +++ b/internal/state/updates_remote.go @@ -7,7 +7,6 @@ import ( "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/contexts" "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/ids" ) type RemoteAddMessageFlagsStateUpdate struct { @@ -54,21 +53,3 @@ type RemoteMessageDeletedStateUpdate struct { MessageIDStateFilter remoteID imap.MessageID } - -func NewRemoteMessageDeletedStateUpdate(messageID imap.InternalMessageID, remoteID imap.MessageID) Update { - return &RemoteMessageDeletedStateUpdate{ - MessageIDStateFilter: MessageIDStateFilter{MessageID: messageID}, - remoteID: remoteID, - } -} - -func (u *RemoteMessageDeletedStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { - return s.actionRemoveMessagesFromMailbox(ctx, tx, []ids.MessageIDPair{{ - InternalID: u.MessageID, - RemoteID: u.remoteID, - }}, s.snap.mboxID) -} - -func (u *RemoteMessageDeletedStateUpdate) String() string { - return fmt.Sprintf("RemoteMessageDeletedStateUpdate %v remote ID = %v", u.MessageIDStateFilter.String(), u.remoteID) -}