diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 4a1ca858..b6c0c09f 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -14,32 +14,18 @@ jobs: - name: Get sources uses: actions/checkout@v3 - - name: Set up Go 1.18 + - name: Set up Go 1.20 uses: actions/setup-go@v3 with: - go-version: '1.18' - - - name: Remove old static libs if modified - if: needs.check.outputs.changed == 'true' - run: | - rm -r internal/parser/lib - - - name: Download new static libs if modified - if: needs.check.outputs.changed == 'true' - uses: actions/download-artifact@v3 - with: - name: ${{ matrix.os }}-libs - path: internal/parser/lib + go-version: '1.20' - name: Run go mod tidy run: go mod tidy - name: Run golangci-lint - uses: golangci/golangci-lint-action@v3 - with: - version: v1.51.1 - args: --timeout=500s - skip-cache: true + run: | + go install github.com/golangci/golangci-lint/cmd/golangci-lint@251ceaa228607dd3e0371694a1ab2c45d21cb744 + golangci-lint run --timeout=500s - name: Run tests run: go test -timeout 15m -v ./... diff --git a/.golangci.yml b/.golangci.yml index 6abe3cd3..cc70d56d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -13,6 +13,7 @@ linters: disable: - godox # Annoying, we have too many TODOs at the moment :p - scopelint # Deprecated, replaced by exportloopref, which is enabled by default. + - errorlint # Too many false positives issues: exclude-rules: diff --git a/builder.go b/builder.go index 864388fd..88d38315 100644 --- a/builder.go +++ b/builder.go @@ -7,9 +7,10 @@ import ( "time" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/backend" - "github.com/ProtonMail/gluon/internal/db" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db" "github.com/ProtonMail/gluon/internal/session" "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/profiling" @@ -36,6 +37,7 @@ type serverBuilder struct { imapLimits limits.IMAP uidValidityGenerator imap.UIDValidityGenerator panicHandler async.PanicHandler + dbCI db.ClientInterface } func newBuilder() (*serverBuilder, error) { @@ -48,6 +50,7 @@ func newBuilder() (*serverBuilder, error) { imapLimits: limits.DefaultLimits(), uidValidityGenerator: imap.DefaultEpochUIDValidityGenerator(), panicHandler: async.NoopPanicHandler{}, + dbCI: ent_db.NewEntDBBuilder(), }, nil } @@ -86,6 +89,7 @@ func (builder *serverBuilder) build() (*Server, error) { builder.loginJailTime, builder.imapLimits, builder.panicHandler, + builder.dbCI, ) if err != nil { return nil, err diff --git a/connector/mock_connector/connector.go b/connector/mock_connector/connector.go index b08a5d17..ef0de4c9 100644 --- a/connector/mock_connector/connector.go +++ b/connector/mock_connector/connector.go @@ -123,18 +123,33 @@ func (mr *MockConnectorMockRecorder) DeleteMailbox(arg0, arg1 interface{}) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMailbox", reflect.TypeOf((*MockConnector)(nil).DeleteMailbox), arg0, arg1) } -// GetUIDValidity mocks base method. -func (m *MockConnector) GetUIDValidity() imap.UID { +// GetMailboxVisibility mocks base method. +func (m *MockConnector) GetMailboxVisibility(arg0 context.Context, arg1 imap.MailboxID) imap.MailboxVisibility { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUIDValidity") - ret0, _ := ret[0].(imap.UID) + ret := m.ctrl.Call(m, "GetMailboxVisibility", arg0, arg1) + ret0, _ := ret[0].(imap.MailboxVisibility) return ret0 } -// GetUIDValidity indicates an expected call of GetUIDValidity. -func (mr *MockConnectorMockRecorder) GetUIDValidity() *gomock.Call { +// GetMailboxVisibility indicates an expected call of GetMailboxVisibility. +func (mr *MockConnectorMockRecorder) GetMailboxVisibility(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUIDValidity", reflect.TypeOf((*MockConnector)(nil).GetUIDValidity)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMailboxVisibility", reflect.TypeOf((*MockConnector)(nil).GetMailboxVisibility), arg0, arg1) +} + +// GetMessageLiteral mocks base method. +func (m *MockConnector) GetMessageLiteral(arg0 context.Context, arg1 imap.MessageID) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMessageLiteral", arg0, arg1) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMessageLiteral indicates an expected call of GetMessageLiteral. +func (mr *MockConnectorMockRecorder) GetMessageLiteral(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessageLiteral", reflect.TypeOf((*MockConnector)(nil).GetMessageLiteral), arg0, arg1) } // GetUpdates mocks base method. @@ -151,20 +166,6 @@ func (mr *MockConnectorMockRecorder) GetUpdates() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUpdates", reflect.TypeOf((*MockConnector)(nil).GetUpdates)) } -// IsMailboxVisible mocks base method. -func (m *MockConnector) IsMailboxVisible(arg0 context.Context, arg1 imap.MailboxID) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsMailboxVisible", arg0, arg1) - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsMailboxVisible indicates an expected call of IsMailboxVisible. -func (mr *MockConnectorMockRecorder) IsMailboxVisible(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsMailboxVisible", reflect.TypeOf((*MockConnector)(nil).IsMailboxVisible), arg0, arg1) -} - // MarkMessagesFlagged mocks base method. func (m *MockConnector) MarkMessagesFlagged(arg0 context.Context, arg1 []imap.MessageID, arg2 bool) error { m.ctrl.T.Helper() @@ -194,11 +195,12 @@ func (mr *MockConnectorMockRecorder) MarkMessagesSeen(arg0, arg1, arg2 interface } // MoveMessages mocks base method. -func (m *MockConnector) MoveMessages(arg0 context.Context, arg1 []imap.MessageID, arg2, arg3 imap.MailboxID) error { +func (m *MockConnector) MoveMessages(arg0 context.Context, arg1 []imap.MessageID, arg2, arg3 imap.MailboxID) (bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "MoveMessages", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 } // MoveMessages indicates an expected call of MoveMessages. @@ -221,20 +223,6 @@ func (mr *MockConnectorMockRecorder) RemoveMessagesFromMailbox(arg0, arg1, arg2 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveMessagesFromMailbox", reflect.TypeOf((*MockConnector)(nil).RemoveMessagesFromMailbox), arg0, arg1, arg2) } -// SetUIDValidity mocks base method. -func (m *MockConnector) SetUIDValidity(arg0 imap.UID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetUIDValidity", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetUIDValidity indicates an expected call of SetUIDValidity. -func (mr *MockConnectorMockRecorder) SetUIDValidity(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUIDValidity", reflect.TypeOf((*MockConnector)(nil).SetUIDValidity), arg0) -} - // UpdateMailboxName mocks base method. func (m *MockConnector) UpdateMailboxName(arg0 context.Context, arg1 imap.MailboxID, arg2 []string) error { m.ctrl.T.Helper() diff --git a/db/client.go b/db/client.go new file mode 100644 index 00000000..ef43ed6c --- /dev/null +++ b/db/client.go @@ -0,0 +1,52 @@ +package db + +import ( + "context" + "path/filepath" +) + +const ChunkLimit = 1000 + +type Client interface { + Init(ctx context.Context) error + Read(ctx context.Context, op func(context.Context, ReadOnly) error) error + Write(ctx context.Context, op func(context.Context, Transaction) error) error + Close() error +} + +type ClientInterface interface { + New(path string, userID string) (Client, bool, error) + Delete(path string, userID string) error +} + +func GetDeferredDeleteDBPath(dir string) string { + return filepath.Join(dir, "deferred_delete") +} + +func ClientReadType[T any](ctx context.Context, c Client, op func(context.Context, ReadOnly) (T, error)) (T, error) { + var result T + + err := c.Read(ctx, func(ctx context.Context, read ReadOnly) error { + var err error + + result, err = op(ctx, read) + + return err + }) + + return result, err +} + +func ClientWriteType[T any](ctx context.Context, c Client, op func(context.Context, Transaction) (T, error)) (T, error) { + var result T + + err := c.Write(ctx, func(ctx context.Context, t Transaction) error { + var err error + + result, err = op(ctx, t) + + return err + }) + + return result, err +} diff --git a/db/deferred_delete.go b/db/deferred_delete.go new file mode 100644 index 00000000..d110aeae --- /dev/null +++ b/db/deferred_delete.go @@ -0,0 +1,48 @@ +package db + +import ( + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/google/uuid" +) + +// DeleteDB will rename all the database files for the given user to a directory within the same folder to avoid +// issues with ent not being able to close the database on demand. The database will be cleaned up on the next +// run on the Gluon server. +func DeleteDB(dir, userID string) error { + // Rather than deleting the files immediately move them to a directory to be cleaned up later. + deferredDeletePath := GetDeferredDeleteDBPath(dir) + + if err := os.MkdirAll(deferredDeletePath, 0o700); err != nil { + return fmt.Errorf("failed to create deferred delete dir: %w", err) + } + + matchingFiles, err := filepath.Glob(filepath.Join(dir, userID+"*")) + if err != nil { + return fmt.Errorf("failed to match db files:%w", err) + } + + for _, file := range matchingFiles { + // Use new UUID to avoid conflict with existing files + if err := os.Rename(file, filepath.Join(deferredDeletePath, uuid.NewString())); err != nil { + return fmt.Errorf("failed to move db file '%v' :%w", file, err) + } + } + + return nil +} + +// DeleteDeferredDBFiles deletes all data from previous databases that were scheduled for removal. +func DeleteDeferredDBFiles(dir string) error { + deferredDeleteDir := GetDeferredDeleteDBPath(dir) + if err := os.RemoveAll(deferredDeleteDir); err != nil { + if !errors.Is(err, os.ErrNotExist) { + return err + } + } + + return nil +} diff --git a/db/errors.go b/db/errors.go new file mode 100644 index 00000000..d50f494c --- /dev/null +++ b/db/errors.go @@ -0,0 +1,14 @@ +package db + +import "errors" + +var ErrNotFound = errors.New("value not found") +var ErrTransactionFailed = errors.New("transaction failed") + +func IsErrNotFound(err error) bool { + if err == nil { + return false + } + + return errors.Is(err, ErrNotFound) +} diff --git a/db/ops.go b/db/ops.go new file mode 100644 index 00000000..cd950cf2 --- /dev/null +++ b/db/ops.go @@ -0,0 +1,13 @@ +package db + +type ReadOnly interface { + MailboxReadOps + MessageReadOps + SubscriptionReadOps +} + +type Transaction interface { + MailboxWriteOps + MessageWriteOps + SubscriptionWriteOps +} diff --git a/db/ops_mailbox.go b/db/ops_mailbox.go new file mode 100644 index 00000000..4c4e8557 --- /dev/null +++ b/db/ops_mailbox.go @@ -0,0 +1,171 @@ +package db + +import ( + "context" + "strings" + + "github.com/ProtonMail/gluon/imap" +) + +type MailboxReadOps interface { + MailboxExistsWithID(ctx context.Context, mboxID imap.InternalMailboxID) (bool, error) + + MailboxExistsWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (bool, error) + + MailboxExistsWithName(ctx context.Context, name string) (bool, error) + + GetMailboxIDFromRemoteID(ctx context.Context, mboxID imap.MailboxID) (imap.InternalMailboxID, error) + + GetMailboxName(ctx context.Context, mboxID imap.InternalMailboxID) (string, error) + + GetMailboxNameWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (string, error) + + GetMailboxMessageIDPairs(ctx context.Context, mboxID imap.InternalMailboxID) ([]MessageIDPair, error) + + GetAllMailboxes(ctx context.Context) ([]*Mailbox, error) + + GetAllMailboxesAsRemoteIDs(ctx context.Context) ([]imap.MailboxID, error) + + GetMailboxByName(ctx context.Context, name string) (*Mailbox, error) + + GetMailboxByID(ctx context.Context, mboxID imap.InternalMailboxID) (*Mailbox, error) + + GetMailboxByRemoteID(ctx context.Context, mboxID imap.MailboxID) (*Mailbox, error) + + GetMailboxRecentCount(ctx context.Context, mboxID imap.InternalMailboxID) (int, error) + + GetMailboxMessageCount(ctx context.Context, mboxID imap.InternalMailboxID) (int, error) + + GetMailboxMessageCountWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (int, error) + + GetMailboxFlags(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) + + GetMailboxPermanentFlags(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) + + GetMailboxAttributes(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) + + GetMailboxUID(ctx context.Context, mboxID imap.InternalMailboxID) (imap.UID, error) + + GetMailboxMessageCountAndUID(ctx context.Context, mboxID imap.InternalMailboxID) (int, imap.UID, error) + + GetMailboxMessageForNewSnapshot(ctx context.Context, mboxID imap.InternalMailboxID) ([]SnapshotMessageResult, error) + + MailboxTranslateRemoteIDs(ctx context.Context, mboxIDs []imap.MailboxID) ([]imap.InternalMailboxID, error) + + MailboxFilterContains(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []MessageIDPair) ([]imap.InternalMessageID, error) + + MailboxFilterContainsInternalID(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]imap.InternalMessageID, error) + + GetMailboxCount(ctx context.Context) (int, error) + + // GetMessageUIDsWithFlagsAfterAddOrUIDBump exploits a property of adding a message to or bumping the UIDs of existing message in mailbox. It can only be + // used if you can guarantee that the messageID list contains only IDs that have recently added or bumped in the mailbox. + GetMailboxMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]UIDWithFlags, error) +} + +type MailboxWriteOps interface { + MailboxReadOps + + CreateMailbox(ctx context.Context, + mboxID imap.MailboxID, + name string, + flags, permFlags, attrs imap.FlagSet, + uidValidity imap.UID) (*Mailbox, error) + + GetOrCreateMailbox(ctx context.Context, + mboxID imap.MailboxID, + name string, + flags, permFlags, attrs imap.FlagSet, + uidValidity imap.UID) (*Mailbox, error) + + GetOrCreateMailboxAlt(ctx context.Context, + mbox imap.Mailbox, + delimiter string, + uidValidity imap.UID) (*Mailbox, error) + + RenameMailboxWithRemoteID(ctx context.Context, mboxID imap.MailboxID, name string) error + + DeleteMailboxWithRemoteID(ctx context.Context, mboxID imap.MailboxID) error + + BumpMailboxUIDNext(ctx context.Context, mboxID imap.InternalMailboxID, count int) error + + AddMessagesToMailbox(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]UIDWithFlags, error) + + BumpMailboxUIDsForMessage(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]UIDWithFlags, error) + + RemoveMessagesFromMailbox(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) error + + ClearRecentFlagInMailboxOnMessage(ctx context.Context, mboxID imap.InternalMailboxID, messageID imap.InternalMessageID) error + + ClearRecentFlagsInMailbox(ctx context.Context, mboxID imap.InternalMailboxID) error + + CreateMailboxIfNotExists(ctx context.Context, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) error + + SetMailboxMessagesDeletedFlag(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID, deleted bool) error + + SetMailboxSubscribed(ctx context.Context, mboxID imap.InternalMailboxID, subscribed bool) error + + UpdateRemoteMailboxID(ctx context.Context, mobxID imap.InternalMailboxID, remoteID imap.MailboxID) error + + SetMailboxUIDValidity(ctx context.Context, mboxID imap.InternalMailboxID, uidValidity imap.UID) error +} + +type SnapshotMessageResult struct { + InternalID imap.InternalMessageID `json:"uid_message"` + RemoteID imap.MessageID `json:"remote_id"` + UID imap.UID `json:"uid"` + Recent bool `json:"recent"` + Deleted bool `json:"deleted"` + Flags string `json:"flags"` +} + +func (msg *SnapshotMessageResult) GetFlagSet() imap.FlagSet { + var flagSet imap.FlagSet + + if len(msg.Flags) > 0 { + flags := strings.Split(msg.Flags, ",") + flagSet = imap.NewFlagSetFromSlice(flags) + } else { + flagSet = imap.NewFlagSet() + } + + if msg.Deleted { + flagSet.AddToSelf(imap.FlagDeleted) + } + + if msg.Recent { + flagSet.AddToSelf(imap.FlagRecent) + } + + return flagSet +} + +type UIDWithFlags struct { + InternalID imap.InternalMessageID `json:"uid_message"` + RemoteID imap.MessageID `json:"remote_id"` + UID imap.UID `json:"uid"` + Recent bool `json:"recent"` + Deleted bool `json:"deleted"` + Flags string `json:"flags"` +} + +func (u *UIDWithFlags) GetFlagSet() imap.FlagSet { + var flagSet imap.FlagSet + + if len(u.Flags) > 0 { + flags := strings.Split(u.Flags, ",") + flagSet = imap.NewFlagSetFromSlice(flags) + } else { + flagSet = imap.NewFlagSet() + } + + if u.Deleted { + flagSet.AddToSelf(imap.FlagDeleted) + } + + if u.Recent { + flagSet.AddToSelf(imap.FlagRecent) + } + + return flagSet +} diff --git a/db/ops_message.go b/db/ops_message.go new file mode 100644 index 00000000..cd67ab17 --- /dev/null +++ b/db/ops_message.go @@ -0,0 +1,92 @@ +package db + +import ( + "context" + "time" + + "github.com/ProtonMail/gluon/imap" + "github.com/bradenaw/juniper/xslices" +) + +type MessageReadOps interface { + MessageExists(ctx context.Context, id imap.InternalMessageID) (bool, error) + + MessageExistsWithRemoteID(ctx context.Context, id imap.MessageID) (bool, error) + + GetMessage(ctx context.Context, id imap.InternalMessageID) (*Message, error) + + GetTotalMessageCount(ctx context.Context) (int, error) + + GetMessageRemoteID(ctx context.Context, id imap.InternalMessageID) (imap.MessageID, error) + + GetImportedMessageData(ctx context.Context, id imap.InternalMessageID) (*Message, error) + + GetMessageDateAndSize(ctx context.Context, id imap.InternalMessageID) (time.Time, int, error) + + GetMessageMailboxIDs(ctx context.Context, id imap.InternalMessageID) ([]imap.InternalMailboxID, error) + + GetMessagesFlags(ctx context.Context, ids []imap.InternalMessageID) ([]MessageFlagSet, error) + + GetMessageIDsMarkedAsDelete(ctx context.Context) ([]imap.InternalMessageID, error) + + GetMessageIDFromRemoteID(ctx context.Context, id imap.MessageID) (imap.InternalMessageID, error) + + GetMessageDeletedFlag(ctx context.Context, id imap.InternalMessageID) (bool, error) + + GetAllMessagesIDsAsMap(ctx context.Context) (map[imap.InternalMessageID]struct{}, error) +} + +type MessageWriteOps interface { + MessageReadOps + + CreateMessages(ctx context.Context, reqs ...*CreateMessageReq) ([]*Message, error) + + CreateMessageAndAddToMailbox(ctx context.Context, mbox imap.InternalMailboxID, req *CreateMessageReq) (imap.UID, imap.FlagSet, error) + + MarkMessageAsDeleted(ctx context.Context, id imap.InternalMessageID) error + + MarkMessageAsDeletedAndAssignRandomRemoteID(ctx context.Context, id imap.InternalMessageID) error + + MarkMessageAsDeletedWithRemoteID(ctx context.Context, id imap.MessageID) error + + DeleteMessages(ctx context.Context, ids []imap.InternalMessageID) error + + UpdateRemoteMessageID(ctx context.Context, internalID imap.InternalMessageID, remoteID imap.MessageID) error + + AddFlagToMessages(ctx context.Context, ids []imap.InternalMessageID, flag string) error + + RemoveFlagFromMessages(ctx context.Context, ids []imap.InternalMessageID, flag string) error + + SetFlagsOnMessages(ctx context.Context, ids []imap.InternalMessageID, flags imap.FlagSet) error +} + +type CreateMessageReq struct { + Message imap.Message + InternalID imap.InternalMessageID + LiteralSize int + Body string + Structure string + Envelope string +} + +type MessageFlagSet struct { + ID imap.InternalMessageID + RemoteID imap.MessageID + FlagSet imap.FlagSet +} + +func NewFlagSet(msgUID *UID, flags []*MessageFlag) imap.FlagSet { + flagSet := imap.NewFlagSetFromSlice(xslices.Map(flags, func(flag *MessageFlag) string { + return flag.Value + })) + + if msgUID.Deleted { + flagSet.AddToSelf(imap.FlagDeleted) + } + + if msgUID.Recent { + flagSet.AddToSelf(imap.FlagRecent) + } + + return flagSet +} diff --git a/db/ops_subscription.go b/db/ops_subscription.go new file mode 100644 index 00000000..b79d4a83 --- /dev/null +++ b/db/ops_subscription.go @@ -0,0 +1,16 @@ +package db + +import ( + "context" + + "github.com/ProtonMail/gluon/imap" +) + +type SubscriptionReadOps interface { + GetDeletedSubscriptionSet(ctx context.Context) (map[imap.MailboxID]*DeletedSubscription, error) +} + +type SubscriptionWriteOps interface { + AddDeletedSubscription(ctx context.Context, mboxName string, mboxID imap.MailboxID) error + RemoveDeletedSubscriptionWithName(ctx context.Context, mboxName string) (int, error) +} diff --git a/db/types.go b/db/types.go new file mode 100644 index 00000000..9d7147a0 --- /dev/null +++ b/db/types.go @@ -0,0 +1,126 @@ +package db + +import ( + "fmt" + "time" + + "github.com/ProtonMail/gluon/imap" +) + +type MailboxIDPair struct { + InternalID imap.InternalMailboxID + RemoteID imap.MailboxID +} + +func (m *MailboxIDPair) String() string { + return fmt.Sprintf("%v::%v", m.InternalID, m.RemoteID) +} + +type MessageIDPair struct { + InternalID imap.InternalMessageID + RemoteID imap.MessageID +} + +func (m *MessageIDPair) String() string { + return fmt.Sprintf("%v::%v", m.InternalID, m.RemoteID) +} + +func NewMailboxIDPair(mbox *Mailbox) MailboxIDPair { + return MailboxIDPair{ + InternalID: mbox.ID, + RemoteID: mbox.RemoteID, + } +} + +func NewMailboxIDPairWithoutRemote(internalID imap.InternalMailboxID) MailboxIDPair { + return MailboxIDPair{ + InternalID: internalID, + RemoteID: "", + } +} + +func NewMessageIDPair(msg *Message) MessageIDPair { + return MessageIDPair{ + InternalID: msg.ID, + RemoteID: msg.RemoteID, + } +} + +func SplitMessageIDPairSlice(s []MessageIDPair) ([]imap.InternalMessageID, []imap.MessageID) { + l := len(s) + + internalMessageIDs := make([]imap.InternalMessageID, 0, l) + remoteMessageIDs := make([]imap.MessageID, 0, l) + + for _, v := range s { + internalMessageIDs = append(internalMessageIDs, v.InternalID) + remoteMessageIDs = append(remoteMessageIDs, v.RemoteID) + } + + return internalMessageIDs, remoteMessageIDs +} + +func SplitMailboxIDPairSlice(s []MailboxIDPair) ([]imap.InternalMailboxID, []imap.MailboxID) { + l := len(s) + + internalMailboxIDs := make([]imap.InternalMailboxID, 0, l) + mailboxIDs := make([]imap.MailboxID, 0, l) + + for _, v := range s { + internalMailboxIDs = append(internalMailboxIDs, v.InternalID) + mailboxIDs = append(mailboxIDs, v.RemoteID) + } + + return internalMailboxIDs, mailboxIDs +} + +type MailboxFlag struct { + ID int + Value string +} + +type MailboxAttr struct { + ID int + Value string +} + +type Mailbox struct { + ID imap.InternalMailboxID + RemoteID imap.MailboxID + Name string + UIDNext imap.UID + UIDValidity imap.UID + Subscribed bool + Flags []*MailboxFlag + PermanentFlags []*MailboxFlag + Attributes []*MailboxAttr +} + +type MessageFlag struct { + ID int + Value string +} + +type Message struct { + ID imap.InternalMessageID + RemoteID imap.MessageID + Date time.Time + Size int + Body string + BodyStructure string + Envelope string + Deleted bool + Flags []*MessageFlag + UIDs []*UID +} + +type UID struct { + UID imap.UID + Deleted bool + Recent bool +} + +type DeletedSubscription struct { + Name string + RemoteID imap.MailboxID +} diff --git a/go.mod b/go.mod index d2fd5a41..21dd5004 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/ProtonMail/gluon -go 1.18 +go 1.20 require ( entgo.io/ent v0.11.8 diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 8bceb802..5e487583 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -9,10 +9,8 @@ import ( "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/connector" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/reporter" @@ -51,10 +49,19 @@ type Backend struct { imapLimits limits.IMAP + database db.ClientInterface + panicHandler async.PanicHandler } -func New(dataDir, databaseDir string, storeBuilder store.Builder, delim string, loginJailTime time.Duration, imapLimits limits.IMAP, panicHandler async.PanicHandler) (*Backend, error) { +func New(dataDir, databaseDir string, + storeBuilder store.Builder, + delim string, + loginJailTime time.Duration, + imapLimits limits.IMAP, + panicHandler async.PanicHandler, + database db.ClientInterface, +) (*Backend, error) { return &Backend{ dataDir: dataDir, databaseDir: databaseDir, @@ -64,6 +71,7 @@ func New(dataDir, databaseDir string, storeBuilder store.Builder, delim string, loginJailTime: loginJailTime, imapLimits: imapLimits, panicHandler: panicHandler, + database: database, }, nil } @@ -86,7 +94,7 @@ func (b *Backend) AddUser(ctx context.Context, userID string, conn connector.Con return false, err } - db, isNew, err := db.NewDB(b.getDBDir(), userID) + db, isNew, err := b.database.New(b.getDBDir(), userID) if err != nil { if err := storeBuilder.Close(); err != nil { logrus.WithError(err).Error("Failed to close store builder") @@ -130,7 +138,7 @@ func (b *Backend) RemoveUser(ctx context.Context, userID string, removeFiles boo return err } - if err := db.DeleteDB(b.getDBDir(), userID); err != nil { + if err := b.database.Delete(b.getDBDir(), userID); err != nil { return err } } @@ -147,21 +155,21 @@ func (b *Backend) GetMailboxMessageCounts(ctx context.Context, userID string) (m return nil, ErrNoSuchUser } - return db.ReadResult(ctx, user.db, func(ctx context.Context, c *ent.Client) (map[imap.MailboxID]int, error) { + return db.ClientReadType(ctx, user.db, func(ctx context.Context, c db.ReadOnly) (map[imap.MailboxID]int, error) { counts := make(map[imap.MailboxID]int) - mailboxes, err := c.Mailbox.Query().Select(mailbox.FieldRemoteID).All(ctx) + mailboxes, err := c.GetAllMailboxesAsRemoteIDs(ctx) if err != nil { return nil, err } - for _, mailbox := range mailboxes { - messageCount, err := mailbox.QueryUIDs().Count(ctx) + for _, mailboxID := range mailboxes { + messageCount, err := c.GetMailboxMessageCountWithRemoteID(ctx, mailboxID) if err != nil { return nil, err } - counts[mailbox.RemoteID] = messageCount + counts[mailboxID] = messageCount } return counts, nil diff --git a/internal/backend/connector_updates.go b/internal/backend/connector_updates.go index bc240645..68604a7f 100644 --- a/internal/backend/connector_updates.go +++ b/internal/backend/connector_updates.go @@ -9,9 +9,8 @@ import ( "runtime" "strings" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/rfc822" @@ -79,8 +78,8 @@ func (user *user) applyMailboxCreated(ctx context.Context, update *imap.MailboxC return fmt.Errorf("attempting to create protected mailbox (recovery)") } - if exists, err := db.ReadResult(ctx, user.db, func(ctx context.Context, client *ent.Client) (bool, error) { - return db.MailboxExistsWithRemoteID(ctx, client, update.Mailbox.ID) + if exists, err := db.ClientReadType(ctx, user.db, func(ctx context.Context, client db.ReadOnly) (bool, error) { + return client.MailboxExistsWithRemoteID(ctx, update.Mailbox.ID) }); err != nil { return err } else if exists { @@ -96,16 +95,15 @@ func (user *user) applyMailboxCreated(ctx context.Context, update *imap.MailboxC return err } - return user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - if mailboxCount, err := db.GetMailboxCount(ctx, tx.Client()); err != nil { + return user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + if mailboxCount, err := tx.GetMailboxCount(ctx); err != nil { return err } else if err := user.imapLimits.CheckMailBoxCount(mailboxCount); err != nil { return err } - if _, err := db.CreateMailbox( + if _, err := tx.CreateMailbox( ctx, - tx, update.Mailbox.ID, strings.Join(update.Mailbox.Name, user.delimiter), update.Mailbox.Flags, @@ -126,21 +124,21 @@ func (user *user) applyMailboxDeleted(ctx context.Context, update *imap.MailboxD return fmt.Errorf("attempting to delete protected mailbox (recovery)") } - stateUpdate, err := db.WriteResult(ctx, user.db, func(ctx context.Context, tx *ent.Tx) (state.Update, error) { - mailbox, err := db.GetMailboxByRemoteID(ctx, tx.Client(), update.MailboxID) + stateUpdate, err := db.ClientWriteType(ctx, user.db, func(ctx context.Context, tx db.Transaction) (state.Update, error) { + mailbox, err := tx.GetMailboxByRemoteID(ctx, update.MailboxID) if err != nil { - if ent.IsNotFound(err) { + if db.IsErrNotFound(err) { return nil, nil } return nil, err } - if err := db.DeleteMailboxWithRemoteID(ctx, tx, update.MailboxID); err != nil { + if err := tx.DeleteMailboxWithRemoteID(ctx, update.MailboxID); err != nil { return nil, err } - if _, err := db.RemoveDeletedSubscriptionWithName(ctx, tx, mailbox.Name); err != nil { + if _, err := tx.RemoveDeletedSubscriptionWithName(ctx, mailbox.Name); err != nil { return nil, err } @@ -163,16 +161,14 @@ func (user *user) applyMailboxUpdated(ctx context.Context, update *imap.MailboxU return fmt.Errorf("attempting to rename protected mailbox (recovery)") } - return user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - client := tx.Client() - - if exists, err := db.MailboxExistsWithRemoteID(ctx, client, update.MailboxID); err != nil { + return user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + if exists, err := tx.MailboxExistsWithRemoteID(ctx, update.MailboxID); err != nil { return err } else if !exists { return nil } - currentName, err := db.GetMailboxNameWithRemoteID(ctx, client, update.MailboxID) + currentName, err := tx.GetMailboxNameWithRemoteID(ctx, update.MailboxID) if err != nil { return err } @@ -181,7 +177,7 @@ func (user *user) applyMailboxUpdated(ctx context.Context, update *imap.MailboxU return nil } - return db.RenameMailboxWithRemoteID(ctx, tx, update.MailboxID, strings.Join(update.MailboxName, user.delimiter)) + return tx.RenameMailboxWithRemoteID(ctx, update.MailboxID, strings.Join(update.MailboxName, user.delimiter)) }) } @@ -191,8 +187,8 @@ func (user *user) applyMailboxIDChanged(ctx context.Context, update *imap.Mailbo return fmt.Errorf("attempting to change protected mailbox (recovery) remote ID") } - return user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - if err := db.UpdateRemoteMailboxID(ctx, tx, update.InternalID, update.RemoteID); err != nil { + return user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + if err := tx.UpdateRemoteMailboxID(ctx, update.InternalID, update.RemoteID); err != nil { return err } @@ -215,9 +211,7 @@ func (user *user) applyMessagesCreated(ctx context.Context, update *imap.Message messageForMBox := make(map[imap.InternalMailboxID][]imap.InternalMessageID) mboxInternalIDMap := make(map[imap.MailboxID]imap.InternalMailboxID) - err := user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - client := tx.Client() - + err := user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { for _, message := range update.Messages { if slices.Contains(message.MailboxIDs, ids.GluonInternalRecoveryMailboxRemoteID) { logrus.Errorf("attempting to import messages into protected mailbox (recovery), skipping") @@ -226,8 +220,8 @@ func (user *user) applyMessagesCreated(ctx context.Context, update *imap.Message internalID, ok := messagesToCreateFilter[message.Message.ID] if !ok { - messageID, err := db.GetMessageIDFromRemoteID(ctx, client, message.Message.ID) - if ent.IsNotFound(err) { + messageID, err := tx.GetMessageIDFromRemoteID(ctx, message.Message.ID) + if db.IsErrNotFound(err) { internalID = imap.NewInternalMessageID() literalReader, literalSize, err := rfc822.SetHeaderValueNoMemCopy(message.Literal, ids.InternalIDKey, internalID.String()) @@ -259,7 +253,7 @@ func (user *user) applyMessagesCreated(ctx context.Context, update *imap.Message for _, mboxID := range message.MailboxIDs { v, ok := mboxInternalIDMap[mboxID] if !ok { - internalMBoxID, err := db.GetMailboxIDWithRemoteID(ctx, client, mboxID) + internalMBoxID, err := tx.GetMailboxIDFromRemoteID(ctx, mboxID) if err != nil { // If a mailbox doesn't exist and we are allowed to skip move to next mailbox. if update.IgnoreUnknownMailboxIDs { @@ -310,7 +304,7 @@ func (user *user) applyMessagesCreated(ctx context.Context, update *imap.Message } // Create message in the database - if _, err := db.CreateMessages(ctx, tx, xslices.Map(chunk, func(req *DBRequestWithLiteral) *db.CreateMessageReq { + if _, err := tx.CreateMessages(ctx, xslices.Map(chunk, func(req *DBRequestWithLiteral) *db.CreateMessageReq { return &req.CreateMessageReq })...); err != nil { return err @@ -319,7 +313,7 @@ func (user *user) applyMessagesCreated(ctx context.Context, update *imap.Message // Assign all the messages to the mailbox for mboxID, msgList := range messageForMBox { - inMailbox, err := db.FilterMailboxContainsInternalID(ctx, tx.Client(), mboxID, msgList) + inMailbox, err := tx.MailboxFilterContainsInternalID(ctx, mboxID, msgList) if err != nil { return err } @@ -358,23 +352,24 @@ func (user *user) applyMessageMailboxesUpdated(ctx context.Context, update *imap return fmt.Errorf("attempting to move messages into protected mailbox (recovery)") } - if exists, err := db.ReadResult(ctx, user.db, func(ctx context.Context, client *ent.Client) (bool, error) { - return db.MessageExistsWithRemoteID(ctx, client, update.MessageID) + if exists, err := db.ClientReadType(ctx, user.db, func(ctx context.Context, client db.ReadOnly) (bool, error) { + return client.MessageExistsWithRemoteID(ctx, update.MessageID) }); err != nil { return err } else if !exists { return state.ErrNoSuchMessage } - return user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - client := tx.Client() - - internalMsgID, err := db.GetMessageIDFromRemoteID(ctx, client, update.MessageID) + return user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + internalMsgID, err := tx.GetMessageIDFromRemoteID(ctx, update.MessageID) if err != nil { + if db.IsErrNotFound(err) { + return state.ErrNoSuchMessage + } return err } - internalMBoxIDs, err := db.TranslateRemoteMailboxIDs(ctx, client, update.MailboxIDs) + internalMBoxIDs, err := tx.MailboxTranslateRemoteIDs(ctx, update.MailboxIDs) if err != nil { return err } @@ -393,19 +388,19 @@ func (user *user) applyMessageMailboxesUpdated(ctx context.Context, update *imap // applyMessageFlagsUpdated applies a MessageFlagsUpdated update. func (user *user) applyMessageFlagsUpdated(ctx context.Context, update *imap.MessageFlagsUpdated) error { - if exists, err := db.ReadResult(ctx, user.db, func(ctx context.Context, client *ent.Client) (bool, error) { - return db.MessageExistsWithRemoteID(ctx, client, update.MessageID) + if exists, err := db.ClientReadType(ctx, user.db, func(ctx context.Context, client db.ReadOnly) (bool, error) { + return client.MessageExistsWithRemoteID(ctx, update.MessageID) }); err != nil { return err } else if !exists { return state.ErrNoSuchMessage } - return user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - internalMsgID, err := db.GetMessageIDFromRemoteID(ctx, tx.Client(), update.MessageID) + return user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + internalMsgID, err := tx.GetMessageIDFromRemoteID(ctx, update.MessageID) if err != nil { - if ent.IsNotFound(err) { + if db.IsErrNotFound(err) { return state.ErrNoSuchMessage } return err @@ -421,8 +416,8 @@ func (user *user) applyMessageFlagsUpdated(ctx context.Context, update *imap.Mes // applyMessageIDChanged applies a MessageIDChanged update. func (user *user) applyMessageIDChanged(ctx context.Context, update *imap.MessageIDChanged) error { - if err := user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - return db.UpdateRemoteMessageID(ctx, tx, update.InternalID, update.RemoteID) + if err := user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + return tx.UpdateRemoteMessageID(ctx, update.InternalID, update.RemoteID) }); err != nil { return err } @@ -436,8 +431,8 @@ func (user *user) applyMessageIDChanged(ctx context.Context, update *imap.Messag return nil } -func (user *user) setMessageMailboxes(ctx context.Context, tx *ent.Tx, messageID imap.InternalMessageID, mboxIDs []imap.InternalMailboxID) error { - curMailboxIDs, err := db.GetMessageMailboxIDs(ctx, tx.Client(), messageID) +func (user *user) setMessageMailboxes(ctx context.Context, tx db.Transaction, messageID imap.InternalMessageID, mboxIDs []imap.InternalMailboxID) error { + curMailboxIDs, err := tx.GetMessageMailboxIDs(ctx, messageID) if err != nil { return err } @@ -458,7 +453,7 @@ func (user *user) setMessageMailboxes(ctx context.Context, tx *ent.Tx, messageID } // applyMessagesAddedToMailbox adds the messages to the given mailbox. -func (user *user) applyMessagesAddedToMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { +func (user *user) applyMessagesAddedToMailbox(ctx context.Context, tx db.Transaction, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { messageUIDs, update, err := state.AddMessagesToMailbox(ctx, tx, mboxID, messageIDs, nil, user.imapLimits) if err != nil { return nil, err @@ -470,7 +465,7 @@ func (user *user) applyMessagesAddedToMailbox(ctx context.Context, tx *ent.Tx, m } // applyMessagesRemovedFromMailbox removes the messages from the given mailbox. -func (user *user) applyMessagesRemovedFromMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) error { +func (user *user) applyMessagesRemovedFromMailbox(ctx context.Context, tx db.Transaction, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) error { updates, err := state.RemoveMessagesFromMailbox(ctx, tx, mboxID, messageIDs) if err != nil { return err @@ -483,8 +478,8 @@ func (user *user) applyMessagesRemovedFromMailbox(ctx context.Context, tx *ent.T return nil } -func (user *user) setMessageFlags(ctx context.Context, tx *ent.Tx, messageID imap.InternalMessageID, flags imap.FlagSet) error { - curFlags, err := db.GetMessageFlags(ctx, tx.Client(), []imap.InternalMessageID{messageID}) +func (user *user) setMessageFlags(ctx context.Context, tx db.Transaction, messageID imap.InternalMessageID, flags imap.FlagSet) error { + curFlags, err := tx.GetMessagesFlags(ctx, []imap.InternalMessageID{messageID}) if err != nil { return err } @@ -510,8 +505,8 @@ func (user *user) setMessageFlags(ctx context.Context, tx *ent.Tx, messageID ima return nil } -func (user *user) addMessageFlags(ctx context.Context, tx *ent.Tx, messageID imap.InternalMessageID, flag string) error { - if err := db.AddMessageFlag(ctx, tx, []imap.InternalMessageID{messageID}, flag); err != nil { +func (user *user) addMessageFlags(ctx context.Context, tx db.Transaction, messageID imap.InternalMessageID, flag string) error { + if err := tx.AddFlagToMessages(ctx, []imap.InternalMessageID{messageID}, flag); err != nil { return err } @@ -520,8 +515,8 @@ func (user *user) addMessageFlags(ctx context.Context, tx *ent.Tx, messageID ima return nil } -func (user *user) removeMessageFlags(ctx context.Context, tx *ent.Tx, messageID imap.InternalMessageID, flag string) error { - if err := db.RemoveMessageFlag(ctx, tx, []imap.InternalMessageID{messageID}, flag); err != nil { +func (user *user) removeMessageFlags(ctx context.Context, tx db.Transaction, messageID imap.InternalMessageID, flag string) error { + if err := tx.RemoveFlagFromMessages(ctx, []imap.InternalMessageID{messageID}, flag); err != nil { return err } @@ -533,25 +528,25 @@ func (user *user) removeMessageFlags(ctx context.Context, tx *ent.Tx, messageID func (user *user) applyMessageDeleted(ctx context.Context, update *imap.MessageDeleted) error { var stateUpdates []state.Update - if err := user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - if err := db.MarkMessageAsDeletedWithRemoteID(ctx, tx, update.MessageID); err != nil { - if ent.IsNotFound(err) { + if err := user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + if err := tx.MarkMessageAsDeletedWithRemoteID(ctx, update.MessageID); err != nil { + if db.IsErrNotFound(err) { return nil } return err } - internalMessageID, err := db.GetMessageIDFromRemoteID(ctx, tx.Client(), update.MessageID) + internalMessageID, err := tx.GetMessageIDFromRemoteID(ctx, update.MessageID) if err != nil { - if ent.IsNotFound(err) { + if db.IsErrNotFound(err) { return nil } return err } - mailboxes, err := db.GetMessageMailboxIDs(ctx, tx.Client(), internalMessageID) + mailboxes, err := tx.GetMessageMailboxIDs(ctx, internalMessageID) if err != nil { return err } @@ -582,10 +577,10 @@ func (user *user) applyMessageDeleted(ctx context.Context, update *imap.MessageD func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageUpdated) error { log := logrus.WithField("message updated", update.Message.ID.ShortID()) - internalMessageID, err := db.ReadResult(ctx, user.db, func(ctx context.Context, client *ent.Client) (imap.InternalMessageID, error) { - return db.GetMessageIDFromRemoteID(ctx, client, update.Message.ID) + internalMessageID, err := db.ClientReadType(ctx, user.db, func(ctx context.Context, client db.ReadOnly) (imap.InternalMessageID, error) { + return client.GetMessageIDFromRemoteID(ctx, update.Message.ID) }) - if ent.IsNotFound(err) { + if db.IsErrNotFound(err) { if update.AllowCreate { log.Warn("Message not found, creating it instead") @@ -603,7 +598,7 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU return err } - return user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { // compare and see if the literal has changed. onDiskLiteral, err := user.store.Get(internalMessageID) if err != nil { @@ -630,7 +625,7 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU targetMailboxes := make([]imap.InternalMailboxID, 0, len(update.MailboxIDs)) for _, mbox := range update.MailboxIDs { - internalMBoxID, err := db.GetMailboxIDFromRemoteID(ctx, tx.Client(), mbox) + internalMBoxID, err := tx.GetMailboxIDFromRemoteID(ctx, mbox) if err != nil { return err } @@ -645,7 +640,7 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU var stateUpdates []state.Update { // delete the message and remove from the mailboxes. - mailboxes, err := db.GetMessageMailboxIDs(ctx, tx.Client(), internalMessageID) + mailboxes, err := tx.GetMessageMailboxIDs(ctx, internalMessageID) if err != nil { return err } @@ -663,7 +658,7 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU // We need change the old remote id as it will break our table constraint otherwise and everything // will silently fail. - if err := db.MarkMessageAsDeletedAndAssignRandomRemoteID(ctx, tx, internalMessageID); err != nil { + if err := tx.MarkMessageAsDeletedAndAssignRandomRemoteID(ctx, internalMessageID); err != nil { return err } } @@ -685,7 +680,7 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU InternalID: newInternalID, } - if m, err := db.CreateMessages(ctx, tx, request); err != nil { + if m, err := tx.CreateMessages(ctx, request); err != nil { return err } else if len(m) == 0 { return fmt.Errorf("no messages were inserted") @@ -696,7 +691,7 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU } for _, mbox := range update.MailboxIDs { - internalMBoxID, err := db.GetMailboxIDFromRemoteID(ctx, tx.Client(), mbox) + internalMBoxID, err := tx.GetMailboxIDFromRemoteID(ctx, mbox) if err != nil { return err } @@ -721,8 +716,8 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU // applyUIDValidityBumped applies a UIDValidityBumped event to the user. func (user *user) applyUIDValidityBumped(ctx context.Context, update *imap.UIDValidityBumped) error { - if err := user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - mailboxes, err := db.GetAllMailboxes(ctx, tx.Client()) + if err := user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + mailboxes, err := tx.GetAllMailboxes(ctx) if err != nil { return err } @@ -733,7 +728,7 @@ func (user *user) applyUIDValidityBumped(ctx context.Context, update *imap.UIDVa return err } - if _, err := mailbox.Update().SetUIDValidity(uidValidity).Save(ctx); err != nil { + if err := tx.SetMailboxUIDValidity(ctx, mailbox.ID, uidValidity); err != nil { return err } } diff --git a/internal/backend/state_user_interface_impl.go b/internal/backend/state_user_interface_impl.go index aebe3dd7..69691358 100644 --- a/internal/backend/state_user_interface_impl.go +++ b/internal/backend/state_user_interface_impl.go @@ -2,13 +2,12 @@ package backend import ( "context" - "github.com/ProtonMail/gluon/internal/utils" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/state" + "github.com/ProtonMail/gluon/internal/utils" "github.com/ProtonMail/gluon/store" ) @@ -31,7 +30,7 @@ func (s *StateUserInterfaceImpl) GetDelimiter() string { return s.u.delimiter } -func (s *StateUserInterfaceImpl) GetDB() *db.DB { +func (s *StateUserInterfaceImpl) GetDB() db.Client { return s.u.db } @@ -43,7 +42,7 @@ func (s *StateUserInterfaceImpl) GetStore() *store.WriteControlledStore { return s.u.store } -func (s *StateUserInterfaceImpl) QueueOrApplyStateUpdate(ctx context.Context, tx *ent.Tx, updates ...state.Update) error { +func (s *StateUserInterfaceImpl) QueueOrApplyStateUpdate(ctx context.Context, tx db.Transaction, updates ...state.Update) error { // If we detect a state id in the context, it means this function call is a result of a User interaction. // When that happens the update needs to be applied to the state matching the state ID immediately. If no such // stateID exists or the context information is not present, all updates are queued for later execution. @@ -84,8 +83,8 @@ func (s *StateUserInterfaceImpl) GenerateUIDValidity() (imap.UID, error) { return s.u.uidValidityGenerator.Generate() } -func (s *StateUserInterfaceImpl) GetRecoveryMailboxID() ids.MailboxIDPair { - return ids.MailboxIDPair{ +func (s *StateUserInterfaceImpl) GetRecoveryMailboxID() db.MailboxIDPair { + return db.MailboxIDPair{ InternalID: s.u.recoveryMailboxID, RemoteID: ids.GluonInternalRecoveryMailboxRemoteID, } diff --git a/internal/backend/user.go b/internal/backend/user.go index cb923e7c..4c834850 100644 --- a/internal/backend/user.go +++ b/internal/backend/user.go @@ -7,9 +7,8 @@ import ( "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/connector" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/internal/utils" @@ -30,7 +29,7 @@ type user struct { store *store.WriteControlledStore delimiter string - db *db.DB + db db.Client states map[state.StateID]*state.State statesLock sync.RWMutex @@ -53,7 +52,7 @@ type user struct { func newUser( ctx context.Context, userID string, - database *db.DB, + database db.Client, conn connector.Connector, st store.Store, delimiter string, @@ -68,7 +67,7 @@ func newUser( recoveredMessageHashes := utils.NewMessageHashesMap() // Create recovery mailbox if it does not exist - recoveryMBox, err := db.WriteResult(ctx, database, func(ctx context.Context, tx *ent.Tx) (*ent.Mailbox, error) { + recoveryMBox, err := db.ClientWriteType(ctx, database, func(ctx context.Context, tx db.Transaction) (*db.Mailbox, error) { uidValidity, err := uidValidityGenerator.Generate() if err != nil { return nil, err @@ -83,13 +82,13 @@ func newUser( Attributes: imap.NewFlagSet(imap.AttrNoInferiors), } - recoveryMBox, err := db.GetOrCreateMailbox(ctx, tx, mbox, delimiter, uidValidity) + recoveryMBox, err := tx.GetOrCreateMailboxAlt(ctx, mbox, delimiter, uidValidity) if err != nil { return nil, err } // Pre-fill the message hashes map - messages, err := db.GetMailboxMessageIDPairs(ctx, tx.Client(), recoveryMBox.ID) + messages, err := tx.GetMailboxMessageIDPairs(ctx, recoveryMBox.ID) if err != nil { return nil, err } @@ -218,13 +217,13 @@ func (user *user) close(ctx context.Context) error { func (user *user) deleteAllMessagesMarkedDeleted(ctx context.Context) error { // Delete messages in database first before deleting from the storage to avoid data loss. - ids, err := db.WriteResult(ctx, user.db, func(ctx context.Context, tx *ent.Tx) ([]imap.InternalMessageID, error) { - ids, err := db.GetMessageIDsMarkedDeleted(ctx, tx.Client()) + ids, err := db.ClientWriteType(ctx, user.db, func(ctx context.Context, tx db.Transaction) ([]imap.InternalMessageID, error) { + ids, err := tx.GetMessageIDsMarkedAsDelete(ctx) if err != nil { return nil, err } - if err := db.DeleteMessages(ctx, tx, ids...); err != nil { + if err := tx.DeleteMessages(ctx, ids); err != nil { return nil, err } @@ -270,8 +269,8 @@ func (user *user) newState() (*state.State, error) { } func (user *user) removeState(ctx context.Context, st *state.State) error { - messageIDs, err := db.ReadResult(ctx, user.db, func(ctx context.Context, client *ent.Client) ([]imap.InternalMessageID, error) { - return db.GetMessageIDsMarkedDeleted(ctx, client) + messageIDs, err := db.ClientReadType(ctx, user.db, func(ctx context.Context, client db.ReadOnly) ([]imap.InternalMessageID, error) { + return client.GetMessageIDsMarkedAsDelete(ctx) }) if err != nil { return err @@ -309,8 +308,8 @@ func (user *user) removeState(ctx context.Context, st *state.State) error { defer user.statesWG.Done() // Delete messages in database first before deleting from the storage to avoid data loss. - if err := user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - if err := db.DeleteMessages(ctx, tx, messageIDs...); err != nil { + if err := user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + if err := tx.DeleteMessages(ctx, messageIDs); err != nil { return err } @@ -356,8 +355,8 @@ func (user *user) cleanupStaleStoreData(ctx context.Context) error { return err } - dbIdMap, err := db.ReadResult(ctx, user.db, func(ctx context.Context, client *ent.Client) (map[imap.InternalMessageID]struct{}, error) { - return db.GetAllMessagesIDsAsMap(ctx, client) + dbIdMap, err := db.ClientReadType(ctx, user.db, func(ctx context.Context, client db.ReadOnly) (map[imap.InternalMessageID]struct{}, error) { + return client.GetAllMessagesIDsAsMap(ctx) }) if err != nil { return err diff --git a/internal/data/db.go b/internal/data/db.go new file mode 100644 index 00000000..909c0fd6 --- /dev/null +++ b/internal/data/db.go @@ -0,0 +1,18 @@ +package data + +import ( + "errors" + "io/fs" + "os" +) + +// pathExists returns whether the given file exists. +func pathExists(path string) (bool, error) { + if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) { + return false, nil + } else if err != nil { + return false, err + } + + return true, nil +} diff --git a/internal/db/ent/hook/hook.go b/internal/db/ent/hook/hook.go deleted file mode 100644 index 1dc0f3f3..00000000 --- a/internal/db/ent/hook/hook.go +++ /dev/null @@ -1,291 +0,0 @@ -// Code generated by ent, DO NOT EDIT. - -package hook - -import ( - "context" - "fmt" - - "github.com/ProtonMail/gluon/internal/db/ent" -) - -// The DeletedSubscriptionFunc type is an adapter to allow the use of ordinary -// function as DeletedSubscription mutator. -type DeletedSubscriptionFunc func(context.Context, *ent.DeletedSubscriptionMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f DeletedSubscriptionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.DeletedSubscriptionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DeletedSubscriptionMutation", m) - } - return f(ctx, mv) -} - -// The MailboxFunc type is an adapter to allow the use of ordinary -// function as Mailbox mutator. -type MailboxFunc func(context.Context, *ent.MailboxMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f MailboxFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MailboxMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MailboxMutation", m) - } - return f(ctx, mv) -} - -// The MailboxAttrFunc type is an adapter to allow the use of ordinary -// function as MailboxAttr mutator. -type MailboxAttrFunc func(context.Context, *ent.MailboxAttrMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f MailboxAttrFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MailboxAttrMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MailboxAttrMutation", m) - } - return f(ctx, mv) -} - -// The MailboxFlagFunc type is an adapter to allow the use of ordinary -// function as MailboxFlag mutator. -type MailboxFlagFunc func(context.Context, *ent.MailboxFlagMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f MailboxFlagFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MailboxFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MailboxFlagMutation", m) - } - return f(ctx, mv) -} - -// The MailboxPermFlagFunc type is an adapter to allow the use of ordinary -// function as MailboxPermFlag mutator. -type MailboxPermFlagFunc func(context.Context, *ent.MailboxPermFlagMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f MailboxPermFlagFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MailboxPermFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MailboxPermFlagMutation", m) - } - return f(ctx, mv) -} - -// The MessageFunc type is an adapter to allow the use of ordinary -// function as Message mutator. -type MessageFunc func(context.Context, *ent.MessageMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f MessageFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MessageMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MessageMutation", m) - } - return f(ctx, mv) -} - -// The MessageFlagFunc type is an adapter to allow the use of ordinary -// function as MessageFlag mutator. -type MessageFlagFunc func(context.Context, *ent.MessageFlagMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f MessageFlagFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MessageFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MessageFlagMutation", m) - } - return f(ctx, mv) -} - -// The UIDFunc type is an adapter to allow the use of ordinary -// function as UID mutator. -type UIDFunc func(context.Context, *ent.UIDMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f UIDFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.UIDMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UIDMutation", m) - } - return f(ctx, mv) -} - -// Condition is a hook condition function. -type Condition func(context.Context, ent.Mutation) bool - -// And groups conditions with the AND operator. -func And(first, second Condition, rest ...Condition) Condition { - return func(ctx context.Context, m ent.Mutation) bool { - if !first(ctx, m) || !second(ctx, m) { - return false - } - for _, cond := range rest { - if !cond(ctx, m) { - return false - } - } - return true - } -} - -// Or groups conditions with the OR operator. -func Or(first, second Condition, rest ...Condition) Condition { - return func(ctx context.Context, m ent.Mutation) bool { - if first(ctx, m) || second(ctx, m) { - return true - } - for _, cond := range rest { - if cond(ctx, m) { - return true - } - } - return false - } -} - -// Not negates a given condition. -func Not(cond Condition) Condition { - return func(ctx context.Context, m ent.Mutation) bool { - return !cond(ctx, m) - } -} - -// HasOp is a condition testing mutation operation. -func HasOp(op ent.Op) Condition { - return func(_ context.Context, m ent.Mutation) bool { - return m.Op().Is(op) - } -} - -// HasAddedFields is a condition validating `.AddedField` on fields. -func HasAddedFields(field string, fields ...string) Condition { - return func(_ context.Context, m ent.Mutation) bool { - if _, exists := m.AddedField(field); !exists { - return false - } - for _, field := range fields { - if _, exists := m.AddedField(field); !exists { - return false - } - } - return true - } -} - -// HasClearedFields is a condition validating `.FieldCleared` on fields. -func HasClearedFields(field string, fields ...string) Condition { - return func(_ context.Context, m ent.Mutation) bool { - if exists := m.FieldCleared(field); !exists { - return false - } - for _, field := range fields { - if exists := m.FieldCleared(field); !exists { - return false - } - } - return true - } -} - -// HasFields is a condition validating `.Field` on fields. -func HasFields(field string, fields ...string) Condition { - return func(_ context.Context, m ent.Mutation) bool { - if _, exists := m.Field(field); !exists { - return false - } - for _, field := range fields { - if _, exists := m.Field(field); !exists { - return false - } - } - return true - } -} - -// If executes the given hook under condition. -// -// hook.If(ComputeAverage, And(HasFields(...), HasAddedFields(...))) -func If(hk ent.Hook, cond Condition) ent.Hook { - return func(next ent.Mutator) ent.Mutator { - return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { - if cond(ctx, m) { - return hk(next).Mutate(ctx, m) - } - return next.Mutate(ctx, m) - }) - } -} - -// On executes the given hook only for the given operation. -// -// hook.On(Log, ent.Delete|ent.Create) -func On(hk ent.Hook, op ent.Op) ent.Hook { - return If(hk, HasOp(op)) -} - -// Unless skips the given hook only for the given operation. -// -// hook.Unless(Log, ent.Update|ent.UpdateOne) -func Unless(hk ent.Hook, op ent.Op) ent.Hook { - return If(hk, Not(HasOp(op))) -} - -// FixedError is a hook returning a fixed error. -func FixedError(err error) ent.Hook { - return func(ent.Mutator) ent.Mutator { - return ent.MutateFunc(func(context.Context, ent.Mutation) (ent.Value, error) { - return nil, err - }) - } -} - -// Reject returns a hook that rejects all operations that match op. -// -// func (T) Hooks() []ent.Hook { -// return []ent.Hook{ -// Reject(ent.Delete|ent.Update), -// } -// } -func Reject(op ent.Op) ent.Hook { - hk := FixedError(fmt.Errorf("%s operation is not allowed", op)) - return On(hk, op) -} - -// Chain acts as a list of hooks and is effectively immutable. -// Once created, it will always hold the same set of hooks in the same order. -type Chain struct { - hooks []ent.Hook -} - -// NewChain creates a new chain of hooks. -func NewChain(hooks ...ent.Hook) Chain { - return Chain{append([]ent.Hook(nil), hooks...)} -} - -// Hook chains the list of hooks and returns the final hook. -func (c Chain) Hook() ent.Hook { - return func(mutator ent.Mutator) ent.Mutator { - for i := len(c.hooks) - 1; i >= 0; i-- { - mutator = c.hooks[i](mutator) - } - return mutator - } -} - -// Append extends a chain, adding the specified hook -// as the last ones in the mutation flow. -func (c Chain) Append(hooks ...ent.Hook) Chain { - newHooks := make([]ent.Hook, 0, len(c.hooks)+len(hooks)) - newHooks = append(newHooks, c.hooks...) - newHooks = append(newHooks, hooks...) - return Chain{newHooks} -} - -// Extend extends a chain, adding the specified chain -// as the last ones in the mutation flow. -func (c Chain) Extend(chain Chain) Chain { - return c.Append(chain.hooks...) -} diff --git a/internal/db/ent/runtime/runtime.go b/internal/db/ent/runtime/runtime.go deleted file mode 100644 index f6a48303..00000000 --- a/internal/db/ent/runtime/runtime.go +++ /dev/null @@ -1,10 +0,0 @@ -// Code generated by ent, DO NOT EDIT. - -package runtime - -// The schema-stitching logic is generated in github.com/ProtonMail/gluon/internal/db/ent/runtime.go - -const ( - Version = "v0.11.2" // Version of ent codegen. - Sum = "h1:UM2/BUhF2FfsxPHRxLjQbhqJNaDdVlOwNIAMLs2jyto=" // Sum of ent codegen. -) diff --git a/internal/db_impl/db_impl.go b/internal/db_impl/db_impl.go new file mode 100644 index 00000000..1c96616a --- /dev/null +++ b/internal/db_impl/db_impl.go @@ -0,0 +1,10 @@ +package db_impl + +import ( + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db" +) + +func NewEntDB() db.ClientInterface { + return ent_db.NewEntDB() +} diff --git a/internal/db/db.go b/internal/db_impl/ent_db/db.go similarity index 55% rename from internal/db/db.go rename to internal/db_impl/ent_db/db.go index b3885f35..0dbb7b1f 100644 --- a/internal/db/db.go +++ b/internal/db_impl/ent_db/db.go @@ -1,23 +1,23 @@ -package db +package ent_db import ( "context" "errors" "fmt" - "github.com/ProtonMail/gluon/internal/utils" - "github.com/google/uuid" "io/fs" "os" "path/filepath" "sync" "entgo.io/ent/dialect" - "github.com/ProtonMail/gluon/internal/db/ent" + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" + "github.com/ProtonMail/gluon/internal/utils" "github.com/ProtonMail/gluon/reporter" ) type DB struct { - db *ent.Client + db *internal.Client lock sync.RWMutex } @@ -28,22 +28,36 @@ func (d *DB) Init(ctx context.Context) error { return d.db.Schema.Create(ctx) } -func (d *DB) Read(ctx context.Context, fn func(context.Context, *ent.Client) error) error { - _, err := ReadResult(ctx, d, func(ctx context.Context, client *ent.Client) (struct{}, error) { +func (d *DB) ReadEnt(ctx context.Context, fn func(context.Context, *internal.Client) error) error { + _, err := ReadResult(ctx, d, func(ctx context.Context, client *internal.Client) (struct{}, error) { return struct{}{}, fn(ctx, client) }) return err } -func (d *DB) Write(ctx context.Context, fn func(context.Context, *ent.Tx) error) error { - _, err := WriteResult(ctx, d, func(ctx context.Context, tx *ent.Tx) (struct{}, error) { +func (d *DB) WriteEnt(ctx context.Context, fn func(context.Context, *internal.Tx) error) error { + _, err := WriteResult(ctx, d, func(ctx context.Context, tx *internal.Tx) (struct{}, error) { return struct{}{}, fn(ctx, tx) }) return err } +func (d *DB) Read(ctx context.Context, fn func(context.Context, db.ReadOnly) error) error { + return d.ReadEnt(ctx, func(ctx context.Context, client *internal.Client) error { + rd := newOpsReadFromClient(client) + return fn(ctx, rd) + }) +} + +func (d *DB) Write(ctx context.Context, fn func(context.Context, db.Transaction) error) error { + return d.WriteEnt(ctx, func(ctx context.Context, tx *internal.Tx) error { + op := newEntOpsWrite(tx) + return fn(ctx, op) + }) +} + func (d *DB) Close() error { d.lock.Lock() defer d.lock.Unlock() @@ -51,14 +65,14 @@ func (d *DB) Close() error { return d.db.Close() } -func ReadResult[T any](ctx context.Context, db *DB, fn func(context.Context, *ent.Client) (T, error)) (T, error) { +func ReadResult[T any](ctx context.Context, db *DB, fn func(context.Context, *internal.Client) (T, error)) (T, error) { db.lock.RLock() defer db.lock.RUnlock() return fn(ctx, db.db) } -func WriteResult[T any](ctx context.Context, db *DB, fn func(context.Context, *ent.Tx) (T, error)) (T, error) { +func WriteResult[T any](ctx context.Context, db *DB, fn func(context.Context, *internal.Tx) (T, error)) (T, error) { db.lock.Lock() defer db.lock.Unlock() @@ -102,21 +116,13 @@ func WriteResult[T any](ctx context.Context, db *DB, fn func(context.Context, *e return result, nil } -func getDatabaseConn(dir, userID, path string) string { - return fmt.Sprintf("file:%v?cache=shared&_fk=1&_journal=WAL", path) -} +type EntDBBuilder struct{} -func getDatabasePath(dir, userID string) string { - return filepath.Join(dir, fmt.Sprintf("%v.db", userID)) -} - -func GetDeferredDeleteDBPath(dir string) string { - return filepath.Join(dir, "deferred_delete") +func NewEntDBBuilder() db.ClientInterface { + return &EntDBBuilder{} } -// NewDB creates a new database instance. -// If the database does not exist, it will be created and the second return value will be true. -func NewDB(dir, userID string) (*DB, bool, error) { +func (EntDBBuilder) New(dir string, userID string) (db.Client, bool, error) { if err := os.MkdirAll(dir, 0o700); err != nil { return nil, false, err } @@ -129,7 +135,7 @@ func NewDB(dir, userID string) (*DB, bool, error) { return nil, false, err } - client, err := ent.Open(dialect.SQLite, getDatabaseConn(dir, userID, path)) + client, err := internal.Open(dialect.SQLite, getDatabaseConn(dir, userID, path)) if err != nil { return nil, false, err } @@ -137,42 +143,12 @@ func NewDB(dir, userID string) (*DB, bool, error) { return &DB{db: client}, !exists, nil } -// DeleteDB will rename all the database files for the given user to a directory within the same folder to avoid -// issues with ent not being able to close the database on demand. The database will be cleaned up on the next -// run on the Gluon server. -func DeleteDB(dir, userID string) error { - // Rather than deleting the files immediately move them to a directory to be cleaned up later. - deferredDeletePath := GetDeferredDeleteDBPath(dir) - - if err := os.MkdirAll(deferredDeletePath, 0o700); err != nil { - return fmt.Errorf("failed to create deferred delete dir: %w", err) - } - - matchingFiles, err := filepath.Glob(filepath.Join(dir, userID+"*")) - if err != nil { - return fmt.Errorf("failed to match db files:%w", err) - } - - for _, file := range matchingFiles { - // Use new UUID to avoid conflict with existing files - if err := os.Rename(file, filepath.Join(deferredDeletePath, uuid.NewString())); err != nil { - return fmt.Errorf("failed to move db file '%v' :%w", file, err) - } - } - - return nil +func (EntDBBuilder) Delete(dir string, userID string) error { + return db.DeleteDB(dir, userID) } -// DeleteDeferredDBFiles deletes all data from previous databases that were scheduled for removal. -func DeleteDeferredDBFiles(dir string) error { - deferredDeleteDir := GetDeferredDeleteDBPath(dir) - if err := os.RemoveAll(deferredDeleteDir); err != nil { - if !errors.Is(err, os.ErrNotExist) { - return err - } - } - - return nil +func getDatabaseConn(dir, userID, path string) string { + return fmt.Sprintf("file:%v?cache=shared&_fk=1&_journal=WAL", path) } // pathExists returns whether the given file exists. @@ -185,3 +161,32 @@ func pathExists(path string) (bool, error) { return true, nil } + +func getDatabasePath(dir, userID string) string { + return filepath.Join(dir, fmt.Sprintf("%v.db", userID)) +} + +func NewEntDB() db.ClientInterface { + return &EntDBBuilder{} +} + +func wrapEntError(err error) error { + if err == nil { + return nil + } + + if internal.IsNotFound(err) { + return fmt.Errorf("%v (%w)", err, db.ErrNotFound) + } + + return err +} + +func wrapEntErrFn(fn func() error) error { + return wrapEntError(fn()) +} + +func wrapEntErrFnTyped[T any](fn func() (T, error)) (T, error) { + val, err := fn() + return val, wrapEntError(err) +} diff --git a/internal/db/ent/client.go b/internal/db_impl/ent_db/internal/client.go similarity index 72% rename from internal/db/ent/client.go rename to internal/db_impl/ent_db/internal/client.go index 17d75a15..b1047681 100644 --- a/internal/db/ent/client.go +++ b/internal/db_impl/ent_db/internal/client.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -9,16 +9,16 @@ import ( "log" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/migrate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/migrate" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" @@ -50,7 +50,7 @@ type Client struct { // NewClient creates a new client configured with the given options. func NewClient(opts ...Option) *Client { - cfg := config{log: log.Println, hooks: &hooks{}} + cfg := config{log: log.Println, hooks: &hooks{}, inters: &inters{}} cfg.options(opts...) client := &Client{config: cfg} client.init() @@ -89,11 +89,11 @@ func Open(driverName, dataSourceName string, options ...Option) (*Client, error) // is used until the transaction is committed or rolled back. func (c *Client) Tx(ctx context.Context) (*Tx, error) { if _, ok := c.driver.(*txDriver); ok { - return nil, errors.New("ent: cannot start a transaction within a transaction") + return nil, errors.New("internal: cannot start a transaction within a transaction") } tx, err := newTx(ctx, c.driver) if err != nil { - return nil, fmt.Errorf("ent: starting a transaction: %w", err) + return nil, fmt.Errorf("internal: starting a transaction: %w", err) } cfg := c.config cfg.driver = tx @@ -173,6 +173,43 @@ func (c *Client) Use(hooks ...Hook) { c.UID.Use(hooks...) } +// Intercept adds the query interceptors to all the entity clients. +// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`. +func (c *Client) Intercept(interceptors ...Interceptor) { + c.DeletedSubscription.Intercept(interceptors...) + c.Mailbox.Intercept(interceptors...) + c.MailboxAttr.Intercept(interceptors...) + c.MailboxFlag.Intercept(interceptors...) + c.MailboxPermFlag.Intercept(interceptors...) + c.Message.Intercept(interceptors...) + c.MessageFlag.Intercept(interceptors...) + c.UID.Intercept(interceptors...) +} + +// Mutate implements the ent.Mutator interface. +func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { + switch m := m.(type) { + case *DeletedSubscriptionMutation: + return c.DeletedSubscription.mutate(ctx, m) + case *MailboxMutation: + return c.Mailbox.mutate(ctx, m) + case *MailboxAttrMutation: + return c.MailboxAttr.mutate(ctx, m) + case *MailboxFlagMutation: + return c.MailboxFlag.mutate(ctx, m) + case *MailboxPermFlagMutation: + return c.MailboxPermFlag.mutate(ctx, m) + case *MessageMutation: + return c.Message.mutate(ctx, m) + case *MessageFlagMutation: + return c.MessageFlag.mutate(ctx, m) + case *UIDMutation: + return c.UID.mutate(ctx, m) + default: + return nil, fmt.Errorf("internal: unknown mutation type %T", m) + } +} + // DeletedSubscriptionClient is a client for the DeletedSubscription schema. type DeletedSubscriptionClient struct { config @@ -189,6 +226,12 @@ func (c *DeletedSubscriptionClient) Use(hooks ...Hook) { c.hooks.DeletedSubscription = append(c.hooks.DeletedSubscription, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `deletedsubscription.Intercept(f(g(h())))`. +func (c *DeletedSubscriptionClient) Intercept(interceptors ...Interceptor) { + c.inters.DeletedSubscription = append(c.inters.DeletedSubscription, interceptors...) +} + // Create returns a builder for creating a DeletedSubscription entity. func (c *DeletedSubscriptionClient) Create() *DeletedSubscriptionCreate { mutation := newDeletedSubscriptionMutation(c.config, OpCreate) @@ -229,7 +272,7 @@ func (c *DeletedSubscriptionClient) DeleteOne(ds *DeletedSubscription) *DeletedS return c.DeleteOneID(ds.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *DeletedSubscriptionClient) DeleteOneID(id int) *DeletedSubscriptionDeleteOne { builder := c.Delete().Where(deletedsubscription.ID(id)) builder.mutation.id = &id @@ -241,6 +284,8 @@ func (c *DeletedSubscriptionClient) DeleteOneID(id int) *DeletedSubscriptionDele func (c *DeletedSubscriptionClient) Query() *DeletedSubscriptionQuery { return &DeletedSubscriptionQuery{ config: c.config, + ctx: &QueryContext{Type: TypeDeletedSubscription}, + inters: c.Interceptors(), } } @@ -263,6 +308,26 @@ func (c *DeletedSubscriptionClient) Hooks() []Hook { return c.hooks.DeletedSubscription } +// Interceptors returns the client interceptors. +func (c *DeletedSubscriptionClient) Interceptors() []Interceptor { + return c.inters.DeletedSubscription +} + +func (c *DeletedSubscriptionClient) mutate(ctx context.Context, m *DeletedSubscriptionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&DeletedSubscriptionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&DeletedSubscriptionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&DeletedSubscriptionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&DeletedSubscriptionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown DeletedSubscription mutation op: %q", m.Op()) + } +} + // MailboxClient is a client for the Mailbox schema. type MailboxClient struct { config @@ -279,6 +344,12 @@ func (c *MailboxClient) Use(hooks ...Hook) { c.hooks.Mailbox = append(c.hooks.Mailbox, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `mailbox.Intercept(f(g(h())))`. +func (c *MailboxClient) Intercept(interceptors ...Interceptor) { + c.inters.Mailbox = append(c.inters.Mailbox, interceptors...) +} + // Create returns a builder for creating a Mailbox entity. func (c *MailboxClient) Create() *MailboxCreate { mutation := newMailboxMutation(c.config, OpCreate) @@ -319,7 +390,7 @@ func (c *MailboxClient) DeleteOne(m *Mailbox) *MailboxDeleteOne { return c.DeleteOneID(m.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MailboxClient) DeleteOneID(id imap.InternalMailboxID) *MailboxDeleteOne { builder := c.Delete().Where(mailbox.ID(id)) builder.mutation.id = &id @@ -331,6 +402,8 @@ func (c *MailboxClient) DeleteOneID(id imap.InternalMailboxID) *MailboxDeleteOne func (c *MailboxClient) Query() *MailboxQuery { return &MailboxQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMailbox}, + inters: c.Interceptors(), } } @@ -350,8 +423,8 @@ func (c *MailboxClient) GetX(ctx context.Context, id imap.InternalMailboxID) *Ma // QueryUIDs queries the UIDs edge of a Mailbox. func (c *MailboxClient) QueryUIDs(m *Mailbox) *UIDQuery { - query := &UIDQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&UIDClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(mailbox.Table, mailbox.FieldID, id), @@ -366,8 +439,8 @@ func (c *MailboxClient) QueryUIDs(m *Mailbox) *UIDQuery { // QueryFlags queries the flags edge of a Mailbox. func (c *MailboxClient) QueryFlags(m *Mailbox) *MailboxFlagQuery { - query := &MailboxFlagQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MailboxFlagClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(mailbox.Table, mailbox.FieldID, id), @@ -382,8 +455,8 @@ func (c *MailboxClient) QueryFlags(m *Mailbox) *MailboxFlagQuery { // QueryPermanentFlags queries the permanent_flags edge of a Mailbox. func (c *MailboxClient) QueryPermanentFlags(m *Mailbox) *MailboxPermFlagQuery { - query := &MailboxPermFlagQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MailboxPermFlagClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(mailbox.Table, mailbox.FieldID, id), @@ -398,8 +471,8 @@ func (c *MailboxClient) QueryPermanentFlags(m *Mailbox) *MailboxPermFlagQuery { // QueryAttributes queries the attributes edge of a Mailbox. func (c *MailboxClient) QueryAttributes(m *Mailbox) *MailboxAttrQuery { - query := &MailboxAttrQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MailboxAttrClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(mailbox.Table, mailbox.FieldID, id), @@ -417,6 +490,26 @@ func (c *MailboxClient) Hooks() []Hook { return c.hooks.Mailbox } +// Interceptors returns the client interceptors. +func (c *MailboxClient) Interceptors() []Interceptor { + return c.inters.Mailbox +} + +func (c *MailboxClient) mutate(ctx context.Context, m *MailboxMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MailboxCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MailboxUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MailboxUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MailboxDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown Mailbox mutation op: %q", m.Op()) + } +} + // MailboxAttrClient is a client for the MailboxAttr schema. type MailboxAttrClient struct { config @@ -433,6 +526,12 @@ func (c *MailboxAttrClient) Use(hooks ...Hook) { c.hooks.MailboxAttr = append(c.hooks.MailboxAttr, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `mailboxattr.Intercept(f(g(h())))`. +func (c *MailboxAttrClient) Intercept(interceptors ...Interceptor) { + c.inters.MailboxAttr = append(c.inters.MailboxAttr, interceptors...) +} + // Create returns a builder for creating a MailboxAttr entity. func (c *MailboxAttrClient) Create() *MailboxAttrCreate { mutation := newMailboxAttrMutation(c.config, OpCreate) @@ -473,7 +572,7 @@ func (c *MailboxAttrClient) DeleteOne(ma *MailboxAttr) *MailboxAttrDeleteOne { return c.DeleteOneID(ma.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MailboxAttrClient) DeleteOneID(id int) *MailboxAttrDeleteOne { builder := c.Delete().Where(mailboxattr.ID(id)) builder.mutation.id = &id @@ -485,6 +584,8 @@ func (c *MailboxAttrClient) DeleteOneID(id int) *MailboxAttrDeleteOne { func (c *MailboxAttrClient) Query() *MailboxAttrQuery { return &MailboxAttrQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMailboxAttr}, + inters: c.Interceptors(), } } @@ -507,6 +608,26 @@ func (c *MailboxAttrClient) Hooks() []Hook { return c.hooks.MailboxAttr } +// Interceptors returns the client interceptors. +func (c *MailboxAttrClient) Interceptors() []Interceptor { + return c.inters.MailboxAttr +} + +func (c *MailboxAttrClient) mutate(ctx context.Context, m *MailboxAttrMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MailboxAttrCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MailboxAttrUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MailboxAttrUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MailboxAttrDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown MailboxAttr mutation op: %q", m.Op()) + } +} + // MailboxFlagClient is a client for the MailboxFlag schema. type MailboxFlagClient struct { config @@ -523,6 +644,12 @@ func (c *MailboxFlagClient) Use(hooks ...Hook) { c.hooks.MailboxFlag = append(c.hooks.MailboxFlag, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `mailboxflag.Intercept(f(g(h())))`. +func (c *MailboxFlagClient) Intercept(interceptors ...Interceptor) { + c.inters.MailboxFlag = append(c.inters.MailboxFlag, interceptors...) +} + // Create returns a builder for creating a MailboxFlag entity. func (c *MailboxFlagClient) Create() *MailboxFlagCreate { mutation := newMailboxFlagMutation(c.config, OpCreate) @@ -563,7 +690,7 @@ func (c *MailboxFlagClient) DeleteOne(mf *MailboxFlag) *MailboxFlagDeleteOne { return c.DeleteOneID(mf.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MailboxFlagClient) DeleteOneID(id int) *MailboxFlagDeleteOne { builder := c.Delete().Where(mailboxflag.ID(id)) builder.mutation.id = &id @@ -575,6 +702,8 @@ func (c *MailboxFlagClient) DeleteOneID(id int) *MailboxFlagDeleteOne { func (c *MailboxFlagClient) Query() *MailboxFlagQuery { return &MailboxFlagQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMailboxFlag}, + inters: c.Interceptors(), } } @@ -597,6 +726,26 @@ func (c *MailboxFlagClient) Hooks() []Hook { return c.hooks.MailboxFlag } +// Interceptors returns the client interceptors. +func (c *MailboxFlagClient) Interceptors() []Interceptor { + return c.inters.MailboxFlag +} + +func (c *MailboxFlagClient) mutate(ctx context.Context, m *MailboxFlagMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MailboxFlagCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MailboxFlagUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MailboxFlagUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MailboxFlagDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown MailboxFlag mutation op: %q", m.Op()) + } +} + // MailboxPermFlagClient is a client for the MailboxPermFlag schema. type MailboxPermFlagClient struct { config @@ -613,6 +762,12 @@ func (c *MailboxPermFlagClient) Use(hooks ...Hook) { c.hooks.MailboxPermFlag = append(c.hooks.MailboxPermFlag, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `mailboxpermflag.Intercept(f(g(h())))`. +func (c *MailboxPermFlagClient) Intercept(interceptors ...Interceptor) { + c.inters.MailboxPermFlag = append(c.inters.MailboxPermFlag, interceptors...) +} + // Create returns a builder for creating a MailboxPermFlag entity. func (c *MailboxPermFlagClient) Create() *MailboxPermFlagCreate { mutation := newMailboxPermFlagMutation(c.config, OpCreate) @@ -653,7 +808,7 @@ func (c *MailboxPermFlagClient) DeleteOne(mpf *MailboxPermFlag) *MailboxPermFlag return c.DeleteOneID(mpf.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MailboxPermFlagClient) DeleteOneID(id int) *MailboxPermFlagDeleteOne { builder := c.Delete().Where(mailboxpermflag.ID(id)) builder.mutation.id = &id @@ -665,6 +820,8 @@ func (c *MailboxPermFlagClient) DeleteOneID(id int) *MailboxPermFlagDeleteOne { func (c *MailboxPermFlagClient) Query() *MailboxPermFlagQuery { return &MailboxPermFlagQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMailboxPermFlag}, + inters: c.Interceptors(), } } @@ -687,6 +844,26 @@ func (c *MailboxPermFlagClient) Hooks() []Hook { return c.hooks.MailboxPermFlag } +// Interceptors returns the client interceptors. +func (c *MailboxPermFlagClient) Interceptors() []Interceptor { + return c.inters.MailboxPermFlag +} + +func (c *MailboxPermFlagClient) mutate(ctx context.Context, m *MailboxPermFlagMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MailboxPermFlagCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MailboxPermFlagUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MailboxPermFlagUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MailboxPermFlagDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown MailboxPermFlag mutation op: %q", m.Op()) + } +} + // MessageClient is a client for the Message schema. type MessageClient struct { config @@ -703,6 +880,12 @@ func (c *MessageClient) Use(hooks ...Hook) { c.hooks.Message = append(c.hooks.Message, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `message.Intercept(f(g(h())))`. +func (c *MessageClient) Intercept(interceptors ...Interceptor) { + c.inters.Message = append(c.inters.Message, interceptors...) +} + // Create returns a builder for creating a Message entity. func (c *MessageClient) Create() *MessageCreate { mutation := newMessageMutation(c.config, OpCreate) @@ -743,7 +926,7 @@ func (c *MessageClient) DeleteOne(m *Message) *MessageDeleteOne { return c.DeleteOneID(m.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MessageClient) DeleteOneID(id imap.InternalMessageID) *MessageDeleteOne { builder := c.Delete().Where(message.ID(id)) builder.mutation.id = &id @@ -755,6 +938,8 @@ func (c *MessageClient) DeleteOneID(id imap.InternalMessageID) *MessageDeleteOne func (c *MessageClient) Query() *MessageQuery { return &MessageQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMessage}, + inters: c.Interceptors(), } } @@ -774,8 +959,8 @@ func (c *MessageClient) GetX(ctx context.Context, id imap.InternalMessageID) *Me // QueryFlags queries the flags edge of a Message. func (c *MessageClient) QueryFlags(m *Message) *MessageFlagQuery { - query := &MessageFlagQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MessageFlagClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(message.Table, message.FieldID, id), @@ -790,8 +975,8 @@ func (c *MessageClient) QueryFlags(m *Message) *MessageFlagQuery { // QueryUIDs queries the UIDs edge of a Message. func (c *MessageClient) QueryUIDs(m *Message) *UIDQuery { - query := &UIDQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&UIDClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(message.Table, message.FieldID, id), @@ -809,6 +994,26 @@ func (c *MessageClient) Hooks() []Hook { return c.hooks.Message } +// Interceptors returns the client interceptors. +func (c *MessageClient) Interceptors() []Interceptor { + return c.inters.Message +} + +func (c *MessageClient) mutate(ctx context.Context, m *MessageMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MessageCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MessageUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MessageUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MessageDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown Message mutation op: %q", m.Op()) + } +} + // MessageFlagClient is a client for the MessageFlag schema. type MessageFlagClient struct { config @@ -825,6 +1030,12 @@ func (c *MessageFlagClient) Use(hooks ...Hook) { c.hooks.MessageFlag = append(c.hooks.MessageFlag, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `messageflag.Intercept(f(g(h())))`. +func (c *MessageFlagClient) Intercept(interceptors ...Interceptor) { + c.inters.MessageFlag = append(c.inters.MessageFlag, interceptors...) +} + // Create returns a builder for creating a MessageFlag entity. func (c *MessageFlagClient) Create() *MessageFlagCreate { mutation := newMessageFlagMutation(c.config, OpCreate) @@ -865,7 +1076,7 @@ func (c *MessageFlagClient) DeleteOne(mf *MessageFlag) *MessageFlagDeleteOne { return c.DeleteOneID(mf.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MessageFlagClient) DeleteOneID(id int) *MessageFlagDeleteOne { builder := c.Delete().Where(messageflag.ID(id)) builder.mutation.id = &id @@ -877,6 +1088,8 @@ func (c *MessageFlagClient) DeleteOneID(id int) *MessageFlagDeleteOne { func (c *MessageFlagClient) Query() *MessageFlagQuery { return &MessageFlagQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMessageFlag}, + inters: c.Interceptors(), } } @@ -896,8 +1109,8 @@ func (c *MessageFlagClient) GetX(ctx context.Context, id int) *MessageFlag { // QueryMessages queries the messages edge of a MessageFlag. func (c *MessageFlagClient) QueryMessages(mf *MessageFlag) *MessageQuery { - query := &MessageQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MessageClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := mf.ID step := sqlgraph.NewStep( sqlgraph.From(messageflag.Table, messageflag.FieldID, id), @@ -915,6 +1128,26 @@ func (c *MessageFlagClient) Hooks() []Hook { return c.hooks.MessageFlag } +// Interceptors returns the client interceptors. +func (c *MessageFlagClient) Interceptors() []Interceptor { + return c.inters.MessageFlag +} + +func (c *MessageFlagClient) mutate(ctx context.Context, m *MessageFlagMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MessageFlagCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MessageFlagUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MessageFlagUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MessageFlagDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown MessageFlag mutation op: %q", m.Op()) + } +} + // UIDClient is a client for the UID schema. type UIDClient struct { config @@ -931,6 +1164,12 @@ func (c *UIDClient) Use(hooks ...Hook) { c.hooks.UID = append(c.hooks.UID, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `uid.Intercept(f(g(h())))`. +func (c *UIDClient) Intercept(interceptors ...Interceptor) { + c.inters.UID = append(c.inters.UID, interceptors...) +} + // Create returns a builder for creating a UID entity. func (c *UIDClient) Create() *UIDCreate { mutation := newUIDMutation(c.config, OpCreate) @@ -971,7 +1210,7 @@ func (c *UIDClient) DeleteOne(u *UID) *UIDDeleteOne { return c.DeleteOneID(u.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *UIDClient) DeleteOneID(id int) *UIDDeleteOne { builder := c.Delete().Where(uid.ID(id)) builder.mutation.id = &id @@ -983,6 +1222,8 @@ func (c *UIDClient) DeleteOneID(id int) *UIDDeleteOne { func (c *UIDClient) Query() *UIDQuery { return &UIDQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUID}, + inters: c.Interceptors(), } } @@ -1002,8 +1243,8 @@ func (c *UIDClient) GetX(ctx context.Context, id int) *UID { // QueryMessage queries the message edge of a UID. func (c *UIDClient) QueryMessage(u *UID) *MessageQuery { - query := &MessageQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MessageClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := u.ID step := sqlgraph.NewStep( sqlgraph.From(uid.Table, uid.FieldID, id), @@ -1018,8 +1259,8 @@ func (c *UIDClient) QueryMessage(u *UID) *MessageQuery { // QueryMailbox queries the mailbox edge of a UID. func (c *UIDClient) QueryMailbox(u *UID) *MailboxQuery { - query := &MailboxQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MailboxClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := u.ID step := sqlgraph.NewStep( sqlgraph.From(uid.Table, uid.FieldID, id), @@ -1036,3 +1277,23 @@ func (c *UIDClient) QueryMailbox(u *UID) *MailboxQuery { func (c *UIDClient) Hooks() []Hook { return c.hooks.UID } + +// Interceptors returns the client interceptors. +func (c *UIDClient) Interceptors() []Interceptor { + return c.inters.UID +} + +func (c *UIDClient) mutate(ctx context.Context, m *UIDMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&UIDCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&UIDUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&UIDUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&UIDDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown UID mutation op: %q", m.Op()) + } +} diff --git a/internal/db/ent/config.go b/internal/db_impl/ent_db/internal/config.go similarity index 55% rename from internal/db/ent/config.go rename to internal/db_impl/ent_db/internal/config.go index 3f650be6..68d3879b 100644 --- a/internal/db/ent/config.go +++ b/internal/db_impl/ent_db/internal/config.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "entgo.io/ent" @@ -17,22 +17,36 @@ type config struct { // debug enable a debug logging. debug bool // log used for logging on debug mode. - log func(...interface{}) + log func(...any) // hooks to execute on mutations. hooks *hooks + // interceptors to execute on queries. + inters *inters } -// hooks per client, for fast access. -type hooks struct { - DeletedSubscription []ent.Hook - Mailbox []ent.Hook - MailboxAttr []ent.Hook - MailboxFlag []ent.Hook - MailboxPermFlag []ent.Hook - Message []ent.Hook - MessageFlag []ent.Hook - UID []ent.Hook -} +// hooks and interceptors per client, for fast access. +type ( + hooks struct { + DeletedSubscription []ent.Hook + Mailbox []ent.Hook + MailboxAttr []ent.Hook + MailboxFlag []ent.Hook + MailboxPermFlag []ent.Hook + Message []ent.Hook + MessageFlag []ent.Hook + UID []ent.Hook + } + inters struct { + DeletedSubscription []ent.Interceptor + Mailbox []ent.Interceptor + MailboxAttr []ent.Interceptor + MailboxFlag []ent.Interceptor + MailboxPermFlag []ent.Interceptor + Message []ent.Interceptor + MessageFlag []ent.Interceptor + UID []ent.Interceptor + } +) // Options applies the options on the config object. func (c *config) options(opts ...Option) { @@ -52,7 +66,7 @@ func Debug() Option { } // Log sets the logging function for debug mode. -func Log(fn func(...interface{})) Option { +func Log(fn func(...any)) Option { return func(c *config) { c.log = fn } diff --git a/internal/db/ent/context.go b/internal/db_impl/ent_db/internal/context.go similarity index 98% rename from internal/db/ent/context.go rename to internal/db_impl/ent_db/internal/context.go index 7811bfa2..da69c313 100644 --- a/internal/db/ent/context.go +++ b/internal/db_impl/ent_db/internal/context.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" diff --git a/internal/db/ent/deletedsubscription.go b/internal/db_impl/ent_db/internal/deletedsubscription.go similarity index 87% rename from internal/db/ent/deletedsubscription.go rename to internal/db_impl/ent_db/internal/deletedsubscription.go index bdf2ce8a..94b50f29 100644 --- a/internal/db/ent/deletedsubscription.go +++ b/internal/db_impl/ent_db/internal/deletedsubscription.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -8,7 +8,7 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" ) // DeletedSubscription is the model entity for the DeletedSubscription schema. @@ -23,8 +23,8 @@ type DeletedSubscription struct { } // scanValues returns the types for scanning values from sql.Rows. -func (*DeletedSubscription) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*DeletedSubscription) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case deletedsubscription.FieldID: @@ -40,7 +40,7 @@ func (*DeletedSubscription) scanValues(columns []string) ([]interface{}, error) // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the DeletedSubscription fields. -func (ds *DeletedSubscription) assignValues(columns []string, values []interface{}) error { +func (ds *DeletedSubscription) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -73,7 +73,7 @@ func (ds *DeletedSubscription) assignValues(columns []string, values []interface // Note that you need to call DeletedSubscription.Unwrap() before calling this method if this DeletedSubscription // was returned from a transaction, and the transaction was committed or rolled back. func (ds *DeletedSubscription) Update() *DeletedSubscriptionUpdateOne { - return (&DeletedSubscriptionClient{config: ds.config}).UpdateOne(ds) + return NewDeletedSubscriptionClient(ds.config).UpdateOne(ds) } // Unwrap unwraps the DeletedSubscription entity that was returned from a transaction after it was closed, @@ -81,7 +81,7 @@ func (ds *DeletedSubscription) Update() *DeletedSubscriptionUpdateOne { func (ds *DeletedSubscription) Unwrap() *DeletedSubscription { _tx, ok := ds.config.driver.(*txDriver) if !ok { - panic("ent: DeletedSubscription is not a transactional entity") + panic("internal: DeletedSubscription is not a transactional entity") } ds.config.driver = _tx.drv return ds @@ -103,9 +103,3 @@ func (ds *DeletedSubscription) String() string { // DeletedSubscriptions is a parsable slice of DeletedSubscription. type DeletedSubscriptions []*DeletedSubscription - -func (ds DeletedSubscriptions) config(cfg config) { - for _i := range ds { - ds[_i].config = cfg - } -} diff --git a/internal/db/ent/deletedsubscription/deletedsubscription.go b/internal/db_impl/ent_db/internal/deletedsubscription/deletedsubscription.go similarity index 100% rename from internal/db/ent/deletedsubscription/deletedsubscription.go rename to internal/db_impl/ent_db/internal/deletedsubscription/deletedsubscription.go diff --git a/internal/db/ent/deletedsubscription/where.go b/internal/db_impl/ent_db/internal/deletedsubscription/where.go similarity index 57% rename from internal/db/ent/deletedsubscription/where.go rename to internal/db_impl/ent_db/internal/deletedsubscription/where.go index d43f6049..66d35e48 100644 --- a/internal/db/ent/deletedsubscription/where.go +++ b/internal/db_impl/ent_db/internal/deletedsubscription/where.go @@ -5,302 +5,212 @@ package deletedsubscription import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.DeletedSubscription(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.DeletedSubscription(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.DeletedSubscription(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.DeletedSubscription(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.DeletedSubscription(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.DeletedSubscription(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.DeletedSubscription(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.DeletedSubscription(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.DeletedSubscription(sql.FieldLTE(FieldID, id)) } // Name applies equality check predicate on the "Name" field. It's identical to NameEQ. func Name(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldEQ(FieldName, v)) } // RemoteID applies equality check predicate on the "RemoteID" field. It's identical to RemoteIDEQ. func RemoteID(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldEQ(FieldRemoteID, vc)) } // NameEQ applies the EQ predicate on the "Name" field. func NameEQ(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldEQ(FieldName, v)) } // NameNEQ applies the NEQ predicate on the "Name" field. func NameNEQ(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldNEQ(FieldName, v)) } // NameIn applies the In predicate on the "Name" field. func NameIn(vs ...string) predicate.DeletedSubscription { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldName), v...)) - }) + return predicate.DeletedSubscription(sql.FieldIn(FieldName, vs...)) } // NameNotIn applies the NotIn predicate on the "Name" field. func NameNotIn(vs ...string) predicate.DeletedSubscription { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldName), v...)) - }) + return predicate.DeletedSubscription(sql.FieldNotIn(FieldName, vs...)) } // NameGT applies the GT predicate on the "Name" field. func NameGT(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldGT(FieldName, v)) } // NameGTE applies the GTE predicate on the "Name" field. func NameGTE(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldGTE(FieldName, v)) } // NameLT applies the LT predicate on the "Name" field. func NameLT(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldLT(FieldName, v)) } // NameLTE applies the LTE predicate on the "Name" field. func NameLTE(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldLTE(FieldName, v)) } // NameContains applies the Contains predicate on the "Name" field. func NameContains(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldContains(FieldName, v)) } // NameHasPrefix applies the HasPrefix predicate on the "Name" field. func NameHasPrefix(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldHasPrefix(FieldName, v)) } // NameHasSuffix applies the HasSuffix predicate on the "Name" field. func NameHasSuffix(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldHasSuffix(FieldName, v)) } // NameEqualFold applies the EqualFold predicate on the "Name" field. func NameEqualFold(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldEqualFold(FieldName, v)) } // NameContainsFold applies the ContainsFold predicate on the "Name" field. func NameContainsFold(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldContainsFold(FieldName, v)) } // RemoteIDEQ applies the EQ predicate on the "RemoteID" field. func RemoteIDEQ(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldEQ(FieldRemoteID, vc)) } // RemoteIDNEQ applies the NEQ predicate on the "RemoteID" field. func RemoteIDNEQ(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldNEQ(FieldRemoteID, vc)) } // RemoteIDIn applies the In predicate on the "RemoteID" field. func RemoteIDIn(vs ...imap.MailboxID) predicate.DeletedSubscription { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = string(vs[i]) } - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldRemoteID), v...)) - }) + return predicate.DeletedSubscription(sql.FieldIn(FieldRemoteID, v...)) } // RemoteIDNotIn applies the NotIn predicate on the "RemoteID" field. func RemoteIDNotIn(vs ...imap.MailboxID) predicate.DeletedSubscription { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = string(vs[i]) } - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldRemoteID), v...)) - }) + return predicate.DeletedSubscription(sql.FieldNotIn(FieldRemoteID, v...)) } // RemoteIDGT applies the GT predicate on the "RemoteID" field. func RemoteIDGT(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldGT(FieldRemoteID, vc)) } // RemoteIDGTE applies the GTE predicate on the "RemoteID" field. func RemoteIDGTE(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldGTE(FieldRemoteID, vc)) } // RemoteIDLT applies the LT predicate on the "RemoteID" field. func RemoteIDLT(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldLT(FieldRemoteID, vc)) } // RemoteIDLTE applies the LTE predicate on the "RemoteID" field. func RemoteIDLTE(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldLTE(FieldRemoteID, vc)) } // RemoteIDContains applies the Contains predicate on the "RemoteID" field. func RemoteIDContains(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldContains(FieldRemoteID, vc)) } // RemoteIDHasPrefix applies the HasPrefix predicate on the "RemoteID" field. func RemoteIDHasPrefix(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldHasPrefix(FieldRemoteID, vc)) } // RemoteIDHasSuffix applies the HasSuffix predicate on the "RemoteID" field. func RemoteIDHasSuffix(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldHasSuffix(FieldRemoteID, vc)) } // RemoteIDEqualFold applies the EqualFold predicate on the "RemoteID" field. func RemoteIDEqualFold(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldEqualFold(FieldRemoteID, vc)) } // RemoteIDContainsFold applies the ContainsFold predicate on the "RemoteID" field. func RemoteIDContainsFold(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldContainsFold(FieldRemoteID, vc)) } // And groups predicates with the AND operator between them. diff --git a/internal/db/ent/deletedsubscription_create.go b/internal/db_impl/ent_db/internal/deletedsubscription_create.go similarity index 72% rename from internal/db/ent/deletedsubscription_create.go rename to internal/db_impl/ent_db/internal/deletedsubscription_create.go index e5c02209..773c24c1 100644 --- a/internal/db/ent/deletedsubscription_create.go +++ b/internal/db_impl/ent_db/internal/deletedsubscription_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,7 +10,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" ) // DeletedSubscriptionCreate is the builder for creating a DeletedSubscription entity. @@ -39,49 +39,7 @@ func (dsc *DeletedSubscriptionCreate) Mutation() *DeletedSubscriptionMutation { // Save creates the DeletedSubscription in the database. func (dsc *DeletedSubscriptionCreate) Save(ctx context.Context) (*DeletedSubscription, error) { - var ( - err error - node *DeletedSubscription - ) - if len(dsc.hooks) == 0 { - if err = dsc.check(); err != nil { - return nil, err - } - node, err = dsc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DeletedSubscriptionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = dsc.check(); err != nil { - return nil, err - } - dsc.mutation = mutation - if node, err = dsc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(dsc.hooks) - 1; i >= 0; i-- { - if dsc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dsc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, dsc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*DeletedSubscription) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from DeletedSubscriptionMutation", v) - } - node = nv - } - return node, err + return withHooks[*DeletedSubscription, DeletedSubscriptionMutation](ctx, dsc.sqlSave, dsc.mutation, dsc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -109,15 +67,18 @@ func (dsc *DeletedSubscriptionCreate) ExecX(ctx context.Context) { // check runs all checks and user-defined validators on the builder. func (dsc *DeletedSubscriptionCreate) check() error { if _, ok := dsc.mutation.Name(); !ok { - return &ValidationError{Name: "Name", err: errors.New(`ent: missing required field "DeletedSubscription.Name"`)} + return &ValidationError{Name: "Name", err: errors.New(`internal: missing required field "DeletedSubscription.Name"`)} } if _, ok := dsc.mutation.RemoteID(); !ok { - return &ValidationError{Name: "RemoteID", err: errors.New(`ent: missing required field "DeletedSubscription.RemoteID"`)} + return &ValidationError{Name: "RemoteID", err: errors.New(`internal: missing required field "DeletedSubscription.RemoteID"`)} } return nil } func (dsc *DeletedSubscriptionCreate) sqlSave(ctx context.Context) (*DeletedSubscription, error) { + if err := dsc.check(); err != nil { + return nil, err + } _node, _spec := dsc.createSpec() if err := sqlgraph.CreateNode(ctx, dsc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -127,34 +88,22 @@ func (dsc *DeletedSubscriptionCreate) sqlSave(ctx context.Context) (*DeletedSubs } id := _spec.ID.Value.(int64) _node.ID = int(id) + dsc.mutation.id = &_node.ID + dsc.mutation.done = true return _node, nil } func (dsc *DeletedSubscriptionCreate) createSpec() (*DeletedSubscription, *sqlgraph.CreateSpec) { var ( _node = &DeletedSubscription{config: dsc.config} - _spec = &sqlgraph.CreateSpec{ - Table: deletedsubscription.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: deletedsubscription.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(deletedsubscription.Table, sqlgraph.NewFieldSpec(deletedsubscription.FieldID, field.TypeInt)) ) if value, ok := dsc.mutation.Name(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: deletedsubscription.FieldName, - }) + _spec.SetField(deletedsubscription.FieldName, field.TypeString, value) _node.Name = value } if value, ok := dsc.mutation.RemoteID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: deletedsubscription.FieldRemoteID, - }) + _spec.SetField(deletedsubscription.FieldRemoteID, field.TypeString, value) _node.RemoteID = value } return _node, _spec diff --git a/internal/db/ent/deletedsubscription_delete.go b/internal/db_impl/ent_db/internal/deletedsubscription_delete.go similarity index 63% rename from internal/db/ent/deletedsubscription_delete.go rename to internal/db_impl/ent_db/internal/deletedsubscription_delete.go index df358765..f55bbfca 100644 --- a/internal/db/ent/deletedsubscription_delete.go +++ b/internal/db_impl/ent_db/internal/deletedsubscription_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // DeletedSubscriptionDelete is the builder for deleting a DeletedSubscription entity. @@ -28,34 +27,7 @@ func (dsd *DeletedSubscriptionDelete) Where(ps ...predicate.DeletedSubscription) // Exec executes the deletion query and returns how many vertices were deleted. func (dsd *DeletedSubscriptionDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(dsd.hooks) == 0 { - affected, err = dsd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DeletedSubscriptionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - dsd.mutation = mutation - affected, err = dsd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(dsd.hooks) - 1; i >= 0; i-- { - if dsd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dsd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, dsd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, DeletedSubscriptionMutation](ctx, dsd.sqlExec, dsd.mutation, dsd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (dsd *DeletedSubscriptionDelete) ExecX(ctx context.Context) int { } func (dsd *DeletedSubscriptionDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: deletedsubscription.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: deletedsubscription.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(deletedsubscription.Table, sqlgraph.NewFieldSpec(deletedsubscription.FieldID, field.TypeInt)) if ps := dsd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (dsd *DeletedSubscriptionDelete) sqlExec(ctx context.Context) (int, error) if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + dsd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type DeletedSubscriptionDeleteOne struct { dsd *DeletedSubscriptionDelete } +// Where appends a list predicates to the DeletedSubscriptionDelete builder. +func (dsdo *DeletedSubscriptionDeleteOne) Where(ps ...predicate.DeletedSubscription) *DeletedSubscriptionDeleteOne { + dsdo.dsd.mutation.Where(ps...) + return dsdo +} + // Exec executes the deletion query. func (dsdo *DeletedSubscriptionDeleteOne) Exec(ctx context.Context) error { n, err := dsdo.dsd.Exec(ctx) @@ -111,5 +82,7 @@ func (dsdo *DeletedSubscriptionDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (dsdo *DeletedSubscriptionDeleteOne) ExecX(ctx context.Context) { - dsdo.dsd.ExecX(ctx) + if err := dsdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/deletedsubscription_query.go b/internal/db_impl/ent_db/internal/deletedsubscription_query.go similarity index 68% rename from internal/db/ent/deletedsubscription_query.go rename to internal/db_impl/ent_db/internal/deletedsubscription_query.go index 93472a0a..78107132 100644 --- a/internal/db/ent/deletedsubscription_query.go +++ b/internal/db_impl/ent_db/internal/deletedsubscription_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,18 +10,16 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // DeletedSubscriptionQuery is the builder for querying DeletedSubscription entities. type DeletedSubscriptionQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.DeletedSubscription // intermediate query (i.e. traversal path). sql *sql.Selector @@ -34,26 +32,26 @@ func (dsq *DeletedSubscriptionQuery) Where(ps ...predicate.DeletedSubscription) return dsq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (dsq *DeletedSubscriptionQuery) Limit(limit int) *DeletedSubscriptionQuery { - dsq.limit = &limit + dsq.ctx.Limit = &limit return dsq } -// Offset adds an offset step to the query. +// Offset to start from. func (dsq *DeletedSubscriptionQuery) Offset(offset int) *DeletedSubscriptionQuery { - dsq.offset = &offset + dsq.ctx.Offset = &offset return dsq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (dsq *DeletedSubscriptionQuery) Unique(unique bool) *DeletedSubscriptionQuery { - dsq.unique = &unique + dsq.ctx.Unique = &unique return dsq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (dsq *DeletedSubscriptionQuery) Order(o ...OrderFunc) *DeletedSubscriptionQuery { dsq.order = append(dsq.order, o...) return dsq @@ -62,7 +60,7 @@ func (dsq *DeletedSubscriptionQuery) Order(o ...OrderFunc) *DeletedSubscriptionQ // First returns the first DeletedSubscription entity from the query. // Returns a *NotFoundError when no DeletedSubscription was found. func (dsq *DeletedSubscriptionQuery) First(ctx context.Context) (*DeletedSubscription, error) { - nodes, err := dsq.Limit(1).All(ctx) + nodes, err := dsq.Limit(1).All(setContextOp(ctx, dsq.ctx, "First")) if err != nil { return nil, err } @@ -85,7 +83,7 @@ func (dsq *DeletedSubscriptionQuery) FirstX(ctx context.Context) *DeletedSubscri // Returns a *NotFoundError when no DeletedSubscription ID was found. func (dsq *DeletedSubscriptionQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = dsq.Limit(1).IDs(ctx); err != nil { + if ids, err = dsq.Limit(1).IDs(setContextOp(ctx, dsq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -108,7 +106,7 @@ func (dsq *DeletedSubscriptionQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one DeletedSubscription entity is found. // Returns a *NotFoundError when no DeletedSubscription entities are found. func (dsq *DeletedSubscriptionQuery) Only(ctx context.Context) (*DeletedSubscription, error) { - nodes, err := dsq.Limit(2).All(ctx) + nodes, err := dsq.Limit(2).All(setContextOp(ctx, dsq.ctx, "Only")) if err != nil { return nil, err } @@ -136,7 +134,7 @@ func (dsq *DeletedSubscriptionQuery) OnlyX(ctx context.Context) *DeletedSubscrip // Returns a *NotFoundError when no entities are found. func (dsq *DeletedSubscriptionQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = dsq.Limit(2).IDs(ctx); err != nil { + if ids, err = dsq.Limit(2).IDs(setContextOp(ctx, dsq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -161,10 +159,12 @@ func (dsq *DeletedSubscriptionQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of DeletedSubscriptions. func (dsq *DeletedSubscriptionQuery) All(ctx context.Context) ([]*DeletedSubscription, error) { + ctx = setContextOp(ctx, dsq.ctx, "All") if err := dsq.prepareQuery(ctx); err != nil { return nil, err } - return dsq.sqlAll(ctx) + qr := querierAll[[]*DeletedSubscription, *DeletedSubscriptionQuery]() + return withInterceptors[[]*DeletedSubscription](ctx, dsq, qr, dsq.inters) } // AllX is like All, but panics if an error occurs. @@ -177,9 +177,12 @@ func (dsq *DeletedSubscriptionQuery) AllX(ctx context.Context) []*DeletedSubscri } // IDs executes the query and returns a list of DeletedSubscription IDs. -func (dsq *DeletedSubscriptionQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := dsq.Select(deletedsubscription.FieldID).Scan(ctx, &ids); err != nil { +func (dsq *DeletedSubscriptionQuery) IDs(ctx context.Context) (ids []int, err error) { + if dsq.ctx.Unique == nil && dsq.path != nil { + dsq.Unique(true) + } + ctx = setContextOp(ctx, dsq.ctx, "IDs") + if err = dsq.Select(deletedsubscription.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -196,10 +199,11 @@ func (dsq *DeletedSubscriptionQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (dsq *DeletedSubscriptionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, dsq.ctx, "Count") if err := dsq.prepareQuery(ctx); err != nil { return 0, err } - return dsq.sqlCount(ctx) + return withInterceptors[int](ctx, dsq, querierCount[*DeletedSubscriptionQuery](), dsq.inters) } // CountX is like Count, but panics if an error occurs. @@ -213,10 +217,15 @@ func (dsq *DeletedSubscriptionQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (dsq *DeletedSubscriptionQuery) Exist(ctx context.Context) (bool, error) { - if err := dsq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, dsq.ctx, "Exist") + switch _, err := dsq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return dsq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -236,14 +245,13 @@ func (dsq *DeletedSubscriptionQuery) Clone() *DeletedSubscriptionQuery { } return &DeletedSubscriptionQuery{ config: dsq.config, - limit: dsq.limit, - offset: dsq.offset, + ctx: dsq.ctx.Clone(), order: append([]OrderFunc{}, dsq.order...), + inters: append([]Interceptor{}, dsq.inters...), predicates: append([]predicate.DeletedSubscription{}, dsq.predicates...), // clone intermediate query. - sql: dsq.sql.Clone(), - path: dsq.path, - unique: dsq.unique, + sql: dsq.sql.Clone(), + path: dsq.path, } } @@ -259,19 +267,14 @@ func (dsq *DeletedSubscriptionQuery) Clone() *DeletedSubscriptionQuery { // // client.DeletedSubscription.Query(). // GroupBy(deletedsubscription.FieldName). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (dsq *DeletedSubscriptionQuery) GroupBy(field string, fields ...string) *DeletedSubscriptionGroupBy { - grbuild := &DeletedSubscriptionGroupBy{config: dsq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := dsq.prepareQuery(ctx); err != nil { - return nil, err - } - return dsq.sqlQuery(ctx), nil - } + dsq.ctx.Fields = append([]string{field}, fields...) + grbuild := &DeletedSubscriptionGroupBy{build: dsq} + grbuild.flds = &dsq.ctx.Fields grbuild.label = deletedsubscription.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -288,17 +291,32 @@ func (dsq *DeletedSubscriptionQuery) GroupBy(field string, fields ...string) *De // Select(deletedsubscription.FieldName). // Scan(ctx, &v) func (dsq *DeletedSubscriptionQuery) Select(fields ...string) *DeletedSubscriptionSelect { - dsq.fields = append(dsq.fields, fields...) - selbuild := &DeletedSubscriptionSelect{DeletedSubscriptionQuery: dsq} - selbuild.label = deletedsubscription.Label - selbuild.flds, selbuild.scan = &dsq.fields, selbuild.Scan - return selbuild + dsq.ctx.Fields = append(dsq.ctx.Fields, fields...) + sbuild := &DeletedSubscriptionSelect{DeletedSubscriptionQuery: dsq} + sbuild.label = deletedsubscription.Label + sbuild.flds, sbuild.scan = &dsq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a DeletedSubscriptionSelect configured with the given aggregations. +func (dsq *DeletedSubscriptionQuery) Aggregate(fns ...AggregateFunc) *DeletedSubscriptionSelect { + return dsq.Select().Aggregate(fns...) } func (dsq *DeletedSubscriptionQuery) prepareQuery(ctx context.Context) error { - for _, f := range dsq.fields { + for _, inter := range dsq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, dsq); err != nil { + return err + } + } + } + for _, f := range dsq.ctx.Fields { if !deletedsubscription.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if dsq.path != nil { @@ -316,10 +334,10 @@ func (dsq *DeletedSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryH nodes = []*DeletedSubscription{} _spec = dsq.querySpec() ) - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*DeletedSubscription).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &DeletedSubscription{config: dsq.config} nodes = append(nodes, node) return node.assignValues(columns, values) @@ -338,38 +356,22 @@ func (dsq *DeletedSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryH func (dsq *DeletedSubscriptionQuery) sqlCount(ctx context.Context) (int, error) { _spec := dsq.querySpec() - _spec.Node.Columns = dsq.fields - if len(dsq.fields) > 0 { - _spec.Unique = dsq.unique != nil && *dsq.unique + _spec.Node.Columns = dsq.ctx.Fields + if len(dsq.ctx.Fields) > 0 { + _spec.Unique = dsq.ctx.Unique != nil && *dsq.ctx.Unique } return sqlgraph.CountNodes(ctx, dsq.driver, _spec) } -func (dsq *DeletedSubscriptionQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := dsq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (dsq *DeletedSubscriptionQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: deletedsubscription.Table, - Columns: deletedsubscription.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: deletedsubscription.FieldID, - }, - }, - From: dsq.sql, - Unique: true, - } - if unique := dsq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(deletedsubscription.Table, deletedsubscription.Columns, sqlgraph.NewFieldSpec(deletedsubscription.FieldID, field.TypeInt)) + _spec.From = dsq.sql + if unique := dsq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if dsq.path != nil { + _spec.Unique = true } - if fields := dsq.fields; len(fields) > 0 { + if fields := dsq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, deletedsubscription.FieldID) for i := range fields { @@ -385,10 +387,10 @@ func (dsq *DeletedSubscriptionQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := dsq.limit; limit != nil { + if limit := dsq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := dsq.offset; offset != nil { + if offset := dsq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := dsq.order; len(ps) > 0 { @@ -404,7 +406,7 @@ func (dsq *DeletedSubscriptionQuery) querySpec() *sqlgraph.QuerySpec { func (dsq *DeletedSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(dsq.driver.Dialect()) t1 := builder.Table(deletedsubscription.Table) - columns := dsq.fields + columns := dsq.ctx.Fields if len(columns) == 0 { columns = deletedsubscription.Columns } @@ -413,7 +415,7 @@ func (dsq *DeletedSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector selector = dsq.sql selector.Select(selector.Columns(columns...)...) } - if dsq.unique != nil && *dsq.unique { + if dsq.ctx.Unique != nil && *dsq.ctx.Unique { selector.Distinct() } for _, p := range dsq.predicates { @@ -422,12 +424,12 @@ func (dsq *DeletedSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector for _, p := range dsq.order { p(selector) } - if offset := dsq.offset; offset != nil { + if offset := dsq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := dsq.limit; limit != nil { + if limit := dsq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -435,13 +437,8 @@ func (dsq *DeletedSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector // DeletedSubscriptionGroupBy is the group-by builder for DeletedSubscription entities. type DeletedSubscriptionGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *DeletedSubscriptionQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -450,74 +447,77 @@ func (dsgb *DeletedSubscriptionGroupBy) Aggregate(fns ...AggregateFunc) *Deleted return dsgb } -// Scan applies the group-by query and scans the result into the given value. -func (dsgb *DeletedSubscriptionGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := dsgb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (dsgb *DeletedSubscriptionGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, dsgb.build.ctx, "GroupBy") + if err := dsgb.build.prepareQuery(ctx); err != nil { return err } - dsgb.sql = query - return dsgb.sqlScan(ctx, v) + return scanWithInterceptors[*DeletedSubscriptionQuery, *DeletedSubscriptionGroupBy](ctx, dsgb.build, dsgb, dsgb.build.inters, v) } -func (dsgb *DeletedSubscriptionGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range dsgb.fields { - if !deletedsubscription.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (dsgb *DeletedSubscriptionGroupBy) sqlScan(ctx context.Context, root *DeletedSubscriptionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(dsgb.fns)) + for _, fn := range dsgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*dsgb.flds)+len(dsgb.fns)) + for _, f := range *dsgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := dsgb.sqlQuery() + selector.GroupBy(selector.Columns(*dsgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := dsgb.driver.Query(ctx, query, args, rows); err != nil { + if err := dsgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (dsgb *DeletedSubscriptionGroupBy) sqlQuery() *sql.Selector { - selector := dsgb.sql.Select() - aggregation := make([]string, 0, len(dsgb.fns)) - for _, fn := range dsgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(dsgb.fields)+len(dsgb.fns)) - for _, f := range dsgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(dsgb.fields...)...) -} - // DeletedSubscriptionSelect is the builder for selecting fields of DeletedSubscription entities. type DeletedSubscriptionSelect struct { *DeletedSubscriptionQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (dss *DeletedSubscriptionSelect) Aggregate(fns ...AggregateFunc) *DeletedSubscriptionSelect { + dss.fns = append(dss.fns, fns...) + return dss } // Scan applies the selector query and scans the result into the given value. -func (dss *DeletedSubscriptionSelect) Scan(ctx context.Context, v interface{}) error { +func (dss *DeletedSubscriptionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, dss.ctx, "Select") if err := dss.prepareQuery(ctx); err != nil { return err } - dss.sql = dss.DeletedSubscriptionQuery.sqlQuery(ctx) - return dss.sqlScan(ctx, v) + return scanWithInterceptors[*DeletedSubscriptionQuery, *DeletedSubscriptionSelect](ctx, dss.DeletedSubscriptionQuery, dss, dss.inters, v) } -func (dss *DeletedSubscriptionSelect) sqlScan(ctx context.Context, v interface{}) error { +func (dss *DeletedSubscriptionSelect) sqlScan(ctx context.Context, root *DeletedSubscriptionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(dss.fns)) + for _, fn := range dss.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*dss.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := dss.sql.Query() + query, args := selector.Query() if err := dss.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/deletedsubscription_update.go b/internal/db_impl/ent_db/internal/deletedsubscription_update.go similarity index 64% rename from internal/db/ent/deletedsubscription_update.go rename to internal/db_impl/ent_db/internal/deletedsubscription_update.go index f28054c5..4ec23db6 100644 --- a/internal/db/ent/deletedsubscription_update.go +++ b/internal/db_impl/ent_db/internal/deletedsubscription_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -11,8 +11,8 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // DeletedSubscriptionUpdate is the builder for updating DeletedSubscription entities. @@ -47,34 +47,7 @@ func (dsu *DeletedSubscriptionUpdate) Mutation() *DeletedSubscriptionMutation { // Save executes the query and returns the number of nodes affected by the update operation. func (dsu *DeletedSubscriptionUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(dsu.hooks) == 0 { - affected, err = dsu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DeletedSubscriptionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - dsu.mutation = mutation - affected, err = dsu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(dsu.hooks) - 1; i >= 0; i-- { - if dsu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dsu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, dsu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, DeletedSubscriptionMutation](ctx, dsu.sqlSave, dsu.mutation, dsu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -100,16 +73,7 @@ func (dsu *DeletedSubscriptionUpdate) ExecX(ctx context.Context) { } func (dsu *DeletedSubscriptionUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: deletedsubscription.Table, - Columns: deletedsubscription.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: deletedsubscription.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(deletedsubscription.Table, deletedsubscription.Columns, sqlgraph.NewFieldSpec(deletedsubscription.FieldID, field.TypeInt)) if ps := dsu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -118,18 +82,10 @@ func (dsu *DeletedSubscriptionUpdate) sqlSave(ctx context.Context) (n int, err e } } if value, ok := dsu.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: deletedsubscription.FieldName, - }) + _spec.SetField(deletedsubscription.FieldName, field.TypeString, value) } if value, ok := dsu.mutation.RemoteID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: deletedsubscription.FieldRemoteID, - }) + _spec.SetField(deletedsubscription.FieldRemoteID, field.TypeString, value) } if n, err = sqlgraph.UpdateNodes(ctx, dsu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -139,6 +95,7 @@ func (dsu *DeletedSubscriptionUpdate) sqlSave(ctx context.Context) (n int, err e } return 0, err } + dsu.mutation.done = true return n, nil } @@ -167,6 +124,12 @@ func (dsuo *DeletedSubscriptionUpdateOne) Mutation() *DeletedSubscriptionMutatio return dsuo.mutation } +// Where appends a list predicates to the DeletedSubscriptionUpdate builder. +func (dsuo *DeletedSubscriptionUpdateOne) Where(ps ...predicate.DeletedSubscription) *DeletedSubscriptionUpdateOne { + dsuo.mutation.Where(ps...) + return dsuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (dsuo *DeletedSubscriptionUpdateOne) Select(field string, fields ...string) *DeletedSubscriptionUpdateOne { @@ -176,40 +139,7 @@ func (dsuo *DeletedSubscriptionUpdateOne) Select(field string, fields ...string) // Save executes the query and returns the updated DeletedSubscription entity. func (dsuo *DeletedSubscriptionUpdateOne) Save(ctx context.Context) (*DeletedSubscription, error) { - var ( - err error - node *DeletedSubscription - ) - if len(dsuo.hooks) == 0 { - node, err = dsuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DeletedSubscriptionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - dsuo.mutation = mutation - node, err = dsuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(dsuo.hooks) - 1; i >= 0; i-- { - if dsuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dsuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, dsuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*DeletedSubscription) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from DeletedSubscriptionMutation", v) - } - node = nv - } - return node, err + return withHooks[*DeletedSubscription, DeletedSubscriptionMutation](ctx, dsuo.sqlSave, dsuo.mutation, dsuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -235,19 +165,10 @@ func (dsuo *DeletedSubscriptionUpdateOne) ExecX(ctx context.Context) { } func (dsuo *DeletedSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *DeletedSubscription, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: deletedsubscription.Table, - Columns: deletedsubscription.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: deletedsubscription.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(deletedsubscription.Table, deletedsubscription.Columns, sqlgraph.NewFieldSpec(deletedsubscription.FieldID, field.TypeInt)) id, ok := dsuo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "DeletedSubscription.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "DeletedSubscription.id" for update`)} } _spec.Node.ID.Value = id if fields := dsuo.fields; len(fields) > 0 { @@ -255,7 +176,7 @@ func (dsuo *DeletedSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *D _spec.Node.Columns = append(_spec.Node.Columns, deletedsubscription.FieldID) for _, f := range fields { if !deletedsubscription.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != deletedsubscription.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -270,18 +191,10 @@ func (dsuo *DeletedSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *D } } if value, ok := dsuo.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: deletedsubscription.FieldName, - }) + _spec.SetField(deletedsubscription.FieldName, field.TypeString, value) } if value, ok := dsuo.mutation.RemoteID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: deletedsubscription.FieldRemoteID, - }) + _spec.SetField(deletedsubscription.FieldRemoteID, field.TypeString, value) } _node = &DeletedSubscription{config: dsuo.config} _spec.Assign = _node.assignValues @@ -294,5 +207,6 @@ func (dsuo *DeletedSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *D } return nil, err } + dsuo.mutation.done = true return _node, nil } diff --git a/internal/db/ent/ent.go b/internal/db_impl/ent_db/internal/ent.go similarity index 63% rename from internal/db/ent/ent.go rename to internal/db_impl/ent_db/internal/ent.go index b5c03992..c63f197f 100644 --- a/internal/db/ent/ent.go +++ b/internal/db_impl/ent_db/internal/ent.go @@ -1,35 +1,43 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" "errors" "fmt" + "reflect" "entgo.io/ent" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // ent aliases to avoid import conflicts in user's code. type ( - Op = ent.Op - Hook = ent.Hook - Value = ent.Value - Query = ent.Query - Policy = ent.Policy - Mutator = ent.Mutator - Mutation = ent.Mutation - MutateFunc = ent.MutateFunc + Op = ent.Op + Hook = ent.Hook + Value = ent.Value + Query = ent.Query + QueryContext = ent.QueryContext + Querier = ent.Querier + QuerierFunc = ent.QuerierFunc + Interceptor = ent.Interceptor + InterceptFunc = ent.InterceptFunc + Traverser = ent.Traverser + TraverseFunc = ent.TraverseFunc + Policy = ent.Policy + Mutator = ent.Mutator + Mutation = ent.Mutation + MutateFunc = ent.MutateFunc ) // OrderFunc applies an ordering on the sql selector. @@ -67,7 +75,7 @@ func Asc(fields ...string) OrderFunc { check := columnChecker(s.TableName()) for _, f := range fields { if err := check(f); err != nil { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("internal: %w", err)}) } s.OrderBy(sql.Asc(s.C(f))) } @@ -80,7 +88,7 @@ func Desc(fields ...string) OrderFunc { check := columnChecker(s.TableName()) for _, f := range fields { if err := check(f); err != nil { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("internal: %w", err)}) } s.OrderBy(sql.Desc(s.C(f))) } @@ -93,7 +101,7 @@ type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // // GroupBy(field1, field2). -// Aggregate(ent.As(ent.Sum(field1), "sum_field1"), (ent.As(ent.Sum(field2), "sum_field2")). +// Aggregate(internal.As(internal.Sum(field1), "sum_field1"), (internal.As(internal.Sum(field2), "sum_field2")). // Scan(ctx, &v) func As(fn AggregateFunc, end string) AggregateFunc { return func(s *sql.Selector) string { @@ -113,7 +121,7 @@ func Max(field string) AggregateFunc { return func(s *sql.Selector) string { check := columnChecker(s.TableName()) if err := check(field); err != nil { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("internal: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -125,7 +133,7 @@ func Mean(field string) AggregateFunc { return func(s *sql.Selector) string { check := columnChecker(s.TableName()) if err := check(field); err != nil { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("internal: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -137,7 +145,7 @@ func Min(field string) AggregateFunc { return func(s *sql.Selector) string { check := columnChecker(s.TableName()) if err := check(field); err != nil { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("internal: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -149,7 +157,7 @@ func Sum(field string) AggregateFunc { return func(s *sql.Selector) string { check := columnChecker(s.TableName()) if err := check(field); err != nil { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("internal: %w", err)}) return "" } return sql.Sum(s.C(field)) @@ -188,7 +196,7 @@ type NotFoundError struct { // Error implements the error interface. func (e *NotFoundError) Error() string { - return "ent: " + e.label + " not found" + return "internal: " + e.label + " not found" } // IsNotFound returns a boolean indicating whether the error is a not found error. @@ -215,7 +223,7 @@ type NotSingularError struct { // Error implements the error interface. func (e *NotSingularError) Error() string { - return "ent: " + e.label + " not singular" + return "internal: " + e.label + " not singular" } // IsNotSingular returns a boolean indicating whether the error is a not singular error. @@ -234,7 +242,7 @@ type NotLoadedError struct { // Error implements the error interface. func (e *NotLoadedError) Error() string { - return "ent: " + e.edge + " edge was not loaded" + return "internal: " + e.edge + " edge was not loaded" } // IsNotLoaded returns a boolean indicating whether the error is a not loaded error. @@ -256,7 +264,7 @@ type ConstraintError struct { // Error implements the error interface. func (e ConstraintError) Error() string { - return "ent: constraint failed: " + e.msg + return "internal: constraint failed: " + e.msg } // Unwrap implements the errors.Wrapper interface. @@ -277,11 +285,12 @@ func IsConstraintError(err error) bool { type selector struct { label string flds *[]string - scan func(context.Context, interface{}) error + fns []AggregateFunc + scan func(context.Context, any) error } // ScanX is like Scan, but panics if an error occurs. -func (s *selector) ScanX(ctx context.Context, v interface{}) { +func (s *selector) ScanX(ctx context.Context, v any) { if err := s.scan(ctx, v); err != nil { panic(err) } @@ -290,7 +299,7 @@ func (s *selector) ScanX(ctx context.Context, v interface{}) { // Strings returns list of strings from a selector. It is only allowed when selecting one field. func (s *selector) Strings(ctx context.Context) ([]string, error) { if len(*s.flds) > 1 { - return nil, errors.New("ent: Strings is not achievable when selecting more than 1 field") + return nil, errors.New("internal: Strings is not achievable when selecting more than 1 field") } var v []string if err := s.scan(ctx, &v); err != nil { @@ -320,7 +329,7 @@ func (s *selector) String(ctx context.Context) (_ string, err error) { case 0: err = &NotFoundError{s.label} default: - err = fmt.Errorf("ent: Strings returned %d results when one was expected", len(v)) + err = fmt.Errorf("internal: Strings returned %d results when one was expected", len(v)) } return } @@ -337,7 +346,7 @@ func (s *selector) StringX(ctx context.Context) string { // Ints returns list of ints from a selector. It is only allowed when selecting one field. func (s *selector) Ints(ctx context.Context) ([]int, error) { if len(*s.flds) > 1 { - return nil, errors.New("ent: Ints is not achievable when selecting more than 1 field") + return nil, errors.New("internal: Ints is not achievable when selecting more than 1 field") } var v []int if err := s.scan(ctx, &v); err != nil { @@ -367,7 +376,7 @@ func (s *selector) Int(ctx context.Context) (_ int, err error) { case 0: err = &NotFoundError{s.label} default: - err = fmt.Errorf("ent: Ints returned %d results when one was expected", len(v)) + err = fmt.Errorf("internal: Ints returned %d results when one was expected", len(v)) } return } @@ -384,7 +393,7 @@ func (s *selector) IntX(ctx context.Context) int { // Float64s returns list of float64s from a selector. It is only allowed when selecting one field. func (s *selector) Float64s(ctx context.Context) ([]float64, error) { if len(*s.flds) > 1 { - return nil, errors.New("ent: Float64s is not achievable when selecting more than 1 field") + return nil, errors.New("internal: Float64s is not achievable when selecting more than 1 field") } var v []float64 if err := s.scan(ctx, &v); err != nil { @@ -414,7 +423,7 @@ func (s *selector) Float64(ctx context.Context) (_ float64, err error) { case 0: err = &NotFoundError{s.label} default: - err = fmt.Errorf("ent: Float64s returned %d results when one was expected", len(v)) + err = fmt.Errorf("internal: Float64s returned %d results when one was expected", len(v)) } return } @@ -431,7 +440,7 @@ func (s *selector) Float64X(ctx context.Context) float64 { // Bools returns list of bools from a selector. It is only allowed when selecting one field. func (s *selector) Bools(ctx context.Context) ([]bool, error) { if len(*s.flds) > 1 { - return nil, errors.New("ent: Bools is not achievable when selecting more than 1 field") + return nil, errors.New("internal: Bools is not achievable when selecting more than 1 field") } var v []bool if err := s.scan(ctx, &v); err != nil { @@ -461,7 +470,7 @@ func (s *selector) Bool(ctx context.Context) (_ bool, err error) { case 0: err = &NotFoundError{s.label} default: - err = fmt.Errorf("ent: Bools returned %d results when one was expected", len(v)) + err = fmt.Errorf("internal: Bools returned %d results when one was expected", len(v)) } return } @@ -475,5 +484,121 @@ func (s *selector) BoolX(ctx context.Context) bool { return v } +// withHooks invokes the builder operation with the given hooks, if any. +func withHooks[V Value, M any, PM interface { + *M + Mutation +}](ctx context.Context, exec func(context.Context) (V, error), mutation PM, hooks []Hook) (value V, err error) { + if len(hooks) == 0 { + return exec(ctx) + } + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutationT, ok := m.(PM) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + // Set the mutation to the builder. + *mutation = *mutationT + return exec(ctx) + }) + for i := len(hooks) - 1; i >= 0; i-- { + if hooks[i] == nil { + return value, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") + } + mut = hooks[i](mut) + } + v, err := mut.Mutate(ctx, mutation) + if err != nil { + return value, err + } + nv, ok := v.(V) + if !ok { + return value, fmt.Errorf("unexpected node type %T returned from %T", v, mutation) + } + return nv, nil +} + +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { + if ent.QueryFromContext(ctx) == nil { + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) + } + return ctx +} + +func querierAll[V Value, Q interface { + sqlAll(context.Context, ...queryHook) (V, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlAll(ctx) + }) +} + +func querierCount[Q interface { + sqlCount(context.Context) (int, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlCount(ctx) + }) +} + +func withInterceptors[V Value](ctx context.Context, q Query, qr Querier, inters []Interceptor) (v V, err error) { + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + rv, err := qr.Query(ctx, q) + if err != nil { + return v, err + } + vt, ok := rv.(V) + if !ok { + return v, fmt.Errorf("unexpected type %T returned from %T. expected type: %T", vt, q, v) + } + return vt, nil +} + +func scanWithInterceptors[Q1 ent.Query, Q2 interface { + sqlScan(context.Context, Q1, any) error +}](ctx context.Context, rootQuery Q1, selectOrGroup Q2, inters []Interceptor, v any) error { + rv := reflect.ValueOf(v) + var qr Querier = QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q1) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + if err := selectOrGroup.sqlScan(ctx, query, v); err != nil { + return nil, err + } + if k := rv.Kind(); k == reflect.Pointer && rv.Elem().CanInterface() { + return rv.Elem().Interface(), nil + } + return v, nil + }) + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + vv, err := qr.Query(ctx, rootQuery) + if err != nil { + return err + } + switch rv2 := reflect.ValueOf(vv); { + case rv.IsNil(), rv2.IsNil(), rv.Kind() != reflect.Pointer: + case rv.Type() == rv2.Type(): + rv.Elem().Set(rv2.Elem()) + case rv.Elem().Type() == rv2.Type(): + rv.Elem().Set(rv2) + } + return nil +} + // queryHook describes an internal hook for the different sqlAll methods. type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/internal/db/ent/enttest/enttest.go b/internal/db_impl/ent_db/internal/enttest/enttest.go similarity index 65% rename from internal/db/ent/enttest/enttest.go rename to internal/db_impl/ent_db/internal/enttest/enttest.go index 3dcd985b..0b65a31b 100644 --- a/internal/db/ent/enttest/enttest.go +++ b/internal/db_impl/ent_db/internal/enttest/enttest.go @@ -5,12 +5,12 @@ package enttest import ( "context" - "github.com/ProtonMail/gluon/internal/db/ent" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" // required by schema hooks. - _ "github.com/ProtonMail/gluon/internal/db/ent/runtime" + _ "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/runtime" "entgo.io/ent/dialect/sql/schema" - "github.com/ProtonMail/gluon/internal/db/ent/migrate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/migrate" ) type ( @@ -18,20 +18,20 @@ type ( // testing.T and testing.B and used by enttest. TestingT interface { FailNow() - Error(...interface{}) + Error(...any) } // Option configures client creation. Option func(*options) options struct { - opts []ent.Option + opts []internal.Option migrateOpts []schema.MigrateOption } ) // WithOptions forwards options to client creation. -func WithOptions(opts ...ent.Option) Option { +func WithOptions(opts ...internal.Option) Option { return func(o *options) { o.opts = append(o.opts, opts...) } @@ -52,10 +52,10 @@ func newOptions(opts []Option) *options { return o } -// Open calls ent.Open and auto-run migration. -func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *ent.Client { +// Open calls internal.Open and auto-run migration. +func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *internal.Client { o := newOptions(opts) - c, err := ent.Open(driverName, dataSourceName, o.opts...) + c, err := internal.Open(driverName, dataSourceName, o.opts...) if err != nil { t.Error(err) t.FailNow() @@ -64,14 +64,14 @@ func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *ent.Cl return c } -// NewClient calls ent.NewClient and auto-run migration. -func NewClient(t TestingT, opts ...Option) *ent.Client { +// NewClient calls internal.NewClient and auto-run migration. +func NewClient(t TestingT, opts ...Option) *internal.Client { o := newOptions(opts) - c := ent.NewClient(o.opts...) + c := internal.NewClient(o.opts...) migrateSchema(t, c, o) return c } -func migrateSchema(t TestingT, c *ent.Client, o *options) { +func migrateSchema(t TestingT, c *internal.Client, o *options) { tables, err := schema.CopyTables(migrate.Tables) if err != nil { t.Error(err) diff --git a/internal/db/ent/generate.go b/internal/db_impl/ent_db/internal/generate.go similarity index 80% rename from internal/db/ent/generate.go rename to internal/db_impl/ent_db/internal/generate.go index 8d3fdfdc..c46b85b3 100644 --- a/internal/db/ent/generate.go +++ b/internal/db_impl/ent_db/internal/generate.go @@ -1,3 +1,3 @@ -package ent +package internal //go:generate go run -mod=mod entgo.io/ent/cmd/ent generate ./schema diff --git a/internal/db_impl/ent_db/internal/hook/hook.go b/internal/db_impl/ent_db/internal/hook/hook.go new file mode 100644 index 00000000..344683cb --- /dev/null +++ b/internal/db_impl/ent_db/internal/hook/hook.go @@ -0,0 +1,283 @@ +// Code generated by ent, DO NOT EDIT. + +package hook + +import ( + "context" + "fmt" + + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" +) + +// The DeletedSubscriptionFunc type is an adapter to allow the use of ordinary +// function as DeletedSubscription mutator. +type DeletedSubscriptionFunc func(context.Context, *internal.DeletedSubscriptionMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f DeletedSubscriptionFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.DeletedSubscriptionMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.DeletedSubscriptionMutation", m) +} + +// The MailboxFunc type is an adapter to allow the use of ordinary +// function as Mailbox mutator. +type MailboxFunc func(context.Context, *internal.MailboxMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f MailboxFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.MailboxMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.MailboxMutation", m) +} + +// The MailboxAttrFunc type is an adapter to allow the use of ordinary +// function as MailboxAttr mutator. +type MailboxAttrFunc func(context.Context, *internal.MailboxAttrMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f MailboxAttrFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.MailboxAttrMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.MailboxAttrMutation", m) +} + +// The MailboxFlagFunc type is an adapter to allow the use of ordinary +// function as MailboxFlag mutator. +type MailboxFlagFunc func(context.Context, *internal.MailboxFlagMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f MailboxFlagFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.MailboxFlagMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.MailboxFlagMutation", m) +} + +// The MailboxPermFlagFunc type is an adapter to allow the use of ordinary +// function as MailboxPermFlag mutator. +type MailboxPermFlagFunc func(context.Context, *internal.MailboxPermFlagMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f MailboxPermFlagFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.MailboxPermFlagMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.MailboxPermFlagMutation", m) +} + +// The MessageFunc type is an adapter to allow the use of ordinary +// function as Message mutator. +type MessageFunc func(context.Context, *internal.MessageMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f MessageFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.MessageMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.MessageMutation", m) +} + +// The MessageFlagFunc type is an adapter to allow the use of ordinary +// function as MessageFlag mutator. +type MessageFlagFunc func(context.Context, *internal.MessageFlagMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f MessageFlagFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.MessageFlagMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.MessageFlagMutation", m) +} + +// The UIDFunc type is an adapter to allow the use of ordinary +// function as UID mutator. +type UIDFunc func(context.Context, *internal.UIDMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f UIDFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.UIDMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.UIDMutation", m) +} + +// Condition is a hook condition function. +type Condition func(context.Context, internal.Mutation) bool + +// And groups conditions with the AND operator. +func And(first, second Condition, rest ...Condition) Condition { + return func(ctx context.Context, m internal.Mutation) bool { + if !first(ctx, m) || !second(ctx, m) { + return false + } + for _, cond := range rest { + if !cond(ctx, m) { + return false + } + } + return true + } +} + +// Or groups conditions with the OR operator. +func Or(first, second Condition, rest ...Condition) Condition { + return func(ctx context.Context, m internal.Mutation) bool { + if first(ctx, m) || second(ctx, m) { + return true + } + for _, cond := range rest { + if cond(ctx, m) { + return true + } + } + return false + } +} + +// Not negates a given condition. +func Not(cond Condition) Condition { + return func(ctx context.Context, m internal.Mutation) bool { + return !cond(ctx, m) + } +} + +// HasOp is a condition testing mutation operation. +func HasOp(op internal.Op) Condition { + return func(_ context.Context, m internal.Mutation) bool { + return m.Op().Is(op) + } +} + +// HasAddedFields is a condition validating `.AddedField` on fields. +func HasAddedFields(field string, fields ...string) Condition { + return func(_ context.Context, m internal.Mutation) bool { + if _, exists := m.AddedField(field); !exists { + return false + } + for _, field := range fields { + if _, exists := m.AddedField(field); !exists { + return false + } + } + return true + } +} + +// HasClearedFields is a condition validating `.FieldCleared` on fields. +func HasClearedFields(field string, fields ...string) Condition { + return func(_ context.Context, m internal.Mutation) bool { + if exists := m.FieldCleared(field); !exists { + return false + } + for _, field := range fields { + if exists := m.FieldCleared(field); !exists { + return false + } + } + return true + } +} + +// HasFields is a condition validating `.Field` on fields. +func HasFields(field string, fields ...string) Condition { + return func(_ context.Context, m internal.Mutation) bool { + if _, exists := m.Field(field); !exists { + return false + } + for _, field := range fields { + if _, exists := m.Field(field); !exists { + return false + } + } + return true + } +} + +// If executes the given hook under condition. +// +// hook.If(ComputeAverage, And(HasFields(...), HasAddedFields(...))) +func If(hk internal.Hook, cond Condition) internal.Hook { + return func(next internal.Mutator) internal.Mutator { + return internal.MutateFunc(func(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if cond(ctx, m) { + return hk(next).Mutate(ctx, m) + } + return next.Mutate(ctx, m) + }) + } +} + +// On executes the given hook only for the given operation. +// +// hook.On(Log, internal.Delete|internal.Create) +func On(hk internal.Hook, op internal.Op) internal.Hook { + return If(hk, HasOp(op)) +} + +// Unless skips the given hook only for the given operation. +// +// hook.Unless(Log, internal.Update|internal.UpdateOne) +func Unless(hk internal.Hook, op internal.Op) internal.Hook { + return If(hk, Not(HasOp(op))) +} + +// FixedError is a hook returning a fixed error. +func FixedError(err error) internal.Hook { + return func(internal.Mutator) internal.Mutator { + return internal.MutateFunc(func(context.Context, internal.Mutation) (internal.Value, error) { + return nil, err + }) + } +} + +// Reject returns a hook that rejects all operations that match op. +// +// func (T) Hooks() []internal.Hook { +// return []internal.Hook{ +// Reject(internal.Delete|internal.Update), +// } +// } +func Reject(op internal.Op) internal.Hook { + hk := FixedError(fmt.Errorf("%s operation is not allowed", op)) + return On(hk, op) +} + +// Chain acts as a list of hooks and is effectively immutable. +// Once created, it will always hold the same set of hooks in the same order. +type Chain struct { + hooks []internal.Hook +} + +// NewChain creates a new chain of hooks. +func NewChain(hooks ...internal.Hook) Chain { + return Chain{append([]internal.Hook(nil), hooks...)} +} + +// Hook chains the list of hooks and returns the final hook. +func (c Chain) Hook() internal.Hook { + return func(mutator internal.Mutator) internal.Mutator { + for i := len(c.hooks) - 1; i >= 0; i-- { + mutator = c.hooks[i](mutator) + } + return mutator + } +} + +// Append extends a chain, adding the specified hook +// as the last ones in the mutation flow. +func (c Chain) Append(hooks ...internal.Hook) Chain { + newHooks := make([]internal.Hook, 0, len(c.hooks)+len(hooks)) + newHooks = append(newHooks, c.hooks...) + newHooks = append(newHooks, hooks...) + return Chain{newHooks} +} + +// Extend extends a chain, adding the specified chain +// as the last ones in the mutation flow. +func (c Chain) Extend(chain Chain) Chain { + return c.Append(chain.hooks...) +} diff --git a/internal/db/ent/mailbox.go b/internal/db_impl/ent_db/internal/mailbox.go similarity index 90% rename from internal/db/ent/mailbox.go rename to internal/db_impl/ent_db/internal/mailbox.go index b863a41b..edf61c35 100644 --- a/internal/db/ent/mailbox.go +++ b/internal/db_impl/ent_db/internal/mailbox.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -8,7 +8,7 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" ) // Mailbox is the model entity for the Mailbox schema. @@ -83,8 +83,8 @@ func (e MailboxEdges) AttributesOrErr() ([]*MailboxAttr, error) { } // scanValues returns the types for scanning values from sql.Rows. -func (*Mailbox) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*Mailbox) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case mailbox.FieldSubscribed: @@ -102,7 +102,7 @@ func (*Mailbox) scanValues(columns []string) ([]interface{}, error) { // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the Mailbox fields. -func (m *Mailbox) assignValues(columns []string, values []interface{}) error { +func (m *Mailbox) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -151,29 +151,29 @@ func (m *Mailbox) assignValues(columns []string, values []interface{}) error { // QueryUIDs queries the "UIDs" edge of the Mailbox entity. func (m *Mailbox) QueryUIDs() *UIDQuery { - return (&MailboxClient{config: m.config}).QueryUIDs(m) + return NewMailboxClient(m.config).QueryUIDs(m) } // QueryFlags queries the "flags" edge of the Mailbox entity. func (m *Mailbox) QueryFlags() *MailboxFlagQuery { - return (&MailboxClient{config: m.config}).QueryFlags(m) + return NewMailboxClient(m.config).QueryFlags(m) } // QueryPermanentFlags queries the "permanent_flags" edge of the Mailbox entity. func (m *Mailbox) QueryPermanentFlags() *MailboxPermFlagQuery { - return (&MailboxClient{config: m.config}).QueryPermanentFlags(m) + return NewMailboxClient(m.config).QueryPermanentFlags(m) } // QueryAttributes queries the "attributes" edge of the Mailbox entity. func (m *Mailbox) QueryAttributes() *MailboxAttrQuery { - return (&MailboxClient{config: m.config}).QueryAttributes(m) + return NewMailboxClient(m.config).QueryAttributes(m) } // Update returns a builder for updating this Mailbox. // Note that you need to call Mailbox.Unwrap() before calling this method if this Mailbox // was returned from a transaction, and the transaction was committed or rolled back. func (m *Mailbox) Update() *MailboxUpdateOne { - return (&MailboxClient{config: m.config}).UpdateOne(m) + return NewMailboxClient(m.config).UpdateOne(m) } // Unwrap unwraps the Mailbox entity that was returned from a transaction after it was closed, @@ -181,7 +181,7 @@ func (m *Mailbox) Update() *MailboxUpdateOne { func (m *Mailbox) Unwrap() *Mailbox { _tx, ok := m.config.driver.(*txDriver) if !ok { - panic("ent: Mailbox is not a transactional entity") + panic("internal: Mailbox is not a transactional entity") } m.config.driver = _tx.drv return m @@ -212,9 +212,3 @@ func (m *Mailbox) String() string { // Mailboxes is a parsable slice of Mailbox. type Mailboxes []*Mailbox - -func (m Mailboxes) config(cfg config) { - for _i := range m { - m[_i].config = cfg - } -} diff --git a/internal/db/ent/mailbox/mailbox.go b/internal/db_impl/ent_db/internal/mailbox/mailbox.go similarity index 100% rename from internal/db/ent/mailbox/mailbox.go rename to internal/db_impl/ent_db/internal/mailbox/mailbox.go diff --git a/internal/db/ent/mailbox/where.go b/internal/db_impl/ent_db/internal/mailbox/where.go similarity index 65% rename from internal/db/ent/mailbox/where.go rename to internal/db_impl/ent_db/internal/mailbox/where.go index 64940642..bfba382c 100644 --- a/internal/db/ent/mailbox/where.go +++ b/internal/db_impl/ent_db/internal/mailbox/where.go @@ -6,493 +6,357 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Mailbox(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Mailbox(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Mailbox(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Mailbox(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Mailbox(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Mailbox(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Mailbox(sql.FieldLTE(FieldID, id)) } // RemoteID applies equality check predicate on the "RemoteID" field. It's identical to RemoteIDEQ. func RemoteID(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldRemoteID, vc)) } // Name applies equality check predicate on the "Name" field. It's identical to NameEQ. func Name(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldName, v)) } // UIDNext applies equality check predicate on the "UIDNext" field. It's identical to UIDNextEQ. func UIDNext(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUIDNext), vc)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldUIDNext, vc)) } // UIDValidity applies equality check predicate on the "UIDValidity" field. It's identical to UIDValidityEQ. func UIDValidity(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUIDValidity), vc)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldUIDValidity, vc)) } // Subscribed applies equality check predicate on the "Subscribed" field. It's identical to SubscribedEQ. func Subscribed(v bool) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSubscribed), v)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldSubscribed, v)) } // RemoteIDEQ applies the EQ predicate on the "RemoteID" field. func RemoteIDEQ(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldRemoteID, vc)) } // RemoteIDNEQ applies the NEQ predicate on the "RemoteID" field. func RemoteIDNEQ(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldNEQ(FieldRemoteID, vc)) } // RemoteIDIn applies the In predicate on the "RemoteID" field. func RemoteIDIn(vs ...imap.MailboxID) predicate.Mailbox { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = string(vs[i]) } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldRemoteID), v...)) - }) + return predicate.Mailbox(sql.FieldIn(FieldRemoteID, v...)) } // RemoteIDNotIn applies the NotIn predicate on the "RemoteID" field. func RemoteIDNotIn(vs ...imap.MailboxID) predicate.Mailbox { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = string(vs[i]) } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldRemoteID), v...)) - }) + return predicate.Mailbox(sql.FieldNotIn(FieldRemoteID, v...)) } // RemoteIDGT applies the GT predicate on the "RemoteID" field. func RemoteIDGT(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldGT(FieldRemoteID, vc)) } // RemoteIDGTE applies the GTE predicate on the "RemoteID" field. func RemoteIDGTE(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldGTE(FieldRemoteID, vc)) } // RemoteIDLT applies the LT predicate on the "RemoteID" field. func RemoteIDLT(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldLT(FieldRemoteID, vc)) } // RemoteIDLTE applies the LTE predicate on the "RemoteID" field. func RemoteIDLTE(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldLTE(FieldRemoteID, vc)) } // RemoteIDContains applies the Contains predicate on the "RemoteID" field. func RemoteIDContains(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldContains(FieldRemoteID, vc)) } // RemoteIDHasPrefix applies the HasPrefix predicate on the "RemoteID" field. func RemoteIDHasPrefix(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldHasPrefix(FieldRemoteID, vc)) } // RemoteIDHasSuffix applies the HasSuffix predicate on the "RemoteID" field. func RemoteIDHasSuffix(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldHasSuffix(FieldRemoteID, vc)) } // RemoteIDIsNil applies the IsNil predicate on the "RemoteID" field. func RemoteIDIsNil() predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldRemoteID))) - }) + return predicate.Mailbox(sql.FieldIsNull(FieldRemoteID)) } // RemoteIDNotNil applies the NotNil predicate on the "RemoteID" field. func RemoteIDNotNil() predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldRemoteID))) - }) + return predicate.Mailbox(sql.FieldNotNull(FieldRemoteID)) } // RemoteIDEqualFold applies the EqualFold predicate on the "RemoteID" field. func RemoteIDEqualFold(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldEqualFold(FieldRemoteID, vc)) } // RemoteIDContainsFold applies the ContainsFold predicate on the "RemoteID" field. func RemoteIDContainsFold(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldContainsFold(FieldRemoteID, vc)) } // NameEQ applies the EQ predicate on the "Name" field. func NameEQ(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldName, v)) } // NameNEQ applies the NEQ predicate on the "Name" field. func NameNEQ(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldNEQ(FieldName, v)) } // NameIn applies the In predicate on the "Name" field. func NameIn(vs ...string) predicate.Mailbox { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldName), v...)) - }) + return predicate.Mailbox(sql.FieldIn(FieldName, vs...)) } // NameNotIn applies the NotIn predicate on the "Name" field. func NameNotIn(vs ...string) predicate.Mailbox { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldName), v...)) - }) + return predicate.Mailbox(sql.FieldNotIn(FieldName, vs...)) } // NameGT applies the GT predicate on the "Name" field. func NameGT(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldGT(FieldName, v)) } // NameGTE applies the GTE predicate on the "Name" field. func NameGTE(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldGTE(FieldName, v)) } // NameLT applies the LT predicate on the "Name" field. func NameLT(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldLT(FieldName, v)) } // NameLTE applies the LTE predicate on the "Name" field. func NameLTE(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldLTE(FieldName, v)) } // NameContains applies the Contains predicate on the "Name" field. func NameContains(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldContains(FieldName, v)) } // NameHasPrefix applies the HasPrefix predicate on the "Name" field. func NameHasPrefix(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldHasPrefix(FieldName, v)) } // NameHasSuffix applies the HasSuffix predicate on the "Name" field. func NameHasSuffix(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldHasSuffix(FieldName, v)) } // NameEqualFold applies the EqualFold predicate on the "Name" field. func NameEqualFold(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldEqualFold(FieldName, v)) } // NameContainsFold applies the ContainsFold predicate on the "Name" field. func NameContainsFold(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldContainsFold(FieldName, v)) } // UIDNextEQ applies the EQ predicate on the "UIDNext" field. func UIDNextEQ(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUIDNext), vc)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldUIDNext, vc)) } // UIDNextNEQ applies the NEQ predicate on the "UIDNext" field. func UIDNextNEQ(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUIDNext), vc)) - }) + return predicate.Mailbox(sql.FieldNEQ(FieldUIDNext, vc)) } // UIDNextIn applies the In predicate on the "UIDNext" field. func UIDNextIn(vs ...imap.UID) predicate.Mailbox { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = uint32(vs[i]) } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUIDNext), v...)) - }) + return predicate.Mailbox(sql.FieldIn(FieldUIDNext, v...)) } // UIDNextNotIn applies the NotIn predicate on the "UIDNext" field. func UIDNextNotIn(vs ...imap.UID) predicate.Mailbox { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = uint32(vs[i]) } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUIDNext), v...)) - }) + return predicate.Mailbox(sql.FieldNotIn(FieldUIDNext, v...)) } // UIDNextGT applies the GT predicate on the "UIDNext" field. func UIDNextGT(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUIDNext), vc)) - }) + return predicate.Mailbox(sql.FieldGT(FieldUIDNext, vc)) } // UIDNextGTE applies the GTE predicate on the "UIDNext" field. func UIDNextGTE(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUIDNext), vc)) - }) + return predicate.Mailbox(sql.FieldGTE(FieldUIDNext, vc)) } // UIDNextLT applies the LT predicate on the "UIDNext" field. func UIDNextLT(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUIDNext), vc)) - }) + return predicate.Mailbox(sql.FieldLT(FieldUIDNext, vc)) } // UIDNextLTE applies the LTE predicate on the "UIDNext" field. func UIDNextLTE(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUIDNext), vc)) - }) + return predicate.Mailbox(sql.FieldLTE(FieldUIDNext, vc)) } // UIDValidityEQ applies the EQ predicate on the "UIDValidity" field. func UIDValidityEQ(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUIDValidity), vc)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldUIDValidity, vc)) } // UIDValidityNEQ applies the NEQ predicate on the "UIDValidity" field. func UIDValidityNEQ(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUIDValidity), vc)) - }) + return predicate.Mailbox(sql.FieldNEQ(FieldUIDValidity, vc)) } // UIDValidityIn applies the In predicate on the "UIDValidity" field. func UIDValidityIn(vs ...imap.UID) predicate.Mailbox { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = uint32(vs[i]) } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUIDValidity), v...)) - }) + return predicate.Mailbox(sql.FieldIn(FieldUIDValidity, v...)) } // UIDValidityNotIn applies the NotIn predicate on the "UIDValidity" field. func UIDValidityNotIn(vs ...imap.UID) predicate.Mailbox { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = uint32(vs[i]) } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUIDValidity), v...)) - }) + return predicate.Mailbox(sql.FieldNotIn(FieldUIDValidity, v...)) } // UIDValidityGT applies the GT predicate on the "UIDValidity" field. func UIDValidityGT(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUIDValidity), vc)) - }) + return predicate.Mailbox(sql.FieldGT(FieldUIDValidity, vc)) } // UIDValidityGTE applies the GTE predicate on the "UIDValidity" field. func UIDValidityGTE(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUIDValidity), vc)) - }) + return predicate.Mailbox(sql.FieldGTE(FieldUIDValidity, vc)) } // UIDValidityLT applies the LT predicate on the "UIDValidity" field. func UIDValidityLT(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUIDValidity), vc)) - }) + return predicate.Mailbox(sql.FieldLT(FieldUIDValidity, vc)) } // UIDValidityLTE applies the LTE predicate on the "UIDValidity" field. func UIDValidityLTE(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUIDValidity), vc)) - }) + return predicate.Mailbox(sql.FieldLTE(FieldUIDValidity, vc)) } // SubscribedEQ applies the EQ predicate on the "Subscribed" field. func SubscribedEQ(v bool) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSubscribed), v)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldSubscribed, v)) } // SubscribedNEQ applies the NEQ predicate on the "Subscribed" field. func SubscribedNEQ(v bool) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSubscribed), v)) - }) + return predicate.Mailbox(sql.FieldNEQ(FieldSubscribed, v)) } // HasUIDs applies the HasEdge predicate on the "UIDs" edge. @@ -500,7 +364,6 @@ func HasUIDs() predicate.Mailbox { return predicate.Mailbox(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(UIDsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, UIDsTable, UIDsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -528,7 +391,6 @@ func HasFlags() predicate.Mailbox { return predicate.Mailbox(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(FlagsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, FlagsTable, FlagsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -556,7 +418,6 @@ func HasPermanentFlags() predicate.Mailbox { return predicate.Mailbox(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(PermanentFlagsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, PermanentFlagsTable, PermanentFlagsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -584,7 +445,6 @@ func HasAttributes() predicate.Mailbox { return predicate.Mailbox(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(AttributesTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, AttributesTable, AttributesColumn), ) sqlgraph.HasNeighbors(s, step) diff --git a/internal/db/ent/mailbox_create.go b/internal/db_impl/ent_db/internal/mailbox_create.go similarity index 80% rename from internal/db/ent/mailbox_create.go rename to internal/db_impl/ent_db/internal/mailbox_create.go index b1fa61a3..9b831be9 100644 --- a/internal/db/ent/mailbox_create.go +++ b/internal/db_impl/ent_db/internal/mailbox_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,11 +10,11 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // MailboxCreate is the builder for creating a Mailbox entity. @@ -159,50 +159,8 @@ func (mc *MailboxCreate) Mutation() *MailboxMutation { // Save creates the Mailbox in the database. func (mc *MailboxCreate) Save(ctx context.Context) (*Mailbox, error) { - var ( - err error - node *Mailbox - ) mc.defaults() - if len(mc.hooks) == 0 { - if err = mc.check(); err != nil { - return nil, err - } - node, err = mc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mc.check(); err != nil { - return nil, err - } - mc.mutation = mutation - if node, err = mc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mc.hooks) - 1; i >= 0; i-- { - if mc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Mailbox) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxMutation", v) - } - node = nv - } - return node, err + return withHooks[*Mailbox, MailboxMutation](ctx, mc.sqlSave, mc.mutation, mc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -246,21 +204,24 @@ func (mc *MailboxCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (mc *MailboxCreate) check() error { if _, ok := mc.mutation.Name(); !ok { - return &ValidationError{Name: "Name", err: errors.New(`ent: missing required field "Mailbox.Name"`)} + return &ValidationError{Name: "Name", err: errors.New(`internal: missing required field "Mailbox.Name"`)} } if _, ok := mc.mutation.UIDNext(); !ok { - return &ValidationError{Name: "UIDNext", err: errors.New(`ent: missing required field "Mailbox.UIDNext"`)} + return &ValidationError{Name: "UIDNext", err: errors.New(`internal: missing required field "Mailbox.UIDNext"`)} } if _, ok := mc.mutation.UIDValidity(); !ok { - return &ValidationError{Name: "UIDValidity", err: errors.New(`ent: missing required field "Mailbox.UIDValidity"`)} + return &ValidationError{Name: "UIDValidity", err: errors.New(`internal: missing required field "Mailbox.UIDValidity"`)} } if _, ok := mc.mutation.Subscribed(); !ok { - return &ValidationError{Name: "Subscribed", err: errors.New(`ent: missing required field "Mailbox.Subscribed"`)} + return &ValidationError{Name: "Subscribed", err: errors.New(`internal: missing required field "Mailbox.Subscribed"`)} } return nil } func (mc *MailboxCreate) sqlSave(ctx context.Context) (*Mailbox, error) { + if err := mc.check(); err != nil { + return nil, err + } _node, _spec := mc.createSpec() if err := sqlgraph.CreateNode(ctx, mc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -272,62 +233,38 @@ func (mc *MailboxCreate) sqlSave(ctx context.Context) (*Mailbox, error) { id := _spec.ID.Value.(int64) _node.ID = imap.InternalMailboxID(id) } + mc.mutation.id = &_node.ID + mc.mutation.done = true return _node, nil } func (mc *MailboxCreate) createSpec() (*Mailbox, *sqlgraph.CreateSpec) { var ( _node = &Mailbox{config: mc.config} - _spec = &sqlgraph.CreateSpec{ - Table: mailbox.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, - Column: mailbox.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(mailbox.Table, sqlgraph.NewFieldSpec(mailbox.FieldID, field.TypeUint64)) ) if id, ok := mc.mutation.ID(); ok { _node.ID = id _spec.ID.Value = id } if value, ok := mc.mutation.RemoteID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailbox.FieldRemoteID, - }) + _spec.SetField(mailbox.FieldRemoteID, field.TypeString, value) _node.RemoteID = value } if value, ok := mc.mutation.Name(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailbox.FieldName, - }) + _spec.SetField(mailbox.FieldName, field.TypeString, value) _node.Name = value } if value, ok := mc.mutation.UIDNext(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDNext, - }) + _spec.SetField(mailbox.FieldUIDNext, field.TypeUint32, value) _node.UIDNext = value } if value, ok := mc.mutation.UIDValidity(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDValidity, - }) + _spec.SetField(mailbox.FieldUIDValidity, field.TypeUint32, value) _node.UIDValidity = value } if value, ok := mc.mutation.Subscribed(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: mailbox.FieldSubscribed, - }) + _spec.SetField(mailbox.FieldSubscribed, field.TypeBool, value) _node.Subscribed = value } if nodes := mc.mutation.UIDsIDs(); len(nodes) > 0 { diff --git a/internal/db/ent/mailbox_delete.go b/internal/db_impl/ent_db/internal/mailbox_delete.go similarity index 62% rename from internal/db/ent/mailbox_delete.go rename to internal/db_impl/ent_db/internal/mailbox_delete.go index 5fd9ac3d..a8074cc9 100644 --- a/internal/db/ent/mailbox_delete.go +++ b/internal/db_impl/ent_db/internal/mailbox_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxDelete is the builder for deleting a Mailbox entity. @@ -28,34 +27,7 @@ func (md *MailboxDelete) Where(ps ...predicate.Mailbox) *MailboxDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (md *MailboxDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(md.hooks) == 0 { - affected, err = md.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - md.mutation = mutation - affected, err = md.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(md.hooks) - 1; i >= 0; i-- { - if md.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = md.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, md.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxMutation](ctx, md.sqlExec, md.mutation, md.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (md *MailboxDelete) ExecX(ctx context.Context) int { } func (md *MailboxDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailbox.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, - Column: mailbox.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(mailbox.Table, sqlgraph.NewFieldSpec(mailbox.FieldID, field.TypeUint64)) if ps := md.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (md *MailboxDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + md.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MailboxDeleteOne struct { md *MailboxDelete } +// Where appends a list predicates to the MailboxDelete builder. +func (mdo *MailboxDeleteOne) Where(ps ...predicate.Mailbox) *MailboxDeleteOne { + mdo.md.mutation.Where(ps...) + return mdo +} + // Exec executes the deletion query. func (mdo *MailboxDeleteOne) Exec(ctx context.Context) error { n, err := mdo.md.Exec(ctx) @@ -111,5 +82,7 @@ func (mdo *MailboxDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mdo *MailboxDeleteOne) ExecX(ctx context.Context) { - mdo.md.ExecX(ctx) + if err := mdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/mailbox_query.go b/internal/db_impl/ent_db/internal/mailbox_query.go similarity index 77% rename from internal/db/ent/mailbox_query.go rename to internal/db_impl/ent_db/internal/mailbox_query.go index 0afa7e5e..a21664be 100644 --- a/internal/db/ent/mailbox_query.go +++ b/internal/db_impl/ent_db/internal/mailbox_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -12,22 +12,20 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // MailboxQuery is the builder for querying Mailbox entities. type MailboxQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.Mailbox withUIDs *UIDQuery withFlags *MailboxFlagQuery @@ -44,26 +42,26 @@ func (mq *MailboxQuery) Where(ps ...predicate.Mailbox) *MailboxQuery { return mq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mq *MailboxQuery) Limit(limit int) *MailboxQuery { - mq.limit = &limit + mq.ctx.Limit = &limit return mq } -// Offset adds an offset step to the query. +// Offset to start from. func (mq *MailboxQuery) Offset(offset int) *MailboxQuery { - mq.offset = &offset + mq.ctx.Offset = &offset return mq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mq *MailboxQuery) Unique(unique bool) *MailboxQuery { - mq.unique = &unique + mq.ctx.Unique = &unique return mq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (mq *MailboxQuery) Order(o ...OrderFunc) *MailboxQuery { mq.order = append(mq.order, o...) return mq @@ -71,7 +69,7 @@ func (mq *MailboxQuery) Order(o ...OrderFunc) *MailboxQuery { // QueryUIDs chains the current query on the "UIDs" edge. func (mq *MailboxQuery) QueryUIDs() *UIDQuery { - query := &UIDQuery{config: mq.config} + query := (&UIDClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -93,7 +91,7 @@ func (mq *MailboxQuery) QueryUIDs() *UIDQuery { // QueryFlags chains the current query on the "flags" edge. func (mq *MailboxQuery) QueryFlags() *MailboxFlagQuery { - query := &MailboxFlagQuery{config: mq.config} + query := (&MailboxFlagClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -115,7 +113,7 @@ func (mq *MailboxQuery) QueryFlags() *MailboxFlagQuery { // QueryPermanentFlags chains the current query on the "permanent_flags" edge. func (mq *MailboxQuery) QueryPermanentFlags() *MailboxPermFlagQuery { - query := &MailboxPermFlagQuery{config: mq.config} + query := (&MailboxPermFlagClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -137,7 +135,7 @@ func (mq *MailboxQuery) QueryPermanentFlags() *MailboxPermFlagQuery { // QueryAttributes chains the current query on the "attributes" edge. func (mq *MailboxQuery) QueryAttributes() *MailboxAttrQuery { - query := &MailboxAttrQuery{config: mq.config} + query := (&MailboxAttrClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -160,7 +158,7 @@ func (mq *MailboxQuery) QueryAttributes() *MailboxAttrQuery { // First returns the first Mailbox entity from the query. // Returns a *NotFoundError when no Mailbox was found. func (mq *MailboxQuery) First(ctx context.Context) (*Mailbox, error) { - nodes, err := mq.Limit(1).All(ctx) + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) if err != nil { return nil, err } @@ -183,7 +181,7 @@ func (mq *MailboxQuery) FirstX(ctx context.Context) *Mailbox { // Returns a *NotFoundError when no Mailbox ID was found. func (mq *MailboxQuery) FirstID(ctx context.Context) (id imap.InternalMailboxID, err error) { var ids []imap.InternalMailboxID - if ids, err = mq.Limit(1).IDs(ctx); err != nil { + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -206,7 +204,7 @@ func (mq *MailboxQuery) FirstIDX(ctx context.Context) imap.InternalMailboxID { // Returns a *NotSingularError when more than one Mailbox entity is found. // Returns a *NotFoundError when no Mailbox entities are found. func (mq *MailboxQuery) Only(ctx context.Context) (*Mailbox, error) { - nodes, err := mq.Limit(2).All(ctx) + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) if err != nil { return nil, err } @@ -234,7 +232,7 @@ func (mq *MailboxQuery) OnlyX(ctx context.Context) *Mailbox { // Returns a *NotFoundError when no entities are found. func (mq *MailboxQuery) OnlyID(ctx context.Context) (id imap.InternalMailboxID, err error) { var ids []imap.InternalMailboxID - if ids, err = mq.Limit(2).IDs(ctx); err != nil { + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -259,10 +257,12 @@ func (mq *MailboxQuery) OnlyIDX(ctx context.Context) imap.InternalMailboxID { // All executes the query and returns a list of Mailboxes. func (mq *MailboxQuery) All(ctx context.Context) ([]*Mailbox, error) { + ctx = setContextOp(ctx, mq.ctx, "All") if err := mq.prepareQuery(ctx); err != nil { return nil, err } - return mq.sqlAll(ctx) + qr := querierAll[[]*Mailbox, *MailboxQuery]() + return withInterceptors[[]*Mailbox](ctx, mq, qr, mq.inters) } // AllX is like All, but panics if an error occurs. @@ -275,9 +275,12 @@ func (mq *MailboxQuery) AllX(ctx context.Context) []*Mailbox { } // IDs executes the query and returns a list of Mailbox IDs. -func (mq *MailboxQuery) IDs(ctx context.Context) ([]imap.InternalMailboxID, error) { - var ids []imap.InternalMailboxID - if err := mq.Select(mailbox.FieldID).Scan(ctx, &ids); err != nil { +func (mq *MailboxQuery) IDs(ctx context.Context) (ids []imap.InternalMailboxID, err error) { + if mq.ctx.Unique == nil && mq.path != nil { + mq.Unique(true) + } + ctx = setContextOp(ctx, mq.ctx, "IDs") + if err = mq.Select(mailbox.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -294,10 +297,11 @@ func (mq *MailboxQuery) IDsX(ctx context.Context) []imap.InternalMailboxID { // Count returns the count of the given query. func (mq *MailboxQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mq.ctx, "Count") if err := mq.prepareQuery(ctx); err != nil { return 0, err } - return mq.sqlCount(ctx) + return withInterceptors[int](ctx, mq, querierCount[*MailboxQuery](), mq.inters) } // CountX is like Count, but panics if an error occurs. @@ -311,10 +315,15 @@ func (mq *MailboxQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mq *MailboxQuery) Exist(ctx context.Context) (bool, error) { - if err := mq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mq.ctx, "Exist") + switch _, err := mq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return mq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -334,25 +343,24 @@ func (mq *MailboxQuery) Clone() *MailboxQuery { } return &MailboxQuery{ config: mq.config, - limit: mq.limit, - offset: mq.offset, + ctx: mq.ctx.Clone(), order: append([]OrderFunc{}, mq.order...), + inters: append([]Interceptor{}, mq.inters...), predicates: append([]predicate.Mailbox{}, mq.predicates...), withUIDs: mq.withUIDs.Clone(), withFlags: mq.withFlags.Clone(), withPermanentFlags: mq.withPermanentFlags.Clone(), withAttributes: mq.withAttributes.Clone(), // clone intermediate query. - sql: mq.sql.Clone(), - path: mq.path, - unique: mq.unique, + sql: mq.sql.Clone(), + path: mq.path, } } // WithUIDs tells the query-builder to eager-load the nodes that are connected to // the "UIDs" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MailboxQuery) WithUIDs(opts ...func(*UIDQuery)) *MailboxQuery { - query := &UIDQuery{config: mq.config} + query := (&UIDClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -363,7 +371,7 @@ func (mq *MailboxQuery) WithUIDs(opts ...func(*UIDQuery)) *MailboxQuery { // WithFlags tells the query-builder to eager-load the nodes that are connected to // the "flags" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MailboxQuery) WithFlags(opts ...func(*MailboxFlagQuery)) *MailboxQuery { - query := &MailboxFlagQuery{config: mq.config} + query := (&MailboxFlagClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -374,7 +382,7 @@ func (mq *MailboxQuery) WithFlags(opts ...func(*MailboxFlagQuery)) *MailboxQuery // WithPermanentFlags tells the query-builder to eager-load the nodes that are connected to // the "permanent_flags" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MailboxQuery) WithPermanentFlags(opts ...func(*MailboxPermFlagQuery)) *MailboxQuery { - query := &MailboxPermFlagQuery{config: mq.config} + query := (&MailboxPermFlagClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -385,7 +393,7 @@ func (mq *MailboxQuery) WithPermanentFlags(opts ...func(*MailboxPermFlagQuery)) // WithAttributes tells the query-builder to eager-load the nodes that are connected to // the "attributes" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MailboxQuery) WithAttributes(opts ...func(*MailboxAttrQuery)) *MailboxQuery { - query := &MailboxAttrQuery{config: mq.config} + query := (&MailboxAttrClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -405,19 +413,14 @@ func (mq *MailboxQuery) WithAttributes(opts ...func(*MailboxAttrQuery)) *Mailbox // // client.Mailbox.Query(). // GroupBy(mailbox.FieldRemoteID). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (mq *MailboxQuery) GroupBy(field string, fields ...string) *MailboxGroupBy { - grbuild := &MailboxGroupBy{config: mq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mq.prepareQuery(ctx); err != nil { - return nil, err - } - return mq.sqlQuery(ctx), nil - } + mq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MailboxGroupBy{build: mq} + grbuild.flds = &mq.ctx.Fields grbuild.label = mailbox.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -434,17 +437,32 @@ func (mq *MailboxQuery) GroupBy(field string, fields ...string) *MailboxGroupBy // Select(mailbox.FieldRemoteID). // Scan(ctx, &v) func (mq *MailboxQuery) Select(fields ...string) *MailboxSelect { - mq.fields = append(mq.fields, fields...) - selbuild := &MailboxSelect{MailboxQuery: mq} - selbuild.label = mailbox.Label - selbuild.flds, selbuild.scan = &mq.fields, selbuild.Scan - return selbuild + mq.ctx.Fields = append(mq.ctx.Fields, fields...) + sbuild := &MailboxSelect{MailboxQuery: mq} + sbuild.label = mailbox.Label + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MailboxSelect configured with the given aggregations. +func (mq *MailboxQuery) Aggregate(fns ...AggregateFunc) *MailboxSelect { + return mq.Select().Aggregate(fns...) } func (mq *MailboxQuery) prepareQuery(ctx context.Context) error { - for _, f := range mq.fields { + for _, inter := range mq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mq); err != nil { + return err + } + } + } + for _, f := range mq.ctx.Fields { if !mailbox.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if mq.path != nil { @@ -468,10 +486,10 @@ func (mq *MailboxQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Mail mq.withAttributes != nil, } ) - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*Mailbox).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &Mailbox{config: mq.config} nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes @@ -644,38 +662,22 @@ func (mq *MailboxQuery) loadAttributes(ctx context.Context, query *MailboxAttrQu func (mq *MailboxQuery) sqlCount(ctx context.Context) (int, error) { _spec := mq.querySpec() - _spec.Node.Columns = mq.fields - if len(mq.fields) > 0 { - _spec.Unique = mq.unique != nil && *mq.unique + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique } return sqlgraph.CountNodes(ctx, mq.driver, _spec) } -func (mq *MailboxQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := mq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (mq *MailboxQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailbox.Table, - Columns: mailbox.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, - Column: mailbox.FieldID, - }, - }, - From: mq.sql, - Unique: true, - } - if unique := mq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(mailbox.Table, mailbox.Columns, sqlgraph.NewFieldSpec(mailbox.FieldID, field.TypeUint64)) + _spec.From = mq.sql + if unique := mq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mq.path != nil { + _spec.Unique = true } - if fields := mq.fields; len(fields) > 0 { + if fields := mq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, mailbox.FieldID) for i := range fields { @@ -691,10 +693,10 @@ func (mq *MailboxQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mq.order; len(ps) > 0 { @@ -710,7 +712,7 @@ func (mq *MailboxQuery) querySpec() *sqlgraph.QuerySpec { func (mq *MailboxQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mq.driver.Dialect()) t1 := builder.Table(mailbox.Table) - columns := mq.fields + columns := mq.ctx.Fields if len(columns) == 0 { columns = mailbox.Columns } @@ -719,7 +721,7 @@ func (mq *MailboxQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mq.sql selector.Select(selector.Columns(columns...)...) } - if mq.unique != nil && *mq.unique { + if mq.ctx.Unique != nil && *mq.ctx.Unique { selector.Distinct() } for _, p := range mq.predicates { @@ -728,12 +730,12 @@ func (mq *MailboxQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mq.order { p(selector) } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -741,13 +743,8 @@ func (mq *MailboxQuery) sqlQuery(ctx context.Context) *sql.Selector { // MailboxGroupBy is the group-by builder for Mailbox entities. type MailboxGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MailboxQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -756,74 +753,77 @@ func (mgb *MailboxGroupBy) Aggregate(fns ...AggregateFunc) *MailboxGroupBy { return mgb } -// Scan applies the group-by query and scans the result into the given value. -func (mgb *MailboxGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := mgb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (mgb *MailboxGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") + if err := mgb.build.prepareQuery(ctx); err != nil { return err } - mgb.sql = query - return mgb.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxQuery, *MailboxGroupBy](ctx, mgb.build, mgb, mgb.build.inters, v) } -func (mgb *MailboxGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range mgb.fields { - if !mailbox.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mgb *MailboxGroupBy) sqlScan(ctx context.Context, root *MailboxQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mgb.fns)) + for _, fn := range mgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mgb.flds)+len(mgb.fns)) + for _, f := range *mgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mgb.sqlQuery() + selector.GroupBy(selector.Columns(*mgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mgb *MailboxGroupBy) sqlQuery() *sql.Selector { - selector := mgb.sql.Select() - aggregation := make([]string, 0, len(mgb.fns)) - for _, fn := range mgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mgb.fields)+len(mgb.fns)) - for _, f := range mgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mgb.fields...)...) -} - // MailboxSelect is the builder for selecting fields of Mailbox entities. type MailboxSelect struct { *MailboxQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ms *MailboxSelect) Aggregate(fns ...AggregateFunc) *MailboxSelect { + ms.fns = append(ms.fns, fns...) + return ms } // Scan applies the selector query and scans the result into the given value. -func (ms *MailboxSelect) Scan(ctx context.Context, v interface{}) error { +func (ms *MailboxSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ms.ctx, "Select") if err := ms.prepareQuery(ctx); err != nil { return err } - ms.sql = ms.MailboxQuery.sqlQuery(ctx) - return ms.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxQuery, *MailboxSelect](ctx, ms.MailboxQuery, ms, ms.inters, v) } -func (ms *MailboxSelect) sqlScan(ctx context.Context, v interface{}) error { +func (ms *MailboxSelect) sqlScan(ctx context.Context, root *MailboxQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ms.fns)) + for _, fn := range ms.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ms.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ms.sql.Query() + query, args := selector.Query() if err := ms.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/mailbox_update.go b/internal/db_impl/ent_db/internal/mailbox_update.go similarity index 85% rename from internal/db/ent/mailbox_update.go rename to internal/db_impl/ent_db/internal/mailbox_update.go index 46ceca2a..150a494c 100644 --- a/internal/db/ent/mailbox_update.go +++ b/internal/db_impl/ent_db/internal/mailbox_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -11,12 +11,12 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // MailboxUpdate is the builder for updating Mailbox entities. @@ -265,34 +265,7 @@ func (mu *MailboxUpdate) RemoveAttributes(m ...*MailboxAttr) *MailboxUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (mu *MailboxUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mu.hooks) == 0 { - affected, err = mu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mu.mutation = mutation - affected, err = mu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mu.hooks) - 1; i >= 0; i-- { - if mu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxMutation](ctx, mu.sqlSave, mu.mutation, mu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -318,16 +291,7 @@ func (mu *MailboxUpdate) ExecX(ctx context.Context) { } func (mu *MailboxUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailbox.Table, - Columns: mailbox.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, - Column: mailbox.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailbox.Table, mailbox.Columns, sqlgraph.NewFieldSpec(mailbox.FieldID, field.TypeUint64)) if ps := mu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -336,59 +300,28 @@ func (mu *MailboxUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := mu.mutation.RemoteID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailbox.FieldRemoteID, - }) + _spec.SetField(mailbox.FieldRemoteID, field.TypeString, value) } if mu.mutation.RemoteIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: mailbox.FieldRemoteID, - }) + _spec.ClearField(mailbox.FieldRemoteID, field.TypeString) } if value, ok := mu.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailbox.FieldName, - }) + _spec.SetField(mailbox.FieldName, field.TypeString, value) } if value, ok := mu.mutation.UIDNext(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDNext, - }) + _spec.SetField(mailbox.FieldUIDNext, field.TypeUint32, value) } if value, ok := mu.mutation.AddedUIDNext(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDNext, - }) + _spec.AddField(mailbox.FieldUIDNext, field.TypeUint32, value) } if value, ok := mu.mutation.UIDValidity(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDValidity, - }) + _spec.SetField(mailbox.FieldUIDValidity, field.TypeUint32, value) } if value, ok := mu.mutation.AddedUIDValidity(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDValidity, - }) + _spec.AddField(mailbox.FieldUIDValidity, field.TypeUint32, value) } if value, ok := mu.mutation.Subscribed(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: mailbox.FieldSubscribed, - }) + _spec.SetField(mailbox.FieldSubscribed, field.TypeBool, value) } if mu.mutation.UIDsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -614,6 +547,7 @@ func (mu *MailboxUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mu.mutation.done = true return n, nil } @@ -856,6 +790,12 @@ func (muo *MailboxUpdateOne) RemoveAttributes(m ...*MailboxAttr) *MailboxUpdateO return muo.RemoveAttributeIDs(ids...) } +// Where appends a list predicates to the MailboxUpdate builder. +func (muo *MailboxUpdateOne) Where(ps ...predicate.Mailbox) *MailboxUpdateOne { + muo.mutation.Where(ps...) + return muo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (muo *MailboxUpdateOne) Select(field string, fields ...string) *MailboxUpdateOne { @@ -865,40 +805,7 @@ func (muo *MailboxUpdateOne) Select(field string, fields ...string) *MailboxUpda // Save executes the query and returns the updated Mailbox entity. func (muo *MailboxUpdateOne) Save(ctx context.Context) (*Mailbox, error) { - var ( - err error - node *Mailbox - ) - if len(muo.hooks) == 0 { - node, err = muo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - muo.mutation = mutation - node, err = muo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(muo.hooks) - 1; i >= 0; i-- { - if muo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = muo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, muo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Mailbox) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxMutation", v) - } - node = nv - } - return node, err + return withHooks[*Mailbox, MailboxMutation](ctx, muo.sqlSave, muo.mutation, muo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -924,19 +831,10 @@ func (muo *MailboxUpdateOne) ExecX(ctx context.Context) { } func (muo *MailboxUpdateOne) sqlSave(ctx context.Context) (_node *Mailbox, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailbox.Table, - Columns: mailbox.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, - Column: mailbox.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailbox.Table, mailbox.Columns, sqlgraph.NewFieldSpec(mailbox.FieldID, field.TypeUint64)) id, ok := muo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Mailbox.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "Mailbox.id" for update`)} } _spec.Node.ID.Value = id if fields := muo.fields; len(fields) > 0 { @@ -944,7 +842,7 @@ func (muo *MailboxUpdateOne) sqlSave(ctx context.Context) (_node *Mailbox, err e _spec.Node.Columns = append(_spec.Node.Columns, mailbox.FieldID) for _, f := range fields { if !mailbox.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != mailbox.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -959,59 +857,28 @@ func (muo *MailboxUpdateOne) sqlSave(ctx context.Context) (_node *Mailbox, err e } } if value, ok := muo.mutation.RemoteID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailbox.FieldRemoteID, - }) + _spec.SetField(mailbox.FieldRemoteID, field.TypeString, value) } if muo.mutation.RemoteIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: mailbox.FieldRemoteID, - }) + _spec.ClearField(mailbox.FieldRemoteID, field.TypeString) } if value, ok := muo.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailbox.FieldName, - }) + _spec.SetField(mailbox.FieldName, field.TypeString, value) } if value, ok := muo.mutation.UIDNext(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDNext, - }) + _spec.SetField(mailbox.FieldUIDNext, field.TypeUint32, value) } if value, ok := muo.mutation.AddedUIDNext(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDNext, - }) + _spec.AddField(mailbox.FieldUIDNext, field.TypeUint32, value) } if value, ok := muo.mutation.UIDValidity(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDValidity, - }) + _spec.SetField(mailbox.FieldUIDValidity, field.TypeUint32, value) } if value, ok := muo.mutation.AddedUIDValidity(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDValidity, - }) + _spec.AddField(mailbox.FieldUIDValidity, field.TypeUint32, value) } if value, ok := muo.mutation.Subscribed(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: mailbox.FieldSubscribed, - }) + _spec.SetField(mailbox.FieldSubscribed, field.TypeBool, value) } if muo.mutation.UIDsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -1240,5 +1107,6 @@ func (muo *MailboxUpdateOne) sqlSave(ctx context.Context) (_node *Mailbox, err e } return nil, err } + muo.mutation.done = true return _node, nil } diff --git a/internal/db/ent/mailboxattr.go b/internal/db_impl/ent_db/internal/mailboxattr.go similarity index 85% rename from internal/db/ent/mailboxattr.go rename to internal/db_impl/ent_db/internal/mailboxattr.go index fa548244..aef2a814 100644 --- a/internal/db/ent/mailboxattr.go +++ b/internal/db_impl/ent_db/internal/mailboxattr.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -8,7 +8,7 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" ) // MailboxAttr is the model entity for the MailboxAttr schema. @@ -22,8 +22,8 @@ type MailboxAttr struct { } // scanValues returns the types for scanning values from sql.Rows. -func (*MailboxAttr) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*MailboxAttr) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case mailboxattr.FieldID: @@ -41,7 +41,7 @@ func (*MailboxAttr) scanValues(columns []string) ([]interface{}, error) { // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the MailboxAttr fields. -func (ma *MailboxAttr) assignValues(columns []string, values []interface{}) error { +func (ma *MailboxAttr) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -75,7 +75,7 @@ func (ma *MailboxAttr) assignValues(columns []string, values []interface{}) erro // Note that you need to call MailboxAttr.Unwrap() before calling this method if this MailboxAttr // was returned from a transaction, and the transaction was committed or rolled back. func (ma *MailboxAttr) Update() *MailboxAttrUpdateOne { - return (&MailboxAttrClient{config: ma.config}).UpdateOne(ma) + return NewMailboxAttrClient(ma.config).UpdateOne(ma) } // Unwrap unwraps the MailboxAttr entity that was returned from a transaction after it was closed, @@ -83,7 +83,7 @@ func (ma *MailboxAttr) Update() *MailboxAttrUpdateOne { func (ma *MailboxAttr) Unwrap() *MailboxAttr { _tx, ok := ma.config.driver.(*txDriver) if !ok { - panic("ent: MailboxAttr is not a transactional entity") + panic("internal: MailboxAttr is not a transactional entity") } ma.config.driver = _tx.drv return ma @@ -102,9 +102,3 @@ func (ma *MailboxAttr) String() string { // MailboxAttrs is a parsable slice of MailboxAttr. type MailboxAttrs []*MailboxAttr - -func (ma MailboxAttrs) config(cfg config) { - for _i := range ma { - ma[_i].config = cfg - } -} diff --git a/internal/db/ent/mailboxattr/mailboxattr.go b/internal/db_impl/ent_db/internal/mailboxattr/mailboxattr.go similarity index 100% rename from internal/db/ent/mailboxattr/mailboxattr.go rename to internal/db_impl/ent_db/internal/mailboxattr/mailboxattr.go diff --git a/internal/db/ent/mailboxattr/where.go b/internal/db_impl/ent_db/internal/mailboxattr/where.go similarity index 56% rename from internal/db/ent/mailboxattr/where.go rename to internal/db_impl/ent_db/internal/mailboxattr/where.go index 84a16fe4..6ea922dd 100644 --- a/internal/db/ent/mailboxattr/where.go +++ b/internal/db_impl/ent_db/internal/mailboxattr/where.go @@ -4,184 +4,122 @@ package mailboxattr import ( "entgo.io/ent/dialect/sql" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MailboxAttr(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MailboxAttr(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.MailboxAttr(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.MailboxAttr(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.MailboxAttr(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.MailboxAttr(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.MailboxAttr(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.MailboxAttr(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.MailboxAttr(sql.FieldLTE(FieldID, id)) } // Value applies equality check predicate on the "Value" field. It's identical to ValueEQ. func Value(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldEQ(FieldValue, v)) } // ValueEQ applies the EQ predicate on the "Value" field. func ValueEQ(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "Value" field. func ValueNEQ(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "Value" field. func ValueIn(vs ...string) predicate.MailboxAttr { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.MailboxAttr(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "Value" field. func ValueNotIn(vs ...string) predicate.MailboxAttr { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.MailboxAttr(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "Value" field. func ValueGT(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "Value" field. func ValueGTE(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "Value" field. func ValueLT(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "Value" field. func ValueLTE(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "Value" field. func ValueContains(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "Value" field. func ValueHasPrefix(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "Value" field. func ValueHasSuffix(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "Value" field. func ValueEqualFold(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "Value" field. func ValueContainsFold(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldContainsFold(FieldValue, v)) } // And groups predicates with the AND operator between them. diff --git a/internal/db/ent/mailboxattr_create.go b/internal/db_impl/ent_db/internal/mailboxattr_create.go similarity index 74% rename from internal/db/ent/mailboxattr_create.go rename to internal/db_impl/ent_db/internal/mailboxattr_create.go index 4908d550..0cb6b4f0 100644 --- a/internal/db/ent/mailboxattr_create.go +++ b/internal/db_impl/ent_db/internal/mailboxattr_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -9,7 +9,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" ) // MailboxAttrCreate is the builder for creating a MailboxAttr entity. @@ -32,49 +32,7 @@ func (mac *MailboxAttrCreate) Mutation() *MailboxAttrMutation { // Save creates the MailboxAttr in the database. func (mac *MailboxAttrCreate) Save(ctx context.Context) (*MailboxAttr, error) { - var ( - err error - node *MailboxAttr - ) - if len(mac.hooks) == 0 { - if err = mac.check(); err != nil { - return nil, err - } - node, err = mac.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxAttrMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mac.check(); err != nil { - return nil, err - } - mac.mutation = mutation - if node, err = mac.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mac.hooks) - 1; i >= 0; i-- { - if mac.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mac.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mac.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MailboxAttr) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxAttrMutation", v) - } - node = nv - } - return node, err + return withHooks[*MailboxAttr, MailboxAttrMutation](ctx, mac.sqlSave, mac.mutation, mac.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -102,12 +60,15 @@ func (mac *MailboxAttrCreate) ExecX(ctx context.Context) { // check runs all checks and user-defined validators on the builder. func (mac *MailboxAttrCreate) check() error { if _, ok := mac.mutation.Value(); !ok { - return &ValidationError{Name: "Value", err: errors.New(`ent: missing required field "MailboxAttr.Value"`)} + return &ValidationError{Name: "Value", err: errors.New(`internal: missing required field "MailboxAttr.Value"`)} } return nil } func (mac *MailboxAttrCreate) sqlSave(ctx context.Context) (*MailboxAttr, error) { + if err := mac.check(); err != nil { + return nil, err + } _node, _spec := mac.createSpec() if err := sqlgraph.CreateNode(ctx, mac.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -117,26 +78,18 @@ func (mac *MailboxAttrCreate) sqlSave(ctx context.Context) (*MailboxAttr, error) } id := _spec.ID.Value.(int64) _node.ID = int(id) + mac.mutation.id = &_node.ID + mac.mutation.done = true return _node, nil } func (mac *MailboxAttrCreate) createSpec() (*MailboxAttr, *sqlgraph.CreateSpec) { var ( _node = &MailboxAttr{config: mac.config} - _spec = &sqlgraph.CreateSpec{ - Table: mailboxattr.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxattr.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(mailboxattr.Table, sqlgraph.NewFieldSpec(mailboxattr.FieldID, field.TypeInt)) ) if value, ok := mac.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxattr.FieldValue, - }) + _spec.SetField(mailboxattr.FieldValue, field.TypeString, value) _node.Value = value } return _node, _spec diff --git a/internal/db/ent/mailboxattr_delete.go b/internal/db_impl/ent_db/internal/mailboxattr_delete.go similarity index 62% rename from internal/db/ent/mailboxattr_delete.go rename to internal/db_impl/ent_db/internal/mailboxattr_delete.go index a13092bd..fbaa3e78 100644 --- a/internal/db/ent/mailboxattr_delete.go +++ b/internal/db_impl/ent_db/internal/mailboxattr_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxAttrDelete is the builder for deleting a MailboxAttr entity. @@ -28,34 +27,7 @@ func (mad *MailboxAttrDelete) Where(ps ...predicate.MailboxAttr) *MailboxAttrDel // Exec executes the deletion query and returns how many vertices were deleted. func (mad *MailboxAttrDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mad.hooks) == 0 { - affected, err = mad.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxAttrMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mad.mutation = mutation - affected, err = mad.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(mad.hooks) - 1; i >= 0; i-- { - if mad.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mad.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mad.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxAttrMutation](ctx, mad.sqlExec, mad.mutation, mad.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (mad *MailboxAttrDelete) ExecX(ctx context.Context) int { } func (mad *MailboxAttrDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxattr.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxattr.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(mailboxattr.Table, sqlgraph.NewFieldSpec(mailboxattr.FieldID, field.TypeInt)) if ps := mad.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (mad *MailboxAttrDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + mad.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MailboxAttrDeleteOne struct { mad *MailboxAttrDelete } +// Where appends a list predicates to the MailboxAttrDelete builder. +func (mado *MailboxAttrDeleteOne) Where(ps ...predicate.MailboxAttr) *MailboxAttrDeleteOne { + mado.mad.mutation.Where(ps...) + return mado +} + // Exec executes the deletion query. func (mado *MailboxAttrDeleteOne) Exec(ctx context.Context) error { n, err := mado.mad.Exec(ctx) @@ -111,5 +82,7 @@ func (mado *MailboxAttrDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mado *MailboxAttrDeleteOne) ExecX(ctx context.Context) { - mado.mad.ExecX(ctx) + if err := mado.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/mailboxattr_query.go b/internal/db_impl/ent_db/internal/mailboxattr_query.go similarity index 67% rename from internal/db/ent/mailboxattr_query.go rename to internal/db_impl/ent_db/internal/mailboxattr_query.go index bb5fcf54..7974ba3a 100644 --- a/internal/db/ent/mailboxattr_query.go +++ b/internal/db_impl/ent_db/internal/mailboxattr_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,18 +10,16 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxAttrQuery is the builder for querying MailboxAttr entities. type MailboxAttrQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.MailboxAttr withFKs bool // intermediate query (i.e. traversal path). @@ -35,26 +33,26 @@ func (maq *MailboxAttrQuery) Where(ps ...predicate.MailboxAttr) *MailboxAttrQuer return maq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (maq *MailboxAttrQuery) Limit(limit int) *MailboxAttrQuery { - maq.limit = &limit + maq.ctx.Limit = &limit return maq } -// Offset adds an offset step to the query. +// Offset to start from. func (maq *MailboxAttrQuery) Offset(offset int) *MailboxAttrQuery { - maq.offset = &offset + maq.ctx.Offset = &offset return maq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (maq *MailboxAttrQuery) Unique(unique bool) *MailboxAttrQuery { - maq.unique = &unique + maq.ctx.Unique = &unique return maq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (maq *MailboxAttrQuery) Order(o ...OrderFunc) *MailboxAttrQuery { maq.order = append(maq.order, o...) return maq @@ -63,7 +61,7 @@ func (maq *MailboxAttrQuery) Order(o ...OrderFunc) *MailboxAttrQuery { // First returns the first MailboxAttr entity from the query. // Returns a *NotFoundError when no MailboxAttr was found. func (maq *MailboxAttrQuery) First(ctx context.Context) (*MailboxAttr, error) { - nodes, err := maq.Limit(1).All(ctx) + nodes, err := maq.Limit(1).All(setContextOp(ctx, maq.ctx, "First")) if err != nil { return nil, err } @@ -86,7 +84,7 @@ func (maq *MailboxAttrQuery) FirstX(ctx context.Context) *MailboxAttr { // Returns a *NotFoundError when no MailboxAttr ID was found. func (maq *MailboxAttrQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = maq.Limit(1).IDs(ctx); err != nil { + if ids, err = maq.Limit(1).IDs(setContextOp(ctx, maq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -109,7 +107,7 @@ func (maq *MailboxAttrQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one MailboxAttr entity is found. // Returns a *NotFoundError when no MailboxAttr entities are found. func (maq *MailboxAttrQuery) Only(ctx context.Context) (*MailboxAttr, error) { - nodes, err := maq.Limit(2).All(ctx) + nodes, err := maq.Limit(2).All(setContextOp(ctx, maq.ctx, "Only")) if err != nil { return nil, err } @@ -137,7 +135,7 @@ func (maq *MailboxAttrQuery) OnlyX(ctx context.Context) *MailboxAttr { // Returns a *NotFoundError when no entities are found. func (maq *MailboxAttrQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = maq.Limit(2).IDs(ctx); err != nil { + if ids, err = maq.Limit(2).IDs(setContextOp(ctx, maq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -162,10 +160,12 @@ func (maq *MailboxAttrQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of MailboxAttrs. func (maq *MailboxAttrQuery) All(ctx context.Context) ([]*MailboxAttr, error) { + ctx = setContextOp(ctx, maq.ctx, "All") if err := maq.prepareQuery(ctx); err != nil { return nil, err } - return maq.sqlAll(ctx) + qr := querierAll[[]*MailboxAttr, *MailboxAttrQuery]() + return withInterceptors[[]*MailboxAttr](ctx, maq, qr, maq.inters) } // AllX is like All, but panics if an error occurs. @@ -178,9 +178,12 @@ func (maq *MailboxAttrQuery) AllX(ctx context.Context) []*MailboxAttr { } // IDs executes the query and returns a list of MailboxAttr IDs. -func (maq *MailboxAttrQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := maq.Select(mailboxattr.FieldID).Scan(ctx, &ids); err != nil { +func (maq *MailboxAttrQuery) IDs(ctx context.Context) (ids []int, err error) { + if maq.ctx.Unique == nil && maq.path != nil { + maq.Unique(true) + } + ctx = setContextOp(ctx, maq.ctx, "IDs") + if err = maq.Select(mailboxattr.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -197,10 +200,11 @@ func (maq *MailboxAttrQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (maq *MailboxAttrQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, maq.ctx, "Count") if err := maq.prepareQuery(ctx); err != nil { return 0, err } - return maq.sqlCount(ctx) + return withInterceptors[int](ctx, maq, querierCount[*MailboxAttrQuery](), maq.inters) } // CountX is like Count, but panics if an error occurs. @@ -214,10 +218,15 @@ func (maq *MailboxAttrQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (maq *MailboxAttrQuery) Exist(ctx context.Context) (bool, error) { - if err := maq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, maq.ctx, "Exist") + switch _, err := maq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return maq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -237,14 +246,13 @@ func (maq *MailboxAttrQuery) Clone() *MailboxAttrQuery { } return &MailboxAttrQuery{ config: maq.config, - limit: maq.limit, - offset: maq.offset, + ctx: maq.ctx.Clone(), order: append([]OrderFunc{}, maq.order...), + inters: append([]Interceptor{}, maq.inters...), predicates: append([]predicate.MailboxAttr{}, maq.predicates...), // clone intermediate query. - sql: maq.sql.Clone(), - path: maq.path, - unique: maq.unique, + sql: maq.sql.Clone(), + path: maq.path, } } @@ -260,19 +268,14 @@ func (maq *MailboxAttrQuery) Clone() *MailboxAttrQuery { // // client.MailboxAttr.Query(). // GroupBy(mailboxattr.FieldValue). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (maq *MailboxAttrQuery) GroupBy(field string, fields ...string) *MailboxAttrGroupBy { - grbuild := &MailboxAttrGroupBy{config: maq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := maq.prepareQuery(ctx); err != nil { - return nil, err - } - return maq.sqlQuery(ctx), nil - } + maq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MailboxAttrGroupBy{build: maq} + grbuild.flds = &maq.ctx.Fields grbuild.label = mailboxattr.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -289,17 +292,32 @@ func (maq *MailboxAttrQuery) GroupBy(field string, fields ...string) *MailboxAtt // Select(mailboxattr.FieldValue). // Scan(ctx, &v) func (maq *MailboxAttrQuery) Select(fields ...string) *MailboxAttrSelect { - maq.fields = append(maq.fields, fields...) - selbuild := &MailboxAttrSelect{MailboxAttrQuery: maq} - selbuild.label = mailboxattr.Label - selbuild.flds, selbuild.scan = &maq.fields, selbuild.Scan - return selbuild + maq.ctx.Fields = append(maq.ctx.Fields, fields...) + sbuild := &MailboxAttrSelect{MailboxAttrQuery: maq} + sbuild.label = mailboxattr.Label + sbuild.flds, sbuild.scan = &maq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MailboxAttrSelect configured with the given aggregations. +func (maq *MailboxAttrQuery) Aggregate(fns ...AggregateFunc) *MailboxAttrSelect { + return maq.Select().Aggregate(fns...) } func (maq *MailboxAttrQuery) prepareQuery(ctx context.Context) error { - for _, f := range maq.fields { + for _, inter := range maq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, maq); err != nil { + return err + } + } + } + for _, f := range maq.ctx.Fields { if !mailboxattr.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if maq.path != nil { @@ -321,10 +339,10 @@ func (maq *MailboxAttrQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([] if withFKs { _spec.Node.Columns = append(_spec.Node.Columns, mailboxattr.ForeignKeys...) } - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*MailboxAttr).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &MailboxAttr{config: maq.config} nodes = append(nodes, node) return node.assignValues(columns, values) @@ -343,38 +361,22 @@ func (maq *MailboxAttrQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([] func (maq *MailboxAttrQuery) sqlCount(ctx context.Context) (int, error) { _spec := maq.querySpec() - _spec.Node.Columns = maq.fields - if len(maq.fields) > 0 { - _spec.Unique = maq.unique != nil && *maq.unique + _spec.Node.Columns = maq.ctx.Fields + if len(maq.ctx.Fields) > 0 { + _spec.Unique = maq.ctx.Unique != nil && *maq.ctx.Unique } return sqlgraph.CountNodes(ctx, maq.driver, _spec) } -func (maq *MailboxAttrQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := maq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (maq *MailboxAttrQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxattr.Table, - Columns: mailboxattr.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxattr.FieldID, - }, - }, - From: maq.sql, - Unique: true, - } - if unique := maq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(mailboxattr.Table, mailboxattr.Columns, sqlgraph.NewFieldSpec(mailboxattr.FieldID, field.TypeInt)) + _spec.From = maq.sql + if unique := maq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if maq.path != nil { + _spec.Unique = true } - if fields := maq.fields; len(fields) > 0 { + if fields := maq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, mailboxattr.FieldID) for i := range fields { @@ -390,10 +392,10 @@ func (maq *MailboxAttrQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := maq.limit; limit != nil { + if limit := maq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := maq.offset; offset != nil { + if offset := maq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := maq.order; len(ps) > 0 { @@ -409,7 +411,7 @@ func (maq *MailboxAttrQuery) querySpec() *sqlgraph.QuerySpec { func (maq *MailboxAttrQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(maq.driver.Dialect()) t1 := builder.Table(mailboxattr.Table) - columns := maq.fields + columns := maq.ctx.Fields if len(columns) == 0 { columns = mailboxattr.Columns } @@ -418,7 +420,7 @@ func (maq *MailboxAttrQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = maq.sql selector.Select(selector.Columns(columns...)...) } - if maq.unique != nil && *maq.unique { + if maq.ctx.Unique != nil && *maq.ctx.Unique { selector.Distinct() } for _, p := range maq.predicates { @@ -427,12 +429,12 @@ func (maq *MailboxAttrQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range maq.order { p(selector) } - if offset := maq.offset; offset != nil { + if offset := maq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := maq.limit; limit != nil { + if limit := maq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -440,13 +442,8 @@ func (maq *MailboxAttrQuery) sqlQuery(ctx context.Context) *sql.Selector { // MailboxAttrGroupBy is the group-by builder for MailboxAttr entities. type MailboxAttrGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MailboxAttrQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -455,74 +452,77 @@ func (magb *MailboxAttrGroupBy) Aggregate(fns ...AggregateFunc) *MailboxAttrGrou return magb } -// Scan applies the group-by query and scans the result into the given value. -func (magb *MailboxAttrGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := magb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (magb *MailboxAttrGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, magb.build.ctx, "GroupBy") + if err := magb.build.prepareQuery(ctx); err != nil { return err } - magb.sql = query - return magb.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxAttrQuery, *MailboxAttrGroupBy](ctx, magb.build, magb, magb.build.inters, v) } -func (magb *MailboxAttrGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range magb.fields { - if !mailboxattr.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (magb *MailboxAttrGroupBy) sqlScan(ctx context.Context, root *MailboxAttrQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(magb.fns)) + for _, fn := range magb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*magb.flds)+len(magb.fns)) + for _, f := range *magb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := magb.sqlQuery() + selector.GroupBy(selector.Columns(*magb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := magb.driver.Query(ctx, query, args, rows); err != nil { + if err := magb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (magb *MailboxAttrGroupBy) sqlQuery() *sql.Selector { - selector := magb.sql.Select() - aggregation := make([]string, 0, len(magb.fns)) - for _, fn := range magb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(magb.fields)+len(magb.fns)) - for _, f := range magb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(magb.fields...)...) -} - // MailboxAttrSelect is the builder for selecting fields of MailboxAttr entities. type MailboxAttrSelect struct { *MailboxAttrQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (mas *MailboxAttrSelect) Aggregate(fns ...AggregateFunc) *MailboxAttrSelect { + mas.fns = append(mas.fns, fns...) + return mas } // Scan applies the selector query and scans the result into the given value. -func (mas *MailboxAttrSelect) Scan(ctx context.Context, v interface{}) error { +func (mas *MailboxAttrSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mas.ctx, "Select") if err := mas.prepareQuery(ctx); err != nil { return err } - mas.sql = mas.MailboxAttrQuery.sqlQuery(ctx) - return mas.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxAttrQuery, *MailboxAttrSelect](ctx, mas.MailboxAttrQuery, mas, mas.inters, v) } -func (mas *MailboxAttrSelect) sqlScan(ctx context.Context, v interface{}) error { +func (mas *MailboxAttrSelect) sqlScan(ctx context.Context, root *MailboxAttrQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(mas.fns)) + for _, fn := range mas.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*mas.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := mas.sql.Query() + query, args := selector.Query() if err := mas.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/mailboxattr_update.go b/internal/db_impl/ent_db/internal/mailboxattr_update.go similarity index 63% rename from internal/db/ent/mailboxattr_update.go rename to internal/db_impl/ent_db/internal/mailboxattr_update.go index 73f65979..8998256a 100644 --- a/internal/db/ent/mailboxattr_update.go +++ b/internal/db_impl/ent_db/internal/mailboxattr_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,8 +10,8 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxAttrUpdate is the builder for updating MailboxAttr entities. @@ -40,34 +40,7 @@ func (mau *MailboxAttrUpdate) Mutation() *MailboxAttrMutation { // Save executes the query and returns the number of nodes affected by the update operation. func (mau *MailboxAttrUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mau.hooks) == 0 { - affected, err = mau.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxAttrMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mau.mutation = mutation - affected, err = mau.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mau.hooks) - 1; i >= 0; i-- { - if mau.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mau.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mau.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxAttrMutation](ctx, mau.sqlSave, mau.mutation, mau.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -93,16 +66,7 @@ func (mau *MailboxAttrUpdate) ExecX(ctx context.Context) { } func (mau *MailboxAttrUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxattr.Table, - Columns: mailboxattr.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxattr.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailboxattr.Table, mailboxattr.Columns, sqlgraph.NewFieldSpec(mailboxattr.FieldID, field.TypeInt)) if ps := mau.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -111,11 +75,7 @@ func (mau *MailboxAttrUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := mau.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxattr.FieldValue, - }) + _spec.SetField(mailboxattr.FieldValue, field.TypeString, value) } if n, err = sqlgraph.UpdateNodes(ctx, mau.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -125,6 +85,7 @@ func (mau *MailboxAttrUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mau.mutation.done = true return n, nil } @@ -147,6 +108,12 @@ func (mauo *MailboxAttrUpdateOne) Mutation() *MailboxAttrMutation { return mauo.mutation } +// Where appends a list predicates to the MailboxAttrUpdate builder. +func (mauo *MailboxAttrUpdateOne) Where(ps ...predicate.MailboxAttr) *MailboxAttrUpdateOne { + mauo.mutation.Where(ps...) + return mauo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (mauo *MailboxAttrUpdateOne) Select(field string, fields ...string) *MailboxAttrUpdateOne { @@ -156,40 +123,7 @@ func (mauo *MailboxAttrUpdateOne) Select(field string, fields ...string) *Mailbo // Save executes the query and returns the updated MailboxAttr entity. func (mauo *MailboxAttrUpdateOne) Save(ctx context.Context) (*MailboxAttr, error) { - var ( - err error - node *MailboxAttr - ) - if len(mauo.hooks) == 0 { - node, err = mauo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxAttrMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mauo.mutation = mutation - node, err = mauo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(mauo.hooks) - 1; i >= 0; i-- { - if mauo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mauo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mauo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MailboxAttr) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxAttrMutation", v) - } - node = nv - } - return node, err + return withHooks[*MailboxAttr, MailboxAttrMutation](ctx, mauo.sqlSave, mauo.mutation, mauo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -215,19 +149,10 @@ func (mauo *MailboxAttrUpdateOne) ExecX(ctx context.Context) { } func (mauo *MailboxAttrUpdateOne) sqlSave(ctx context.Context) (_node *MailboxAttr, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxattr.Table, - Columns: mailboxattr.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxattr.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailboxattr.Table, mailboxattr.Columns, sqlgraph.NewFieldSpec(mailboxattr.FieldID, field.TypeInt)) id, ok := mauo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "MailboxAttr.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "MailboxAttr.id" for update`)} } _spec.Node.ID.Value = id if fields := mauo.fields; len(fields) > 0 { @@ -235,7 +160,7 @@ func (mauo *MailboxAttrUpdateOne) sqlSave(ctx context.Context) (_node *MailboxAt _spec.Node.Columns = append(_spec.Node.Columns, mailboxattr.FieldID) for _, f := range fields { if !mailboxattr.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != mailboxattr.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -250,11 +175,7 @@ func (mauo *MailboxAttrUpdateOne) sqlSave(ctx context.Context) (_node *MailboxAt } } if value, ok := mauo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxattr.FieldValue, - }) + _spec.SetField(mailboxattr.FieldValue, field.TypeString, value) } _node = &MailboxAttr{config: mauo.config} _spec.Assign = _node.assignValues @@ -267,5 +188,6 @@ func (mauo *MailboxAttrUpdateOne) sqlSave(ctx context.Context) (_node *MailboxAt } return nil, err } + mauo.mutation.done = true return _node, nil } diff --git a/internal/db/ent/mailboxflag.go b/internal/db_impl/ent_db/internal/mailboxflag.go similarity index 85% rename from internal/db/ent/mailboxflag.go rename to internal/db_impl/ent_db/internal/mailboxflag.go index 18ce8cf5..d2a5d39e 100644 --- a/internal/db/ent/mailboxflag.go +++ b/internal/db_impl/ent_db/internal/mailboxflag.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -8,7 +8,7 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" ) // MailboxFlag is the model entity for the MailboxFlag schema. @@ -22,8 +22,8 @@ type MailboxFlag struct { } // scanValues returns the types for scanning values from sql.Rows. -func (*MailboxFlag) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*MailboxFlag) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case mailboxflag.FieldID: @@ -41,7 +41,7 @@ func (*MailboxFlag) scanValues(columns []string) ([]interface{}, error) { // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the MailboxFlag fields. -func (mf *MailboxFlag) assignValues(columns []string, values []interface{}) error { +func (mf *MailboxFlag) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -75,7 +75,7 @@ func (mf *MailboxFlag) assignValues(columns []string, values []interface{}) erro // Note that you need to call MailboxFlag.Unwrap() before calling this method if this MailboxFlag // was returned from a transaction, and the transaction was committed or rolled back. func (mf *MailboxFlag) Update() *MailboxFlagUpdateOne { - return (&MailboxFlagClient{config: mf.config}).UpdateOne(mf) + return NewMailboxFlagClient(mf.config).UpdateOne(mf) } // Unwrap unwraps the MailboxFlag entity that was returned from a transaction after it was closed, @@ -83,7 +83,7 @@ func (mf *MailboxFlag) Update() *MailboxFlagUpdateOne { func (mf *MailboxFlag) Unwrap() *MailboxFlag { _tx, ok := mf.config.driver.(*txDriver) if !ok { - panic("ent: MailboxFlag is not a transactional entity") + panic("internal: MailboxFlag is not a transactional entity") } mf.config.driver = _tx.drv return mf @@ -102,9 +102,3 @@ func (mf *MailboxFlag) String() string { // MailboxFlags is a parsable slice of MailboxFlag. type MailboxFlags []*MailboxFlag - -func (mf MailboxFlags) config(cfg config) { - for _i := range mf { - mf[_i].config = cfg - } -} diff --git a/internal/db/ent/mailboxflag/mailboxflag.go b/internal/db_impl/ent_db/internal/mailboxflag/mailboxflag.go similarity index 100% rename from internal/db/ent/mailboxflag/mailboxflag.go rename to internal/db_impl/ent_db/internal/mailboxflag/mailboxflag.go diff --git a/internal/db/ent/mailboxflag/where.go b/internal/db_impl/ent_db/internal/mailboxflag/where.go similarity index 56% rename from internal/db/ent/mailboxflag/where.go rename to internal/db_impl/ent_db/internal/mailboxflag/where.go index 004f34f3..dab75500 100644 --- a/internal/db/ent/mailboxflag/where.go +++ b/internal/db_impl/ent_db/internal/mailboxflag/where.go @@ -4,184 +4,122 @@ package mailboxflag import ( "entgo.io/ent/dialect/sql" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MailboxFlag(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MailboxFlag(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.MailboxFlag(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.MailboxFlag(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.MailboxFlag(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.MailboxFlag(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.MailboxFlag(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.MailboxFlag(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.MailboxFlag(sql.FieldLTE(FieldID, id)) } // Value applies equality check predicate on the "Value" field. It's identical to ValueEQ. func Value(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldEQ(FieldValue, v)) } // ValueEQ applies the EQ predicate on the "Value" field. func ValueEQ(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "Value" field. func ValueNEQ(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "Value" field. func ValueIn(vs ...string) predicate.MailboxFlag { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.MailboxFlag(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "Value" field. func ValueNotIn(vs ...string) predicate.MailboxFlag { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.MailboxFlag(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "Value" field. func ValueGT(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "Value" field. func ValueGTE(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "Value" field. func ValueLT(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "Value" field. func ValueLTE(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "Value" field. func ValueContains(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "Value" field. func ValueHasPrefix(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "Value" field. func ValueHasSuffix(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "Value" field. func ValueEqualFold(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "Value" field. func ValueContainsFold(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldContainsFold(FieldValue, v)) } // And groups predicates with the AND operator between them. diff --git a/internal/db/ent/mailboxflag_create.go b/internal/db_impl/ent_db/internal/mailboxflag_create.go similarity index 74% rename from internal/db/ent/mailboxflag_create.go rename to internal/db_impl/ent_db/internal/mailboxflag_create.go index fa262a37..dd6fb8e0 100644 --- a/internal/db/ent/mailboxflag_create.go +++ b/internal/db_impl/ent_db/internal/mailboxflag_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -9,7 +9,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" ) // MailboxFlagCreate is the builder for creating a MailboxFlag entity. @@ -32,49 +32,7 @@ func (mfc *MailboxFlagCreate) Mutation() *MailboxFlagMutation { // Save creates the MailboxFlag in the database. func (mfc *MailboxFlagCreate) Save(ctx context.Context) (*MailboxFlag, error) { - var ( - err error - node *MailboxFlag - ) - if len(mfc.hooks) == 0 { - if err = mfc.check(); err != nil { - return nil, err - } - node, err = mfc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mfc.check(); err != nil { - return nil, err - } - mfc.mutation = mutation - if node, err = mfc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mfc.hooks) - 1; i >= 0; i-- { - if mfc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mfc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MailboxFlag) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxFlagMutation", v) - } - node = nv - } - return node, err + return withHooks[*MailboxFlag, MailboxFlagMutation](ctx, mfc.sqlSave, mfc.mutation, mfc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -102,12 +60,15 @@ func (mfc *MailboxFlagCreate) ExecX(ctx context.Context) { // check runs all checks and user-defined validators on the builder. func (mfc *MailboxFlagCreate) check() error { if _, ok := mfc.mutation.Value(); !ok { - return &ValidationError{Name: "Value", err: errors.New(`ent: missing required field "MailboxFlag.Value"`)} + return &ValidationError{Name: "Value", err: errors.New(`internal: missing required field "MailboxFlag.Value"`)} } return nil } func (mfc *MailboxFlagCreate) sqlSave(ctx context.Context) (*MailboxFlag, error) { + if err := mfc.check(); err != nil { + return nil, err + } _node, _spec := mfc.createSpec() if err := sqlgraph.CreateNode(ctx, mfc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -117,26 +78,18 @@ func (mfc *MailboxFlagCreate) sqlSave(ctx context.Context) (*MailboxFlag, error) } id := _spec.ID.Value.(int64) _node.ID = int(id) + mfc.mutation.id = &_node.ID + mfc.mutation.done = true return _node, nil } func (mfc *MailboxFlagCreate) createSpec() (*MailboxFlag, *sqlgraph.CreateSpec) { var ( _node = &MailboxFlag{config: mfc.config} - _spec = &sqlgraph.CreateSpec{ - Table: mailboxflag.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxflag.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(mailboxflag.Table, sqlgraph.NewFieldSpec(mailboxflag.FieldID, field.TypeInt)) ) if value, ok := mfc.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxflag.FieldValue, - }) + _spec.SetField(mailboxflag.FieldValue, field.TypeString, value) _node.Value = value } return _node, _spec diff --git a/internal/db/ent/mailboxflag_delete.go b/internal/db_impl/ent_db/internal/mailboxflag_delete.go similarity index 62% rename from internal/db/ent/mailboxflag_delete.go rename to internal/db_impl/ent_db/internal/mailboxflag_delete.go index 01ba933c..7fd9e0d9 100644 --- a/internal/db/ent/mailboxflag_delete.go +++ b/internal/db_impl/ent_db/internal/mailboxflag_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxFlagDelete is the builder for deleting a MailboxFlag entity. @@ -28,34 +27,7 @@ func (mfd *MailboxFlagDelete) Where(ps ...predicate.MailboxFlag) *MailboxFlagDel // Exec executes the deletion query and returns how many vertices were deleted. func (mfd *MailboxFlagDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mfd.hooks) == 0 { - affected, err = mfd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mfd.mutation = mutation - affected, err = mfd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(mfd.hooks) - 1; i >= 0; i-- { - if mfd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mfd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxFlagMutation](ctx, mfd.sqlExec, mfd.mutation, mfd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (mfd *MailboxFlagDelete) ExecX(ctx context.Context) int { } func (mfd *MailboxFlagDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxflag.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(mailboxflag.Table, sqlgraph.NewFieldSpec(mailboxflag.FieldID, field.TypeInt)) if ps := mfd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (mfd *MailboxFlagDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + mfd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MailboxFlagDeleteOne struct { mfd *MailboxFlagDelete } +// Where appends a list predicates to the MailboxFlagDelete builder. +func (mfdo *MailboxFlagDeleteOne) Where(ps ...predicate.MailboxFlag) *MailboxFlagDeleteOne { + mfdo.mfd.mutation.Where(ps...) + return mfdo +} + // Exec executes the deletion query. func (mfdo *MailboxFlagDeleteOne) Exec(ctx context.Context) error { n, err := mfdo.mfd.Exec(ctx) @@ -111,5 +82,7 @@ func (mfdo *MailboxFlagDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mfdo *MailboxFlagDeleteOne) ExecX(ctx context.Context) { - mfdo.mfd.ExecX(ctx) + if err := mfdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/mailboxflag_query.go b/internal/db_impl/ent_db/internal/mailboxflag_query.go similarity index 67% rename from internal/db/ent/mailboxflag_query.go rename to internal/db_impl/ent_db/internal/mailboxflag_query.go index 0f04bbbe..bfb150d1 100644 --- a/internal/db/ent/mailboxflag_query.go +++ b/internal/db_impl/ent_db/internal/mailboxflag_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,18 +10,16 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxFlagQuery is the builder for querying MailboxFlag entities. type MailboxFlagQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.MailboxFlag withFKs bool // intermediate query (i.e. traversal path). @@ -35,26 +33,26 @@ func (mfq *MailboxFlagQuery) Where(ps ...predicate.MailboxFlag) *MailboxFlagQuer return mfq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mfq *MailboxFlagQuery) Limit(limit int) *MailboxFlagQuery { - mfq.limit = &limit + mfq.ctx.Limit = &limit return mfq } -// Offset adds an offset step to the query. +// Offset to start from. func (mfq *MailboxFlagQuery) Offset(offset int) *MailboxFlagQuery { - mfq.offset = &offset + mfq.ctx.Offset = &offset return mfq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mfq *MailboxFlagQuery) Unique(unique bool) *MailboxFlagQuery { - mfq.unique = &unique + mfq.ctx.Unique = &unique return mfq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (mfq *MailboxFlagQuery) Order(o ...OrderFunc) *MailboxFlagQuery { mfq.order = append(mfq.order, o...) return mfq @@ -63,7 +61,7 @@ func (mfq *MailboxFlagQuery) Order(o ...OrderFunc) *MailboxFlagQuery { // First returns the first MailboxFlag entity from the query. // Returns a *NotFoundError when no MailboxFlag was found. func (mfq *MailboxFlagQuery) First(ctx context.Context) (*MailboxFlag, error) { - nodes, err := mfq.Limit(1).All(ctx) + nodes, err := mfq.Limit(1).All(setContextOp(ctx, mfq.ctx, "First")) if err != nil { return nil, err } @@ -86,7 +84,7 @@ func (mfq *MailboxFlagQuery) FirstX(ctx context.Context) *MailboxFlag { // Returns a *NotFoundError when no MailboxFlag ID was found. func (mfq *MailboxFlagQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mfq.Limit(1).IDs(ctx); err != nil { + if ids, err = mfq.Limit(1).IDs(setContextOp(ctx, mfq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -109,7 +107,7 @@ func (mfq *MailboxFlagQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one MailboxFlag entity is found. // Returns a *NotFoundError when no MailboxFlag entities are found. func (mfq *MailboxFlagQuery) Only(ctx context.Context) (*MailboxFlag, error) { - nodes, err := mfq.Limit(2).All(ctx) + nodes, err := mfq.Limit(2).All(setContextOp(ctx, mfq.ctx, "Only")) if err != nil { return nil, err } @@ -137,7 +135,7 @@ func (mfq *MailboxFlagQuery) OnlyX(ctx context.Context) *MailboxFlag { // Returns a *NotFoundError when no entities are found. func (mfq *MailboxFlagQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mfq.Limit(2).IDs(ctx); err != nil { + if ids, err = mfq.Limit(2).IDs(setContextOp(ctx, mfq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -162,10 +160,12 @@ func (mfq *MailboxFlagQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of MailboxFlags. func (mfq *MailboxFlagQuery) All(ctx context.Context) ([]*MailboxFlag, error) { + ctx = setContextOp(ctx, mfq.ctx, "All") if err := mfq.prepareQuery(ctx); err != nil { return nil, err } - return mfq.sqlAll(ctx) + qr := querierAll[[]*MailboxFlag, *MailboxFlagQuery]() + return withInterceptors[[]*MailboxFlag](ctx, mfq, qr, mfq.inters) } // AllX is like All, but panics if an error occurs. @@ -178,9 +178,12 @@ func (mfq *MailboxFlagQuery) AllX(ctx context.Context) []*MailboxFlag { } // IDs executes the query and returns a list of MailboxFlag IDs. -func (mfq *MailboxFlagQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := mfq.Select(mailboxflag.FieldID).Scan(ctx, &ids); err != nil { +func (mfq *MailboxFlagQuery) IDs(ctx context.Context) (ids []int, err error) { + if mfq.ctx.Unique == nil && mfq.path != nil { + mfq.Unique(true) + } + ctx = setContextOp(ctx, mfq.ctx, "IDs") + if err = mfq.Select(mailboxflag.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -197,10 +200,11 @@ func (mfq *MailboxFlagQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (mfq *MailboxFlagQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mfq.ctx, "Count") if err := mfq.prepareQuery(ctx); err != nil { return 0, err } - return mfq.sqlCount(ctx) + return withInterceptors[int](ctx, mfq, querierCount[*MailboxFlagQuery](), mfq.inters) } // CountX is like Count, but panics if an error occurs. @@ -214,10 +218,15 @@ func (mfq *MailboxFlagQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mfq *MailboxFlagQuery) Exist(ctx context.Context) (bool, error) { - if err := mfq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mfq.ctx, "Exist") + switch _, err := mfq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return mfq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -237,14 +246,13 @@ func (mfq *MailboxFlagQuery) Clone() *MailboxFlagQuery { } return &MailboxFlagQuery{ config: mfq.config, - limit: mfq.limit, - offset: mfq.offset, + ctx: mfq.ctx.Clone(), order: append([]OrderFunc{}, mfq.order...), + inters: append([]Interceptor{}, mfq.inters...), predicates: append([]predicate.MailboxFlag{}, mfq.predicates...), // clone intermediate query. - sql: mfq.sql.Clone(), - path: mfq.path, - unique: mfq.unique, + sql: mfq.sql.Clone(), + path: mfq.path, } } @@ -260,19 +268,14 @@ func (mfq *MailboxFlagQuery) Clone() *MailboxFlagQuery { // // client.MailboxFlag.Query(). // GroupBy(mailboxflag.FieldValue). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (mfq *MailboxFlagQuery) GroupBy(field string, fields ...string) *MailboxFlagGroupBy { - grbuild := &MailboxFlagGroupBy{config: mfq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mfq.prepareQuery(ctx); err != nil { - return nil, err - } - return mfq.sqlQuery(ctx), nil - } + mfq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MailboxFlagGroupBy{build: mfq} + grbuild.flds = &mfq.ctx.Fields grbuild.label = mailboxflag.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -289,17 +292,32 @@ func (mfq *MailboxFlagQuery) GroupBy(field string, fields ...string) *MailboxFla // Select(mailboxflag.FieldValue). // Scan(ctx, &v) func (mfq *MailboxFlagQuery) Select(fields ...string) *MailboxFlagSelect { - mfq.fields = append(mfq.fields, fields...) - selbuild := &MailboxFlagSelect{MailboxFlagQuery: mfq} - selbuild.label = mailboxflag.Label - selbuild.flds, selbuild.scan = &mfq.fields, selbuild.Scan - return selbuild + mfq.ctx.Fields = append(mfq.ctx.Fields, fields...) + sbuild := &MailboxFlagSelect{MailboxFlagQuery: mfq} + sbuild.label = mailboxflag.Label + sbuild.flds, sbuild.scan = &mfq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MailboxFlagSelect configured with the given aggregations. +func (mfq *MailboxFlagQuery) Aggregate(fns ...AggregateFunc) *MailboxFlagSelect { + return mfq.Select().Aggregate(fns...) } func (mfq *MailboxFlagQuery) prepareQuery(ctx context.Context) error { - for _, f := range mfq.fields { + for _, inter := range mfq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mfq); err != nil { + return err + } + } + } + for _, f := range mfq.ctx.Fields { if !mailboxflag.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if mfq.path != nil { @@ -321,10 +339,10 @@ func (mfq *MailboxFlagQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([] if withFKs { _spec.Node.Columns = append(_spec.Node.Columns, mailboxflag.ForeignKeys...) } - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*MailboxFlag).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &MailboxFlag{config: mfq.config} nodes = append(nodes, node) return node.assignValues(columns, values) @@ -343,38 +361,22 @@ func (mfq *MailboxFlagQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([] func (mfq *MailboxFlagQuery) sqlCount(ctx context.Context) (int, error) { _spec := mfq.querySpec() - _spec.Node.Columns = mfq.fields - if len(mfq.fields) > 0 { - _spec.Unique = mfq.unique != nil && *mfq.unique + _spec.Node.Columns = mfq.ctx.Fields + if len(mfq.ctx.Fields) > 0 { + _spec.Unique = mfq.ctx.Unique != nil && *mfq.ctx.Unique } return sqlgraph.CountNodes(ctx, mfq.driver, _spec) } -func (mfq *MailboxFlagQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := mfq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (mfq *MailboxFlagQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxflag.Table, - Columns: mailboxflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxflag.FieldID, - }, - }, - From: mfq.sql, - Unique: true, - } - if unique := mfq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(mailboxflag.Table, mailboxflag.Columns, sqlgraph.NewFieldSpec(mailboxflag.FieldID, field.TypeInt)) + _spec.From = mfq.sql + if unique := mfq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mfq.path != nil { + _spec.Unique = true } - if fields := mfq.fields; len(fields) > 0 { + if fields := mfq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, mailboxflag.FieldID) for i := range fields { @@ -390,10 +392,10 @@ func (mfq *MailboxFlagQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mfq.limit; limit != nil { + if limit := mfq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mfq.offset; offset != nil { + if offset := mfq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mfq.order; len(ps) > 0 { @@ -409,7 +411,7 @@ func (mfq *MailboxFlagQuery) querySpec() *sqlgraph.QuerySpec { func (mfq *MailboxFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mfq.driver.Dialect()) t1 := builder.Table(mailboxflag.Table) - columns := mfq.fields + columns := mfq.ctx.Fields if len(columns) == 0 { columns = mailboxflag.Columns } @@ -418,7 +420,7 @@ func (mfq *MailboxFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mfq.sql selector.Select(selector.Columns(columns...)...) } - if mfq.unique != nil && *mfq.unique { + if mfq.ctx.Unique != nil && *mfq.ctx.Unique { selector.Distinct() } for _, p := range mfq.predicates { @@ -427,12 +429,12 @@ func (mfq *MailboxFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mfq.order { p(selector) } - if offset := mfq.offset; offset != nil { + if offset := mfq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mfq.limit; limit != nil { + if limit := mfq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -440,13 +442,8 @@ func (mfq *MailboxFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { // MailboxFlagGroupBy is the group-by builder for MailboxFlag entities. type MailboxFlagGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MailboxFlagQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -455,74 +452,77 @@ func (mfgb *MailboxFlagGroupBy) Aggregate(fns ...AggregateFunc) *MailboxFlagGrou return mfgb } -// Scan applies the group-by query and scans the result into the given value. -func (mfgb *MailboxFlagGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := mfgb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (mfgb *MailboxFlagGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mfgb.build.ctx, "GroupBy") + if err := mfgb.build.prepareQuery(ctx); err != nil { return err } - mfgb.sql = query - return mfgb.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxFlagQuery, *MailboxFlagGroupBy](ctx, mfgb.build, mfgb, mfgb.build.inters, v) } -func (mfgb *MailboxFlagGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range mfgb.fields { - if !mailboxflag.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mfgb *MailboxFlagGroupBy) sqlScan(ctx context.Context, root *MailboxFlagQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mfgb.fns)) + for _, fn := range mfgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mfgb.flds)+len(mfgb.fns)) + for _, f := range *mfgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mfgb.sqlQuery() + selector.GroupBy(selector.Columns(*mfgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mfgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mfgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mfgb *MailboxFlagGroupBy) sqlQuery() *sql.Selector { - selector := mfgb.sql.Select() - aggregation := make([]string, 0, len(mfgb.fns)) - for _, fn := range mfgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mfgb.fields)+len(mfgb.fns)) - for _, f := range mfgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mfgb.fields...)...) -} - // MailboxFlagSelect is the builder for selecting fields of MailboxFlag entities. type MailboxFlagSelect struct { *MailboxFlagQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (mfs *MailboxFlagSelect) Aggregate(fns ...AggregateFunc) *MailboxFlagSelect { + mfs.fns = append(mfs.fns, fns...) + return mfs } // Scan applies the selector query and scans the result into the given value. -func (mfs *MailboxFlagSelect) Scan(ctx context.Context, v interface{}) error { +func (mfs *MailboxFlagSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mfs.ctx, "Select") if err := mfs.prepareQuery(ctx); err != nil { return err } - mfs.sql = mfs.MailboxFlagQuery.sqlQuery(ctx) - return mfs.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxFlagQuery, *MailboxFlagSelect](ctx, mfs.MailboxFlagQuery, mfs, mfs.inters, v) } -func (mfs *MailboxFlagSelect) sqlScan(ctx context.Context, v interface{}) error { +func (mfs *MailboxFlagSelect) sqlScan(ctx context.Context, root *MailboxFlagQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(mfs.fns)) + for _, fn := range mfs.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*mfs.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := mfs.sql.Query() + query, args := selector.Query() if err := mfs.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/mailboxflag_update.go b/internal/db_impl/ent_db/internal/mailboxflag_update.go similarity index 63% rename from internal/db/ent/mailboxflag_update.go rename to internal/db_impl/ent_db/internal/mailboxflag_update.go index 2ce94f03..5c4d5496 100644 --- a/internal/db/ent/mailboxflag_update.go +++ b/internal/db_impl/ent_db/internal/mailboxflag_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,8 +10,8 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxFlagUpdate is the builder for updating MailboxFlag entities. @@ -40,34 +40,7 @@ func (mfu *MailboxFlagUpdate) Mutation() *MailboxFlagMutation { // Save executes the query and returns the number of nodes affected by the update operation. func (mfu *MailboxFlagUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mfu.hooks) == 0 { - affected, err = mfu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mfu.mutation = mutation - affected, err = mfu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mfu.hooks) - 1; i >= 0; i-- { - if mfu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mfu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxFlagMutation](ctx, mfu.sqlSave, mfu.mutation, mfu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -93,16 +66,7 @@ func (mfu *MailboxFlagUpdate) ExecX(ctx context.Context) { } func (mfu *MailboxFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxflag.Table, - Columns: mailboxflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailboxflag.Table, mailboxflag.Columns, sqlgraph.NewFieldSpec(mailboxflag.FieldID, field.TypeInt)) if ps := mfu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -111,11 +75,7 @@ func (mfu *MailboxFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := mfu.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxflag.FieldValue, - }) + _spec.SetField(mailboxflag.FieldValue, field.TypeString, value) } if n, err = sqlgraph.UpdateNodes(ctx, mfu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -125,6 +85,7 @@ func (mfu *MailboxFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mfu.mutation.done = true return n, nil } @@ -147,6 +108,12 @@ func (mfuo *MailboxFlagUpdateOne) Mutation() *MailboxFlagMutation { return mfuo.mutation } +// Where appends a list predicates to the MailboxFlagUpdate builder. +func (mfuo *MailboxFlagUpdateOne) Where(ps ...predicate.MailboxFlag) *MailboxFlagUpdateOne { + mfuo.mutation.Where(ps...) + return mfuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (mfuo *MailboxFlagUpdateOne) Select(field string, fields ...string) *MailboxFlagUpdateOne { @@ -156,40 +123,7 @@ func (mfuo *MailboxFlagUpdateOne) Select(field string, fields ...string) *Mailbo // Save executes the query and returns the updated MailboxFlag entity. func (mfuo *MailboxFlagUpdateOne) Save(ctx context.Context) (*MailboxFlag, error) { - var ( - err error - node *MailboxFlag - ) - if len(mfuo.hooks) == 0 { - node, err = mfuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mfuo.mutation = mutation - node, err = mfuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(mfuo.hooks) - 1; i >= 0; i-- { - if mfuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mfuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MailboxFlag) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxFlagMutation", v) - } - node = nv - } - return node, err + return withHooks[*MailboxFlag, MailboxFlagMutation](ctx, mfuo.sqlSave, mfuo.mutation, mfuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -215,19 +149,10 @@ func (mfuo *MailboxFlagUpdateOne) ExecX(ctx context.Context) { } func (mfuo *MailboxFlagUpdateOne) sqlSave(ctx context.Context) (_node *MailboxFlag, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxflag.Table, - Columns: mailboxflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailboxflag.Table, mailboxflag.Columns, sqlgraph.NewFieldSpec(mailboxflag.FieldID, field.TypeInt)) id, ok := mfuo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "MailboxFlag.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "MailboxFlag.id" for update`)} } _spec.Node.ID.Value = id if fields := mfuo.fields; len(fields) > 0 { @@ -235,7 +160,7 @@ func (mfuo *MailboxFlagUpdateOne) sqlSave(ctx context.Context) (_node *MailboxFl _spec.Node.Columns = append(_spec.Node.Columns, mailboxflag.FieldID) for _, f := range fields { if !mailboxflag.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != mailboxflag.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -250,11 +175,7 @@ func (mfuo *MailboxFlagUpdateOne) sqlSave(ctx context.Context) (_node *MailboxFl } } if value, ok := mfuo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxflag.FieldValue, - }) + _spec.SetField(mailboxflag.FieldValue, field.TypeString, value) } _node = &MailboxFlag{config: mfuo.config} _spec.Assign = _node.assignValues @@ -267,5 +188,6 @@ func (mfuo *MailboxFlagUpdateOne) sqlSave(ctx context.Context) (_node *MailboxFl } return nil, err } + mfuo.mutation.done = true return _node, nil } diff --git a/internal/db/ent/mailboxpermflag.go b/internal/db_impl/ent_db/internal/mailboxpermflag.go similarity index 87% rename from internal/db/ent/mailboxpermflag.go rename to internal/db_impl/ent_db/internal/mailboxpermflag.go index 1c25ac63..a67eb63a 100644 --- a/internal/db/ent/mailboxpermflag.go +++ b/internal/db_impl/ent_db/internal/mailboxpermflag.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -8,7 +8,7 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" ) // MailboxPermFlag is the model entity for the MailboxPermFlag schema. @@ -22,8 +22,8 @@ type MailboxPermFlag struct { } // scanValues returns the types for scanning values from sql.Rows. -func (*MailboxPermFlag) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*MailboxPermFlag) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case mailboxpermflag.FieldID: @@ -41,7 +41,7 @@ func (*MailboxPermFlag) scanValues(columns []string) ([]interface{}, error) { // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the MailboxPermFlag fields. -func (mpf *MailboxPermFlag) assignValues(columns []string, values []interface{}) error { +func (mpf *MailboxPermFlag) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -75,7 +75,7 @@ func (mpf *MailboxPermFlag) assignValues(columns []string, values []interface{}) // Note that you need to call MailboxPermFlag.Unwrap() before calling this method if this MailboxPermFlag // was returned from a transaction, and the transaction was committed or rolled back. func (mpf *MailboxPermFlag) Update() *MailboxPermFlagUpdateOne { - return (&MailboxPermFlagClient{config: mpf.config}).UpdateOne(mpf) + return NewMailboxPermFlagClient(mpf.config).UpdateOne(mpf) } // Unwrap unwraps the MailboxPermFlag entity that was returned from a transaction after it was closed, @@ -83,7 +83,7 @@ func (mpf *MailboxPermFlag) Update() *MailboxPermFlagUpdateOne { func (mpf *MailboxPermFlag) Unwrap() *MailboxPermFlag { _tx, ok := mpf.config.driver.(*txDriver) if !ok { - panic("ent: MailboxPermFlag is not a transactional entity") + panic("internal: MailboxPermFlag is not a transactional entity") } mpf.config.driver = _tx.drv return mpf @@ -102,9 +102,3 @@ func (mpf *MailboxPermFlag) String() string { // MailboxPermFlags is a parsable slice of MailboxPermFlag. type MailboxPermFlags []*MailboxPermFlag - -func (mpf MailboxPermFlags) config(cfg config) { - for _i := range mpf { - mpf[_i].config = cfg - } -} diff --git a/internal/db/ent/mailboxpermflag/mailboxpermflag.go b/internal/db_impl/ent_db/internal/mailboxpermflag/mailboxpermflag.go similarity index 100% rename from internal/db/ent/mailboxpermflag/mailboxpermflag.go rename to internal/db_impl/ent_db/internal/mailboxpermflag/mailboxpermflag.go diff --git a/internal/db/ent/mailboxpermflag/where.go b/internal/db_impl/ent_db/internal/mailboxpermflag/where.go similarity index 56% rename from internal/db/ent/mailboxpermflag/where.go rename to internal/db_impl/ent_db/internal/mailboxpermflag/where.go index 035aec0a..6d959434 100644 --- a/internal/db/ent/mailboxpermflag/where.go +++ b/internal/db_impl/ent_db/internal/mailboxpermflag/where.go @@ -4,184 +4,122 @@ package mailboxpermflag import ( "entgo.io/ent/dialect/sql" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MailboxPermFlag(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MailboxPermFlag(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.MailboxPermFlag(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.MailboxPermFlag(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.MailboxPermFlag(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.MailboxPermFlag(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.MailboxPermFlag(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.MailboxPermFlag(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.MailboxPermFlag(sql.FieldLTE(FieldID, id)) } // Value applies equality check predicate on the "Value" field. It's identical to ValueEQ. func Value(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldEQ(FieldValue, v)) } // ValueEQ applies the EQ predicate on the "Value" field. func ValueEQ(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "Value" field. func ValueNEQ(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "Value" field. func ValueIn(vs ...string) predicate.MailboxPermFlag { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.MailboxPermFlag(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "Value" field. func ValueNotIn(vs ...string) predicate.MailboxPermFlag { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.MailboxPermFlag(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "Value" field. func ValueGT(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "Value" field. func ValueGTE(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "Value" field. func ValueLT(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "Value" field. func ValueLTE(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "Value" field. func ValueContains(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "Value" field. func ValueHasPrefix(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "Value" field. func ValueHasSuffix(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "Value" field. func ValueEqualFold(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "Value" field. func ValueContainsFold(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldContainsFold(FieldValue, v)) } // And groups predicates with the AND operator between them. diff --git a/internal/db/ent/mailboxpermflag_create.go b/internal/db_impl/ent_db/internal/mailboxpermflag_create.go similarity index 74% rename from internal/db/ent/mailboxpermflag_create.go rename to internal/db_impl/ent_db/internal/mailboxpermflag_create.go index aa8089bd..8a9cbbd2 100644 --- a/internal/db/ent/mailboxpermflag_create.go +++ b/internal/db_impl/ent_db/internal/mailboxpermflag_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -9,7 +9,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" ) // MailboxPermFlagCreate is the builder for creating a MailboxPermFlag entity. @@ -32,49 +32,7 @@ func (mpfc *MailboxPermFlagCreate) Mutation() *MailboxPermFlagMutation { // Save creates the MailboxPermFlag in the database. func (mpfc *MailboxPermFlagCreate) Save(ctx context.Context) (*MailboxPermFlag, error) { - var ( - err error - node *MailboxPermFlag - ) - if len(mpfc.hooks) == 0 { - if err = mpfc.check(); err != nil { - return nil, err - } - node, err = mpfc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxPermFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mpfc.check(); err != nil { - return nil, err - } - mpfc.mutation = mutation - if node, err = mpfc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mpfc.hooks) - 1; i >= 0; i-- { - if mpfc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mpfc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mpfc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MailboxPermFlag) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxPermFlagMutation", v) - } - node = nv - } - return node, err + return withHooks[*MailboxPermFlag, MailboxPermFlagMutation](ctx, mpfc.sqlSave, mpfc.mutation, mpfc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -102,12 +60,15 @@ func (mpfc *MailboxPermFlagCreate) ExecX(ctx context.Context) { // check runs all checks and user-defined validators on the builder. func (mpfc *MailboxPermFlagCreate) check() error { if _, ok := mpfc.mutation.Value(); !ok { - return &ValidationError{Name: "Value", err: errors.New(`ent: missing required field "MailboxPermFlag.Value"`)} + return &ValidationError{Name: "Value", err: errors.New(`internal: missing required field "MailboxPermFlag.Value"`)} } return nil } func (mpfc *MailboxPermFlagCreate) sqlSave(ctx context.Context) (*MailboxPermFlag, error) { + if err := mpfc.check(); err != nil { + return nil, err + } _node, _spec := mpfc.createSpec() if err := sqlgraph.CreateNode(ctx, mpfc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -117,26 +78,18 @@ func (mpfc *MailboxPermFlagCreate) sqlSave(ctx context.Context) (*MailboxPermFla } id := _spec.ID.Value.(int64) _node.ID = int(id) + mpfc.mutation.id = &_node.ID + mpfc.mutation.done = true return _node, nil } func (mpfc *MailboxPermFlagCreate) createSpec() (*MailboxPermFlag, *sqlgraph.CreateSpec) { var ( _node = &MailboxPermFlag{config: mpfc.config} - _spec = &sqlgraph.CreateSpec{ - Table: mailboxpermflag.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxpermflag.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(mailboxpermflag.Table, sqlgraph.NewFieldSpec(mailboxpermflag.FieldID, field.TypeInt)) ) if value, ok := mpfc.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxpermflag.FieldValue, - }) + _spec.SetField(mailboxpermflag.FieldValue, field.TypeString, value) _node.Value = value } return _node, _spec diff --git a/internal/db/ent/mailboxpermflag_delete.go b/internal/db_impl/ent_db/internal/mailboxpermflag_delete.go similarity index 63% rename from internal/db/ent/mailboxpermflag_delete.go rename to internal/db_impl/ent_db/internal/mailboxpermflag_delete.go index 0e2aa06d..2f2b48c9 100644 --- a/internal/db/ent/mailboxpermflag_delete.go +++ b/internal/db_impl/ent_db/internal/mailboxpermflag_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxPermFlagDelete is the builder for deleting a MailboxPermFlag entity. @@ -28,34 +27,7 @@ func (mpfd *MailboxPermFlagDelete) Where(ps ...predicate.MailboxPermFlag) *Mailb // Exec executes the deletion query and returns how many vertices were deleted. func (mpfd *MailboxPermFlagDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mpfd.hooks) == 0 { - affected, err = mpfd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxPermFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mpfd.mutation = mutation - affected, err = mpfd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(mpfd.hooks) - 1; i >= 0; i-- { - if mpfd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mpfd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mpfd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxPermFlagMutation](ctx, mpfd.sqlExec, mpfd.mutation, mpfd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (mpfd *MailboxPermFlagDelete) ExecX(ctx context.Context) int { } func (mpfd *MailboxPermFlagDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxpermflag.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxpermflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(mailboxpermflag.Table, sqlgraph.NewFieldSpec(mailboxpermflag.FieldID, field.TypeInt)) if ps := mpfd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (mpfd *MailboxPermFlagDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + mpfd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MailboxPermFlagDeleteOne struct { mpfd *MailboxPermFlagDelete } +// Where appends a list predicates to the MailboxPermFlagDelete builder. +func (mpfdo *MailboxPermFlagDeleteOne) Where(ps ...predicate.MailboxPermFlag) *MailboxPermFlagDeleteOne { + mpfdo.mpfd.mutation.Where(ps...) + return mpfdo +} + // Exec executes the deletion query. func (mpfdo *MailboxPermFlagDeleteOne) Exec(ctx context.Context) error { n, err := mpfdo.mpfd.Exec(ctx) @@ -111,5 +82,7 @@ func (mpfdo *MailboxPermFlagDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mpfdo *MailboxPermFlagDeleteOne) ExecX(ctx context.Context) { - mpfdo.mpfd.ExecX(ctx) + if err := mpfdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/mailboxpermflag_query.go b/internal/db_impl/ent_db/internal/mailboxpermflag_query.go similarity index 68% rename from internal/db/ent/mailboxpermflag_query.go rename to internal/db_impl/ent_db/internal/mailboxpermflag_query.go index a5b1f983..d955eb89 100644 --- a/internal/db/ent/mailboxpermflag_query.go +++ b/internal/db_impl/ent_db/internal/mailboxpermflag_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,18 +10,16 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxPermFlagQuery is the builder for querying MailboxPermFlag entities. type MailboxPermFlagQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.MailboxPermFlag withFKs bool // intermediate query (i.e. traversal path). @@ -35,26 +33,26 @@ func (mpfq *MailboxPermFlagQuery) Where(ps ...predicate.MailboxPermFlag) *Mailbo return mpfq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mpfq *MailboxPermFlagQuery) Limit(limit int) *MailboxPermFlagQuery { - mpfq.limit = &limit + mpfq.ctx.Limit = &limit return mpfq } -// Offset adds an offset step to the query. +// Offset to start from. func (mpfq *MailboxPermFlagQuery) Offset(offset int) *MailboxPermFlagQuery { - mpfq.offset = &offset + mpfq.ctx.Offset = &offset return mpfq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mpfq *MailboxPermFlagQuery) Unique(unique bool) *MailboxPermFlagQuery { - mpfq.unique = &unique + mpfq.ctx.Unique = &unique return mpfq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (mpfq *MailboxPermFlagQuery) Order(o ...OrderFunc) *MailboxPermFlagQuery { mpfq.order = append(mpfq.order, o...) return mpfq @@ -63,7 +61,7 @@ func (mpfq *MailboxPermFlagQuery) Order(o ...OrderFunc) *MailboxPermFlagQuery { // First returns the first MailboxPermFlag entity from the query. // Returns a *NotFoundError when no MailboxPermFlag was found. func (mpfq *MailboxPermFlagQuery) First(ctx context.Context) (*MailboxPermFlag, error) { - nodes, err := mpfq.Limit(1).All(ctx) + nodes, err := mpfq.Limit(1).All(setContextOp(ctx, mpfq.ctx, "First")) if err != nil { return nil, err } @@ -86,7 +84,7 @@ func (mpfq *MailboxPermFlagQuery) FirstX(ctx context.Context) *MailboxPermFlag { // Returns a *NotFoundError when no MailboxPermFlag ID was found. func (mpfq *MailboxPermFlagQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mpfq.Limit(1).IDs(ctx); err != nil { + if ids, err = mpfq.Limit(1).IDs(setContextOp(ctx, mpfq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -109,7 +107,7 @@ func (mpfq *MailboxPermFlagQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one MailboxPermFlag entity is found. // Returns a *NotFoundError when no MailboxPermFlag entities are found. func (mpfq *MailboxPermFlagQuery) Only(ctx context.Context) (*MailboxPermFlag, error) { - nodes, err := mpfq.Limit(2).All(ctx) + nodes, err := mpfq.Limit(2).All(setContextOp(ctx, mpfq.ctx, "Only")) if err != nil { return nil, err } @@ -137,7 +135,7 @@ func (mpfq *MailboxPermFlagQuery) OnlyX(ctx context.Context) *MailboxPermFlag { // Returns a *NotFoundError when no entities are found. func (mpfq *MailboxPermFlagQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mpfq.Limit(2).IDs(ctx); err != nil { + if ids, err = mpfq.Limit(2).IDs(setContextOp(ctx, mpfq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -162,10 +160,12 @@ func (mpfq *MailboxPermFlagQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of MailboxPermFlags. func (mpfq *MailboxPermFlagQuery) All(ctx context.Context) ([]*MailboxPermFlag, error) { + ctx = setContextOp(ctx, mpfq.ctx, "All") if err := mpfq.prepareQuery(ctx); err != nil { return nil, err } - return mpfq.sqlAll(ctx) + qr := querierAll[[]*MailboxPermFlag, *MailboxPermFlagQuery]() + return withInterceptors[[]*MailboxPermFlag](ctx, mpfq, qr, mpfq.inters) } // AllX is like All, but panics if an error occurs. @@ -178,9 +178,12 @@ func (mpfq *MailboxPermFlagQuery) AllX(ctx context.Context) []*MailboxPermFlag { } // IDs executes the query and returns a list of MailboxPermFlag IDs. -func (mpfq *MailboxPermFlagQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := mpfq.Select(mailboxpermflag.FieldID).Scan(ctx, &ids); err != nil { +func (mpfq *MailboxPermFlagQuery) IDs(ctx context.Context) (ids []int, err error) { + if mpfq.ctx.Unique == nil && mpfq.path != nil { + mpfq.Unique(true) + } + ctx = setContextOp(ctx, mpfq.ctx, "IDs") + if err = mpfq.Select(mailboxpermflag.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -197,10 +200,11 @@ func (mpfq *MailboxPermFlagQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (mpfq *MailboxPermFlagQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mpfq.ctx, "Count") if err := mpfq.prepareQuery(ctx); err != nil { return 0, err } - return mpfq.sqlCount(ctx) + return withInterceptors[int](ctx, mpfq, querierCount[*MailboxPermFlagQuery](), mpfq.inters) } // CountX is like Count, but panics if an error occurs. @@ -214,10 +218,15 @@ func (mpfq *MailboxPermFlagQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mpfq *MailboxPermFlagQuery) Exist(ctx context.Context) (bool, error) { - if err := mpfq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mpfq.ctx, "Exist") + switch _, err := mpfq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return mpfq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -237,14 +246,13 @@ func (mpfq *MailboxPermFlagQuery) Clone() *MailboxPermFlagQuery { } return &MailboxPermFlagQuery{ config: mpfq.config, - limit: mpfq.limit, - offset: mpfq.offset, + ctx: mpfq.ctx.Clone(), order: append([]OrderFunc{}, mpfq.order...), + inters: append([]Interceptor{}, mpfq.inters...), predicates: append([]predicate.MailboxPermFlag{}, mpfq.predicates...), // clone intermediate query. - sql: mpfq.sql.Clone(), - path: mpfq.path, - unique: mpfq.unique, + sql: mpfq.sql.Clone(), + path: mpfq.path, } } @@ -260,19 +268,14 @@ func (mpfq *MailboxPermFlagQuery) Clone() *MailboxPermFlagQuery { // // client.MailboxPermFlag.Query(). // GroupBy(mailboxpermflag.FieldValue). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (mpfq *MailboxPermFlagQuery) GroupBy(field string, fields ...string) *MailboxPermFlagGroupBy { - grbuild := &MailboxPermFlagGroupBy{config: mpfq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mpfq.prepareQuery(ctx); err != nil { - return nil, err - } - return mpfq.sqlQuery(ctx), nil - } + mpfq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MailboxPermFlagGroupBy{build: mpfq} + grbuild.flds = &mpfq.ctx.Fields grbuild.label = mailboxpermflag.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -289,17 +292,32 @@ func (mpfq *MailboxPermFlagQuery) GroupBy(field string, fields ...string) *Mailb // Select(mailboxpermflag.FieldValue). // Scan(ctx, &v) func (mpfq *MailboxPermFlagQuery) Select(fields ...string) *MailboxPermFlagSelect { - mpfq.fields = append(mpfq.fields, fields...) - selbuild := &MailboxPermFlagSelect{MailboxPermFlagQuery: mpfq} - selbuild.label = mailboxpermflag.Label - selbuild.flds, selbuild.scan = &mpfq.fields, selbuild.Scan - return selbuild + mpfq.ctx.Fields = append(mpfq.ctx.Fields, fields...) + sbuild := &MailboxPermFlagSelect{MailboxPermFlagQuery: mpfq} + sbuild.label = mailboxpermflag.Label + sbuild.flds, sbuild.scan = &mpfq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MailboxPermFlagSelect configured with the given aggregations. +func (mpfq *MailboxPermFlagQuery) Aggregate(fns ...AggregateFunc) *MailboxPermFlagSelect { + return mpfq.Select().Aggregate(fns...) } func (mpfq *MailboxPermFlagQuery) prepareQuery(ctx context.Context) error { - for _, f := range mpfq.fields { + for _, inter := range mpfq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mpfq); err != nil { + return err + } + } + } + for _, f := range mpfq.ctx.Fields { if !mailboxpermflag.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if mpfq.path != nil { @@ -321,10 +339,10 @@ func (mpfq *MailboxPermFlagQuery) sqlAll(ctx context.Context, hooks ...queryHook if withFKs { _spec.Node.Columns = append(_spec.Node.Columns, mailboxpermflag.ForeignKeys...) } - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*MailboxPermFlag).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &MailboxPermFlag{config: mpfq.config} nodes = append(nodes, node) return node.assignValues(columns, values) @@ -343,38 +361,22 @@ func (mpfq *MailboxPermFlagQuery) sqlAll(ctx context.Context, hooks ...queryHook func (mpfq *MailboxPermFlagQuery) sqlCount(ctx context.Context) (int, error) { _spec := mpfq.querySpec() - _spec.Node.Columns = mpfq.fields - if len(mpfq.fields) > 0 { - _spec.Unique = mpfq.unique != nil && *mpfq.unique + _spec.Node.Columns = mpfq.ctx.Fields + if len(mpfq.ctx.Fields) > 0 { + _spec.Unique = mpfq.ctx.Unique != nil && *mpfq.ctx.Unique } return sqlgraph.CountNodes(ctx, mpfq.driver, _spec) } -func (mpfq *MailboxPermFlagQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := mpfq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (mpfq *MailboxPermFlagQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxpermflag.Table, - Columns: mailboxpermflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxpermflag.FieldID, - }, - }, - From: mpfq.sql, - Unique: true, - } - if unique := mpfq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(mailboxpermflag.Table, mailboxpermflag.Columns, sqlgraph.NewFieldSpec(mailboxpermflag.FieldID, field.TypeInt)) + _spec.From = mpfq.sql + if unique := mpfq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mpfq.path != nil { + _spec.Unique = true } - if fields := mpfq.fields; len(fields) > 0 { + if fields := mpfq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, mailboxpermflag.FieldID) for i := range fields { @@ -390,10 +392,10 @@ func (mpfq *MailboxPermFlagQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mpfq.limit; limit != nil { + if limit := mpfq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mpfq.offset; offset != nil { + if offset := mpfq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mpfq.order; len(ps) > 0 { @@ -409,7 +411,7 @@ func (mpfq *MailboxPermFlagQuery) querySpec() *sqlgraph.QuerySpec { func (mpfq *MailboxPermFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mpfq.driver.Dialect()) t1 := builder.Table(mailboxpermflag.Table) - columns := mpfq.fields + columns := mpfq.ctx.Fields if len(columns) == 0 { columns = mailboxpermflag.Columns } @@ -418,7 +420,7 @@ func (mpfq *MailboxPermFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mpfq.sql selector.Select(selector.Columns(columns...)...) } - if mpfq.unique != nil && *mpfq.unique { + if mpfq.ctx.Unique != nil && *mpfq.ctx.Unique { selector.Distinct() } for _, p := range mpfq.predicates { @@ -427,12 +429,12 @@ func (mpfq *MailboxPermFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mpfq.order { p(selector) } - if offset := mpfq.offset; offset != nil { + if offset := mpfq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mpfq.limit; limit != nil { + if limit := mpfq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -440,13 +442,8 @@ func (mpfq *MailboxPermFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { // MailboxPermFlagGroupBy is the group-by builder for MailboxPermFlag entities. type MailboxPermFlagGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MailboxPermFlagQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -455,74 +452,77 @@ func (mpfgb *MailboxPermFlagGroupBy) Aggregate(fns ...AggregateFunc) *MailboxPer return mpfgb } -// Scan applies the group-by query and scans the result into the given value. -func (mpfgb *MailboxPermFlagGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := mpfgb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (mpfgb *MailboxPermFlagGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mpfgb.build.ctx, "GroupBy") + if err := mpfgb.build.prepareQuery(ctx); err != nil { return err } - mpfgb.sql = query - return mpfgb.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxPermFlagQuery, *MailboxPermFlagGroupBy](ctx, mpfgb.build, mpfgb, mpfgb.build.inters, v) } -func (mpfgb *MailboxPermFlagGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range mpfgb.fields { - if !mailboxpermflag.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mpfgb *MailboxPermFlagGroupBy) sqlScan(ctx context.Context, root *MailboxPermFlagQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mpfgb.fns)) + for _, fn := range mpfgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mpfgb.flds)+len(mpfgb.fns)) + for _, f := range *mpfgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mpfgb.sqlQuery() + selector.GroupBy(selector.Columns(*mpfgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mpfgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mpfgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mpfgb *MailboxPermFlagGroupBy) sqlQuery() *sql.Selector { - selector := mpfgb.sql.Select() - aggregation := make([]string, 0, len(mpfgb.fns)) - for _, fn := range mpfgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mpfgb.fields)+len(mpfgb.fns)) - for _, f := range mpfgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mpfgb.fields...)...) -} - // MailboxPermFlagSelect is the builder for selecting fields of MailboxPermFlag entities. type MailboxPermFlagSelect struct { *MailboxPermFlagQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (mpfs *MailboxPermFlagSelect) Aggregate(fns ...AggregateFunc) *MailboxPermFlagSelect { + mpfs.fns = append(mpfs.fns, fns...) + return mpfs } // Scan applies the selector query and scans the result into the given value. -func (mpfs *MailboxPermFlagSelect) Scan(ctx context.Context, v interface{}) error { +func (mpfs *MailboxPermFlagSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mpfs.ctx, "Select") if err := mpfs.prepareQuery(ctx); err != nil { return err } - mpfs.sql = mpfs.MailboxPermFlagQuery.sqlQuery(ctx) - return mpfs.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxPermFlagQuery, *MailboxPermFlagSelect](ctx, mpfs.MailboxPermFlagQuery, mpfs, mpfs.inters, v) } -func (mpfs *MailboxPermFlagSelect) sqlScan(ctx context.Context, v interface{}) error { +func (mpfs *MailboxPermFlagSelect) sqlScan(ctx context.Context, root *MailboxPermFlagQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(mpfs.fns)) + for _, fn := range mpfs.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*mpfs.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := mpfs.sql.Query() + query, args := selector.Query() if err := mpfs.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/mailboxpermflag_update.go b/internal/db_impl/ent_db/internal/mailboxpermflag_update.go similarity index 64% rename from internal/db/ent/mailboxpermflag_update.go rename to internal/db_impl/ent_db/internal/mailboxpermflag_update.go index bbfd2583..d5a13206 100644 --- a/internal/db/ent/mailboxpermflag_update.go +++ b/internal/db_impl/ent_db/internal/mailboxpermflag_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,8 +10,8 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxPermFlagUpdate is the builder for updating MailboxPermFlag entities. @@ -40,34 +40,7 @@ func (mpfu *MailboxPermFlagUpdate) Mutation() *MailboxPermFlagMutation { // Save executes the query and returns the number of nodes affected by the update operation. func (mpfu *MailboxPermFlagUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mpfu.hooks) == 0 { - affected, err = mpfu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxPermFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mpfu.mutation = mutation - affected, err = mpfu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mpfu.hooks) - 1; i >= 0; i-- { - if mpfu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mpfu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mpfu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxPermFlagMutation](ctx, mpfu.sqlSave, mpfu.mutation, mpfu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -93,16 +66,7 @@ func (mpfu *MailboxPermFlagUpdate) ExecX(ctx context.Context) { } func (mpfu *MailboxPermFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxpermflag.Table, - Columns: mailboxpermflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxpermflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailboxpermflag.Table, mailboxpermflag.Columns, sqlgraph.NewFieldSpec(mailboxpermflag.FieldID, field.TypeInt)) if ps := mpfu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -111,11 +75,7 @@ func (mpfu *MailboxPermFlagUpdate) sqlSave(ctx context.Context) (n int, err erro } } if value, ok := mpfu.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxpermflag.FieldValue, - }) + _spec.SetField(mailboxpermflag.FieldValue, field.TypeString, value) } if n, err = sqlgraph.UpdateNodes(ctx, mpfu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -125,6 +85,7 @@ func (mpfu *MailboxPermFlagUpdate) sqlSave(ctx context.Context) (n int, err erro } return 0, err } + mpfu.mutation.done = true return n, nil } @@ -147,6 +108,12 @@ func (mpfuo *MailboxPermFlagUpdateOne) Mutation() *MailboxPermFlagMutation { return mpfuo.mutation } +// Where appends a list predicates to the MailboxPermFlagUpdate builder. +func (mpfuo *MailboxPermFlagUpdateOne) Where(ps ...predicate.MailboxPermFlag) *MailboxPermFlagUpdateOne { + mpfuo.mutation.Where(ps...) + return mpfuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (mpfuo *MailboxPermFlagUpdateOne) Select(field string, fields ...string) *MailboxPermFlagUpdateOne { @@ -156,40 +123,7 @@ func (mpfuo *MailboxPermFlagUpdateOne) Select(field string, fields ...string) *M // Save executes the query and returns the updated MailboxPermFlag entity. func (mpfuo *MailboxPermFlagUpdateOne) Save(ctx context.Context) (*MailboxPermFlag, error) { - var ( - err error - node *MailboxPermFlag - ) - if len(mpfuo.hooks) == 0 { - node, err = mpfuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxPermFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mpfuo.mutation = mutation - node, err = mpfuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(mpfuo.hooks) - 1; i >= 0; i-- { - if mpfuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mpfuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mpfuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MailboxPermFlag) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxPermFlagMutation", v) - } - node = nv - } - return node, err + return withHooks[*MailboxPermFlag, MailboxPermFlagMutation](ctx, mpfuo.sqlSave, mpfuo.mutation, mpfuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -215,19 +149,10 @@ func (mpfuo *MailboxPermFlagUpdateOne) ExecX(ctx context.Context) { } func (mpfuo *MailboxPermFlagUpdateOne) sqlSave(ctx context.Context) (_node *MailboxPermFlag, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxpermflag.Table, - Columns: mailboxpermflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxpermflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailboxpermflag.Table, mailboxpermflag.Columns, sqlgraph.NewFieldSpec(mailboxpermflag.FieldID, field.TypeInt)) id, ok := mpfuo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "MailboxPermFlag.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "MailboxPermFlag.id" for update`)} } _spec.Node.ID.Value = id if fields := mpfuo.fields; len(fields) > 0 { @@ -235,7 +160,7 @@ func (mpfuo *MailboxPermFlagUpdateOne) sqlSave(ctx context.Context) (_node *Mail _spec.Node.Columns = append(_spec.Node.Columns, mailboxpermflag.FieldID) for _, f := range fields { if !mailboxpermflag.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != mailboxpermflag.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -250,11 +175,7 @@ func (mpfuo *MailboxPermFlagUpdateOne) sqlSave(ctx context.Context) (_node *Mail } } if value, ok := mpfuo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxpermflag.FieldValue, - }) + _spec.SetField(mailboxpermflag.FieldValue, field.TypeString, value) } _node = &MailboxPermFlag{config: mpfuo.config} _spec.Assign = _node.assignValues @@ -267,5 +188,6 @@ func (mpfuo *MailboxPermFlagUpdateOne) sqlSave(ctx context.Context) (_node *Mail } return nil, err } + mpfuo.mutation.done = true return _node, nil } diff --git a/internal/db/ent/message.go b/internal/db_impl/ent_db/internal/message.go similarity index 91% rename from internal/db/ent/message.go rename to internal/db_impl/ent_db/internal/message.go index 3c8ac699..5b0f483e 100644 --- a/internal/db/ent/message.go +++ b/internal/db_impl/ent_db/internal/message.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -9,7 +9,7 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" ) // Message is the model entity for the Message schema. @@ -66,8 +66,8 @@ func (e MessageEdges) UIDsOrErr() ([]*UID, error) { } // scanValues returns the types for scanning values from sql.Rows. -func (*Message) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*Message) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case message.FieldID: @@ -89,7 +89,7 @@ func (*Message) scanValues(columns []string) ([]interface{}, error) { // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the Message fields. -func (m *Message) assignValues(columns []string, values []interface{}) error { +func (m *Message) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -150,19 +150,19 @@ func (m *Message) assignValues(columns []string, values []interface{}) error { // QueryFlags queries the "flags" edge of the Message entity. func (m *Message) QueryFlags() *MessageFlagQuery { - return (&MessageClient{config: m.config}).QueryFlags(m) + return NewMessageClient(m.config).QueryFlags(m) } // QueryUIDs queries the "UIDs" edge of the Message entity. func (m *Message) QueryUIDs() *UIDQuery { - return (&MessageClient{config: m.config}).QueryUIDs(m) + return NewMessageClient(m.config).QueryUIDs(m) } // Update returns a builder for updating this Message. // Note that you need to call Message.Unwrap() before calling this method if this Message // was returned from a transaction, and the transaction was committed or rolled back. func (m *Message) Update() *MessageUpdateOne { - return (&MessageClient{config: m.config}).UpdateOne(m) + return NewMessageClient(m.config).UpdateOne(m) } // Unwrap unwraps the Message entity that was returned from a transaction after it was closed, @@ -170,7 +170,7 @@ func (m *Message) Update() *MessageUpdateOne { func (m *Message) Unwrap() *Message { _tx, ok := m.config.driver.(*txDriver) if !ok { - panic("ent: Message is not a transactional entity") + panic("internal: Message is not a transactional entity") } m.config.driver = _tx.drv return m @@ -207,9 +207,3 @@ func (m *Message) String() string { // Messages is a parsable slice of Message. type Messages []*Message - -func (m Messages) config(cfg config) { - for _i := range m { - m[_i].config = cfg - } -} diff --git a/internal/db/ent/message/message.go b/internal/db_impl/ent_db/internal/message/message.go similarity index 100% rename from internal/db/ent/message/message.go rename to internal/db_impl/ent_db/internal/message/message.go diff --git a/internal/db/ent/message/where.go b/internal/db_impl/ent_db/internal/message/where.go similarity index 58% rename from internal/db/ent/message/where.go rename to internal/db_impl/ent_db/internal/message/where.go index 28c524f8..c2cc30ef 100644 --- a/internal/db/ent/message/where.go +++ b/internal/db_impl/ent_db/internal/message/where.go @@ -8,691 +8,467 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Message(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Message(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Message(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Message(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Message(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Message(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Message(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Message(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Message(sql.FieldLTE(FieldID, id)) } // RemoteID applies equality check predicate on the "RemoteID" field. It's identical to RemoteIDEQ. func RemoteID(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldEQ(FieldRemoteID, vc)) } // Date applies equality check predicate on the "Date" field. It's identical to DateEQ. func Date(v time.Time) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldDate), v)) - }) + return predicate.Message(sql.FieldEQ(FieldDate, v)) } // Size applies equality check predicate on the "Size" field. It's identical to SizeEQ. func Size(v int) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSize), v)) - }) + return predicate.Message(sql.FieldEQ(FieldSize, v)) } // Body applies equality check predicate on the "Body" field. It's identical to BodyEQ. func Body(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldEQ(FieldBody, v)) } // BodyStructure applies equality check predicate on the "BodyStructure" field. It's identical to BodyStructureEQ. func BodyStructure(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldEQ(FieldBodyStructure, v)) } // Envelope applies equality check predicate on the "Envelope" field. It's identical to EnvelopeEQ. func Envelope(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldEQ(FieldEnvelope, v)) } // Deleted applies equality check predicate on the "Deleted" field. It's identical to DeletedEQ. func Deleted(v bool) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldDeleted), v)) - }) + return predicate.Message(sql.FieldEQ(FieldDeleted, v)) } // RemoteIDEQ applies the EQ predicate on the "RemoteID" field. func RemoteIDEQ(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldEQ(FieldRemoteID, vc)) } // RemoteIDNEQ applies the NEQ predicate on the "RemoteID" field. func RemoteIDNEQ(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldNEQ(FieldRemoteID, vc)) } // RemoteIDIn applies the In predicate on the "RemoteID" field. func RemoteIDIn(vs ...imap.MessageID) predicate.Message { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = string(vs[i]) } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldRemoteID), v...)) - }) + return predicate.Message(sql.FieldIn(FieldRemoteID, v...)) } // RemoteIDNotIn applies the NotIn predicate on the "RemoteID" field. func RemoteIDNotIn(vs ...imap.MessageID) predicate.Message { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = string(vs[i]) } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldRemoteID), v...)) - }) + return predicate.Message(sql.FieldNotIn(FieldRemoteID, v...)) } // RemoteIDGT applies the GT predicate on the "RemoteID" field. func RemoteIDGT(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldGT(FieldRemoteID, vc)) } // RemoteIDGTE applies the GTE predicate on the "RemoteID" field. func RemoteIDGTE(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldGTE(FieldRemoteID, vc)) } // RemoteIDLT applies the LT predicate on the "RemoteID" field. func RemoteIDLT(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldLT(FieldRemoteID, vc)) } // RemoteIDLTE applies the LTE predicate on the "RemoteID" field. func RemoteIDLTE(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldLTE(FieldRemoteID, vc)) } // RemoteIDContains applies the Contains predicate on the "RemoteID" field. func RemoteIDContains(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldContains(FieldRemoteID, vc)) } // RemoteIDHasPrefix applies the HasPrefix predicate on the "RemoteID" field. func RemoteIDHasPrefix(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldHasPrefix(FieldRemoteID, vc)) } // RemoteIDHasSuffix applies the HasSuffix predicate on the "RemoteID" field. func RemoteIDHasSuffix(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldHasSuffix(FieldRemoteID, vc)) } // RemoteIDIsNil applies the IsNil predicate on the "RemoteID" field. func RemoteIDIsNil() predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldRemoteID))) - }) + return predicate.Message(sql.FieldIsNull(FieldRemoteID)) } // RemoteIDNotNil applies the NotNil predicate on the "RemoteID" field. func RemoteIDNotNil() predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldRemoteID))) - }) + return predicate.Message(sql.FieldNotNull(FieldRemoteID)) } // RemoteIDEqualFold applies the EqualFold predicate on the "RemoteID" field. func RemoteIDEqualFold(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldEqualFold(FieldRemoteID, vc)) } // RemoteIDContainsFold applies the ContainsFold predicate on the "RemoteID" field. func RemoteIDContainsFold(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldContainsFold(FieldRemoteID, vc)) } // DateEQ applies the EQ predicate on the "Date" field. func DateEQ(v time.Time) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldDate), v)) - }) + return predicate.Message(sql.FieldEQ(FieldDate, v)) } // DateNEQ applies the NEQ predicate on the "Date" field. func DateNEQ(v time.Time) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldDate), v)) - }) + return predicate.Message(sql.FieldNEQ(FieldDate, v)) } // DateIn applies the In predicate on the "Date" field. func DateIn(vs ...time.Time) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldDate), v...)) - }) + return predicate.Message(sql.FieldIn(FieldDate, vs...)) } // DateNotIn applies the NotIn predicate on the "Date" field. func DateNotIn(vs ...time.Time) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldDate), v...)) - }) + return predicate.Message(sql.FieldNotIn(FieldDate, vs...)) } // DateGT applies the GT predicate on the "Date" field. func DateGT(v time.Time) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldDate), v)) - }) + return predicate.Message(sql.FieldGT(FieldDate, v)) } // DateGTE applies the GTE predicate on the "Date" field. func DateGTE(v time.Time) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldDate), v)) - }) + return predicate.Message(sql.FieldGTE(FieldDate, v)) } // DateLT applies the LT predicate on the "Date" field. func DateLT(v time.Time) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldDate), v)) - }) + return predicate.Message(sql.FieldLT(FieldDate, v)) } // DateLTE applies the LTE predicate on the "Date" field. func DateLTE(v time.Time) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldDate), v)) - }) + return predicate.Message(sql.FieldLTE(FieldDate, v)) } // SizeEQ applies the EQ predicate on the "Size" field. func SizeEQ(v int) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSize), v)) - }) + return predicate.Message(sql.FieldEQ(FieldSize, v)) } // SizeNEQ applies the NEQ predicate on the "Size" field. func SizeNEQ(v int) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSize), v)) - }) + return predicate.Message(sql.FieldNEQ(FieldSize, v)) } // SizeIn applies the In predicate on the "Size" field. func SizeIn(vs ...int) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSize), v...)) - }) + return predicate.Message(sql.FieldIn(FieldSize, vs...)) } // SizeNotIn applies the NotIn predicate on the "Size" field. func SizeNotIn(vs ...int) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSize), v...)) - }) + return predicate.Message(sql.FieldNotIn(FieldSize, vs...)) } // SizeGT applies the GT predicate on the "Size" field. func SizeGT(v int) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSize), v)) - }) + return predicate.Message(sql.FieldGT(FieldSize, v)) } // SizeGTE applies the GTE predicate on the "Size" field. func SizeGTE(v int) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSize), v)) - }) + return predicate.Message(sql.FieldGTE(FieldSize, v)) } // SizeLT applies the LT predicate on the "Size" field. func SizeLT(v int) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSize), v)) - }) + return predicate.Message(sql.FieldLT(FieldSize, v)) } // SizeLTE applies the LTE predicate on the "Size" field. func SizeLTE(v int) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSize), v)) - }) + return predicate.Message(sql.FieldLTE(FieldSize, v)) } // BodyEQ applies the EQ predicate on the "Body" field. func BodyEQ(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldEQ(FieldBody, v)) } // BodyNEQ applies the NEQ predicate on the "Body" field. func BodyNEQ(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldNEQ(FieldBody, v)) } // BodyIn applies the In predicate on the "Body" field. func BodyIn(vs ...string) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldBody), v...)) - }) + return predicate.Message(sql.FieldIn(FieldBody, vs...)) } // BodyNotIn applies the NotIn predicate on the "Body" field. func BodyNotIn(vs ...string) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldBody), v...)) - }) + return predicate.Message(sql.FieldNotIn(FieldBody, vs...)) } // BodyGT applies the GT predicate on the "Body" field. func BodyGT(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldGT(FieldBody, v)) } // BodyGTE applies the GTE predicate on the "Body" field. func BodyGTE(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldGTE(FieldBody, v)) } // BodyLT applies the LT predicate on the "Body" field. func BodyLT(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldLT(FieldBody, v)) } // BodyLTE applies the LTE predicate on the "Body" field. func BodyLTE(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldLTE(FieldBody, v)) } // BodyContains applies the Contains predicate on the "Body" field. func BodyContains(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldContains(FieldBody, v)) } // BodyHasPrefix applies the HasPrefix predicate on the "Body" field. func BodyHasPrefix(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldHasPrefix(FieldBody, v)) } // BodyHasSuffix applies the HasSuffix predicate on the "Body" field. func BodyHasSuffix(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldHasSuffix(FieldBody, v)) } // BodyEqualFold applies the EqualFold predicate on the "Body" field. func BodyEqualFold(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldEqualFold(FieldBody, v)) } // BodyContainsFold applies the ContainsFold predicate on the "Body" field. func BodyContainsFold(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldContainsFold(FieldBody, v)) } // BodyStructureEQ applies the EQ predicate on the "BodyStructure" field. func BodyStructureEQ(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldEQ(FieldBodyStructure, v)) } // BodyStructureNEQ applies the NEQ predicate on the "BodyStructure" field. func BodyStructureNEQ(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldNEQ(FieldBodyStructure, v)) } // BodyStructureIn applies the In predicate on the "BodyStructure" field. func BodyStructureIn(vs ...string) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldBodyStructure), v...)) - }) + return predicate.Message(sql.FieldIn(FieldBodyStructure, vs...)) } // BodyStructureNotIn applies the NotIn predicate on the "BodyStructure" field. func BodyStructureNotIn(vs ...string) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldBodyStructure), v...)) - }) + return predicate.Message(sql.FieldNotIn(FieldBodyStructure, vs...)) } // BodyStructureGT applies the GT predicate on the "BodyStructure" field. func BodyStructureGT(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldGT(FieldBodyStructure, v)) } // BodyStructureGTE applies the GTE predicate on the "BodyStructure" field. func BodyStructureGTE(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldGTE(FieldBodyStructure, v)) } // BodyStructureLT applies the LT predicate on the "BodyStructure" field. func BodyStructureLT(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldLT(FieldBodyStructure, v)) } // BodyStructureLTE applies the LTE predicate on the "BodyStructure" field. func BodyStructureLTE(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldLTE(FieldBodyStructure, v)) } // BodyStructureContains applies the Contains predicate on the "BodyStructure" field. func BodyStructureContains(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldContains(FieldBodyStructure, v)) } // BodyStructureHasPrefix applies the HasPrefix predicate on the "BodyStructure" field. func BodyStructureHasPrefix(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldHasPrefix(FieldBodyStructure, v)) } // BodyStructureHasSuffix applies the HasSuffix predicate on the "BodyStructure" field. func BodyStructureHasSuffix(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldHasSuffix(FieldBodyStructure, v)) } // BodyStructureEqualFold applies the EqualFold predicate on the "BodyStructure" field. func BodyStructureEqualFold(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldEqualFold(FieldBodyStructure, v)) } // BodyStructureContainsFold applies the ContainsFold predicate on the "BodyStructure" field. func BodyStructureContainsFold(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldContainsFold(FieldBodyStructure, v)) } // EnvelopeEQ applies the EQ predicate on the "Envelope" field. func EnvelopeEQ(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldEQ(FieldEnvelope, v)) } // EnvelopeNEQ applies the NEQ predicate on the "Envelope" field. func EnvelopeNEQ(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldNEQ(FieldEnvelope, v)) } // EnvelopeIn applies the In predicate on the "Envelope" field. func EnvelopeIn(vs ...string) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldEnvelope), v...)) - }) + return predicate.Message(sql.FieldIn(FieldEnvelope, vs...)) } // EnvelopeNotIn applies the NotIn predicate on the "Envelope" field. func EnvelopeNotIn(vs ...string) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldEnvelope), v...)) - }) + return predicate.Message(sql.FieldNotIn(FieldEnvelope, vs...)) } // EnvelopeGT applies the GT predicate on the "Envelope" field. func EnvelopeGT(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldGT(FieldEnvelope, v)) } // EnvelopeGTE applies the GTE predicate on the "Envelope" field. func EnvelopeGTE(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldGTE(FieldEnvelope, v)) } // EnvelopeLT applies the LT predicate on the "Envelope" field. func EnvelopeLT(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldLT(FieldEnvelope, v)) } // EnvelopeLTE applies the LTE predicate on the "Envelope" field. func EnvelopeLTE(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldLTE(FieldEnvelope, v)) } // EnvelopeContains applies the Contains predicate on the "Envelope" field. func EnvelopeContains(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldContains(FieldEnvelope, v)) } // EnvelopeHasPrefix applies the HasPrefix predicate on the "Envelope" field. func EnvelopeHasPrefix(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldHasPrefix(FieldEnvelope, v)) } // EnvelopeHasSuffix applies the HasSuffix predicate on the "Envelope" field. func EnvelopeHasSuffix(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldHasSuffix(FieldEnvelope, v)) } // EnvelopeEqualFold applies the EqualFold predicate on the "Envelope" field. func EnvelopeEqualFold(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldEqualFold(FieldEnvelope, v)) } // EnvelopeContainsFold applies the ContainsFold predicate on the "Envelope" field. func EnvelopeContainsFold(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldContainsFold(FieldEnvelope, v)) } // DeletedEQ applies the EQ predicate on the "Deleted" field. func DeletedEQ(v bool) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldDeleted), v)) - }) + return predicate.Message(sql.FieldEQ(FieldDeleted, v)) } // DeletedNEQ applies the NEQ predicate on the "Deleted" field. func DeletedNEQ(v bool) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldDeleted), v)) - }) + return predicate.Message(sql.FieldNEQ(FieldDeleted, v)) } // HasFlags applies the HasEdge predicate on the "flags" edge. @@ -700,7 +476,6 @@ func HasFlags() predicate.Message { return predicate.Message(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(FlagsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, FlagsTable, FlagsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -728,7 +503,6 @@ func HasUIDs() predicate.Message { return predicate.Message(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(UIDsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, true, UIDsTable, UIDsColumn), ) sqlgraph.HasNeighbors(s, step) diff --git a/internal/db/ent/message_create.go b/internal/db_impl/ent_db/internal/message_create.go similarity index 74% rename from internal/db/ent/message_create.go rename to internal/db_impl/ent_db/internal/message_create.go index c9d923a0..777b8d1a 100644 --- a/internal/db/ent/message_create.go +++ b/internal/db_impl/ent_db/internal/message_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -11,9 +11,9 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // MessageCreate is the builder for creating a Message entity. @@ -124,50 +124,8 @@ func (mc *MessageCreate) Mutation() *MessageMutation { // Save creates the Message in the database. func (mc *MessageCreate) Save(ctx context.Context) (*Message, error) { - var ( - err error - node *Message - ) mc.defaults() - if len(mc.hooks) == 0 { - if err = mc.check(); err != nil { - return nil, err - } - node, err = mc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mc.check(); err != nil { - return nil, err - } - mc.mutation = mutation - if node, err = mc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mc.hooks) - 1; i >= 0; i-- { - if mc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Message) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MessageMutation", v) - } - node = nv - } - return node, err + return withHooks[*Message, MessageMutation](ctx, mc.sqlSave, mc.mutation, mc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -203,27 +161,30 @@ func (mc *MessageCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (mc *MessageCreate) check() error { if _, ok := mc.mutation.Date(); !ok { - return &ValidationError{Name: "Date", err: errors.New(`ent: missing required field "Message.Date"`)} + return &ValidationError{Name: "Date", err: errors.New(`internal: missing required field "Message.Date"`)} } if _, ok := mc.mutation.Size(); !ok { - return &ValidationError{Name: "Size", err: errors.New(`ent: missing required field "Message.Size"`)} + return &ValidationError{Name: "Size", err: errors.New(`internal: missing required field "Message.Size"`)} } if _, ok := mc.mutation.Body(); !ok { - return &ValidationError{Name: "Body", err: errors.New(`ent: missing required field "Message.Body"`)} + return &ValidationError{Name: "Body", err: errors.New(`internal: missing required field "Message.Body"`)} } if _, ok := mc.mutation.BodyStructure(); !ok { - return &ValidationError{Name: "BodyStructure", err: errors.New(`ent: missing required field "Message.BodyStructure"`)} + return &ValidationError{Name: "BodyStructure", err: errors.New(`internal: missing required field "Message.BodyStructure"`)} } if _, ok := mc.mutation.Envelope(); !ok { - return &ValidationError{Name: "Envelope", err: errors.New(`ent: missing required field "Message.Envelope"`)} + return &ValidationError{Name: "Envelope", err: errors.New(`internal: missing required field "Message.Envelope"`)} } if _, ok := mc.mutation.Deleted(); !ok { - return &ValidationError{Name: "Deleted", err: errors.New(`ent: missing required field "Message.Deleted"`)} + return &ValidationError{Name: "Deleted", err: errors.New(`internal: missing required field "Message.Deleted"`)} } return nil } func (mc *MessageCreate) sqlSave(ctx context.Context) (*Message, error) { + if err := mc.check(); err != nil { + return nil, err + } _node, _spec := mc.createSpec() if err := sqlgraph.CreateNode(ctx, mc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -238,78 +199,46 @@ func (mc *MessageCreate) sqlSave(ctx context.Context) (*Message, error) { return nil, err } } + mc.mutation.id = &_node.ID + mc.mutation.done = true return _node, nil } func (mc *MessageCreate) createSpec() (*Message, *sqlgraph.CreateSpec) { var ( _node = &Message{config: mc.config} - _spec = &sqlgraph.CreateSpec{ - Table: message.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUUID, - Column: message.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(message.Table, sqlgraph.NewFieldSpec(message.FieldID, field.TypeUUID)) ) if id, ok := mc.mutation.ID(); ok { _node.ID = id _spec.ID.Value = &id } if value, ok := mc.mutation.RemoteID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldRemoteID, - }) + _spec.SetField(message.FieldRemoteID, field.TypeString, value) _node.RemoteID = value } if value, ok := mc.mutation.Date(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: message.FieldDate, - }) + _spec.SetField(message.FieldDate, field.TypeTime, value) _node.Date = value } if value, ok := mc.mutation.Size(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Value: value, - Column: message.FieldSize, - }) + _spec.SetField(message.FieldSize, field.TypeInt, value) _node.Size = value } if value, ok := mc.mutation.Body(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldBody, - }) + _spec.SetField(message.FieldBody, field.TypeString, value) _node.Body = value } if value, ok := mc.mutation.BodyStructure(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldBodyStructure, - }) + _spec.SetField(message.FieldBodyStructure, field.TypeString, value) _node.BodyStructure = value } if value, ok := mc.mutation.Envelope(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldEnvelope, - }) + _spec.SetField(message.FieldEnvelope, field.TypeString, value) _node.Envelope = value } if value, ok := mc.mutation.Deleted(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: message.FieldDeleted, - }) + _spec.SetField(message.FieldDeleted, field.TypeBool, value) _node.Deleted = value } if nodes := mc.mutation.FlagsIDs(); len(nodes) > 0 { diff --git a/internal/db/ent/message_delete.go b/internal/db_impl/ent_db/internal/message_delete.go similarity index 62% rename from internal/db/ent/message_delete.go rename to internal/db_impl/ent_db/internal/message_delete.go index b5094106..60640466 100644 --- a/internal/db/ent/message_delete.go +++ b/internal/db_impl/ent_db/internal/message_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MessageDelete is the builder for deleting a Message entity. @@ -28,34 +27,7 @@ func (md *MessageDelete) Where(ps ...predicate.Message) *MessageDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (md *MessageDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(md.hooks) == 0 { - affected, err = md.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - md.mutation = mutation - affected, err = md.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(md.hooks) - 1; i >= 0; i-- { - if md.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = md.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, md.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MessageMutation](ctx, md.sqlExec, md.mutation, md.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (md *MessageDelete) ExecX(ctx context.Context) int { } func (md *MessageDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: message.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUUID, - Column: message.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(message.Table, sqlgraph.NewFieldSpec(message.FieldID, field.TypeUUID)) if ps := md.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (md *MessageDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + md.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MessageDeleteOne struct { md *MessageDelete } +// Where appends a list predicates to the MessageDelete builder. +func (mdo *MessageDeleteOne) Where(ps ...predicate.Message) *MessageDeleteOne { + mdo.md.mutation.Where(ps...) + return mdo +} + // Exec executes the deletion query. func (mdo *MessageDeleteOne) Exec(ctx context.Context) error { n, err := mdo.md.Exec(ctx) @@ -111,5 +82,7 @@ func (mdo *MessageDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mdo *MessageDeleteOne) ExecX(ctx context.Context) { - mdo.md.ExecX(ctx) + if err := mdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/message_query.go b/internal/db_impl/ent_db/internal/message_query.go similarity index 73% rename from internal/db/ent/message_query.go rename to internal/db_impl/ent_db/internal/message_query.go index 84416328..d1ac2c4c 100644 --- a/internal/db/ent/message_query.go +++ b/internal/db_impl/ent_db/internal/message_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -12,20 +12,18 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // MessageQuery is the builder for querying Message entities. type MessageQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.Message withFlags *MessageFlagQuery withUIDs *UIDQuery @@ -40,26 +38,26 @@ func (mq *MessageQuery) Where(ps ...predicate.Message) *MessageQuery { return mq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mq *MessageQuery) Limit(limit int) *MessageQuery { - mq.limit = &limit + mq.ctx.Limit = &limit return mq } -// Offset adds an offset step to the query. +// Offset to start from. func (mq *MessageQuery) Offset(offset int) *MessageQuery { - mq.offset = &offset + mq.ctx.Offset = &offset return mq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mq *MessageQuery) Unique(unique bool) *MessageQuery { - mq.unique = &unique + mq.ctx.Unique = &unique return mq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (mq *MessageQuery) Order(o ...OrderFunc) *MessageQuery { mq.order = append(mq.order, o...) return mq @@ -67,7 +65,7 @@ func (mq *MessageQuery) Order(o ...OrderFunc) *MessageQuery { // QueryFlags chains the current query on the "flags" edge. func (mq *MessageQuery) QueryFlags() *MessageFlagQuery { - query := &MessageFlagQuery{config: mq.config} + query := (&MessageFlagClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -89,7 +87,7 @@ func (mq *MessageQuery) QueryFlags() *MessageFlagQuery { // QueryUIDs chains the current query on the "UIDs" edge. func (mq *MessageQuery) QueryUIDs() *UIDQuery { - query := &UIDQuery{config: mq.config} + query := (&UIDClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -112,7 +110,7 @@ func (mq *MessageQuery) QueryUIDs() *UIDQuery { // First returns the first Message entity from the query. // Returns a *NotFoundError when no Message was found. func (mq *MessageQuery) First(ctx context.Context) (*Message, error) { - nodes, err := mq.Limit(1).All(ctx) + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) if err != nil { return nil, err } @@ -135,7 +133,7 @@ func (mq *MessageQuery) FirstX(ctx context.Context) *Message { // Returns a *NotFoundError when no Message ID was found. func (mq *MessageQuery) FirstID(ctx context.Context) (id imap.InternalMessageID, err error) { var ids []imap.InternalMessageID - if ids, err = mq.Limit(1).IDs(ctx); err != nil { + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -158,7 +156,7 @@ func (mq *MessageQuery) FirstIDX(ctx context.Context) imap.InternalMessageID { // Returns a *NotSingularError when more than one Message entity is found. // Returns a *NotFoundError when no Message entities are found. func (mq *MessageQuery) Only(ctx context.Context) (*Message, error) { - nodes, err := mq.Limit(2).All(ctx) + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) if err != nil { return nil, err } @@ -186,7 +184,7 @@ func (mq *MessageQuery) OnlyX(ctx context.Context) *Message { // Returns a *NotFoundError when no entities are found. func (mq *MessageQuery) OnlyID(ctx context.Context) (id imap.InternalMessageID, err error) { var ids []imap.InternalMessageID - if ids, err = mq.Limit(2).IDs(ctx); err != nil { + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -211,10 +209,12 @@ func (mq *MessageQuery) OnlyIDX(ctx context.Context) imap.InternalMessageID { // All executes the query and returns a list of Messages. func (mq *MessageQuery) All(ctx context.Context) ([]*Message, error) { + ctx = setContextOp(ctx, mq.ctx, "All") if err := mq.prepareQuery(ctx); err != nil { return nil, err } - return mq.sqlAll(ctx) + qr := querierAll[[]*Message, *MessageQuery]() + return withInterceptors[[]*Message](ctx, mq, qr, mq.inters) } // AllX is like All, but panics if an error occurs. @@ -227,9 +227,12 @@ func (mq *MessageQuery) AllX(ctx context.Context) []*Message { } // IDs executes the query and returns a list of Message IDs. -func (mq *MessageQuery) IDs(ctx context.Context) ([]imap.InternalMessageID, error) { - var ids []imap.InternalMessageID - if err := mq.Select(message.FieldID).Scan(ctx, &ids); err != nil { +func (mq *MessageQuery) IDs(ctx context.Context) (ids []imap.InternalMessageID, err error) { + if mq.ctx.Unique == nil && mq.path != nil { + mq.Unique(true) + } + ctx = setContextOp(ctx, mq.ctx, "IDs") + if err = mq.Select(message.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -246,10 +249,11 @@ func (mq *MessageQuery) IDsX(ctx context.Context) []imap.InternalMessageID { // Count returns the count of the given query. func (mq *MessageQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mq.ctx, "Count") if err := mq.prepareQuery(ctx); err != nil { return 0, err } - return mq.sqlCount(ctx) + return withInterceptors[int](ctx, mq, querierCount[*MessageQuery](), mq.inters) } // CountX is like Count, but panics if an error occurs. @@ -263,10 +267,15 @@ func (mq *MessageQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mq *MessageQuery) Exist(ctx context.Context) (bool, error) { - if err := mq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mq.ctx, "Exist") + switch _, err := mq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return mq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -286,23 +295,22 @@ func (mq *MessageQuery) Clone() *MessageQuery { } return &MessageQuery{ config: mq.config, - limit: mq.limit, - offset: mq.offset, + ctx: mq.ctx.Clone(), order: append([]OrderFunc{}, mq.order...), + inters: append([]Interceptor{}, mq.inters...), predicates: append([]predicate.Message{}, mq.predicates...), withFlags: mq.withFlags.Clone(), withUIDs: mq.withUIDs.Clone(), // clone intermediate query. - sql: mq.sql.Clone(), - path: mq.path, - unique: mq.unique, + sql: mq.sql.Clone(), + path: mq.path, } } // WithFlags tells the query-builder to eager-load the nodes that are connected to // the "flags" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MessageQuery) WithFlags(opts ...func(*MessageFlagQuery)) *MessageQuery { - query := &MessageFlagQuery{config: mq.config} + query := (&MessageFlagClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -313,7 +321,7 @@ func (mq *MessageQuery) WithFlags(opts ...func(*MessageFlagQuery)) *MessageQuery // WithUIDs tells the query-builder to eager-load the nodes that are connected to // the "UIDs" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MessageQuery) WithUIDs(opts ...func(*UIDQuery)) *MessageQuery { - query := &UIDQuery{config: mq.config} + query := (&UIDClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -333,19 +341,14 @@ func (mq *MessageQuery) WithUIDs(opts ...func(*UIDQuery)) *MessageQuery { // // client.Message.Query(). // GroupBy(message.FieldRemoteID). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (mq *MessageQuery) GroupBy(field string, fields ...string) *MessageGroupBy { - grbuild := &MessageGroupBy{config: mq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mq.prepareQuery(ctx); err != nil { - return nil, err - } - return mq.sqlQuery(ctx), nil - } + mq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MessageGroupBy{build: mq} + grbuild.flds = &mq.ctx.Fields grbuild.label = message.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -362,17 +365,32 @@ func (mq *MessageQuery) GroupBy(field string, fields ...string) *MessageGroupBy // Select(message.FieldRemoteID). // Scan(ctx, &v) func (mq *MessageQuery) Select(fields ...string) *MessageSelect { - mq.fields = append(mq.fields, fields...) - selbuild := &MessageSelect{MessageQuery: mq} - selbuild.label = message.Label - selbuild.flds, selbuild.scan = &mq.fields, selbuild.Scan - return selbuild + mq.ctx.Fields = append(mq.ctx.Fields, fields...) + sbuild := &MessageSelect{MessageQuery: mq} + sbuild.label = message.Label + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MessageSelect configured with the given aggregations. +func (mq *MessageQuery) Aggregate(fns ...AggregateFunc) *MessageSelect { + return mq.Select().Aggregate(fns...) } func (mq *MessageQuery) prepareQuery(ctx context.Context) error { - for _, f := range mq.fields { + for _, inter := range mq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mq); err != nil { + return err + } + } + } + for _, f := range mq.ctx.Fields { if !message.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if mq.path != nil { @@ -394,10 +412,10 @@ func (mq *MessageQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Mess mq.withUIDs != nil, } ) - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*Message).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &Message{config: mq.config} nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes @@ -494,38 +512,22 @@ func (mq *MessageQuery) loadUIDs(ctx context.Context, query *UIDQuery, nodes []* func (mq *MessageQuery) sqlCount(ctx context.Context) (int, error) { _spec := mq.querySpec() - _spec.Node.Columns = mq.fields - if len(mq.fields) > 0 { - _spec.Unique = mq.unique != nil && *mq.unique + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique } return sqlgraph.CountNodes(ctx, mq.driver, _spec) } -func (mq *MessageQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := mq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (mq *MessageQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: message.Table, - Columns: message.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUUID, - Column: message.FieldID, - }, - }, - From: mq.sql, - Unique: true, - } - if unique := mq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(message.Table, message.Columns, sqlgraph.NewFieldSpec(message.FieldID, field.TypeUUID)) + _spec.From = mq.sql + if unique := mq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mq.path != nil { + _spec.Unique = true } - if fields := mq.fields; len(fields) > 0 { + if fields := mq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, message.FieldID) for i := range fields { @@ -541,10 +543,10 @@ func (mq *MessageQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mq.order; len(ps) > 0 { @@ -560,7 +562,7 @@ func (mq *MessageQuery) querySpec() *sqlgraph.QuerySpec { func (mq *MessageQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mq.driver.Dialect()) t1 := builder.Table(message.Table) - columns := mq.fields + columns := mq.ctx.Fields if len(columns) == 0 { columns = message.Columns } @@ -569,7 +571,7 @@ func (mq *MessageQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mq.sql selector.Select(selector.Columns(columns...)...) } - if mq.unique != nil && *mq.unique { + if mq.ctx.Unique != nil && *mq.ctx.Unique { selector.Distinct() } for _, p := range mq.predicates { @@ -578,12 +580,12 @@ func (mq *MessageQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mq.order { p(selector) } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -591,13 +593,8 @@ func (mq *MessageQuery) sqlQuery(ctx context.Context) *sql.Selector { // MessageGroupBy is the group-by builder for Message entities. type MessageGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MessageQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -606,74 +603,77 @@ func (mgb *MessageGroupBy) Aggregate(fns ...AggregateFunc) *MessageGroupBy { return mgb } -// Scan applies the group-by query and scans the result into the given value. -func (mgb *MessageGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := mgb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (mgb *MessageGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") + if err := mgb.build.prepareQuery(ctx); err != nil { return err } - mgb.sql = query - return mgb.sqlScan(ctx, v) + return scanWithInterceptors[*MessageQuery, *MessageGroupBy](ctx, mgb.build, mgb, mgb.build.inters, v) } -func (mgb *MessageGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range mgb.fields { - if !message.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mgb *MessageGroupBy) sqlScan(ctx context.Context, root *MessageQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mgb.fns)) + for _, fn := range mgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mgb.flds)+len(mgb.fns)) + for _, f := range *mgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mgb.sqlQuery() + selector.GroupBy(selector.Columns(*mgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mgb *MessageGroupBy) sqlQuery() *sql.Selector { - selector := mgb.sql.Select() - aggregation := make([]string, 0, len(mgb.fns)) - for _, fn := range mgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mgb.fields)+len(mgb.fns)) - for _, f := range mgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mgb.fields...)...) -} - // MessageSelect is the builder for selecting fields of Message entities. type MessageSelect struct { *MessageQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ms *MessageSelect) Aggregate(fns ...AggregateFunc) *MessageSelect { + ms.fns = append(ms.fns, fns...) + return ms } // Scan applies the selector query and scans the result into the given value. -func (ms *MessageSelect) Scan(ctx context.Context, v interface{}) error { +func (ms *MessageSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ms.ctx, "Select") if err := ms.prepareQuery(ctx); err != nil { return err } - ms.sql = ms.MessageQuery.sqlQuery(ctx) - return ms.sqlScan(ctx, v) + return scanWithInterceptors[*MessageQuery, *MessageSelect](ctx, ms.MessageQuery, ms, ms.inters, v) } -func (ms *MessageSelect) sqlScan(ctx context.Context, v interface{}) error { +func (ms *MessageSelect) sqlScan(ctx context.Context, root *MessageQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ms.fns)) + for _, fn := range ms.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ms.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ms.sql.Query() + query, args := selector.Query() if err := ms.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/message_update.go b/internal/db_impl/ent_db/internal/message_update.go similarity index 78% rename from internal/db/ent/message_update.go rename to internal/db_impl/ent_db/internal/message_update.go index 15564592..ac56d906 100644 --- a/internal/db/ent/message_update.go +++ b/internal/db_impl/ent_db/internal/message_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -12,10 +12,10 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // MessageUpdate is the builder for updating Message entities. @@ -181,34 +181,7 @@ func (mu *MessageUpdate) RemoveUIDs(u ...*UID) *MessageUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (mu *MessageUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mu.hooks) == 0 { - affected, err = mu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mu.mutation = mutation - affected, err = mu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mu.hooks) - 1; i >= 0; i-- { - if mu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MessageMutation](ctx, mu.sqlSave, mu.mutation, mu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -234,16 +207,7 @@ func (mu *MessageUpdate) ExecX(ctx context.Context) { } func (mu *MessageUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: message.Table, - Columns: message.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUUID, - Column: message.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(message.Table, message.Columns, sqlgraph.NewFieldSpec(message.FieldID, field.TypeUUID)) if ps := mu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -252,66 +216,31 @@ func (mu *MessageUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := mu.mutation.RemoteID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldRemoteID, - }) + _spec.SetField(message.FieldRemoteID, field.TypeString, value) } if mu.mutation.RemoteIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: message.FieldRemoteID, - }) + _spec.ClearField(message.FieldRemoteID, field.TypeString) } if value, ok := mu.mutation.Date(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: message.FieldDate, - }) + _spec.SetField(message.FieldDate, field.TypeTime, value) } if value, ok := mu.mutation.Size(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Value: value, - Column: message.FieldSize, - }) + _spec.SetField(message.FieldSize, field.TypeInt, value) } if value, ok := mu.mutation.AddedSize(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Value: value, - Column: message.FieldSize, - }) + _spec.AddField(message.FieldSize, field.TypeInt, value) } if value, ok := mu.mutation.Body(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldBody, - }) + _spec.SetField(message.FieldBody, field.TypeString, value) } if value, ok := mu.mutation.BodyStructure(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldBodyStructure, - }) + _spec.SetField(message.FieldBodyStructure, field.TypeString, value) } if value, ok := mu.mutation.Envelope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldEnvelope, - }) + _spec.SetField(message.FieldEnvelope, field.TypeString, value) } if value, ok := mu.mutation.Deleted(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: message.FieldDeleted, - }) + _spec.SetField(message.FieldDeleted, field.TypeBool, value) } if mu.mutation.FlagsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -429,6 +358,7 @@ func (mu *MessageUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mu.mutation.done = true return n, nil } @@ -588,6 +518,12 @@ func (muo *MessageUpdateOne) RemoveUIDs(u ...*UID) *MessageUpdateOne { return muo.RemoveUIDIDs(ids...) } +// Where appends a list predicates to the MessageUpdate builder. +func (muo *MessageUpdateOne) Where(ps ...predicate.Message) *MessageUpdateOne { + muo.mutation.Where(ps...) + return muo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (muo *MessageUpdateOne) Select(field string, fields ...string) *MessageUpdateOne { @@ -597,40 +533,7 @@ func (muo *MessageUpdateOne) Select(field string, fields ...string) *MessageUpda // Save executes the query and returns the updated Message entity. func (muo *MessageUpdateOne) Save(ctx context.Context) (*Message, error) { - var ( - err error - node *Message - ) - if len(muo.hooks) == 0 { - node, err = muo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - muo.mutation = mutation - node, err = muo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(muo.hooks) - 1; i >= 0; i-- { - if muo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = muo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, muo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Message) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MessageMutation", v) - } - node = nv - } - return node, err + return withHooks[*Message, MessageMutation](ctx, muo.sqlSave, muo.mutation, muo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -656,19 +559,10 @@ func (muo *MessageUpdateOne) ExecX(ctx context.Context) { } func (muo *MessageUpdateOne) sqlSave(ctx context.Context) (_node *Message, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: message.Table, - Columns: message.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUUID, - Column: message.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(message.Table, message.Columns, sqlgraph.NewFieldSpec(message.FieldID, field.TypeUUID)) id, ok := muo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Message.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "Message.id" for update`)} } _spec.Node.ID.Value = id if fields := muo.fields; len(fields) > 0 { @@ -676,7 +570,7 @@ func (muo *MessageUpdateOne) sqlSave(ctx context.Context) (_node *Message, err e _spec.Node.Columns = append(_spec.Node.Columns, message.FieldID) for _, f := range fields { if !message.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != message.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -691,66 +585,31 @@ func (muo *MessageUpdateOne) sqlSave(ctx context.Context) (_node *Message, err e } } if value, ok := muo.mutation.RemoteID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldRemoteID, - }) + _spec.SetField(message.FieldRemoteID, field.TypeString, value) } if muo.mutation.RemoteIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: message.FieldRemoteID, - }) + _spec.ClearField(message.FieldRemoteID, field.TypeString) } if value, ok := muo.mutation.Date(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: message.FieldDate, - }) + _spec.SetField(message.FieldDate, field.TypeTime, value) } if value, ok := muo.mutation.Size(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Value: value, - Column: message.FieldSize, - }) + _spec.SetField(message.FieldSize, field.TypeInt, value) } if value, ok := muo.mutation.AddedSize(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Value: value, - Column: message.FieldSize, - }) + _spec.AddField(message.FieldSize, field.TypeInt, value) } if value, ok := muo.mutation.Body(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldBody, - }) + _spec.SetField(message.FieldBody, field.TypeString, value) } if value, ok := muo.mutation.BodyStructure(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldBodyStructure, - }) + _spec.SetField(message.FieldBodyStructure, field.TypeString, value) } if value, ok := muo.mutation.Envelope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldEnvelope, - }) + _spec.SetField(message.FieldEnvelope, field.TypeString, value) } if value, ok := muo.mutation.Deleted(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: message.FieldDeleted, - }) + _spec.SetField(message.FieldDeleted, field.TypeBool, value) } if muo.mutation.FlagsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -871,5 +730,6 @@ func (muo *MessageUpdateOne) sqlSave(ctx context.Context) (_node *Message, err e } return nil, err } + muo.mutation.done = true return _node, nil } diff --git a/internal/db/ent/messageflag.go b/internal/db_impl/ent_db/internal/messageflag.go similarity index 86% rename from internal/db/ent/messageflag.go rename to internal/db_impl/ent_db/internal/messageflag.go index f1d82172..331e2cef 100644 --- a/internal/db/ent/messageflag.go +++ b/internal/db_impl/ent_db/internal/messageflag.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -8,8 +8,8 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" ) // MessageFlag is the model entity for the MessageFlag schema. @@ -48,8 +48,8 @@ func (e MessageFlagEdges) MessagesOrErr() (*Message, error) { } // scanValues returns the types for scanning values from sql.Rows. -func (*MessageFlag) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*MessageFlag) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case messageflag.FieldID: @@ -67,7 +67,7 @@ func (*MessageFlag) scanValues(columns []string) ([]interface{}, error) { // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the MessageFlag fields. -func (mf *MessageFlag) assignValues(columns []string, values []interface{}) error { +func (mf *MessageFlag) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -99,14 +99,14 @@ func (mf *MessageFlag) assignValues(columns []string, values []interface{}) erro // QueryMessages queries the "messages" edge of the MessageFlag entity. func (mf *MessageFlag) QueryMessages() *MessageQuery { - return (&MessageFlagClient{config: mf.config}).QueryMessages(mf) + return NewMessageFlagClient(mf.config).QueryMessages(mf) } // Update returns a builder for updating this MessageFlag. // Note that you need to call MessageFlag.Unwrap() before calling this method if this MessageFlag // was returned from a transaction, and the transaction was committed or rolled back. func (mf *MessageFlag) Update() *MessageFlagUpdateOne { - return (&MessageFlagClient{config: mf.config}).UpdateOne(mf) + return NewMessageFlagClient(mf.config).UpdateOne(mf) } // Unwrap unwraps the MessageFlag entity that was returned from a transaction after it was closed, @@ -114,7 +114,7 @@ func (mf *MessageFlag) Update() *MessageFlagUpdateOne { func (mf *MessageFlag) Unwrap() *MessageFlag { _tx, ok := mf.config.driver.(*txDriver) if !ok { - panic("ent: MessageFlag is not a transactional entity") + panic("internal: MessageFlag is not a transactional entity") } mf.config.driver = _tx.drv return mf @@ -133,9 +133,3 @@ func (mf *MessageFlag) String() string { // MessageFlags is a parsable slice of MessageFlag. type MessageFlags []*MessageFlag - -func (mf MessageFlags) config(cfg config) { - for _i := range mf { - mf[_i].config = cfg - } -} diff --git a/internal/db/ent/messageflag/messageflag.go b/internal/db_impl/ent_db/internal/messageflag/messageflag.go similarity index 100% rename from internal/db/ent/messageflag/messageflag.go rename to internal/db_impl/ent_db/internal/messageflag/messageflag.go diff --git a/internal/db/ent/messageflag/where.go b/internal/db_impl/ent_db/internal/messageflag/where.go similarity index 62% rename from internal/db/ent/messageflag/where.go rename to internal/db_impl/ent_db/internal/messageflag/where.go index a4377f12..65cc881a 100644 --- a/internal/db/ent/messageflag/where.go +++ b/internal/db_impl/ent_db/internal/messageflag/where.go @@ -5,184 +5,122 @@ package messageflag import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MessageFlag(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MessageFlag(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.MessageFlag(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.MessageFlag(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.MessageFlag(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.MessageFlag(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.MessageFlag(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.MessageFlag(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.MessageFlag(sql.FieldLTE(FieldID, id)) } // Value applies equality check predicate on the "Value" field. It's identical to ValueEQ. func Value(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldEQ(FieldValue, v)) } // ValueEQ applies the EQ predicate on the "Value" field. func ValueEQ(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "Value" field. func ValueNEQ(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "Value" field. func ValueIn(vs ...string) predicate.MessageFlag { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.MessageFlag(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "Value" field. func ValueNotIn(vs ...string) predicate.MessageFlag { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.MessageFlag(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "Value" field. func ValueGT(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "Value" field. func ValueGTE(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "Value" field. func ValueLT(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "Value" field. func ValueLTE(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "Value" field. func ValueContains(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "Value" field. func ValueHasPrefix(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "Value" field. func ValueHasSuffix(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "Value" field. func ValueEqualFold(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "Value" field. func ValueContainsFold(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldContainsFold(FieldValue, v)) } // HasMessages applies the HasEdge predicate on the "messages" edge. @@ -190,7 +128,6 @@ func HasMessages() predicate.MessageFlag { return predicate.MessageFlag(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(MessagesTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, MessagesTable, MessagesColumn), ) sqlgraph.HasNeighbors(s, step) diff --git a/internal/db/ent/messageflag_create.go b/internal/db_impl/ent_db/internal/messageflag_create.go similarity index 78% rename from internal/db/ent/messageflag_create.go rename to internal/db_impl/ent_db/internal/messageflag_create.go index 7b7bec11..c99f9e1a 100644 --- a/internal/db/ent/messageflag_create.go +++ b/internal/db_impl/ent_db/internal/messageflag_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,8 +10,8 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" ) // MessageFlagCreate is the builder for creating a MessageFlag entity. @@ -53,49 +53,7 @@ func (mfc *MessageFlagCreate) Mutation() *MessageFlagMutation { // Save creates the MessageFlag in the database. func (mfc *MessageFlagCreate) Save(ctx context.Context) (*MessageFlag, error) { - var ( - err error - node *MessageFlag - ) - if len(mfc.hooks) == 0 { - if err = mfc.check(); err != nil { - return nil, err - } - node, err = mfc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mfc.check(); err != nil { - return nil, err - } - mfc.mutation = mutation - if node, err = mfc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mfc.hooks) - 1; i >= 0; i-- { - if mfc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mfc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MessageFlag) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MessageFlagMutation", v) - } - node = nv - } - return node, err + return withHooks[*MessageFlag, MessageFlagMutation](ctx, mfc.sqlSave, mfc.mutation, mfc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -123,12 +81,15 @@ func (mfc *MessageFlagCreate) ExecX(ctx context.Context) { // check runs all checks and user-defined validators on the builder. func (mfc *MessageFlagCreate) check() error { if _, ok := mfc.mutation.Value(); !ok { - return &ValidationError{Name: "Value", err: errors.New(`ent: missing required field "MessageFlag.Value"`)} + return &ValidationError{Name: "Value", err: errors.New(`internal: missing required field "MessageFlag.Value"`)} } return nil } func (mfc *MessageFlagCreate) sqlSave(ctx context.Context) (*MessageFlag, error) { + if err := mfc.check(); err != nil { + return nil, err + } _node, _spec := mfc.createSpec() if err := sqlgraph.CreateNode(ctx, mfc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -138,26 +99,18 @@ func (mfc *MessageFlagCreate) sqlSave(ctx context.Context) (*MessageFlag, error) } id := _spec.ID.Value.(int64) _node.ID = int(id) + mfc.mutation.id = &_node.ID + mfc.mutation.done = true return _node, nil } func (mfc *MessageFlagCreate) createSpec() (*MessageFlag, *sqlgraph.CreateSpec) { var ( _node = &MessageFlag{config: mfc.config} - _spec = &sqlgraph.CreateSpec{ - Table: messageflag.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: messageflag.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(messageflag.Table, sqlgraph.NewFieldSpec(messageflag.FieldID, field.TypeInt)) ) if value, ok := mfc.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: messageflag.FieldValue, - }) + _spec.SetField(messageflag.FieldValue, field.TypeString, value) _node.Value = value } if nodes := mfc.mutation.MessagesIDs(); len(nodes) > 0 { diff --git a/internal/db/ent/messageflag_delete.go b/internal/db_impl/ent_db/internal/messageflag_delete.go similarity index 62% rename from internal/db/ent/messageflag_delete.go rename to internal/db_impl/ent_db/internal/messageflag_delete.go index 728d5d8d..f6371b70 100644 --- a/internal/db/ent/messageflag_delete.go +++ b/internal/db_impl/ent_db/internal/messageflag_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MessageFlagDelete is the builder for deleting a MessageFlag entity. @@ -28,34 +27,7 @@ func (mfd *MessageFlagDelete) Where(ps ...predicate.MessageFlag) *MessageFlagDel // Exec executes the deletion query and returns how many vertices were deleted. func (mfd *MessageFlagDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mfd.hooks) == 0 { - affected, err = mfd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mfd.mutation = mutation - affected, err = mfd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(mfd.hooks) - 1; i >= 0; i-- { - if mfd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mfd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MessageFlagMutation](ctx, mfd.sqlExec, mfd.mutation, mfd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (mfd *MessageFlagDelete) ExecX(ctx context.Context) int { } func (mfd *MessageFlagDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: messageflag.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: messageflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(messageflag.Table, sqlgraph.NewFieldSpec(messageflag.FieldID, field.TypeInt)) if ps := mfd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (mfd *MessageFlagDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + mfd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MessageFlagDeleteOne struct { mfd *MessageFlagDelete } +// Where appends a list predicates to the MessageFlagDelete builder. +func (mfdo *MessageFlagDeleteOne) Where(ps ...predicate.MessageFlag) *MessageFlagDeleteOne { + mfdo.mfd.mutation.Where(ps...) + return mfdo +} + // Exec executes the deletion query. func (mfdo *MessageFlagDeleteOne) Exec(ctx context.Context) error { n, err := mfdo.mfd.Exec(ctx) @@ -111,5 +82,7 @@ func (mfdo *MessageFlagDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mfdo *MessageFlagDeleteOne) ExecX(ctx context.Context) { - mfdo.mfd.ExecX(ctx) + if err := mfdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/messageflag_query.go b/internal/db_impl/ent_db/internal/messageflag_query.go similarity index 70% rename from internal/db/ent/messageflag_query.go rename to internal/db_impl/ent_db/internal/messageflag_query.go index 39244ca8..5d15eb5d 100644 --- a/internal/db/ent/messageflag_query.go +++ b/internal/db_impl/ent_db/internal/messageflag_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -11,19 +11,17 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MessageFlagQuery is the builder for querying MessageFlag entities. type MessageFlagQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.MessageFlag withMessages *MessageQuery withFKs bool @@ -38,26 +36,26 @@ func (mfq *MessageFlagQuery) Where(ps ...predicate.MessageFlag) *MessageFlagQuer return mfq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mfq *MessageFlagQuery) Limit(limit int) *MessageFlagQuery { - mfq.limit = &limit + mfq.ctx.Limit = &limit return mfq } -// Offset adds an offset step to the query. +// Offset to start from. func (mfq *MessageFlagQuery) Offset(offset int) *MessageFlagQuery { - mfq.offset = &offset + mfq.ctx.Offset = &offset return mfq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mfq *MessageFlagQuery) Unique(unique bool) *MessageFlagQuery { - mfq.unique = &unique + mfq.ctx.Unique = &unique return mfq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (mfq *MessageFlagQuery) Order(o ...OrderFunc) *MessageFlagQuery { mfq.order = append(mfq.order, o...) return mfq @@ -65,7 +63,7 @@ func (mfq *MessageFlagQuery) Order(o ...OrderFunc) *MessageFlagQuery { // QueryMessages chains the current query on the "messages" edge. func (mfq *MessageFlagQuery) QueryMessages() *MessageQuery { - query := &MessageQuery{config: mfq.config} + query := (&MessageClient{config: mfq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mfq.prepareQuery(ctx); err != nil { return nil, err @@ -88,7 +86,7 @@ func (mfq *MessageFlagQuery) QueryMessages() *MessageQuery { // First returns the first MessageFlag entity from the query. // Returns a *NotFoundError when no MessageFlag was found. func (mfq *MessageFlagQuery) First(ctx context.Context) (*MessageFlag, error) { - nodes, err := mfq.Limit(1).All(ctx) + nodes, err := mfq.Limit(1).All(setContextOp(ctx, mfq.ctx, "First")) if err != nil { return nil, err } @@ -111,7 +109,7 @@ func (mfq *MessageFlagQuery) FirstX(ctx context.Context) *MessageFlag { // Returns a *NotFoundError when no MessageFlag ID was found. func (mfq *MessageFlagQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mfq.Limit(1).IDs(ctx); err != nil { + if ids, err = mfq.Limit(1).IDs(setContextOp(ctx, mfq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -134,7 +132,7 @@ func (mfq *MessageFlagQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one MessageFlag entity is found. // Returns a *NotFoundError when no MessageFlag entities are found. func (mfq *MessageFlagQuery) Only(ctx context.Context) (*MessageFlag, error) { - nodes, err := mfq.Limit(2).All(ctx) + nodes, err := mfq.Limit(2).All(setContextOp(ctx, mfq.ctx, "Only")) if err != nil { return nil, err } @@ -162,7 +160,7 @@ func (mfq *MessageFlagQuery) OnlyX(ctx context.Context) *MessageFlag { // Returns a *NotFoundError when no entities are found. func (mfq *MessageFlagQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mfq.Limit(2).IDs(ctx); err != nil { + if ids, err = mfq.Limit(2).IDs(setContextOp(ctx, mfq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -187,10 +185,12 @@ func (mfq *MessageFlagQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of MessageFlags. func (mfq *MessageFlagQuery) All(ctx context.Context) ([]*MessageFlag, error) { + ctx = setContextOp(ctx, mfq.ctx, "All") if err := mfq.prepareQuery(ctx); err != nil { return nil, err } - return mfq.sqlAll(ctx) + qr := querierAll[[]*MessageFlag, *MessageFlagQuery]() + return withInterceptors[[]*MessageFlag](ctx, mfq, qr, mfq.inters) } // AllX is like All, but panics if an error occurs. @@ -203,9 +203,12 @@ func (mfq *MessageFlagQuery) AllX(ctx context.Context) []*MessageFlag { } // IDs executes the query and returns a list of MessageFlag IDs. -func (mfq *MessageFlagQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := mfq.Select(messageflag.FieldID).Scan(ctx, &ids); err != nil { +func (mfq *MessageFlagQuery) IDs(ctx context.Context) (ids []int, err error) { + if mfq.ctx.Unique == nil && mfq.path != nil { + mfq.Unique(true) + } + ctx = setContextOp(ctx, mfq.ctx, "IDs") + if err = mfq.Select(messageflag.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -222,10 +225,11 @@ func (mfq *MessageFlagQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (mfq *MessageFlagQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mfq.ctx, "Count") if err := mfq.prepareQuery(ctx); err != nil { return 0, err } - return mfq.sqlCount(ctx) + return withInterceptors[int](ctx, mfq, querierCount[*MessageFlagQuery](), mfq.inters) } // CountX is like Count, but panics if an error occurs. @@ -239,10 +243,15 @@ func (mfq *MessageFlagQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mfq *MessageFlagQuery) Exist(ctx context.Context) (bool, error) { - if err := mfq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mfq.ctx, "Exist") + switch _, err := mfq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return mfq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -262,22 +271,21 @@ func (mfq *MessageFlagQuery) Clone() *MessageFlagQuery { } return &MessageFlagQuery{ config: mfq.config, - limit: mfq.limit, - offset: mfq.offset, + ctx: mfq.ctx.Clone(), order: append([]OrderFunc{}, mfq.order...), + inters: append([]Interceptor{}, mfq.inters...), predicates: append([]predicate.MessageFlag{}, mfq.predicates...), withMessages: mfq.withMessages.Clone(), // clone intermediate query. - sql: mfq.sql.Clone(), - path: mfq.path, - unique: mfq.unique, + sql: mfq.sql.Clone(), + path: mfq.path, } } // WithMessages tells the query-builder to eager-load the nodes that are connected to // the "messages" edge. The optional arguments are used to configure the query builder of the edge. func (mfq *MessageFlagQuery) WithMessages(opts ...func(*MessageQuery)) *MessageFlagQuery { - query := &MessageQuery{config: mfq.config} + query := (&MessageClient{config: mfq.config}).Query() for _, opt := range opts { opt(query) } @@ -297,19 +305,14 @@ func (mfq *MessageFlagQuery) WithMessages(opts ...func(*MessageQuery)) *MessageF // // client.MessageFlag.Query(). // GroupBy(messageflag.FieldValue). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (mfq *MessageFlagQuery) GroupBy(field string, fields ...string) *MessageFlagGroupBy { - grbuild := &MessageFlagGroupBy{config: mfq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mfq.prepareQuery(ctx); err != nil { - return nil, err - } - return mfq.sqlQuery(ctx), nil - } + mfq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MessageFlagGroupBy{build: mfq} + grbuild.flds = &mfq.ctx.Fields grbuild.label = messageflag.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -326,17 +329,32 @@ func (mfq *MessageFlagQuery) GroupBy(field string, fields ...string) *MessageFla // Select(messageflag.FieldValue). // Scan(ctx, &v) func (mfq *MessageFlagQuery) Select(fields ...string) *MessageFlagSelect { - mfq.fields = append(mfq.fields, fields...) - selbuild := &MessageFlagSelect{MessageFlagQuery: mfq} - selbuild.label = messageflag.Label - selbuild.flds, selbuild.scan = &mfq.fields, selbuild.Scan - return selbuild + mfq.ctx.Fields = append(mfq.ctx.Fields, fields...) + sbuild := &MessageFlagSelect{MessageFlagQuery: mfq} + sbuild.label = messageflag.Label + sbuild.flds, sbuild.scan = &mfq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MessageFlagSelect configured with the given aggregations. +func (mfq *MessageFlagQuery) Aggregate(fns ...AggregateFunc) *MessageFlagSelect { + return mfq.Select().Aggregate(fns...) } func (mfq *MessageFlagQuery) prepareQuery(ctx context.Context) error { - for _, f := range mfq.fields { + for _, inter := range mfq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mfq); err != nil { + return err + } + } + } + for _, f := range mfq.ctx.Fields { if !messageflag.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if mfq.path != nil { @@ -364,10 +382,10 @@ func (mfq *MessageFlagQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([] if withFKs { _spec.Node.Columns = append(_spec.Node.Columns, messageflag.ForeignKeys...) } - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*MessageFlag).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &MessageFlag{config: mfq.config} nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes @@ -404,6 +422,9 @@ func (mfq *MessageFlagQuery) loadMessages(ctx context.Context, query *MessageQue } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(message.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -423,38 +444,22 @@ func (mfq *MessageFlagQuery) loadMessages(ctx context.Context, query *MessageQue func (mfq *MessageFlagQuery) sqlCount(ctx context.Context) (int, error) { _spec := mfq.querySpec() - _spec.Node.Columns = mfq.fields - if len(mfq.fields) > 0 { - _spec.Unique = mfq.unique != nil && *mfq.unique + _spec.Node.Columns = mfq.ctx.Fields + if len(mfq.ctx.Fields) > 0 { + _spec.Unique = mfq.ctx.Unique != nil && *mfq.ctx.Unique } return sqlgraph.CountNodes(ctx, mfq.driver, _spec) } -func (mfq *MessageFlagQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := mfq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (mfq *MessageFlagQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: messageflag.Table, - Columns: messageflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: messageflag.FieldID, - }, - }, - From: mfq.sql, - Unique: true, - } - if unique := mfq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(messageflag.Table, messageflag.Columns, sqlgraph.NewFieldSpec(messageflag.FieldID, field.TypeInt)) + _spec.From = mfq.sql + if unique := mfq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mfq.path != nil { + _spec.Unique = true } - if fields := mfq.fields; len(fields) > 0 { + if fields := mfq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, messageflag.FieldID) for i := range fields { @@ -470,10 +475,10 @@ func (mfq *MessageFlagQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mfq.limit; limit != nil { + if limit := mfq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mfq.offset; offset != nil { + if offset := mfq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mfq.order; len(ps) > 0 { @@ -489,7 +494,7 @@ func (mfq *MessageFlagQuery) querySpec() *sqlgraph.QuerySpec { func (mfq *MessageFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mfq.driver.Dialect()) t1 := builder.Table(messageflag.Table) - columns := mfq.fields + columns := mfq.ctx.Fields if len(columns) == 0 { columns = messageflag.Columns } @@ -498,7 +503,7 @@ func (mfq *MessageFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mfq.sql selector.Select(selector.Columns(columns...)...) } - if mfq.unique != nil && *mfq.unique { + if mfq.ctx.Unique != nil && *mfq.ctx.Unique { selector.Distinct() } for _, p := range mfq.predicates { @@ -507,12 +512,12 @@ func (mfq *MessageFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mfq.order { p(selector) } - if offset := mfq.offset; offset != nil { + if offset := mfq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mfq.limit; limit != nil { + if limit := mfq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -520,13 +525,8 @@ func (mfq *MessageFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { // MessageFlagGroupBy is the group-by builder for MessageFlag entities. type MessageFlagGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MessageFlagQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -535,74 +535,77 @@ func (mfgb *MessageFlagGroupBy) Aggregate(fns ...AggregateFunc) *MessageFlagGrou return mfgb } -// Scan applies the group-by query and scans the result into the given value. -func (mfgb *MessageFlagGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := mfgb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (mfgb *MessageFlagGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mfgb.build.ctx, "GroupBy") + if err := mfgb.build.prepareQuery(ctx); err != nil { return err } - mfgb.sql = query - return mfgb.sqlScan(ctx, v) + return scanWithInterceptors[*MessageFlagQuery, *MessageFlagGroupBy](ctx, mfgb.build, mfgb, mfgb.build.inters, v) } -func (mfgb *MessageFlagGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range mfgb.fields { - if !messageflag.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mfgb *MessageFlagGroupBy) sqlScan(ctx context.Context, root *MessageFlagQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mfgb.fns)) + for _, fn := range mfgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mfgb.flds)+len(mfgb.fns)) + for _, f := range *mfgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mfgb.sqlQuery() + selector.GroupBy(selector.Columns(*mfgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mfgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mfgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mfgb *MessageFlagGroupBy) sqlQuery() *sql.Selector { - selector := mfgb.sql.Select() - aggregation := make([]string, 0, len(mfgb.fns)) - for _, fn := range mfgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mfgb.fields)+len(mfgb.fns)) - for _, f := range mfgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mfgb.fields...)...) -} - // MessageFlagSelect is the builder for selecting fields of MessageFlag entities. type MessageFlagSelect struct { *MessageFlagQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (mfs *MessageFlagSelect) Aggregate(fns ...AggregateFunc) *MessageFlagSelect { + mfs.fns = append(mfs.fns, fns...) + return mfs } // Scan applies the selector query and scans the result into the given value. -func (mfs *MessageFlagSelect) Scan(ctx context.Context, v interface{}) error { +func (mfs *MessageFlagSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mfs.ctx, "Select") if err := mfs.prepareQuery(ctx); err != nil { return err } - mfs.sql = mfs.MessageFlagQuery.sqlQuery(ctx) - return mfs.sqlScan(ctx, v) + return scanWithInterceptors[*MessageFlagQuery, *MessageFlagSelect](ctx, mfs.MessageFlagQuery, mfs, mfs.inters, v) } -func (mfs *MessageFlagSelect) sqlScan(ctx context.Context, v interface{}) error { +func (mfs *MessageFlagSelect) sqlScan(ctx context.Context, root *MessageFlagQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(mfs.fns)) + for _, fn := range mfs.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*mfs.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := mfs.sql.Query() + query, args := selector.Query() if err := mfs.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/messageflag_update.go b/internal/db_impl/ent_db/internal/messageflag_update.go similarity index 75% rename from internal/db/ent/messageflag_update.go rename to internal/db_impl/ent_db/internal/messageflag_update.go index 923a9270..0c49ae62 100644 --- a/internal/db/ent/messageflag_update.go +++ b/internal/db_impl/ent_db/internal/messageflag_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -11,9 +11,9 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MessageFlagUpdate is the builder for updating MessageFlag entities. @@ -67,34 +67,7 @@ func (mfu *MessageFlagUpdate) ClearMessages() *MessageFlagUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (mfu *MessageFlagUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mfu.hooks) == 0 { - affected, err = mfu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mfu.mutation = mutation - affected, err = mfu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mfu.hooks) - 1; i >= 0; i-- { - if mfu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mfu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MessageFlagMutation](ctx, mfu.sqlSave, mfu.mutation, mfu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -120,16 +93,7 @@ func (mfu *MessageFlagUpdate) ExecX(ctx context.Context) { } func (mfu *MessageFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: messageflag.Table, - Columns: messageflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: messageflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(messageflag.Table, messageflag.Columns, sqlgraph.NewFieldSpec(messageflag.FieldID, field.TypeInt)) if ps := mfu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -138,11 +102,7 @@ func (mfu *MessageFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := mfu.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: messageflag.FieldValue, - }) + _spec.SetField(messageflag.FieldValue, field.TypeString, value) } if mfu.mutation.MessagesCleared() { edge := &sqlgraph.EdgeSpec{ @@ -187,6 +147,7 @@ func (mfu *MessageFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mfu.mutation.done = true return n, nil } @@ -234,6 +195,12 @@ func (mfuo *MessageFlagUpdateOne) ClearMessages() *MessageFlagUpdateOne { return mfuo } +// Where appends a list predicates to the MessageFlagUpdate builder. +func (mfuo *MessageFlagUpdateOne) Where(ps ...predicate.MessageFlag) *MessageFlagUpdateOne { + mfuo.mutation.Where(ps...) + return mfuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (mfuo *MessageFlagUpdateOne) Select(field string, fields ...string) *MessageFlagUpdateOne { @@ -243,40 +210,7 @@ func (mfuo *MessageFlagUpdateOne) Select(field string, fields ...string) *Messag // Save executes the query and returns the updated MessageFlag entity. func (mfuo *MessageFlagUpdateOne) Save(ctx context.Context) (*MessageFlag, error) { - var ( - err error - node *MessageFlag - ) - if len(mfuo.hooks) == 0 { - node, err = mfuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mfuo.mutation = mutation - node, err = mfuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(mfuo.hooks) - 1; i >= 0; i-- { - if mfuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mfuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MessageFlag) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MessageFlagMutation", v) - } - node = nv - } - return node, err + return withHooks[*MessageFlag, MessageFlagMutation](ctx, mfuo.sqlSave, mfuo.mutation, mfuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -302,19 +236,10 @@ func (mfuo *MessageFlagUpdateOne) ExecX(ctx context.Context) { } func (mfuo *MessageFlagUpdateOne) sqlSave(ctx context.Context) (_node *MessageFlag, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: messageflag.Table, - Columns: messageflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: messageflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(messageflag.Table, messageflag.Columns, sqlgraph.NewFieldSpec(messageflag.FieldID, field.TypeInt)) id, ok := mfuo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "MessageFlag.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "MessageFlag.id" for update`)} } _spec.Node.ID.Value = id if fields := mfuo.fields; len(fields) > 0 { @@ -322,7 +247,7 @@ func (mfuo *MessageFlagUpdateOne) sqlSave(ctx context.Context) (_node *MessageFl _spec.Node.Columns = append(_spec.Node.Columns, messageflag.FieldID) for _, f := range fields { if !messageflag.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != messageflag.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -337,11 +262,7 @@ func (mfuo *MessageFlagUpdateOne) sqlSave(ctx context.Context) (_node *MessageFl } } if value, ok := mfuo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: messageflag.FieldValue, - }) + _spec.SetField(messageflag.FieldValue, field.TypeString, value) } if mfuo.mutation.MessagesCleared() { edge := &sqlgraph.EdgeSpec{ @@ -389,5 +310,6 @@ func (mfuo *MessageFlagUpdateOne) sqlSave(ctx context.Context) (_node *MessageFl } return nil, err } + mfuo.mutation.done = true return _node, nil } diff --git a/internal/db/ent/migrate/migrate.go b/internal/db_impl/ent_db/internal/migrate/migrate.go similarity index 100% rename from internal/db/ent/migrate/migrate.go rename to internal/db_impl/ent_db/internal/migrate/migrate.go diff --git a/internal/db/ent/migrate/schema.go b/internal/db_impl/ent_db/internal/migrate/schema.go similarity index 100% rename from internal/db/ent/migrate/schema.go rename to internal/db_impl/ent_db/internal/migrate/schema.go diff --git a/internal/db/ent/mutation.go b/internal/db_impl/ent_db/internal/mutation.go similarity index 96% rename from internal/db/ent/mutation.go rename to internal/db_impl/ent_db/internal/mutation.go index 8bdc9d5f..d2faa1dd 100644 --- a/internal/db/ent/mutation.go +++ b/internal/db_impl/ent_db/internal/mutation.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,17 +10,18 @@ import ( "time" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" "entgo.io/ent" + "entgo.io/ent/dialect/sql" ) const ( @@ -119,7 +120,7 @@ func (m DeletedSubscriptionMutation) Client() *Client { // it returns an error otherwise. func (m DeletedSubscriptionMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -231,11 +232,26 @@ func (m *DeletedSubscriptionMutation) Where(ps ...predicate.DeletedSubscription) m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the DeletedSubscriptionMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *DeletedSubscriptionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.DeletedSubscription, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *DeletedSubscriptionMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *DeletedSubscriptionMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (DeletedSubscription). func (m *DeletedSubscriptionMutation) Type() string { return m.typ @@ -501,7 +517,7 @@ func (m MailboxMutation) Client() *Client { // it returns an error otherwise. func (m MailboxMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -996,11 +1012,26 @@ func (m *MailboxMutation) Where(ps ...predicate.Mailbox) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MailboxMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MailboxMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Mailbox, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *MailboxMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *MailboxMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Mailbox). func (m *MailboxMutation) Type() string { return m.typ @@ -1449,7 +1480,7 @@ func (m MailboxAttrMutation) Client() *Client { // it returns an error otherwise. func (m MailboxAttrMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -1525,11 +1556,26 @@ func (m *MailboxAttrMutation) Where(ps ...predicate.MailboxAttr) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MailboxAttrMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MailboxAttrMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.MailboxAttr, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *MailboxAttrMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *MailboxAttrMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (MailboxAttr). func (m *MailboxAttrMutation) Type() string { return m.typ @@ -1760,7 +1806,7 @@ func (m MailboxFlagMutation) Client() *Client { // it returns an error otherwise. func (m MailboxFlagMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -1836,11 +1882,26 @@ func (m *MailboxFlagMutation) Where(ps ...predicate.MailboxFlag) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MailboxFlagMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MailboxFlagMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.MailboxFlag, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *MailboxFlagMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *MailboxFlagMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (MailboxFlag). func (m *MailboxFlagMutation) Type() string { return m.typ @@ -2071,7 +2132,7 @@ func (m MailboxPermFlagMutation) Client() *Client { // it returns an error otherwise. func (m MailboxPermFlagMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -2147,11 +2208,26 @@ func (m *MailboxPermFlagMutation) Where(ps ...predicate.MailboxPermFlag) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MailboxPermFlagMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MailboxPermFlagMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.MailboxPermFlag, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *MailboxPermFlagMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *MailboxPermFlagMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (MailboxPermFlag). func (m *MailboxPermFlagMutation) Type() string { return m.typ @@ -2395,7 +2471,7 @@ func (m MessageMutation) Client() *Client { // it returns an error otherwise. func (m MessageMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -2834,11 +2910,26 @@ func (m *MessageMutation) Where(ps ...predicate.Message) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MessageMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MessageMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Message, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *MessageMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *MessageMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Message). func (m *MessageMutation) Type() string { return m.typ @@ -3259,7 +3350,7 @@ func (m MessageFlagMutation) Client() *Client { // it returns an error otherwise. func (m MessageFlagMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -3374,11 +3465,26 @@ func (m *MessageFlagMutation) Where(ps ...predicate.MessageFlag) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MessageFlagMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MessageFlagMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.MessageFlag, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *MessageFlagMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *MessageFlagMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (MessageFlag). func (m *MessageFlagMutation) Type() string { return m.typ @@ -3515,8 +3621,6 @@ func (m *MessageFlagMutation) RemovedEdges() []string { // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. func (m *MessageFlagMutation) RemovedIDs(name string) []ent.Value { - switch name { - } return nil } @@ -3644,7 +3748,7 @@ func (m UIDMutation) Client() *Client { // it returns an error otherwise. func (m UIDMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -3890,11 +3994,26 @@ func (m *UIDMutation) Where(ps ...predicate.UID) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the UIDMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *UIDMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.UID, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *UIDMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *UIDMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (UID). func (m *UIDMutation) Type() string { return m.typ @@ -4087,8 +4206,6 @@ func (m *UIDMutation) RemovedEdges() []string { // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. func (m *UIDMutation) RemovedIDs(name string) []ent.Value { - switch name { - } return nil } diff --git a/internal/db/ent/predicate/predicate.go b/internal/db_impl/ent_db/internal/predicate/predicate.go similarity index 100% rename from internal/db/ent/predicate/predicate.go rename to internal/db_impl/ent_db/internal/predicate/predicate.go diff --git a/internal/db/ent/runtime.go b/internal/db_impl/ent_db/internal/runtime.go similarity index 87% rename from internal/db/ent/runtime.go rename to internal/db_impl/ent_db/internal/runtime.go index a36a489b..7c8fa47b 100644 --- a/internal/db/ent/runtime.go +++ b/internal/db_impl/ent_db/internal/runtime.go @@ -1,13 +1,13 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/schema" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/schema" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // The init function reads all schema descriptors with runtime code diff --git a/internal/db_impl/ent_db/internal/runtime/runtime.go b/internal/db_impl/ent_db/internal/runtime/runtime.go new file mode 100644 index 00000000..5bac1906 --- /dev/null +++ b/internal/db_impl/ent_db/internal/runtime/runtime.go @@ -0,0 +1,10 @@ +// Code generated by ent, DO NOT EDIT. + +package runtime + +// The schema-stitching logic is generated in github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/runtime.go + +const ( + Version = "v0.11.8" // Version of ent codegen. + Sum = "h1:M/M0QL1CYCUSdqGRXUrXhFYSDRJPsOOrr+RLEej/gyQ=" // Sum of ent codegen. +) diff --git a/internal/db/ent/schema/mailbox.go b/internal/db_impl/ent_db/internal/schema/mailbox.go similarity index 100% rename from internal/db/ent/schema/mailbox.go rename to internal/db_impl/ent_db/internal/schema/mailbox.go diff --git a/internal/db/ent/schema/mailboxattr.go b/internal/db_impl/ent_db/internal/schema/mailboxattr.go similarity index 100% rename from internal/db/ent/schema/mailboxattr.go rename to internal/db_impl/ent_db/internal/schema/mailboxattr.go diff --git a/internal/db/ent/schema/mailboxflag.go b/internal/db_impl/ent_db/internal/schema/mailboxflag.go similarity index 100% rename from internal/db/ent/schema/mailboxflag.go rename to internal/db_impl/ent_db/internal/schema/mailboxflag.go diff --git a/internal/db/ent/schema/mailboxpermflag.go b/internal/db_impl/ent_db/internal/schema/mailboxpermflag.go similarity index 100% rename from internal/db/ent/schema/mailboxpermflag.go rename to internal/db_impl/ent_db/internal/schema/mailboxpermflag.go diff --git a/internal/db/ent/schema/message.go b/internal/db_impl/ent_db/internal/schema/message.go similarity index 100% rename from internal/db/ent/schema/message.go rename to internal/db_impl/ent_db/internal/schema/message.go diff --git a/internal/db/ent/schema/messageflag.go b/internal/db_impl/ent_db/internal/schema/messageflag.go similarity index 100% rename from internal/db/ent/schema/messageflag.go rename to internal/db_impl/ent_db/internal/schema/messageflag.go diff --git a/internal/db/ent/schema/subscriptions.go b/internal/db_impl/ent_db/internal/schema/subscriptions.go similarity index 100% rename from internal/db/ent/schema/subscriptions.go rename to internal/db_impl/ent_db/internal/schema/subscriptions.go diff --git a/internal/db/ent/schema/uid.go b/internal/db_impl/ent_db/internal/schema/uid.go similarity index 100% rename from internal/db/ent/schema/uid.go rename to internal/db_impl/ent_db/internal/schema/uid.go diff --git a/internal/db/ent/tx.go b/internal/db_impl/ent_db/internal/tx.go similarity index 92% rename from internal/db/ent/tx.go rename to internal/db_impl/ent_db/internal/tx.go index d97d7848..b3350afd 100644 --- a/internal/db/ent/tx.go +++ b/internal/db_impl/ent_db/internal/tx.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -32,12 +32,6 @@ type Tx struct { // lazily loaded. client *Client clientOnce sync.Once - - // completion callbacks. - mu sync.Mutex - onCommit []CommitHook - onRollback []RollbackHook - // ctx lives for the life of the transaction. It is // the same context used by the underlying connection. ctx context.Context @@ -82,9 +76,9 @@ func (tx *Tx) Commit() error { var fn Committer = CommitFunc(func(context.Context, *Tx) error { return txDriver.tx.Commit() }) - tx.mu.Lock() - hooks := append([]CommitHook(nil), tx.onCommit...) - tx.mu.Unlock() + txDriver.mu.Lock() + hooks := append([]CommitHook(nil), txDriver.onCommit...) + txDriver.mu.Unlock() for i := len(hooks) - 1; i >= 0; i-- { fn = hooks[i](fn) } @@ -93,9 +87,10 @@ func (tx *Tx) Commit() error { // OnCommit adds a hook to call on commit. func (tx *Tx) OnCommit(f CommitHook) { - tx.mu.Lock() - defer tx.mu.Unlock() - tx.onCommit = append(tx.onCommit, f) + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onCommit = append(txDriver.onCommit, f) + txDriver.mu.Unlock() } type ( @@ -137,9 +132,9 @@ func (tx *Tx) Rollback() error { var fn Rollbacker = RollbackFunc(func(context.Context, *Tx) error { return txDriver.tx.Rollback() }) - tx.mu.Lock() - hooks := append([]RollbackHook(nil), tx.onRollback...) - tx.mu.Unlock() + txDriver.mu.Lock() + hooks := append([]RollbackHook(nil), txDriver.onRollback...) + txDriver.mu.Unlock() for i := len(hooks) - 1; i >= 0; i-- { fn = hooks[i](fn) } @@ -148,9 +143,10 @@ func (tx *Tx) Rollback() error { // OnRollback adds a hook to call on rollback. func (tx *Tx) OnRollback(f RollbackHook) { - tx.mu.Lock() - defer tx.mu.Unlock() - tx.onRollback = append(tx.onRollback, f) + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onRollback = append(txDriver.onRollback, f) + txDriver.mu.Unlock() } // Client returns a Client that binds to current transaction. @@ -189,6 +185,10 @@ type txDriver struct { drv dialect.Driver // tx is the underlying transaction. tx dialect.Tx + // completion hooks. + mu sync.Mutex + onCommit []CommitHook + onRollback []RollbackHook } // newTx creates a new transactional driver. @@ -219,12 +219,12 @@ func (*txDriver) Commit() error { return nil } func (*txDriver) Rollback() error { return nil } // Exec calls tx.Exec. -func (tx *txDriver) Exec(ctx context.Context, query string, args, v interface{}) error { +func (tx *txDriver) Exec(ctx context.Context, query string, args, v any) error { return tx.tx.Exec(ctx, query, args, v) } // Query calls tx.Query. -func (tx *txDriver) Query(ctx context.Context, query string, args, v interface{}) error { +func (tx *txDriver) Query(ctx context.Context, query string, args, v any) error { return tx.tx.Query(ctx, query, args, v) } diff --git a/internal/db/ent/uid.go b/internal/db_impl/ent_db/internal/uid.go similarity index 89% rename from internal/db/ent/uid.go rename to internal/db_impl/ent_db/internal/uid.go index 5eb4b9f9..bd1ecd40 100644 --- a/internal/db/ent/uid.go +++ b/internal/db_impl/ent_db/internal/uid.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -8,9 +8,9 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // UID is the model entity for the UID schema. @@ -69,8 +69,8 @@ func (e UIDEdges) MailboxOrErr() (*Mailbox, error) { } // scanValues returns the types for scanning values from sql.Rows. -func (*UID) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*UID) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case uid.FieldDeleted, uid.FieldRecent: @@ -90,7 +90,7 @@ func (*UID) scanValues(columns []string) ([]interface{}, error) { // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the UID fields. -func (u *UID) assignValues(columns []string, values []interface{}) error { +func (u *UID) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -141,19 +141,19 @@ func (u *UID) assignValues(columns []string, values []interface{}) error { // QueryMessage queries the "message" edge of the UID entity. func (u *UID) QueryMessage() *MessageQuery { - return (&UIDClient{config: u.config}).QueryMessage(u) + return NewUIDClient(u.config).QueryMessage(u) } // QueryMailbox queries the "mailbox" edge of the UID entity. func (u *UID) QueryMailbox() *MailboxQuery { - return (&UIDClient{config: u.config}).QueryMailbox(u) + return NewUIDClient(u.config).QueryMailbox(u) } // Update returns a builder for updating this UID. // Note that you need to call UID.Unwrap() before calling this method if this UID // was returned from a transaction, and the transaction was committed or rolled back. func (u *UID) Update() *UIDUpdateOne { - return (&UIDClient{config: u.config}).UpdateOne(u) + return NewUIDClient(u.config).UpdateOne(u) } // Unwrap unwraps the UID entity that was returned from a transaction after it was closed, @@ -161,7 +161,7 @@ func (u *UID) Update() *UIDUpdateOne { func (u *UID) Unwrap() *UID { _tx, ok := u.config.driver.(*txDriver) if !ok { - panic("ent: UID is not a transactional entity") + panic("internal: UID is not a transactional entity") } u.config.driver = _tx.drv return u @@ -186,9 +186,3 @@ func (u *UID) String() string { // UIDs is a parsable slice of UID. type UIDs []*UID - -func (u UIDs) config(cfg config) { - for _i := range u { - u[_i].config = cfg - } -} diff --git a/internal/db/ent/uid/uid.go b/internal/db_impl/ent_db/internal/uid/uid.go similarity index 100% rename from internal/db/ent/uid/uid.go rename to internal/db_impl/ent_db/internal/uid/uid.go diff --git a/internal/db/ent/uid/where.go b/internal/db_impl/ent_db/internal/uid/where.go similarity index 67% rename from internal/db/ent/uid/where.go rename to internal/db_impl/ent_db/internal/uid/where.go index 2fc79f1b..2564df99 100644 --- a/internal/db/ent/uid/where.go +++ b/internal/db_impl/ent_db/internal/uid/where.go @@ -6,198 +6,142 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.UID(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.UID(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.UID(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.UID(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.UID(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.UID(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.UID(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.UID(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.UID(sql.FieldLTE(FieldID, id)) } // UID applies equality check predicate on the "UID" field. It's identical to UIDEQ. func UID(v imap.UID) predicate.UID { vc := uint32(v) - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUID), vc)) - }) + return predicate.UID(sql.FieldEQ(FieldUID, vc)) } // Deleted applies equality check predicate on the "Deleted" field. It's identical to DeletedEQ. func Deleted(v bool) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldDeleted), v)) - }) + return predicate.UID(sql.FieldEQ(FieldDeleted, v)) } // Recent applies equality check predicate on the "Recent" field. It's identical to RecentEQ. func Recent(v bool) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRecent), v)) - }) + return predicate.UID(sql.FieldEQ(FieldRecent, v)) } // UIDEQ applies the EQ predicate on the "UID" field. func UIDEQ(v imap.UID) predicate.UID { vc := uint32(v) - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUID), vc)) - }) + return predicate.UID(sql.FieldEQ(FieldUID, vc)) } // UIDNEQ applies the NEQ predicate on the "UID" field. func UIDNEQ(v imap.UID) predicate.UID { vc := uint32(v) - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUID), vc)) - }) + return predicate.UID(sql.FieldNEQ(FieldUID, vc)) } // UIDIn applies the In predicate on the "UID" field. func UIDIn(vs ...imap.UID) predicate.UID { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = uint32(vs[i]) } - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUID), v...)) - }) + return predicate.UID(sql.FieldIn(FieldUID, v...)) } // UIDNotIn applies the NotIn predicate on the "UID" field. func UIDNotIn(vs ...imap.UID) predicate.UID { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = uint32(vs[i]) } - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUID), v...)) - }) + return predicate.UID(sql.FieldNotIn(FieldUID, v...)) } // UIDGT applies the GT predicate on the "UID" field. func UIDGT(v imap.UID) predicate.UID { vc := uint32(v) - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUID), vc)) - }) + return predicate.UID(sql.FieldGT(FieldUID, vc)) } // UIDGTE applies the GTE predicate on the "UID" field. func UIDGTE(v imap.UID) predicate.UID { vc := uint32(v) - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUID), vc)) - }) + return predicate.UID(sql.FieldGTE(FieldUID, vc)) } // UIDLT applies the LT predicate on the "UID" field. func UIDLT(v imap.UID) predicate.UID { vc := uint32(v) - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUID), vc)) - }) + return predicate.UID(sql.FieldLT(FieldUID, vc)) } // UIDLTE applies the LTE predicate on the "UID" field. func UIDLTE(v imap.UID) predicate.UID { vc := uint32(v) - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUID), vc)) - }) + return predicate.UID(sql.FieldLTE(FieldUID, vc)) } // DeletedEQ applies the EQ predicate on the "Deleted" field. func DeletedEQ(v bool) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldDeleted), v)) - }) + return predicate.UID(sql.FieldEQ(FieldDeleted, v)) } // DeletedNEQ applies the NEQ predicate on the "Deleted" field. func DeletedNEQ(v bool) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldDeleted), v)) - }) + return predicate.UID(sql.FieldNEQ(FieldDeleted, v)) } // RecentEQ applies the EQ predicate on the "Recent" field. func RecentEQ(v bool) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRecent), v)) - }) + return predicate.UID(sql.FieldEQ(FieldRecent, v)) } // RecentNEQ applies the NEQ predicate on the "Recent" field. func RecentNEQ(v bool) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldRecent), v)) - }) + return predicate.UID(sql.FieldNEQ(FieldRecent, v)) } // HasMessage applies the HasEdge predicate on the "message" edge. @@ -205,7 +149,6 @@ func HasMessage() predicate.UID { return predicate.UID(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(MessageTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, false, MessageTable, MessageColumn), ) sqlgraph.HasNeighbors(s, step) @@ -233,7 +176,6 @@ func HasMailbox() predicate.UID { return predicate.UID(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(MailboxTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, MailboxTable, MailboxColumn), ) sqlgraph.HasNeighbors(s, step) diff --git a/internal/db/ent/uid_create.go b/internal/db_impl/ent_db/internal/uid_create.go similarity index 78% rename from internal/db/ent/uid_create.go rename to internal/db_impl/ent_db/internal/uid_create.go index 59d5565a..e813b610 100644 --- a/internal/db/ent/uid_create.go +++ b/internal/db_impl/ent_db/internal/uid_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,9 +10,9 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // UIDCreate is the builder for creating a UID entity. @@ -101,50 +101,8 @@ func (uc *UIDCreate) Mutation() *UIDMutation { // Save creates the UID in the database. func (uc *UIDCreate) Save(ctx context.Context) (*UID, error) { - var ( - err error - node *UID - ) uc.defaults() - if len(uc.hooks) == 0 { - if err = uc.check(); err != nil { - return nil, err - } - node, err = uc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*UIDMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = uc.check(); err != nil { - return nil, err - } - uc.mutation = mutation - if node, err = uc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(uc.hooks) - 1; i >= 0; i-- { - if uc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = uc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, uc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*UID) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from UIDMutation", v) - } - node = nv - } - return node, err + return withHooks[*UID, UIDMutation](ctx, uc.sqlSave, uc.mutation, uc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -184,18 +142,21 @@ func (uc *UIDCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (uc *UIDCreate) check() error { if _, ok := uc.mutation.UID(); !ok { - return &ValidationError{Name: "UID", err: errors.New(`ent: missing required field "UID.UID"`)} + return &ValidationError{Name: "UID", err: errors.New(`internal: missing required field "UID.UID"`)} } if _, ok := uc.mutation.Deleted(); !ok { - return &ValidationError{Name: "Deleted", err: errors.New(`ent: missing required field "UID.Deleted"`)} + return &ValidationError{Name: "Deleted", err: errors.New(`internal: missing required field "UID.Deleted"`)} } if _, ok := uc.mutation.Recent(); !ok { - return &ValidationError{Name: "Recent", err: errors.New(`ent: missing required field "UID.Recent"`)} + return &ValidationError{Name: "Recent", err: errors.New(`internal: missing required field "UID.Recent"`)} } return nil } func (uc *UIDCreate) sqlSave(ctx context.Context) (*UID, error) { + if err := uc.check(); err != nil { + return nil, err + } _node, _spec := uc.createSpec() if err := sqlgraph.CreateNode(ctx, uc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -205,42 +166,26 @@ func (uc *UIDCreate) sqlSave(ctx context.Context) (*UID, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + uc.mutation.id = &_node.ID + uc.mutation.done = true return _node, nil } func (uc *UIDCreate) createSpec() (*UID, *sqlgraph.CreateSpec) { var ( _node = &UID{config: uc.config} - _spec = &sqlgraph.CreateSpec{ - Table: uid.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: uid.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(uid.Table, sqlgraph.NewFieldSpec(uid.FieldID, field.TypeInt)) ) if value, ok := uc.mutation.UID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: uid.FieldUID, - }) + _spec.SetField(uid.FieldUID, field.TypeUint32, value) _node.UID = value } if value, ok := uc.mutation.Deleted(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: uid.FieldDeleted, - }) + _spec.SetField(uid.FieldDeleted, field.TypeBool, value) _node.Deleted = value } if value, ok := uc.mutation.Recent(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: uid.FieldRecent, - }) + _spec.SetField(uid.FieldRecent, field.TypeBool, value) _node.Recent = value } if nodes := uc.mutation.MessageIDs(); len(nodes) > 0 { diff --git a/internal/db/ent/uid_delete.go b/internal/db_impl/ent_db/internal/uid_delete.go similarity index 61% rename from internal/db/ent/uid_delete.go rename to internal/db_impl/ent_db/internal/uid_delete.go index fb7be4ac..56000ba7 100644 --- a/internal/db/ent/uid_delete.go +++ b/internal/db_impl/ent_db/internal/uid_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // UIDDelete is the builder for deleting a UID entity. @@ -28,34 +27,7 @@ func (ud *UIDDelete) Where(ps ...predicate.UID) *UIDDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (ud *UIDDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(ud.hooks) == 0 { - affected, err = ud.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*UIDMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ud.mutation = mutation - affected, err = ud.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(ud.hooks) - 1; i >= 0; i-- { - if ud.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ud.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, ud.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, UIDMutation](ctx, ud.sqlExec, ud.mutation, ud.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (ud *UIDDelete) ExecX(ctx context.Context) int { } func (ud *UIDDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: uid.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: uid.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(uid.Table, sqlgraph.NewFieldSpec(uid.FieldID, field.TypeInt)) if ps := ud.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (ud *UIDDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + ud.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type UIDDeleteOne struct { ud *UIDDelete } +// Where appends a list predicates to the UIDDelete builder. +func (udo *UIDDeleteOne) Where(ps ...predicate.UID) *UIDDeleteOne { + udo.ud.mutation.Where(ps...) + return udo +} + // Exec executes the deletion query. func (udo *UIDDeleteOne) Exec(ctx context.Context) error { n, err := udo.ud.Exec(ctx) @@ -111,5 +82,7 @@ func (udo *UIDDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (udo *UIDDeleteOne) ExecX(ctx context.Context) { - udo.ud.ExecX(ctx) + if err := udo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/uid_query.go b/internal/db_impl/ent_db/internal/uid_query.go similarity index 72% rename from internal/db/ent/uid_query.go rename to internal/db_impl/ent_db/internal/uid_query.go index 89e724b5..7e5bef5c 100644 --- a/internal/db/ent/uid_query.go +++ b/internal/db_impl/ent_db/internal/uid_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -11,20 +11,18 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // UIDQuery is the builder for querying UID entities. type UIDQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.UID withMessage *MessageQuery withMailbox *MailboxQuery @@ -40,26 +38,26 @@ func (uq *UIDQuery) Where(ps ...predicate.UID) *UIDQuery { return uq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (uq *UIDQuery) Limit(limit int) *UIDQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } -// Offset adds an offset step to the query. +// Offset to start from. func (uq *UIDQuery) Offset(offset int) *UIDQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UIDQuery) Unique(unique bool) *UIDQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (uq *UIDQuery) Order(o ...OrderFunc) *UIDQuery { uq.order = append(uq.order, o...) return uq @@ -67,7 +65,7 @@ func (uq *UIDQuery) Order(o ...OrderFunc) *UIDQuery { // QueryMessage chains the current query on the "message" edge. func (uq *UIDQuery) QueryMessage() *MessageQuery { - query := &MessageQuery{config: uq.config} + query := (&MessageClient{config: uq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := uq.prepareQuery(ctx); err != nil { return nil, err @@ -89,7 +87,7 @@ func (uq *UIDQuery) QueryMessage() *MessageQuery { // QueryMailbox chains the current query on the "mailbox" edge. func (uq *UIDQuery) QueryMailbox() *MailboxQuery { - query := &MailboxQuery{config: uq.config} + query := (&MailboxClient{config: uq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := uq.prepareQuery(ctx); err != nil { return nil, err @@ -112,7 +110,7 @@ func (uq *UIDQuery) QueryMailbox() *MailboxQuery { // First returns the first UID entity from the query. // Returns a *NotFoundError when no UID was found. func (uq *UIDQuery) First(ctx context.Context) (*UID, error) { - nodes, err := uq.Limit(1).All(ctx) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -135,7 +133,7 @@ func (uq *UIDQuery) FirstX(ctx context.Context) *UID { // Returns a *NotFoundError when no UID ID was found. func (uq *UIDQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(ctx); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -158,7 +156,7 @@ func (uq *UIDQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one UID entity is found. // Returns a *NotFoundError when no UID entities are found. func (uq *UIDQuery) Only(ctx context.Context) (*UID, error) { - nodes, err := uq.Limit(2).All(ctx) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -186,7 +184,7 @@ func (uq *UIDQuery) OnlyX(ctx context.Context) *UID { // Returns a *NotFoundError when no entities are found. func (uq *UIDQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(ctx); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -211,10 +209,12 @@ func (uq *UIDQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of UIDs. func (uq *UIDQuery) All(ctx context.Context) ([]*UID, error) { + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } - return uq.sqlAll(ctx) + qr := querierAll[[]*UID, *UIDQuery]() + return withInterceptors[[]*UID](ctx, uq, qr, uq.inters) } // AllX is like All, but panics if an error occurs. @@ -227,9 +227,12 @@ func (uq *UIDQuery) AllX(ctx context.Context) []*UID { } // IDs executes the query and returns a list of UID IDs. -func (uq *UIDQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := uq.Select(uid.FieldID).Scan(ctx, &ids); err != nil { +func (uq *UIDQuery) IDs(ctx context.Context) (ids []int, err error) { + if uq.ctx.Unique == nil && uq.path != nil { + uq.Unique(true) + } + ctx = setContextOp(ctx, uq.ctx, "IDs") + if err = uq.Select(uid.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -246,10 +249,11 @@ func (uq *UIDQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UIDQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } - return uq.sqlCount(ctx) + return withInterceptors[int](ctx, uq, querierCount[*UIDQuery](), uq.inters) } // CountX is like Count, but panics if an error occurs. @@ -263,10 +267,15 @@ func (uq *UIDQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UIDQuery) Exist(ctx context.Context) (bool, error) { - if err := uq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, uq.ctx, "Exist") + switch _, err := uq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return uq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -286,23 +295,22 @@ func (uq *UIDQuery) Clone() *UIDQuery { } return &UIDQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), + inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.UID{}, uq.predicates...), withMessage: uq.withMessage.Clone(), withMailbox: uq.withMailbox.Clone(), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } // WithMessage tells the query-builder to eager-load the nodes that are connected to // the "message" edge. The optional arguments are used to configure the query builder of the edge. func (uq *UIDQuery) WithMessage(opts ...func(*MessageQuery)) *UIDQuery { - query := &MessageQuery{config: uq.config} + query := (&MessageClient{config: uq.config}).Query() for _, opt := range opts { opt(query) } @@ -313,7 +321,7 @@ func (uq *UIDQuery) WithMessage(opts ...func(*MessageQuery)) *UIDQuery { // WithMailbox tells the query-builder to eager-load the nodes that are connected to // the "mailbox" edge. The optional arguments are used to configure the query builder of the edge. func (uq *UIDQuery) WithMailbox(opts ...func(*MailboxQuery)) *UIDQuery { - query := &MailboxQuery{config: uq.config} + query := (&MailboxClient{config: uq.config}).Query() for _, opt := range opts { opt(query) } @@ -333,19 +341,14 @@ func (uq *UIDQuery) WithMailbox(opts ...func(*MailboxQuery)) *UIDQuery { // // client.UID.Query(). // GroupBy(uid.FieldUID). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (uq *UIDQuery) GroupBy(field string, fields ...string) *UIDGroupBy { - grbuild := &UIDGroupBy{config: uq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := uq.prepareQuery(ctx); err != nil { - return nil, err - } - return uq.sqlQuery(ctx), nil - } + uq.ctx.Fields = append([]string{field}, fields...) + grbuild := &UIDGroupBy{build: uq} + grbuild.flds = &uq.ctx.Fields grbuild.label = uid.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -362,17 +365,32 @@ func (uq *UIDQuery) GroupBy(field string, fields ...string) *UIDGroupBy { // Select(uid.FieldUID). // Scan(ctx, &v) func (uq *UIDQuery) Select(fields ...string) *UIDSelect { - uq.fields = append(uq.fields, fields...) - selbuild := &UIDSelect{UIDQuery: uq} - selbuild.label = uid.Label - selbuild.flds, selbuild.scan = &uq.fields, selbuild.Scan - return selbuild + uq.ctx.Fields = append(uq.ctx.Fields, fields...) + sbuild := &UIDSelect{UIDQuery: uq} + sbuild.label = uid.Label + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a UIDSelect configured with the given aggregations. +func (uq *UIDQuery) Aggregate(fns ...AggregateFunc) *UIDSelect { + return uq.Select().Aggregate(fns...) } func (uq *UIDQuery) prepareQuery(ctx context.Context) error { - for _, f := range uq.fields { + for _, inter := range uq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, uq); err != nil { + return err + } + } + } + for _, f := range uq.ctx.Fields { if !uid.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if uq.path != nil { @@ -401,10 +419,10 @@ func (uq *UIDQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UID, err if withFKs { _spec.Node.Columns = append(_spec.Node.Columns, uid.ForeignKeys...) } - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*UID).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &UID{config: uq.config} nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes @@ -447,6 +465,9 @@ func (uq *UIDQuery) loadMessage(ctx context.Context, query *MessageQuery, nodes } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(message.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -476,6 +497,9 @@ func (uq *UIDQuery) loadMailbox(ctx context.Context, query *MailboxQuery, nodes } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(mailbox.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -495,38 +519,22 @@ func (uq *UIDQuery) loadMailbox(ctx context.Context, query *MailboxQuery, nodes func (uq *UIDQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } -func (uq *UIDQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := uq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (uq *UIDQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: uid.Table, - Columns: uid.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: uid.FieldID, - }, - }, - From: uq.sql, - Unique: true, - } - if unique := uq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(uid.Table, uid.Columns, sqlgraph.NewFieldSpec(uid.FieldID, field.TypeInt)) + _spec.From = uq.sql + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if uq.path != nil { + _spec.Unique = true } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, uid.FieldID) for i := range fields { @@ -542,10 +550,10 @@ func (uq *UIDQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -561,7 +569,7 @@ func (uq *UIDQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UIDQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(uid.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = uid.Columns } @@ -570,7 +578,7 @@ func (uq *UIDQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, p := range uq.predicates { @@ -579,12 +587,12 @@ func (uq *UIDQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -592,13 +600,8 @@ func (uq *UIDQuery) sqlQuery(ctx context.Context) *sql.Selector { // UIDGroupBy is the group-by builder for UID entities. type UIDGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *UIDQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -607,74 +610,77 @@ func (ugb *UIDGroupBy) Aggregate(fns ...AggregateFunc) *UIDGroupBy { return ugb } -// Scan applies the group-by query and scans the result into the given value. -func (ugb *UIDGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := ugb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (ugb *UIDGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") + if err := ugb.build.prepareQuery(ctx); err != nil { return err } - ugb.sql = query - return ugb.sqlScan(ctx, v) + return scanWithInterceptors[*UIDQuery, *UIDGroupBy](ctx, ugb.build, ugb, ugb.build.inters, v) } -func (ugb *UIDGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range ugb.fields { - if !uid.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (ugb *UIDGroupBy) sqlScan(ctx context.Context, root *UIDQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(ugb.fns)) + for _, fn := range ugb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*ugb.flds)+len(ugb.fns)) + for _, f := range *ugb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := ugb.sqlQuery() + selector.GroupBy(selector.Columns(*ugb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := ugb.driver.Query(ctx, query, args, rows); err != nil { + if err := ugb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (ugb *UIDGroupBy) sqlQuery() *sql.Selector { - selector := ugb.sql.Select() - aggregation := make([]string, 0, len(ugb.fns)) - for _, fn := range ugb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) - for _, f := range ugb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(ugb.fields...)...) -} - // UIDSelect is the builder for selecting fields of UID entities. type UIDSelect struct { *UIDQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (us *UIDSelect) Aggregate(fns ...AggregateFunc) *UIDSelect { + us.fns = append(us.fns, fns...) + return us } // Scan applies the selector query and scans the result into the given value. -func (us *UIDSelect) Scan(ctx context.Context, v interface{}) error { +func (us *UIDSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } - us.sql = us.UIDQuery.sqlQuery(ctx) - return us.sqlScan(ctx, v) + return scanWithInterceptors[*UIDQuery, *UIDSelect](ctx, us.UIDQuery, us, us.inters, v) } -func (us *UIDSelect) sqlScan(ctx context.Context, v interface{}) error { +func (us *UIDSelect) sqlScan(ctx context.Context, root *UIDQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(us.fns)) + for _, fn := range us.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*us.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := us.sql.Query() + query, args := selector.Query() if err := us.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/uid_update.go b/internal/db_impl/ent_db/internal/uid_update.go similarity index 78% rename from internal/db/ent/uid_update.go rename to internal/db_impl/ent_db/internal/uid_update.go index 8eb1a062..2d28c5fd 100644 --- a/internal/db/ent/uid_update.go +++ b/internal/db_impl/ent_db/internal/uid_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -11,10 +11,10 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // UIDUpdate is the builder for updating UID entities. @@ -128,34 +128,7 @@ func (uu *UIDUpdate) ClearMailbox() *UIDUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (uu *UIDUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(uu.hooks) == 0 { - affected, err = uu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*UIDMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - uu.mutation = mutation - affected, err = uu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(uu.hooks) - 1; i >= 0; i-- { - if uu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = uu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, uu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, UIDMutation](ctx, uu.sqlSave, uu.mutation, uu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -181,16 +154,7 @@ func (uu *UIDUpdate) ExecX(ctx context.Context) { } func (uu *UIDUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: uid.Table, - Columns: uid.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: uid.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(uid.Table, uid.Columns, sqlgraph.NewFieldSpec(uid.FieldID, field.TypeInt)) if ps := uu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -199,32 +163,16 @@ func (uu *UIDUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := uu.mutation.UID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: uid.FieldUID, - }) + _spec.SetField(uid.FieldUID, field.TypeUint32, value) } if value, ok := uu.mutation.AddedUID(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: uid.FieldUID, - }) + _spec.AddField(uid.FieldUID, field.TypeUint32, value) } if value, ok := uu.mutation.Deleted(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: uid.FieldDeleted, - }) + _spec.SetField(uid.FieldDeleted, field.TypeBool, value) } if value, ok := uu.mutation.Recent(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: uid.FieldRecent, - }) + _spec.SetField(uid.FieldRecent, field.TypeBool, value) } if uu.mutation.MessageCleared() { edge := &sqlgraph.EdgeSpec{ @@ -304,6 +252,7 @@ func (uu *UIDUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + uu.mutation.done = true return n, nil } @@ -411,6 +360,12 @@ func (uuo *UIDUpdateOne) ClearMailbox() *UIDUpdateOne { return uuo } +// Where appends a list predicates to the UIDUpdate builder. +func (uuo *UIDUpdateOne) Where(ps ...predicate.UID) *UIDUpdateOne { + uuo.mutation.Where(ps...) + return uuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (uuo *UIDUpdateOne) Select(field string, fields ...string) *UIDUpdateOne { @@ -420,40 +375,7 @@ func (uuo *UIDUpdateOne) Select(field string, fields ...string) *UIDUpdateOne { // Save executes the query and returns the updated UID entity. func (uuo *UIDUpdateOne) Save(ctx context.Context) (*UID, error) { - var ( - err error - node *UID - ) - if len(uuo.hooks) == 0 { - node, err = uuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*UIDMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - uuo.mutation = mutation - node, err = uuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(uuo.hooks) - 1; i >= 0; i-- { - if uuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = uuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, uuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*UID) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from UIDMutation", v) - } - node = nv - } - return node, err + return withHooks[*UID, UIDMutation](ctx, uuo.sqlSave, uuo.mutation, uuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -479,19 +401,10 @@ func (uuo *UIDUpdateOne) ExecX(ctx context.Context) { } func (uuo *UIDUpdateOne) sqlSave(ctx context.Context) (_node *UID, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: uid.Table, - Columns: uid.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: uid.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(uid.Table, uid.Columns, sqlgraph.NewFieldSpec(uid.FieldID, field.TypeInt)) id, ok := uuo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UID.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "UID.id" for update`)} } _spec.Node.ID.Value = id if fields := uuo.fields; len(fields) > 0 { @@ -499,7 +412,7 @@ func (uuo *UIDUpdateOne) sqlSave(ctx context.Context) (_node *UID, err error) { _spec.Node.Columns = append(_spec.Node.Columns, uid.FieldID) for _, f := range fields { if !uid.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != uid.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -514,32 +427,16 @@ func (uuo *UIDUpdateOne) sqlSave(ctx context.Context) (_node *UID, err error) { } } if value, ok := uuo.mutation.UID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: uid.FieldUID, - }) + _spec.SetField(uid.FieldUID, field.TypeUint32, value) } if value, ok := uuo.mutation.AddedUID(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: uid.FieldUID, - }) + _spec.AddField(uid.FieldUID, field.TypeUint32, value) } if value, ok := uuo.mutation.Deleted(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: uid.FieldDeleted, - }) + _spec.SetField(uid.FieldDeleted, field.TypeBool, value) } if value, ok := uuo.mutation.Recent(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: uid.FieldRecent, - }) + _spec.SetField(uid.FieldRecent, field.TypeBool, value) } if uuo.mutation.MessageCleared() { edge := &sqlgraph.EdgeSpec{ @@ -622,5 +519,6 @@ func (uuo *UIDUpdateOne) sqlSave(ctx context.Context) (_node *UID, err error) { } return nil, err } + uuo.mutation.done = true return _node, nil } diff --git a/internal/db/mailbox.go b/internal/db_impl/ent_db/mailbox.go similarity index 62% rename from internal/db/mailbox.go rename to internal/db_impl/ent_db/mailbox.go index ddc1a880..6355a7cf 100644 --- a/internal/db/mailbox.go +++ b/internal/db_impl/ent_db/mailbox.go @@ -1,4 +1,4 @@ -package db +package ent_db import ( "context" @@ -6,24 +6,25 @@ import ( "strings" "entgo.io/ent/dialect/sql" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" "github.com/ProtonMail/gluon/internal/ids" "github.com/bradenaw/juniper/xslices" ) func CreateMailbox( ctx context.Context, - tx *ent.Tx, + tx *internal.Tx, mboxID imap.MailboxID, name string, flags, permFlags, attrs imap.FlagSet, uidValidity imap.UID, -) (*ent.Mailbox, error) { +) (*internal.Mailbox, error) { create := tx.Mailbox.Create(). SetName(name). SetUIDValidity(uidValidity) @@ -73,15 +74,15 @@ func CreateMailbox( return mbox, nil } -func MailboxExistsWithID(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (bool, error) { +func MailboxExistsWithID(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (bool, error) { return client.Mailbox.Query().Where(mailbox.ID(mboxID)).Exist(ctx) } -func MailboxExistsWithRemoteID(ctx context.Context, client *ent.Client, mboxID imap.MailboxID) (bool, error) { +func MailboxExistsWithRemoteID(ctx context.Context, client *internal.Client, mboxID imap.MailboxID) (bool, error) { return client.Mailbox.Query().Where(mailbox.RemoteID(mboxID)).Exist(ctx) } -func GetMailboxIDFromRemoteID(ctx context.Context, client *ent.Client, mboxID imap.MailboxID) (imap.InternalMailboxID, error) { +func GetMailboxIDFromRemoteID(ctx context.Context, client *internal.Client, mboxID imap.MailboxID) (imap.InternalMailboxID, error) { mbox, err := client.Mailbox.Query().Where(mailbox.RemoteID(mboxID)).Select(mailbox.FieldID).Only(ctx) if err != nil { return 0, err @@ -90,11 +91,11 @@ func GetMailboxIDFromRemoteID(ctx context.Context, client *ent.Client, mboxID im return mbox.ID, nil } -func MailboxExistsWithName(ctx context.Context, client *ent.Client, name string) (bool, error) { +func MailboxExistsWithName(ctx context.Context, client *internal.Client, name string) (bool, error) { return client.Mailbox.Query().Where(mailbox.Name(name)).Exist(ctx) } -func RenameMailboxWithRemoteID(ctx context.Context, tx *ent.Tx, mboxID imap.MailboxID, name string) error { +func RenameMailboxWithRemoteID(ctx context.Context, tx *internal.Tx, mboxID imap.MailboxID, name string) error { if _, err := tx.Mailbox.Update(). Where(mailbox.RemoteID(mboxID)). SetName(name). @@ -109,7 +110,7 @@ func RenameMailboxWithRemoteID(ctx context.Context, tx *ent.Tx, mboxID imap.Mail // It returns the (potentially new) global UID validity. func DeleteMailboxWithRemoteID( ctx context.Context, - tx *ent.Tx, + tx *internal.Tx, mboxID imap.MailboxID, ) error { mbox, err := tx.Mailbox.Query().Where(mailbox.RemoteID(mboxID)).Select(mailbox.FieldSubscribed, mailbox.FieldName).Only(ctx) @@ -130,7 +131,7 @@ func DeleteMailboxWithRemoteID( return nil } -func UpdateRemoteMailboxID(ctx context.Context, tx *ent.Tx, internalID imap.InternalMailboxID, remoteID imap.MailboxID) error { +func UpdateRemoteMailboxID(ctx context.Context, tx *internal.Tx, internalID imap.InternalMailboxID, remoteID imap.MailboxID) error { if _, err := tx.Mailbox.Update(). Where(mailbox.ID(internalID)). SetRemoteID(remoteID). @@ -141,7 +142,7 @@ func UpdateRemoteMailboxID(ctx context.Context, tx *ent.Tx, internalID imap.Inte return nil } -func BumpMailboxUIDNext(ctx context.Context, tx *ent.Tx, mbox *ent.Mailbox, withCount ...int) error { +func BumpMailboxUIDNext(ctx context.Context, tx *internal.Tx, mbox *internal.Mailbox, withCount ...int) error { var n int if len(withCount) > 0 { @@ -159,7 +160,7 @@ func BumpMailboxUIDNext(ctx context.Context, tx *ent.Tx, mbox *ent.Mailbox, with return nil } -func GetMailboxName(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (string, error) { +func GetMailboxName(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (string, error) { mailbox, err := client.Mailbox.Query().Where(mailbox.ID(mboxID)).Select(mailbox.FieldName).Only(ctx) if err != nil { return "", err @@ -168,7 +169,7 @@ func GetMailboxName(ctx context.Context, client *ent.Client, mboxID imap.Interna return mailbox.Name, nil } -func GetMailboxNameWithRemoteID(ctx context.Context, client *ent.Client, mboxID imap.MailboxID) (string, error) { +func GetMailboxNameWithRemoteID(ctx context.Context, client *internal.Client, mboxID imap.MailboxID) (string, error) { mailbox, err := client.Mailbox.Query().Where(mailbox.RemoteID(mboxID)).Select(mailbox.FieldName).Only(ctx) if err != nil { return "", err @@ -177,7 +178,7 @@ func GetMailboxNameWithRemoteID(ctx context.Context, client *ent.Client, mboxID return mailbox.Name, nil } -func GetMailboxMessageIDPairs(ctx context.Context, client *ent.Client, mailboxID imap.InternalMailboxID) ([]ids.MessageIDPair, error) { +func GetMailboxMessageIDPairs(ctx context.Context, client *internal.Client, mailboxID imap.InternalMailboxID) ([]db.MessageIDPair, error) { messages, err := client.Message.Query(). Where(message.HasUIDsWith(uid.HasMailboxWith(mailbox.ID(mailboxID)))). Select(message.FieldID, message.FieldRemoteID). @@ -186,15 +187,18 @@ func GetMailboxMessageIDPairs(ctx context.Context, client *ent.Client, mailboxID return nil, err } - return xslices.Map(messages, func(message *ent.Message) ids.MessageIDPair { - return ids.NewMessageIDPair(message) + return xslices.Map(messages, func(message *internal.Message) db.MessageIDPair { + return db.MessageIDPair{ + InternalID: message.ID, + RemoteID: message.RemoteID, + } }), nil } -func GetAllMailboxes(ctx context.Context, client *ent.Client) ([]*ent.Mailbox, error) { +func GetAllMailboxes(ctx context.Context, client *internal.Client) ([]*internal.Mailbox, error) { const QueryLimit = 16000 - var mailboxes []*ent.Mailbox + var mailboxes []*internal.Mailbox for i := 0; ; i += QueryLimit { result, err := client.Mailbox.Query(). @@ -217,60 +221,30 @@ func GetAllMailboxes(ctx context.Context, client *ent.Client) ([]*ent.Mailbox, e return mailboxes, nil } -func GetMailboxByName(ctx context.Context, client *ent.Client, name string) (*ent.Mailbox, error) { +func GetMailboxByName(ctx context.Context, client *internal.Client, name string) (*internal.Mailbox, error) { return client.Mailbox.Query().Where(mailbox.Name(name)).Only(ctx) } -func GetMailboxByID(ctx context.Context, client *ent.Client, id imap.InternalMailboxID) (*ent.Mailbox, error) { +func GetMailboxByID(ctx context.Context, client *internal.Client, id imap.InternalMailboxID) (*internal.Mailbox, error) { return client.Mailbox.Query().Where(mailbox.ID(id)).Only(ctx) } -func GetMailboxByRemoteID(ctx context.Context, client *ent.Client, id imap.MailboxID) (*ent.Mailbox, error) { +func GetMailboxByRemoteID(ctx context.Context, client *internal.Client, id imap.MailboxID) (*internal.Mailbox, error) { return client.Mailbox.Query().Where(mailbox.RemoteID(id)).Only(ctx) } -func GetMailboxRecentCount(ctx context.Context, client *ent.Client, mbox *ent.Mailbox) (int, error) { +func GetMailboxRecentCount(ctx context.Context, client *internal.Client, mbox *internal.Mailbox) (int, error) { return mbox.QueryUIDs().Where(uid.Recent(true)).Count(ctx) } -func GetMailboxMessageCount(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (int, error) { +func GetMailboxMessageCount(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (int, error) { return client.UID.Query().Where(func(s *sql.Selector) { s.Where(sql.EQ(uid.MailboxColumn, mboxID)) }).Count(ctx) } -type SnapshotMessageResult struct { - InternalID imap.InternalMessageID `json:"uid_message"` - RemoteID imap.MessageID `json:"remote_id"` - UID imap.UID `json:"uid"` - Recent bool `json:"recent"` - Deleted bool `json:"deleted"` - Flags string `json:"flags"` -} - -func (msg *SnapshotMessageResult) GetFlagSet() imap.FlagSet { - var flagSet imap.FlagSet - - if len(msg.Flags) > 0 { - flags := strings.Split(msg.Flags, ",") - flagSet = imap.NewFlagSetFromSlice(flags) - } else { - flagSet = imap.NewFlagSet() - } - - if msg.Deleted { - flagSet.AddToSelf(imap.FlagDeleted) - } - - if msg.Recent { - flagSet.AddToSelf(imap.FlagRecent) - } - - return flagSet -} - -func GetMailboxMessagesForNewSnapshot(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) ([]SnapshotMessageResult, error) { - messages := make([]SnapshotMessageResult, 0, 32) +func GetMailboxMessagesForNewSnapshot(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) ([]db.SnapshotMessageResult, error) { + messages := make([]db.SnapshotMessageResult, 0, 32) if err := client.UID.Query().Where(func(s *sql.Selector) { msgTable := sql.Table(message.Table) @@ -288,7 +262,7 @@ func GetMailboxMessagesForNewSnapshot(ctx context.Context, client *ent.Client, m return messages, nil } -func GetMailboxIDWithRemoteID(ctx context.Context, client *ent.Client, mboxID imap.MailboxID) (imap.InternalMailboxID, error) { +func GetMailboxIDWithRemoteID(ctx context.Context, client *internal.Client, mboxID imap.MailboxID) (imap.InternalMailboxID, error) { mbox, err := client.Mailbox.Query().Where(mailbox.RemoteID(mboxID)).Select(mailbox.FieldID).Only(ctx) if err != nil { return 0, err @@ -297,18 +271,18 @@ func GetMailboxIDWithRemoteID(ctx context.Context, client *ent.Client, mboxID im return mbox.ID, nil } -func TranslateRemoteMailboxIDs(ctx context.Context, client *ent.Client, mboxIDs []imap.MailboxID) ([]imap.InternalMailboxID, error) { +func TranslateRemoteMailboxIDs(ctx context.Context, client *internal.Client, mboxIDs []imap.MailboxID) ([]imap.InternalMailboxID, error) { mboxes, err := client.Mailbox.Query().Where(mailbox.RemoteIDIn(mboxIDs...)).Select(mailbox.FieldID).All(ctx) if err != nil { return nil, err } - return xslices.Map(mboxes, func(m *ent.Mailbox) imap.InternalMailboxID { + return xslices.Map(mboxes, func(m *internal.Mailbox) imap.InternalMailboxID { return m.ID }), nil } -func CreateMailboxIfNotExists(ctx context.Context, tx *ent.Tx, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) error { +func CreateMailboxIfNotExists(ctx context.Context, tx *internal.Tx, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) error { exists, err := MailboxExistsWithRemoteID(ctx, tx.Client(), mbox.ID) if err != nil { return err @@ -332,10 +306,10 @@ func CreateMailboxIfNotExists(ctx context.Context, tx *ent.Tx, mbox imap.Mailbox return nil } -func GetOrCreateMailbox(ctx context.Context, tx *ent.Tx, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) (*ent.Mailbox, error) { +func GetOrCreateMailbox(ctx context.Context, tx *internal.Tx, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) (*internal.Mailbox, error) { mailbox, err := tx.Mailbox.Query().Where(mailbox.RemoteID(mbox.ID)).Only(ctx) if err != nil { - if !ent.IsNotFound(err) { + if !internal.IsNotFound(err) { return nil, err } } else { @@ -354,7 +328,7 @@ func GetOrCreateMailbox(ctx context.Context, tx *ent.Tx, mbox imap.Mailbox, deli ) } -func FilterMailboxContains(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID, messageIDs []ids.MessageIDPair) ([]imap.InternalMessageID, error) { +func FilterMailboxContains(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID, messageIDs []db.MessageIDPair) ([]imap.InternalMessageID, error) { type result struct { InternalID imap.InternalMessageID `json:"uid_message"` } @@ -362,7 +336,7 @@ func FilterMailboxContains(ctx context.Context, client *ent.Client, mboxID imap. var r []result if err := client.UID.Query().Where(func(s *sql.Selector) { - s.Where(sql.And(sql.EQ(uid.MailboxColumn, mboxID), sql.In(uid.MessageColumn, xslices.Map(messageIDs, func(id ids.MessageIDPair) interface{} { + s.Where(sql.And(sql.EQ(uid.MailboxColumn, mboxID), sql.In(uid.MessageColumn, xslices.Map(messageIDs, func(id db.MessageIDPair) interface{} { return id.InternalID })...))) s.Select(uid.MessageColumn) @@ -375,7 +349,7 @@ func FilterMailboxContains(ctx context.Context, client *ent.Client, mboxID imap. }), nil } -func FilterMailboxContainsInternalID(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]imap.InternalMessageID, error) { +func FilterMailboxContainsInternalID(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]imap.InternalMessageID, error) { type result struct { InternalID imap.InternalMessageID `json:"uid_message"` } @@ -396,51 +370,51 @@ func FilterMailboxContainsInternalID(ctx context.Context, client *ent.Client, mb }), nil } -func GetMailboxFlags(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { +func GetMailboxFlags(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { mbox, err := client.Mailbox.Query().Where(mailbox.ID(mboxID)).WithFlags().Only(ctx) if err != nil { return imap.FlagSet{}, err } - return imap.NewFlagSetFromSlice(xslices.Map(mbox.Edges.Flags, func(flag *ent.MailboxFlag) string { + return imap.NewFlagSetFromSlice(xslices.Map(mbox.Edges.Flags, func(flag *internal.MailboxFlag) string { return flag.Value })), nil } -func GetMailboxPermanentFlags(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { +func GetMailboxPermanentFlags(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { mbox, err := client.Mailbox.Query().Where(mailbox.ID(mboxID)).WithPermanentFlags().Only(ctx) if err != nil { return imap.FlagSet{}, err } - return imap.NewFlagSetFromSlice(xslices.Map(mbox.Edges.PermanentFlags, func(flag *ent.MailboxPermFlag) string { + return imap.NewFlagSetFromSlice(xslices.Map(mbox.Edges.PermanentFlags, func(flag *internal.MailboxPermFlag) string { return flag.Value })), nil } -func GetMailboxAttributes(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { +func GetMailboxAttributes(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { mbox, err := client.Mailbox.Query().Where(mailbox.ID(mboxID)).WithAttributes().Only(ctx) if err != nil { return imap.FlagSet{}, err } - return imap.NewFlagSetFromSlice(xslices.Map(mbox.Edges.Attributes, func(flag *ent.MailboxAttr) string { + return imap.NewFlagSetFromSlice(xslices.Map(mbox.Edges.Attributes, func(flag *internal.MailboxAttr) string { return flag.Value })), nil } -func IsMessageInMailbox(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID, messageID imap.InternalMailboxID) (bool, error) { +func IsMessageInMailbox(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID, messageID imap.InternalMailboxID) (bool, error) { return client.UID.Query().Where(func(s *sql.Selector) { s.Where(sql.And(sql.EQ(uid.MailboxColumn, mboxID), sql.EQ(uid.MessageColumn, messageID))) s.Select(uid.MessageColumn) }).Exist(ctx) } -func GetMailboxCount(ctx context.Context, client *ent.Client) (int, error) { +func GetMailboxCount(ctx context.Context, client *internal.Client) (int, error) { return client.Mailbox.Query().Count(ctx) } -func GetMailboxUID(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (imap.UID, error) { +func GetMailboxUID(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (imap.UID, error) { mbox, err := client.Mailbox.Query().Where(mailbox.ID(mboxID)).Select(mailbox.FieldUIDNext).Only(ctx) if err != nil { return 0, err @@ -449,7 +423,7 @@ func GetMailboxUID(ctx context.Context, client *ent.Client, mboxID imap.Internal return mbox.UIDNext, err } -func GetMailboxMessageCountAndUID(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (int, imap.UID, error) { +func GetMailboxMessageCountAndUID(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (int, imap.UID, error) { messageCount, err := GetMailboxMessageCount(ctx, client, mboxID) if err != nil { return 0, 0, err diff --git a/internal/db/message.go b/internal/db_impl/ent_db/message.go similarity index 68% rename from internal/db/message.go rename to internal/db_impl/ent_db/message.go index 38163e10..e2d34505 100644 --- a/internal/db/message.go +++ b/internal/db_impl/ent_db/message.go @@ -1,38 +1,28 @@ -package db +package ent_db import ( "context" "fmt" - "strings" "entgo.io/ent/dialect/sql" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/uid" - "github.com/ProtonMail/gluon/internal/ids" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" "github.com/bradenaw/juniper/xslices" "golang.org/x/exp/slices" ) -const ChunkLimit = 1000 +const ChunkLimit = db.ChunkLimit -type CreateMessageReq struct { - Message imap.Message - InternalID imap.InternalMessageID - LiteralSize int - Body string - Structure string - Envelope string -} - -func CreateMessages(ctx context.Context, tx *ent.Tx, reqs ...*CreateMessageReq) ([]*ent.Message, error) { - flags := make(map[imap.InternalMessageID][]*ent.MessageFlag) +func CreateMessages(ctx context.Context, tx *internal.Tx, reqs ...*db.CreateMessageReq) ([]*internal.Message, error) { + flags := make(map[imap.InternalMessageID][]*internal.MessageFlag) for _, req := range reqs { - builders := xslices.Map(req.Message.Flags.ToSlice(), func(flag string) *ent.MessageFlagCreate { + builders := xslices.Map(req.Message.Flags.ToSlice(), func(flag string) *internal.MessageFlagCreate { return tx.MessageFlag.Create().SetValue(flag) }) @@ -44,7 +34,7 @@ func CreateMessages(ctx context.Context, tx *ent.Tx, reqs ...*CreateMessageReq) flags[req.InternalID] = entFlags } - builders := xslices.Map(reqs, func(req *CreateMessageReq) *ent.MessageCreate { + builders := xslices.Map(reqs, func(req *db.CreateMessageReq) *internal.MessageCreate { msgCreate := tx.Message.Create(). SetID(req.InternalID). SetDate(req.Message.Date). @@ -61,7 +51,7 @@ func CreateMessages(ctx context.Context, tx *ent.Tx, reqs ...*CreateMessageReq) return msgCreate }) - messages := make([]*ent.Message, 0, len(builders)) + messages := make([]*internal.Message, 0, len(builders)) // Avoid too many SQL variables error. for _, chunk := range xslices.Chunk(builders, ChunkLimit) { @@ -76,7 +66,7 @@ func CreateMessages(ctx context.Context, tx *ent.Tx, reqs ...*CreateMessageReq) return messages, nil } -func AddMessagesToMailbox(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, mboxID imap.InternalMailboxID) ([]UIDWithFlags, error) { +func AddMessagesToMailbox(ctx context.Context, tx *internal.Tx, messageIDs []imap.InternalMessageID, mboxID imap.InternalMailboxID) ([]db.UIDWithFlags, error) { if len(messageIDs) == 0 { return nil, nil } @@ -88,7 +78,7 @@ func AddMessagesToMailbox(ctx context.Context, tx *ent.Tx, messageIDs []imap.Int return nil, err } - var builders []*ent.UIDCreate + var builders []*internal.UIDCreate for idx, messageID := range messageIDs { nextUID := mbox.UIDNext.Add(uint32(idx)) @@ -117,16 +107,16 @@ func AddMessagesToMailbox(ctx context.Context, tx *ent.Tx, messageIDs []imap.Int return GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx, tx.Client(), mboxID, messageIDs) } -func CreateAndAddMessageToMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID, req *CreateMessageReq) (imap.UID, imap.FlagSet, error) { +func CreateAndAddMessageToMailbox(ctx context.Context, tx *internal.Tx, mboxID imap.InternalMailboxID, req *db.CreateMessageReq) (imap.UID, imap.FlagSet, error) { mbox, err := tx.Mailbox.Query().Where(mailbox.ID(mboxID)).Select(mailbox.FieldID, mailbox.FieldUIDNext).Only(ctx) if err != nil { return 0, imap.FlagSet{}, err } - var flags []*ent.MessageFlag + var flags []*internal.MessageFlag { - builders := xslices.Map(req.Message.Flags.ToSlice(), func(flag string) *ent.MessageFlagCreate { + builders := xslices.Map(req.Message.Flags.ToSlice(), func(flag string) *internal.MessageFlagCreate { return tx.MessageFlag.Create().SetValue(flag) }) @@ -172,13 +162,7 @@ func CreateAndAddMessageToMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.I return uid.UID, NewFlagSet(uid, flags), err } -type CreateAndAddMessagesResult struct { - UID imap.UID - Flags imap.FlagSet - MessageID ids.MessageIDPair -} - -func BumpMailboxUIDsForMessage(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, mboxID imap.InternalMailboxID) ([]UIDWithFlags, error) { +func BumpMailboxUIDsForMessage(ctx context.Context, tx *internal.Tx, messageIDs []imap.InternalMessageID, mboxID imap.InternalMailboxID) ([]db.UIDWithFlags, error) { messageUIDs := make(map[imap.InternalMessageID]imap.UID) mbox, err := tx.Mailbox.Query().Where(mailbox.ID(mboxID)).Only(ctx) @@ -186,7 +170,7 @@ func BumpMailboxUIDsForMessage(ctx context.Context, tx *ent.Tx, messageIDs []ima return nil, err } - var builders []*ent.UIDUpdate + var builders []*internal.UIDUpdate for idx, messageID := range messageIDs { uidNext := mbox.UIDNext.Add(uint32(idx)) @@ -211,7 +195,7 @@ func BumpMailboxUIDsForMessage(ctx context.Context, tx *ent.Tx, messageIDs []ima return GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx, tx.Client(), mboxID, messageIDs) } -func RemoveMessagesFromMailbox(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, mboxID imap.InternalMailboxID) error { +func RemoveMessagesFromMailbox(ctx context.Context, tx *internal.Tx, messageIDs []imap.InternalMessageID, mboxID imap.InternalMailboxID) error { // Avoid too many SQL variables error. for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { if _, err := tx.UID.Delete(). @@ -228,7 +212,7 @@ func RemoveMessagesFromMailbox(ctx context.Context, tx *ent.Tx, messageIDs []ima return nil } -func MessageExistsWithRemoteID(ctx context.Context, client *ent.Client, messageID imap.MessageID) (bool, error) { +func MessageExistsWithRemoteID(ctx context.Context, client *internal.Client, messageID imap.MessageID) (bool, error) { count, err := client.Message.Query().Where(message.RemoteID(messageID)).Count(ctx) if err != nil { return false, err @@ -237,19 +221,19 @@ func MessageExistsWithRemoteID(ctx context.Context, client *ent.Client, messageI return count > 0, nil } -func GetMessage(ctx context.Context, client *ent.Client, messageID imap.InternalMessageID) (*ent.Message, error) { +func GetMessage(ctx context.Context, client *internal.Client, messageID imap.InternalMessageID) (*internal.Message, error) { return client.Message.Query().Where(message.ID(messageID)).Only(ctx) } -func GetImportedMessageData(ctx context.Context, client *ent.Client, messageID imap.InternalMessageID) (*ent.Message, error) { +func GetImportedMessageData(ctx context.Context, client *internal.Client, messageID imap.InternalMessageID) (*internal.Message, error) { return client.Message.Query().Where(message.ID(messageID)).WithFlags().Select(message.FieldDate).Only(ctx) } -func GetMessageDateAndSize(ctx context.Context, client *ent.Client, messageID imap.InternalMessageID) (*ent.Message, error) { +func GetMessageDateAndSize(ctx context.Context, client *internal.Client, messageID imap.InternalMessageID) (*internal.Message, error) { return client.Message.Query().Where(message.ID(messageID)).Select(message.FieldSize, message.FieldDate).Only(ctx) } -func GetMessageMailboxIDs(ctx context.Context, client *ent.Client, messageID imap.InternalMessageID) ([]imap.InternalMailboxID, error) { +func GetMessageMailboxIDs(ctx context.Context, client *internal.Client, messageID imap.InternalMessageID) ([]imap.InternalMailboxID, error) { type tmp struct { MBoxID imap.InternalMailboxID `json:"mailbox_ui_ds"` } @@ -267,40 +251,10 @@ func GetMessageMailboxIDs(ctx context.Context, client *ent.Client, messageID ima }), nil } -type UIDWithFlags struct { - InternalID imap.InternalMessageID `json:"uid_message"` - RemoteID imap.MessageID `json:"remote_id"` - UID imap.UID `json:"uid"` - Recent bool `json:"recent"` - Deleted bool `json:"deleted"` - Flags string `json:"flags"` -} - -func (u *UIDWithFlags) GetFlagSet() imap.FlagSet { - var flagSet imap.FlagSet - - if len(u.Flags) > 0 { - flags := strings.Split(u.Flags, ",") - flagSet = imap.NewFlagSetFromSlice(flags) - } else { - flagSet = imap.NewFlagSet() - } - - if u.Deleted { - flagSet.AddToSelf(imap.FlagDeleted) - } - - if u.Recent { - flagSet.AddToSelf(imap.FlagRecent) - } - - return flagSet -} - // GetMessageUIDsWithFlagsAfterAddOrUIDBump exploits a property of adding a message to or bumping the UIDs of existing message in mailbox. It can only be // used if you can guarantee that the messageID list contains only IDs that have recently added or bumped in the mailbox. -func GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]UIDWithFlags, error) { - result := make([]UIDWithFlags, 0, len(messageIDs)) +func GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { + result := make([]db.UIDWithFlags, 0, len(messageIDs)) // Hav to split this in chunks as this can trigger too many SQL Variables. for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { @@ -320,7 +274,7 @@ func GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, client *ent.C } } - slices.SortFunc(result, func(v1 UIDWithFlags, v2 UIDWithFlags) bool { + slices.SortFunc(result, func(v1 db.UIDWithFlags, v2 db.UIDWithFlags) bool { return v1.UID < v2.UID }) @@ -331,21 +285,15 @@ func GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, client *ent.C return result, nil } -type MessageFlagSet struct { - ID imap.InternalMessageID - RemoteID imap.MessageID - FlagSet imap.FlagSet -} - // GetMessageFlags returns the flags of the given messages. // It does not include per-mailbox flags (\Deleted, \Recent)! -func GetMessageFlags(ctx context.Context, client *ent.Client, messageIDs []imap.InternalMessageID) ([]MessageFlagSet, error) { - result := make([]MessageFlagSet, 0, len(messageIDs)) +func GetMessageFlags(ctx context.Context, client *internal.Client, messageIDs []imap.InternalMessageID) ([]db.MessageFlagSet, error) { + result := make([]db.MessageFlagSet, 0, len(messageIDs)) for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { chunkMessages, err := client.Message.Query(). Where(message.IDIn(chunk...)). - WithFlags(func(query *ent.MessageFlagQuery) { + WithFlags(func(query *internal.MessageFlagQuery) { query.Select(messageflag.FieldValue) }). Select(message.FieldID, message.FieldRemoteID). @@ -355,7 +303,7 @@ func GetMessageFlags(ctx context.Context, client *ent.Client, messageIDs []imap. } for _, msg := range chunkMessages { - mfs := MessageFlagSet{ + mfs := db.MessageFlagSet{ ID: msg.ID, RemoteID: msg.RemoteID, FlagSet: imap.NewFlagSetWithCapacity(len(msg.Edges.Flags)), @@ -372,7 +320,7 @@ func GetMessageFlags(ctx context.Context, client *ent.Client, messageIDs []imap. return result, nil } -func GetMessageDeleted(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) (map[imap.InternalMessageID]bool, error) { +func GetMessageDeleted(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) (map[imap.InternalMessageID]bool, error) { type tmp struct { MsgID imap.InternalMessageID `json:"uid_message"` Deleted bool `json:"deleted"` @@ -402,9 +350,9 @@ func GetMessageDeleted(ctx context.Context, client *ent.Client, mboxID imap.Inte return res, nil } -func AddMessageFlag(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, addFlag string) error { +func AddMessageFlag(ctx context.Context, tx *internal.Tx, messageIDs []imap.InternalMessageID, addFlag string) error { for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { - builders := xslices.Map(chunk, func(imap.InternalMessageID) *ent.MessageFlagCreate { + builders := xslices.Map(chunk, func(imap.InternalMessageID) *internal.MessageFlagCreate { return tx.MessageFlag.Create().SetValue(addFlag) }) @@ -423,7 +371,7 @@ func AddMessageFlag(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalM return nil } -func RemoveMessageFlag(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, remFlag string) error { +func RemoveMessageFlag(ctx context.Context, tx *internal.Tx, messageIDs []imap.InternalMessageID, remFlag string) error { remFlagSet := imap.NewFlagSet(remFlag) for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { @@ -436,8 +384,8 @@ func RemoveMessageFlag(ctx context.Context, tx *ent.Tx, messageIDs []imap.Intern return err } - flags := xslices.Map(messages, func(message *ent.Message) *ent.MessageFlag { - return message.Edges.Flags[xslices.IndexFunc(message.Edges.Flags, func(flag *ent.MessageFlag) bool { + flags := xslices.Map(messages, func(message *internal.Message) *internal.MessageFlag { + return message.Edges.Flags[xslices.IndexFunc(message.Edges.Flags, func(flag *internal.MessageFlag) bool { return remFlagSet.Contains(flag.Value) })] }) @@ -452,7 +400,7 @@ func RemoveMessageFlag(ctx context.Context, tx *ent.Tx, messageIDs []imap.Intern return nil } -func SetMessageFlags(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, setFlags imap.FlagSet) error { +func SetMessageFlags(ctx context.Context, tx *internal.Tx, messageIDs []imap.InternalMessageID, setFlags imap.FlagSet) error { for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { messages, err := tx.Message.Query(). Where(message.IDIn(chunk...)). @@ -464,7 +412,7 @@ func SetMessageFlags(ctx context.Context, tx *ent.Tx, messageIDs []imap.Internal } for _, message := range messages { - curFlagSet := imap.NewFlagSetFromSlice(xslices.Map(message.Edges.Flags, func(flag *ent.MessageFlag) string { + curFlagSet := imap.NewFlagSetFromSlice(xslices.Map(message.Edges.Flags, func(flag *internal.MessageFlag) string { return flag.Value })) @@ -472,7 +420,7 @@ func SetMessageFlags(ctx context.Context, tx *ent.Tx, messageIDs []imap.Internal return !curFlagSet.Contains(flag) }) - builders := xslices.Map(addFlags, func(flag string) *ent.MessageFlagCreate { + builders := xslices.Map(addFlags, func(flag string) *internal.MessageFlagCreate { return tx.MessageFlag.Create().SetValue(flag) }) @@ -481,7 +429,7 @@ func SetMessageFlags(ctx context.Context, tx *ent.Tx, messageIDs []imap.Internal return err } - remEntFlags := xslices.Filter(message.Edges.Flags, func(flag *ent.MessageFlag) bool { + remEntFlags := xslices.Filter(message.Edges.Flags, func(flag *internal.MessageFlag) bool { return !setFlags.Contains(flag.Value) }) @@ -497,7 +445,7 @@ func SetMessageFlags(ctx context.Context, tx *ent.Tx, messageIDs []imap.Internal return nil } -func SetDeletedFlag(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID, deleted bool) error { +func SetDeletedFlag(ctx context.Context, tx *internal.Tx, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID, deleted bool) error { for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { if _, err := tx.UID.Update(). Where( @@ -515,7 +463,7 @@ func SetDeletedFlag(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailbox return nil } -func ClearRecentFlag(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID, messageID imap.InternalMessageID) error { +func ClearRecentFlag(ctx context.Context, tx *internal.Tx, mboxID imap.InternalMailboxID, messageID imap.InternalMessageID) error { if _, err := tx.UID.Update(). Where( func(s *sql.Selector) { @@ -529,7 +477,7 @@ func ClearRecentFlag(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailbo return nil } -func ClearRecentFlags(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID) error { +func ClearRecentFlags(ctx context.Context, tx *internal.Tx, mboxID imap.InternalMailboxID) error { if _, err := tx.UID.Update(). Where( func(s *sql.Selector) { @@ -543,7 +491,7 @@ func ClearRecentFlags(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailb return nil } -func UpdateRemoteMessageID(ctx context.Context, tx *ent.Tx, internalID imap.InternalMessageID, remoteID imap.MessageID) error { +func UpdateRemoteMessageID(ctx context.Context, tx *internal.Tx, internalID imap.InternalMessageID, remoteID imap.MessageID) error { if _, err := tx.Message.Update(). Where(message.ID(internalID)). SetRemoteID(remoteID). @@ -554,7 +502,7 @@ func UpdateRemoteMessageID(ctx context.Context, tx *ent.Tx, internalID imap.Inte return nil } -func MarkMessageAsDeleted(ctx context.Context, tx *ent.Tx, messageID imap.InternalMessageID) error { +func MarkMessageAsDeleted(ctx context.Context, tx *internal.Tx, messageID imap.InternalMessageID) error { if _, err := tx.Message.Update().Where(message.ID(messageID)).SetDeleted(true).Save(ctx); err != nil { return err } @@ -562,7 +510,7 @@ func MarkMessageAsDeleted(ctx context.Context, tx *ent.Tx, messageID imap.Intern return nil } -func MarkMessageAsDeletedAndAssignRandomRemoteID(ctx context.Context, tx *ent.Tx, messageID imap.InternalMessageID) error { +func MarkMessageAsDeletedAndAssignRandomRemoteID(ctx context.Context, tx *internal.Tx, messageID imap.InternalMessageID) error { randomID := imap.MessageID(fmt.Sprintf("DELETED-%v", imap.NewInternalMessageID())) if _, err := tx.Message.Update().Where(message.ID(messageID)).SetDeleted(true).SetRemoteID(randomID).Save(ctx); err != nil { return err @@ -571,7 +519,7 @@ func MarkMessageAsDeletedAndAssignRandomRemoteID(ctx context.Context, tx *ent.Tx return nil } -func MarkMessageAsDeletedWithRemoteID(ctx context.Context, tx *ent.Tx, messageID imap.MessageID) error { +func MarkMessageAsDeletedWithRemoteID(ctx context.Context, tx *internal.Tx, messageID imap.MessageID) error { if _, err := tx.Message.Update().Where(message.RemoteID(messageID)).SetDeleted(true).Save(ctx); err != nil { return err } @@ -579,7 +527,7 @@ func MarkMessageAsDeletedWithRemoteID(ctx context.Context, tx *ent.Tx, messageID return nil } -func DeleteMessages(ctx context.Context, tx *ent.Tx, messageIDs ...imap.InternalMessageID) error { +func DeleteMessages(ctx context.Context, tx *internal.Tx, messageIDs ...imap.InternalMessageID) error { for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { if _, err := tx.Message.Delete().Where(message.IDIn(chunk...)).Exec(ctx); err != nil { return err @@ -589,26 +537,40 @@ func DeleteMessages(ctx context.Context, tx *ent.Tx, messageIDs ...imap.Internal return nil } -func GetMessageIDsMarkedDeleted(ctx context.Context, client *ent.Client) ([]imap.InternalMessageID, error) { +func GetMessageIDsMarkedDeleted(ctx context.Context, client *internal.Client) ([]imap.InternalMessageID, error) { messages, err := client.Message.Query().Where(message.Deleted(true)).Select(message.FieldID).All(ctx) if err != nil { return nil, err } - return xslices.Map(messages, func(t *ent.Message) imap.InternalMessageID { + return xslices.Map(messages, func(t *internal.Message) imap.InternalMessageID { return t.ID }), nil } -func HasMessageWithID(ctx context.Context, client *ent.Client, id imap.InternalMessageID) (bool, error) { +func HasMessageWithID(ctx context.Context, client *internal.Client, id imap.InternalMessageID) (bool, error) { return client.Message.Query().Where(message.ID(id)).Exist(ctx) } -func HasMessageWithRemoteID(ctx context.Context, client *ent.Client, id imap.MessageID) (bool, error) { - return client.Message.Query().Where(message.RemoteID(id)).Exist(ctx) +func HasMessageWithRemoteID(ctx context.Context, client *internal.Client, id imap.MessageID) (bool, error) { + _, err := client.Message.Query().Where(message.RemoteID(id)).Select(message.FieldRemoteID).Only(ctx) + if err != nil { + if internal.IsNotFound(err) { + return false, nil + } + + return false, err + } + + // For whatever weird reason, this stopped working all together. No code changes were made, but reflection started + // failing. + // Now we get error="internal: check existence: sql/scan: missing struct field for column: id (id)". + //return client.Message.Query().Where(message.RemoteID(id)).Exist(ctx) + + return true, nil } -func GetMessageIDFromRemoteID(ctx context.Context, client *ent.Client, id imap.MessageID) (imap.InternalMessageID, error) { +func GetMessageIDFromRemoteID(ctx context.Context, client *internal.Client, id imap.MessageID) (imap.InternalMessageID, error) { message, err := client.Message.Query().Where(message.RemoteID(id)).Select(message.FieldID).Only(ctx) if err != nil { return imap.InternalMessageID{}, err @@ -617,7 +579,7 @@ func GetMessageIDFromRemoteID(ctx context.Context, client *ent.Client, id imap.M return message.ID, nil } -func GetMessageWithIDWithDeletedFlag(ctx context.Context, client *ent.Client, id imap.InternalMessageID) (*ent.Message, error) { +func GetMessageWithIDWithDeletedFlag(ctx context.Context, client *internal.Client, id imap.InternalMessageID) (*internal.Message, error) { message, err := client.Message.Query().Where(message.ID(id)).Select(message.FieldID, message.FieldDeleted).Only(ctx) if err != nil { return nil, err @@ -626,7 +588,7 @@ func GetMessageWithIDWithDeletedFlag(ctx context.Context, client *ent.Client, id return message, nil } -func GetMessageFromRemoteIDWithDeletedFlag(ctx context.Context, client *ent.Client, id imap.MessageID) (*ent.Message, error) { +func GetMessageFromRemoteIDWithDeletedFlag(ctx context.Context, client *internal.Client, id imap.MessageID) (*internal.Message, error) { message, err := client.Message.Query().Where(message.RemoteID(id)).Select(message.FieldID, message.FieldDeleted).Only(ctx) if err != nil { return nil, err @@ -635,7 +597,7 @@ func GetMessageFromRemoteIDWithDeletedFlag(ctx context.Context, client *ent.Clie return message, nil } -func GetMessageRemoteIDFromID(ctx context.Context, client *ent.Client, id imap.InternalMessageID) (imap.MessageID, error) { +func GetMessageRemoteIDFromID(ctx context.Context, client *internal.Client, id imap.InternalMessageID) (imap.MessageID, error) { message, err := client.Message.Query().Where(message.ID(id)).Select(message.FieldRemoteID).Only(ctx) if err != nil { return "", err @@ -644,8 +606,8 @@ func GetMessageRemoteIDFromID(ctx context.Context, client *ent.Client, id imap.I return message.RemoteID, nil } -func NewFlagSet(msgUID *ent.UID, flags []*ent.MessageFlag) imap.FlagSet { - flagSet := imap.NewFlagSetFromSlice(xslices.Map(flags, func(flag *ent.MessageFlag) string { +func NewFlagSet(msgUID *internal.UID, flags []*internal.MessageFlag) imap.FlagSet { + flagSet := imap.NewFlagSetFromSlice(xslices.Map(flags, func(flag *internal.MessageFlag) string { return flag.Value })) @@ -660,8 +622,8 @@ func NewFlagSet(msgUID *ent.UID, flags []*ent.MessageFlag) imap.FlagSet { return flagSet } -func GetHighestMessageID(ctx context.Context, client *ent.Client) (imap.InternalMessageID, error) { - message, err := client.Message.Query().Select(message.FieldID).Order(ent.Desc(message.FieldID)).Limit(1).All(ctx) +func GetHighestMessageID(ctx context.Context, client *internal.Client) (imap.InternalMessageID, error) { + message, err := client.Message.Query().Select(message.FieldID).Order(internal.Desc(message.FieldID)).Limit(1).All(ctx) if err != nil { return imap.InternalMessageID{}, err } @@ -673,7 +635,7 @@ func GetHighestMessageID(ctx context.Context, client *ent.Client) (imap.Internal return message[0].ID, nil } -func GetAllMessagesIDsAsMap(ctx context.Context, client *ent.Client) (map[imap.InternalMessageID]struct{}, error) { +func GetAllMessagesIDsAsMap(ctx context.Context, client *internal.Client) (map[imap.InternalMessageID]struct{}, error) { messages, err := client.Message.Query().Select(message.FieldID).All(ctx) if err != nil { return nil, err diff --git a/internal/db_impl/ent_db/ops_read.go b/internal/db_impl/ent_db/ops_read.go new file mode 100644 index 00000000..bfdb2d7e --- /dev/null +++ b/internal/db_impl/ent_db/ops_read.go @@ -0,0 +1,333 @@ +package ent_db + +import ( + "context" + "time" + + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/imap" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" + "github.com/bradenaw/juniper/xslices" +) + +type EntOpsRead struct { + client *internal.Client +} + +func newOpsReadFromClient(client *internal.Client) *EntOpsRead { + return &EntOpsRead{ + client: client, + } +} + +func newOpsReadFromTx(tx *internal.Tx) *EntOpsRead { + return &EntOpsRead{ + client: tx.Client(), + } +} + +func (op *EntOpsRead) MailboxExistsWithID(ctx context.Context, mboxID imap.InternalMailboxID) (bool, error) { + return wrapEntErrFnTyped(func() (bool, error) { + return MailboxExistsWithID(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) MailboxExistsWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (bool, error) { + return wrapEntErrFnTyped(func() (bool, error) { + return MailboxExistsWithRemoteID(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) MailboxExistsWithName(ctx context.Context, name string) (bool, error) { + return wrapEntErrFnTyped(func() (bool, error) { + return MailboxExistsWithName(ctx, op.client, name) + }) +} + +func (op *EntOpsRead) GetMailboxIDFromRemoteID(ctx context.Context, mboxID imap.MailboxID) (imap.InternalMailboxID, error) { + return wrapEntErrFnTyped(func() (imap.InternalMailboxID, error) { + return GetMailboxIDFromRemoteID(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxName(ctx context.Context, mboxID imap.InternalMailboxID) (string, error) { + return wrapEntErrFnTyped(func() (string, error) { + return GetMailboxName(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxNameWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (string, error) { + return wrapEntErrFnTyped(func() (string, error) { + return GetMailboxNameWithRemoteID(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxMessageIDPairs(ctx context.Context, mboxID imap.InternalMailboxID) ([]db.MessageIDPair, error) { + return wrapEntErrFnTyped(func() ([]db.MessageIDPair, error) { + return GetMailboxMessageIDPairs(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetAllMailboxes(ctx context.Context) ([]*db.Mailbox, error) { + return wrapEntErrFnTyped(func() ([]*db.Mailbox, error) { + val, err := GetAllMailboxes(ctx, op.client) + + return xslices.Map(val, entMBoxToDB), err + }) +} + +func (op *EntOpsRead) GetAllMailboxesAsRemoteIDs(ctx context.Context) ([]imap.MailboxID, error) { + return wrapEntErrFnTyped(func() ([]imap.MailboxID, error) { + val, err := GetAllMailboxes(ctx, op.client) + + return xslices.Map(val, func(t *internal.Mailbox) imap.MailboxID { + return t.RemoteID + }), err + }) +} + +func (op *EntOpsRead) GetMailboxByName(ctx context.Context, name string) (*db.Mailbox, error) { + return wrapEntErrFnTyped(func() (*db.Mailbox, error) { + val, err := GetMailboxByName(ctx, op.client, name) + + return entMBoxToDB(val), err + }) +} + +func (op *EntOpsRead) GetMailboxByID(ctx context.Context, mboxID imap.InternalMailboxID) (*db.Mailbox, error) { + return wrapEntErrFnTyped(func() (*db.Mailbox, error) { + val, err := GetMailboxByID(ctx, op.client, mboxID) + + return entMBoxToDB(val), err + }) +} + +func (op *EntOpsRead) GetMailboxByRemoteID(ctx context.Context, mboxID imap.MailboxID) (*db.Mailbox, error) { + return wrapEntErrFnTyped(func() (*db.Mailbox, error) { + val, err := GetMailboxByRemoteID(ctx, op.client, mboxID) + + return entMBoxToDB(val), err + }) +} + +func (op *EntOpsRead) GetMailboxRecentCount(ctx context.Context, mboxID imap.InternalMailboxID) (int, error) { + return wrapEntErrFnTyped(func() (int, error) { + mbox, err := GetMailboxByID(ctx, op.client, mboxID) + if err != nil { + return 0, err + } + + return GetMailboxRecentCount(ctx, op.client, mbox) + }) +} + +func (op *EntOpsRead) GetMailboxMessageCount(ctx context.Context, mboxID imap.InternalMailboxID) (int, error) { + return wrapEntErrFnTyped(func() (int, error) { + return GetMailboxMessageCount(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxMessageCountWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (int, error) { + return wrapEntErrFnTyped(func() (int, error) { + mbox, err := GetMailboxByRemoteID(ctx, op.client, mboxID) + if err != nil { + return 0, err + } + + return GetMailboxMessageCount(ctx, op.client, mbox.ID) + }) +} + +func (op *EntOpsRead) GetMailboxFlags(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { + return wrapEntErrFnTyped(func() (imap.FlagSet, error) { + return GetMailboxFlags(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxPermanentFlags(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { + return wrapEntErrFnTyped(func() (imap.FlagSet, error) { + return GetMailboxPermanentFlags(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxAttributes(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { + return wrapEntErrFnTyped(func() (imap.FlagSet, error) { + return GetMailboxAttributes(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxUID(ctx context.Context, mboxID imap.InternalMailboxID) (imap.UID, error) { + return wrapEntErrFnTyped(func() (imap.UID, error) { + return GetMailboxUID(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxMessageCountAndUID(ctx context.Context, mboxID imap.InternalMailboxID) (int, imap.UID, error) { + var count int + + var uid imap.UID + + err := wrapEntErrFn(func() error { + var err error + + count, uid, err = GetMailboxMessageCountAndUID(ctx, op.client, mboxID) + + return err + }) + + return count, uid, err +} + +func (op *EntOpsRead) GetMailboxMessageForNewSnapshot(ctx context.Context, mboxID imap.InternalMailboxID) ([]db.SnapshotMessageResult, error) { + return wrapEntErrFnTyped(func() ([]db.SnapshotMessageResult, error) { + return GetMailboxMessagesForNewSnapshot(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) MailboxTranslateRemoteIDs(ctx context.Context, mboxIDs []imap.MailboxID) ([]imap.InternalMailboxID, error) { + return wrapEntErrFnTyped(func() ([]imap.InternalMailboxID, error) { + return TranslateRemoteMailboxIDs(ctx, op.client, mboxIDs) + }) +} + +func (op *EntOpsRead) MailboxFilterContains(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []db.MessageIDPair) ([]imap.InternalMessageID, error) { + return wrapEntErrFnTyped(func() ([]imap.InternalMessageID, error) { + return FilterMailboxContains(ctx, op.client, mboxID, messageIDs) + }) +} + +func (op *EntOpsRead) MailboxFilterContainsInternalID(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]imap.InternalMessageID, error) { + return wrapEntErrFnTyped(func() ([]imap.InternalMessageID, error) { + return FilterMailboxContainsInternalID(ctx, op.client, mboxID, messageIDs) + }) +} + +func (op *EntOpsRead) GetMailboxCount(ctx context.Context) (int, error) { + return wrapEntErrFnTyped(func() (int, error) { + return GetMailboxCount(ctx, op.client) + }) +} + +func (op *EntOpsRead) GetMailboxMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { + return wrapEntErrFnTyped(func() ([]db.UIDWithFlags, error) { + return GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx, op.client, mboxID, messageIDs) + }) +} + +func (op *EntOpsRead) MessageExists(ctx context.Context, id imap.InternalMessageID) (bool, error) { + return wrapEntErrFnTyped(func() (bool, error) { + return HasMessageWithID(ctx, op.client, id) + }) +} + +func (op *EntOpsRead) MessageExistsWithRemoteID(ctx context.Context, id imap.MessageID) (bool, error) { + return wrapEntErrFnTyped(func() (bool, error) { + return HasMessageWithRemoteID(ctx, op.client, id) + }) +} + +func (op *EntOpsRead) GetMessage(ctx context.Context, id imap.InternalMessageID) (*db.Message, error) { + return wrapEntErrFnTyped(func() (*db.Message, error) { + msg, err := GetMessage(ctx, op.client, id) + + return entMessageToDB(msg), err + }) +} + +func (op *EntOpsRead) GetTotalMessageCount(ctx context.Context) (int, error) { + return wrapEntErrFnTyped(func() (int, error) { + return op.client.Message.Query().Count(ctx) + }) +} + +func (op *EntOpsRead) GetMessageRemoteID(ctx context.Context, id imap.InternalMessageID) (imap.MessageID, error) { + return wrapEntErrFnTyped(func() (imap.MessageID, error) { + return GetMessageRemoteIDFromID(ctx, op.client, id) + }) +} + +func (op *EntOpsRead) GetImportedMessageData(ctx context.Context, id imap.InternalMessageID) (*db.Message, error) { + return wrapEntErrFnTyped(func() (*db.Message, error) { + msg, err := GetImportedMessageData(ctx, op.client, id) + + return entMessageToDB(msg), err + }) +} + +func (op *EntOpsRead) GetMessageDateAndSize(ctx context.Context, id imap.InternalMessageID) (time.Time, int, error) { + var date time.Time + + var size int + + err := wrapEntErrFn(func() error { + msg, err := GetMessageDateAndSize(ctx, op.client, id) + if err != nil { + return err + } + + date = msg.Date + size = msg.Size + + return err + }) + + return date, size, err +} + +func (op *EntOpsRead) GetMessageMailboxIDs(ctx context.Context, id imap.InternalMessageID) ([]imap.InternalMailboxID, error) { + return wrapEntErrFnTyped(func() ([]imap.InternalMailboxID, error) { + return GetMessageMailboxIDs(ctx, op.client, id) + }) +} + +func (op *EntOpsRead) GetMessagesFlags(ctx context.Context, ids []imap.InternalMessageID) ([]db.MessageFlagSet, error) { + return wrapEntErrFnTyped(func() ([]db.MessageFlagSet, error) { + return GetMessageFlags(ctx, op.client, ids) + }) +} + +func (op *EntOpsRead) GetMessageIDsMarkedAsDelete(ctx context.Context) ([]imap.InternalMessageID, error) { + return wrapEntErrFnTyped(func() ([]imap.InternalMessageID, error) { + return GetMessageIDsMarkedDeleted(ctx, op.client) + }) +} + +func (op *EntOpsRead) GetMessageIDFromRemoteID(ctx context.Context, id imap.MessageID) (imap.InternalMessageID, error) { + return wrapEntErrFnTyped(func() (imap.InternalMessageID, error) { + return GetMessageIDFromRemoteID(ctx, op.client, id) + }) +} + +func (op *EntOpsRead) GetMessageDeletedFlag(ctx context.Context, id imap.InternalMessageID) (bool, error) { + return wrapEntErrFnTyped(func() (bool, error) { + msg, err := GetMessageWithIDWithDeletedFlag(ctx, op.client, id) + if err != nil { + return false, err + } + + return msg.Deleted, nil + }) +} + +func (op *EntOpsRead) GetAllMessagesIDsAsMap(ctx context.Context) (map[imap.InternalMessageID]struct{}, error) { + return wrapEntErrFnTyped(func() (map[imap.InternalMessageID]struct{}, error) { + return GetAllMessagesIDsAsMap(ctx, op.client) + }) +} + +func (op *EntOpsRead) GetDeletedSubscriptionSet(ctx context.Context) (map[imap.MailboxID]*db.DeletedSubscription, error) { + return wrapEntErrFnTyped(func() (map[imap.MailboxID]*db.DeletedSubscription, error) { + ent, err := GetDeletedSubscriptionSet(ctx, op.client) + if err != nil { + return nil, err + } + + result := make(map[imap.MailboxID]*db.DeletedSubscription, len(ent)) + + for k, v := range ent { + result[k] = entSubscriptionToDB(v) + } + + return result, nil + }) +} diff --git a/internal/db_impl/ent_db/ops_write.go b/internal/db_impl/ent_db/ops_write.go new file mode 100644 index 00000000..b04f3e94 --- /dev/null +++ b/internal/db_impl/ent_db/ops_write.go @@ -0,0 +1,225 @@ +package ent_db + +import ( + "context" + + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/imap" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/bradenaw/juniper/xslices" +) + +type EntOpsWrite struct { + EntOpsRead + tx *internal.Tx +} + +func newEntOpsWrite(tx *internal.Tx) *EntOpsWrite { + return &EntOpsWrite{ + EntOpsRead: EntOpsRead{client: tx.Client()}, + tx: tx, + } +} + +func (op *EntOpsWrite) CreateMailbox(ctx context.Context, + mboxID imap.MailboxID, + name string, + flags, permFlags, attrs imap.FlagSet, + uidValidity imap.UID) (*db.Mailbox, error) { + return wrapEntErrFnTyped(func() (*db.Mailbox, error) { + mbox, err := CreateMailbox(ctx, op.tx, mboxID, name, flags, permFlags, attrs, uidValidity) + + return entMBoxToDB(mbox), err + }) +} + +func (op *EntOpsWrite) GetOrCreateMailbox(ctx context.Context, + mboxID imap.MailboxID, + name string, + flags, permFlags, attrs imap.FlagSet, + uidValidity imap.UID) (*db.Mailbox, error) { + return wrapEntErrFnTyped(func() (*db.Mailbox, error) { + mbox, err := CreateMailbox(ctx, op.tx, mboxID, name, flags, permFlags, attrs, uidValidity) + + return entMBoxToDB(mbox), err + }) +} + +func (op *EntOpsWrite) GetOrCreateMailboxAlt(ctx context.Context, + mbox imap.Mailbox, + delimiter string, + uidValidity imap.UID) (*db.Mailbox, error) { + return wrapEntErrFnTyped(func() (*db.Mailbox, error) { + mbox, err := GetOrCreateMailbox(ctx, op.tx, mbox, delimiter, uidValidity) + + return entMBoxToDB(mbox), err + }) +} + +func (op *EntOpsWrite) RenameMailboxWithRemoteID(ctx context.Context, mboxID imap.MailboxID, name string) error { + return wrapEntErrFn(func() error { + return RenameMailboxWithRemoteID(ctx, op.tx, mboxID, name) + }) +} + +func (op *EntOpsWrite) DeleteMailboxWithRemoteID(ctx context.Context, mboxID imap.MailboxID) error { + return wrapEntErrFn(func() error { + return DeleteMailboxWithRemoteID(ctx, op.tx, mboxID) + }) +} + +func (op *EntOpsWrite) BumpMailboxUIDNext(ctx context.Context, mboxID imap.InternalMailboxID, count int) error { + return wrapEntErrFn(func() error { + mbox, err := op.tx.Mailbox.Query().Where(mailbox.ID(mboxID)).Only(ctx) + if err != nil { + return err + } + + return BumpMailboxUIDNext(ctx, op.tx, mbox, count) + }) +} + +func (op *EntOpsWrite) AddMessagesToMailbox(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { + return wrapEntErrFnTyped(func() ([]db.UIDWithFlags, error) { + return AddMessagesToMailbox(ctx, op.tx, messageIDs, mboxID) + }) +} + +func (op *EntOpsWrite) BumpMailboxUIDsForMessage(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { + return wrapEntErrFnTyped(func() ([]db.UIDWithFlags, error) { + return BumpMailboxUIDsForMessage(ctx, op.tx, messageIDs, mboxID) + }) +} + +func (op *EntOpsWrite) RemoveMessagesFromMailbox(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) error { + return wrapEntErrFn(func() error { + return RemoveMessagesFromMailbox(ctx, op.tx, messageIDs, mboxID) + }) +} + +func (op *EntOpsWrite) ClearRecentFlagInMailboxOnMessage(ctx context.Context, mboxID imap.InternalMailboxID, messageID imap.InternalMessageID) error { + return wrapEntErrFn(func() error { + return ClearRecentFlag(ctx, op.tx, mboxID, messageID) + }) +} + +func (op *EntOpsWrite) ClearRecentFlagsInMailbox(ctx context.Context, mboxID imap.InternalMailboxID) error { + return wrapEntErrFn(func() error { + return ClearRecentFlags(ctx, op.tx, mboxID) + }) +} + +func (op *EntOpsWrite) CreateMailboxIfNotExists(ctx context.Context, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) error { + return wrapEntErrFn(func() error { + return CreateMailboxIfNotExists(ctx, op.tx, mbox, delimiter, uidValidity) + }) +} + +func (op *EntOpsWrite) SetMailboxMessagesDeletedFlag(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID, deleted bool) error { + return wrapEntErrFn(func() error { + return SetDeletedFlag(ctx, op.tx, mboxID, messageIDs, deleted) + }) +} + +func (op *EntOpsWrite) SetMailboxSubscribed(ctx context.Context, mboxID imap.InternalMailboxID, subscribed bool) error { + return wrapEntErrFn(func() error { + return op.tx.Mailbox.Update().Where(mailbox.ID(mboxID)).SetSubscribed(subscribed).Exec(ctx) + }) +} + +func (op *EntOpsWrite) UpdateRemoteMailboxID(ctx context.Context, mobxID imap.InternalMailboxID, remoteID imap.MailboxID) error { + return wrapEntErrFn(func() error { + return UpdateRemoteMailboxID(ctx, op.tx, mobxID, remoteID) + }) +} + +func (op *EntOpsWrite) SetMailboxUIDValidity(ctx context.Context, mboxID imap.InternalMailboxID, uidValidity imap.UID) error { + return wrapEntErrFn(func() error { + return op.tx.Mailbox.Update().Where(mailbox.ID(mboxID)).SetUIDValidity(uidValidity).Exec(ctx) + }) +} + +func (op *EntOpsWrite) CreateMessages(ctx context.Context, reqs ...*db.CreateMessageReq) ([]*db.Message, error) { + return wrapEntErrFnTyped(func() ([]*db.Message, error) { + msgs, err := CreateMessages(ctx, op.tx, reqs...) + + return xslices.Map(msgs, entMessageToDB), err + }) +} + +func (op *EntOpsWrite) CreateMessageAndAddToMailbox(ctx context.Context, mbox imap.InternalMailboxID, req *db.CreateMessageReq) (imap.UID, imap.FlagSet, error) { + var uid imap.UID + + var flagSet imap.FlagSet + + err := wrapEntErrFn(func() error { + var err error + + uid, flagSet, err = CreateAndAddMessageToMailbox(ctx, op.tx, mbox, req) + + return err + }) + + return uid, flagSet, err +} + +func (op *EntOpsWrite) MarkMessageAsDeleted(ctx context.Context, id imap.InternalMessageID) error { + return wrapEntErrFn(func() error { + return MarkMessageAsDeleted(ctx, op.tx, id) + }) +} + +func (op *EntOpsWrite) MarkMessageAsDeletedAndAssignRandomRemoteID(ctx context.Context, id imap.InternalMessageID) error { + return wrapEntErrFn(func() error { + return MarkMessageAsDeletedAndAssignRandomRemoteID(ctx, op.tx, id) + }) +} + +func (op *EntOpsWrite) MarkMessageAsDeletedWithRemoteID(ctx context.Context, id imap.MessageID) error { + return wrapEntErrFn(func() error { + return MarkMessageAsDeletedWithRemoteID(ctx, op.tx, id) + }) +} + +func (op *EntOpsWrite) DeleteMessages(ctx context.Context, ids []imap.InternalMessageID) error { + return wrapEntErrFn(func() error { + return DeleteMessages(ctx, op.tx, ids...) + }) +} + +func (op *EntOpsWrite) UpdateRemoteMessageID(ctx context.Context, internalID imap.InternalMessageID, remoteID imap.MessageID) error { + return wrapEntErrFn(func() error { + return UpdateRemoteMessageID(ctx, op.tx, internalID, remoteID) + }) +} + +func (op *EntOpsWrite) AddFlagToMessages(ctx context.Context, ids []imap.InternalMessageID, flag string) error { + return wrapEntErrFn(func() error { + return AddMessageFlag(ctx, op.tx, ids, flag) + }) +} + +func (op *EntOpsWrite) RemoveFlagFromMessages(ctx context.Context, ids []imap.InternalMessageID, flag string) error { + return wrapEntErrFn(func() error { + return RemoveMessageFlag(ctx, op.tx, ids, flag) + }) +} + +func (op *EntOpsWrite) SetFlagsOnMessages(ctx context.Context, ids []imap.InternalMessageID, flags imap.FlagSet) error { + return wrapEntErrFn(func() error { + return SetMessageFlags(ctx, op.tx, ids, flags) + }) +} + +func (op *EntOpsWrite) AddDeletedSubscription(ctx context.Context, mboxName string, mboxID imap.MailboxID) error { + return wrapEntErrFn(func() error { + return AddDeletedSubscription(ctx, op.tx, mboxName, mboxID) + }) +} + +func (op *EntOpsWrite) RemoveDeletedSubscriptionWithName(ctx context.Context, mboxName string) (int, error) { + return wrapEntErrFnTyped(func() (int, error) { + return RemoveDeletedSubscriptionWithName(ctx, op.tx, mboxName) + }) +} diff --git a/internal/db/subscription.go b/internal/db_impl/ent_db/subscription.go similarity index 63% rename from internal/db/subscription.go rename to internal/db_impl/ent_db/subscription.go index 6fa702ca..79b46747 100644 --- a/internal/db/subscription.go +++ b/internal/db_impl/ent_db/subscription.go @@ -1,14 +1,14 @@ -package db +package ent_db import ( "context" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" ) -func AddDeletedSubscription(ctx context.Context, tx *ent.Tx, mboxName string, mboxID imap.MailboxID) error { +func AddDeletedSubscription(ctx context.Context, tx *internal.Tx, mboxName string, mboxID imap.MailboxID) error { count, err := tx.DeletedSubscription.Update().Where(deletedsubscription.NameEqualFold(mboxName)).SetRemoteID(mboxID).Save(ctx) if err != nil { return err @@ -23,14 +23,14 @@ func AddDeletedSubscription(ctx context.Context, tx *ent.Tx, mboxName string, mb return nil } -func RemoveDeletedSubscriptionWithName(ctx context.Context, tx *ent.Tx, mboxName string) (int, error) { +func RemoveDeletedSubscriptionWithName(ctx context.Context, tx *internal.Tx, mboxName string) (int, error) { return tx.DeletedSubscription.Delete().Where(deletedsubscription.NameEqualFold(mboxName)).Exec(ctx) } -func GetDeletedSubscriptionSet(ctx context.Context, client *ent.Client) (map[imap.MailboxID]*ent.DeletedSubscription, error) { +func GetDeletedSubscriptionSet(ctx context.Context, client *internal.Client) (map[imap.MailboxID]*internal.DeletedSubscription, error) { const QueryLimit = 16000 - subscriptions := make(map[imap.MailboxID]*ent.DeletedSubscription) + subscriptions := make(map[imap.MailboxID]*internal.DeletedSubscription) for i := 0; ; i += QueryLimit { result, err := client.DeletedSubscription.Query(). diff --git a/internal/db_impl/ent_db/type_conversions.go b/internal/db_impl/ent_db/type_conversions.go new file mode 100644 index 00000000..a6595fa1 --- /dev/null +++ b/internal/db_impl/ent_db/type_conversions.go @@ -0,0 +1,103 @@ +package ent_db + +import ( + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" + "github.com/bradenaw/juniper/xslices" +) + +func entMailboxFlagToDB(flag *internal.MailboxFlag) *db.MailboxFlag { + if flag == nil { + return nil + } + + return &db.MailboxFlag{ + ID: flag.ID, + Value: flag.Value, + } +} + +func entMailboxPermFlagsToDB(flag *internal.MailboxPermFlag) *db.MailboxFlag { + if flag == nil { + return nil + } + + return &db.MailboxFlag{ + ID: flag.ID, + Value: flag.Value, + } +} + +func entMailboxAttrToDB(flag *internal.MailboxAttr) *db.MailboxAttr { + if flag == nil { + return nil + } + + return &db.MailboxAttr{ + ID: flag.ID, + Value: flag.Value, + } +} + +func entMBoxToDB(mbox *internal.Mailbox) *db.Mailbox { + if mbox == nil { + return nil + } + + return &db.Mailbox{ + ID: mbox.ID, + RemoteID: mbox.RemoteID, + Name: mbox.Name, + UIDNext: mbox.UIDNext, + UIDValidity: mbox.UIDValidity, + Subscribed: mbox.Subscribed, + Flags: xslices.Map(mbox.Edges.Flags, entMailboxFlagToDB), + PermanentFlags: xslices.Map(mbox.Edges.PermanentFlags, entMailboxPermFlagsToDB), + Attributes: xslices.Map(mbox.Edges.Attributes, entMailboxAttrToDB), + } +} + +func entMessageFlagsToDB(flag *internal.MessageFlag) *db.MessageFlag { + return &db.MessageFlag{ + ID: flag.ID, + Value: flag.Value, + } +} + +func entMessageUIDToDB(uid *internal.UID) *db.UID { + return &db.UID{ + UID: uid.UID, + Deleted: uid.Deleted, + Recent: uid.Recent, + } +} + +func entMessageToDB(msg *internal.Message) *db.Message { + if msg == nil { + return nil + } + + return &db.Message{ + ID: msg.ID, + RemoteID: msg.RemoteID, + Date: msg.Date, + Size: msg.Size, + Body: msg.Body, + BodyStructure: msg.BodyStructure, + Envelope: msg.Envelope, + Deleted: msg.Deleted, + Flags: xslices.Map(msg.Edges.Flags, entMessageFlagsToDB), + UIDs: xslices.Map(msg.Edges.UIDs, entMessageUIDToDB), + } +} + +func entSubscriptionToDB(s *internal.DeletedSubscription) *db.DeletedSubscription { + if s == nil { + return nil + } + + return &db.DeletedSubscription{ + Name: s.Name, + RemoteID: s.RemoteID, + } +} diff --git a/internal/ids/ids.go b/internal/ids/ids.go index cf75a299..234160a3 100644 --- a/internal/ids/ids.go +++ b/internal/ids/ids.go @@ -5,76 +5,8 @@ import ( "strings" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent" ) -type MailboxIDPair struct { - InternalID imap.InternalMailboxID - RemoteID imap.MailboxID -} - -func (m *MailboxIDPair) String() string { - return fmt.Sprintf("%v::%v", m.InternalID, m.RemoteID) -} - -type MessageIDPair struct { - InternalID imap.InternalMessageID - RemoteID imap.MessageID -} - -func (m *MessageIDPair) String() string { - return fmt.Sprintf("%v::%v", m.InternalID, m.RemoteID) -} - -func NewMailboxIDPair(mbox *ent.Mailbox) MailboxIDPair { - return MailboxIDPair{ - InternalID: mbox.ID, - RemoteID: mbox.RemoteID, - } -} - -func NewMailboxIDPairWithoutRemote(internalID imap.InternalMailboxID) MailboxIDPair { - return MailboxIDPair{ - InternalID: internalID, - RemoteID: "", - } -} - -func NewMessageIDPair(msg *ent.Message) MessageIDPair { - return MessageIDPair{ - InternalID: msg.ID, - RemoteID: msg.RemoteID, - } -} - -func SplitMessageIDPairSlice(s []MessageIDPair) ([]imap.InternalMessageID, []imap.MessageID) { - l := len(s) - - internalMessageIDs := make([]imap.InternalMessageID, 0, l) - remoteMessageIDs := make([]imap.MessageID, 0, l) - - for _, v := range s { - internalMessageIDs = append(internalMessageIDs, v.InternalID) - remoteMessageIDs = append(remoteMessageIDs, v.RemoteID) - } - - return internalMessageIDs, remoteMessageIDs -} - -func SplitMailboxIDPairSlice(s []MailboxIDPair) ([]imap.InternalMailboxID, []imap.MailboxID) { - l := len(s) - - internalMailboxIDs := make([]imap.InternalMailboxID, 0, l) - mailboxIDs := make([]imap.MailboxID, 0, l) - - for _, v := range s { - internalMailboxIDs = append(internalMailboxIDs, v.InternalID) - mailboxIDs = append(mailboxIDs, v.RemoteID) - } - - return internalMailboxIDs, mailboxIDs -} - const GluonRecoveryMailboxName = "Recovered Messages" const GluonRecoveryMailboxNameLowerCase = "recovered messages" const GluonInternalRecoveryMailboxRemoteID = imap.MailboxID("GLUON-INTERNAL-RECOVERY-MBOX") diff --git a/internal/session/errors.go b/internal/session/errors.go index 82a8b540..16140a86 100644 --- a/internal/session/errors.go +++ b/internal/session/errors.go @@ -3,10 +3,11 @@ package session import ( "context" "errors" + "net" + "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/rfc822" - "net" ) var ( diff --git a/internal/session/handle_append.go b/internal/session/handle_append.go index 04169321..5b3b0426 100644 --- a/internal/session/handle_append.go +++ b/internal/session/handle_append.go @@ -3,6 +3,7 @@ package session import ( "context" "errors" + "github.com/ProtonMail/gluon/imap/command" "github.com/ProtonMail/gluon/internal/response" "github.com/ProtonMail/gluon/internal/state" diff --git a/internal/state/actions.go b/internal/state/actions.go index a9aad84d..540416d0 100644 --- a/internal/state/actions.go +++ b/internal/state/actions.go @@ -7,9 +7,8 @@ import ( "strings" "time" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/rfc822" @@ -18,21 +17,20 @@ import ( "golang.org/x/exp/slices" ) -func (state *State) actionCreateAndGetMailbox(ctx context.Context, tx *ent.Tx, name string, uidValidity imap.UID) (*ent.Mailbox, error) { +func (state *State) actionCreateAndGetMailbox(ctx context.Context, tx db.Transaction, name string, uidValidity imap.UID) (*db.Mailbox, error) { res, err := state.user.GetRemote().CreateMailbox(ctx, strings.Split(name, state.delimiter)) if err != nil { return nil, err } - exists, err := db.MailboxExistsWithRemoteID(ctx, tx.Client(), res.ID) + exists, err := tx.MailboxExistsWithRemoteID(ctx, res.ID) if err != nil { return nil, err } if !exists { - mbox, err := db.CreateMailbox( + mbox, err := tx.CreateMailbox( ctx, - tx, res.ID, strings.Join(res.Name, state.user.GetDelimiter()), res.Flags, @@ -47,31 +45,31 @@ func (state *State) actionCreateAndGetMailbox(ctx context.Context, tx *ent.Tx, n return mbox, nil } - return db.GetMailboxByRemoteID(ctx, tx.Client(), res.ID) + return tx.GetMailboxByRemoteID(ctx, res.ID) } -func (state *State) actionCreateMailbox(ctx context.Context, tx *ent.Tx, name string, uidValidity imap.UID) error { +func (state *State) actionCreateMailbox(ctx context.Context, tx db.Transaction, name string, uidValidity imap.UID) error { res, err := state.user.GetRemote().CreateMailbox(ctx, strings.Split(name, state.delimiter)) if err != nil { return err } - return db.CreateMailboxIfNotExists(ctx, tx, res, state.delimiter, uidValidity) + return tx.CreateMailboxIfNotExists(ctx, res, state.delimiter, uidValidity) } -func (state *State) actionDeleteMailbox(ctx context.Context, tx *ent.Tx, mboxID ids.MailboxIDPair) ([]Update, error) { +func (state *State) actionDeleteMailbox(ctx context.Context, tx db.Transaction, mboxID db.MailboxIDPair) ([]Update, error) { if err := state.user.GetRemote().DeleteMailbox(ctx, mboxID.RemoteID); err != nil { return nil, err } - if err := db.DeleteMailboxWithRemoteID(ctx, tx, mboxID.RemoteID); err != nil { + if err := tx.DeleteMailboxWithRemoteID(ctx, mboxID.RemoteID); err != nil { return nil, err } return []Update{NewMailboxDeletedStateUpdate(mboxID.InternalID)}, nil } -func (state *State) actionUpdateMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.MailboxID, newName string) error { +func (state *State) actionUpdateMailbox(ctx context.Context, tx db.Transaction, mboxID imap.MailboxID, newName string) error { if err := state.user.GetRemote().UpdateMailbox( ctx, mboxID, @@ -80,13 +78,13 @@ func (state *State) actionUpdateMailbox(ctx context.Context, tx *ent.Tx, mboxID return err } - return db.RenameMailboxWithRemoteID(ctx, tx, mboxID, newName) + return tx.RenameMailboxWithRemoteID(ctx, mboxID, newName) } func (state *State) actionCreateMessage( ctx context.Context, - tx *ent.Tx, - mboxID ids.MailboxIDPair, + tx db.Transaction, + mboxID db.MailboxIDPair, literal []byte, flags imap.FlagSet, date time.Time, @@ -100,14 +98,14 @@ func (state *State) actionCreateMessage( { // Handle the case where duplicate messages can return the same remote ID. - knownInternalID, knownErr := db.GetMessageIDFromRemoteID(ctx, tx.Client(), res.ID) - if knownErr != nil && !ent.IsNotFound(knownErr) { + knownInternalID, knownErr := tx.GetMessageIDFromRemoteID(ctx, res.ID) + if knownErr != nil && !db.IsErrNotFound(knownErr) { return nil, 0, knownErr } if knownErr == nil { // Try to collect the original message date. var existingMessageDate time.Time - if existingMessage, msgErr := db.GetMessage(ctx, tx.Client(), internalID); msgErr == nil { + if existingMessage, msgErr := tx.GetMessage(ctx, internalID); msgErr == nil { existingMessageDate = existingMessage.Date } @@ -128,7 +126,7 @@ func (state *State) actionCreateMessage( updates, result, err := state.actionAddMessagesToMailbox(ctx, tx, - []ids.MessageIDPair{{InternalID: knownInternalID, RemoteID: res.ID}}, + []db.MessageIDPair{{InternalID: knownInternalID, RemoteID: res.ID}}, mboxID, isSelectedMailbox, ) @@ -163,7 +161,7 @@ func (state *State) actionCreateMessage( InternalID: internalID, } - messageUID, flagSet, err := db.CreateAndAddMessageToMailbox(ctx, tx, mboxID.InternalID, &req) + messageUID, flagSet, err := tx.CreateMessageAndAddToMailbox(ctx, mboxID.InternalID, &req) if err != nil { return nil, 0, err } @@ -177,7 +175,7 @@ func (state *State) actionCreateMessage( updates := []Update{newExistsStateUpdateWithExists( mboxID.InternalID, - []*exists{newExists(ids.MessageIDPair{InternalID: internalID, RemoteID: res.ID}, messageUID, flagSet)}, + []*exists{newExists(db.MessageIDPair{InternalID: internalID, RemoteID: res.ID}, messageUID, flagSet)}, st, ), } @@ -187,7 +185,7 @@ func (state *State) actionCreateMessage( func (state *State) actionCreateRecoveredMessage( ctx context.Context, - tx *ent.Tx, + tx db.Transaction, literal []byte, flags imap.FlagSet, date time.Time, @@ -225,14 +223,14 @@ func (state *State) actionCreateRecoveredMessage( recoveryMBoxID := state.user.GetRecoveryMailboxID() - messageUID, flagSet, err := db.CreateAndAddMessageToMailbox(ctx, tx, recoveryMBoxID.InternalID, &req) + messageUID, flagSet, err := tx.CreateMessageAndAddToMailbox(ctx, recoveryMBoxID.InternalID, &req) if err != nil { return nil, false, err } var updates = []Update{newExistsStateUpdateWithExists( recoveryMBoxID.InternalID, - []*exists{newExists(ids.MessageIDPair{InternalID: internalID, RemoteID: remoteID}, messageUID, flagSet)}, + []*exists{newExists(db.MessageIDPair{InternalID: internalID, RemoteID: remoteID}, messageUID, flagSet)}, nil, ), } @@ -242,20 +240,20 @@ func (state *State) actionCreateRecoveredMessage( func (state *State) actionAddMessagesToMailbox( ctx context.Context, - tx *ent.Tx, - messageIDs []ids.MessageIDPair, - mboxID ids.MailboxIDPair, + tx db.Transaction, + messageIDs []db.MessageIDPair, + mboxID db.MailboxIDPair, isMailboxSelected bool, ) ([]Update, []db.UIDWithFlags, error) { var allUpdates []Update { - haveMessageIDs, err := db.FilterMailboxContains(ctx, tx.Client(), mboxID.InternalID, messageIDs) + haveMessageIDs, err := tx.MailboxFilterContains(ctx, mboxID.InternalID, messageIDs) if err != nil { return nil, nil, err } - if remMessageIDs := xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { + if remMessageIDs := xslices.Filter(messageIDs, func(messageID db.MessageIDPair) bool { return slices.Contains(haveMessageIDs, messageID.InternalID) }); len(remMessageIDs) > 0 { updates, err := state.actionRemoveMessagesFromMailboxUnchecked(ctx, tx, remMessageIDs, mboxID) @@ -267,7 +265,7 @@ func (state *State) actionAddMessagesToMailbox( } } - internalIDs, remoteIDs := ids.SplitMessageIDPairSlice(messageIDs) + internalIDs, remoteIDs := db.SplitMessageIDPairSlice(messageIDs) if err := state.user.GetRemote().AddMessagesToMailbox(ctx, remoteIDs, mboxID.RemoteID); err != nil { return nil, nil, err @@ -291,11 +289,11 @@ func (state *State) actionAddMessagesToMailbox( func (state *State) actionAddRecoveredMessagesToMailbox( ctx context.Context, - tx *ent.Tx, - messageIDs []ids.MessageIDPair, - mboxID ids.MailboxIDPair, + tx db.Transaction, + messageIDs []db.MessageIDPair, + mboxID db.MailboxIDPair, ) ([]db.UIDWithFlags, Update, error) { - internalIDs, remoteIDs := ids.SplitMessageIDPairSlice(messageIDs) + internalIDs, remoteIDs := db.SplitMessageIDPairSlice(messageIDs) if err := state.user.GetRemote().AddMessagesToMailbox(ctx, remoteIDs, mboxID.RemoteID); err != nil { return nil, nil, err @@ -306,39 +304,39 @@ func (state *State) actionAddRecoveredMessagesToMailbox( func (state *State) actionImportRecoveredMessage( ctx context.Context, - tx *ent.Tx, + tx db.Transaction, id imap.InternalMessageID, mboxID imap.MailboxID, -) (ids.MessageIDPair, bool, error) { - message, err := db.GetImportedMessageData(ctx, tx.Client(), id) +) (db.MessageIDPair, bool, error) { + message, err := tx.GetImportedMessageData(ctx, id) if err != nil { - return ids.MessageIDPair{}, false, err + return db.MessageIDPair{}, false, err } literal, err := state.user.GetStore().Get(id) if err != nil { - return ids.MessageIDPair{}, false, err + return db.MessageIDPair{}, false, err } messageFlags := imap.NewFlagSet() - for _, flag := range message.Edges.Flags { + for _, flag := range message.Flags { messageFlags.AddToSelf(flag.Value) } internalID, res, newLiteral, err := state.user.GetRemote().CreateMessage(ctx, mboxID, literal, messageFlags, message.Date) if err != nil { - return ids.MessageIDPair{}, false, err + return db.MessageIDPair{}, false, err } { // Handle the unlikely case where duplicate messages can return the same remote ID. - internalID, err := db.GetMessageIDFromRemoteID(ctx, tx.Client(), res.ID) - if err != nil && !ent.IsNotFound(err) { - return ids.MessageIDPair{}, false, err + internalID, err := tx.GetMessageIDFromRemoteID(ctx, res.ID) + if err != nil && !db.IsErrNotFound(err) { + return db.MessageIDPair{}, false, err } if err == nil { - return ids.MessageIDPair{ + return db.MessageIDPair{ InternalID: internalID, RemoteID: res.ID, }, true, nil @@ -347,16 +345,16 @@ func (state *State) actionImportRecoveredMessage( parsedMessage, err := imap.NewParsedMessage(newLiteral) if err != nil { - return ids.MessageIDPair{}, false, err + return db.MessageIDPair{}, false, err } literalReader, literalSize, err := rfc822.SetHeaderValueNoMemCopy(newLiteral, ids.InternalIDKey, internalID.String()) if err != nil { - return ids.MessageIDPair{}, false, fmt.Errorf("failed to set internal ID: %w", err) + return db.MessageIDPair{}, false, fmt.Errorf("failed to set internal ID: %w", err) } if err := state.user.GetStore().SetUnchecked(internalID, literalReader); err != nil { - return ids.MessageIDPair{}, false, fmt.Errorf("failed to store message literal: %w", err) + return db.MessageIDPair{}, false, fmt.Errorf("failed to store message literal: %w", err) } req := db.CreateMessageReq{ @@ -368,11 +366,11 @@ func (state *State) actionImportRecoveredMessage( InternalID: internalID, } - if _, err := db.CreateMessages(ctx, tx, &req); err != nil { - return ids.MessageIDPair{}, false, err + if _, err := tx.CreateMessages(ctx, &req); err != nil { + return db.MessageIDPair{}, false, err } - return ids.MessageIDPair{ + return db.MessageIDPair{ InternalID: internalID, RemoteID: res.ID, }, false, nil @@ -380,11 +378,11 @@ func (state *State) actionImportRecoveredMessage( func (state *State) actionCopyMessagesOutOfRecoveryMailbox( ctx context.Context, - tx *ent.Tx, - messageIDs []ids.MessageIDPair, - mboxID ids.MailboxIDPair, + tx db.Transaction, + messageIDs []db.MessageIDPair, + mboxID db.MailboxIDPair, ) ([]Update, []db.UIDWithFlags, error) { - ids := make([]ids.MessageIDPair, 0, len(messageIDs)) + ids := make([]db.MessageIDPair, 0, len(messageIDs)) // Import messages to remote. for _, id := range messageIDs { @@ -407,11 +405,11 @@ func (state *State) actionCopyMessagesOutOfRecoveryMailbox( func (state *State) actionMoveMessagesOutOfRecoveryMailbox( ctx context.Context, - tx *ent.Tx, - messageIDs []ids.MessageIDPair, - mboxID ids.MailboxIDPair, + tx db.Transaction, + messageIDs []db.MessageIDPair, + mboxID db.MailboxIDPair, ) ([]Update, []db.UIDWithFlags, error) { - ids := make([]ids.MessageIDPair, 0, len(messageIDs)) + ids := make([]db.MessageIDPair, 0, len(messageIDs)) oldInternalIDs := make([]imap.InternalMessageID, 0, len(messageIDs)) // Import messages to remote. @@ -422,7 +420,7 @@ func (state *State) actionMoveMessagesOutOfRecoveryMailbox( } if !deduped { - if err := db.MarkMessageAsDeleted(ctx, tx, id.InternalID); err != nil { + if err := tx.MarkMessageAsDeleted(ctx, id.InternalID); err != nil { return nil, nil, err } } @@ -460,11 +458,11 @@ func (state *State) actionMoveMessagesOutOfRecoveryMailbox( // have already validated the input beforehand (e.g.: actionAddMessagesToMailbox and actionRemoveMessagesFromMailbox). func (state *State) actionRemoveMessagesFromMailboxUnchecked( ctx context.Context, - tx *ent.Tx, - messageIDs []ids.MessageIDPair, - mboxID ids.MailboxIDPair, + tx db.Transaction, + messageIDs []db.MessageIDPair, + mboxID db.MailboxIDPair, ) ([]Update, error) { - internalIDs, remoteIDs := ids.SplitMessageIDPairSlice(messageIDs) + internalIDs, remoteIDs := db.SplitMessageIDPairSlice(messageIDs) if mboxID.InternalID != state.user.GetRecoveryMailboxID().InternalID { if err := state.user.GetRemote().RemoveMessagesFromMailbox(ctx, remoteIDs, mboxID.RemoteID); err != nil { @@ -479,16 +477,16 @@ func (state *State) actionRemoveMessagesFromMailboxUnchecked( func (state *State) actionRemoveMessagesFromMailbox( ctx context.Context, - tx *ent.Tx, - messageIDs []ids.MessageIDPair, - mboxID ids.MailboxIDPair, + tx db.Transaction, + messageIDs []db.MessageIDPair, + mboxID db.MailboxIDPair, ) ([]Update, error) { - haveMessageIDs, err := db.FilterMailboxContains(ctx, tx.Client(), mboxID.InternalID, messageIDs) + haveMessageIDs, err := tx.MailboxFilterContains(ctx, mboxID.InternalID, messageIDs) if err != nil { return nil, err } - messageIDs = xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { + messageIDs = xslices.Filter(messageIDs, func(messageID db.MessageIDPair) bool { return slices.Contains(haveMessageIDs, messageID.InternalID) }) @@ -501,16 +499,16 @@ func (state *State) actionRemoveMessagesFromMailbox( func (state *State) actionMoveMessages( ctx context.Context, - tx *ent.Tx, - messageIDs []ids.MessageIDPair, - mboxFromID, mboxToID ids.MailboxIDPair, + tx db.Transaction, + messageIDs []db.MessageIDPair, + mboxFromID, mboxToID db.MailboxIDPair, ) ([]Update, []db.UIDWithFlags, error) { var allUpdates []Update if mboxFromID.InternalID == mboxToID.InternalID { - internalIDs, _ := ids.SplitMessageIDPairSlice(messageIDs) + internalIDs, _ := db.SplitMessageIDPairSlice(messageIDs) - uid, err := db.BumpMailboxUIDsForMessage(ctx, tx, internalIDs, mboxToID.InternalID) + uid, err := tx.BumpMailboxUIDsForMessage(ctx, mboxToID.InternalID, internalIDs) if err != nil { return nil, nil, err } @@ -519,12 +517,12 @@ func (state *State) actionMoveMessages( } { - messageIDsToAdd, err := db.FilterMailboxContains(ctx, tx.Client(), mboxToID.InternalID, messageIDs) + messageIDsToAdd, err := tx.MailboxFilterContains(ctx, mboxToID.InternalID, messageIDs) if err != nil { return nil, nil, err } - if remMessageIDs := xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { + if remMessageIDs := xslices.Filter(messageIDs, func(messageID db.MessageIDPair) bool { return slices.Contains(messageIDsToAdd, messageID.InternalID) }); len(remMessageIDs) > 0 { updates, err := state.actionRemoveMessagesFromMailboxUnchecked(ctx, tx, remMessageIDs, mboxToID) @@ -536,16 +534,16 @@ func (state *State) actionMoveMessages( } } - messageInFromMBox, err := db.FilterMailboxContains(ctx, tx.Client(), mboxFromID.InternalID, messageIDs) + messageInFromMBox, err := tx.MailboxFilterContains(ctx, mboxFromID.InternalID, messageIDs) if err != nil { return nil, nil, err } - messagesIDsToMove := xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { + messagesIDsToMove := xslices.Filter(messageIDs, func(messageID db.MessageIDPair) bool { return slices.Contains(messageInFromMBox, messageID.InternalID) }) - internalIDs, remoteIDs := ids.SplitMessageIDPairSlice(messagesIDsToMove) + internalIDs, remoteIDs := db.SplitMessageIDPairSlice(messagesIDsToMove) shouldRemoveOldMessages, err := state.user.GetRemote().MoveMessagesFromMailbox(ctx, remoteIDs, mboxFromID.RemoteID, mboxToID.RemoteID) if err != nil { @@ -564,7 +562,7 @@ func (state *State) actionMoveMessages( func (state *State) actionAddMessageFlags( ctx context.Context, - tx *ent.Tx, + tx db.Transaction, messages []snapMsgWithSeq, addFlags imap.FlagSet, ) ([]Update, error) { @@ -577,7 +575,7 @@ func (state *State) actionAddMessageFlags( func (state *State) actionRemoveMessageFlags( ctx context.Context, - tx *ent.Tx, + tx db.Transaction, messages []snapMsgWithSeq, remFlags imap.FlagSet, ) ([]Update, error) { @@ -589,7 +587,7 @@ func (state *State) actionRemoveMessageFlags( } func (state *State) actionSetMessageFlags(ctx context.Context, - tx *ent.Tx, + tx db.Transaction, messages []snapMsgWithSeq, setFlags imap.FlagSet) ([]Update, error) { if setFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { diff --git a/internal/state/mailbox.go b/internal/state/mailbox.go index f3ce0cc8..3746584e 100644 --- a/internal/state/mailbox.go +++ b/internal/state/mailbox.go @@ -8,10 +8,9 @@ import ( "time" "github.com/ProtonMail/gluon/connector" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/response" "github.com/ProtonMail/gluon/reporter" @@ -21,7 +20,7 @@ import ( ) type Mailbox struct { - id ids.MailboxIDPair + id db.MailboxIDPair name string subscribed bool uidValidity imap.UID @@ -40,9 +39,9 @@ type AppendOnlyMailbox interface { UIDValidity() imap.UID } -func newMailbox(mbox *ent.Mailbox, state *State, snap *snapshot) *Mailbox { +func newMailbox(mbox *db.Mailbox, state *State, snap *snapshot) *Mailbox { return &Mailbox{ - id: ids.NewMailboxIDPair(mbox), + id: db.NewMailboxIDPair(mbox), name: mbox.Name, uidValidity: mbox.UIDValidity, uidNext: mbox.UIDNext, @@ -85,20 +84,20 @@ func (m *Mailbox) Count() int { } func (m *Mailbox) Flags(ctx context.Context) (imap.FlagSet, error) { - return stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { - return db.GetMailboxFlags(ctx, client, m.id.InternalID) + return stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (imap.FlagSet, error) { + return client.GetMailboxFlags(ctx, m.id.InternalID) }) } func (m *Mailbox) PermanentFlags(ctx context.Context) (imap.FlagSet, error) { - return stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { - return db.GetMailboxPermanentFlags(ctx, client, m.id.InternalID) + return stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (imap.FlagSet, error) { + return client.GetMailboxPermanentFlags(ctx, m.id.InternalID) }) } func (m *Mailbox) Attributes(ctx context.Context) (imap.FlagSet, error) { - return stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { - return db.GetMailboxAttributes(ctx, client, m.id.InternalID) + return stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (imap.FlagSet, error) { + return client.GetMailboxAttributes(ctx, m.id.InternalID) }) } @@ -147,8 +146,8 @@ func (m *Mailbox) GetMessagesWithoutFlagCount(flag string) int { } func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap.FlagSet, date time.Time) (imap.UID, error) { - if err := stateDBRead(ctx, m.state, func(ctx context.Context, client *ent.Client) error { - if messageCount, uid, err := db.GetMailboxMessageCountAndUID(ctx, client, m.snap.mboxID.InternalID); err != nil { + if err := stateDBRead(ctx, m.state, func(ctx context.Context, client db.ReadOnly) error { + if messageCount, uid, err := client.GetMailboxMessageCountAndUID(ctx, m.snap.mboxID.InternalID); err != nil { return err } else { if err := m.state.imapLimits.CheckMailBoxMessageCount(messageCount, 1); err != nil { @@ -185,30 +184,21 @@ func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap. return 0, err } - if message, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Message, error) { - message, err := db.GetMessageWithIDWithDeletedFlag(ctx, client, msgID) - if err != nil { - if ent.IsNotFound(err) { - return nil, nil - } - - return nil, err - } - - return message, nil - }); err != nil || message == nil { + if messageDeleted, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (bool, error) { + return client.GetMessageDeletedFlag(ctx, msgID) + }); err != nil { logrus.WithError(err).Warn("The message has an unknown internal ID") - } else if !message.Deleted { + } else if !messageDeleted { logrus.Debugf("Appending duplicate message with Internal ID:%v", msgID.ShortID()) // Only shuffle around messages that haven't been marked for deletion. - if res, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, []db.UIDWithFlags, error) { - remoteID, err := db.GetMessageRemoteIDFromID(ctx, tx.Client(), msgID) + if res, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, []db.UIDWithFlags, error) { + remoteID, err := tx.GetMessageRemoteID(ctx, msgID) if err != nil { return nil, nil, err } return m.state.actionAddMessagesToMailbox(ctx, tx, - []ids.MessageIDPair{{InternalID: msgID, RemoteID: remoteID}}, + []db.MessageIDPair{{InternalID: msgID, RemoteID: remoteID}}, m.id, m.snap == m.state.snap, ) @@ -229,7 +219,7 @@ func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap. } } - return stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, imap.UID, error) { + return stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, imap.UID, error) { return m.state.actionCreateMessage(ctx, tx, m.snap.mboxID, literal, flags, date, m.snap == m.state.snap, appendIntoDrafts) }) } @@ -245,7 +235,7 @@ func (m *Mailbox) Append(ctx context.Context, literal []byte, flags imap.FlagSet } // Failed to append to mailbox attempt to insert into recovery mailbox. - knownMessage, recoverErr := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, bool, error) { + knownMessage, recoverErr := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, bool, error) { return m.state.actionCreateRecoveredMessage(ctx, tx, literal, flags, date) }) if recoverErr != nil && !knownMessage { @@ -269,8 +259,8 @@ func (m *Mailbox) Copy(ctx context.Context, seq []command.SeqRange, name string) return nil, ErrOperationNotAllowed } - mbox, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { - return db.GetMailboxByName(ctx, client, name) + mbox, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (*db.Mailbox, error) { + return client.GetMailboxByName(ctx, name) }) if err != nil { return nil, ErrNoSuchMailbox @@ -281,7 +271,7 @@ func (m *Mailbox) Copy(ctx context.Context, seq []command.SeqRange, name string) return nil, err } - msgIDs := make([]ids.MessageIDPair, len(messages)) + msgIDs := make([]db.MessageIDPair, len(messages)) msgUIDs := make([]imap.UID, len(messages)) for i := 0; i < len(messages); i++ { @@ -290,11 +280,11 @@ func (m *Mailbox) Copy(ctx context.Context, seq []command.SeqRange, name string) msgIDs[i] = snapMsg.ID } - destUIDs, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, []db.UIDWithFlags, error) { + destUIDs, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, []db.UIDWithFlags, error) { if m.state.user.GetRecoveryMailboxID().InternalID == m.snap.mboxID.InternalID { - return m.state.actionCopyMessagesOutOfRecoveryMailbox(ctx, tx, msgIDs, ids.NewMailboxIDPair(mbox)) + return m.state.actionCopyMessagesOutOfRecoveryMailbox(ctx, tx, msgIDs, db.NewMailboxIDPair(mbox)) } else { - return m.state.actionAddMessagesToMailbox(ctx, tx, msgIDs, ids.NewMailboxIDPair(mbox), m.snap == m.state.snap) + return m.state.actionAddMessagesToMailbox(ctx, tx, msgIDs, db.NewMailboxIDPair(mbox), m.snap == m.state.snap) } }) if err != nil { @@ -320,8 +310,8 @@ func (m *Mailbox) Move(ctx context.Context, seq []command.SeqRange, name string) return nil, ErrOperationNotAllowed } - mbox, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { - return db.GetMailboxByName(ctx, client, name) + mbox, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (*db.Mailbox, error) { + return client.GetMailboxByName(ctx, name) }) if err != nil { return nil, ErrNoSuchMailbox @@ -332,7 +322,7 @@ func (m *Mailbox) Move(ctx context.Context, seq []command.SeqRange, name string) return nil, err } - msgIDs := make([]ids.MessageIDPair, len(messages)) + msgIDs := make([]db.MessageIDPair, len(messages)) msgUIDs := make([]imap.UID, len(messages)) for i := 0; i < len(messages); i++ { @@ -341,11 +331,11 @@ func (m *Mailbox) Move(ctx context.Context, seq []command.SeqRange, name string) msgIDs[i] = snapMsg.ID } - destUIDs, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, []db.UIDWithFlags, error) { + destUIDs, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, []db.UIDWithFlags, error) { if m.state.user.GetRecoveryMailboxID().InternalID == m.snap.mboxID.InternalID { - return m.state.actionMoveMessagesOutOfRecoveryMailbox(ctx, tx, msgIDs, ids.NewMailboxIDPair(mbox)) + return m.state.actionMoveMessagesOutOfRecoveryMailbox(ctx, tx, msgIDs, db.NewMailboxIDPair(mbox)) } else { - return m.state.actionMoveMessages(ctx, tx, msgIDs, m.snap.mboxID, ids.NewMailboxIDPair(mbox)) + return m.state.actionMoveMessages(ctx, tx, msgIDs, m.snap.mboxID, db.NewMailboxIDPair(mbox)) } }) if err != nil { @@ -369,7 +359,7 @@ func (m *Mailbox) Store(ctx context.Context, seqSet []command.SeqRange, action c return err } - return stateDBWrite(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { + return stateDBWrite(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { switch action { case command.StoreActionAddFlags: return m.state.actionAddMessageFlags(ctx, tx, messages, flags) @@ -386,7 +376,7 @@ func (m *Mailbox) Store(ctx context.Context, seqSet []command.SeqRange, action c } func (m *Mailbox) Expunge(ctx context.Context, seq []command.SeqRange) error { - var msgIDs []ids.MessageIDPair + var msgIDs []db.MessageIDPair if seq != nil { snapMsgs, err := m.snap.getMessagesInRange(ctx, seq) @@ -394,7 +384,7 @@ func (m *Mailbox) Expunge(ctx context.Context, seq []command.SeqRange) error { return err } - msgIDs = make([]ids.MessageIDPair, 0, len(snapMsgs)) + msgIDs = make([]db.MessageIDPair, 0, len(snapMsgs)) for _, v := range snapMsgs { if v.toExpunge { @@ -405,7 +395,7 @@ func (m *Mailbox) Expunge(ctx context.Context, seq []command.SeqRange) error { msgIDs = m.snap.getAllMessagesIDsMarkedDelete() } - return stateDBWrite(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { + return stateDBWrite(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { return m.state.actionRemoveMessagesFromMailbox(ctx, tx, msgIDs, m.snap.mboxID) }) } diff --git a/internal/state/mailbox_fetch.go b/internal/state/mailbox_fetch.go index fc2e8134..ec6bef85 100644 --- a/internal/state/mailbox_fetch.go +++ b/internal/state/mailbox_fetch.go @@ -9,11 +9,10 @@ import ( "sync/atomic" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" "github.com/ProtonMail/gluon/internal/contexts" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/response" "github.com/ProtonMail/gluon/rfc822" "github.com/bradenaw/juniper/parallel" @@ -28,7 +27,7 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons return err } - operations := make([]func(snapMsgWithSeq, *ent.Message, []byte) (response.Item, error), 0, len(cmd.Attributes)) + operations := make([]func(snapMsgWithSeq, *db.Message, []byte) (response.Item, error), 0, len(cmd.Attributes)) var ( needsLiteral bool @@ -84,7 +83,7 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons setSeen = true } - op := func(_ snapMsgWithSeq, _ *ent.Message, literal []byte) (response.Item, error) { + op := func(_ snapMsgWithSeq, _ *db.Message, literal []byte) (response.Item, error) { return fetchAttributeBodySection(attribute, literal) } @@ -118,8 +117,8 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons defer async.HandlePanic(m.state.panicHandler) msg := snapMessages[i] - message, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Message, error) { - return db.GetMessage(ctx, client, msg.ID.InternalID) + message, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (*db.Message, error) { + return client.GetMessage(ctx, msg.ID.InternalID) }) if err != nil { return err @@ -175,7 +174,7 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons }) if len(msgsToBeMarkedSeen) != 0 { - if err := stateDBWrite(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { + if err := stateDBWrite(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { return m.state.actionAddMessageFlags(ctx, tx, msgsToBeMarkedSeen, imap.NewFlagSet(imap.FlagSeen)) }); err != nil { return err @@ -185,47 +184,47 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons return nil } -func fetchEnvelope(_ snapMsgWithSeq, message *ent.Message, _ []byte) (response.Item, error) { +func fetchEnvelope(_ snapMsgWithSeq, message *db.Message, _ []byte) (response.Item, error) { return response.ItemEnvelope(message.Envelope), nil } -func fetchFlags(msg snapMsgWithSeq, message *ent.Message, _ []byte) (response.Item, error) { +func fetchFlags(msg snapMsgWithSeq, message *db.Message, _ []byte) (response.Item, error) { return response.ItemFlags(msg.flags), nil } -func fetchInternalDate(_ snapMsgWithSeq, message *ent.Message, _ []byte) (response.Item, error) { +func fetchInternalDate(_ snapMsgWithSeq, message *db.Message, _ []byte) (response.Item, error) { return response.ItemInternalDate(message.Date), nil } -func fetchRFC822(_ snapMsgWithSeq, _ *ent.Message, literal []byte) (response.Item, error) { +func fetchRFC822(_ snapMsgWithSeq, _ *db.Message, literal []byte) (response.Item, error) { return response.ItemRFC822Literal(literal), nil } -func fetchRFC822Header(_ snapMsgWithSeq, _ *ent.Message, literal []byte) (response.Item, error) { +func fetchRFC822Header(_ snapMsgWithSeq, _ *db.Message, literal []byte) (response.Item, error) { section := rfc822.Parse(literal) return response.ItemRFC822Header(section.Header()), nil } -func fetchRFC822Size(_ snapMsgWithSeq, message *ent.Message, _ []byte) (response.Item, error) { +func fetchRFC822Size(_ snapMsgWithSeq, message *db.Message, _ []byte) (response.Item, error) { return response.ItemRFC822Size(message.Size), nil } -func fetchRFC822Text(_ snapMsgWithSeq, _ *ent.Message, literal []byte) (response.Item, error) { +func fetchRFC822Text(_ snapMsgWithSeq, _ *db.Message, literal []byte) (response.Item, error) { section := rfc822.Parse(literal) return response.ItemRFC822Text(section.Body()), nil } -func fetchBody(_ snapMsgWithSeq, message *ent.Message, _ []byte) (response.Item, error) { +func fetchBody(_ snapMsgWithSeq, message *db.Message, _ []byte) (response.Item, error) { return response.ItemBody(message.Body), nil } -func fetchBodyStructure(_ snapMsgWithSeq, message *ent.Message, _ []byte) (response.Item, error) { +func fetchBodyStructure(_ snapMsgWithSeq, message *db.Message, _ []byte) (response.Item, error) { return response.ItemBodyStructure(message.BodyStructure), nil } -func fetchUID(msg snapMsgWithSeq, _ *ent.Message, _ []byte) (response.Item, error) { +func fetchUID(msg snapMsgWithSeq, _ *db.Message, _ []byte) (response.Item, error) { return response.ItemUID(msg.UID), nil } diff --git a/internal/state/mailbox_search.go b/internal/state/mailbox_search.go index 038a18d6..a35b5841 100644 --- a/internal/state/mailbox_search.go +++ b/internal/state/mailbox_search.go @@ -10,11 +10,10 @@ import ( "time" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" "github.com/ProtonMail/gluon/internal/contexts" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/rfc5322" "github.com/ProtonMail/gluon/rfc822" "github.com/bradenaw/juniper/parallel" @@ -88,15 +87,16 @@ func buildSearchData(ctx context.Context, m *Mailbox, op *buildSearchOpResult, m data := searchData{message: message} if op.needsMessage { - dbm, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Message, error) { - return db.GetMessageDateAndSize(ctx, client, message.ID.InternalID) - }) - if err != nil { + if err := stateDBRead(ctx, m.state, func(ctx context.Context, client db.ReadOnly) error { + date, size, err := client.GetMessageDateAndSize(ctx, message.ID.InternalID) + + data.dbMessage.size = size + data.dbMessage.date = date + + return err + }); err != nil { return searchData{}, err } - - data.dbMessage.size = dbm.Size - data.dbMessage.date = dbm.Date } if op.needsLiteral { diff --git a/internal/state/match.go b/internal/state/match.go index 81848efe..fcf819ad 100644 --- a/internal/state/match.go +++ b/internal/state/match.go @@ -6,9 +6,8 @@ import ( "regexp" "strings" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/bradenaw/juniper/xslices" ) @@ -22,12 +21,12 @@ type matchMailbox struct { Name string Subscribed bool // EntMBox should be set to nil if there is no such value. - EntMBox *ent.Mailbox + EntMBox *db.Mailbox } func getMatches( ctx context.Context, - client *ent.Client, + client db.ReadOnly, allMailboxes []matchMailbox, ref, pattern, delimiter string, subscribed bool, @@ -79,7 +78,7 @@ func getMatches( func prepareMatch( ctx context.Context, - client *ent.Client, + client db.ReadOnly, matchedName string, mbox *matchMailbox, pattern, delimiter string, @@ -108,13 +107,13 @@ func prepareMatch( if mbox.EntMBox != nil { atts = imap.NewFlagSetFromSlice(xslices.Map( - mbox.EntMBox.Edges.Attributes, - func(flag *ent.MailboxAttr) string { + mbox.EntMBox.Attributes, + func(flag *db.MailboxAttr) string { return flag.Value }, )) - recent, err := db.GetMailboxRecentCount(ctx, client, mbox.EntMBox) + recent, err := client.GetMailboxRecentCount(ctx, mbox.EntMBox.ID) if err != nil { return Match{}, false, err } diff --git a/internal/state/responders.go b/internal/state/responders.go index 4c091d2c..1bba1e2f 100644 --- a/internal/state/responders.go +++ b/internal/state/responders.go @@ -5,11 +5,9 @@ import ( "fmt" "sync" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/contexts" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/response" "github.com/ProtonMail/gluon/reporter" "github.com/bradenaw/juniper/xslices" @@ -20,7 +18,7 @@ type responderStateUpdate struct { responders []Responder } -func (r *responderStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (r *responderStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { return s.PushResponder(ctx, tx, r.responders...) } @@ -37,7 +35,7 @@ func (r *responderStateUpdate) String() string { // Seeing as this is only used right now to clear the recent flags, we can avoid a lot of necessary database locking // and transaction overhead. type responderDBUpdate interface { - apply(ctx context.Context, tx *ent.Tx) error + apply(ctx context.Context, tx db.Transaction) error } func NewMailboxIDResponderStateUpdate(id imap.InternalMailboxID, responders ...Responder) Update { @@ -63,12 +61,12 @@ type Responder interface { } type exists struct { - messageID ids.MessageIDPair + messageID db.MessageIDPair messageUID imap.UID flags imap.FlagSet } -func newExists(messageID ids.MessageIDPair, messageUID imap.UID, flags imap.FlagSet) *exists { +func newExists(messageID db.MessageIDPair, messageUID imap.UID, flags imap.FlagSet) *exists { return &exists{messageID: messageID, messageUID: messageUID, flags: flags} } @@ -81,8 +79,8 @@ type clearRecentFlagRespUpdate struct { mboxID imap.InternalMailboxID } -func (u *clearRecentFlagRespUpdate) apply(ctx context.Context, tx *ent.Tx) error { - return db.ClearRecentFlag(ctx, tx, u.mboxID, u.messageID) +func (u *clearRecentFlagRespUpdate) apply(ctx context.Context, tx db.Transaction) error { + return tx.ClearRecentFlagInMailboxOnMessage(ctx, u.mboxID, u.messageID) } // targetedExists needs to be separate so that we update the targetStateID safely when doing concurrent updates @@ -197,7 +195,7 @@ func newExistsStateUpdateWithExists(mailboxID imap.InternalMailboxID, responders } } -func (e *ExistsStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (e *ExistsStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { // This check needs to be thread safe since we don't know when a state update // will be executed. Before each of these updates run we check whether at state // target ID has been set and update for the first state that manages to run this code diff --git a/internal/state/snapshot.go b/internal/state/snapshot.go index c38e3679..8e497698 100644 --- a/internal/state/snapshot.go +++ b/internal/state/snapshot.go @@ -4,37 +4,35 @@ import ( "context" "strings" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" "github.com/ProtonMail/gluon/internal/contexts" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/ids" "github.com/bradenaw/juniper/xslices" ) type snapshot struct { - mboxID ids.MailboxIDPair + mboxID db.MailboxIDPair state *State messages *snapMsgList } -func newSnapshot(ctx context.Context, state *State, client *ent.Client, mbox *ent.Mailbox) (*snapshot, error) { - snapshotMessages, err := db.GetMailboxMessagesForNewSnapshot(ctx, client, mbox.ID) +func newSnapshot(ctx context.Context, state *State, client db.ReadOnly, mbox *db.Mailbox) (*snapshot, error) { + snapshotMessages, err := client.GetMailboxMessageForNewSnapshot(ctx, mbox.ID) if err != nil { return nil, err } snap := &snapshot{ - mboxID: ids.NewMailboxIDPair(mbox), + mboxID: db.NewMailboxIDPair(mbox), state: state, messages: newMsgList(len(snapshotMessages)), } for _, snapshotMessage := range snapshotMessages { if err := snap.messages.insert( - ids.MessageIDPair{InternalID: snapshotMessage.InternalID, RemoteID: snapshotMessage.RemoteID}, + db.MessageIDPair{InternalID: snapshotMessage.InternalID, RemoteID: snapshotMessage.RemoteID}, snapshotMessage.UID, snapshotMessage.GetFlagSet(), ); err != nil { @@ -45,9 +43,9 @@ func newSnapshot(ctx context.Context, state *State, client *ent.Client, mbox *en return snap, nil } -func newEmptySnapshot(state *State, mbox *ent.Mailbox) *snapshot { +func newEmptySnapshot(state *State, mbox *db.Mailbox) *snapshot { return &snapshot{ - mboxID: ids.NewMailboxIDPair(mbox), + mboxID: db.NewMailboxIDPair(mbox), state: state, messages: newMsgList(0), } @@ -119,14 +117,14 @@ func (snap *snapshot) getAllMessages() []snapMsgWithSeq { return result } -func (snap *snapshot) getAllMessageIDs() []ids.MessageIDPair { - return xslices.Map(snap.messages.all(), func(msg *snapMsg) ids.MessageIDPair { +func (snap *snapshot) getAllMessageIDs() []db.MessageIDPair { + return xslices.Map(snap.messages.all(), func(msg *snapMsg) db.MessageIDPair { return msg.ID }) } -func (snap *snapshot) getAllMessagesIDsMarkedDelete() []ids.MessageIDPair { - var msgs []ids.MessageIDPair +func (snap *snapshot) getAllMessagesIDsMarkedDelete() []db.MessageIDPair { + var msgs []db.MessageIDPair for _, v := range snap.messages.all() { if v.toExpunge { @@ -237,7 +235,7 @@ func (snap *snapshot) getMessagesWithoutFlagCount(flag string) int { }) } -func (snap *snapshot) appendMessage(messageID ids.MessageIDPair, uid imap.UID, flags imap.FlagSet) error { +func (snap *snapshot) appendMessage(messageID db.MessageIDPair, uid imap.UID, flags imap.FlagSet) error { return snap.messages.insert( messageID, uid, @@ -245,7 +243,7 @@ func (snap *snapshot) appendMessage(messageID ids.MessageIDPair, uid imap.UID, f ) } -func (snap *snapshot) appendMessageFromOtherState(messageID ids.MessageIDPair, uid imap.UID, flags imap.FlagSet) error { +func (snap *snapshot) appendMessageFromOtherState(messageID db.MessageIDPair, uid imap.UID, flags imap.FlagSet) error { snap.messages.insertOutOfOrder( messageID, uid, diff --git a/internal/state/snapshot_messages.go b/internal/state/snapshot_messages.go index dcfc1b4e..a73dc2ec 100644 --- a/internal/state/snapshot_messages.go +++ b/internal/state/snapshot_messages.go @@ -3,9 +3,9 @@ package state import ( "fmt" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" - "github.com/ProtonMail/gluon/internal/ids" "github.com/bradenaw/juniper/xslices" "golang.org/x/exp/slices" ) @@ -14,7 +14,7 @@ var ErrOutOfOrderUIDInsertion = fmt.Errorf("UIDs must be strictly ascending") // snapMsg is a single message inside a snapshot. type snapMsg struct { - ID ids.MessageIDPair + ID db.MessageIDPair UID imap.UID flags imap.FlagSet toExpunge bool @@ -48,7 +48,7 @@ func (list *snapMsgList) binarySearchByUID(uid imap.UID) (int, bool) { return index, ok } -func (list *snapMsgList) insert(msgID ids.MessageIDPair, msgUID imap.UID, flags imap.FlagSet) error { +func (list *snapMsgList) insert(msgID db.MessageIDPair, msgUID imap.UID, flags imap.FlagSet) error { if len(list.msg) > 0 && list.msg[len(list.msg)-1].UID >= msgUID { return fmt.Errorf("UID-Last=%v UID-Msg=%v: %w", list.msg[len(list.msg)-1].UID, msgUID, ErrOutOfOrderUIDInsertion) } @@ -67,7 +67,7 @@ func (list *snapMsgList) insert(msgID ids.MessageIDPair, msgUID imap.UID, flags return nil } -func (list *snapMsgList) insertOutOfOrder(msgID ids.MessageIDPair, msgUID imap.UID, flags imap.FlagSet) { +func (list *snapMsgList) insertOutOfOrder(msgID db.MessageIDPair, msgUID imap.UID, flags imap.FlagSet) { index, ok := list.binarySearchByUID(msgUID) if ok { panic("Duplicate UID added") diff --git a/internal/state/snapshot_messages_test.go b/internal/state/snapshot_messages_test.go index 40072ae0..314b4c5f 100644 --- a/internal/state/snapshot_messages_test.go +++ b/internal/state/snapshot_messages_test.go @@ -3,9 +3,9 @@ package state import ( "testing" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" - "github.com/ProtonMail/gluon/internal/ids" "github.com/stretchr/testify/require" ) @@ -221,8 +221,8 @@ func TestSnapListGetMessages(t *testing.T) { } } -func messageIDPair(internalID imap.InternalMessageID, remoteID imap.MessageID) ids.MessageIDPair { - return ids.MessageIDPair{InternalID: internalID, RemoteID: remoteID} +func messageIDPair(internalID imap.InternalMessageID, remoteID imap.MessageID) db.MessageIDPair { + return db.MessageIDPair{InternalID: internalID, RemoteID: remoteID} } func must[T any](val T, ok bool) T { diff --git a/internal/state/state.go b/internal/state/state.go index 1ab6f6c6..b35d8ea3 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -8,9 +8,8 @@ import ( "sync/atomic" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/response" "github.com/ProtonMail/gluon/limits" @@ -77,20 +76,20 @@ func (state *State) UserID() string { } func (state *State) List(ctx context.Context, ref, pattern string, lsub bool, fn func(map[string]Match) error) error { - return stateDBRead(ctx, state, func(ctx context.Context, client *ent.Client) error { - mailboxes, err := db.GetAllMailboxes(ctx, client) + return stateDBRead(ctx, state, func(ctx context.Context, client db.ReadOnly) error { + mailboxes, err := client.GetAllMailboxes(ctx) if err != nil { return err } recoveryMailboxID := state.user.GetRecoveryMailboxID().InternalID - recoveryMBoxMessageCount, err := db.GetMailboxMessageCount(ctx, client, recoveryMailboxID) + recoveryMBoxMessageCount, err := client.GetMailboxMessageCount(ctx, recoveryMailboxID) if err != nil { logrus.WithError(err).Error("Failed to get recovery mailbox message count, assuming empty") recoveryMBoxMessageCount = 0 } - mailboxes = xslices.Filter(mailboxes, func(mailbox *ent.Mailbox) bool { + mailboxes = xslices.Filter(mailboxes, func(mailbox *db.Mailbox) bool { if mailbox.ID == recoveryMailboxID && recoveryMBoxMessageCount == 0 { return false } @@ -101,7 +100,7 @@ func (state *State) List(ctx context.Context, ref, pattern string, lsub bool, fn case imap.Visible: return true case imap.HiddenIfEmpty: - count, err := db.GetMailboxMessageCount(ctx, client, mailbox.ID) + count, err := client.GetMailboxMessageCount(ctx, mailbox.ID) if err != nil { logrus.WithError(err).Error("Failed to get recovery mailbox message count, assuming not empty") return true @@ -113,10 +112,10 @@ func (state *State) List(ctx context.Context, ref, pattern string, lsub bool, fn } }) - var deletedSubscriptions map[imap.MailboxID]*ent.DeletedSubscription + var deletedSubscriptions map[imap.MailboxID]*db.DeletedSubscription if lsub { - deletedSubscriptions, err = db.GetDeletedSubscriptionSet(ctx, client) + deletedSubscriptions, err = client.GetDeletedSubscriptionSet(ctx) if err != nil { return err } @@ -164,8 +163,8 @@ func (state *State) List(ctx context.Context, ref, pattern string, lsub bool, fn } func (state *State) Select(ctx context.Context, name string, fn func(*Mailbox) error) error { - mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { - return db.GetMailboxByName(ctx, client, name) + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*db.Mailbox, error) { + return client.GetMailboxByName(ctx, name) }) if err != nil { return ErrNoSuchMailbox @@ -177,15 +176,15 @@ func (state *State) Select(ctx context.Context, name string, fn func(*Mailbox) e } } - snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*snapshot, error) { + snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*snapshot, error) { return newSnapshot(ctx, state, client, mbox) }) if err != nil { return err } - if err := stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { - return nil, db.ClearRecentFlags(ctx, tx, mbox.ID) + if err := stateDBWrite(ctx, state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { + return nil, tx.ClearRecentFlagsInMailbox(ctx, mbox.ID) }); err != nil { return err } @@ -197,8 +196,8 @@ func (state *State) Select(ctx context.Context, name string, fn func(*Mailbox) e } func (state *State) Examine(ctx context.Context, name string, fn func(*Mailbox) error) error { - mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { - return db.GetMailboxByName(ctx, client, name) + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*db.Mailbox, error) { + return client.GetMailboxByName(ctx, name) }) if err != nil { return ErrNoSuchMailbox @@ -210,7 +209,7 @@ func (state *State) Examine(ctx context.Context, name string, fn func(*Mailbox) } } - snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*snapshot, error) { + snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*snapshot, error) { return newSnapshot(ctx, state, client, mbox) }) if err != nil { @@ -247,10 +246,8 @@ func (state *State) Create(ctx context.Context, name string) error { } } - return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { - client := tx.Client() - - if mailboxCount, err := db.GetMailboxCount(ctx, client); err != nil { + return stateDBWrite(ctx, state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { + if mailboxCount, err := tx.GetMailboxCount(ctx); err != nil { return nil, err } else if err := state.imapLimits.CheckMailBoxCount(mailboxCount); err != nil { return nil, err @@ -263,14 +260,14 @@ func (state *State) Create(ctx context.Context, name string) error { name = strings.TrimRight(name, state.delimiter) } - if exists, err := db.MailboxExistsWithName(ctx, client, name); err != nil { + if exists, err := tx.MailboxExistsWithName(ctx, name); err != nil { return nil, err } else if exists { return nil, ErrExistingMailbox } for _, superior := range listSuperiors(name, state.delimiter) { - if exists, err := db.MailboxExistsWithName(ctx, client, superior); err != nil { + if exists, err := tx.MailboxExistsWithName(ctx, superior); err != nil { return nil, err } else if exists { continue @@ -297,13 +294,13 @@ func (state *State) Delete(ctx context.Context, name string) (bool, error) { return false, ErrOperationNotAllowed } - mboxID, err := stateDBWriteResult(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, imap.InternalMailboxID, error) { - mbox, err := db.GetMailboxByName(ctx, tx.Client(), name) + mboxID, err := stateDBWriteResult(ctx, state, func(ctx context.Context, tx db.Transaction) ([]Update, imap.InternalMailboxID, error) { + mbox, err := tx.GetMailboxByName(ctx, name) if err != nil { return nil, 0, ErrNoSuchMailbox } - update, err := state.actionDeleteMailbox(ctx, tx, ids.NewMailboxIDPair(mbox)) + update, err := state.actionDeleteMailbox(ctx, tx, db.NewMailboxIDPair(mbox)) if err != nil { return nil, 0, err } @@ -319,7 +316,7 @@ func (state *State) Delete(ctx context.Context, name string) (bool, error) { func (state *State) Rename(ctx context.Context, oldName, newName string) error { type Result struct { - MBox *ent.Mailbox + MBox *db.Mailbox MBoxesToCreate []string } @@ -327,15 +324,13 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { return ErrOperationNotAllowed } - return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { - client := tx.Client() - - mbox, err := db.GetMailboxByName(ctx, client, oldName) + return stateDBWrite(ctx, state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { + mbox, err := tx.GetMailboxByName(ctx, oldName) if err != nil { return nil, ErrNoSuchMailbox } - if exists, err := db.MailboxExistsWithName(ctx, client, newName); err != nil { + if exists, err := tx.MailboxExistsWithName(ctx, newName); err != nil { return nil, err } else if exists { return nil, ErrExistingMailbox @@ -343,7 +338,7 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { var mboxesToCreate []string for _, superior := range listSuperiors(newName, state.delimiter) { - if exists, err := db.MailboxExistsWithName(ctx, client, superior); err != nil { + if exists, err := tx.MailboxExistsWithName(ctx, superior); err != nil { return nil, err } else if exists { if superior == oldName { @@ -366,7 +361,7 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { return nil, err } - if err := db.CreateMailboxIfNotExists(ctx, tx, res, state.delimiter, uidValidity); err != nil { + if err := tx.CreateMailboxIfNotExists(ctx, res, state.delimiter, uidValidity); err != nil { return nil, err } } @@ -380,24 +375,24 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { } // Locally update all inferiors so we don't wait for update - mailboxes, err := db.GetAllMailboxes(ctx, tx.Client()) + mailboxes, err := tx.GetAllMailboxes(ctx) if err != nil { return nil, err } - inferiors := listInferiors(oldName, state.delimiter, xslices.Map(mailboxes, func(mailbox *ent.Mailbox) string { + inferiors := listInferiors(oldName, state.delimiter, xslices.Map(mailboxes, func(mailbox *db.Mailbox) string { return mailbox.Name })) for _, inferior := range inferiors { - mbox, err := db.GetMailboxByName(ctx, tx.Client(), inferior) + mbox, err := tx.GetMailboxByName(ctx, inferior) if err != nil { return nil, ErrNoSuchMailbox } newInferior := newName + strings.TrimPrefix(inferior, oldName) - if err := db.RenameMailboxWithRemoteID(ctx, tx, mbox.RemoteID, newInferior); err != nil { + if err := tx.RenameMailboxWithRemoteID(ctx, mbox.RemoteID, newInferior); err != nil { return nil, err } } @@ -407,8 +402,8 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { } func (state *State) Subscribe(ctx context.Context, name string) error { - return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { - mbox, err := db.GetMailboxByName(ctx, tx.Client(), name) + return stateDBWrite(ctx, state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { + mbox, err := tx.GetMailboxByName(ctx, name) if err != nil { return nil, ErrNoSuchMailbox } @@ -417,16 +412,16 @@ func (state *State) Subscribe(ctx context.Context, name string) error { return nil, ErrAlreadySubscribed } - return nil, mbox.Update().SetSubscribed(true).Exec(ctx) + return nil, tx.SetMailboxSubscribed(ctx, mbox.ID, true) }) } func (state *State) Unsubscribe(ctx context.Context, name string) error { - return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { - mbox, err := db.GetMailboxByName(ctx, tx.Client(), name) + return stateDBWrite(ctx, state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { + mbox, err := tx.GetMailboxByName(ctx, name) if err != nil { // If mailbox does not exist, check that if it is present in the deleted subscription table - if count, err := db.RemoveDeletedSubscriptionWithName(ctx, tx, name); err != nil { + if count, err := tx.RemoveDeletedSubscriptionWithName(ctx, name); err != nil { return nil, err } else if count == 0 { return nil, ErrNoSuchMailbox @@ -439,7 +434,7 @@ func (state *State) Unsubscribe(ctx context.Context, name string) error { return nil, ErrAlreadyUnsubscribed } - return nil, mbox.Update().SetSubscribed(false).Exec(ctx) + return nil, tx.SetMailboxSubscribed(ctx, mbox.ID, false) }) } @@ -455,8 +450,8 @@ func (state *State) Idle(ctx context.Context, fn func([]response.Response, chan } func (state *State) Mailbox(ctx context.Context, name string, fn func(*Mailbox) error) error { - mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { - return db.GetMailboxByName(ctx, client, name) + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*db.Mailbox, error) { + return client.GetMailboxByName(ctx, name) }) if err != nil { return ErrNoSuchMailbox @@ -466,7 +461,7 @@ func (state *State) Mailbox(ctx context.Context, name string, fn func(*Mailbox) return fn(newMailbox(mbox, state, state.snap)) } - snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*snapshot, error) { + snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*snapshot, error) { return newSnapshot(ctx, state, client, mbox) }) if err != nil { @@ -484,8 +479,8 @@ func (state *State) AppendOnlyMailbox(ctx context.Context, name string, fn func( return ErrOperationNotAllowed } - mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { - return db.GetMailboxByName(ctx, client, name) + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*db.Mailbox, error) { + return client.GetMailboxByName(ctx, name) }) if err != nil { return ErrNoSuchMailbox @@ -505,8 +500,8 @@ func (state *State) Selected(ctx context.Context, fn func(*Mailbox) error) error return ErrSessionNotSelected } - mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { - return db.GetMailboxByID(ctx, client, state.snap.mboxID.InternalID) + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*db.Mailbox, error) { + return client.GetMailboxByID(ctx, state.snap.mboxID.InternalID) }) if err != nil { return ErrNoSuchMailbox @@ -572,7 +567,7 @@ func (state *State) ApplyUpdate(ctx context.Context, update Update) error { return nil } - if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx db.Transaction) error { return update.Apply(ctx, tx, state) }); err != nil { reporter.MessageWithContext(ctx, @@ -599,7 +594,7 @@ func (state *State) markInvalid() { } // renameInbox creates a new mailbox and moves everything there. -func (state *State) renameInbox(ctx context.Context, tx *ent.Tx, inbox *ent.Mailbox, newName string) ([]Update, error) { +func (state *State) renameInbox(ctx context.Context, tx db.Transaction, inbox *db.Mailbox, newName string) ([]Update, error) { uidValidity, err := state.user.GenerateUIDValidity() if err != nil { return nil, err @@ -610,14 +605,14 @@ func (state *State) renameInbox(ctx context.Context, tx *ent.Tx, inbox *ent.Mail return nil, err } - messageIDs, err := db.GetMailboxMessageIDPairs(ctx, tx.Client(), inbox.ID) + messageIDs, err := tx.GetMailboxMessageIDPairs(ctx, inbox.ID) if err != nil { return nil, err } - mboxIDPair := ids.NewMailboxIDPair(mbox) + mboxIDPair := db.NewMailboxIDPair(mbox) - updates, _, err := state.actionMoveMessages(ctx, tx, messageIDs, ids.NewMailboxIDPair(inbox), mboxIDPair) + updates, _, err := state.actionMoveMessages(ctx, tx, messageIDs, db.NewMailboxIDPair(inbox), mboxIDPair) if err != nil { return nil, err } @@ -644,7 +639,7 @@ func (state *State) endIdle() { state.idleCh = nil } -func (state *State) getLiteral(ctx context.Context, messageID ids.MessageIDPair) ([]byte, error) { +func (state *State) getLiteral(ctx context.Context, messageID db.MessageIDPair) ([]byte, error) { var literal []byte storeLiteral, firstErr := state.user.GetStore().Get(messageID.InternalID) @@ -710,7 +705,7 @@ func (state *State) flushResponses(ctx context.Context, permitExpunge bool) ([]r } } - if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx db.Transaction) error { for _, update := range dbUpdates { if err := update.apply(ctx, tx); err != nil { return err @@ -725,7 +720,7 @@ func (state *State) flushResponses(ctx context.Context, permitExpunge bool) ([]r return response.Merge(responses), nil } -func (state *State) PushResponder(ctx context.Context, tx *ent.Tx, responder ...Responder) error { +func (state *State) PushResponder(ctx context.Context, tx db.Transaction, responder ...Responder) error { if state.idleCh == nil { return state.queueResponder(responder...) } @@ -833,18 +828,18 @@ func (state *State) close() error { return nil } -func stateDBRead(ctx context.Context, state *State, fn func(context.Context, *ent.Client) error) error { +func stateDBRead(ctx context.Context, state *State, fn func(context.Context, db.ReadOnly) error) error { return state.user.GetDB().Read(ctx, fn) } -func stateDBReadResult[T any](ctx context.Context, state *State, fn func(context.Context, *ent.Client) (T, error)) (T, error) { - return db.ReadResult(ctx, state.user.GetDB(), fn) +func stateDBReadResult[T any](ctx context.Context, state *State, fn func(context.Context, db.ReadOnly) (T, error)) (T, error) { + return db.ClientReadType(ctx, state.user.GetDB(), fn) } -func stateDBWrite(ctx context.Context, state *State, fn func(context.Context, *ent.Tx) ([]Update, error)) error { +func stateDBWrite(ctx context.Context, state *State, fn func(context.Context, db.Transaction) ([]Update, error)) error { var updates []Update - if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx db.Transaction) error { up, err := fn(ctx, tx) updates = up return err @@ -854,7 +849,7 @@ func stateDBWrite(ctx context.Context, state *State, fn func(context.Context, *e // need to create a separate transaction for the state updates so that import changes get written first. if len(updates) != 0 { - if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx db.Transaction) error { return state.user.QueueOrApplyStateUpdate(ctx, tx, updates...) }); err != nil { return err @@ -864,10 +859,10 @@ func stateDBWrite(ctx context.Context, state *State, fn func(context.Context, *e return nil } -func stateDBWriteResult[T any](ctx context.Context, state *State, fn func(context.Context, *ent.Tx) ([]Update, T, error)) (T, error) { +func stateDBWriteResult[T any](ctx context.Context, state *State, fn func(context.Context, db.Transaction) ([]Update, T, error)) (T, error) { var updates []Update - result, err := db.WriteResult(ctx, state.user.GetDB(), func(ctx context.Context, tx *ent.Tx) (T, error) { + result, err := db.ClientWriteType(ctx, state.user.GetDB(), func(ctx context.Context, tx db.Transaction) (T, error) { up, val, err := fn(ctx, tx) updates = up return val, err @@ -879,7 +874,7 @@ func stateDBWriteResult[T any](ctx context.Context, state *State, fn func(contex // need to create a separate transaction for the state updates so that import changes get written first. if len(updates) != 0 { - if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx db.Transaction) error { return state.user.QueueOrApplyStateUpdate(ctx, tx, updates...) }); err != nil { return result, err diff --git a/internal/state/updates.go b/internal/state/updates.go index d8e0b783..df918eb9 100644 --- a/internal/state/updates.go +++ b/internal/state/updates.go @@ -5,10 +5,9 @@ import ( "fmt" "strings" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/contexts" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/bradenaw/juniper/xslices" ) @@ -17,7 +16,7 @@ type Update interface { // Filter returns true when the state can be passed into A. Filter(s *State) bool // Apply the update to a given state. - Apply(cxt context.Context, tx *ent.Tx, s *State) error + Apply(cxt context.Context, tx db.Transaction, s *State) error String() string } @@ -31,7 +30,7 @@ func newMessageFlagsComboStateUpdate() *messageFlagsComboStateUpdate { return &messageFlagsComboStateUpdate{} } -func (u *messageFlagsComboStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *messageFlagsComboStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { for _, v := range u.updates { if err := v.Apply(ctx, tx, s); err != nil { return err @@ -53,11 +52,11 @@ type messageFlagsAddedStateUpdate struct { AllStateFilter messageIDs []imap.InternalMessageID flags imap.FlagSet - mboxID ids.MailboxIDPair + mboxID db.MailboxIDPair stateID StateID } -func newMessageFlagsAddedStateUpdate(flags imap.FlagSet, mboxID ids.MailboxIDPair, messageIDs []imap.InternalMessageID, stateID StateID) Update { +func newMessageFlagsAddedStateUpdate(flags imap.FlagSet, mboxID db.MailboxIDPair, messageIDs []imap.InternalMessageID, stateID StateID) Update { return &messageFlagsAddedStateUpdate{ flags: flags, mboxID: mboxID, @@ -66,7 +65,7 @@ func newMessageFlagsAddedStateUpdate(flags imap.FlagSet, mboxID ids.MailboxIDPai } } -func (u *messageFlagsAddedStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *messageFlagsAddedStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { for _, messageID := range u.messageIDs { newFlags := u.flags @@ -96,7 +95,7 @@ func (u *messageFlagsAddedStateUpdate) String() string { // applyMessageFlagsAdded adds the flags to the given messages. func (state *State) applyMessageFlagsAdded(ctx context.Context, - tx *ent.Tx, + tx db.Transaction, messageIDs []imap.InternalMessageID, addFlags imap.FlagSet) ([]Update, error) { if addFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { @@ -106,9 +105,7 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, // Since DB state can be more up to date then the flag state we should only emit add flag updates for values // that actually changed. - client := tx.Client() - - curFlags, err := db.GetMessageFlags(ctx, client, messageIDs) + curFlags, err := tx.GetMessagesFlags(ctx, messageIDs) if err != nil { return nil, err } @@ -150,7 +147,7 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, flagStateUpdate := newMessageFlagsComboStateUpdate() if addFlags.ContainsUnchecked(imap.FlagDeletedLowerCase) { - if err := db.SetDeletedFlag(ctx, tx, state.snap.mboxID.InternalID, messageIDs, true); err != nil { + if err := tx.SetMailboxMessagesDeletedFlag(ctx, state.snap.mboxID.InternalID, messageIDs, true); err != nil { return nil, err } @@ -169,7 +166,7 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, } } - if err := db.AddMessageFlag(ctx, tx, messagesToFlag, flag); err != nil { + if err := tx.AddFlagToMessages(ctx, messagesToFlag, flag); err != nil { return nil, err } @@ -183,11 +180,11 @@ type messageFlagsRemovedStateUpdate struct { AllStateFilter messageIDs []imap.InternalMessageID flags imap.FlagSet - mboxID ids.MailboxIDPair + mboxID db.MailboxIDPair stateID StateID } -func NewMessageFlagsRemovedStateUpdate(flags imap.FlagSet, mboxID ids.MailboxIDPair, messageIDs []imap.InternalMessageID, stateID StateID) Update { +func NewMessageFlagsRemovedStateUpdate(flags imap.FlagSet, mboxID db.MailboxIDPair, messageIDs []imap.InternalMessageID, stateID StateID) Update { return &messageFlagsRemovedStateUpdate{ flags: flags, mboxID: mboxID, @@ -196,7 +193,7 @@ func NewMessageFlagsRemovedStateUpdate(flags imap.FlagSet, mboxID ids.MailboxIDP } } -func (u *messageFlagsRemovedStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *messageFlagsRemovedStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { for _, messageID := range u.messageIDs { newFlags := u.flags @@ -226,16 +223,14 @@ func (u *messageFlagsRemovedStateUpdate) String() string { // applyMessageFlagsRemoved removes the flags from the given messages. func (state *State) applyMessageFlagsRemoved(ctx context.Context, - tx *ent.Tx, + tx db.Transaction, messageIDs []imap.InternalMessageID, remFlags imap.FlagSet) ([]Update, error) { if remFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { return nil, fmt.Errorf("the recent flag is read-only") } - client := tx.Client() - - curFlags, err := db.GetMessageFlags(ctx, client, messageIDs) + curFlags, err := tx.GetMessagesFlags(ctx, messageIDs) if err != nil { return nil, err } @@ -276,7 +271,7 @@ func (state *State) applyMessageFlagsRemoved(ctx context.Context, flagStateUpdate := newMessageFlagsComboStateUpdate() if remFlags.ContainsUnchecked(imap.FlagDeletedLowerCase) { - if err := db.SetDeletedFlag(ctx, tx, state.snap.mboxID.InternalID, messageIDs, false); err != nil { + if err := tx.SetMailboxMessagesDeletedFlag(ctx, state.snap.mboxID.InternalID, messageIDs, false); err != nil { return nil, err } @@ -295,7 +290,7 @@ func (state *State) applyMessageFlagsRemoved(ctx context.Context, } } - if err := db.RemoveMessageFlag(ctx, tx, messagesToFlag, flag); err != nil { + if err := tx.RemoveFlagFromMessages(ctx, messagesToFlag, flag); err != nil { return nil, err } @@ -309,11 +304,11 @@ type messageFlagsSetStateUpdate struct { AllStateFilter messageIDs []imap.InternalMessageID flags imap.FlagSet - mboxID ids.MailboxIDPair + mboxID db.MailboxIDPair stateID StateID } -func NewMessageFlagsSetStateUpdate(flags imap.FlagSet, mboxID ids.MailboxIDPair, messageIDs []imap.InternalMessageID, stateID StateID) Update { +func NewMessageFlagsSetStateUpdate(flags imap.FlagSet, mboxID db.MailboxIDPair, messageIDs []imap.InternalMessageID, stateID StateID) Update { return &messageFlagsSetStateUpdate{ flags: flags, mboxID: mboxID, @@ -322,7 +317,7 @@ func NewMessageFlagsSetStateUpdate(flags imap.FlagSet, mboxID ids.MailboxIDPair, } } -func (u *messageFlagsSetStateUpdate) Apply(ctx context.Context, tx *ent.Tx, state *State) error { +func (u *messageFlagsSetStateUpdate) Apply(ctx context.Context, tx db.Transaction, state *State) error { for _, messageID := range u.messageIDs { newFlags := u.flags @@ -352,7 +347,7 @@ func (u *messageFlagsSetStateUpdate) String() string { // applyMessageFlagsSet sets the flags of the given messages. func (state *State) applyMessageFlagsSet(ctx context.Context, - tx *ent.Tx, + tx db.Transaction, messageIDs []imap.InternalMessageID, setFlags imap.FlagSet) ([]Update, error) { if setFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { @@ -363,7 +358,7 @@ func (state *State) applyMessageFlagsSet(ctx context.Context, return nil, nil } - curFlags, err := db.GetMessageFlags(ctx, tx.Client(), messageIDs) + curFlags, err := tx.GetMessagesFlags(ctx, messageIDs) if err != nil { return nil, err } @@ -398,11 +393,11 @@ func (state *State) applyMessageFlagsSet(ctx context.Context, } } - if err := db.SetDeletedFlag(ctx, tx, state.snap.mboxID.InternalID, messageIDs, setFlags.Contains(imap.FlagDeleted)); err != nil { + if err := tx.SetMailboxMessagesDeletedFlag(ctx, state.snap.mboxID.InternalID, messageIDs, setFlags.Contains(imap.FlagDeleted)); err != nil { return nil, err } - if err := db.SetMessageFlags(ctx, tx, messageIDs, setFlags.Remove(imap.FlagDeleted)); err != nil { + if err := tx.SetFlagsOnMessages(ctx, messageIDs, setFlags.Remove(imap.FlagDeleted)); err != nil { return nil, err } @@ -421,7 +416,7 @@ func NewMailboxRemoteIDUpdateStateUpdate(internalID imap.InternalMailboxID, remo } } -func (u *mailboxRemoteIDUpdateStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *mailboxRemoteIDUpdateStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { s.snap.mboxID.RemoteID = u.remoteID return nil @@ -439,7 +434,7 @@ func NewMailboxDeletedStateUpdate(mboxID imap.InternalMailboxID) Update { return &mailboxDeletedStateUpdate{MBoxIDStateFilter: MBoxIDStateFilter{MboxID: mboxID}} } -func (u *mailboxDeletedStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *mailboxDeletedStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { s.markInvalid() return nil @@ -457,7 +452,7 @@ func NewUIDValidityBumpedStateUpdate() Update { return &uidValidityBumpedStateUpdate{} } -func (u *uidValidityBumpedStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *uidValidityBumpedStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { s.markInvalid() return nil diff --git a/internal/state/updates_mailbox.go b/internal/state/updates_mailbox.go index b0321c4e..940264fc 100644 --- a/internal/state/updates_mailbox.go +++ b/internal/state/updates_mailbox.go @@ -3,10 +3,8 @@ package state import ( "context" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/limits" "github.com/bradenaw/juniper/xslices" ) @@ -16,14 +14,14 @@ import ( // MoveMessagesFromMailbox moves messages from one mailbox to the other. func MoveMessagesFromMailbox( ctx context.Context, - tx *ent.Tx, + tx db.Transaction, mboxFromID, mboxToID imap.InternalMailboxID, messageIDs []imap.InternalMessageID, s *State, imapLimits limits.IMAP, removeOldMessages bool, ) ([]db.UIDWithFlags, []Update, error) { - messageCount, uid, err := db.GetMailboxMessageCountAndUID(ctx, tx.Client(), mboxToID) + messageCount, uid, err := tx.GetMailboxMessageCountAndUID(ctx, mboxToID) if err != nil { return nil, nil, err } @@ -37,12 +35,12 @@ func MoveMessagesFromMailbox( } if mboxFromID != mboxToID && removeOldMessages { - if err := db.RemoveMessagesFromMailbox(ctx, tx, messageIDs, mboxFromID); err != nil { + if err := tx.RemoveMessagesFromMailbox(ctx, mboxFromID, messageIDs); err != nil { return nil, nil, err } } - messageUIDs, err := db.AddMessagesToMailbox(ctx, tx, messageIDs, mboxToID) + messageUIDs, err := tx.AddMessagesToMailbox(ctx, mboxToID, messageIDs) if err != nil { return nil, nil, err } @@ -50,7 +48,7 @@ func MoveMessagesFromMailbox( stateUpdates := make([]Update, 0, len(messageIDs)+1) { responders := xslices.Map(messageUIDs, func(uid db.UIDWithFlags) *exists { - return newExists(ids.MessageIDPair{ + return newExists(db.MessageIDPair{ InternalID: uid.InternalID, RemoteID: uid.RemoteID, }, uid.UID, uid.GetFlagSet()) @@ -68,8 +66,13 @@ func MoveMessagesFromMailbox( } // AddMessagesToMailbox adds the messages to the given mailbox. -func AddMessagesToMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID, s *State, imapLimits limits.IMAP) ([]db.UIDWithFlags, Update, error) { - messageCount, uid, err := db.GetMailboxMessageCountAndUID(ctx, tx.Client(), mboxID) +func AddMessagesToMailbox(ctx context.Context, + tx db.Transaction, + mboxID imap.InternalMailboxID, + messageIDs []imap.InternalMessageID, + s *State, + imapLimits limits.IMAP) ([]db.UIDWithFlags, Update, error) { + messageCount, uid, err := tx.GetMailboxMessageCountAndUID(ctx, mboxID) if err != nil { return nil, nil, err } @@ -82,13 +85,13 @@ func AddMessagesToMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.InternalM return nil, nil, err } - messageUIDs, err := db.AddMessagesToMailbox(ctx, tx, messageIDs, mboxID) + messageUIDs, err := tx.AddMessagesToMailbox(ctx, mboxID, messageIDs) if err != nil { return nil, nil, err } responders := xslices.Map(messageUIDs, func(uid db.UIDWithFlags) *exists { - return newExists(ids.MessageIDPair{ + return newExists(db.MessageIDPair{ InternalID: uid.InternalID, RemoteID: uid.RemoteID, }, uid.UID, uid.GetFlagSet()) @@ -98,9 +101,9 @@ func AddMessagesToMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.InternalM } // RemoveMessagesFromMailbox removes the messages from the given mailbox. -func RemoveMessagesFromMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]Update, error) { +func RemoveMessagesFromMailbox(ctx context.Context, tx db.Transaction, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]Update, error) { if len(messageIDs) > 0 { - if err := db.RemoveMessagesFromMailbox(ctx, tx, messageIDs, mboxID); err != nil { + if err := tx.RemoveMessagesFromMailbox(ctx, mboxID, messageIDs); err != nil { return nil, err } } diff --git a/internal/state/updates_remote.go b/internal/state/updates_remote.go index 2f4e3d5f..f279156c 100644 --- a/internal/state/updates_remote.go +++ b/internal/state/updates_remote.go @@ -4,9 +4,9 @@ import ( "context" "fmt" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/contexts" - "github.com/ProtonMail/gluon/internal/db/ent" ) type RemoteAddMessageFlagsStateUpdate struct { @@ -21,7 +21,7 @@ func NewRemoteAddMessageFlagsStateUpdate(messageID imap.InternalMessageID, flag } } -func (u *RemoteAddMessageFlagsStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *RemoteAddMessageFlagsStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { return s.PushResponder(ctx, tx, NewFetch(u.MessageID, imap.NewFlagSet(u.flag), contexts.IsUID(ctx), contexts.IsSilent(ctx), false, FetchFlagOpAdd)) } @@ -41,7 +41,7 @@ func NewRemoteRemoveMessageFlagsStateUpdate(messageID imap.InternalMessageID, fl } } -func (u *RemoteRemoveMessageFlagsStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *RemoteRemoveMessageFlagsStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { return s.PushResponder(ctx, tx, NewFetch(u.MessageID, imap.NewFlagSet(u.flag), contexts.IsUID(ctx), contexts.IsSilent(ctx), false, FetchFlagOpRem)) } diff --git a/internal/state/user_interface.go b/internal/state/user_interface.go index f7174f8b..671e11ea 100644 --- a/internal/state/user_interface.go +++ b/internal/state/user_interface.go @@ -2,12 +2,10 @@ package state import ( "context" - "github.com/ProtonMail/gluon/internal/utils" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/ids" + "github.com/ProtonMail/gluon/internal/utils" "github.com/ProtonMail/gluon/store" ) @@ -19,17 +17,17 @@ type UserInterface interface { GetDelimiter() string - GetDB() *db.DB + GetDB() db.Client GetRemote() Connector GetStore() *store.WriteControlledStore - QueueOrApplyStateUpdate(ctx context.Context, tx *ent.Tx, update ...Update) error + QueueOrApplyStateUpdate(ctx context.Context, tx db.Transaction, update ...Update) error ReleaseState(ctx context.Context, st *State) error - GetRecoveryMailboxID() ids.MailboxIDPair + GetRecoveryMailboxID() db.MailboxIDPair GenerateUIDValidity() (imap.UID, error) diff --git a/internal/utils/message_hashmap.go b/internal/utils/message_hashmap.go index 122ee91d..5b190b31 100644 --- a/internal/utils/message_hashmap.go +++ b/internal/utils/message_hashmap.go @@ -1,9 +1,10 @@ package utils import ( + "sync" + "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/rfc822" - "sync" ) // MessageHashesMap tracks the hashes for a literal and it's associated internal IMAP ID. diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 8e8cd68b..92997ae7 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -2,6 +2,7 @@ package utils import ( "errors" + "github.com/google/uuid" ) diff --git a/option.go b/option.go index cd160bc9..49b6a211 100644 --- a/option.go +++ b/option.go @@ -6,6 +6,7 @@ import ( "time" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" limits2 "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/profiling" @@ -230,3 +231,15 @@ func (w withUIDValidityGenerator) config(builder *serverBuilder) { func WithUIDValidityGenerator(generator imap.UIDValidityGenerator) Option { return &withUIDValidityGenerator{generator: generator} } + +type withDBClient struct { + ci db.ClientInterface +} + +func (w withDBClient) config(builder *serverBuilder) { + builder.dbCI = w.ci +} + +func WithDBClient(ci db.ClientInterface) Option { + return &withDBClient{ci: ci} +} diff --git a/rfc5322/parser_test.go b/rfc5322/parser_test.go index f061a2cb..c9816075 100644 --- a/rfc5322/parser_test.go +++ b/rfc5322/parser_test.go @@ -2,12 +2,12 @@ package rfc5322 import ( "bytes" - "github.com/stretchr/testify/require" "net/mail" "testing" "github.com/ProtonMail/gluon/rfcparser" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func newTestRFCParser(s string) *rfcparser.Parser { diff --git a/rfc5322/validation.go b/rfc5322/validation.go index 2d3aefdd..136b2d99 100644 --- a/rfc5322/validation.go +++ b/rfc5322/validation.go @@ -3,6 +3,7 @@ package rfc5322 import ( "errors" "fmt" + "github.com/ProtonMail/gluon/rfc822" ) diff --git a/rfc5322/validation_test.go b/rfc5322/validation_test.go index d385c036..453eccba 100644 --- a/rfc5322/validation_test.go +++ b/rfc5322/validation_test.go @@ -1,8 +1,9 @@ package rfc5322 import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestValidateMessageHeaderFields_RequiredFieldsPass(t *testing.T) { diff --git a/rfc822/hash.go b/rfc822/hash.go index 201b675a..9b982acc 100644 --- a/rfc822/hash.go +++ b/rfc822/hash.go @@ -4,8 +4,9 @@ import ( "bytes" "crypto/sha256" "encoding/base64" - "github.com/sirupsen/logrus" "strings" + + "github.com/sirupsen/logrus" ) // GetMessageHash returns the hash of the given message. diff --git a/store/mock_store/store.go b/store/mock_store/store.go index 0b8f2ad8..5c2046a6 100644 --- a/store/mock_store/store.go +++ b/store/mock_store/store.go @@ -5,6 +5,7 @@ package mock_store import ( + io "io" reflect "reflect" imap "github.com/ProtonMail/gluon/imap" @@ -97,7 +98,7 @@ func (mr *MockStoreMockRecorder) List() *gomock.Call { } // Set mocks base method. -func (m *MockStore) Set(arg0 imap.InternalMessageID, arg1 []byte) error { +func (m *MockStore) Set(arg0 imap.InternalMessageID, arg1 io.Reader) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Set", arg0, arg1) ret0, _ := ret[0].(error) diff --git a/tests/account_test.go b/tests/account_test.go index a2496ae9..49ed0fe4 100644 --- a/tests/account_test.go +++ b/tests/account_test.go @@ -2,11 +2,12 @@ package tests import ( "errors" - "github.com/ProtonMail/gluon/internal/db" - "github.com/stretchr/testify/require" "os" "path/filepath" "testing" + + "github.com/ProtonMail/gluon/db" + "github.com/stretchr/testify/require" ) func TestAccountRemovalMovesDBToDeferredDeleteFolder(t *testing.T) { diff --git a/tests/db_test.go b/tests/db_test.go index ab19ec37..93d808bf 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -3,13 +3,15 @@ package tests import ( "context" - "github.com/ProtonMail/gluon/internal/db/ent" + "github.com/ProtonMail/gluon/db" "github.com/stretchr/testify/require" ) func dbCheckUserMessageCount(s *testSession, user string, expectedCount int) { - err := s.withUserDB(user, func(ent *ent.Client, ctx context.Context) { - val, err := ent.Message.Query().Count(ctx) + err := s.withUserDB(user, func(ent db.Client, ctx context.Context) { + val, err := db.ClientReadType(ctx, ent, func(ctx context.Context, only db.ReadOnly) (int, error) { + return only.GetTotalMessageCount(ctx) + }) require.NoError(s.tb, err) require.Equal(s.tb, expectedCount, val) }) diff --git a/tests/deleted_test.go b/tests/deleted_test.go index 87d4f25e..fc7c434a 100644 --- a/tests/deleted_test.go +++ b/tests/deleted_test.go @@ -1,11 +1,12 @@ package tests import ( - "github.com/stretchr/testify/require" "os" "path/filepath" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestDeleted(t *testing.T) { diff --git a/tests/server_test.go b/tests/server_test.go index 17c4267f..aedd8227 100644 --- a/tests/server_test.go +++ b/tests/server_test.go @@ -12,7 +12,9 @@ import ( "github.com/ProtonMail/gluon" "github.com/ProtonMail/gluon/connector" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" + "github.com/ProtonMail/gluon/internal/db_impl" "github.com/ProtonMail/gluon/internal/hash" "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/logging" @@ -77,6 +79,7 @@ type serverOptions struct { imapLimits limits.IMAP reporter reporter.Reporter uidValidityGenerator imap.UIDValidityGenerator + database db.ClientInterface } func (s *serverOptions) defaultUsername() string { @@ -179,6 +182,14 @@ type uidValidityGeneratorOption struct { generator imap.UIDValidityGenerator } +type withDatabaseOption struct { + database db.ClientInterface +} + +func (w withDatabaseOption) apply(options *serverOptions) { + options.database = w.database +} + func (u uidValidityGeneratorOption) apply(options *serverOptions) { options.uidValidityGenerator = u.generator } @@ -227,6 +238,10 @@ func withDatabaseDir(dir string) serverOption { return &databaseDirOption{dir: dir} } +func withDatabase(ci db.ClientInterface) serverOption { + return &withDatabaseOption{database: ci} +} + func defaultServerOptions(tb testing.TB, modifiers ...serverOption) *serverOptions { options := &serverOptions{ credentials: []credentials{{ @@ -241,6 +256,7 @@ func defaultServerOptions(tb testing.TB, modifiers ...serverOption) *serverOptio storeBuilder: &store.OnDiskStoreBuilder{}, connectorBuilder: &dummyConnectorBuilder{}, imapLimits: limits.DefaultLimits(), + database: db_impl.NewEntDB(), } for _, op := range modifiers { diff --git a/tests/session_test.go b/tests/session_test.go index 2dcb9848..6e9fcf3b 100644 --- a/tests/session_test.go +++ b/tests/session_test.go @@ -11,12 +11,11 @@ import ( "testing" "time" - "entgo.io/ent/dialect" "github.com/ProtonMail/gluon" "github.com/ProtonMail/gluon/connector" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/events" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/utils" "github.com/ProtonMail/go-mbox" "github.com/emersion/go-imap/client" @@ -113,13 +112,13 @@ func (s *testSession) newClient() *client.Client { return client } -func (s *testSession) withUserDB(user string, fn func(client *ent.Client, ctx context.Context)) error { - path, ok := s.userDBPaths[s.userIDs[user]] +func (s *testSession) withUserDB(user string, fn func(client db.Client, ctx context.Context)) error { + userID, ok := s.userIDs[user] if !ok { return fmt.Errorf("User not found") } - client, err := ent.Open(dialect.SQLite, path) + client, _, err := s.options.database.New(s.server.GetDatabasePath(), userID) if err != nil { return err }