From e8633c1c84c0204f9fab99c724ad126426d043ca Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Wed, 24 May 2023 13:59:14 +0200 Subject: [PATCH 1/4] refactor(GODT-2511): Hide ent implementation behind an interface First step towards remove ent, provide an abstract database interface that we use to implement the new implementation and compare behavior with the old version. There is now a new option to specify the database implementation for the Server and the test server. --- .golangci.yml | 1 + builder.go | 6 +- connector/mock_connector/connector.go | 64 ++- db/client.go | 52 +++ db/deferred_delete.go | 47 ++ db/errors.go | 14 + db/ops.go | 13 + db/ops_mailbox.go | 170 ++++++++ db/ops_message.go | 91 ++++ db/ops_subscription.go | 15 + db/types.go | 125 ++++++ go.sum | 5 + internal/backend/backend.go | 30 +- internal/backend/connector_updates.go | 137 +++--- internal/backend/state_user_interface_impl.go | 11 +- internal/backend/user.go | 31 +- internal/data/db.go | 18 + internal/db/ent/hook/hook.go | 291 ------------- internal/db/ent/runtime/runtime.go | 10 - internal/db_impl/db_impl.go | 10 + internal/{db => db_impl/ent_db}/db.go | 121 +++--- .../ent => db_impl/ent_db/internal}/client.go | 339 +++++++++++++-- .../ent => db_impl/ent_db/internal}/config.go | 42 +- .../ent_db/internal}/context.go | 2 +- .../ent_db/internal}/deletedsubscription.go | 20 +- .../deletedsubscription.go | 0 .../internal}/deletedsubscription/where.go | 170 ++------ .../internal}/deletedsubscription_create.go | 77 +--- .../internal}/deletedsubscription_delete.go | 57 +-- .../internal}/deletedsubscription_query.go | 248 +++++------ .../internal}/deletedsubscription_update.go | 128 +----- .../ent => db_impl/ent_db/internal}/ent.go | 201 +++++++-- .../ent_db/internal}/enttest/enttest.go | 26 +- .../ent_db/internal}/generate.go | 2 +- internal/db_impl/ent_db/internal/hook/hook.go | 283 ++++++++++++ .../ent_db/internal}/mailbox.go | 28 +- .../ent_db/internal}/mailbox/mailbox.go | 0 .../ent_db/internal}/mailbox/where.go | 274 +++--------- .../ent_db/internal}/mailbox_create.go | 107 +---- .../ent_db/internal}/mailbox_delete.go | 57 +-- .../ent_db/internal}/mailbox_query.go | 272 ++++++------ .../ent_db/internal}/mailbox_update.go | 206 ++------- .../ent_db/internal}/mailboxattr.go | 20 +- .../internal}/mailboxattr/mailboxattr.go | 0 .../ent_db/internal}/mailboxattr/where.go | 110 ++--- .../ent_db/internal}/mailboxattr_create.go | 69 +-- .../ent_db/internal}/mailboxattr_delete.go | 57 +-- .../ent_db/internal}/mailboxattr_query.go | 248 +++++------ .../ent_db/internal}/mailboxattr_update.go | 116 +---- .../ent_db/internal}/mailboxflag.go | 20 +- .../internal}/mailboxflag/mailboxflag.go | 0 .../ent_db/internal}/mailboxflag/where.go | 110 ++--- .../ent_db/internal}/mailboxflag_create.go | 69 +-- .../ent_db/internal}/mailboxflag_delete.go | 57 +-- .../ent_db/internal}/mailboxflag_query.go | 248 +++++------ .../ent_db/internal}/mailboxflag_update.go | 116 +---- .../ent_db/internal}/mailboxpermflag.go | 20 +- .../mailboxpermflag/mailboxpermflag.go | 0 .../ent_db/internal}/mailboxpermflag/where.go | 110 ++--- .../internal}/mailboxpermflag_create.go | 69 +-- .../internal}/mailboxpermflag_delete.go | 57 +-- .../ent_db/internal}/mailboxpermflag_query.go | 248 +++++------ .../internal}/mailboxpermflag_update.go | 116 +---- .../ent_db/internal}/message.go | 24 +- .../ent_db/internal}/message/message.go | 0 .../ent_db/internal}/message/where.go | 408 ++++-------------- .../ent_db/internal}/message_create.go | 119 ++--- .../ent_db/internal}/message_delete.go | 57 +-- .../ent_db/internal}/message_query.go | 260 +++++------ .../ent_db/internal}/message_update.go | 214 ++------- .../ent_db/internal}/messageflag.go | 24 +- .../internal}/messageflag/messageflag.go | 0 .../ent_db/internal}/messageflag/where.go | 111 ++--- .../ent_db/internal}/messageflag_create.go | 71 +-- .../ent_db/internal}/messageflag_delete.go | 57 +-- .../ent_db/internal}/messageflag_query.go | 257 +++++------ .../ent_db/internal}/messageflag_update.go | 118 +---- .../ent_db/internal}/migrate/migrate.go | 0 .../ent_db/internal}/migrate/schema.go | 0 .../ent_db/internal}/mutation.go | 161 ++++++- .../ent_db/internal}/predicate/predicate.go | 0 .../ent_db/internal}/runtime.go | 10 +- .../ent_db/internal/runtime/runtime.go | 10 + .../ent_db/internal}/schema/mailbox.go | 0 .../ent_db/internal}/schema/mailboxattr.go | 0 .../ent_db/internal}/schema/mailboxflag.go | 0 .../internal}/schema/mailboxpermflag.go | 0 .../ent_db/internal}/schema/message.go | 0 .../ent_db/internal}/schema/messageflag.go | 0 .../ent_db/internal}/schema/subscriptions.go | 0 .../ent_db/internal}/schema/uid.go | 0 .../{db/ent => db_impl/ent_db/internal}/tx.go | 42 +- .../ent => db_impl/ent_db/internal}/uid.go | 28 +- .../ent_db/internal}/uid/uid.go | 0 .../ent_db/internal}/uid/where.go | 112 ++--- .../ent_db/internal}/uid_create.go | 89 +--- .../ent_db/internal}/uid_delete.go | 57 +-- .../ent_db/internal}/uid_query.go | 266 ++++++------ .../ent_db/internal}/uid_update.go | 156 ++----- internal/{db => db_impl/ent_db}/mailbox.go | 132 +++--- internal/{db => db_impl/ent_db}/message.go | 196 ++++----- internal/db_impl/ent_db/ops_read.go | 332 ++++++++++++++ internal/db_impl/ent_db/ops_write.go | 224 ++++++++++ .../{db => db_impl/ent_db}/subscription.go | 14 +- internal/db_impl/ent_db/type_conversions.go | 103 +++++ internal/ids/ids.go | 68 --- internal/state/actions.go | 162 ++++--- internal/state/mailbox.go | 84 ++-- internal/state/mailbox_fetch.go | 33 +- internal/state/mailbox_search.go | 18 +- internal/state/match.go | 15 +- internal/state/responders.go | 18 +- internal/state/snapshot.go | 30 +- internal/state/snapshot_messages.go | 8 +- internal/state/snapshot_messages_test.go | 6 +- internal/state/state.go | 133 +++--- internal/state/updates.go | 59 ++- internal/state/updates_mailbox.go | 31 +- internal/state/updates_remote.go | 6 +- internal/state/user_interface.go | 10 +- option.go | 13 + store/mock_store/store.go | 3 +- tests/account_test.go | 2 +- tests/db_test.go | 8 +- tests/server_test.go | 16 + tests/session_test.go | 9 +- 126 files changed, 4713 insertions(+), 5332 deletions(-) create mode 100644 db/client.go create mode 100644 db/deferred_delete.go create mode 100644 db/errors.go create mode 100644 db/ops.go create mode 100644 db/ops_mailbox.go create mode 100644 db/ops_message.go create mode 100644 db/ops_subscription.go create mode 100644 db/types.go create mode 100644 internal/data/db.go delete mode 100644 internal/db/ent/hook/hook.go delete mode 100644 internal/db/ent/runtime/runtime.go create mode 100644 internal/db_impl/db_impl.go rename internal/{db => db_impl/ent_db}/db.go (55%) rename internal/{db/ent => db_impl/ent_db/internal}/client.go (72%) rename internal/{db/ent => db_impl/ent_db/internal}/config.go (55%) rename internal/{db/ent => db_impl/ent_db/internal}/context.go (98%) rename internal/{db/ent => db_impl/ent_db/internal}/deletedsubscription.go (87%) rename internal/{db/ent => db_impl/ent_db/internal}/deletedsubscription/deletedsubscription.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/deletedsubscription/where.go (57%) rename internal/{db/ent => db_impl/ent_db/internal}/deletedsubscription_create.go (72%) rename internal/{db/ent => db_impl/ent_db/internal}/deletedsubscription_delete.go (63%) rename internal/{db/ent => db_impl/ent_db/internal}/deletedsubscription_query.go (68%) rename internal/{db/ent => db_impl/ent_db/internal}/deletedsubscription_update.go (64%) rename internal/{db/ent => db_impl/ent_db/internal}/ent.go (63%) rename internal/{db/ent => db_impl/ent_db/internal}/enttest/enttest.go (65%) rename internal/{db/ent => db_impl/ent_db/internal}/generate.go (80%) create mode 100644 internal/db_impl/ent_db/internal/hook/hook.go rename internal/{db/ent => db_impl/ent_db/internal}/mailbox.go (90%) rename internal/{db/ent => db_impl/ent_db/internal}/mailbox/mailbox.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/mailbox/where.go (65%) rename internal/{db/ent => db_impl/ent_db/internal}/mailbox_create.go (80%) rename internal/{db/ent => db_impl/ent_db/internal}/mailbox_delete.go (62%) rename internal/{db/ent => db_impl/ent_db/internal}/mailbox_query.go (77%) rename internal/{db/ent => db_impl/ent_db/internal}/mailbox_update.go (85%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxattr.go (85%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxattr/mailboxattr.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxattr/where.go (56%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxattr_create.go (74%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxattr_delete.go (62%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxattr_query.go (67%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxattr_update.go (63%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxflag.go (85%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxflag/mailboxflag.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxflag/where.go (56%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxflag_create.go (74%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxflag_delete.go (62%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxflag_query.go (67%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxflag_update.go (63%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxpermflag.go (87%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxpermflag/mailboxpermflag.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxpermflag/where.go (56%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxpermflag_create.go (74%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxpermflag_delete.go (63%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxpermflag_query.go (68%) rename internal/{db/ent => db_impl/ent_db/internal}/mailboxpermflag_update.go (64%) rename internal/{db/ent => db_impl/ent_db/internal}/message.go (91%) rename internal/{db/ent => db_impl/ent_db/internal}/message/message.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/message/where.go (58%) rename internal/{db/ent => db_impl/ent_db/internal}/message_create.go (74%) rename internal/{db/ent => db_impl/ent_db/internal}/message_delete.go (62%) rename internal/{db/ent => db_impl/ent_db/internal}/message_query.go (73%) rename internal/{db/ent => db_impl/ent_db/internal}/message_update.go (78%) rename internal/{db/ent => db_impl/ent_db/internal}/messageflag.go (86%) rename internal/{db/ent => db_impl/ent_db/internal}/messageflag/messageflag.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/messageflag/where.go (62%) rename internal/{db/ent => db_impl/ent_db/internal}/messageflag_create.go (78%) rename internal/{db/ent => db_impl/ent_db/internal}/messageflag_delete.go (62%) rename internal/{db/ent => db_impl/ent_db/internal}/messageflag_query.go (70%) rename internal/{db/ent => db_impl/ent_db/internal}/messageflag_update.go (75%) rename internal/{db/ent => db_impl/ent_db/internal}/migrate/migrate.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/migrate/schema.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/mutation.go (96%) rename internal/{db/ent => db_impl/ent_db/internal}/predicate/predicate.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/runtime.go (87%) create mode 100644 internal/db_impl/ent_db/internal/runtime/runtime.go rename internal/{db/ent => db_impl/ent_db/internal}/schema/mailbox.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/schema/mailboxattr.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/schema/mailboxflag.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/schema/mailboxpermflag.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/schema/message.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/schema/messageflag.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/schema/subscriptions.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/schema/uid.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/tx.go (92%) rename internal/{db/ent => db_impl/ent_db/internal}/uid.go (89%) rename internal/{db/ent => db_impl/ent_db/internal}/uid/uid.go (100%) rename internal/{db/ent => db_impl/ent_db/internal}/uid/where.go (67%) rename internal/{db/ent => db_impl/ent_db/internal}/uid_create.go (78%) rename internal/{db/ent => db_impl/ent_db/internal}/uid_delete.go (61%) rename internal/{db/ent => db_impl/ent_db/internal}/uid_query.go (72%) rename internal/{db/ent => db_impl/ent_db/internal}/uid_update.go (78%) rename internal/{db => db_impl/ent_db}/mailbox.go (62%) rename internal/{db => db_impl/ent_db}/message.go (68%) create mode 100644 internal/db_impl/ent_db/ops_read.go create mode 100644 internal/db_impl/ent_db/ops_write.go rename internal/{db => db_impl/ent_db}/subscription.go (63%) create mode 100644 internal/db_impl/ent_db/type_conversions.go diff --git a/.golangci.yml b/.golangci.yml index 6abe3cd3..cc70d56d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -13,6 +13,7 @@ linters: disable: - godox # Annoying, we have too many TODOs at the moment :p - scopelint # Deprecated, replaced by exportloopref, which is enabled by default. + - errorlint # Too many false positives issues: exclude-rules: diff --git a/builder.go b/builder.go index 864388fd..56b52023 100644 --- a/builder.go +++ b/builder.go @@ -2,6 +2,8 @@ package gluon import ( "crypto/tls" + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db" "io" "os" "time" @@ -9,7 +11,6 @@ import ( "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/backend" - "github.com/ProtonMail/gluon/internal/db" "github.com/ProtonMail/gluon/internal/session" "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/profiling" @@ -36,6 +37,7 @@ type serverBuilder struct { imapLimits limits.IMAP uidValidityGenerator imap.UIDValidityGenerator panicHandler async.PanicHandler + dbCI db.ClientInterface } func newBuilder() (*serverBuilder, error) { @@ -48,6 +50,7 @@ func newBuilder() (*serverBuilder, error) { imapLimits: limits.DefaultLimits(), uidValidityGenerator: imap.DefaultEpochUIDValidityGenerator(), panicHandler: async.NoopPanicHandler{}, + dbCI: ent_db.NewEntDBBuilder(), }, nil } @@ -86,6 +89,7 @@ func (builder *serverBuilder) build() (*Server, error) { builder.loginJailTime, builder.imapLimits, builder.panicHandler, + builder.dbCI, ) if err != nil { return nil, err diff --git a/connector/mock_connector/connector.go b/connector/mock_connector/connector.go index b08a5d17..ef0de4c9 100644 --- a/connector/mock_connector/connector.go +++ b/connector/mock_connector/connector.go @@ -123,18 +123,33 @@ func (mr *MockConnectorMockRecorder) DeleteMailbox(arg0, arg1 interface{}) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMailbox", reflect.TypeOf((*MockConnector)(nil).DeleteMailbox), arg0, arg1) } -// GetUIDValidity mocks base method. -func (m *MockConnector) GetUIDValidity() imap.UID { +// GetMailboxVisibility mocks base method. +func (m *MockConnector) GetMailboxVisibility(arg0 context.Context, arg1 imap.MailboxID) imap.MailboxVisibility { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUIDValidity") - ret0, _ := ret[0].(imap.UID) + ret := m.ctrl.Call(m, "GetMailboxVisibility", arg0, arg1) + ret0, _ := ret[0].(imap.MailboxVisibility) return ret0 } -// GetUIDValidity indicates an expected call of GetUIDValidity. -func (mr *MockConnectorMockRecorder) GetUIDValidity() *gomock.Call { +// GetMailboxVisibility indicates an expected call of GetMailboxVisibility. +func (mr *MockConnectorMockRecorder) GetMailboxVisibility(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUIDValidity", reflect.TypeOf((*MockConnector)(nil).GetUIDValidity)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMailboxVisibility", reflect.TypeOf((*MockConnector)(nil).GetMailboxVisibility), arg0, arg1) +} + +// GetMessageLiteral mocks base method. +func (m *MockConnector) GetMessageLiteral(arg0 context.Context, arg1 imap.MessageID) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMessageLiteral", arg0, arg1) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMessageLiteral indicates an expected call of GetMessageLiteral. +func (mr *MockConnectorMockRecorder) GetMessageLiteral(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessageLiteral", reflect.TypeOf((*MockConnector)(nil).GetMessageLiteral), arg0, arg1) } // GetUpdates mocks base method. @@ -151,20 +166,6 @@ func (mr *MockConnectorMockRecorder) GetUpdates() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUpdates", reflect.TypeOf((*MockConnector)(nil).GetUpdates)) } -// IsMailboxVisible mocks base method. -func (m *MockConnector) IsMailboxVisible(arg0 context.Context, arg1 imap.MailboxID) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsMailboxVisible", arg0, arg1) - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsMailboxVisible indicates an expected call of IsMailboxVisible. -func (mr *MockConnectorMockRecorder) IsMailboxVisible(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsMailboxVisible", reflect.TypeOf((*MockConnector)(nil).IsMailboxVisible), arg0, arg1) -} - // MarkMessagesFlagged mocks base method. func (m *MockConnector) MarkMessagesFlagged(arg0 context.Context, arg1 []imap.MessageID, arg2 bool) error { m.ctrl.T.Helper() @@ -194,11 +195,12 @@ func (mr *MockConnectorMockRecorder) MarkMessagesSeen(arg0, arg1, arg2 interface } // MoveMessages mocks base method. -func (m *MockConnector) MoveMessages(arg0 context.Context, arg1 []imap.MessageID, arg2, arg3 imap.MailboxID) error { +func (m *MockConnector) MoveMessages(arg0 context.Context, arg1 []imap.MessageID, arg2, arg3 imap.MailboxID) (bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "MoveMessages", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 } // MoveMessages indicates an expected call of MoveMessages. @@ -221,20 +223,6 @@ func (mr *MockConnectorMockRecorder) RemoveMessagesFromMailbox(arg0, arg1, arg2 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveMessagesFromMailbox", reflect.TypeOf((*MockConnector)(nil).RemoveMessagesFromMailbox), arg0, arg1, arg2) } -// SetUIDValidity mocks base method. -func (m *MockConnector) SetUIDValidity(arg0 imap.UID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetUIDValidity", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetUIDValidity indicates an expected call of SetUIDValidity. -func (mr *MockConnectorMockRecorder) SetUIDValidity(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUIDValidity", reflect.TypeOf((*MockConnector)(nil).SetUIDValidity), arg0) -} - // UpdateMailboxName mocks base method. func (m *MockConnector) UpdateMailboxName(arg0 context.Context, arg1 imap.MailboxID, arg2 []string) error { m.ctrl.T.Helper() diff --git a/db/client.go b/db/client.go new file mode 100644 index 00000000..ef43ed6c --- /dev/null +++ b/db/client.go @@ -0,0 +1,52 @@ +package db + +import ( + "context" + "path/filepath" +) + +const ChunkLimit = 1000 + +type Client interface { + Init(ctx context.Context) error + Read(ctx context.Context, op func(context.Context, ReadOnly) error) error + Write(ctx context.Context, op func(context.Context, Transaction) error) error + Close() error +} + +type ClientInterface interface { + New(path string, userID string) (Client, bool, error) + Delete(path string, userID string) error +} + +func GetDeferredDeleteDBPath(dir string) string { + return filepath.Join(dir, "deferred_delete") +} + +func ClientReadType[T any](ctx context.Context, c Client, op func(context.Context, ReadOnly) (T, error)) (T, error) { + var result T + + err := c.Read(ctx, func(ctx context.Context, read ReadOnly) error { + var err error + + result, err = op(ctx, read) + + return err + }) + + return result, err +} + +func ClientWriteType[T any](ctx context.Context, c Client, op func(context.Context, Transaction) (T, error)) (T, error) { + var result T + + err := c.Write(ctx, func(ctx context.Context, t Transaction) error { + var err error + + result, err = op(ctx, t) + + return err + }) + + return result, err +} diff --git a/db/deferred_delete.go b/db/deferred_delete.go new file mode 100644 index 00000000..33463e55 --- /dev/null +++ b/db/deferred_delete.go @@ -0,0 +1,47 @@ +package db + +import ( + "errors" + "fmt" + "github.com/google/uuid" + "os" + "path/filepath" +) + +// DeleteDB will rename all the database files for the given user to a directory within the same folder to avoid +// issues with ent not being able to close the database on demand. The database will be cleaned up on the next +// run on the Gluon server. +func DeleteDB(dir, userID string) error { + // Rather than deleting the files immediately move them to a directory to be cleaned up later. + deferredDeletePath := GetDeferredDeleteDBPath(dir) + + if err := os.MkdirAll(deferredDeletePath, 0o700); err != nil { + return fmt.Errorf("failed to create deferred delete dir: %w", err) + } + + matchingFiles, err := filepath.Glob(filepath.Join(dir, userID+"*")) + if err != nil { + return fmt.Errorf("failed to match db files:%w", err) + } + + for _, file := range matchingFiles { + // Use new UUID to avoid conflict with existing files + if err := os.Rename(file, filepath.Join(deferredDeletePath, uuid.NewString())); err != nil { + return fmt.Errorf("failed to move db file '%v' :%w", file, err) + } + } + + return nil +} + +// DeleteDeferredDBFiles deletes all data from previous databases that were scheduled for removal. +func DeleteDeferredDBFiles(dir string) error { + deferredDeleteDir := GetDeferredDeleteDBPath(dir) + if err := os.RemoveAll(deferredDeleteDir); err != nil { + if !errors.Is(err, os.ErrNotExist) { + return err + } + } + + return nil +} diff --git a/db/errors.go b/db/errors.go new file mode 100644 index 00000000..d50f494c --- /dev/null +++ b/db/errors.go @@ -0,0 +1,14 @@ +package db + +import "errors" + +var ErrNotFound = errors.New("value not found") +var ErrTransactionFailed = errors.New("transaction failed") + +func IsErrNotFound(err error) bool { + if err == nil { + return false + } + + return errors.Is(err, ErrNotFound) +} diff --git a/db/ops.go b/db/ops.go new file mode 100644 index 00000000..cd950cf2 --- /dev/null +++ b/db/ops.go @@ -0,0 +1,13 @@ +package db + +type ReadOnly interface { + MailboxReadOps + MessageReadOps + SubscriptionReadOps +} + +type Transaction interface { + MailboxWriteOps + MessageWriteOps + SubscriptionWriteOps +} diff --git a/db/ops_mailbox.go b/db/ops_mailbox.go new file mode 100644 index 00000000..eb22d684 --- /dev/null +++ b/db/ops_mailbox.go @@ -0,0 +1,170 @@ +package db + +import ( + "context" + "github.com/ProtonMail/gluon/imap" + "strings" +) + +type MailboxReadOps interface { + MailboxExistsWithID(ctx context.Context, mboxID imap.InternalMailboxID) (bool, error) + + MailboxExistsWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (bool, error) + + MailboxExistsWithName(ctx context.Context, name string) (bool, error) + + GetMailboxIDFromRemoteID(ctx context.Context, mboxID imap.MailboxID) (imap.InternalMailboxID, error) + + GetMailboxName(ctx context.Context, mboxID imap.InternalMailboxID) (string, error) + + GetMailboxNameWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (string, error) + + GetMailboxMessageIDPairs(ctx context.Context, mboxID imap.InternalMailboxID) ([]MessageIDPair, error) + + GetAllMailboxes(ctx context.Context) ([]*Mailbox, error) + + GetAllMailboxesAsRemoteIDs(ctx context.Context) ([]imap.MailboxID, error) + + GetMailboxByName(ctx context.Context, name string) (*Mailbox, error) + + GetMailboxByID(ctx context.Context, mboxID imap.InternalMailboxID) (*Mailbox, error) + + GetMailboxByRemoteID(ctx context.Context, mboxID imap.MailboxID) (*Mailbox, error) + + GetMailboxRecentCount(ctx context.Context, mboxID imap.InternalMailboxID) (int, error) + + GetMailboxMessageCount(ctx context.Context, mboxID imap.InternalMailboxID) (int, error) + + GetMailboxMessageCountWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (int, error) + + GetMailboxFlags(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) + + GetMailboxPermanentFlags(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) + + GetMailboxAttributes(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) + + GetMailboxUID(ctx context.Context, mboxID imap.InternalMailboxID) (imap.UID, error) + + GetMailboxMessageCountAndUID(ctx context.Context, mboxID imap.InternalMailboxID) (int, imap.UID, error) + + GetMailboxMessageForNewSnapshot(ctx context.Context, mboxID imap.InternalMailboxID) ([]SnapshotMessageResult, error) + + MailboxTranslateRemoteIDs(ctx context.Context, mboxIDs []imap.MailboxID) ([]imap.InternalMailboxID, error) + + MailboxFilterContains(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []MessageIDPair) ([]imap.InternalMessageID, error) + + MailboxFilterContainsInternalID(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]imap.InternalMessageID, error) + + GetMailboxCount(ctx context.Context) (int, error) + + // GetMessageUIDsWithFlagsAfterAddOrUIDBump exploits a property of adding a message to or bumping the UIDs of existing message in mailbox. It can only be + // used if you can guarantee that the messageID list contains only IDs that have recently added or bumped in the mailbox. + GetMailboxMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]UIDWithFlags, error) +} + +type MailboxWriteOps interface { + MailboxReadOps + + CreateMailbox(ctx context.Context, + mboxID imap.MailboxID, + name string, + flags, permFlags, attrs imap.FlagSet, + uidValidity imap.UID) (*Mailbox, error) + + GetOrCreateMailbox(ctx context.Context, + mboxID imap.MailboxID, + name string, + flags, permFlags, attrs imap.FlagSet, + uidValidity imap.UID) (*Mailbox, error) + + GetOrCreateMailboxAlt(ctx context.Context, + mbox imap.Mailbox, + delimiter string, + uidValidity imap.UID) (*Mailbox, error) + + RenameMailboxWithRemoteID(ctx context.Context, mboxID imap.MailboxID, name string) error + + DeleteMailboxWithRemoteID(ctx context.Context, mboxID imap.MailboxID) error + + BumpMailboxUIDNext(ctx context.Context, mboxID imap.InternalMailboxID, count int) error + + AddMessagesToMailbox(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]UIDWithFlags, error) + + BumpMailboxUIDsForMessage(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]UIDWithFlags, error) + + RemoveMessagesFromMailbox(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) error + + ClearRecentFlagInMailboxOnMessage(ctx context.Context, mboxID imap.InternalMailboxID, messageID imap.InternalMessageID) error + + ClearRecentFlagsInMailbox(ctx context.Context, mboxID imap.InternalMailboxID) error + + CreateMailboxIfNotExists(ctx context.Context, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) error + + SetMailboxMessagesDeletedFlag(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID, deleted bool) error + + SetMailboxSubscribed(ctx context.Context, mboxID imap.InternalMailboxID, subscribed bool) error + + UpdateRemoteMailboxID(ctx context.Context, mobxID imap.InternalMailboxID, remoteID imap.MailboxID) error + + SetMailboxUIDValidity(ctx context.Context, mboxID imap.InternalMailboxID, uidValidity imap.UID) error +} + +type SnapshotMessageResult struct { + InternalID imap.InternalMessageID `json:"uid_message"` + RemoteID imap.MessageID `json:"remote_id"` + UID imap.UID `json:"uid"` + Recent bool `json:"recent"` + Deleted bool `json:"deleted"` + Flags string `json:"flags"` +} + +func (msg *SnapshotMessageResult) GetFlagSet() imap.FlagSet { + var flagSet imap.FlagSet + + if len(msg.Flags) > 0 { + flags := strings.Split(msg.Flags, ",") + flagSet = imap.NewFlagSetFromSlice(flags) + } else { + flagSet = imap.NewFlagSet() + } + + if msg.Deleted { + flagSet.AddToSelf(imap.FlagDeleted) + } + + if msg.Recent { + flagSet.AddToSelf(imap.FlagRecent) + } + + return flagSet +} + +type UIDWithFlags struct { + InternalID imap.InternalMessageID `json:"uid_message"` + RemoteID imap.MessageID `json:"remote_id"` + UID imap.UID `json:"uid"` + Recent bool `json:"recent"` + Deleted bool `json:"deleted"` + Flags string `json:"flags"` +} + +func (u *UIDWithFlags) GetFlagSet() imap.FlagSet { + var flagSet imap.FlagSet + + if len(u.Flags) > 0 { + flags := strings.Split(u.Flags, ",") + flagSet = imap.NewFlagSetFromSlice(flags) + } else { + flagSet = imap.NewFlagSet() + } + + if u.Deleted { + flagSet.AddToSelf(imap.FlagDeleted) + } + + if u.Recent { + flagSet.AddToSelf(imap.FlagRecent) + } + + return flagSet +} diff --git a/db/ops_message.go b/db/ops_message.go new file mode 100644 index 00000000..ee9da379 --- /dev/null +++ b/db/ops_message.go @@ -0,0 +1,91 @@ +package db + +import ( + "context" + "github.com/ProtonMail/gluon/imap" + "github.com/bradenaw/juniper/xslices" + "time" +) + +type MessageReadOps interface { + MessageExists(ctx context.Context, id imap.InternalMessageID) (bool, error) + + MessageExistsWithRemoteID(ctx context.Context, id imap.MessageID) (bool, error) + + GetMessage(ctx context.Context, id imap.InternalMessageID) (*Message, error) + + GetTotalMessageCount(ctx context.Context) (int, error) + + GetMessageRemoteID(ctx context.Context, id imap.InternalMessageID) (imap.MessageID, error) + + GetImportedMessageData(ctx context.Context, id imap.InternalMessageID) (*Message, error) + + GetMessageDateAndSize(ctx context.Context, id imap.InternalMessageID) (time.Time, int, error) + + GetMessageMailboxIDs(ctx context.Context, id imap.InternalMessageID) ([]imap.InternalMailboxID, error) + + GetMessagesFlags(ctx context.Context, ids []imap.InternalMessageID) ([]MessageFlagSet, error) + + GetMessageIDsMarkedAsDelete(ctx context.Context) ([]imap.InternalMessageID, error) + + GetMessageIDFromRemoteID(ctx context.Context, id imap.MessageID) (imap.InternalMessageID, error) + + GetMessageDeletedFlag(ctx context.Context, id imap.InternalMessageID) (bool, error) + + GetAllMessagesIDsAsMap(ctx context.Context) (map[imap.InternalMessageID]struct{}, error) +} + +type MessageWriteOps interface { + MessageReadOps + + CreateMessages(ctx context.Context, reqs ...*CreateMessageReq) ([]*Message, error) + + CreateMessageAndAddToMailbox(ctx context.Context, mbox imap.InternalMailboxID, req *CreateMessageReq) (imap.UID, imap.FlagSet, error) + + MarkMessageAsDeleted(ctx context.Context, id imap.InternalMessageID) error + + MarkMessageAsDeletedAndAssignRandomRemoteID(ctx context.Context, id imap.InternalMessageID) error + + MarkMessageAsDeletedWithRemoteID(ctx context.Context, id imap.MessageID) error + + DeleteMessages(ctx context.Context, ids []imap.InternalMessageID) error + + UpdateRemoteMessageID(ctx context.Context, internalID imap.InternalMessageID, remoteID imap.MessageID) error + + AddFlagToMessages(ctx context.Context, ids []imap.InternalMessageID, flag string) error + + RemoveFlagFromMessages(ctx context.Context, ids []imap.InternalMessageID, flag string) error + + SetFlagsOnMessages(ctx context.Context, ids []imap.InternalMessageID, flags imap.FlagSet) error +} + +type CreateMessageReq struct { + Message imap.Message + InternalID imap.InternalMessageID + LiteralSize int + Body string + Structure string + Envelope string +} + +type MessageFlagSet struct { + ID imap.InternalMessageID + RemoteID imap.MessageID + FlagSet imap.FlagSet +} + +func NewFlagSet(msgUID *UID, flags []*MessageFlag) imap.FlagSet { + flagSet := imap.NewFlagSetFromSlice(xslices.Map(flags, func(flag *MessageFlag) string { + return flag.Value + })) + + if msgUID.Deleted { + flagSet.AddToSelf(imap.FlagDeleted) + } + + if msgUID.Recent { + flagSet.AddToSelf(imap.FlagRecent) + } + + return flagSet +} diff --git a/db/ops_subscription.go b/db/ops_subscription.go new file mode 100644 index 00000000..272342d5 --- /dev/null +++ b/db/ops_subscription.go @@ -0,0 +1,15 @@ +package db + +import ( + "context" + "github.com/ProtonMail/gluon/imap" +) + +type SubscriptionReadOps interface { + GetDeletedSubscriptionSet(ctx context.Context) (map[imap.MailboxID]*DeletedSubscription, error) +} + +type SubscriptionWriteOps interface { + AddDeletedSubscription(ctx context.Context, mboxName string, mboxID imap.MailboxID) error + RemoveDeletedSubscriptionWithName(ctx context.Context, mboxName string) (int, error) +} diff --git a/db/types.go b/db/types.go new file mode 100644 index 00000000..7f4a7b49 --- /dev/null +++ b/db/types.go @@ -0,0 +1,125 @@ +package db + +import ( + "fmt" + "github.com/ProtonMail/gluon/imap" + "time" +) + +type MailboxIDPair struct { + InternalID imap.InternalMailboxID + RemoteID imap.MailboxID +} + +func (m *MailboxIDPair) String() string { + return fmt.Sprintf("%v::%v", m.InternalID, m.RemoteID) +} + +type MessageIDPair struct { + InternalID imap.InternalMessageID + RemoteID imap.MessageID +} + +func (m *MessageIDPair) String() string { + return fmt.Sprintf("%v::%v", m.InternalID, m.RemoteID) +} + +func NewMailboxIDPair(mbox *Mailbox) MailboxIDPair { + return MailboxIDPair{ + InternalID: mbox.ID, + RemoteID: mbox.RemoteID, + } +} + +func NewMailboxIDPairWithoutRemote(internalID imap.InternalMailboxID) MailboxIDPair { + return MailboxIDPair{ + InternalID: internalID, + RemoteID: "", + } +} + +func NewMessageIDPair(msg *Message) MessageIDPair { + return MessageIDPair{ + InternalID: msg.ID, + RemoteID: msg.RemoteID, + } +} + +func SplitMessageIDPairSlice(s []MessageIDPair) ([]imap.InternalMessageID, []imap.MessageID) { + l := len(s) + + internalMessageIDs := make([]imap.InternalMessageID, 0, l) + remoteMessageIDs := make([]imap.MessageID, 0, l) + + for _, v := range s { + internalMessageIDs = append(internalMessageIDs, v.InternalID) + remoteMessageIDs = append(remoteMessageIDs, v.RemoteID) + } + + return internalMessageIDs, remoteMessageIDs +} + +func SplitMailboxIDPairSlice(s []MailboxIDPair) ([]imap.InternalMailboxID, []imap.MailboxID) { + l := len(s) + + internalMailboxIDs := make([]imap.InternalMailboxID, 0, l) + mailboxIDs := make([]imap.MailboxID, 0, l) + + for _, v := range s { + internalMailboxIDs = append(internalMailboxIDs, v.InternalID) + mailboxIDs = append(mailboxIDs, v.RemoteID) + } + + return internalMailboxIDs, mailboxIDs +} + +type MailboxFlag struct { + ID int + Value string +} + +type MailboxAttr struct { + ID int + Value string +} + +type Mailbox struct { + ID imap.InternalMailboxID + RemoteID imap.MailboxID + Name string + UIDNext imap.UID + UIDValidity imap.UID + Subscribed bool + Flags []*MailboxFlag + PermanentFlags []*MailboxFlag + Attributes []*MailboxAttr +} + +type MessageFlag struct { + ID int + Value string +} + +type Message struct { + ID imap.InternalMessageID + RemoteID imap.MessageID + Date time.Time + Size int + Body string + BodyStructure string + Envelope string + Deleted bool + Flags []*MessageFlag + UIDs []*UID +} + +type UID struct { + UID imap.UID + Deleted bool + Recent bool +} + +type DeletedSubscription struct { + Name string + RemoteID imap.MailboxID +} diff --git a/go.sum b/go.sum index 6dcdf870..58fabfd6 100644 --- a/go.sum +++ b/go.sum @@ -51,10 +51,12 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348 h1:MtvEpTB6LX3vkb4ax0b5D2DHbNAUsen0Gx5wZoq3lV4= +github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/pierrec/lz4/v4 v4.1.17 h1:kV4Ip+/hUBC+8T6+2EgburRtkE9ef4nbY3f4dFhGjMc= github.com/pierrec/lz4/v4 v4.1.17/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= @@ -67,6 +69,8 @@ github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/f github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -113,6 +117,7 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.3.1-0.20221202221704-aa9f4b2f3d57 h1:/X0t/E4VxbZE7MLS7auvE7YICHeVvbIa9vkOVvYW/24= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 8bceb802..e2693064 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -3,6 +3,7 @@ package backend import ( "context" "fmt" + "github.com/ProtonMail/gluon/db" "sync" "sync/atomic" "time" @@ -10,9 +11,6 @@ import ( "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/reporter" @@ -51,10 +49,19 @@ type Backend struct { imapLimits limits.IMAP + database db.ClientInterface + panicHandler async.PanicHandler } -func New(dataDir, databaseDir string, storeBuilder store.Builder, delim string, loginJailTime time.Duration, imapLimits limits.IMAP, panicHandler async.PanicHandler) (*Backend, error) { +func New(dataDir, databaseDir string, + storeBuilder store.Builder, + delim string, + loginJailTime time.Duration, + imapLimits limits.IMAP, + panicHandler async.PanicHandler, + database db.ClientInterface, +) (*Backend, error) { return &Backend{ dataDir: dataDir, databaseDir: databaseDir, @@ -64,6 +71,7 @@ func New(dataDir, databaseDir string, storeBuilder store.Builder, delim string, loginJailTime: loginJailTime, imapLimits: imapLimits, panicHandler: panicHandler, + database: database, }, nil } @@ -86,7 +94,7 @@ func (b *Backend) AddUser(ctx context.Context, userID string, conn connector.Con return false, err } - db, isNew, err := db.NewDB(b.getDBDir(), userID) + db, isNew, err := b.database.New(b.getDBDir(), userID) if err != nil { if err := storeBuilder.Close(); err != nil { logrus.WithError(err).Error("Failed to close store builder") @@ -130,7 +138,7 @@ func (b *Backend) RemoveUser(ctx context.Context, userID string, removeFiles boo return err } - if err := db.DeleteDB(b.getDBDir(), userID); err != nil { + if err := b.database.Delete(b.getDBDir(), userID); err != nil { return err } } @@ -147,21 +155,21 @@ func (b *Backend) GetMailboxMessageCounts(ctx context.Context, userID string) (m return nil, ErrNoSuchUser } - return db.ReadResult(ctx, user.db, func(ctx context.Context, c *ent.Client) (map[imap.MailboxID]int, error) { + return db.ClientReadType(ctx, user.db, func(ctx context.Context, c db.ReadOnly) (map[imap.MailboxID]int, error) { counts := make(map[imap.MailboxID]int) - mailboxes, err := c.Mailbox.Query().Select(mailbox.FieldRemoteID).All(ctx) + mailboxes, err := c.GetAllMailboxesAsRemoteIDs(ctx) if err != nil { return nil, err } - for _, mailbox := range mailboxes { - messageCount, err := mailbox.QueryUIDs().Count(ctx) + for _, mailboxID := range mailboxes { + messageCount, err := c.GetMailboxMessageCountWithRemoteID(ctx, mailboxID) if err != nil { return nil, err } - counts[mailbox.RemoteID] = messageCount + counts[mailboxID] = messageCount } return counts, nil diff --git a/internal/backend/connector_updates.go b/internal/backend/connector_updates.go index bc240645..8fea5ac1 100644 --- a/internal/backend/connector_updates.go +++ b/internal/backend/connector_updates.go @@ -4,14 +4,13 @@ import ( "bytes" "context" "fmt" + "github.com/ProtonMail/gluon/db" "io" "os" "runtime" "strings" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/rfc822" @@ -79,8 +78,8 @@ func (user *user) applyMailboxCreated(ctx context.Context, update *imap.MailboxC return fmt.Errorf("attempting to create protected mailbox (recovery)") } - if exists, err := db.ReadResult(ctx, user.db, func(ctx context.Context, client *ent.Client) (bool, error) { - return db.MailboxExistsWithRemoteID(ctx, client, update.Mailbox.ID) + if exists, err := db.ClientReadType(ctx, user.db, func(ctx context.Context, client db.ReadOnly) (bool, error) { + return client.MailboxExistsWithRemoteID(ctx, update.Mailbox.ID) }); err != nil { return err } else if exists { @@ -96,16 +95,15 @@ func (user *user) applyMailboxCreated(ctx context.Context, update *imap.MailboxC return err } - return user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - if mailboxCount, err := db.GetMailboxCount(ctx, tx.Client()); err != nil { + return user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + if mailboxCount, err := tx.GetMailboxCount(ctx); err != nil { return err } else if err := user.imapLimits.CheckMailBoxCount(mailboxCount); err != nil { return err } - if _, err := db.CreateMailbox( + if _, err := tx.CreateMailbox( ctx, - tx, update.Mailbox.ID, strings.Join(update.Mailbox.Name, user.delimiter), update.Mailbox.Flags, @@ -126,21 +124,21 @@ func (user *user) applyMailboxDeleted(ctx context.Context, update *imap.MailboxD return fmt.Errorf("attempting to delete protected mailbox (recovery)") } - stateUpdate, err := db.WriteResult(ctx, user.db, func(ctx context.Context, tx *ent.Tx) (state.Update, error) { - mailbox, err := db.GetMailboxByRemoteID(ctx, tx.Client(), update.MailboxID) + stateUpdate, err := db.ClientWriteType(ctx, user.db, func(ctx context.Context, tx db.Transaction) (state.Update, error) { + mailbox, err := tx.GetMailboxByRemoteID(ctx, update.MailboxID) if err != nil { - if ent.IsNotFound(err) { + if db.IsErrNotFound(err) { return nil, nil } return nil, err } - if err := db.DeleteMailboxWithRemoteID(ctx, tx, update.MailboxID); err != nil { + if err := tx.DeleteMailboxWithRemoteID(ctx, update.MailboxID); err != nil { return nil, err } - if _, err := db.RemoveDeletedSubscriptionWithName(ctx, tx, mailbox.Name); err != nil { + if _, err := tx.RemoveDeletedSubscriptionWithName(ctx, mailbox.Name); err != nil { return nil, err } @@ -163,16 +161,14 @@ func (user *user) applyMailboxUpdated(ctx context.Context, update *imap.MailboxU return fmt.Errorf("attempting to rename protected mailbox (recovery)") } - return user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - client := tx.Client() - - if exists, err := db.MailboxExistsWithRemoteID(ctx, client, update.MailboxID); err != nil { + return user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + if exists, err := tx.MailboxExistsWithRemoteID(ctx, update.MailboxID); err != nil { return err } else if !exists { return nil } - currentName, err := db.GetMailboxNameWithRemoteID(ctx, client, update.MailboxID) + currentName, err := tx.GetMailboxNameWithRemoteID(ctx, update.MailboxID) if err != nil { return err } @@ -181,7 +177,7 @@ func (user *user) applyMailboxUpdated(ctx context.Context, update *imap.MailboxU return nil } - return db.RenameMailboxWithRemoteID(ctx, tx, update.MailboxID, strings.Join(update.MailboxName, user.delimiter)) + return tx.RenameMailboxWithRemoteID(ctx, update.MailboxID, strings.Join(update.MailboxName, user.delimiter)) }) } @@ -191,8 +187,8 @@ func (user *user) applyMailboxIDChanged(ctx context.Context, update *imap.Mailbo return fmt.Errorf("attempting to change protected mailbox (recovery) remote ID") } - return user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - if err := db.UpdateRemoteMailboxID(ctx, tx, update.InternalID, update.RemoteID); err != nil { + return user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + if err := tx.UpdateRemoteMailboxID(ctx, update.InternalID, update.RemoteID); err != nil { return err } @@ -215,9 +211,7 @@ func (user *user) applyMessagesCreated(ctx context.Context, update *imap.Message messageForMBox := make(map[imap.InternalMailboxID][]imap.InternalMessageID) mboxInternalIDMap := make(map[imap.MailboxID]imap.InternalMailboxID) - err := user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - client := tx.Client() - + err := user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { for _, message := range update.Messages { if slices.Contains(message.MailboxIDs, ids.GluonInternalRecoveryMailboxRemoteID) { logrus.Errorf("attempting to import messages into protected mailbox (recovery), skipping") @@ -226,8 +220,8 @@ func (user *user) applyMessagesCreated(ctx context.Context, update *imap.Message internalID, ok := messagesToCreateFilter[message.Message.ID] if !ok { - messageID, err := db.GetMessageIDFromRemoteID(ctx, client, message.Message.ID) - if ent.IsNotFound(err) { + messageID, err := tx.GetMessageIDFromRemoteID(ctx, message.Message.ID) + if db.IsErrNotFound(err) { internalID = imap.NewInternalMessageID() literalReader, literalSize, err := rfc822.SetHeaderValueNoMemCopy(message.Literal, ids.InternalIDKey, internalID.String()) @@ -259,7 +253,7 @@ func (user *user) applyMessagesCreated(ctx context.Context, update *imap.Message for _, mboxID := range message.MailboxIDs { v, ok := mboxInternalIDMap[mboxID] if !ok { - internalMBoxID, err := db.GetMailboxIDWithRemoteID(ctx, client, mboxID) + internalMBoxID, err := tx.GetMailboxIDFromRemoteID(ctx, mboxID) if err != nil { // If a mailbox doesn't exist and we are allowed to skip move to next mailbox. if update.IgnoreUnknownMailboxIDs { @@ -310,7 +304,7 @@ func (user *user) applyMessagesCreated(ctx context.Context, update *imap.Message } // Create message in the database - if _, err := db.CreateMessages(ctx, tx, xslices.Map(chunk, func(req *DBRequestWithLiteral) *db.CreateMessageReq { + if _, err := tx.CreateMessages(ctx, xslices.Map(chunk, func(req *DBRequestWithLiteral) *db.CreateMessageReq { return &req.CreateMessageReq })...); err != nil { return err @@ -319,7 +313,7 @@ func (user *user) applyMessagesCreated(ctx context.Context, update *imap.Message // Assign all the messages to the mailbox for mboxID, msgList := range messageForMBox { - inMailbox, err := db.FilterMailboxContainsInternalID(ctx, tx.Client(), mboxID, msgList) + inMailbox, err := tx.MailboxFilterContainsInternalID(ctx, mboxID, msgList) if err != nil { return err } @@ -358,23 +352,24 @@ func (user *user) applyMessageMailboxesUpdated(ctx context.Context, update *imap return fmt.Errorf("attempting to move messages into protected mailbox (recovery)") } - if exists, err := db.ReadResult(ctx, user.db, func(ctx context.Context, client *ent.Client) (bool, error) { - return db.MessageExistsWithRemoteID(ctx, client, update.MessageID) + if exists, err := db.ClientReadType(ctx, user.db, func(ctx context.Context, client db.ReadOnly) (bool, error) { + return client.MessageExistsWithRemoteID(ctx, update.MessageID) }); err != nil { return err } else if !exists { return state.ErrNoSuchMessage } - return user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - client := tx.Client() - - internalMsgID, err := db.GetMessageIDFromRemoteID(ctx, client, update.MessageID) + return user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + internalMsgID, err := tx.GetMessageIDFromRemoteID(ctx, update.MessageID) if err != nil { + if db.IsErrNotFound(err) { + return state.ErrNoSuchMessage + } return err } - internalMBoxIDs, err := db.TranslateRemoteMailboxIDs(ctx, client, update.MailboxIDs) + internalMBoxIDs, err := tx.MailboxTranslateRemoteIDs(ctx, update.MailboxIDs) if err != nil { return err } @@ -393,19 +388,19 @@ func (user *user) applyMessageMailboxesUpdated(ctx context.Context, update *imap // applyMessageFlagsUpdated applies a MessageFlagsUpdated update. func (user *user) applyMessageFlagsUpdated(ctx context.Context, update *imap.MessageFlagsUpdated) error { - if exists, err := db.ReadResult(ctx, user.db, func(ctx context.Context, client *ent.Client) (bool, error) { - return db.MessageExistsWithRemoteID(ctx, client, update.MessageID) + if exists, err := db.ClientReadType(ctx, user.db, func(ctx context.Context, client db.ReadOnly) (bool, error) { + return client.MessageExistsWithRemoteID(ctx, update.MessageID) }); err != nil { return err } else if !exists { return state.ErrNoSuchMessage } - return user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - internalMsgID, err := db.GetMessageIDFromRemoteID(ctx, tx.Client(), update.MessageID) + return user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + internalMsgID, err := tx.GetMessageIDFromRemoteID(ctx, update.MessageID) if err != nil { - if ent.IsNotFound(err) { + if db.IsErrNotFound(err) { return state.ErrNoSuchMessage } return err @@ -421,8 +416,8 @@ func (user *user) applyMessageFlagsUpdated(ctx context.Context, update *imap.Mes // applyMessageIDChanged applies a MessageIDChanged update. func (user *user) applyMessageIDChanged(ctx context.Context, update *imap.MessageIDChanged) error { - if err := user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - return db.UpdateRemoteMessageID(ctx, tx, update.InternalID, update.RemoteID) + if err := user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + return tx.UpdateRemoteMessageID(ctx, update.InternalID, update.RemoteID) }); err != nil { return err } @@ -436,8 +431,8 @@ func (user *user) applyMessageIDChanged(ctx context.Context, update *imap.Messag return nil } -func (user *user) setMessageMailboxes(ctx context.Context, tx *ent.Tx, messageID imap.InternalMessageID, mboxIDs []imap.InternalMailboxID) error { - curMailboxIDs, err := db.GetMessageMailboxIDs(ctx, tx.Client(), messageID) +func (user *user) setMessageMailboxes(ctx context.Context, tx db.Transaction, messageID imap.InternalMessageID, mboxIDs []imap.InternalMailboxID) error { + curMailboxIDs, err := tx.GetMessageMailboxIDs(ctx, messageID) if err != nil { return err } @@ -458,7 +453,7 @@ func (user *user) setMessageMailboxes(ctx context.Context, tx *ent.Tx, messageID } // applyMessagesAddedToMailbox adds the messages to the given mailbox. -func (user *user) applyMessagesAddedToMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { +func (user *user) applyMessagesAddedToMailbox(ctx context.Context, tx db.Transaction, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { messageUIDs, update, err := state.AddMessagesToMailbox(ctx, tx, mboxID, messageIDs, nil, user.imapLimits) if err != nil { return nil, err @@ -470,7 +465,7 @@ func (user *user) applyMessagesAddedToMailbox(ctx context.Context, tx *ent.Tx, m } // applyMessagesRemovedFromMailbox removes the messages from the given mailbox. -func (user *user) applyMessagesRemovedFromMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) error { +func (user *user) applyMessagesRemovedFromMailbox(ctx context.Context, tx db.Transaction, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) error { updates, err := state.RemoveMessagesFromMailbox(ctx, tx, mboxID, messageIDs) if err != nil { return err @@ -483,8 +478,8 @@ func (user *user) applyMessagesRemovedFromMailbox(ctx context.Context, tx *ent.T return nil } -func (user *user) setMessageFlags(ctx context.Context, tx *ent.Tx, messageID imap.InternalMessageID, flags imap.FlagSet) error { - curFlags, err := db.GetMessageFlags(ctx, tx.Client(), []imap.InternalMessageID{messageID}) +func (user *user) setMessageFlags(ctx context.Context, tx db.Transaction, messageID imap.InternalMessageID, flags imap.FlagSet) error { + curFlags, err := tx.GetMessagesFlags(ctx, []imap.InternalMessageID{messageID}) if err != nil { return err } @@ -510,8 +505,8 @@ func (user *user) setMessageFlags(ctx context.Context, tx *ent.Tx, messageID ima return nil } -func (user *user) addMessageFlags(ctx context.Context, tx *ent.Tx, messageID imap.InternalMessageID, flag string) error { - if err := db.AddMessageFlag(ctx, tx, []imap.InternalMessageID{messageID}, flag); err != nil { +func (user *user) addMessageFlags(ctx context.Context, tx db.Transaction, messageID imap.InternalMessageID, flag string) error { + if err := tx.AddFlagToMessages(ctx, []imap.InternalMessageID{messageID}, flag); err != nil { return err } @@ -520,8 +515,8 @@ func (user *user) addMessageFlags(ctx context.Context, tx *ent.Tx, messageID ima return nil } -func (user *user) removeMessageFlags(ctx context.Context, tx *ent.Tx, messageID imap.InternalMessageID, flag string) error { - if err := db.RemoveMessageFlag(ctx, tx, []imap.InternalMessageID{messageID}, flag); err != nil { +func (user *user) removeMessageFlags(ctx context.Context, tx db.Transaction, messageID imap.InternalMessageID, flag string) error { + if err := tx.RemoveFlagFromMessages(ctx, []imap.InternalMessageID{messageID}, flag); err != nil { return err } @@ -533,25 +528,25 @@ func (user *user) removeMessageFlags(ctx context.Context, tx *ent.Tx, messageID func (user *user) applyMessageDeleted(ctx context.Context, update *imap.MessageDeleted) error { var stateUpdates []state.Update - if err := user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - if err := db.MarkMessageAsDeletedWithRemoteID(ctx, tx, update.MessageID); err != nil { - if ent.IsNotFound(err) { + if err := user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + if err := tx.MarkMessageAsDeletedWithRemoteID(ctx, update.MessageID); err != nil { + if db.IsErrNotFound(err) { return nil } return err } - internalMessageID, err := db.GetMessageIDFromRemoteID(ctx, tx.Client(), update.MessageID) + internalMessageID, err := tx.GetMessageIDFromRemoteID(ctx, update.MessageID) if err != nil { - if ent.IsNotFound(err) { + if db.IsErrNotFound(err) { return nil } return err } - mailboxes, err := db.GetMessageMailboxIDs(ctx, tx.Client(), internalMessageID) + mailboxes, err := tx.GetMessageMailboxIDs(ctx, internalMessageID) if err != nil { return err } @@ -582,10 +577,10 @@ func (user *user) applyMessageDeleted(ctx context.Context, update *imap.MessageD func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageUpdated) error { log := logrus.WithField("message updated", update.Message.ID.ShortID()) - internalMessageID, err := db.ReadResult(ctx, user.db, func(ctx context.Context, client *ent.Client) (imap.InternalMessageID, error) { - return db.GetMessageIDFromRemoteID(ctx, client, update.Message.ID) + internalMessageID, err := db.ClientReadType(ctx, user.db, func(ctx context.Context, client db.ReadOnly) (imap.InternalMessageID, error) { + return client.GetMessageIDFromRemoteID(ctx, update.Message.ID) }) - if ent.IsNotFound(err) { + if db.IsErrNotFound(err) { if update.AllowCreate { log.Warn("Message not found, creating it instead") @@ -603,7 +598,7 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU return err } - return user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + return user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { // compare and see if the literal has changed. onDiskLiteral, err := user.store.Get(internalMessageID) if err != nil { @@ -630,7 +625,7 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU targetMailboxes := make([]imap.InternalMailboxID, 0, len(update.MailboxIDs)) for _, mbox := range update.MailboxIDs { - internalMBoxID, err := db.GetMailboxIDFromRemoteID(ctx, tx.Client(), mbox) + internalMBoxID, err := tx.GetMailboxIDFromRemoteID(ctx, mbox) if err != nil { return err } @@ -645,7 +640,7 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU var stateUpdates []state.Update { // delete the message and remove from the mailboxes. - mailboxes, err := db.GetMessageMailboxIDs(ctx, tx.Client(), internalMessageID) + mailboxes, err := tx.GetMessageMailboxIDs(ctx, internalMessageID) if err != nil { return err } @@ -663,7 +658,7 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU // We need change the old remote id as it will break our table constraint otherwise and everything // will silently fail. - if err := db.MarkMessageAsDeletedAndAssignRandomRemoteID(ctx, tx, internalMessageID); err != nil { + if err := tx.MarkMessageAsDeletedAndAssignRandomRemoteID(ctx, internalMessageID); err != nil { return err } } @@ -685,7 +680,7 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU InternalID: newInternalID, } - if m, err := db.CreateMessages(ctx, tx, request); err != nil { + if m, err := tx.CreateMessages(ctx, request); err != nil { return err } else if len(m) == 0 { return fmt.Errorf("no messages were inserted") @@ -696,7 +691,7 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU } for _, mbox := range update.MailboxIDs { - internalMBoxID, err := db.GetMailboxIDFromRemoteID(ctx, tx.Client(), mbox) + internalMBoxID, err := tx.GetMailboxIDFromRemoteID(ctx, mbox) if err != nil { return err } @@ -721,8 +716,8 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU // applyUIDValidityBumped applies a UIDValidityBumped event to the user. func (user *user) applyUIDValidityBumped(ctx context.Context, update *imap.UIDValidityBumped) error { - if err := user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - mailboxes, err := db.GetAllMailboxes(ctx, tx.Client()) + if err := user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + mailboxes, err := tx.GetAllMailboxes(ctx) if err != nil { return err } @@ -733,7 +728,7 @@ func (user *user) applyUIDValidityBumped(ctx context.Context, update *imap.UIDVa return err } - if _, err := mailbox.Update().SetUIDValidity(uidValidity).Save(ctx); err != nil { + if err := tx.SetMailboxUIDValidity(ctx, mailbox.ID, uidValidity); err != nil { return err } } diff --git a/internal/backend/state_user_interface_impl.go b/internal/backend/state_user_interface_impl.go index aebe3dd7..a348f518 100644 --- a/internal/backend/state_user_interface_impl.go +++ b/internal/backend/state_user_interface_impl.go @@ -2,11 +2,10 @@ package backend import ( "context" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/internal/utils" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/store" @@ -31,7 +30,7 @@ func (s *StateUserInterfaceImpl) GetDelimiter() string { return s.u.delimiter } -func (s *StateUserInterfaceImpl) GetDB() *db.DB { +func (s *StateUserInterfaceImpl) GetDB() db.Client { return s.u.db } @@ -43,7 +42,7 @@ func (s *StateUserInterfaceImpl) GetStore() *store.WriteControlledStore { return s.u.store } -func (s *StateUserInterfaceImpl) QueueOrApplyStateUpdate(ctx context.Context, tx *ent.Tx, updates ...state.Update) error { +func (s *StateUserInterfaceImpl) QueueOrApplyStateUpdate(ctx context.Context, tx db.Transaction, updates ...state.Update) error { // If we detect a state id in the context, it means this function call is a result of a User interaction. // When that happens the update needs to be applied to the state matching the state ID immediately. If no such // stateID exists or the context information is not present, all updates are queued for later execution. @@ -84,8 +83,8 @@ func (s *StateUserInterfaceImpl) GenerateUIDValidity() (imap.UID, error) { return s.u.uidValidityGenerator.Generate() } -func (s *StateUserInterfaceImpl) GetRecoveryMailboxID() ids.MailboxIDPair { - return ids.MailboxIDPair{ +func (s *StateUserInterfaceImpl) GetRecoveryMailboxID() db.MailboxIDPair { + return db.MailboxIDPair{ InternalID: s.u.recoveryMailboxID, RemoteID: ids.GluonInternalRecoveryMailboxRemoteID, } diff --git a/internal/backend/user.go b/internal/backend/user.go index cb923e7c..4c834850 100644 --- a/internal/backend/user.go +++ b/internal/backend/user.go @@ -7,9 +7,8 @@ import ( "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/connector" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/internal/utils" @@ -30,7 +29,7 @@ type user struct { store *store.WriteControlledStore delimiter string - db *db.DB + db db.Client states map[state.StateID]*state.State statesLock sync.RWMutex @@ -53,7 +52,7 @@ type user struct { func newUser( ctx context.Context, userID string, - database *db.DB, + database db.Client, conn connector.Connector, st store.Store, delimiter string, @@ -68,7 +67,7 @@ func newUser( recoveredMessageHashes := utils.NewMessageHashesMap() // Create recovery mailbox if it does not exist - recoveryMBox, err := db.WriteResult(ctx, database, func(ctx context.Context, tx *ent.Tx) (*ent.Mailbox, error) { + recoveryMBox, err := db.ClientWriteType(ctx, database, func(ctx context.Context, tx db.Transaction) (*db.Mailbox, error) { uidValidity, err := uidValidityGenerator.Generate() if err != nil { return nil, err @@ -83,13 +82,13 @@ func newUser( Attributes: imap.NewFlagSet(imap.AttrNoInferiors), } - recoveryMBox, err := db.GetOrCreateMailbox(ctx, tx, mbox, delimiter, uidValidity) + recoveryMBox, err := tx.GetOrCreateMailboxAlt(ctx, mbox, delimiter, uidValidity) if err != nil { return nil, err } // Pre-fill the message hashes map - messages, err := db.GetMailboxMessageIDPairs(ctx, tx.Client(), recoveryMBox.ID) + messages, err := tx.GetMailboxMessageIDPairs(ctx, recoveryMBox.ID) if err != nil { return nil, err } @@ -218,13 +217,13 @@ func (user *user) close(ctx context.Context) error { func (user *user) deleteAllMessagesMarkedDeleted(ctx context.Context) error { // Delete messages in database first before deleting from the storage to avoid data loss. - ids, err := db.WriteResult(ctx, user.db, func(ctx context.Context, tx *ent.Tx) ([]imap.InternalMessageID, error) { - ids, err := db.GetMessageIDsMarkedDeleted(ctx, tx.Client()) + ids, err := db.ClientWriteType(ctx, user.db, func(ctx context.Context, tx db.Transaction) ([]imap.InternalMessageID, error) { + ids, err := tx.GetMessageIDsMarkedAsDelete(ctx) if err != nil { return nil, err } - if err := db.DeleteMessages(ctx, tx, ids...); err != nil { + if err := tx.DeleteMessages(ctx, ids); err != nil { return nil, err } @@ -270,8 +269,8 @@ func (user *user) newState() (*state.State, error) { } func (user *user) removeState(ctx context.Context, st *state.State) error { - messageIDs, err := db.ReadResult(ctx, user.db, func(ctx context.Context, client *ent.Client) ([]imap.InternalMessageID, error) { - return db.GetMessageIDsMarkedDeleted(ctx, client) + messageIDs, err := db.ClientReadType(ctx, user.db, func(ctx context.Context, client db.ReadOnly) ([]imap.InternalMessageID, error) { + return client.GetMessageIDsMarkedAsDelete(ctx) }) if err != nil { return err @@ -309,8 +308,8 @@ func (user *user) removeState(ctx context.Context, st *state.State) error { defer user.statesWG.Done() // Delete messages in database first before deleting from the storage to avoid data loss. - if err := user.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error { - if err := db.DeleteMessages(ctx, tx, messageIDs...); err != nil { + if err := user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { + if err := tx.DeleteMessages(ctx, messageIDs); err != nil { return err } @@ -356,8 +355,8 @@ func (user *user) cleanupStaleStoreData(ctx context.Context) error { return err } - dbIdMap, err := db.ReadResult(ctx, user.db, func(ctx context.Context, client *ent.Client) (map[imap.InternalMessageID]struct{}, error) { - return db.GetAllMessagesIDsAsMap(ctx, client) + dbIdMap, err := db.ClientReadType(ctx, user.db, func(ctx context.Context, client db.ReadOnly) (map[imap.InternalMessageID]struct{}, error) { + return client.GetAllMessagesIDsAsMap(ctx) }) if err != nil { return err diff --git a/internal/data/db.go b/internal/data/db.go new file mode 100644 index 00000000..909c0fd6 --- /dev/null +++ b/internal/data/db.go @@ -0,0 +1,18 @@ +package data + +import ( + "errors" + "io/fs" + "os" +) + +// pathExists returns whether the given file exists. +func pathExists(path string) (bool, error) { + if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) { + return false, nil + } else if err != nil { + return false, err + } + + return true, nil +} diff --git a/internal/db/ent/hook/hook.go b/internal/db/ent/hook/hook.go deleted file mode 100644 index 1dc0f3f3..00000000 --- a/internal/db/ent/hook/hook.go +++ /dev/null @@ -1,291 +0,0 @@ -// Code generated by ent, DO NOT EDIT. - -package hook - -import ( - "context" - "fmt" - - "github.com/ProtonMail/gluon/internal/db/ent" -) - -// The DeletedSubscriptionFunc type is an adapter to allow the use of ordinary -// function as DeletedSubscription mutator. -type DeletedSubscriptionFunc func(context.Context, *ent.DeletedSubscriptionMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f DeletedSubscriptionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.DeletedSubscriptionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DeletedSubscriptionMutation", m) - } - return f(ctx, mv) -} - -// The MailboxFunc type is an adapter to allow the use of ordinary -// function as Mailbox mutator. -type MailboxFunc func(context.Context, *ent.MailboxMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f MailboxFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MailboxMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MailboxMutation", m) - } - return f(ctx, mv) -} - -// The MailboxAttrFunc type is an adapter to allow the use of ordinary -// function as MailboxAttr mutator. -type MailboxAttrFunc func(context.Context, *ent.MailboxAttrMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f MailboxAttrFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MailboxAttrMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MailboxAttrMutation", m) - } - return f(ctx, mv) -} - -// The MailboxFlagFunc type is an adapter to allow the use of ordinary -// function as MailboxFlag mutator. -type MailboxFlagFunc func(context.Context, *ent.MailboxFlagMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f MailboxFlagFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MailboxFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MailboxFlagMutation", m) - } - return f(ctx, mv) -} - -// The MailboxPermFlagFunc type is an adapter to allow the use of ordinary -// function as MailboxPermFlag mutator. -type MailboxPermFlagFunc func(context.Context, *ent.MailboxPermFlagMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f MailboxPermFlagFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MailboxPermFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MailboxPermFlagMutation", m) - } - return f(ctx, mv) -} - -// The MessageFunc type is an adapter to allow the use of ordinary -// function as Message mutator. -type MessageFunc func(context.Context, *ent.MessageMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f MessageFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MessageMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MessageMutation", m) - } - return f(ctx, mv) -} - -// The MessageFlagFunc type is an adapter to allow the use of ordinary -// function as MessageFlag mutator. -type MessageFlagFunc func(context.Context, *ent.MessageFlagMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f MessageFlagFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.MessageFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MessageFlagMutation", m) - } - return f(ctx, mv) -} - -// The UIDFunc type is an adapter to allow the use of ordinary -// function as UID mutator. -type UIDFunc func(context.Context, *ent.UIDMutation) (ent.Value, error) - -// Mutate calls f(ctx, m). -func (f UIDFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.UIDMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UIDMutation", m) - } - return f(ctx, mv) -} - -// Condition is a hook condition function. -type Condition func(context.Context, ent.Mutation) bool - -// And groups conditions with the AND operator. -func And(first, second Condition, rest ...Condition) Condition { - return func(ctx context.Context, m ent.Mutation) bool { - if !first(ctx, m) || !second(ctx, m) { - return false - } - for _, cond := range rest { - if !cond(ctx, m) { - return false - } - } - return true - } -} - -// Or groups conditions with the OR operator. -func Or(first, second Condition, rest ...Condition) Condition { - return func(ctx context.Context, m ent.Mutation) bool { - if first(ctx, m) || second(ctx, m) { - return true - } - for _, cond := range rest { - if cond(ctx, m) { - return true - } - } - return false - } -} - -// Not negates a given condition. -func Not(cond Condition) Condition { - return func(ctx context.Context, m ent.Mutation) bool { - return !cond(ctx, m) - } -} - -// HasOp is a condition testing mutation operation. -func HasOp(op ent.Op) Condition { - return func(_ context.Context, m ent.Mutation) bool { - return m.Op().Is(op) - } -} - -// HasAddedFields is a condition validating `.AddedField` on fields. -func HasAddedFields(field string, fields ...string) Condition { - return func(_ context.Context, m ent.Mutation) bool { - if _, exists := m.AddedField(field); !exists { - return false - } - for _, field := range fields { - if _, exists := m.AddedField(field); !exists { - return false - } - } - return true - } -} - -// HasClearedFields is a condition validating `.FieldCleared` on fields. -func HasClearedFields(field string, fields ...string) Condition { - return func(_ context.Context, m ent.Mutation) bool { - if exists := m.FieldCleared(field); !exists { - return false - } - for _, field := range fields { - if exists := m.FieldCleared(field); !exists { - return false - } - } - return true - } -} - -// HasFields is a condition validating `.Field` on fields. -func HasFields(field string, fields ...string) Condition { - return func(_ context.Context, m ent.Mutation) bool { - if _, exists := m.Field(field); !exists { - return false - } - for _, field := range fields { - if _, exists := m.Field(field); !exists { - return false - } - } - return true - } -} - -// If executes the given hook under condition. -// -// hook.If(ComputeAverage, And(HasFields(...), HasAddedFields(...))) -func If(hk ent.Hook, cond Condition) ent.Hook { - return func(next ent.Mutator) ent.Mutator { - return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { - if cond(ctx, m) { - return hk(next).Mutate(ctx, m) - } - return next.Mutate(ctx, m) - }) - } -} - -// On executes the given hook only for the given operation. -// -// hook.On(Log, ent.Delete|ent.Create) -func On(hk ent.Hook, op ent.Op) ent.Hook { - return If(hk, HasOp(op)) -} - -// Unless skips the given hook only for the given operation. -// -// hook.Unless(Log, ent.Update|ent.UpdateOne) -func Unless(hk ent.Hook, op ent.Op) ent.Hook { - return If(hk, Not(HasOp(op))) -} - -// FixedError is a hook returning a fixed error. -func FixedError(err error) ent.Hook { - return func(ent.Mutator) ent.Mutator { - return ent.MutateFunc(func(context.Context, ent.Mutation) (ent.Value, error) { - return nil, err - }) - } -} - -// Reject returns a hook that rejects all operations that match op. -// -// func (T) Hooks() []ent.Hook { -// return []ent.Hook{ -// Reject(ent.Delete|ent.Update), -// } -// } -func Reject(op ent.Op) ent.Hook { - hk := FixedError(fmt.Errorf("%s operation is not allowed", op)) - return On(hk, op) -} - -// Chain acts as a list of hooks and is effectively immutable. -// Once created, it will always hold the same set of hooks in the same order. -type Chain struct { - hooks []ent.Hook -} - -// NewChain creates a new chain of hooks. -func NewChain(hooks ...ent.Hook) Chain { - return Chain{append([]ent.Hook(nil), hooks...)} -} - -// Hook chains the list of hooks and returns the final hook. -func (c Chain) Hook() ent.Hook { - return func(mutator ent.Mutator) ent.Mutator { - for i := len(c.hooks) - 1; i >= 0; i-- { - mutator = c.hooks[i](mutator) - } - return mutator - } -} - -// Append extends a chain, adding the specified hook -// as the last ones in the mutation flow. -func (c Chain) Append(hooks ...ent.Hook) Chain { - newHooks := make([]ent.Hook, 0, len(c.hooks)+len(hooks)) - newHooks = append(newHooks, c.hooks...) - newHooks = append(newHooks, hooks...) - return Chain{newHooks} -} - -// Extend extends a chain, adding the specified chain -// as the last ones in the mutation flow. -func (c Chain) Extend(chain Chain) Chain { - return c.Append(chain.hooks...) -} diff --git a/internal/db/ent/runtime/runtime.go b/internal/db/ent/runtime/runtime.go deleted file mode 100644 index f6a48303..00000000 --- a/internal/db/ent/runtime/runtime.go +++ /dev/null @@ -1,10 +0,0 @@ -// Code generated by ent, DO NOT EDIT. - -package runtime - -// The schema-stitching logic is generated in github.com/ProtonMail/gluon/internal/db/ent/runtime.go - -const ( - Version = "v0.11.2" // Version of ent codegen. - Sum = "h1:UM2/BUhF2FfsxPHRxLjQbhqJNaDdVlOwNIAMLs2jyto=" // Sum of ent codegen. -) diff --git a/internal/db_impl/db_impl.go b/internal/db_impl/db_impl.go new file mode 100644 index 00000000..1c96616a --- /dev/null +++ b/internal/db_impl/db_impl.go @@ -0,0 +1,10 @@ +package db_impl + +import ( + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db" +) + +func NewEntDB() db.ClientInterface { + return ent_db.NewEntDB() +} diff --git a/internal/db/db.go b/internal/db_impl/ent_db/db.go similarity index 55% rename from internal/db/db.go rename to internal/db_impl/ent_db/db.go index b3885f35..0dbb7b1f 100644 --- a/internal/db/db.go +++ b/internal/db_impl/ent_db/db.go @@ -1,23 +1,23 @@ -package db +package ent_db import ( "context" "errors" "fmt" - "github.com/ProtonMail/gluon/internal/utils" - "github.com/google/uuid" "io/fs" "os" "path/filepath" "sync" "entgo.io/ent/dialect" - "github.com/ProtonMail/gluon/internal/db/ent" + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" + "github.com/ProtonMail/gluon/internal/utils" "github.com/ProtonMail/gluon/reporter" ) type DB struct { - db *ent.Client + db *internal.Client lock sync.RWMutex } @@ -28,22 +28,36 @@ func (d *DB) Init(ctx context.Context) error { return d.db.Schema.Create(ctx) } -func (d *DB) Read(ctx context.Context, fn func(context.Context, *ent.Client) error) error { - _, err := ReadResult(ctx, d, func(ctx context.Context, client *ent.Client) (struct{}, error) { +func (d *DB) ReadEnt(ctx context.Context, fn func(context.Context, *internal.Client) error) error { + _, err := ReadResult(ctx, d, func(ctx context.Context, client *internal.Client) (struct{}, error) { return struct{}{}, fn(ctx, client) }) return err } -func (d *DB) Write(ctx context.Context, fn func(context.Context, *ent.Tx) error) error { - _, err := WriteResult(ctx, d, func(ctx context.Context, tx *ent.Tx) (struct{}, error) { +func (d *DB) WriteEnt(ctx context.Context, fn func(context.Context, *internal.Tx) error) error { + _, err := WriteResult(ctx, d, func(ctx context.Context, tx *internal.Tx) (struct{}, error) { return struct{}{}, fn(ctx, tx) }) return err } +func (d *DB) Read(ctx context.Context, fn func(context.Context, db.ReadOnly) error) error { + return d.ReadEnt(ctx, func(ctx context.Context, client *internal.Client) error { + rd := newOpsReadFromClient(client) + return fn(ctx, rd) + }) +} + +func (d *DB) Write(ctx context.Context, fn func(context.Context, db.Transaction) error) error { + return d.WriteEnt(ctx, func(ctx context.Context, tx *internal.Tx) error { + op := newEntOpsWrite(tx) + return fn(ctx, op) + }) +} + func (d *DB) Close() error { d.lock.Lock() defer d.lock.Unlock() @@ -51,14 +65,14 @@ func (d *DB) Close() error { return d.db.Close() } -func ReadResult[T any](ctx context.Context, db *DB, fn func(context.Context, *ent.Client) (T, error)) (T, error) { +func ReadResult[T any](ctx context.Context, db *DB, fn func(context.Context, *internal.Client) (T, error)) (T, error) { db.lock.RLock() defer db.lock.RUnlock() return fn(ctx, db.db) } -func WriteResult[T any](ctx context.Context, db *DB, fn func(context.Context, *ent.Tx) (T, error)) (T, error) { +func WriteResult[T any](ctx context.Context, db *DB, fn func(context.Context, *internal.Tx) (T, error)) (T, error) { db.lock.Lock() defer db.lock.Unlock() @@ -102,21 +116,13 @@ func WriteResult[T any](ctx context.Context, db *DB, fn func(context.Context, *e return result, nil } -func getDatabaseConn(dir, userID, path string) string { - return fmt.Sprintf("file:%v?cache=shared&_fk=1&_journal=WAL", path) -} +type EntDBBuilder struct{} -func getDatabasePath(dir, userID string) string { - return filepath.Join(dir, fmt.Sprintf("%v.db", userID)) -} - -func GetDeferredDeleteDBPath(dir string) string { - return filepath.Join(dir, "deferred_delete") +func NewEntDBBuilder() db.ClientInterface { + return &EntDBBuilder{} } -// NewDB creates a new database instance. -// If the database does not exist, it will be created and the second return value will be true. -func NewDB(dir, userID string) (*DB, bool, error) { +func (EntDBBuilder) New(dir string, userID string) (db.Client, bool, error) { if err := os.MkdirAll(dir, 0o700); err != nil { return nil, false, err } @@ -129,7 +135,7 @@ func NewDB(dir, userID string) (*DB, bool, error) { return nil, false, err } - client, err := ent.Open(dialect.SQLite, getDatabaseConn(dir, userID, path)) + client, err := internal.Open(dialect.SQLite, getDatabaseConn(dir, userID, path)) if err != nil { return nil, false, err } @@ -137,42 +143,12 @@ func NewDB(dir, userID string) (*DB, bool, error) { return &DB{db: client}, !exists, nil } -// DeleteDB will rename all the database files for the given user to a directory within the same folder to avoid -// issues with ent not being able to close the database on demand. The database will be cleaned up on the next -// run on the Gluon server. -func DeleteDB(dir, userID string) error { - // Rather than deleting the files immediately move them to a directory to be cleaned up later. - deferredDeletePath := GetDeferredDeleteDBPath(dir) - - if err := os.MkdirAll(deferredDeletePath, 0o700); err != nil { - return fmt.Errorf("failed to create deferred delete dir: %w", err) - } - - matchingFiles, err := filepath.Glob(filepath.Join(dir, userID+"*")) - if err != nil { - return fmt.Errorf("failed to match db files:%w", err) - } - - for _, file := range matchingFiles { - // Use new UUID to avoid conflict with existing files - if err := os.Rename(file, filepath.Join(deferredDeletePath, uuid.NewString())); err != nil { - return fmt.Errorf("failed to move db file '%v' :%w", file, err) - } - } - - return nil +func (EntDBBuilder) Delete(dir string, userID string) error { + return db.DeleteDB(dir, userID) } -// DeleteDeferredDBFiles deletes all data from previous databases that were scheduled for removal. -func DeleteDeferredDBFiles(dir string) error { - deferredDeleteDir := GetDeferredDeleteDBPath(dir) - if err := os.RemoveAll(deferredDeleteDir); err != nil { - if !errors.Is(err, os.ErrNotExist) { - return err - } - } - - return nil +func getDatabaseConn(dir, userID, path string) string { + return fmt.Sprintf("file:%v?cache=shared&_fk=1&_journal=WAL", path) } // pathExists returns whether the given file exists. @@ -185,3 +161,32 @@ func pathExists(path string) (bool, error) { return true, nil } + +func getDatabasePath(dir, userID string) string { + return filepath.Join(dir, fmt.Sprintf("%v.db", userID)) +} + +func NewEntDB() db.ClientInterface { + return &EntDBBuilder{} +} + +func wrapEntError(err error) error { + if err == nil { + return nil + } + + if internal.IsNotFound(err) { + return fmt.Errorf("%v (%w)", err, db.ErrNotFound) + } + + return err +} + +func wrapEntErrFn(fn func() error) error { + return wrapEntError(fn()) +} + +func wrapEntErrFnTyped[T any](fn func() (T, error)) (T, error) { + val, err := fn() + return val, wrapEntError(err) +} diff --git a/internal/db/ent/client.go b/internal/db_impl/ent_db/internal/client.go similarity index 72% rename from internal/db/ent/client.go rename to internal/db_impl/ent_db/internal/client.go index 17d75a15..b1047681 100644 --- a/internal/db/ent/client.go +++ b/internal/db_impl/ent_db/internal/client.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -9,16 +9,16 @@ import ( "log" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/migrate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/migrate" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" @@ -50,7 +50,7 @@ type Client struct { // NewClient creates a new client configured with the given options. func NewClient(opts ...Option) *Client { - cfg := config{log: log.Println, hooks: &hooks{}} + cfg := config{log: log.Println, hooks: &hooks{}, inters: &inters{}} cfg.options(opts...) client := &Client{config: cfg} client.init() @@ -89,11 +89,11 @@ func Open(driverName, dataSourceName string, options ...Option) (*Client, error) // is used until the transaction is committed or rolled back. func (c *Client) Tx(ctx context.Context) (*Tx, error) { if _, ok := c.driver.(*txDriver); ok { - return nil, errors.New("ent: cannot start a transaction within a transaction") + return nil, errors.New("internal: cannot start a transaction within a transaction") } tx, err := newTx(ctx, c.driver) if err != nil { - return nil, fmt.Errorf("ent: starting a transaction: %w", err) + return nil, fmt.Errorf("internal: starting a transaction: %w", err) } cfg := c.config cfg.driver = tx @@ -173,6 +173,43 @@ func (c *Client) Use(hooks ...Hook) { c.UID.Use(hooks...) } +// Intercept adds the query interceptors to all the entity clients. +// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`. +func (c *Client) Intercept(interceptors ...Interceptor) { + c.DeletedSubscription.Intercept(interceptors...) + c.Mailbox.Intercept(interceptors...) + c.MailboxAttr.Intercept(interceptors...) + c.MailboxFlag.Intercept(interceptors...) + c.MailboxPermFlag.Intercept(interceptors...) + c.Message.Intercept(interceptors...) + c.MessageFlag.Intercept(interceptors...) + c.UID.Intercept(interceptors...) +} + +// Mutate implements the ent.Mutator interface. +func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { + switch m := m.(type) { + case *DeletedSubscriptionMutation: + return c.DeletedSubscription.mutate(ctx, m) + case *MailboxMutation: + return c.Mailbox.mutate(ctx, m) + case *MailboxAttrMutation: + return c.MailboxAttr.mutate(ctx, m) + case *MailboxFlagMutation: + return c.MailboxFlag.mutate(ctx, m) + case *MailboxPermFlagMutation: + return c.MailboxPermFlag.mutate(ctx, m) + case *MessageMutation: + return c.Message.mutate(ctx, m) + case *MessageFlagMutation: + return c.MessageFlag.mutate(ctx, m) + case *UIDMutation: + return c.UID.mutate(ctx, m) + default: + return nil, fmt.Errorf("internal: unknown mutation type %T", m) + } +} + // DeletedSubscriptionClient is a client for the DeletedSubscription schema. type DeletedSubscriptionClient struct { config @@ -189,6 +226,12 @@ func (c *DeletedSubscriptionClient) Use(hooks ...Hook) { c.hooks.DeletedSubscription = append(c.hooks.DeletedSubscription, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `deletedsubscription.Intercept(f(g(h())))`. +func (c *DeletedSubscriptionClient) Intercept(interceptors ...Interceptor) { + c.inters.DeletedSubscription = append(c.inters.DeletedSubscription, interceptors...) +} + // Create returns a builder for creating a DeletedSubscription entity. func (c *DeletedSubscriptionClient) Create() *DeletedSubscriptionCreate { mutation := newDeletedSubscriptionMutation(c.config, OpCreate) @@ -229,7 +272,7 @@ func (c *DeletedSubscriptionClient) DeleteOne(ds *DeletedSubscription) *DeletedS return c.DeleteOneID(ds.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *DeletedSubscriptionClient) DeleteOneID(id int) *DeletedSubscriptionDeleteOne { builder := c.Delete().Where(deletedsubscription.ID(id)) builder.mutation.id = &id @@ -241,6 +284,8 @@ func (c *DeletedSubscriptionClient) DeleteOneID(id int) *DeletedSubscriptionDele func (c *DeletedSubscriptionClient) Query() *DeletedSubscriptionQuery { return &DeletedSubscriptionQuery{ config: c.config, + ctx: &QueryContext{Type: TypeDeletedSubscription}, + inters: c.Interceptors(), } } @@ -263,6 +308,26 @@ func (c *DeletedSubscriptionClient) Hooks() []Hook { return c.hooks.DeletedSubscription } +// Interceptors returns the client interceptors. +func (c *DeletedSubscriptionClient) Interceptors() []Interceptor { + return c.inters.DeletedSubscription +} + +func (c *DeletedSubscriptionClient) mutate(ctx context.Context, m *DeletedSubscriptionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&DeletedSubscriptionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&DeletedSubscriptionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&DeletedSubscriptionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&DeletedSubscriptionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown DeletedSubscription mutation op: %q", m.Op()) + } +} + // MailboxClient is a client for the Mailbox schema. type MailboxClient struct { config @@ -279,6 +344,12 @@ func (c *MailboxClient) Use(hooks ...Hook) { c.hooks.Mailbox = append(c.hooks.Mailbox, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `mailbox.Intercept(f(g(h())))`. +func (c *MailboxClient) Intercept(interceptors ...Interceptor) { + c.inters.Mailbox = append(c.inters.Mailbox, interceptors...) +} + // Create returns a builder for creating a Mailbox entity. func (c *MailboxClient) Create() *MailboxCreate { mutation := newMailboxMutation(c.config, OpCreate) @@ -319,7 +390,7 @@ func (c *MailboxClient) DeleteOne(m *Mailbox) *MailboxDeleteOne { return c.DeleteOneID(m.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MailboxClient) DeleteOneID(id imap.InternalMailboxID) *MailboxDeleteOne { builder := c.Delete().Where(mailbox.ID(id)) builder.mutation.id = &id @@ -331,6 +402,8 @@ func (c *MailboxClient) DeleteOneID(id imap.InternalMailboxID) *MailboxDeleteOne func (c *MailboxClient) Query() *MailboxQuery { return &MailboxQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMailbox}, + inters: c.Interceptors(), } } @@ -350,8 +423,8 @@ func (c *MailboxClient) GetX(ctx context.Context, id imap.InternalMailboxID) *Ma // QueryUIDs queries the UIDs edge of a Mailbox. func (c *MailboxClient) QueryUIDs(m *Mailbox) *UIDQuery { - query := &UIDQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&UIDClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(mailbox.Table, mailbox.FieldID, id), @@ -366,8 +439,8 @@ func (c *MailboxClient) QueryUIDs(m *Mailbox) *UIDQuery { // QueryFlags queries the flags edge of a Mailbox. func (c *MailboxClient) QueryFlags(m *Mailbox) *MailboxFlagQuery { - query := &MailboxFlagQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MailboxFlagClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(mailbox.Table, mailbox.FieldID, id), @@ -382,8 +455,8 @@ func (c *MailboxClient) QueryFlags(m *Mailbox) *MailboxFlagQuery { // QueryPermanentFlags queries the permanent_flags edge of a Mailbox. func (c *MailboxClient) QueryPermanentFlags(m *Mailbox) *MailboxPermFlagQuery { - query := &MailboxPermFlagQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MailboxPermFlagClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(mailbox.Table, mailbox.FieldID, id), @@ -398,8 +471,8 @@ func (c *MailboxClient) QueryPermanentFlags(m *Mailbox) *MailboxPermFlagQuery { // QueryAttributes queries the attributes edge of a Mailbox. func (c *MailboxClient) QueryAttributes(m *Mailbox) *MailboxAttrQuery { - query := &MailboxAttrQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MailboxAttrClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(mailbox.Table, mailbox.FieldID, id), @@ -417,6 +490,26 @@ func (c *MailboxClient) Hooks() []Hook { return c.hooks.Mailbox } +// Interceptors returns the client interceptors. +func (c *MailboxClient) Interceptors() []Interceptor { + return c.inters.Mailbox +} + +func (c *MailboxClient) mutate(ctx context.Context, m *MailboxMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MailboxCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MailboxUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MailboxUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MailboxDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown Mailbox mutation op: %q", m.Op()) + } +} + // MailboxAttrClient is a client for the MailboxAttr schema. type MailboxAttrClient struct { config @@ -433,6 +526,12 @@ func (c *MailboxAttrClient) Use(hooks ...Hook) { c.hooks.MailboxAttr = append(c.hooks.MailboxAttr, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `mailboxattr.Intercept(f(g(h())))`. +func (c *MailboxAttrClient) Intercept(interceptors ...Interceptor) { + c.inters.MailboxAttr = append(c.inters.MailboxAttr, interceptors...) +} + // Create returns a builder for creating a MailboxAttr entity. func (c *MailboxAttrClient) Create() *MailboxAttrCreate { mutation := newMailboxAttrMutation(c.config, OpCreate) @@ -473,7 +572,7 @@ func (c *MailboxAttrClient) DeleteOne(ma *MailboxAttr) *MailboxAttrDeleteOne { return c.DeleteOneID(ma.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MailboxAttrClient) DeleteOneID(id int) *MailboxAttrDeleteOne { builder := c.Delete().Where(mailboxattr.ID(id)) builder.mutation.id = &id @@ -485,6 +584,8 @@ func (c *MailboxAttrClient) DeleteOneID(id int) *MailboxAttrDeleteOne { func (c *MailboxAttrClient) Query() *MailboxAttrQuery { return &MailboxAttrQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMailboxAttr}, + inters: c.Interceptors(), } } @@ -507,6 +608,26 @@ func (c *MailboxAttrClient) Hooks() []Hook { return c.hooks.MailboxAttr } +// Interceptors returns the client interceptors. +func (c *MailboxAttrClient) Interceptors() []Interceptor { + return c.inters.MailboxAttr +} + +func (c *MailboxAttrClient) mutate(ctx context.Context, m *MailboxAttrMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MailboxAttrCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MailboxAttrUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MailboxAttrUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MailboxAttrDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown MailboxAttr mutation op: %q", m.Op()) + } +} + // MailboxFlagClient is a client for the MailboxFlag schema. type MailboxFlagClient struct { config @@ -523,6 +644,12 @@ func (c *MailboxFlagClient) Use(hooks ...Hook) { c.hooks.MailboxFlag = append(c.hooks.MailboxFlag, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `mailboxflag.Intercept(f(g(h())))`. +func (c *MailboxFlagClient) Intercept(interceptors ...Interceptor) { + c.inters.MailboxFlag = append(c.inters.MailboxFlag, interceptors...) +} + // Create returns a builder for creating a MailboxFlag entity. func (c *MailboxFlagClient) Create() *MailboxFlagCreate { mutation := newMailboxFlagMutation(c.config, OpCreate) @@ -563,7 +690,7 @@ func (c *MailboxFlagClient) DeleteOne(mf *MailboxFlag) *MailboxFlagDeleteOne { return c.DeleteOneID(mf.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MailboxFlagClient) DeleteOneID(id int) *MailboxFlagDeleteOne { builder := c.Delete().Where(mailboxflag.ID(id)) builder.mutation.id = &id @@ -575,6 +702,8 @@ func (c *MailboxFlagClient) DeleteOneID(id int) *MailboxFlagDeleteOne { func (c *MailboxFlagClient) Query() *MailboxFlagQuery { return &MailboxFlagQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMailboxFlag}, + inters: c.Interceptors(), } } @@ -597,6 +726,26 @@ func (c *MailboxFlagClient) Hooks() []Hook { return c.hooks.MailboxFlag } +// Interceptors returns the client interceptors. +func (c *MailboxFlagClient) Interceptors() []Interceptor { + return c.inters.MailboxFlag +} + +func (c *MailboxFlagClient) mutate(ctx context.Context, m *MailboxFlagMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MailboxFlagCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MailboxFlagUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MailboxFlagUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MailboxFlagDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown MailboxFlag mutation op: %q", m.Op()) + } +} + // MailboxPermFlagClient is a client for the MailboxPermFlag schema. type MailboxPermFlagClient struct { config @@ -613,6 +762,12 @@ func (c *MailboxPermFlagClient) Use(hooks ...Hook) { c.hooks.MailboxPermFlag = append(c.hooks.MailboxPermFlag, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `mailboxpermflag.Intercept(f(g(h())))`. +func (c *MailboxPermFlagClient) Intercept(interceptors ...Interceptor) { + c.inters.MailboxPermFlag = append(c.inters.MailboxPermFlag, interceptors...) +} + // Create returns a builder for creating a MailboxPermFlag entity. func (c *MailboxPermFlagClient) Create() *MailboxPermFlagCreate { mutation := newMailboxPermFlagMutation(c.config, OpCreate) @@ -653,7 +808,7 @@ func (c *MailboxPermFlagClient) DeleteOne(mpf *MailboxPermFlag) *MailboxPermFlag return c.DeleteOneID(mpf.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MailboxPermFlagClient) DeleteOneID(id int) *MailboxPermFlagDeleteOne { builder := c.Delete().Where(mailboxpermflag.ID(id)) builder.mutation.id = &id @@ -665,6 +820,8 @@ func (c *MailboxPermFlagClient) DeleteOneID(id int) *MailboxPermFlagDeleteOne { func (c *MailboxPermFlagClient) Query() *MailboxPermFlagQuery { return &MailboxPermFlagQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMailboxPermFlag}, + inters: c.Interceptors(), } } @@ -687,6 +844,26 @@ func (c *MailboxPermFlagClient) Hooks() []Hook { return c.hooks.MailboxPermFlag } +// Interceptors returns the client interceptors. +func (c *MailboxPermFlagClient) Interceptors() []Interceptor { + return c.inters.MailboxPermFlag +} + +func (c *MailboxPermFlagClient) mutate(ctx context.Context, m *MailboxPermFlagMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MailboxPermFlagCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MailboxPermFlagUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MailboxPermFlagUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MailboxPermFlagDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown MailboxPermFlag mutation op: %q", m.Op()) + } +} + // MessageClient is a client for the Message schema. type MessageClient struct { config @@ -703,6 +880,12 @@ func (c *MessageClient) Use(hooks ...Hook) { c.hooks.Message = append(c.hooks.Message, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `message.Intercept(f(g(h())))`. +func (c *MessageClient) Intercept(interceptors ...Interceptor) { + c.inters.Message = append(c.inters.Message, interceptors...) +} + // Create returns a builder for creating a Message entity. func (c *MessageClient) Create() *MessageCreate { mutation := newMessageMutation(c.config, OpCreate) @@ -743,7 +926,7 @@ func (c *MessageClient) DeleteOne(m *Message) *MessageDeleteOne { return c.DeleteOneID(m.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MessageClient) DeleteOneID(id imap.InternalMessageID) *MessageDeleteOne { builder := c.Delete().Where(message.ID(id)) builder.mutation.id = &id @@ -755,6 +938,8 @@ func (c *MessageClient) DeleteOneID(id imap.InternalMessageID) *MessageDeleteOne func (c *MessageClient) Query() *MessageQuery { return &MessageQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMessage}, + inters: c.Interceptors(), } } @@ -774,8 +959,8 @@ func (c *MessageClient) GetX(ctx context.Context, id imap.InternalMessageID) *Me // QueryFlags queries the flags edge of a Message. func (c *MessageClient) QueryFlags(m *Message) *MessageFlagQuery { - query := &MessageFlagQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MessageFlagClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(message.Table, message.FieldID, id), @@ -790,8 +975,8 @@ func (c *MessageClient) QueryFlags(m *Message) *MessageFlagQuery { // QueryUIDs queries the UIDs edge of a Message. func (c *MessageClient) QueryUIDs(m *Message) *UIDQuery { - query := &UIDQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&UIDClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := m.ID step := sqlgraph.NewStep( sqlgraph.From(message.Table, message.FieldID, id), @@ -809,6 +994,26 @@ func (c *MessageClient) Hooks() []Hook { return c.hooks.Message } +// Interceptors returns the client interceptors. +func (c *MessageClient) Interceptors() []Interceptor { + return c.inters.Message +} + +func (c *MessageClient) mutate(ctx context.Context, m *MessageMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MessageCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MessageUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MessageUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MessageDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown Message mutation op: %q", m.Op()) + } +} + // MessageFlagClient is a client for the MessageFlag schema. type MessageFlagClient struct { config @@ -825,6 +1030,12 @@ func (c *MessageFlagClient) Use(hooks ...Hook) { c.hooks.MessageFlag = append(c.hooks.MessageFlag, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `messageflag.Intercept(f(g(h())))`. +func (c *MessageFlagClient) Intercept(interceptors ...Interceptor) { + c.inters.MessageFlag = append(c.inters.MessageFlag, interceptors...) +} + // Create returns a builder for creating a MessageFlag entity. func (c *MessageFlagClient) Create() *MessageFlagCreate { mutation := newMessageFlagMutation(c.config, OpCreate) @@ -865,7 +1076,7 @@ func (c *MessageFlagClient) DeleteOne(mf *MessageFlag) *MessageFlagDeleteOne { return c.DeleteOneID(mf.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *MessageFlagClient) DeleteOneID(id int) *MessageFlagDeleteOne { builder := c.Delete().Where(messageflag.ID(id)) builder.mutation.id = &id @@ -877,6 +1088,8 @@ func (c *MessageFlagClient) DeleteOneID(id int) *MessageFlagDeleteOne { func (c *MessageFlagClient) Query() *MessageFlagQuery { return &MessageFlagQuery{ config: c.config, + ctx: &QueryContext{Type: TypeMessageFlag}, + inters: c.Interceptors(), } } @@ -896,8 +1109,8 @@ func (c *MessageFlagClient) GetX(ctx context.Context, id int) *MessageFlag { // QueryMessages queries the messages edge of a MessageFlag. func (c *MessageFlagClient) QueryMessages(mf *MessageFlag) *MessageQuery { - query := &MessageQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MessageClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := mf.ID step := sqlgraph.NewStep( sqlgraph.From(messageflag.Table, messageflag.FieldID, id), @@ -915,6 +1128,26 @@ func (c *MessageFlagClient) Hooks() []Hook { return c.hooks.MessageFlag } +// Interceptors returns the client interceptors. +func (c *MessageFlagClient) Interceptors() []Interceptor { + return c.inters.MessageFlag +} + +func (c *MessageFlagClient) mutate(ctx context.Context, m *MessageFlagMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MessageFlagCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MessageFlagUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MessageFlagUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MessageFlagDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown MessageFlag mutation op: %q", m.Op()) + } +} + // UIDClient is a client for the UID schema. type UIDClient struct { config @@ -931,6 +1164,12 @@ func (c *UIDClient) Use(hooks ...Hook) { c.hooks.UID = append(c.hooks.UID, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `uid.Intercept(f(g(h())))`. +func (c *UIDClient) Intercept(interceptors ...Interceptor) { + c.inters.UID = append(c.inters.UID, interceptors...) +} + // Create returns a builder for creating a UID entity. func (c *UIDClient) Create() *UIDCreate { mutation := newUIDMutation(c.config, OpCreate) @@ -971,7 +1210,7 @@ func (c *UIDClient) DeleteOne(u *UID) *UIDDeleteOne { return c.DeleteOneID(u.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *UIDClient) DeleteOneID(id int) *UIDDeleteOne { builder := c.Delete().Where(uid.ID(id)) builder.mutation.id = &id @@ -983,6 +1222,8 @@ func (c *UIDClient) DeleteOneID(id int) *UIDDeleteOne { func (c *UIDClient) Query() *UIDQuery { return &UIDQuery{ config: c.config, + ctx: &QueryContext{Type: TypeUID}, + inters: c.Interceptors(), } } @@ -1002,8 +1243,8 @@ func (c *UIDClient) GetX(ctx context.Context, id int) *UID { // QueryMessage queries the message edge of a UID. func (c *UIDClient) QueryMessage(u *UID) *MessageQuery { - query := &MessageQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MessageClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := u.ID step := sqlgraph.NewStep( sqlgraph.From(uid.Table, uid.FieldID, id), @@ -1018,8 +1259,8 @@ func (c *UIDClient) QueryMessage(u *UID) *MessageQuery { // QueryMailbox queries the mailbox edge of a UID. func (c *UIDClient) QueryMailbox(u *UID) *MailboxQuery { - query := &MailboxQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&MailboxClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := u.ID step := sqlgraph.NewStep( sqlgraph.From(uid.Table, uid.FieldID, id), @@ -1036,3 +1277,23 @@ func (c *UIDClient) QueryMailbox(u *UID) *MailboxQuery { func (c *UIDClient) Hooks() []Hook { return c.hooks.UID } + +// Interceptors returns the client interceptors. +func (c *UIDClient) Interceptors() []Interceptor { + return c.inters.UID +} + +func (c *UIDClient) mutate(ctx context.Context, m *UIDMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&UIDCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&UIDUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&UIDUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&UIDDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("internal: unknown UID mutation op: %q", m.Op()) + } +} diff --git a/internal/db/ent/config.go b/internal/db_impl/ent_db/internal/config.go similarity index 55% rename from internal/db/ent/config.go rename to internal/db_impl/ent_db/internal/config.go index 3f650be6..68d3879b 100644 --- a/internal/db/ent/config.go +++ b/internal/db_impl/ent_db/internal/config.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "entgo.io/ent" @@ -17,22 +17,36 @@ type config struct { // debug enable a debug logging. debug bool // log used for logging on debug mode. - log func(...interface{}) + log func(...any) // hooks to execute on mutations. hooks *hooks + // interceptors to execute on queries. + inters *inters } -// hooks per client, for fast access. -type hooks struct { - DeletedSubscription []ent.Hook - Mailbox []ent.Hook - MailboxAttr []ent.Hook - MailboxFlag []ent.Hook - MailboxPermFlag []ent.Hook - Message []ent.Hook - MessageFlag []ent.Hook - UID []ent.Hook -} +// hooks and interceptors per client, for fast access. +type ( + hooks struct { + DeletedSubscription []ent.Hook + Mailbox []ent.Hook + MailboxAttr []ent.Hook + MailboxFlag []ent.Hook + MailboxPermFlag []ent.Hook + Message []ent.Hook + MessageFlag []ent.Hook + UID []ent.Hook + } + inters struct { + DeletedSubscription []ent.Interceptor + Mailbox []ent.Interceptor + MailboxAttr []ent.Interceptor + MailboxFlag []ent.Interceptor + MailboxPermFlag []ent.Interceptor + Message []ent.Interceptor + MessageFlag []ent.Interceptor + UID []ent.Interceptor + } +) // Options applies the options on the config object. func (c *config) options(opts ...Option) { @@ -52,7 +66,7 @@ func Debug() Option { } // Log sets the logging function for debug mode. -func Log(fn func(...interface{})) Option { +func Log(fn func(...any)) Option { return func(c *config) { c.log = fn } diff --git a/internal/db/ent/context.go b/internal/db_impl/ent_db/internal/context.go similarity index 98% rename from internal/db/ent/context.go rename to internal/db_impl/ent_db/internal/context.go index 7811bfa2..da69c313 100644 --- a/internal/db/ent/context.go +++ b/internal/db_impl/ent_db/internal/context.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" diff --git a/internal/db/ent/deletedsubscription.go b/internal/db_impl/ent_db/internal/deletedsubscription.go similarity index 87% rename from internal/db/ent/deletedsubscription.go rename to internal/db_impl/ent_db/internal/deletedsubscription.go index bdf2ce8a..94b50f29 100644 --- a/internal/db/ent/deletedsubscription.go +++ b/internal/db_impl/ent_db/internal/deletedsubscription.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -8,7 +8,7 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" ) // DeletedSubscription is the model entity for the DeletedSubscription schema. @@ -23,8 +23,8 @@ type DeletedSubscription struct { } // scanValues returns the types for scanning values from sql.Rows. -func (*DeletedSubscription) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*DeletedSubscription) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case deletedsubscription.FieldID: @@ -40,7 +40,7 @@ func (*DeletedSubscription) scanValues(columns []string) ([]interface{}, error) // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the DeletedSubscription fields. -func (ds *DeletedSubscription) assignValues(columns []string, values []interface{}) error { +func (ds *DeletedSubscription) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -73,7 +73,7 @@ func (ds *DeletedSubscription) assignValues(columns []string, values []interface // Note that you need to call DeletedSubscription.Unwrap() before calling this method if this DeletedSubscription // was returned from a transaction, and the transaction was committed or rolled back. func (ds *DeletedSubscription) Update() *DeletedSubscriptionUpdateOne { - return (&DeletedSubscriptionClient{config: ds.config}).UpdateOne(ds) + return NewDeletedSubscriptionClient(ds.config).UpdateOne(ds) } // Unwrap unwraps the DeletedSubscription entity that was returned from a transaction after it was closed, @@ -81,7 +81,7 @@ func (ds *DeletedSubscription) Update() *DeletedSubscriptionUpdateOne { func (ds *DeletedSubscription) Unwrap() *DeletedSubscription { _tx, ok := ds.config.driver.(*txDriver) if !ok { - panic("ent: DeletedSubscription is not a transactional entity") + panic("internal: DeletedSubscription is not a transactional entity") } ds.config.driver = _tx.drv return ds @@ -103,9 +103,3 @@ func (ds *DeletedSubscription) String() string { // DeletedSubscriptions is a parsable slice of DeletedSubscription. type DeletedSubscriptions []*DeletedSubscription - -func (ds DeletedSubscriptions) config(cfg config) { - for _i := range ds { - ds[_i].config = cfg - } -} diff --git a/internal/db/ent/deletedsubscription/deletedsubscription.go b/internal/db_impl/ent_db/internal/deletedsubscription/deletedsubscription.go similarity index 100% rename from internal/db/ent/deletedsubscription/deletedsubscription.go rename to internal/db_impl/ent_db/internal/deletedsubscription/deletedsubscription.go diff --git a/internal/db/ent/deletedsubscription/where.go b/internal/db_impl/ent_db/internal/deletedsubscription/where.go similarity index 57% rename from internal/db/ent/deletedsubscription/where.go rename to internal/db_impl/ent_db/internal/deletedsubscription/where.go index d43f6049..66d35e48 100644 --- a/internal/db/ent/deletedsubscription/where.go +++ b/internal/db_impl/ent_db/internal/deletedsubscription/where.go @@ -5,302 +5,212 @@ package deletedsubscription import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.DeletedSubscription(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.DeletedSubscription(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.DeletedSubscription(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.DeletedSubscription(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.DeletedSubscription(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.DeletedSubscription(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.DeletedSubscription(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.DeletedSubscription(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.DeletedSubscription(sql.FieldLTE(FieldID, id)) } // Name applies equality check predicate on the "Name" field. It's identical to NameEQ. func Name(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldEQ(FieldName, v)) } // RemoteID applies equality check predicate on the "RemoteID" field. It's identical to RemoteIDEQ. func RemoteID(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldEQ(FieldRemoteID, vc)) } // NameEQ applies the EQ predicate on the "Name" field. func NameEQ(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldEQ(FieldName, v)) } // NameNEQ applies the NEQ predicate on the "Name" field. func NameNEQ(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldNEQ(FieldName, v)) } // NameIn applies the In predicate on the "Name" field. func NameIn(vs ...string) predicate.DeletedSubscription { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldName), v...)) - }) + return predicate.DeletedSubscription(sql.FieldIn(FieldName, vs...)) } // NameNotIn applies the NotIn predicate on the "Name" field. func NameNotIn(vs ...string) predicate.DeletedSubscription { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldName), v...)) - }) + return predicate.DeletedSubscription(sql.FieldNotIn(FieldName, vs...)) } // NameGT applies the GT predicate on the "Name" field. func NameGT(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldGT(FieldName, v)) } // NameGTE applies the GTE predicate on the "Name" field. func NameGTE(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldGTE(FieldName, v)) } // NameLT applies the LT predicate on the "Name" field. func NameLT(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldLT(FieldName, v)) } // NameLTE applies the LTE predicate on the "Name" field. func NameLTE(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldLTE(FieldName, v)) } // NameContains applies the Contains predicate on the "Name" field. func NameContains(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldContains(FieldName, v)) } // NameHasPrefix applies the HasPrefix predicate on the "Name" field. func NameHasPrefix(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldHasPrefix(FieldName, v)) } // NameHasSuffix applies the HasSuffix predicate on the "Name" field. func NameHasSuffix(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldHasSuffix(FieldName, v)) } // NameEqualFold applies the EqualFold predicate on the "Name" field. func NameEqualFold(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldEqualFold(FieldName, v)) } // NameContainsFold applies the ContainsFold predicate on the "Name" field. func NameContainsFold(v string) predicate.DeletedSubscription { - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldName), v)) - }) + return predicate.DeletedSubscription(sql.FieldContainsFold(FieldName, v)) } // RemoteIDEQ applies the EQ predicate on the "RemoteID" field. func RemoteIDEQ(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldEQ(FieldRemoteID, vc)) } // RemoteIDNEQ applies the NEQ predicate on the "RemoteID" field. func RemoteIDNEQ(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldNEQ(FieldRemoteID, vc)) } // RemoteIDIn applies the In predicate on the "RemoteID" field. func RemoteIDIn(vs ...imap.MailboxID) predicate.DeletedSubscription { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = string(vs[i]) } - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldRemoteID), v...)) - }) + return predicate.DeletedSubscription(sql.FieldIn(FieldRemoteID, v...)) } // RemoteIDNotIn applies the NotIn predicate on the "RemoteID" field. func RemoteIDNotIn(vs ...imap.MailboxID) predicate.DeletedSubscription { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = string(vs[i]) } - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldRemoteID), v...)) - }) + return predicate.DeletedSubscription(sql.FieldNotIn(FieldRemoteID, v...)) } // RemoteIDGT applies the GT predicate on the "RemoteID" field. func RemoteIDGT(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldGT(FieldRemoteID, vc)) } // RemoteIDGTE applies the GTE predicate on the "RemoteID" field. func RemoteIDGTE(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldGTE(FieldRemoteID, vc)) } // RemoteIDLT applies the LT predicate on the "RemoteID" field. func RemoteIDLT(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldLT(FieldRemoteID, vc)) } // RemoteIDLTE applies the LTE predicate on the "RemoteID" field. func RemoteIDLTE(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldLTE(FieldRemoteID, vc)) } // RemoteIDContains applies the Contains predicate on the "RemoteID" field. func RemoteIDContains(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldContains(FieldRemoteID, vc)) } // RemoteIDHasPrefix applies the HasPrefix predicate on the "RemoteID" field. func RemoteIDHasPrefix(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldHasPrefix(FieldRemoteID, vc)) } // RemoteIDHasSuffix applies the HasSuffix predicate on the "RemoteID" field. func RemoteIDHasSuffix(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldHasSuffix(FieldRemoteID, vc)) } // RemoteIDEqualFold applies the EqualFold predicate on the "RemoteID" field. func RemoteIDEqualFold(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldEqualFold(FieldRemoteID, vc)) } // RemoteIDContainsFold applies the ContainsFold predicate on the "RemoteID" field. func RemoteIDContainsFold(v imap.MailboxID) predicate.DeletedSubscription { vc := string(v) - return predicate.DeletedSubscription(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldRemoteID), vc)) - }) + return predicate.DeletedSubscription(sql.FieldContainsFold(FieldRemoteID, vc)) } // And groups predicates with the AND operator between them. diff --git a/internal/db/ent/deletedsubscription_create.go b/internal/db_impl/ent_db/internal/deletedsubscription_create.go similarity index 72% rename from internal/db/ent/deletedsubscription_create.go rename to internal/db_impl/ent_db/internal/deletedsubscription_create.go index e5c02209..773c24c1 100644 --- a/internal/db/ent/deletedsubscription_create.go +++ b/internal/db_impl/ent_db/internal/deletedsubscription_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,7 +10,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" ) // DeletedSubscriptionCreate is the builder for creating a DeletedSubscription entity. @@ -39,49 +39,7 @@ func (dsc *DeletedSubscriptionCreate) Mutation() *DeletedSubscriptionMutation { // Save creates the DeletedSubscription in the database. func (dsc *DeletedSubscriptionCreate) Save(ctx context.Context) (*DeletedSubscription, error) { - var ( - err error - node *DeletedSubscription - ) - if len(dsc.hooks) == 0 { - if err = dsc.check(); err != nil { - return nil, err - } - node, err = dsc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DeletedSubscriptionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = dsc.check(); err != nil { - return nil, err - } - dsc.mutation = mutation - if node, err = dsc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(dsc.hooks) - 1; i >= 0; i-- { - if dsc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dsc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, dsc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*DeletedSubscription) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from DeletedSubscriptionMutation", v) - } - node = nv - } - return node, err + return withHooks[*DeletedSubscription, DeletedSubscriptionMutation](ctx, dsc.sqlSave, dsc.mutation, dsc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -109,15 +67,18 @@ func (dsc *DeletedSubscriptionCreate) ExecX(ctx context.Context) { // check runs all checks and user-defined validators on the builder. func (dsc *DeletedSubscriptionCreate) check() error { if _, ok := dsc.mutation.Name(); !ok { - return &ValidationError{Name: "Name", err: errors.New(`ent: missing required field "DeletedSubscription.Name"`)} + return &ValidationError{Name: "Name", err: errors.New(`internal: missing required field "DeletedSubscription.Name"`)} } if _, ok := dsc.mutation.RemoteID(); !ok { - return &ValidationError{Name: "RemoteID", err: errors.New(`ent: missing required field "DeletedSubscription.RemoteID"`)} + return &ValidationError{Name: "RemoteID", err: errors.New(`internal: missing required field "DeletedSubscription.RemoteID"`)} } return nil } func (dsc *DeletedSubscriptionCreate) sqlSave(ctx context.Context) (*DeletedSubscription, error) { + if err := dsc.check(); err != nil { + return nil, err + } _node, _spec := dsc.createSpec() if err := sqlgraph.CreateNode(ctx, dsc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -127,34 +88,22 @@ func (dsc *DeletedSubscriptionCreate) sqlSave(ctx context.Context) (*DeletedSubs } id := _spec.ID.Value.(int64) _node.ID = int(id) + dsc.mutation.id = &_node.ID + dsc.mutation.done = true return _node, nil } func (dsc *DeletedSubscriptionCreate) createSpec() (*DeletedSubscription, *sqlgraph.CreateSpec) { var ( _node = &DeletedSubscription{config: dsc.config} - _spec = &sqlgraph.CreateSpec{ - Table: deletedsubscription.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: deletedsubscription.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(deletedsubscription.Table, sqlgraph.NewFieldSpec(deletedsubscription.FieldID, field.TypeInt)) ) if value, ok := dsc.mutation.Name(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: deletedsubscription.FieldName, - }) + _spec.SetField(deletedsubscription.FieldName, field.TypeString, value) _node.Name = value } if value, ok := dsc.mutation.RemoteID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: deletedsubscription.FieldRemoteID, - }) + _spec.SetField(deletedsubscription.FieldRemoteID, field.TypeString, value) _node.RemoteID = value } return _node, _spec diff --git a/internal/db/ent/deletedsubscription_delete.go b/internal/db_impl/ent_db/internal/deletedsubscription_delete.go similarity index 63% rename from internal/db/ent/deletedsubscription_delete.go rename to internal/db_impl/ent_db/internal/deletedsubscription_delete.go index df358765..f55bbfca 100644 --- a/internal/db/ent/deletedsubscription_delete.go +++ b/internal/db_impl/ent_db/internal/deletedsubscription_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // DeletedSubscriptionDelete is the builder for deleting a DeletedSubscription entity. @@ -28,34 +27,7 @@ func (dsd *DeletedSubscriptionDelete) Where(ps ...predicate.DeletedSubscription) // Exec executes the deletion query and returns how many vertices were deleted. func (dsd *DeletedSubscriptionDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(dsd.hooks) == 0 { - affected, err = dsd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DeletedSubscriptionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - dsd.mutation = mutation - affected, err = dsd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(dsd.hooks) - 1; i >= 0; i-- { - if dsd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dsd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, dsd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, DeletedSubscriptionMutation](ctx, dsd.sqlExec, dsd.mutation, dsd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (dsd *DeletedSubscriptionDelete) ExecX(ctx context.Context) int { } func (dsd *DeletedSubscriptionDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: deletedsubscription.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: deletedsubscription.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(deletedsubscription.Table, sqlgraph.NewFieldSpec(deletedsubscription.FieldID, field.TypeInt)) if ps := dsd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (dsd *DeletedSubscriptionDelete) sqlExec(ctx context.Context) (int, error) if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + dsd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type DeletedSubscriptionDeleteOne struct { dsd *DeletedSubscriptionDelete } +// Where appends a list predicates to the DeletedSubscriptionDelete builder. +func (dsdo *DeletedSubscriptionDeleteOne) Where(ps ...predicate.DeletedSubscription) *DeletedSubscriptionDeleteOne { + dsdo.dsd.mutation.Where(ps...) + return dsdo +} + // Exec executes the deletion query. func (dsdo *DeletedSubscriptionDeleteOne) Exec(ctx context.Context) error { n, err := dsdo.dsd.Exec(ctx) @@ -111,5 +82,7 @@ func (dsdo *DeletedSubscriptionDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (dsdo *DeletedSubscriptionDeleteOne) ExecX(ctx context.Context) { - dsdo.dsd.ExecX(ctx) + if err := dsdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/deletedsubscription_query.go b/internal/db_impl/ent_db/internal/deletedsubscription_query.go similarity index 68% rename from internal/db/ent/deletedsubscription_query.go rename to internal/db_impl/ent_db/internal/deletedsubscription_query.go index 93472a0a..78107132 100644 --- a/internal/db/ent/deletedsubscription_query.go +++ b/internal/db_impl/ent_db/internal/deletedsubscription_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,18 +10,16 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // DeletedSubscriptionQuery is the builder for querying DeletedSubscription entities. type DeletedSubscriptionQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.DeletedSubscription // intermediate query (i.e. traversal path). sql *sql.Selector @@ -34,26 +32,26 @@ func (dsq *DeletedSubscriptionQuery) Where(ps ...predicate.DeletedSubscription) return dsq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (dsq *DeletedSubscriptionQuery) Limit(limit int) *DeletedSubscriptionQuery { - dsq.limit = &limit + dsq.ctx.Limit = &limit return dsq } -// Offset adds an offset step to the query. +// Offset to start from. func (dsq *DeletedSubscriptionQuery) Offset(offset int) *DeletedSubscriptionQuery { - dsq.offset = &offset + dsq.ctx.Offset = &offset return dsq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (dsq *DeletedSubscriptionQuery) Unique(unique bool) *DeletedSubscriptionQuery { - dsq.unique = &unique + dsq.ctx.Unique = &unique return dsq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (dsq *DeletedSubscriptionQuery) Order(o ...OrderFunc) *DeletedSubscriptionQuery { dsq.order = append(dsq.order, o...) return dsq @@ -62,7 +60,7 @@ func (dsq *DeletedSubscriptionQuery) Order(o ...OrderFunc) *DeletedSubscriptionQ // First returns the first DeletedSubscription entity from the query. // Returns a *NotFoundError when no DeletedSubscription was found. func (dsq *DeletedSubscriptionQuery) First(ctx context.Context) (*DeletedSubscription, error) { - nodes, err := dsq.Limit(1).All(ctx) + nodes, err := dsq.Limit(1).All(setContextOp(ctx, dsq.ctx, "First")) if err != nil { return nil, err } @@ -85,7 +83,7 @@ func (dsq *DeletedSubscriptionQuery) FirstX(ctx context.Context) *DeletedSubscri // Returns a *NotFoundError when no DeletedSubscription ID was found. func (dsq *DeletedSubscriptionQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = dsq.Limit(1).IDs(ctx); err != nil { + if ids, err = dsq.Limit(1).IDs(setContextOp(ctx, dsq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -108,7 +106,7 @@ func (dsq *DeletedSubscriptionQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one DeletedSubscription entity is found. // Returns a *NotFoundError when no DeletedSubscription entities are found. func (dsq *DeletedSubscriptionQuery) Only(ctx context.Context) (*DeletedSubscription, error) { - nodes, err := dsq.Limit(2).All(ctx) + nodes, err := dsq.Limit(2).All(setContextOp(ctx, dsq.ctx, "Only")) if err != nil { return nil, err } @@ -136,7 +134,7 @@ func (dsq *DeletedSubscriptionQuery) OnlyX(ctx context.Context) *DeletedSubscrip // Returns a *NotFoundError when no entities are found. func (dsq *DeletedSubscriptionQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = dsq.Limit(2).IDs(ctx); err != nil { + if ids, err = dsq.Limit(2).IDs(setContextOp(ctx, dsq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -161,10 +159,12 @@ func (dsq *DeletedSubscriptionQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of DeletedSubscriptions. func (dsq *DeletedSubscriptionQuery) All(ctx context.Context) ([]*DeletedSubscription, error) { + ctx = setContextOp(ctx, dsq.ctx, "All") if err := dsq.prepareQuery(ctx); err != nil { return nil, err } - return dsq.sqlAll(ctx) + qr := querierAll[[]*DeletedSubscription, *DeletedSubscriptionQuery]() + return withInterceptors[[]*DeletedSubscription](ctx, dsq, qr, dsq.inters) } // AllX is like All, but panics if an error occurs. @@ -177,9 +177,12 @@ func (dsq *DeletedSubscriptionQuery) AllX(ctx context.Context) []*DeletedSubscri } // IDs executes the query and returns a list of DeletedSubscription IDs. -func (dsq *DeletedSubscriptionQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := dsq.Select(deletedsubscription.FieldID).Scan(ctx, &ids); err != nil { +func (dsq *DeletedSubscriptionQuery) IDs(ctx context.Context) (ids []int, err error) { + if dsq.ctx.Unique == nil && dsq.path != nil { + dsq.Unique(true) + } + ctx = setContextOp(ctx, dsq.ctx, "IDs") + if err = dsq.Select(deletedsubscription.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -196,10 +199,11 @@ func (dsq *DeletedSubscriptionQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (dsq *DeletedSubscriptionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, dsq.ctx, "Count") if err := dsq.prepareQuery(ctx); err != nil { return 0, err } - return dsq.sqlCount(ctx) + return withInterceptors[int](ctx, dsq, querierCount[*DeletedSubscriptionQuery](), dsq.inters) } // CountX is like Count, but panics if an error occurs. @@ -213,10 +217,15 @@ func (dsq *DeletedSubscriptionQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (dsq *DeletedSubscriptionQuery) Exist(ctx context.Context) (bool, error) { - if err := dsq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, dsq.ctx, "Exist") + switch _, err := dsq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return dsq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -236,14 +245,13 @@ func (dsq *DeletedSubscriptionQuery) Clone() *DeletedSubscriptionQuery { } return &DeletedSubscriptionQuery{ config: dsq.config, - limit: dsq.limit, - offset: dsq.offset, + ctx: dsq.ctx.Clone(), order: append([]OrderFunc{}, dsq.order...), + inters: append([]Interceptor{}, dsq.inters...), predicates: append([]predicate.DeletedSubscription{}, dsq.predicates...), // clone intermediate query. - sql: dsq.sql.Clone(), - path: dsq.path, - unique: dsq.unique, + sql: dsq.sql.Clone(), + path: dsq.path, } } @@ -259,19 +267,14 @@ func (dsq *DeletedSubscriptionQuery) Clone() *DeletedSubscriptionQuery { // // client.DeletedSubscription.Query(). // GroupBy(deletedsubscription.FieldName). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (dsq *DeletedSubscriptionQuery) GroupBy(field string, fields ...string) *DeletedSubscriptionGroupBy { - grbuild := &DeletedSubscriptionGroupBy{config: dsq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := dsq.prepareQuery(ctx); err != nil { - return nil, err - } - return dsq.sqlQuery(ctx), nil - } + dsq.ctx.Fields = append([]string{field}, fields...) + grbuild := &DeletedSubscriptionGroupBy{build: dsq} + grbuild.flds = &dsq.ctx.Fields grbuild.label = deletedsubscription.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -288,17 +291,32 @@ func (dsq *DeletedSubscriptionQuery) GroupBy(field string, fields ...string) *De // Select(deletedsubscription.FieldName). // Scan(ctx, &v) func (dsq *DeletedSubscriptionQuery) Select(fields ...string) *DeletedSubscriptionSelect { - dsq.fields = append(dsq.fields, fields...) - selbuild := &DeletedSubscriptionSelect{DeletedSubscriptionQuery: dsq} - selbuild.label = deletedsubscription.Label - selbuild.flds, selbuild.scan = &dsq.fields, selbuild.Scan - return selbuild + dsq.ctx.Fields = append(dsq.ctx.Fields, fields...) + sbuild := &DeletedSubscriptionSelect{DeletedSubscriptionQuery: dsq} + sbuild.label = deletedsubscription.Label + sbuild.flds, sbuild.scan = &dsq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a DeletedSubscriptionSelect configured with the given aggregations. +func (dsq *DeletedSubscriptionQuery) Aggregate(fns ...AggregateFunc) *DeletedSubscriptionSelect { + return dsq.Select().Aggregate(fns...) } func (dsq *DeletedSubscriptionQuery) prepareQuery(ctx context.Context) error { - for _, f := range dsq.fields { + for _, inter := range dsq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, dsq); err != nil { + return err + } + } + } + for _, f := range dsq.ctx.Fields { if !deletedsubscription.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if dsq.path != nil { @@ -316,10 +334,10 @@ func (dsq *DeletedSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryH nodes = []*DeletedSubscription{} _spec = dsq.querySpec() ) - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*DeletedSubscription).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &DeletedSubscription{config: dsq.config} nodes = append(nodes, node) return node.assignValues(columns, values) @@ -338,38 +356,22 @@ func (dsq *DeletedSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryH func (dsq *DeletedSubscriptionQuery) sqlCount(ctx context.Context) (int, error) { _spec := dsq.querySpec() - _spec.Node.Columns = dsq.fields - if len(dsq.fields) > 0 { - _spec.Unique = dsq.unique != nil && *dsq.unique + _spec.Node.Columns = dsq.ctx.Fields + if len(dsq.ctx.Fields) > 0 { + _spec.Unique = dsq.ctx.Unique != nil && *dsq.ctx.Unique } return sqlgraph.CountNodes(ctx, dsq.driver, _spec) } -func (dsq *DeletedSubscriptionQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := dsq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (dsq *DeletedSubscriptionQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: deletedsubscription.Table, - Columns: deletedsubscription.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: deletedsubscription.FieldID, - }, - }, - From: dsq.sql, - Unique: true, - } - if unique := dsq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(deletedsubscription.Table, deletedsubscription.Columns, sqlgraph.NewFieldSpec(deletedsubscription.FieldID, field.TypeInt)) + _spec.From = dsq.sql + if unique := dsq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if dsq.path != nil { + _spec.Unique = true } - if fields := dsq.fields; len(fields) > 0 { + if fields := dsq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, deletedsubscription.FieldID) for i := range fields { @@ -385,10 +387,10 @@ func (dsq *DeletedSubscriptionQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := dsq.limit; limit != nil { + if limit := dsq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := dsq.offset; offset != nil { + if offset := dsq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := dsq.order; len(ps) > 0 { @@ -404,7 +406,7 @@ func (dsq *DeletedSubscriptionQuery) querySpec() *sqlgraph.QuerySpec { func (dsq *DeletedSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(dsq.driver.Dialect()) t1 := builder.Table(deletedsubscription.Table) - columns := dsq.fields + columns := dsq.ctx.Fields if len(columns) == 0 { columns = deletedsubscription.Columns } @@ -413,7 +415,7 @@ func (dsq *DeletedSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector selector = dsq.sql selector.Select(selector.Columns(columns...)...) } - if dsq.unique != nil && *dsq.unique { + if dsq.ctx.Unique != nil && *dsq.ctx.Unique { selector.Distinct() } for _, p := range dsq.predicates { @@ -422,12 +424,12 @@ func (dsq *DeletedSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector for _, p := range dsq.order { p(selector) } - if offset := dsq.offset; offset != nil { + if offset := dsq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := dsq.limit; limit != nil { + if limit := dsq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -435,13 +437,8 @@ func (dsq *DeletedSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector // DeletedSubscriptionGroupBy is the group-by builder for DeletedSubscription entities. type DeletedSubscriptionGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *DeletedSubscriptionQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -450,74 +447,77 @@ func (dsgb *DeletedSubscriptionGroupBy) Aggregate(fns ...AggregateFunc) *Deleted return dsgb } -// Scan applies the group-by query and scans the result into the given value. -func (dsgb *DeletedSubscriptionGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := dsgb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (dsgb *DeletedSubscriptionGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, dsgb.build.ctx, "GroupBy") + if err := dsgb.build.prepareQuery(ctx); err != nil { return err } - dsgb.sql = query - return dsgb.sqlScan(ctx, v) + return scanWithInterceptors[*DeletedSubscriptionQuery, *DeletedSubscriptionGroupBy](ctx, dsgb.build, dsgb, dsgb.build.inters, v) } -func (dsgb *DeletedSubscriptionGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range dsgb.fields { - if !deletedsubscription.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (dsgb *DeletedSubscriptionGroupBy) sqlScan(ctx context.Context, root *DeletedSubscriptionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(dsgb.fns)) + for _, fn := range dsgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*dsgb.flds)+len(dsgb.fns)) + for _, f := range *dsgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := dsgb.sqlQuery() + selector.GroupBy(selector.Columns(*dsgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := dsgb.driver.Query(ctx, query, args, rows); err != nil { + if err := dsgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (dsgb *DeletedSubscriptionGroupBy) sqlQuery() *sql.Selector { - selector := dsgb.sql.Select() - aggregation := make([]string, 0, len(dsgb.fns)) - for _, fn := range dsgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(dsgb.fields)+len(dsgb.fns)) - for _, f := range dsgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(dsgb.fields...)...) -} - // DeletedSubscriptionSelect is the builder for selecting fields of DeletedSubscription entities. type DeletedSubscriptionSelect struct { *DeletedSubscriptionQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (dss *DeletedSubscriptionSelect) Aggregate(fns ...AggregateFunc) *DeletedSubscriptionSelect { + dss.fns = append(dss.fns, fns...) + return dss } // Scan applies the selector query and scans the result into the given value. -func (dss *DeletedSubscriptionSelect) Scan(ctx context.Context, v interface{}) error { +func (dss *DeletedSubscriptionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, dss.ctx, "Select") if err := dss.prepareQuery(ctx); err != nil { return err } - dss.sql = dss.DeletedSubscriptionQuery.sqlQuery(ctx) - return dss.sqlScan(ctx, v) + return scanWithInterceptors[*DeletedSubscriptionQuery, *DeletedSubscriptionSelect](ctx, dss.DeletedSubscriptionQuery, dss, dss.inters, v) } -func (dss *DeletedSubscriptionSelect) sqlScan(ctx context.Context, v interface{}) error { +func (dss *DeletedSubscriptionSelect) sqlScan(ctx context.Context, root *DeletedSubscriptionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(dss.fns)) + for _, fn := range dss.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*dss.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := dss.sql.Query() + query, args := selector.Query() if err := dss.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/deletedsubscription_update.go b/internal/db_impl/ent_db/internal/deletedsubscription_update.go similarity index 64% rename from internal/db/ent/deletedsubscription_update.go rename to internal/db_impl/ent_db/internal/deletedsubscription_update.go index f28054c5..4ec23db6 100644 --- a/internal/db/ent/deletedsubscription_update.go +++ b/internal/db_impl/ent_db/internal/deletedsubscription_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -11,8 +11,8 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // DeletedSubscriptionUpdate is the builder for updating DeletedSubscription entities. @@ -47,34 +47,7 @@ func (dsu *DeletedSubscriptionUpdate) Mutation() *DeletedSubscriptionMutation { // Save executes the query and returns the number of nodes affected by the update operation. func (dsu *DeletedSubscriptionUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(dsu.hooks) == 0 { - affected, err = dsu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DeletedSubscriptionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - dsu.mutation = mutation - affected, err = dsu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(dsu.hooks) - 1; i >= 0; i-- { - if dsu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dsu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, dsu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, DeletedSubscriptionMutation](ctx, dsu.sqlSave, dsu.mutation, dsu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -100,16 +73,7 @@ func (dsu *DeletedSubscriptionUpdate) ExecX(ctx context.Context) { } func (dsu *DeletedSubscriptionUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: deletedsubscription.Table, - Columns: deletedsubscription.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: deletedsubscription.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(deletedsubscription.Table, deletedsubscription.Columns, sqlgraph.NewFieldSpec(deletedsubscription.FieldID, field.TypeInt)) if ps := dsu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -118,18 +82,10 @@ func (dsu *DeletedSubscriptionUpdate) sqlSave(ctx context.Context) (n int, err e } } if value, ok := dsu.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: deletedsubscription.FieldName, - }) + _spec.SetField(deletedsubscription.FieldName, field.TypeString, value) } if value, ok := dsu.mutation.RemoteID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: deletedsubscription.FieldRemoteID, - }) + _spec.SetField(deletedsubscription.FieldRemoteID, field.TypeString, value) } if n, err = sqlgraph.UpdateNodes(ctx, dsu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -139,6 +95,7 @@ func (dsu *DeletedSubscriptionUpdate) sqlSave(ctx context.Context) (n int, err e } return 0, err } + dsu.mutation.done = true return n, nil } @@ -167,6 +124,12 @@ func (dsuo *DeletedSubscriptionUpdateOne) Mutation() *DeletedSubscriptionMutatio return dsuo.mutation } +// Where appends a list predicates to the DeletedSubscriptionUpdate builder. +func (dsuo *DeletedSubscriptionUpdateOne) Where(ps ...predicate.DeletedSubscription) *DeletedSubscriptionUpdateOne { + dsuo.mutation.Where(ps...) + return dsuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (dsuo *DeletedSubscriptionUpdateOne) Select(field string, fields ...string) *DeletedSubscriptionUpdateOne { @@ -176,40 +139,7 @@ func (dsuo *DeletedSubscriptionUpdateOne) Select(field string, fields ...string) // Save executes the query and returns the updated DeletedSubscription entity. func (dsuo *DeletedSubscriptionUpdateOne) Save(ctx context.Context) (*DeletedSubscription, error) { - var ( - err error - node *DeletedSubscription - ) - if len(dsuo.hooks) == 0 { - node, err = dsuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DeletedSubscriptionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - dsuo.mutation = mutation - node, err = dsuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(dsuo.hooks) - 1; i >= 0; i-- { - if dsuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dsuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, dsuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*DeletedSubscription) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from DeletedSubscriptionMutation", v) - } - node = nv - } - return node, err + return withHooks[*DeletedSubscription, DeletedSubscriptionMutation](ctx, dsuo.sqlSave, dsuo.mutation, dsuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -235,19 +165,10 @@ func (dsuo *DeletedSubscriptionUpdateOne) ExecX(ctx context.Context) { } func (dsuo *DeletedSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *DeletedSubscription, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: deletedsubscription.Table, - Columns: deletedsubscription.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: deletedsubscription.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(deletedsubscription.Table, deletedsubscription.Columns, sqlgraph.NewFieldSpec(deletedsubscription.FieldID, field.TypeInt)) id, ok := dsuo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "DeletedSubscription.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "DeletedSubscription.id" for update`)} } _spec.Node.ID.Value = id if fields := dsuo.fields; len(fields) > 0 { @@ -255,7 +176,7 @@ func (dsuo *DeletedSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *D _spec.Node.Columns = append(_spec.Node.Columns, deletedsubscription.FieldID) for _, f := range fields { if !deletedsubscription.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != deletedsubscription.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -270,18 +191,10 @@ func (dsuo *DeletedSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *D } } if value, ok := dsuo.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: deletedsubscription.FieldName, - }) + _spec.SetField(deletedsubscription.FieldName, field.TypeString, value) } if value, ok := dsuo.mutation.RemoteID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: deletedsubscription.FieldRemoteID, - }) + _spec.SetField(deletedsubscription.FieldRemoteID, field.TypeString, value) } _node = &DeletedSubscription{config: dsuo.config} _spec.Assign = _node.assignValues @@ -294,5 +207,6 @@ func (dsuo *DeletedSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *D } return nil, err } + dsuo.mutation.done = true return _node, nil } diff --git a/internal/db/ent/ent.go b/internal/db_impl/ent_db/internal/ent.go similarity index 63% rename from internal/db/ent/ent.go rename to internal/db_impl/ent_db/internal/ent.go index b5c03992..c63f197f 100644 --- a/internal/db/ent/ent.go +++ b/internal/db_impl/ent_db/internal/ent.go @@ -1,35 +1,43 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" "errors" "fmt" + "reflect" "entgo.io/ent" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // ent aliases to avoid import conflicts in user's code. type ( - Op = ent.Op - Hook = ent.Hook - Value = ent.Value - Query = ent.Query - Policy = ent.Policy - Mutator = ent.Mutator - Mutation = ent.Mutation - MutateFunc = ent.MutateFunc + Op = ent.Op + Hook = ent.Hook + Value = ent.Value + Query = ent.Query + QueryContext = ent.QueryContext + Querier = ent.Querier + QuerierFunc = ent.QuerierFunc + Interceptor = ent.Interceptor + InterceptFunc = ent.InterceptFunc + Traverser = ent.Traverser + TraverseFunc = ent.TraverseFunc + Policy = ent.Policy + Mutator = ent.Mutator + Mutation = ent.Mutation + MutateFunc = ent.MutateFunc ) // OrderFunc applies an ordering on the sql selector. @@ -67,7 +75,7 @@ func Asc(fields ...string) OrderFunc { check := columnChecker(s.TableName()) for _, f := range fields { if err := check(f); err != nil { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("internal: %w", err)}) } s.OrderBy(sql.Asc(s.C(f))) } @@ -80,7 +88,7 @@ func Desc(fields ...string) OrderFunc { check := columnChecker(s.TableName()) for _, f := range fields { if err := check(f); err != nil { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("internal: %w", err)}) } s.OrderBy(sql.Desc(s.C(f))) } @@ -93,7 +101,7 @@ type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // // GroupBy(field1, field2). -// Aggregate(ent.As(ent.Sum(field1), "sum_field1"), (ent.As(ent.Sum(field2), "sum_field2")). +// Aggregate(internal.As(internal.Sum(field1), "sum_field1"), (internal.As(internal.Sum(field2), "sum_field2")). // Scan(ctx, &v) func As(fn AggregateFunc, end string) AggregateFunc { return func(s *sql.Selector) string { @@ -113,7 +121,7 @@ func Max(field string) AggregateFunc { return func(s *sql.Selector) string { check := columnChecker(s.TableName()) if err := check(field); err != nil { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("internal: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -125,7 +133,7 @@ func Mean(field string) AggregateFunc { return func(s *sql.Selector) string { check := columnChecker(s.TableName()) if err := check(field); err != nil { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("internal: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -137,7 +145,7 @@ func Min(field string) AggregateFunc { return func(s *sql.Selector) string { check := columnChecker(s.TableName()) if err := check(field); err != nil { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("internal: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -149,7 +157,7 @@ func Sum(field string) AggregateFunc { return func(s *sql.Selector) string { check := columnChecker(s.TableName()) if err := check(field); err != nil { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("internal: %w", err)}) return "" } return sql.Sum(s.C(field)) @@ -188,7 +196,7 @@ type NotFoundError struct { // Error implements the error interface. func (e *NotFoundError) Error() string { - return "ent: " + e.label + " not found" + return "internal: " + e.label + " not found" } // IsNotFound returns a boolean indicating whether the error is a not found error. @@ -215,7 +223,7 @@ type NotSingularError struct { // Error implements the error interface. func (e *NotSingularError) Error() string { - return "ent: " + e.label + " not singular" + return "internal: " + e.label + " not singular" } // IsNotSingular returns a boolean indicating whether the error is a not singular error. @@ -234,7 +242,7 @@ type NotLoadedError struct { // Error implements the error interface. func (e *NotLoadedError) Error() string { - return "ent: " + e.edge + " edge was not loaded" + return "internal: " + e.edge + " edge was not loaded" } // IsNotLoaded returns a boolean indicating whether the error is a not loaded error. @@ -256,7 +264,7 @@ type ConstraintError struct { // Error implements the error interface. func (e ConstraintError) Error() string { - return "ent: constraint failed: " + e.msg + return "internal: constraint failed: " + e.msg } // Unwrap implements the errors.Wrapper interface. @@ -277,11 +285,12 @@ func IsConstraintError(err error) bool { type selector struct { label string flds *[]string - scan func(context.Context, interface{}) error + fns []AggregateFunc + scan func(context.Context, any) error } // ScanX is like Scan, but panics if an error occurs. -func (s *selector) ScanX(ctx context.Context, v interface{}) { +func (s *selector) ScanX(ctx context.Context, v any) { if err := s.scan(ctx, v); err != nil { panic(err) } @@ -290,7 +299,7 @@ func (s *selector) ScanX(ctx context.Context, v interface{}) { // Strings returns list of strings from a selector. It is only allowed when selecting one field. func (s *selector) Strings(ctx context.Context) ([]string, error) { if len(*s.flds) > 1 { - return nil, errors.New("ent: Strings is not achievable when selecting more than 1 field") + return nil, errors.New("internal: Strings is not achievable when selecting more than 1 field") } var v []string if err := s.scan(ctx, &v); err != nil { @@ -320,7 +329,7 @@ func (s *selector) String(ctx context.Context) (_ string, err error) { case 0: err = &NotFoundError{s.label} default: - err = fmt.Errorf("ent: Strings returned %d results when one was expected", len(v)) + err = fmt.Errorf("internal: Strings returned %d results when one was expected", len(v)) } return } @@ -337,7 +346,7 @@ func (s *selector) StringX(ctx context.Context) string { // Ints returns list of ints from a selector. It is only allowed when selecting one field. func (s *selector) Ints(ctx context.Context) ([]int, error) { if len(*s.flds) > 1 { - return nil, errors.New("ent: Ints is not achievable when selecting more than 1 field") + return nil, errors.New("internal: Ints is not achievable when selecting more than 1 field") } var v []int if err := s.scan(ctx, &v); err != nil { @@ -367,7 +376,7 @@ func (s *selector) Int(ctx context.Context) (_ int, err error) { case 0: err = &NotFoundError{s.label} default: - err = fmt.Errorf("ent: Ints returned %d results when one was expected", len(v)) + err = fmt.Errorf("internal: Ints returned %d results when one was expected", len(v)) } return } @@ -384,7 +393,7 @@ func (s *selector) IntX(ctx context.Context) int { // Float64s returns list of float64s from a selector. It is only allowed when selecting one field. func (s *selector) Float64s(ctx context.Context) ([]float64, error) { if len(*s.flds) > 1 { - return nil, errors.New("ent: Float64s is not achievable when selecting more than 1 field") + return nil, errors.New("internal: Float64s is not achievable when selecting more than 1 field") } var v []float64 if err := s.scan(ctx, &v); err != nil { @@ -414,7 +423,7 @@ func (s *selector) Float64(ctx context.Context) (_ float64, err error) { case 0: err = &NotFoundError{s.label} default: - err = fmt.Errorf("ent: Float64s returned %d results when one was expected", len(v)) + err = fmt.Errorf("internal: Float64s returned %d results when one was expected", len(v)) } return } @@ -431,7 +440,7 @@ func (s *selector) Float64X(ctx context.Context) float64 { // Bools returns list of bools from a selector. It is only allowed when selecting one field. func (s *selector) Bools(ctx context.Context) ([]bool, error) { if len(*s.flds) > 1 { - return nil, errors.New("ent: Bools is not achievable when selecting more than 1 field") + return nil, errors.New("internal: Bools is not achievable when selecting more than 1 field") } var v []bool if err := s.scan(ctx, &v); err != nil { @@ -461,7 +470,7 @@ func (s *selector) Bool(ctx context.Context) (_ bool, err error) { case 0: err = &NotFoundError{s.label} default: - err = fmt.Errorf("ent: Bools returned %d results when one was expected", len(v)) + err = fmt.Errorf("internal: Bools returned %d results when one was expected", len(v)) } return } @@ -475,5 +484,121 @@ func (s *selector) BoolX(ctx context.Context) bool { return v } +// withHooks invokes the builder operation with the given hooks, if any. +func withHooks[V Value, M any, PM interface { + *M + Mutation +}](ctx context.Context, exec func(context.Context) (V, error), mutation PM, hooks []Hook) (value V, err error) { + if len(hooks) == 0 { + return exec(ctx) + } + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutationT, ok := m.(PM) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + // Set the mutation to the builder. + *mutation = *mutationT + return exec(ctx) + }) + for i := len(hooks) - 1; i >= 0; i-- { + if hooks[i] == nil { + return value, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") + } + mut = hooks[i](mut) + } + v, err := mut.Mutate(ctx, mutation) + if err != nil { + return value, err + } + nv, ok := v.(V) + if !ok { + return value, fmt.Errorf("unexpected node type %T returned from %T", v, mutation) + } + return nv, nil +} + +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { + if ent.QueryFromContext(ctx) == nil { + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) + } + return ctx +} + +func querierAll[V Value, Q interface { + sqlAll(context.Context, ...queryHook) (V, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlAll(ctx) + }) +} + +func querierCount[Q interface { + sqlCount(context.Context) (int, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlCount(ctx) + }) +} + +func withInterceptors[V Value](ctx context.Context, q Query, qr Querier, inters []Interceptor) (v V, err error) { + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + rv, err := qr.Query(ctx, q) + if err != nil { + return v, err + } + vt, ok := rv.(V) + if !ok { + return v, fmt.Errorf("unexpected type %T returned from %T. expected type: %T", vt, q, v) + } + return vt, nil +} + +func scanWithInterceptors[Q1 ent.Query, Q2 interface { + sqlScan(context.Context, Q1, any) error +}](ctx context.Context, rootQuery Q1, selectOrGroup Q2, inters []Interceptor, v any) error { + rv := reflect.ValueOf(v) + var qr Querier = QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q1) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + if err := selectOrGroup.sqlScan(ctx, query, v); err != nil { + return nil, err + } + if k := rv.Kind(); k == reflect.Pointer && rv.Elem().CanInterface() { + return rv.Elem().Interface(), nil + } + return v, nil + }) + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + vv, err := qr.Query(ctx, rootQuery) + if err != nil { + return err + } + switch rv2 := reflect.ValueOf(vv); { + case rv.IsNil(), rv2.IsNil(), rv.Kind() != reflect.Pointer: + case rv.Type() == rv2.Type(): + rv.Elem().Set(rv2.Elem()) + case rv.Elem().Type() == rv2.Type(): + rv.Elem().Set(rv2) + } + return nil +} + // queryHook describes an internal hook for the different sqlAll methods. type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/internal/db/ent/enttest/enttest.go b/internal/db_impl/ent_db/internal/enttest/enttest.go similarity index 65% rename from internal/db/ent/enttest/enttest.go rename to internal/db_impl/ent_db/internal/enttest/enttest.go index 3dcd985b..0b65a31b 100644 --- a/internal/db/ent/enttest/enttest.go +++ b/internal/db_impl/ent_db/internal/enttest/enttest.go @@ -5,12 +5,12 @@ package enttest import ( "context" - "github.com/ProtonMail/gluon/internal/db/ent" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" // required by schema hooks. - _ "github.com/ProtonMail/gluon/internal/db/ent/runtime" + _ "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/runtime" "entgo.io/ent/dialect/sql/schema" - "github.com/ProtonMail/gluon/internal/db/ent/migrate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/migrate" ) type ( @@ -18,20 +18,20 @@ type ( // testing.T and testing.B and used by enttest. TestingT interface { FailNow() - Error(...interface{}) + Error(...any) } // Option configures client creation. Option func(*options) options struct { - opts []ent.Option + opts []internal.Option migrateOpts []schema.MigrateOption } ) // WithOptions forwards options to client creation. -func WithOptions(opts ...ent.Option) Option { +func WithOptions(opts ...internal.Option) Option { return func(o *options) { o.opts = append(o.opts, opts...) } @@ -52,10 +52,10 @@ func newOptions(opts []Option) *options { return o } -// Open calls ent.Open and auto-run migration. -func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *ent.Client { +// Open calls internal.Open and auto-run migration. +func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *internal.Client { o := newOptions(opts) - c, err := ent.Open(driverName, dataSourceName, o.opts...) + c, err := internal.Open(driverName, dataSourceName, o.opts...) if err != nil { t.Error(err) t.FailNow() @@ -64,14 +64,14 @@ func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *ent.Cl return c } -// NewClient calls ent.NewClient and auto-run migration. -func NewClient(t TestingT, opts ...Option) *ent.Client { +// NewClient calls internal.NewClient and auto-run migration. +func NewClient(t TestingT, opts ...Option) *internal.Client { o := newOptions(opts) - c := ent.NewClient(o.opts...) + c := internal.NewClient(o.opts...) migrateSchema(t, c, o) return c } -func migrateSchema(t TestingT, c *ent.Client, o *options) { +func migrateSchema(t TestingT, c *internal.Client, o *options) { tables, err := schema.CopyTables(migrate.Tables) if err != nil { t.Error(err) diff --git a/internal/db/ent/generate.go b/internal/db_impl/ent_db/internal/generate.go similarity index 80% rename from internal/db/ent/generate.go rename to internal/db_impl/ent_db/internal/generate.go index 8d3fdfdc..c46b85b3 100644 --- a/internal/db/ent/generate.go +++ b/internal/db_impl/ent_db/internal/generate.go @@ -1,3 +1,3 @@ -package ent +package internal //go:generate go run -mod=mod entgo.io/ent/cmd/ent generate ./schema diff --git a/internal/db_impl/ent_db/internal/hook/hook.go b/internal/db_impl/ent_db/internal/hook/hook.go new file mode 100644 index 00000000..344683cb --- /dev/null +++ b/internal/db_impl/ent_db/internal/hook/hook.go @@ -0,0 +1,283 @@ +// Code generated by ent, DO NOT EDIT. + +package hook + +import ( + "context" + "fmt" + + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" +) + +// The DeletedSubscriptionFunc type is an adapter to allow the use of ordinary +// function as DeletedSubscription mutator. +type DeletedSubscriptionFunc func(context.Context, *internal.DeletedSubscriptionMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f DeletedSubscriptionFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.DeletedSubscriptionMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.DeletedSubscriptionMutation", m) +} + +// The MailboxFunc type is an adapter to allow the use of ordinary +// function as Mailbox mutator. +type MailboxFunc func(context.Context, *internal.MailboxMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f MailboxFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.MailboxMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.MailboxMutation", m) +} + +// The MailboxAttrFunc type is an adapter to allow the use of ordinary +// function as MailboxAttr mutator. +type MailboxAttrFunc func(context.Context, *internal.MailboxAttrMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f MailboxAttrFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.MailboxAttrMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.MailboxAttrMutation", m) +} + +// The MailboxFlagFunc type is an adapter to allow the use of ordinary +// function as MailboxFlag mutator. +type MailboxFlagFunc func(context.Context, *internal.MailboxFlagMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f MailboxFlagFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.MailboxFlagMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.MailboxFlagMutation", m) +} + +// The MailboxPermFlagFunc type is an adapter to allow the use of ordinary +// function as MailboxPermFlag mutator. +type MailboxPermFlagFunc func(context.Context, *internal.MailboxPermFlagMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f MailboxPermFlagFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.MailboxPermFlagMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.MailboxPermFlagMutation", m) +} + +// The MessageFunc type is an adapter to allow the use of ordinary +// function as Message mutator. +type MessageFunc func(context.Context, *internal.MessageMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f MessageFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.MessageMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.MessageMutation", m) +} + +// The MessageFlagFunc type is an adapter to allow the use of ordinary +// function as MessageFlag mutator. +type MessageFlagFunc func(context.Context, *internal.MessageFlagMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f MessageFlagFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.MessageFlagMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.MessageFlagMutation", m) +} + +// The UIDFunc type is an adapter to allow the use of ordinary +// function as UID mutator. +type UIDFunc func(context.Context, *internal.UIDMutation) (internal.Value, error) + +// Mutate calls f(ctx, m). +func (f UIDFunc) Mutate(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if mv, ok := m.(*internal.UIDMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *internal.UIDMutation", m) +} + +// Condition is a hook condition function. +type Condition func(context.Context, internal.Mutation) bool + +// And groups conditions with the AND operator. +func And(first, second Condition, rest ...Condition) Condition { + return func(ctx context.Context, m internal.Mutation) bool { + if !first(ctx, m) || !second(ctx, m) { + return false + } + for _, cond := range rest { + if !cond(ctx, m) { + return false + } + } + return true + } +} + +// Or groups conditions with the OR operator. +func Or(first, second Condition, rest ...Condition) Condition { + return func(ctx context.Context, m internal.Mutation) bool { + if first(ctx, m) || second(ctx, m) { + return true + } + for _, cond := range rest { + if cond(ctx, m) { + return true + } + } + return false + } +} + +// Not negates a given condition. +func Not(cond Condition) Condition { + return func(ctx context.Context, m internal.Mutation) bool { + return !cond(ctx, m) + } +} + +// HasOp is a condition testing mutation operation. +func HasOp(op internal.Op) Condition { + return func(_ context.Context, m internal.Mutation) bool { + return m.Op().Is(op) + } +} + +// HasAddedFields is a condition validating `.AddedField` on fields. +func HasAddedFields(field string, fields ...string) Condition { + return func(_ context.Context, m internal.Mutation) bool { + if _, exists := m.AddedField(field); !exists { + return false + } + for _, field := range fields { + if _, exists := m.AddedField(field); !exists { + return false + } + } + return true + } +} + +// HasClearedFields is a condition validating `.FieldCleared` on fields. +func HasClearedFields(field string, fields ...string) Condition { + return func(_ context.Context, m internal.Mutation) bool { + if exists := m.FieldCleared(field); !exists { + return false + } + for _, field := range fields { + if exists := m.FieldCleared(field); !exists { + return false + } + } + return true + } +} + +// HasFields is a condition validating `.Field` on fields. +func HasFields(field string, fields ...string) Condition { + return func(_ context.Context, m internal.Mutation) bool { + if _, exists := m.Field(field); !exists { + return false + } + for _, field := range fields { + if _, exists := m.Field(field); !exists { + return false + } + } + return true + } +} + +// If executes the given hook under condition. +// +// hook.If(ComputeAverage, And(HasFields(...), HasAddedFields(...))) +func If(hk internal.Hook, cond Condition) internal.Hook { + return func(next internal.Mutator) internal.Mutator { + return internal.MutateFunc(func(ctx context.Context, m internal.Mutation) (internal.Value, error) { + if cond(ctx, m) { + return hk(next).Mutate(ctx, m) + } + return next.Mutate(ctx, m) + }) + } +} + +// On executes the given hook only for the given operation. +// +// hook.On(Log, internal.Delete|internal.Create) +func On(hk internal.Hook, op internal.Op) internal.Hook { + return If(hk, HasOp(op)) +} + +// Unless skips the given hook only for the given operation. +// +// hook.Unless(Log, internal.Update|internal.UpdateOne) +func Unless(hk internal.Hook, op internal.Op) internal.Hook { + return If(hk, Not(HasOp(op))) +} + +// FixedError is a hook returning a fixed error. +func FixedError(err error) internal.Hook { + return func(internal.Mutator) internal.Mutator { + return internal.MutateFunc(func(context.Context, internal.Mutation) (internal.Value, error) { + return nil, err + }) + } +} + +// Reject returns a hook that rejects all operations that match op. +// +// func (T) Hooks() []internal.Hook { +// return []internal.Hook{ +// Reject(internal.Delete|internal.Update), +// } +// } +func Reject(op internal.Op) internal.Hook { + hk := FixedError(fmt.Errorf("%s operation is not allowed", op)) + return On(hk, op) +} + +// Chain acts as a list of hooks and is effectively immutable. +// Once created, it will always hold the same set of hooks in the same order. +type Chain struct { + hooks []internal.Hook +} + +// NewChain creates a new chain of hooks. +func NewChain(hooks ...internal.Hook) Chain { + return Chain{append([]internal.Hook(nil), hooks...)} +} + +// Hook chains the list of hooks and returns the final hook. +func (c Chain) Hook() internal.Hook { + return func(mutator internal.Mutator) internal.Mutator { + for i := len(c.hooks) - 1; i >= 0; i-- { + mutator = c.hooks[i](mutator) + } + return mutator + } +} + +// Append extends a chain, adding the specified hook +// as the last ones in the mutation flow. +func (c Chain) Append(hooks ...internal.Hook) Chain { + newHooks := make([]internal.Hook, 0, len(c.hooks)+len(hooks)) + newHooks = append(newHooks, c.hooks...) + newHooks = append(newHooks, hooks...) + return Chain{newHooks} +} + +// Extend extends a chain, adding the specified chain +// as the last ones in the mutation flow. +func (c Chain) Extend(chain Chain) Chain { + return c.Append(chain.hooks...) +} diff --git a/internal/db/ent/mailbox.go b/internal/db_impl/ent_db/internal/mailbox.go similarity index 90% rename from internal/db/ent/mailbox.go rename to internal/db_impl/ent_db/internal/mailbox.go index b863a41b..edf61c35 100644 --- a/internal/db/ent/mailbox.go +++ b/internal/db_impl/ent_db/internal/mailbox.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -8,7 +8,7 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" ) // Mailbox is the model entity for the Mailbox schema. @@ -83,8 +83,8 @@ func (e MailboxEdges) AttributesOrErr() ([]*MailboxAttr, error) { } // scanValues returns the types for scanning values from sql.Rows. -func (*Mailbox) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*Mailbox) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case mailbox.FieldSubscribed: @@ -102,7 +102,7 @@ func (*Mailbox) scanValues(columns []string) ([]interface{}, error) { // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the Mailbox fields. -func (m *Mailbox) assignValues(columns []string, values []interface{}) error { +func (m *Mailbox) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -151,29 +151,29 @@ func (m *Mailbox) assignValues(columns []string, values []interface{}) error { // QueryUIDs queries the "UIDs" edge of the Mailbox entity. func (m *Mailbox) QueryUIDs() *UIDQuery { - return (&MailboxClient{config: m.config}).QueryUIDs(m) + return NewMailboxClient(m.config).QueryUIDs(m) } // QueryFlags queries the "flags" edge of the Mailbox entity. func (m *Mailbox) QueryFlags() *MailboxFlagQuery { - return (&MailboxClient{config: m.config}).QueryFlags(m) + return NewMailboxClient(m.config).QueryFlags(m) } // QueryPermanentFlags queries the "permanent_flags" edge of the Mailbox entity. func (m *Mailbox) QueryPermanentFlags() *MailboxPermFlagQuery { - return (&MailboxClient{config: m.config}).QueryPermanentFlags(m) + return NewMailboxClient(m.config).QueryPermanentFlags(m) } // QueryAttributes queries the "attributes" edge of the Mailbox entity. func (m *Mailbox) QueryAttributes() *MailboxAttrQuery { - return (&MailboxClient{config: m.config}).QueryAttributes(m) + return NewMailboxClient(m.config).QueryAttributes(m) } // Update returns a builder for updating this Mailbox. // Note that you need to call Mailbox.Unwrap() before calling this method if this Mailbox // was returned from a transaction, and the transaction was committed or rolled back. func (m *Mailbox) Update() *MailboxUpdateOne { - return (&MailboxClient{config: m.config}).UpdateOne(m) + return NewMailboxClient(m.config).UpdateOne(m) } // Unwrap unwraps the Mailbox entity that was returned from a transaction after it was closed, @@ -181,7 +181,7 @@ func (m *Mailbox) Update() *MailboxUpdateOne { func (m *Mailbox) Unwrap() *Mailbox { _tx, ok := m.config.driver.(*txDriver) if !ok { - panic("ent: Mailbox is not a transactional entity") + panic("internal: Mailbox is not a transactional entity") } m.config.driver = _tx.drv return m @@ -212,9 +212,3 @@ func (m *Mailbox) String() string { // Mailboxes is a parsable slice of Mailbox. type Mailboxes []*Mailbox - -func (m Mailboxes) config(cfg config) { - for _i := range m { - m[_i].config = cfg - } -} diff --git a/internal/db/ent/mailbox/mailbox.go b/internal/db_impl/ent_db/internal/mailbox/mailbox.go similarity index 100% rename from internal/db/ent/mailbox/mailbox.go rename to internal/db_impl/ent_db/internal/mailbox/mailbox.go diff --git a/internal/db/ent/mailbox/where.go b/internal/db_impl/ent_db/internal/mailbox/where.go similarity index 65% rename from internal/db/ent/mailbox/where.go rename to internal/db_impl/ent_db/internal/mailbox/where.go index 64940642..bfba382c 100644 --- a/internal/db/ent/mailbox/where.go +++ b/internal/db_impl/ent_db/internal/mailbox/where.go @@ -6,493 +6,357 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Mailbox(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Mailbox(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Mailbox(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Mailbox(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Mailbox(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Mailbox(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id imap.InternalMailboxID) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Mailbox(sql.FieldLTE(FieldID, id)) } // RemoteID applies equality check predicate on the "RemoteID" field. It's identical to RemoteIDEQ. func RemoteID(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldRemoteID, vc)) } // Name applies equality check predicate on the "Name" field. It's identical to NameEQ. func Name(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldName, v)) } // UIDNext applies equality check predicate on the "UIDNext" field. It's identical to UIDNextEQ. func UIDNext(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUIDNext), vc)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldUIDNext, vc)) } // UIDValidity applies equality check predicate on the "UIDValidity" field. It's identical to UIDValidityEQ. func UIDValidity(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUIDValidity), vc)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldUIDValidity, vc)) } // Subscribed applies equality check predicate on the "Subscribed" field. It's identical to SubscribedEQ. func Subscribed(v bool) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSubscribed), v)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldSubscribed, v)) } // RemoteIDEQ applies the EQ predicate on the "RemoteID" field. func RemoteIDEQ(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldRemoteID, vc)) } // RemoteIDNEQ applies the NEQ predicate on the "RemoteID" field. func RemoteIDNEQ(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldNEQ(FieldRemoteID, vc)) } // RemoteIDIn applies the In predicate on the "RemoteID" field. func RemoteIDIn(vs ...imap.MailboxID) predicate.Mailbox { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = string(vs[i]) } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldRemoteID), v...)) - }) + return predicate.Mailbox(sql.FieldIn(FieldRemoteID, v...)) } // RemoteIDNotIn applies the NotIn predicate on the "RemoteID" field. func RemoteIDNotIn(vs ...imap.MailboxID) predicate.Mailbox { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = string(vs[i]) } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldRemoteID), v...)) - }) + return predicate.Mailbox(sql.FieldNotIn(FieldRemoteID, v...)) } // RemoteIDGT applies the GT predicate on the "RemoteID" field. func RemoteIDGT(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldGT(FieldRemoteID, vc)) } // RemoteIDGTE applies the GTE predicate on the "RemoteID" field. func RemoteIDGTE(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldGTE(FieldRemoteID, vc)) } // RemoteIDLT applies the LT predicate on the "RemoteID" field. func RemoteIDLT(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldLT(FieldRemoteID, vc)) } // RemoteIDLTE applies the LTE predicate on the "RemoteID" field. func RemoteIDLTE(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldLTE(FieldRemoteID, vc)) } // RemoteIDContains applies the Contains predicate on the "RemoteID" field. func RemoteIDContains(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldContains(FieldRemoteID, vc)) } // RemoteIDHasPrefix applies the HasPrefix predicate on the "RemoteID" field. func RemoteIDHasPrefix(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldHasPrefix(FieldRemoteID, vc)) } // RemoteIDHasSuffix applies the HasSuffix predicate on the "RemoteID" field. func RemoteIDHasSuffix(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldHasSuffix(FieldRemoteID, vc)) } // RemoteIDIsNil applies the IsNil predicate on the "RemoteID" field. func RemoteIDIsNil() predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldRemoteID))) - }) + return predicate.Mailbox(sql.FieldIsNull(FieldRemoteID)) } // RemoteIDNotNil applies the NotNil predicate on the "RemoteID" field. func RemoteIDNotNil() predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldRemoteID))) - }) + return predicate.Mailbox(sql.FieldNotNull(FieldRemoteID)) } // RemoteIDEqualFold applies the EqualFold predicate on the "RemoteID" field. func RemoteIDEqualFold(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldEqualFold(FieldRemoteID, vc)) } // RemoteIDContainsFold applies the ContainsFold predicate on the "RemoteID" field. func RemoteIDContainsFold(v imap.MailboxID) predicate.Mailbox { vc := string(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldRemoteID), vc)) - }) + return predicate.Mailbox(sql.FieldContainsFold(FieldRemoteID, vc)) } // NameEQ applies the EQ predicate on the "Name" field. func NameEQ(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldName, v)) } // NameNEQ applies the NEQ predicate on the "Name" field. func NameNEQ(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldNEQ(FieldName, v)) } // NameIn applies the In predicate on the "Name" field. func NameIn(vs ...string) predicate.Mailbox { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldName), v...)) - }) + return predicate.Mailbox(sql.FieldIn(FieldName, vs...)) } // NameNotIn applies the NotIn predicate on the "Name" field. func NameNotIn(vs ...string) predicate.Mailbox { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldName), v...)) - }) + return predicate.Mailbox(sql.FieldNotIn(FieldName, vs...)) } // NameGT applies the GT predicate on the "Name" field. func NameGT(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldGT(FieldName, v)) } // NameGTE applies the GTE predicate on the "Name" field. func NameGTE(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldGTE(FieldName, v)) } // NameLT applies the LT predicate on the "Name" field. func NameLT(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldLT(FieldName, v)) } // NameLTE applies the LTE predicate on the "Name" field. func NameLTE(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldLTE(FieldName, v)) } // NameContains applies the Contains predicate on the "Name" field. func NameContains(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldContains(FieldName, v)) } // NameHasPrefix applies the HasPrefix predicate on the "Name" field. func NameHasPrefix(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldHasPrefix(FieldName, v)) } // NameHasSuffix applies the HasSuffix predicate on the "Name" field. func NameHasSuffix(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldHasSuffix(FieldName, v)) } // NameEqualFold applies the EqualFold predicate on the "Name" field. func NameEqualFold(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldEqualFold(FieldName, v)) } // NameContainsFold applies the ContainsFold predicate on the "Name" field. func NameContainsFold(v string) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldName), v)) - }) + return predicate.Mailbox(sql.FieldContainsFold(FieldName, v)) } // UIDNextEQ applies the EQ predicate on the "UIDNext" field. func UIDNextEQ(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUIDNext), vc)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldUIDNext, vc)) } // UIDNextNEQ applies the NEQ predicate on the "UIDNext" field. func UIDNextNEQ(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUIDNext), vc)) - }) + return predicate.Mailbox(sql.FieldNEQ(FieldUIDNext, vc)) } // UIDNextIn applies the In predicate on the "UIDNext" field. func UIDNextIn(vs ...imap.UID) predicate.Mailbox { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = uint32(vs[i]) } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUIDNext), v...)) - }) + return predicate.Mailbox(sql.FieldIn(FieldUIDNext, v...)) } // UIDNextNotIn applies the NotIn predicate on the "UIDNext" field. func UIDNextNotIn(vs ...imap.UID) predicate.Mailbox { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = uint32(vs[i]) } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUIDNext), v...)) - }) + return predicate.Mailbox(sql.FieldNotIn(FieldUIDNext, v...)) } // UIDNextGT applies the GT predicate on the "UIDNext" field. func UIDNextGT(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUIDNext), vc)) - }) + return predicate.Mailbox(sql.FieldGT(FieldUIDNext, vc)) } // UIDNextGTE applies the GTE predicate on the "UIDNext" field. func UIDNextGTE(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUIDNext), vc)) - }) + return predicate.Mailbox(sql.FieldGTE(FieldUIDNext, vc)) } // UIDNextLT applies the LT predicate on the "UIDNext" field. func UIDNextLT(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUIDNext), vc)) - }) + return predicate.Mailbox(sql.FieldLT(FieldUIDNext, vc)) } // UIDNextLTE applies the LTE predicate on the "UIDNext" field. func UIDNextLTE(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUIDNext), vc)) - }) + return predicate.Mailbox(sql.FieldLTE(FieldUIDNext, vc)) } // UIDValidityEQ applies the EQ predicate on the "UIDValidity" field. func UIDValidityEQ(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUIDValidity), vc)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldUIDValidity, vc)) } // UIDValidityNEQ applies the NEQ predicate on the "UIDValidity" field. func UIDValidityNEQ(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUIDValidity), vc)) - }) + return predicate.Mailbox(sql.FieldNEQ(FieldUIDValidity, vc)) } // UIDValidityIn applies the In predicate on the "UIDValidity" field. func UIDValidityIn(vs ...imap.UID) predicate.Mailbox { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = uint32(vs[i]) } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUIDValidity), v...)) - }) + return predicate.Mailbox(sql.FieldIn(FieldUIDValidity, v...)) } // UIDValidityNotIn applies the NotIn predicate on the "UIDValidity" field. func UIDValidityNotIn(vs ...imap.UID) predicate.Mailbox { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = uint32(vs[i]) } - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUIDValidity), v...)) - }) + return predicate.Mailbox(sql.FieldNotIn(FieldUIDValidity, v...)) } // UIDValidityGT applies the GT predicate on the "UIDValidity" field. func UIDValidityGT(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUIDValidity), vc)) - }) + return predicate.Mailbox(sql.FieldGT(FieldUIDValidity, vc)) } // UIDValidityGTE applies the GTE predicate on the "UIDValidity" field. func UIDValidityGTE(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUIDValidity), vc)) - }) + return predicate.Mailbox(sql.FieldGTE(FieldUIDValidity, vc)) } // UIDValidityLT applies the LT predicate on the "UIDValidity" field. func UIDValidityLT(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUIDValidity), vc)) - }) + return predicate.Mailbox(sql.FieldLT(FieldUIDValidity, vc)) } // UIDValidityLTE applies the LTE predicate on the "UIDValidity" field. func UIDValidityLTE(v imap.UID) predicate.Mailbox { vc := uint32(v) - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUIDValidity), vc)) - }) + return predicate.Mailbox(sql.FieldLTE(FieldUIDValidity, vc)) } // SubscribedEQ applies the EQ predicate on the "Subscribed" field. func SubscribedEQ(v bool) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSubscribed), v)) - }) + return predicate.Mailbox(sql.FieldEQ(FieldSubscribed, v)) } // SubscribedNEQ applies the NEQ predicate on the "Subscribed" field. func SubscribedNEQ(v bool) predicate.Mailbox { - return predicate.Mailbox(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSubscribed), v)) - }) + return predicate.Mailbox(sql.FieldNEQ(FieldSubscribed, v)) } // HasUIDs applies the HasEdge predicate on the "UIDs" edge. @@ -500,7 +364,6 @@ func HasUIDs() predicate.Mailbox { return predicate.Mailbox(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(UIDsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, UIDsTable, UIDsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -528,7 +391,6 @@ func HasFlags() predicate.Mailbox { return predicate.Mailbox(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(FlagsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, FlagsTable, FlagsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -556,7 +418,6 @@ func HasPermanentFlags() predicate.Mailbox { return predicate.Mailbox(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(PermanentFlagsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, PermanentFlagsTable, PermanentFlagsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -584,7 +445,6 @@ func HasAttributes() predicate.Mailbox { return predicate.Mailbox(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(AttributesTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, AttributesTable, AttributesColumn), ) sqlgraph.HasNeighbors(s, step) diff --git a/internal/db/ent/mailbox_create.go b/internal/db_impl/ent_db/internal/mailbox_create.go similarity index 80% rename from internal/db/ent/mailbox_create.go rename to internal/db_impl/ent_db/internal/mailbox_create.go index b1fa61a3..9b831be9 100644 --- a/internal/db/ent/mailbox_create.go +++ b/internal/db_impl/ent_db/internal/mailbox_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,11 +10,11 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // MailboxCreate is the builder for creating a Mailbox entity. @@ -159,50 +159,8 @@ func (mc *MailboxCreate) Mutation() *MailboxMutation { // Save creates the Mailbox in the database. func (mc *MailboxCreate) Save(ctx context.Context) (*Mailbox, error) { - var ( - err error - node *Mailbox - ) mc.defaults() - if len(mc.hooks) == 0 { - if err = mc.check(); err != nil { - return nil, err - } - node, err = mc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mc.check(); err != nil { - return nil, err - } - mc.mutation = mutation - if node, err = mc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mc.hooks) - 1; i >= 0; i-- { - if mc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Mailbox) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxMutation", v) - } - node = nv - } - return node, err + return withHooks[*Mailbox, MailboxMutation](ctx, mc.sqlSave, mc.mutation, mc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -246,21 +204,24 @@ func (mc *MailboxCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (mc *MailboxCreate) check() error { if _, ok := mc.mutation.Name(); !ok { - return &ValidationError{Name: "Name", err: errors.New(`ent: missing required field "Mailbox.Name"`)} + return &ValidationError{Name: "Name", err: errors.New(`internal: missing required field "Mailbox.Name"`)} } if _, ok := mc.mutation.UIDNext(); !ok { - return &ValidationError{Name: "UIDNext", err: errors.New(`ent: missing required field "Mailbox.UIDNext"`)} + return &ValidationError{Name: "UIDNext", err: errors.New(`internal: missing required field "Mailbox.UIDNext"`)} } if _, ok := mc.mutation.UIDValidity(); !ok { - return &ValidationError{Name: "UIDValidity", err: errors.New(`ent: missing required field "Mailbox.UIDValidity"`)} + return &ValidationError{Name: "UIDValidity", err: errors.New(`internal: missing required field "Mailbox.UIDValidity"`)} } if _, ok := mc.mutation.Subscribed(); !ok { - return &ValidationError{Name: "Subscribed", err: errors.New(`ent: missing required field "Mailbox.Subscribed"`)} + return &ValidationError{Name: "Subscribed", err: errors.New(`internal: missing required field "Mailbox.Subscribed"`)} } return nil } func (mc *MailboxCreate) sqlSave(ctx context.Context) (*Mailbox, error) { + if err := mc.check(); err != nil { + return nil, err + } _node, _spec := mc.createSpec() if err := sqlgraph.CreateNode(ctx, mc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -272,62 +233,38 @@ func (mc *MailboxCreate) sqlSave(ctx context.Context) (*Mailbox, error) { id := _spec.ID.Value.(int64) _node.ID = imap.InternalMailboxID(id) } + mc.mutation.id = &_node.ID + mc.mutation.done = true return _node, nil } func (mc *MailboxCreate) createSpec() (*Mailbox, *sqlgraph.CreateSpec) { var ( _node = &Mailbox{config: mc.config} - _spec = &sqlgraph.CreateSpec{ - Table: mailbox.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, - Column: mailbox.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(mailbox.Table, sqlgraph.NewFieldSpec(mailbox.FieldID, field.TypeUint64)) ) if id, ok := mc.mutation.ID(); ok { _node.ID = id _spec.ID.Value = id } if value, ok := mc.mutation.RemoteID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailbox.FieldRemoteID, - }) + _spec.SetField(mailbox.FieldRemoteID, field.TypeString, value) _node.RemoteID = value } if value, ok := mc.mutation.Name(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailbox.FieldName, - }) + _spec.SetField(mailbox.FieldName, field.TypeString, value) _node.Name = value } if value, ok := mc.mutation.UIDNext(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDNext, - }) + _spec.SetField(mailbox.FieldUIDNext, field.TypeUint32, value) _node.UIDNext = value } if value, ok := mc.mutation.UIDValidity(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDValidity, - }) + _spec.SetField(mailbox.FieldUIDValidity, field.TypeUint32, value) _node.UIDValidity = value } if value, ok := mc.mutation.Subscribed(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: mailbox.FieldSubscribed, - }) + _spec.SetField(mailbox.FieldSubscribed, field.TypeBool, value) _node.Subscribed = value } if nodes := mc.mutation.UIDsIDs(); len(nodes) > 0 { diff --git a/internal/db/ent/mailbox_delete.go b/internal/db_impl/ent_db/internal/mailbox_delete.go similarity index 62% rename from internal/db/ent/mailbox_delete.go rename to internal/db_impl/ent_db/internal/mailbox_delete.go index 5fd9ac3d..a8074cc9 100644 --- a/internal/db/ent/mailbox_delete.go +++ b/internal/db_impl/ent_db/internal/mailbox_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxDelete is the builder for deleting a Mailbox entity. @@ -28,34 +27,7 @@ func (md *MailboxDelete) Where(ps ...predicate.Mailbox) *MailboxDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (md *MailboxDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(md.hooks) == 0 { - affected, err = md.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - md.mutation = mutation - affected, err = md.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(md.hooks) - 1; i >= 0; i-- { - if md.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = md.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, md.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxMutation](ctx, md.sqlExec, md.mutation, md.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (md *MailboxDelete) ExecX(ctx context.Context) int { } func (md *MailboxDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailbox.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, - Column: mailbox.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(mailbox.Table, sqlgraph.NewFieldSpec(mailbox.FieldID, field.TypeUint64)) if ps := md.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (md *MailboxDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + md.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MailboxDeleteOne struct { md *MailboxDelete } +// Where appends a list predicates to the MailboxDelete builder. +func (mdo *MailboxDeleteOne) Where(ps ...predicate.Mailbox) *MailboxDeleteOne { + mdo.md.mutation.Where(ps...) + return mdo +} + // Exec executes the deletion query. func (mdo *MailboxDeleteOne) Exec(ctx context.Context) error { n, err := mdo.md.Exec(ctx) @@ -111,5 +82,7 @@ func (mdo *MailboxDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mdo *MailboxDeleteOne) ExecX(ctx context.Context) { - mdo.md.ExecX(ctx) + if err := mdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/mailbox_query.go b/internal/db_impl/ent_db/internal/mailbox_query.go similarity index 77% rename from internal/db/ent/mailbox_query.go rename to internal/db_impl/ent_db/internal/mailbox_query.go index 0afa7e5e..a21664be 100644 --- a/internal/db/ent/mailbox_query.go +++ b/internal/db_impl/ent_db/internal/mailbox_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -12,22 +12,20 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // MailboxQuery is the builder for querying Mailbox entities. type MailboxQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.Mailbox withUIDs *UIDQuery withFlags *MailboxFlagQuery @@ -44,26 +42,26 @@ func (mq *MailboxQuery) Where(ps ...predicate.Mailbox) *MailboxQuery { return mq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mq *MailboxQuery) Limit(limit int) *MailboxQuery { - mq.limit = &limit + mq.ctx.Limit = &limit return mq } -// Offset adds an offset step to the query. +// Offset to start from. func (mq *MailboxQuery) Offset(offset int) *MailboxQuery { - mq.offset = &offset + mq.ctx.Offset = &offset return mq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mq *MailboxQuery) Unique(unique bool) *MailboxQuery { - mq.unique = &unique + mq.ctx.Unique = &unique return mq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (mq *MailboxQuery) Order(o ...OrderFunc) *MailboxQuery { mq.order = append(mq.order, o...) return mq @@ -71,7 +69,7 @@ func (mq *MailboxQuery) Order(o ...OrderFunc) *MailboxQuery { // QueryUIDs chains the current query on the "UIDs" edge. func (mq *MailboxQuery) QueryUIDs() *UIDQuery { - query := &UIDQuery{config: mq.config} + query := (&UIDClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -93,7 +91,7 @@ func (mq *MailboxQuery) QueryUIDs() *UIDQuery { // QueryFlags chains the current query on the "flags" edge. func (mq *MailboxQuery) QueryFlags() *MailboxFlagQuery { - query := &MailboxFlagQuery{config: mq.config} + query := (&MailboxFlagClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -115,7 +113,7 @@ func (mq *MailboxQuery) QueryFlags() *MailboxFlagQuery { // QueryPermanentFlags chains the current query on the "permanent_flags" edge. func (mq *MailboxQuery) QueryPermanentFlags() *MailboxPermFlagQuery { - query := &MailboxPermFlagQuery{config: mq.config} + query := (&MailboxPermFlagClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -137,7 +135,7 @@ func (mq *MailboxQuery) QueryPermanentFlags() *MailboxPermFlagQuery { // QueryAttributes chains the current query on the "attributes" edge. func (mq *MailboxQuery) QueryAttributes() *MailboxAttrQuery { - query := &MailboxAttrQuery{config: mq.config} + query := (&MailboxAttrClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -160,7 +158,7 @@ func (mq *MailboxQuery) QueryAttributes() *MailboxAttrQuery { // First returns the first Mailbox entity from the query. // Returns a *NotFoundError when no Mailbox was found. func (mq *MailboxQuery) First(ctx context.Context) (*Mailbox, error) { - nodes, err := mq.Limit(1).All(ctx) + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) if err != nil { return nil, err } @@ -183,7 +181,7 @@ func (mq *MailboxQuery) FirstX(ctx context.Context) *Mailbox { // Returns a *NotFoundError when no Mailbox ID was found. func (mq *MailboxQuery) FirstID(ctx context.Context) (id imap.InternalMailboxID, err error) { var ids []imap.InternalMailboxID - if ids, err = mq.Limit(1).IDs(ctx); err != nil { + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -206,7 +204,7 @@ func (mq *MailboxQuery) FirstIDX(ctx context.Context) imap.InternalMailboxID { // Returns a *NotSingularError when more than one Mailbox entity is found. // Returns a *NotFoundError when no Mailbox entities are found. func (mq *MailboxQuery) Only(ctx context.Context) (*Mailbox, error) { - nodes, err := mq.Limit(2).All(ctx) + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) if err != nil { return nil, err } @@ -234,7 +232,7 @@ func (mq *MailboxQuery) OnlyX(ctx context.Context) *Mailbox { // Returns a *NotFoundError when no entities are found. func (mq *MailboxQuery) OnlyID(ctx context.Context) (id imap.InternalMailboxID, err error) { var ids []imap.InternalMailboxID - if ids, err = mq.Limit(2).IDs(ctx); err != nil { + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -259,10 +257,12 @@ func (mq *MailboxQuery) OnlyIDX(ctx context.Context) imap.InternalMailboxID { // All executes the query and returns a list of Mailboxes. func (mq *MailboxQuery) All(ctx context.Context) ([]*Mailbox, error) { + ctx = setContextOp(ctx, mq.ctx, "All") if err := mq.prepareQuery(ctx); err != nil { return nil, err } - return mq.sqlAll(ctx) + qr := querierAll[[]*Mailbox, *MailboxQuery]() + return withInterceptors[[]*Mailbox](ctx, mq, qr, mq.inters) } // AllX is like All, but panics if an error occurs. @@ -275,9 +275,12 @@ func (mq *MailboxQuery) AllX(ctx context.Context) []*Mailbox { } // IDs executes the query and returns a list of Mailbox IDs. -func (mq *MailboxQuery) IDs(ctx context.Context) ([]imap.InternalMailboxID, error) { - var ids []imap.InternalMailboxID - if err := mq.Select(mailbox.FieldID).Scan(ctx, &ids); err != nil { +func (mq *MailboxQuery) IDs(ctx context.Context) (ids []imap.InternalMailboxID, err error) { + if mq.ctx.Unique == nil && mq.path != nil { + mq.Unique(true) + } + ctx = setContextOp(ctx, mq.ctx, "IDs") + if err = mq.Select(mailbox.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -294,10 +297,11 @@ func (mq *MailboxQuery) IDsX(ctx context.Context) []imap.InternalMailboxID { // Count returns the count of the given query. func (mq *MailboxQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mq.ctx, "Count") if err := mq.prepareQuery(ctx); err != nil { return 0, err } - return mq.sqlCount(ctx) + return withInterceptors[int](ctx, mq, querierCount[*MailboxQuery](), mq.inters) } // CountX is like Count, but panics if an error occurs. @@ -311,10 +315,15 @@ func (mq *MailboxQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mq *MailboxQuery) Exist(ctx context.Context) (bool, error) { - if err := mq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mq.ctx, "Exist") + switch _, err := mq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return mq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -334,25 +343,24 @@ func (mq *MailboxQuery) Clone() *MailboxQuery { } return &MailboxQuery{ config: mq.config, - limit: mq.limit, - offset: mq.offset, + ctx: mq.ctx.Clone(), order: append([]OrderFunc{}, mq.order...), + inters: append([]Interceptor{}, mq.inters...), predicates: append([]predicate.Mailbox{}, mq.predicates...), withUIDs: mq.withUIDs.Clone(), withFlags: mq.withFlags.Clone(), withPermanentFlags: mq.withPermanentFlags.Clone(), withAttributes: mq.withAttributes.Clone(), // clone intermediate query. - sql: mq.sql.Clone(), - path: mq.path, - unique: mq.unique, + sql: mq.sql.Clone(), + path: mq.path, } } // WithUIDs tells the query-builder to eager-load the nodes that are connected to // the "UIDs" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MailboxQuery) WithUIDs(opts ...func(*UIDQuery)) *MailboxQuery { - query := &UIDQuery{config: mq.config} + query := (&UIDClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -363,7 +371,7 @@ func (mq *MailboxQuery) WithUIDs(opts ...func(*UIDQuery)) *MailboxQuery { // WithFlags tells the query-builder to eager-load the nodes that are connected to // the "flags" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MailboxQuery) WithFlags(opts ...func(*MailboxFlagQuery)) *MailboxQuery { - query := &MailboxFlagQuery{config: mq.config} + query := (&MailboxFlagClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -374,7 +382,7 @@ func (mq *MailboxQuery) WithFlags(opts ...func(*MailboxFlagQuery)) *MailboxQuery // WithPermanentFlags tells the query-builder to eager-load the nodes that are connected to // the "permanent_flags" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MailboxQuery) WithPermanentFlags(opts ...func(*MailboxPermFlagQuery)) *MailboxQuery { - query := &MailboxPermFlagQuery{config: mq.config} + query := (&MailboxPermFlagClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -385,7 +393,7 @@ func (mq *MailboxQuery) WithPermanentFlags(opts ...func(*MailboxPermFlagQuery)) // WithAttributes tells the query-builder to eager-load the nodes that are connected to // the "attributes" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MailboxQuery) WithAttributes(opts ...func(*MailboxAttrQuery)) *MailboxQuery { - query := &MailboxAttrQuery{config: mq.config} + query := (&MailboxAttrClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -405,19 +413,14 @@ func (mq *MailboxQuery) WithAttributes(opts ...func(*MailboxAttrQuery)) *Mailbox // // client.Mailbox.Query(). // GroupBy(mailbox.FieldRemoteID). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (mq *MailboxQuery) GroupBy(field string, fields ...string) *MailboxGroupBy { - grbuild := &MailboxGroupBy{config: mq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mq.prepareQuery(ctx); err != nil { - return nil, err - } - return mq.sqlQuery(ctx), nil - } + mq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MailboxGroupBy{build: mq} + grbuild.flds = &mq.ctx.Fields grbuild.label = mailbox.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -434,17 +437,32 @@ func (mq *MailboxQuery) GroupBy(field string, fields ...string) *MailboxGroupBy // Select(mailbox.FieldRemoteID). // Scan(ctx, &v) func (mq *MailboxQuery) Select(fields ...string) *MailboxSelect { - mq.fields = append(mq.fields, fields...) - selbuild := &MailboxSelect{MailboxQuery: mq} - selbuild.label = mailbox.Label - selbuild.flds, selbuild.scan = &mq.fields, selbuild.Scan - return selbuild + mq.ctx.Fields = append(mq.ctx.Fields, fields...) + sbuild := &MailboxSelect{MailboxQuery: mq} + sbuild.label = mailbox.Label + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MailboxSelect configured with the given aggregations. +func (mq *MailboxQuery) Aggregate(fns ...AggregateFunc) *MailboxSelect { + return mq.Select().Aggregate(fns...) } func (mq *MailboxQuery) prepareQuery(ctx context.Context) error { - for _, f := range mq.fields { + for _, inter := range mq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mq); err != nil { + return err + } + } + } + for _, f := range mq.ctx.Fields { if !mailbox.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if mq.path != nil { @@ -468,10 +486,10 @@ func (mq *MailboxQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Mail mq.withAttributes != nil, } ) - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*Mailbox).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &Mailbox{config: mq.config} nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes @@ -644,38 +662,22 @@ func (mq *MailboxQuery) loadAttributes(ctx context.Context, query *MailboxAttrQu func (mq *MailboxQuery) sqlCount(ctx context.Context) (int, error) { _spec := mq.querySpec() - _spec.Node.Columns = mq.fields - if len(mq.fields) > 0 { - _spec.Unique = mq.unique != nil && *mq.unique + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique } return sqlgraph.CountNodes(ctx, mq.driver, _spec) } -func (mq *MailboxQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := mq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (mq *MailboxQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailbox.Table, - Columns: mailbox.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, - Column: mailbox.FieldID, - }, - }, - From: mq.sql, - Unique: true, - } - if unique := mq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(mailbox.Table, mailbox.Columns, sqlgraph.NewFieldSpec(mailbox.FieldID, field.TypeUint64)) + _spec.From = mq.sql + if unique := mq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mq.path != nil { + _spec.Unique = true } - if fields := mq.fields; len(fields) > 0 { + if fields := mq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, mailbox.FieldID) for i := range fields { @@ -691,10 +693,10 @@ func (mq *MailboxQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mq.order; len(ps) > 0 { @@ -710,7 +712,7 @@ func (mq *MailboxQuery) querySpec() *sqlgraph.QuerySpec { func (mq *MailboxQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mq.driver.Dialect()) t1 := builder.Table(mailbox.Table) - columns := mq.fields + columns := mq.ctx.Fields if len(columns) == 0 { columns = mailbox.Columns } @@ -719,7 +721,7 @@ func (mq *MailboxQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mq.sql selector.Select(selector.Columns(columns...)...) } - if mq.unique != nil && *mq.unique { + if mq.ctx.Unique != nil && *mq.ctx.Unique { selector.Distinct() } for _, p := range mq.predicates { @@ -728,12 +730,12 @@ func (mq *MailboxQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mq.order { p(selector) } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -741,13 +743,8 @@ func (mq *MailboxQuery) sqlQuery(ctx context.Context) *sql.Selector { // MailboxGroupBy is the group-by builder for Mailbox entities. type MailboxGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MailboxQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -756,74 +753,77 @@ func (mgb *MailboxGroupBy) Aggregate(fns ...AggregateFunc) *MailboxGroupBy { return mgb } -// Scan applies the group-by query and scans the result into the given value. -func (mgb *MailboxGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := mgb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (mgb *MailboxGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") + if err := mgb.build.prepareQuery(ctx); err != nil { return err } - mgb.sql = query - return mgb.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxQuery, *MailboxGroupBy](ctx, mgb.build, mgb, mgb.build.inters, v) } -func (mgb *MailboxGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range mgb.fields { - if !mailbox.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mgb *MailboxGroupBy) sqlScan(ctx context.Context, root *MailboxQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mgb.fns)) + for _, fn := range mgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mgb.flds)+len(mgb.fns)) + for _, f := range *mgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mgb.sqlQuery() + selector.GroupBy(selector.Columns(*mgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mgb *MailboxGroupBy) sqlQuery() *sql.Selector { - selector := mgb.sql.Select() - aggregation := make([]string, 0, len(mgb.fns)) - for _, fn := range mgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mgb.fields)+len(mgb.fns)) - for _, f := range mgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mgb.fields...)...) -} - // MailboxSelect is the builder for selecting fields of Mailbox entities. type MailboxSelect struct { *MailboxQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ms *MailboxSelect) Aggregate(fns ...AggregateFunc) *MailboxSelect { + ms.fns = append(ms.fns, fns...) + return ms } // Scan applies the selector query and scans the result into the given value. -func (ms *MailboxSelect) Scan(ctx context.Context, v interface{}) error { +func (ms *MailboxSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ms.ctx, "Select") if err := ms.prepareQuery(ctx); err != nil { return err } - ms.sql = ms.MailboxQuery.sqlQuery(ctx) - return ms.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxQuery, *MailboxSelect](ctx, ms.MailboxQuery, ms, ms.inters, v) } -func (ms *MailboxSelect) sqlScan(ctx context.Context, v interface{}) error { +func (ms *MailboxSelect) sqlScan(ctx context.Context, root *MailboxQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ms.fns)) + for _, fn := range ms.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ms.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ms.sql.Query() + query, args := selector.Query() if err := ms.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/mailbox_update.go b/internal/db_impl/ent_db/internal/mailbox_update.go similarity index 85% rename from internal/db/ent/mailbox_update.go rename to internal/db_impl/ent_db/internal/mailbox_update.go index 46ceca2a..150a494c 100644 --- a/internal/db/ent/mailbox_update.go +++ b/internal/db_impl/ent_db/internal/mailbox_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -11,12 +11,12 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // MailboxUpdate is the builder for updating Mailbox entities. @@ -265,34 +265,7 @@ func (mu *MailboxUpdate) RemoveAttributes(m ...*MailboxAttr) *MailboxUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (mu *MailboxUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mu.hooks) == 0 { - affected, err = mu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mu.mutation = mutation - affected, err = mu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mu.hooks) - 1; i >= 0; i-- { - if mu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxMutation](ctx, mu.sqlSave, mu.mutation, mu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -318,16 +291,7 @@ func (mu *MailboxUpdate) ExecX(ctx context.Context) { } func (mu *MailboxUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailbox.Table, - Columns: mailbox.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, - Column: mailbox.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailbox.Table, mailbox.Columns, sqlgraph.NewFieldSpec(mailbox.FieldID, field.TypeUint64)) if ps := mu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -336,59 +300,28 @@ func (mu *MailboxUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := mu.mutation.RemoteID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailbox.FieldRemoteID, - }) + _spec.SetField(mailbox.FieldRemoteID, field.TypeString, value) } if mu.mutation.RemoteIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: mailbox.FieldRemoteID, - }) + _spec.ClearField(mailbox.FieldRemoteID, field.TypeString) } if value, ok := mu.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailbox.FieldName, - }) + _spec.SetField(mailbox.FieldName, field.TypeString, value) } if value, ok := mu.mutation.UIDNext(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDNext, - }) + _spec.SetField(mailbox.FieldUIDNext, field.TypeUint32, value) } if value, ok := mu.mutation.AddedUIDNext(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDNext, - }) + _spec.AddField(mailbox.FieldUIDNext, field.TypeUint32, value) } if value, ok := mu.mutation.UIDValidity(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDValidity, - }) + _spec.SetField(mailbox.FieldUIDValidity, field.TypeUint32, value) } if value, ok := mu.mutation.AddedUIDValidity(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDValidity, - }) + _spec.AddField(mailbox.FieldUIDValidity, field.TypeUint32, value) } if value, ok := mu.mutation.Subscribed(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: mailbox.FieldSubscribed, - }) + _spec.SetField(mailbox.FieldSubscribed, field.TypeBool, value) } if mu.mutation.UIDsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -614,6 +547,7 @@ func (mu *MailboxUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mu.mutation.done = true return n, nil } @@ -856,6 +790,12 @@ func (muo *MailboxUpdateOne) RemoveAttributes(m ...*MailboxAttr) *MailboxUpdateO return muo.RemoveAttributeIDs(ids...) } +// Where appends a list predicates to the MailboxUpdate builder. +func (muo *MailboxUpdateOne) Where(ps ...predicate.Mailbox) *MailboxUpdateOne { + muo.mutation.Where(ps...) + return muo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (muo *MailboxUpdateOne) Select(field string, fields ...string) *MailboxUpdateOne { @@ -865,40 +805,7 @@ func (muo *MailboxUpdateOne) Select(field string, fields ...string) *MailboxUpda // Save executes the query and returns the updated Mailbox entity. func (muo *MailboxUpdateOne) Save(ctx context.Context) (*Mailbox, error) { - var ( - err error - node *Mailbox - ) - if len(muo.hooks) == 0 { - node, err = muo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - muo.mutation = mutation - node, err = muo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(muo.hooks) - 1; i >= 0; i-- { - if muo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = muo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, muo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Mailbox) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxMutation", v) - } - node = nv - } - return node, err + return withHooks[*Mailbox, MailboxMutation](ctx, muo.sqlSave, muo.mutation, muo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -924,19 +831,10 @@ func (muo *MailboxUpdateOne) ExecX(ctx context.Context) { } func (muo *MailboxUpdateOne) sqlSave(ctx context.Context) (_node *Mailbox, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailbox.Table, - Columns: mailbox.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUint64, - Column: mailbox.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailbox.Table, mailbox.Columns, sqlgraph.NewFieldSpec(mailbox.FieldID, field.TypeUint64)) id, ok := muo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Mailbox.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "Mailbox.id" for update`)} } _spec.Node.ID.Value = id if fields := muo.fields; len(fields) > 0 { @@ -944,7 +842,7 @@ func (muo *MailboxUpdateOne) sqlSave(ctx context.Context) (_node *Mailbox, err e _spec.Node.Columns = append(_spec.Node.Columns, mailbox.FieldID) for _, f := range fields { if !mailbox.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != mailbox.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -959,59 +857,28 @@ func (muo *MailboxUpdateOne) sqlSave(ctx context.Context) (_node *Mailbox, err e } } if value, ok := muo.mutation.RemoteID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailbox.FieldRemoteID, - }) + _spec.SetField(mailbox.FieldRemoteID, field.TypeString, value) } if muo.mutation.RemoteIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: mailbox.FieldRemoteID, - }) + _spec.ClearField(mailbox.FieldRemoteID, field.TypeString) } if value, ok := muo.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailbox.FieldName, - }) + _spec.SetField(mailbox.FieldName, field.TypeString, value) } if value, ok := muo.mutation.UIDNext(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDNext, - }) + _spec.SetField(mailbox.FieldUIDNext, field.TypeUint32, value) } if value, ok := muo.mutation.AddedUIDNext(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDNext, - }) + _spec.AddField(mailbox.FieldUIDNext, field.TypeUint32, value) } if value, ok := muo.mutation.UIDValidity(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDValidity, - }) + _spec.SetField(mailbox.FieldUIDValidity, field.TypeUint32, value) } if value, ok := muo.mutation.AddedUIDValidity(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: mailbox.FieldUIDValidity, - }) + _spec.AddField(mailbox.FieldUIDValidity, field.TypeUint32, value) } if value, ok := muo.mutation.Subscribed(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: mailbox.FieldSubscribed, - }) + _spec.SetField(mailbox.FieldSubscribed, field.TypeBool, value) } if muo.mutation.UIDsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -1240,5 +1107,6 @@ func (muo *MailboxUpdateOne) sqlSave(ctx context.Context) (_node *Mailbox, err e } return nil, err } + muo.mutation.done = true return _node, nil } diff --git a/internal/db/ent/mailboxattr.go b/internal/db_impl/ent_db/internal/mailboxattr.go similarity index 85% rename from internal/db/ent/mailboxattr.go rename to internal/db_impl/ent_db/internal/mailboxattr.go index fa548244..aef2a814 100644 --- a/internal/db/ent/mailboxattr.go +++ b/internal/db_impl/ent_db/internal/mailboxattr.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -8,7 +8,7 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" ) // MailboxAttr is the model entity for the MailboxAttr schema. @@ -22,8 +22,8 @@ type MailboxAttr struct { } // scanValues returns the types for scanning values from sql.Rows. -func (*MailboxAttr) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*MailboxAttr) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case mailboxattr.FieldID: @@ -41,7 +41,7 @@ func (*MailboxAttr) scanValues(columns []string) ([]interface{}, error) { // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the MailboxAttr fields. -func (ma *MailboxAttr) assignValues(columns []string, values []interface{}) error { +func (ma *MailboxAttr) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -75,7 +75,7 @@ func (ma *MailboxAttr) assignValues(columns []string, values []interface{}) erro // Note that you need to call MailboxAttr.Unwrap() before calling this method if this MailboxAttr // was returned from a transaction, and the transaction was committed or rolled back. func (ma *MailboxAttr) Update() *MailboxAttrUpdateOne { - return (&MailboxAttrClient{config: ma.config}).UpdateOne(ma) + return NewMailboxAttrClient(ma.config).UpdateOne(ma) } // Unwrap unwraps the MailboxAttr entity that was returned from a transaction after it was closed, @@ -83,7 +83,7 @@ func (ma *MailboxAttr) Update() *MailboxAttrUpdateOne { func (ma *MailboxAttr) Unwrap() *MailboxAttr { _tx, ok := ma.config.driver.(*txDriver) if !ok { - panic("ent: MailboxAttr is not a transactional entity") + panic("internal: MailboxAttr is not a transactional entity") } ma.config.driver = _tx.drv return ma @@ -102,9 +102,3 @@ func (ma *MailboxAttr) String() string { // MailboxAttrs is a parsable slice of MailboxAttr. type MailboxAttrs []*MailboxAttr - -func (ma MailboxAttrs) config(cfg config) { - for _i := range ma { - ma[_i].config = cfg - } -} diff --git a/internal/db/ent/mailboxattr/mailboxattr.go b/internal/db_impl/ent_db/internal/mailboxattr/mailboxattr.go similarity index 100% rename from internal/db/ent/mailboxattr/mailboxattr.go rename to internal/db_impl/ent_db/internal/mailboxattr/mailboxattr.go diff --git a/internal/db/ent/mailboxattr/where.go b/internal/db_impl/ent_db/internal/mailboxattr/where.go similarity index 56% rename from internal/db/ent/mailboxattr/where.go rename to internal/db_impl/ent_db/internal/mailboxattr/where.go index 84a16fe4..6ea922dd 100644 --- a/internal/db/ent/mailboxattr/where.go +++ b/internal/db_impl/ent_db/internal/mailboxattr/where.go @@ -4,184 +4,122 @@ package mailboxattr import ( "entgo.io/ent/dialect/sql" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MailboxAttr(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MailboxAttr(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.MailboxAttr(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.MailboxAttr(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.MailboxAttr(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.MailboxAttr(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.MailboxAttr(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.MailboxAttr(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.MailboxAttr(sql.FieldLTE(FieldID, id)) } // Value applies equality check predicate on the "Value" field. It's identical to ValueEQ. func Value(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldEQ(FieldValue, v)) } // ValueEQ applies the EQ predicate on the "Value" field. func ValueEQ(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "Value" field. func ValueNEQ(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "Value" field. func ValueIn(vs ...string) predicate.MailboxAttr { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.MailboxAttr(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "Value" field. func ValueNotIn(vs ...string) predicate.MailboxAttr { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.MailboxAttr(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "Value" field. func ValueGT(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "Value" field. func ValueGTE(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "Value" field. func ValueLT(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "Value" field. func ValueLTE(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "Value" field. func ValueContains(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "Value" field. func ValueHasPrefix(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "Value" field. func ValueHasSuffix(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "Value" field. func ValueEqualFold(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "Value" field. func ValueContainsFold(v string) predicate.MailboxAttr { - return predicate.MailboxAttr(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.MailboxAttr(sql.FieldContainsFold(FieldValue, v)) } // And groups predicates with the AND operator between them. diff --git a/internal/db/ent/mailboxattr_create.go b/internal/db_impl/ent_db/internal/mailboxattr_create.go similarity index 74% rename from internal/db/ent/mailboxattr_create.go rename to internal/db_impl/ent_db/internal/mailboxattr_create.go index 4908d550..0cb6b4f0 100644 --- a/internal/db/ent/mailboxattr_create.go +++ b/internal/db_impl/ent_db/internal/mailboxattr_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -9,7 +9,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" ) // MailboxAttrCreate is the builder for creating a MailboxAttr entity. @@ -32,49 +32,7 @@ func (mac *MailboxAttrCreate) Mutation() *MailboxAttrMutation { // Save creates the MailboxAttr in the database. func (mac *MailboxAttrCreate) Save(ctx context.Context) (*MailboxAttr, error) { - var ( - err error - node *MailboxAttr - ) - if len(mac.hooks) == 0 { - if err = mac.check(); err != nil { - return nil, err - } - node, err = mac.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxAttrMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mac.check(); err != nil { - return nil, err - } - mac.mutation = mutation - if node, err = mac.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mac.hooks) - 1; i >= 0; i-- { - if mac.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mac.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mac.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MailboxAttr) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxAttrMutation", v) - } - node = nv - } - return node, err + return withHooks[*MailboxAttr, MailboxAttrMutation](ctx, mac.sqlSave, mac.mutation, mac.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -102,12 +60,15 @@ func (mac *MailboxAttrCreate) ExecX(ctx context.Context) { // check runs all checks and user-defined validators on the builder. func (mac *MailboxAttrCreate) check() error { if _, ok := mac.mutation.Value(); !ok { - return &ValidationError{Name: "Value", err: errors.New(`ent: missing required field "MailboxAttr.Value"`)} + return &ValidationError{Name: "Value", err: errors.New(`internal: missing required field "MailboxAttr.Value"`)} } return nil } func (mac *MailboxAttrCreate) sqlSave(ctx context.Context) (*MailboxAttr, error) { + if err := mac.check(); err != nil { + return nil, err + } _node, _spec := mac.createSpec() if err := sqlgraph.CreateNode(ctx, mac.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -117,26 +78,18 @@ func (mac *MailboxAttrCreate) sqlSave(ctx context.Context) (*MailboxAttr, error) } id := _spec.ID.Value.(int64) _node.ID = int(id) + mac.mutation.id = &_node.ID + mac.mutation.done = true return _node, nil } func (mac *MailboxAttrCreate) createSpec() (*MailboxAttr, *sqlgraph.CreateSpec) { var ( _node = &MailboxAttr{config: mac.config} - _spec = &sqlgraph.CreateSpec{ - Table: mailboxattr.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxattr.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(mailboxattr.Table, sqlgraph.NewFieldSpec(mailboxattr.FieldID, field.TypeInt)) ) if value, ok := mac.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxattr.FieldValue, - }) + _spec.SetField(mailboxattr.FieldValue, field.TypeString, value) _node.Value = value } return _node, _spec diff --git a/internal/db/ent/mailboxattr_delete.go b/internal/db_impl/ent_db/internal/mailboxattr_delete.go similarity index 62% rename from internal/db/ent/mailboxattr_delete.go rename to internal/db_impl/ent_db/internal/mailboxattr_delete.go index a13092bd..fbaa3e78 100644 --- a/internal/db/ent/mailboxattr_delete.go +++ b/internal/db_impl/ent_db/internal/mailboxattr_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxAttrDelete is the builder for deleting a MailboxAttr entity. @@ -28,34 +27,7 @@ func (mad *MailboxAttrDelete) Where(ps ...predicate.MailboxAttr) *MailboxAttrDel // Exec executes the deletion query and returns how many vertices were deleted. func (mad *MailboxAttrDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mad.hooks) == 0 { - affected, err = mad.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxAttrMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mad.mutation = mutation - affected, err = mad.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(mad.hooks) - 1; i >= 0; i-- { - if mad.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mad.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mad.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxAttrMutation](ctx, mad.sqlExec, mad.mutation, mad.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (mad *MailboxAttrDelete) ExecX(ctx context.Context) int { } func (mad *MailboxAttrDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxattr.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxattr.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(mailboxattr.Table, sqlgraph.NewFieldSpec(mailboxattr.FieldID, field.TypeInt)) if ps := mad.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (mad *MailboxAttrDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + mad.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MailboxAttrDeleteOne struct { mad *MailboxAttrDelete } +// Where appends a list predicates to the MailboxAttrDelete builder. +func (mado *MailboxAttrDeleteOne) Where(ps ...predicate.MailboxAttr) *MailboxAttrDeleteOne { + mado.mad.mutation.Where(ps...) + return mado +} + // Exec executes the deletion query. func (mado *MailboxAttrDeleteOne) Exec(ctx context.Context) error { n, err := mado.mad.Exec(ctx) @@ -111,5 +82,7 @@ func (mado *MailboxAttrDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mado *MailboxAttrDeleteOne) ExecX(ctx context.Context) { - mado.mad.ExecX(ctx) + if err := mado.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/mailboxattr_query.go b/internal/db_impl/ent_db/internal/mailboxattr_query.go similarity index 67% rename from internal/db/ent/mailboxattr_query.go rename to internal/db_impl/ent_db/internal/mailboxattr_query.go index bb5fcf54..7974ba3a 100644 --- a/internal/db/ent/mailboxattr_query.go +++ b/internal/db_impl/ent_db/internal/mailboxattr_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,18 +10,16 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxAttrQuery is the builder for querying MailboxAttr entities. type MailboxAttrQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.MailboxAttr withFKs bool // intermediate query (i.e. traversal path). @@ -35,26 +33,26 @@ func (maq *MailboxAttrQuery) Where(ps ...predicate.MailboxAttr) *MailboxAttrQuer return maq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (maq *MailboxAttrQuery) Limit(limit int) *MailboxAttrQuery { - maq.limit = &limit + maq.ctx.Limit = &limit return maq } -// Offset adds an offset step to the query. +// Offset to start from. func (maq *MailboxAttrQuery) Offset(offset int) *MailboxAttrQuery { - maq.offset = &offset + maq.ctx.Offset = &offset return maq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (maq *MailboxAttrQuery) Unique(unique bool) *MailboxAttrQuery { - maq.unique = &unique + maq.ctx.Unique = &unique return maq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (maq *MailboxAttrQuery) Order(o ...OrderFunc) *MailboxAttrQuery { maq.order = append(maq.order, o...) return maq @@ -63,7 +61,7 @@ func (maq *MailboxAttrQuery) Order(o ...OrderFunc) *MailboxAttrQuery { // First returns the first MailboxAttr entity from the query. // Returns a *NotFoundError when no MailboxAttr was found. func (maq *MailboxAttrQuery) First(ctx context.Context) (*MailboxAttr, error) { - nodes, err := maq.Limit(1).All(ctx) + nodes, err := maq.Limit(1).All(setContextOp(ctx, maq.ctx, "First")) if err != nil { return nil, err } @@ -86,7 +84,7 @@ func (maq *MailboxAttrQuery) FirstX(ctx context.Context) *MailboxAttr { // Returns a *NotFoundError when no MailboxAttr ID was found. func (maq *MailboxAttrQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = maq.Limit(1).IDs(ctx); err != nil { + if ids, err = maq.Limit(1).IDs(setContextOp(ctx, maq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -109,7 +107,7 @@ func (maq *MailboxAttrQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one MailboxAttr entity is found. // Returns a *NotFoundError when no MailboxAttr entities are found. func (maq *MailboxAttrQuery) Only(ctx context.Context) (*MailboxAttr, error) { - nodes, err := maq.Limit(2).All(ctx) + nodes, err := maq.Limit(2).All(setContextOp(ctx, maq.ctx, "Only")) if err != nil { return nil, err } @@ -137,7 +135,7 @@ func (maq *MailboxAttrQuery) OnlyX(ctx context.Context) *MailboxAttr { // Returns a *NotFoundError when no entities are found. func (maq *MailboxAttrQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = maq.Limit(2).IDs(ctx); err != nil { + if ids, err = maq.Limit(2).IDs(setContextOp(ctx, maq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -162,10 +160,12 @@ func (maq *MailboxAttrQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of MailboxAttrs. func (maq *MailboxAttrQuery) All(ctx context.Context) ([]*MailboxAttr, error) { + ctx = setContextOp(ctx, maq.ctx, "All") if err := maq.prepareQuery(ctx); err != nil { return nil, err } - return maq.sqlAll(ctx) + qr := querierAll[[]*MailboxAttr, *MailboxAttrQuery]() + return withInterceptors[[]*MailboxAttr](ctx, maq, qr, maq.inters) } // AllX is like All, but panics if an error occurs. @@ -178,9 +178,12 @@ func (maq *MailboxAttrQuery) AllX(ctx context.Context) []*MailboxAttr { } // IDs executes the query and returns a list of MailboxAttr IDs. -func (maq *MailboxAttrQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := maq.Select(mailboxattr.FieldID).Scan(ctx, &ids); err != nil { +func (maq *MailboxAttrQuery) IDs(ctx context.Context) (ids []int, err error) { + if maq.ctx.Unique == nil && maq.path != nil { + maq.Unique(true) + } + ctx = setContextOp(ctx, maq.ctx, "IDs") + if err = maq.Select(mailboxattr.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -197,10 +200,11 @@ func (maq *MailboxAttrQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (maq *MailboxAttrQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, maq.ctx, "Count") if err := maq.prepareQuery(ctx); err != nil { return 0, err } - return maq.sqlCount(ctx) + return withInterceptors[int](ctx, maq, querierCount[*MailboxAttrQuery](), maq.inters) } // CountX is like Count, but panics if an error occurs. @@ -214,10 +218,15 @@ func (maq *MailboxAttrQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (maq *MailboxAttrQuery) Exist(ctx context.Context) (bool, error) { - if err := maq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, maq.ctx, "Exist") + switch _, err := maq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return maq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -237,14 +246,13 @@ func (maq *MailboxAttrQuery) Clone() *MailboxAttrQuery { } return &MailboxAttrQuery{ config: maq.config, - limit: maq.limit, - offset: maq.offset, + ctx: maq.ctx.Clone(), order: append([]OrderFunc{}, maq.order...), + inters: append([]Interceptor{}, maq.inters...), predicates: append([]predicate.MailboxAttr{}, maq.predicates...), // clone intermediate query. - sql: maq.sql.Clone(), - path: maq.path, - unique: maq.unique, + sql: maq.sql.Clone(), + path: maq.path, } } @@ -260,19 +268,14 @@ func (maq *MailboxAttrQuery) Clone() *MailboxAttrQuery { // // client.MailboxAttr.Query(). // GroupBy(mailboxattr.FieldValue). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (maq *MailboxAttrQuery) GroupBy(field string, fields ...string) *MailboxAttrGroupBy { - grbuild := &MailboxAttrGroupBy{config: maq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := maq.prepareQuery(ctx); err != nil { - return nil, err - } - return maq.sqlQuery(ctx), nil - } + maq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MailboxAttrGroupBy{build: maq} + grbuild.flds = &maq.ctx.Fields grbuild.label = mailboxattr.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -289,17 +292,32 @@ func (maq *MailboxAttrQuery) GroupBy(field string, fields ...string) *MailboxAtt // Select(mailboxattr.FieldValue). // Scan(ctx, &v) func (maq *MailboxAttrQuery) Select(fields ...string) *MailboxAttrSelect { - maq.fields = append(maq.fields, fields...) - selbuild := &MailboxAttrSelect{MailboxAttrQuery: maq} - selbuild.label = mailboxattr.Label - selbuild.flds, selbuild.scan = &maq.fields, selbuild.Scan - return selbuild + maq.ctx.Fields = append(maq.ctx.Fields, fields...) + sbuild := &MailboxAttrSelect{MailboxAttrQuery: maq} + sbuild.label = mailboxattr.Label + sbuild.flds, sbuild.scan = &maq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MailboxAttrSelect configured with the given aggregations. +func (maq *MailboxAttrQuery) Aggregate(fns ...AggregateFunc) *MailboxAttrSelect { + return maq.Select().Aggregate(fns...) } func (maq *MailboxAttrQuery) prepareQuery(ctx context.Context) error { - for _, f := range maq.fields { + for _, inter := range maq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, maq); err != nil { + return err + } + } + } + for _, f := range maq.ctx.Fields { if !mailboxattr.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if maq.path != nil { @@ -321,10 +339,10 @@ func (maq *MailboxAttrQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([] if withFKs { _spec.Node.Columns = append(_spec.Node.Columns, mailboxattr.ForeignKeys...) } - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*MailboxAttr).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &MailboxAttr{config: maq.config} nodes = append(nodes, node) return node.assignValues(columns, values) @@ -343,38 +361,22 @@ func (maq *MailboxAttrQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([] func (maq *MailboxAttrQuery) sqlCount(ctx context.Context) (int, error) { _spec := maq.querySpec() - _spec.Node.Columns = maq.fields - if len(maq.fields) > 0 { - _spec.Unique = maq.unique != nil && *maq.unique + _spec.Node.Columns = maq.ctx.Fields + if len(maq.ctx.Fields) > 0 { + _spec.Unique = maq.ctx.Unique != nil && *maq.ctx.Unique } return sqlgraph.CountNodes(ctx, maq.driver, _spec) } -func (maq *MailboxAttrQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := maq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (maq *MailboxAttrQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxattr.Table, - Columns: mailboxattr.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxattr.FieldID, - }, - }, - From: maq.sql, - Unique: true, - } - if unique := maq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(mailboxattr.Table, mailboxattr.Columns, sqlgraph.NewFieldSpec(mailboxattr.FieldID, field.TypeInt)) + _spec.From = maq.sql + if unique := maq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if maq.path != nil { + _spec.Unique = true } - if fields := maq.fields; len(fields) > 0 { + if fields := maq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, mailboxattr.FieldID) for i := range fields { @@ -390,10 +392,10 @@ func (maq *MailboxAttrQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := maq.limit; limit != nil { + if limit := maq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := maq.offset; offset != nil { + if offset := maq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := maq.order; len(ps) > 0 { @@ -409,7 +411,7 @@ func (maq *MailboxAttrQuery) querySpec() *sqlgraph.QuerySpec { func (maq *MailboxAttrQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(maq.driver.Dialect()) t1 := builder.Table(mailboxattr.Table) - columns := maq.fields + columns := maq.ctx.Fields if len(columns) == 0 { columns = mailboxattr.Columns } @@ -418,7 +420,7 @@ func (maq *MailboxAttrQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = maq.sql selector.Select(selector.Columns(columns...)...) } - if maq.unique != nil && *maq.unique { + if maq.ctx.Unique != nil && *maq.ctx.Unique { selector.Distinct() } for _, p := range maq.predicates { @@ -427,12 +429,12 @@ func (maq *MailboxAttrQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range maq.order { p(selector) } - if offset := maq.offset; offset != nil { + if offset := maq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := maq.limit; limit != nil { + if limit := maq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -440,13 +442,8 @@ func (maq *MailboxAttrQuery) sqlQuery(ctx context.Context) *sql.Selector { // MailboxAttrGroupBy is the group-by builder for MailboxAttr entities. type MailboxAttrGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MailboxAttrQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -455,74 +452,77 @@ func (magb *MailboxAttrGroupBy) Aggregate(fns ...AggregateFunc) *MailboxAttrGrou return magb } -// Scan applies the group-by query and scans the result into the given value. -func (magb *MailboxAttrGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := magb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (magb *MailboxAttrGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, magb.build.ctx, "GroupBy") + if err := magb.build.prepareQuery(ctx); err != nil { return err } - magb.sql = query - return magb.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxAttrQuery, *MailboxAttrGroupBy](ctx, magb.build, magb, magb.build.inters, v) } -func (magb *MailboxAttrGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range magb.fields { - if !mailboxattr.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (magb *MailboxAttrGroupBy) sqlScan(ctx context.Context, root *MailboxAttrQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(magb.fns)) + for _, fn := range magb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*magb.flds)+len(magb.fns)) + for _, f := range *magb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := magb.sqlQuery() + selector.GroupBy(selector.Columns(*magb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := magb.driver.Query(ctx, query, args, rows); err != nil { + if err := magb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (magb *MailboxAttrGroupBy) sqlQuery() *sql.Selector { - selector := magb.sql.Select() - aggregation := make([]string, 0, len(magb.fns)) - for _, fn := range magb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(magb.fields)+len(magb.fns)) - for _, f := range magb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(magb.fields...)...) -} - // MailboxAttrSelect is the builder for selecting fields of MailboxAttr entities. type MailboxAttrSelect struct { *MailboxAttrQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (mas *MailboxAttrSelect) Aggregate(fns ...AggregateFunc) *MailboxAttrSelect { + mas.fns = append(mas.fns, fns...) + return mas } // Scan applies the selector query and scans the result into the given value. -func (mas *MailboxAttrSelect) Scan(ctx context.Context, v interface{}) error { +func (mas *MailboxAttrSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mas.ctx, "Select") if err := mas.prepareQuery(ctx); err != nil { return err } - mas.sql = mas.MailboxAttrQuery.sqlQuery(ctx) - return mas.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxAttrQuery, *MailboxAttrSelect](ctx, mas.MailboxAttrQuery, mas, mas.inters, v) } -func (mas *MailboxAttrSelect) sqlScan(ctx context.Context, v interface{}) error { +func (mas *MailboxAttrSelect) sqlScan(ctx context.Context, root *MailboxAttrQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(mas.fns)) + for _, fn := range mas.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*mas.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := mas.sql.Query() + query, args := selector.Query() if err := mas.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/mailboxattr_update.go b/internal/db_impl/ent_db/internal/mailboxattr_update.go similarity index 63% rename from internal/db/ent/mailboxattr_update.go rename to internal/db_impl/ent_db/internal/mailboxattr_update.go index 73f65979..8998256a 100644 --- a/internal/db/ent/mailboxattr_update.go +++ b/internal/db_impl/ent_db/internal/mailboxattr_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,8 +10,8 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxAttrUpdate is the builder for updating MailboxAttr entities. @@ -40,34 +40,7 @@ func (mau *MailboxAttrUpdate) Mutation() *MailboxAttrMutation { // Save executes the query and returns the number of nodes affected by the update operation. func (mau *MailboxAttrUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mau.hooks) == 0 { - affected, err = mau.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxAttrMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mau.mutation = mutation - affected, err = mau.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mau.hooks) - 1; i >= 0; i-- { - if mau.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mau.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mau.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxAttrMutation](ctx, mau.sqlSave, mau.mutation, mau.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -93,16 +66,7 @@ func (mau *MailboxAttrUpdate) ExecX(ctx context.Context) { } func (mau *MailboxAttrUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxattr.Table, - Columns: mailboxattr.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxattr.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailboxattr.Table, mailboxattr.Columns, sqlgraph.NewFieldSpec(mailboxattr.FieldID, field.TypeInt)) if ps := mau.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -111,11 +75,7 @@ func (mau *MailboxAttrUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := mau.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxattr.FieldValue, - }) + _spec.SetField(mailboxattr.FieldValue, field.TypeString, value) } if n, err = sqlgraph.UpdateNodes(ctx, mau.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -125,6 +85,7 @@ func (mau *MailboxAttrUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mau.mutation.done = true return n, nil } @@ -147,6 +108,12 @@ func (mauo *MailboxAttrUpdateOne) Mutation() *MailboxAttrMutation { return mauo.mutation } +// Where appends a list predicates to the MailboxAttrUpdate builder. +func (mauo *MailboxAttrUpdateOne) Where(ps ...predicate.MailboxAttr) *MailboxAttrUpdateOne { + mauo.mutation.Where(ps...) + return mauo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (mauo *MailboxAttrUpdateOne) Select(field string, fields ...string) *MailboxAttrUpdateOne { @@ -156,40 +123,7 @@ func (mauo *MailboxAttrUpdateOne) Select(field string, fields ...string) *Mailbo // Save executes the query and returns the updated MailboxAttr entity. func (mauo *MailboxAttrUpdateOne) Save(ctx context.Context) (*MailboxAttr, error) { - var ( - err error - node *MailboxAttr - ) - if len(mauo.hooks) == 0 { - node, err = mauo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxAttrMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mauo.mutation = mutation - node, err = mauo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(mauo.hooks) - 1; i >= 0; i-- { - if mauo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mauo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mauo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MailboxAttr) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxAttrMutation", v) - } - node = nv - } - return node, err + return withHooks[*MailboxAttr, MailboxAttrMutation](ctx, mauo.sqlSave, mauo.mutation, mauo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -215,19 +149,10 @@ func (mauo *MailboxAttrUpdateOne) ExecX(ctx context.Context) { } func (mauo *MailboxAttrUpdateOne) sqlSave(ctx context.Context) (_node *MailboxAttr, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxattr.Table, - Columns: mailboxattr.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxattr.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailboxattr.Table, mailboxattr.Columns, sqlgraph.NewFieldSpec(mailboxattr.FieldID, field.TypeInt)) id, ok := mauo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "MailboxAttr.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "MailboxAttr.id" for update`)} } _spec.Node.ID.Value = id if fields := mauo.fields; len(fields) > 0 { @@ -235,7 +160,7 @@ func (mauo *MailboxAttrUpdateOne) sqlSave(ctx context.Context) (_node *MailboxAt _spec.Node.Columns = append(_spec.Node.Columns, mailboxattr.FieldID) for _, f := range fields { if !mailboxattr.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != mailboxattr.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -250,11 +175,7 @@ func (mauo *MailboxAttrUpdateOne) sqlSave(ctx context.Context) (_node *MailboxAt } } if value, ok := mauo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxattr.FieldValue, - }) + _spec.SetField(mailboxattr.FieldValue, field.TypeString, value) } _node = &MailboxAttr{config: mauo.config} _spec.Assign = _node.assignValues @@ -267,5 +188,6 @@ func (mauo *MailboxAttrUpdateOne) sqlSave(ctx context.Context) (_node *MailboxAt } return nil, err } + mauo.mutation.done = true return _node, nil } diff --git a/internal/db/ent/mailboxflag.go b/internal/db_impl/ent_db/internal/mailboxflag.go similarity index 85% rename from internal/db/ent/mailboxflag.go rename to internal/db_impl/ent_db/internal/mailboxflag.go index 18ce8cf5..d2a5d39e 100644 --- a/internal/db/ent/mailboxflag.go +++ b/internal/db_impl/ent_db/internal/mailboxflag.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -8,7 +8,7 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" ) // MailboxFlag is the model entity for the MailboxFlag schema. @@ -22,8 +22,8 @@ type MailboxFlag struct { } // scanValues returns the types for scanning values from sql.Rows. -func (*MailboxFlag) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*MailboxFlag) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case mailboxflag.FieldID: @@ -41,7 +41,7 @@ func (*MailboxFlag) scanValues(columns []string) ([]interface{}, error) { // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the MailboxFlag fields. -func (mf *MailboxFlag) assignValues(columns []string, values []interface{}) error { +func (mf *MailboxFlag) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -75,7 +75,7 @@ func (mf *MailboxFlag) assignValues(columns []string, values []interface{}) erro // Note that you need to call MailboxFlag.Unwrap() before calling this method if this MailboxFlag // was returned from a transaction, and the transaction was committed or rolled back. func (mf *MailboxFlag) Update() *MailboxFlagUpdateOne { - return (&MailboxFlagClient{config: mf.config}).UpdateOne(mf) + return NewMailboxFlagClient(mf.config).UpdateOne(mf) } // Unwrap unwraps the MailboxFlag entity that was returned from a transaction after it was closed, @@ -83,7 +83,7 @@ func (mf *MailboxFlag) Update() *MailboxFlagUpdateOne { func (mf *MailboxFlag) Unwrap() *MailboxFlag { _tx, ok := mf.config.driver.(*txDriver) if !ok { - panic("ent: MailboxFlag is not a transactional entity") + panic("internal: MailboxFlag is not a transactional entity") } mf.config.driver = _tx.drv return mf @@ -102,9 +102,3 @@ func (mf *MailboxFlag) String() string { // MailboxFlags is a parsable slice of MailboxFlag. type MailboxFlags []*MailboxFlag - -func (mf MailboxFlags) config(cfg config) { - for _i := range mf { - mf[_i].config = cfg - } -} diff --git a/internal/db/ent/mailboxflag/mailboxflag.go b/internal/db_impl/ent_db/internal/mailboxflag/mailboxflag.go similarity index 100% rename from internal/db/ent/mailboxflag/mailboxflag.go rename to internal/db_impl/ent_db/internal/mailboxflag/mailboxflag.go diff --git a/internal/db/ent/mailboxflag/where.go b/internal/db_impl/ent_db/internal/mailboxflag/where.go similarity index 56% rename from internal/db/ent/mailboxflag/where.go rename to internal/db_impl/ent_db/internal/mailboxflag/where.go index 004f34f3..dab75500 100644 --- a/internal/db/ent/mailboxflag/where.go +++ b/internal/db_impl/ent_db/internal/mailboxflag/where.go @@ -4,184 +4,122 @@ package mailboxflag import ( "entgo.io/ent/dialect/sql" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MailboxFlag(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MailboxFlag(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.MailboxFlag(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.MailboxFlag(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.MailboxFlag(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.MailboxFlag(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.MailboxFlag(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.MailboxFlag(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.MailboxFlag(sql.FieldLTE(FieldID, id)) } // Value applies equality check predicate on the "Value" field. It's identical to ValueEQ. func Value(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldEQ(FieldValue, v)) } // ValueEQ applies the EQ predicate on the "Value" field. func ValueEQ(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "Value" field. func ValueNEQ(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "Value" field. func ValueIn(vs ...string) predicate.MailboxFlag { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.MailboxFlag(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "Value" field. func ValueNotIn(vs ...string) predicate.MailboxFlag { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.MailboxFlag(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "Value" field. func ValueGT(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "Value" field. func ValueGTE(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "Value" field. func ValueLT(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "Value" field. func ValueLTE(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "Value" field. func ValueContains(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "Value" field. func ValueHasPrefix(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "Value" field. func ValueHasSuffix(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "Value" field. func ValueEqualFold(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "Value" field. func ValueContainsFold(v string) predicate.MailboxFlag { - return predicate.MailboxFlag(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.MailboxFlag(sql.FieldContainsFold(FieldValue, v)) } // And groups predicates with the AND operator between them. diff --git a/internal/db/ent/mailboxflag_create.go b/internal/db_impl/ent_db/internal/mailboxflag_create.go similarity index 74% rename from internal/db/ent/mailboxflag_create.go rename to internal/db_impl/ent_db/internal/mailboxflag_create.go index fa262a37..dd6fb8e0 100644 --- a/internal/db/ent/mailboxflag_create.go +++ b/internal/db_impl/ent_db/internal/mailboxflag_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -9,7 +9,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" ) // MailboxFlagCreate is the builder for creating a MailboxFlag entity. @@ -32,49 +32,7 @@ func (mfc *MailboxFlagCreate) Mutation() *MailboxFlagMutation { // Save creates the MailboxFlag in the database. func (mfc *MailboxFlagCreate) Save(ctx context.Context) (*MailboxFlag, error) { - var ( - err error - node *MailboxFlag - ) - if len(mfc.hooks) == 0 { - if err = mfc.check(); err != nil { - return nil, err - } - node, err = mfc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mfc.check(); err != nil { - return nil, err - } - mfc.mutation = mutation - if node, err = mfc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mfc.hooks) - 1; i >= 0; i-- { - if mfc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mfc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MailboxFlag) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxFlagMutation", v) - } - node = nv - } - return node, err + return withHooks[*MailboxFlag, MailboxFlagMutation](ctx, mfc.sqlSave, mfc.mutation, mfc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -102,12 +60,15 @@ func (mfc *MailboxFlagCreate) ExecX(ctx context.Context) { // check runs all checks and user-defined validators on the builder. func (mfc *MailboxFlagCreate) check() error { if _, ok := mfc.mutation.Value(); !ok { - return &ValidationError{Name: "Value", err: errors.New(`ent: missing required field "MailboxFlag.Value"`)} + return &ValidationError{Name: "Value", err: errors.New(`internal: missing required field "MailboxFlag.Value"`)} } return nil } func (mfc *MailboxFlagCreate) sqlSave(ctx context.Context) (*MailboxFlag, error) { + if err := mfc.check(); err != nil { + return nil, err + } _node, _spec := mfc.createSpec() if err := sqlgraph.CreateNode(ctx, mfc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -117,26 +78,18 @@ func (mfc *MailboxFlagCreate) sqlSave(ctx context.Context) (*MailboxFlag, error) } id := _spec.ID.Value.(int64) _node.ID = int(id) + mfc.mutation.id = &_node.ID + mfc.mutation.done = true return _node, nil } func (mfc *MailboxFlagCreate) createSpec() (*MailboxFlag, *sqlgraph.CreateSpec) { var ( _node = &MailboxFlag{config: mfc.config} - _spec = &sqlgraph.CreateSpec{ - Table: mailboxflag.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxflag.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(mailboxflag.Table, sqlgraph.NewFieldSpec(mailboxflag.FieldID, field.TypeInt)) ) if value, ok := mfc.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxflag.FieldValue, - }) + _spec.SetField(mailboxflag.FieldValue, field.TypeString, value) _node.Value = value } return _node, _spec diff --git a/internal/db/ent/mailboxflag_delete.go b/internal/db_impl/ent_db/internal/mailboxflag_delete.go similarity index 62% rename from internal/db/ent/mailboxflag_delete.go rename to internal/db_impl/ent_db/internal/mailboxflag_delete.go index 01ba933c..7fd9e0d9 100644 --- a/internal/db/ent/mailboxflag_delete.go +++ b/internal/db_impl/ent_db/internal/mailboxflag_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxFlagDelete is the builder for deleting a MailboxFlag entity. @@ -28,34 +27,7 @@ func (mfd *MailboxFlagDelete) Where(ps ...predicate.MailboxFlag) *MailboxFlagDel // Exec executes the deletion query and returns how many vertices were deleted. func (mfd *MailboxFlagDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mfd.hooks) == 0 { - affected, err = mfd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mfd.mutation = mutation - affected, err = mfd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(mfd.hooks) - 1; i >= 0; i-- { - if mfd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mfd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxFlagMutation](ctx, mfd.sqlExec, mfd.mutation, mfd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (mfd *MailboxFlagDelete) ExecX(ctx context.Context) int { } func (mfd *MailboxFlagDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxflag.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(mailboxflag.Table, sqlgraph.NewFieldSpec(mailboxflag.FieldID, field.TypeInt)) if ps := mfd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (mfd *MailboxFlagDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + mfd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MailboxFlagDeleteOne struct { mfd *MailboxFlagDelete } +// Where appends a list predicates to the MailboxFlagDelete builder. +func (mfdo *MailboxFlagDeleteOne) Where(ps ...predicate.MailboxFlag) *MailboxFlagDeleteOne { + mfdo.mfd.mutation.Where(ps...) + return mfdo +} + // Exec executes the deletion query. func (mfdo *MailboxFlagDeleteOne) Exec(ctx context.Context) error { n, err := mfdo.mfd.Exec(ctx) @@ -111,5 +82,7 @@ func (mfdo *MailboxFlagDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mfdo *MailboxFlagDeleteOne) ExecX(ctx context.Context) { - mfdo.mfd.ExecX(ctx) + if err := mfdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/mailboxflag_query.go b/internal/db_impl/ent_db/internal/mailboxflag_query.go similarity index 67% rename from internal/db/ent/mailboxflag_query.go rename to internal/db_impl/ent_db/internal/mailboxflag_query.go index 0f04bbbe..bfb150d1 100644 --- a/internal/db/ent/mailboxflag_query.go +++ b/internal/db_impl/ent_db/internal/mailboxflag_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,18 +10,16 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxFlagQuery is the builder for querying MailboxFlag entities. type MailboxFlagQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.MailboxFlag withFKs bool // intermediate query (i.e. traversal path). @@ -35,26 +33,26 @@ func (mfq *MailboxFlagQuery) Where(ps ...predicate.MailboxFlag) *MailboxFlagQuer return mfq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mfq *MailboxFlagQuery) Limit(limit int) *MailboxFlagQuery { - mfq.limit = &limit + mfq.ctx.Limit = &limit return mfq } -// Offset adds an offset step to the query. +// Offset to start from. func (mfq *MailboxFlagQuery) Offset(offset int) *MailboxFlagQuery { - mfq.offset = &offset + mfq.ctx.Offset = &offset return mfq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mfq *MailboxFlagQuery) Unique(unique bool) *MailboxFlagQuery { - mfq.unique = &unique + mfq.ctx.Unique = &unique return mfq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (mfq *MailboxFlagQuery) Order(o ...OrderFunc) *MailboxFlagQuery { mfq.order = append(mfq.order, o...) return mfq @@ -63,7 +61,7 @@ func (mfq *MailboxFlagQuery) Order(o ...OrderFunc) *MailboxFlagQuery { // First returns the first MailboxFlag entity from the query. // Returns a *NotFoundError when no MailboxFlag was found. func (mfq *MailboxFlagQuery) First(ctx context.Context) (*MailboxFlag, error) { - nodes, err := mfq.Limit(1).All(ctx) + nodes, err := mfq.Limit(1).All(setContextOp(ctx, mfq.ctx, "First")) if err != nil { return nil, err } @@ -86,7 +84,7 @@ func (mfq *MailboxFlagQuery) FirstX(ctx context.Context) *MailboxFlag { // Returns a *NotFoundError when no MailboxFlag ID was found. func (mfq *MailboxFlagQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mfq.Limit(1).IDs(ctx); err != nil { + if ids, err = mfq.Limit(1).IDs(setContextOp(ctx, mfq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -109,7 +107,7 @@ func (mfq *MailboxFlagQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one MailboxFlag entity is found. // Returns a *NotFoundError when no MailboxFlag entities are found. func (mfq *MailboxFlagQuery) Only(ctx context.Context) (*MailboxFlag, error) { - nodes, err := mfq.Limit(2).All(ctx) + nodes, err := mfq.Limit(2).All(setContextOp(ctx, mfq.ctx, "Only")) if err != nil { return nil, err } @@ -137,7 +135,7 @@ func (mfq *MailboxFlagQuery) OnlyX(ctx context.Context) *MailboxFlag { // Returns a *NotFoundError when no entities are found. func (mfq *MailboxFlagQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mfq.Limit(2).IDs(ctx); err != nil { + if ids, err = mfq.Limit(2).IDs(setContextOp(ctx, mfq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -162,10 +160,12 @@ func (mfq *MailboxFlagQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of MailboxFlags. func (mfq *MailboxFlagQuery) All(ctx context.Context) ([]*MailboxFlag, error) { + ctx = setContextOp(ctx, mfq.ctx, "All") if err := mfq.prepareQuery(ctx); err != nil { return nil, err } - return mfq.sqlAll(ctx) + qr := querierAll[[]*MailboxFlag, *MailboxFlagQuery]() + return withInterceptors[[]*MailboxFlag](ctx, mfq, qr, mfq.inters) } // AllX is like All, but panics if an error occurs. @@ -178,9 +178,12 @@ func (mfq *MailboxFlagQuery) AllX(ctx context.Context) []*MailboxFlag { } // IDs executes the query and returns a list of MailboxFlag IDs. -func (mfq *MailboxFlagQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := mfq.Select(mailboxflag.FieldID).Scan(ctx, &ids); err != nil { +func (mfq *MailboxFlagQuery) IDs(ctx context.Context) (ids []int, err error) { + if mfq.ctx.Unique == nil && mfq.path != nil { + mfq.Unique(true) + } + ctx = setContextOp(ctx, mfq.ctx, "IDs") + if err = mfq.Select(mailboxflag.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -197,10 +200,11 @@ func (mfq *MailboxFlagQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (mfq *MailboxFlagQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mfq.ctx, "Count") if err := mfq.prepareQuery(ctx); err != nil { return 0, err } - return mfq.sqlCount(ctx) + return withInterceptors[int](ctx, mfq, querierCount[*MailboxFlagQuery](), mfq.inters) } // CountX is like Count, but panics if an error occurs. @@ -214,10 +218,15 @@ func (mfq *MailboxFlagQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mfq *MailboxFlagQuery) Exist(ctx context.Context) (bool, error) { - if err := mfq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mfq.ctx, "Exist") + switch _, err := mfq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return mfq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -237,14 +246,13 @@ func (mfq *MailboxFlagQuery) Clone() *MailboxFlagQuery { } return &MailboxFlagQuery{ config: mfq.config, - limit: mfq.limit, - offset: mfq.offset, + ctx: mfq.ctx.Clone(), order: append([]OrderFunc{}, mfq.order...), + inters: append([]Interceptor{}, mfq.inters...), predicates: append([]predicate.MailboxFlag{}, mfq.predicates...), // clone intermediate query. - sql: mfq.sql.Clone(), - path: mfq.path, - unique: mfq.unique, + sql: mfq.sql.Clone(), + path: mfq.path, } } @@ -260,19 +268,14 @@ func (mfq *MailboxFlagQuery) Clone() *MailboxFlagQuery { // // client.MailboxFlag.Query(). // GroupBy(mailboxflag.FieldValue). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (mfq *MailboxFlagQuery) GroupBy(field string, fields ...string) *MailboxFlagGroupBy { - grbuild := &MailboxFlagGroupBy{config: mfq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mfq.prepareQuery(ctx); err != nil { - return nil, err - } - return mfq.sqlQuery(ctx), nil - } + mfq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MailboxFlagGroupBy{build: mfq} + grbuild.flds = &mfq.ctx.Fields grbuild.label = mailboxflag.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -289,17 +292,32 @@ func (mfq *MailboxFlagQuery) GroupBy(field string, fields ...string) *MailboxFla // Select(mailboxflag.FieldValue). // Scan(ctx, &v) func (mfq *MailboxFlagQuery) Select(fields ...string) *MailboxFlagSelect { - mfq.fields = append(mfq.fields, fields...) - selbuild := &MailboxFlagSelect{MailboxFlagQuery: mfq} - selbuild.label = mailboxflag.Label - selbuild.flds, selbuild.scan = &mfq.fields, selbuild.Scan - return selbuild + mfq.ctx.Fields = append(mfq.ctx.Fields, fields...) + sbuild := &MailboxFlagSelect{MailboxFlagQuery: mfq} + sbuild.label = mailboxflag.Label + sbuild.flds, sbuild.scan = &mfq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MailboxFlagSelect configured with the given aggregations. +func (mfq *MailboxFlagQuery) Aggregate(fns ...AggregateFunc) *MailboxFlagSelect { + return mfq.Select().Aggregate(fns...) } func (mfq *MailboxFlagQuery) prepareQuery(ctx context.Context) error { - for _, f := range mfq.fields { + for _, inter := range mfq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mfq); err != nil { + return err + } + } + } + for _, f := range mfq.ctx.Fields { if !mailboxflag.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if mfq.path != nil { @@ -321,10 +339,10 @@ func (mfq *MailboxFlagQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([] if withFKs { _spec.Node.Columns = append(_spec.Node.Columns, mailboxflag.ForeignKeys...) } - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*MailboxFlag).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &MailboxFlag{config: mfq.config} nodes = append(nodes, node) return node.assignValues(columns, values) @@ -343,38 +361,22 @@ func (mfq *MailboxFlagQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([] func (mfq *MailboxFlagQuery) sqlCount(ctx context.Context) (int, error) { _spec := mfq.querySpec() - _spec.Node.Columns = mfq.fields - if len(mfq.fields) > 0 { - _spec.Unique = mfq.unique != nil && *mfq.unique + _spec.Node.Columns = mfq.ctx.Fields + if len(mfq.ctx.Fields) > 0 { + _spec.Unique = mfq.ctx.Unique != nil && *mfq.ctx.Unique } return sqlgraph.CountNodes(ctx, mfq.driver, _spec) } -func (mfq *MailboxFlagQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := mfq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (mfq *MailboxFlagQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxflag.Table, - Columns: mailboxflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxflag.FieldID, - }, - }, - From: mfq.sql, - Unique: true, - } - if unique := mfq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(mailboxflag.Table, mailboxflag.Columns, sqlgraph.NewFieldSpec(mailboxflag.FieldID, field.TypeInt)) + _spec.From = mfq.sql + if unique := mfq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mfq.path != nil { + _spec.Unique = true } - if fields := mfq.fields; len(fields) > 0 { + if fields := mfq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, mailboxflag.FieldID) for i := range fields { @@ -390,10 +392,10 @@ func (mfq *MailboxFlagQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mfq.limit; limit != nil { + if limit := mfq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mfq.offset; offset != nil { + if offset := mfq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mfq.order; len(ps) > 0 { @@ -409,7 +411,7 @@ func (mfq *MailboxFlagQuery) querySpec() *sqlgraph.QuerySpec { func (mfq *MailboxFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mfq.driver.Dialect()) t1 := builder.Table(mailboxflag.Table) - columns := mfq.fields + columns := mfq.ctx.Fields if len(columns) == 0 { columns = mailboxflag.Columns } @@ -418,7 +420,7 @@ func (mfq *MailboxFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mfq.sql selector.Select(selector.Columns(columns...)...) } - if mfq.unique != nil && *mfq.unique { + if mfq.ctx.Unique != nil && *mfq.ctx.Unique { selector.Distinct() } for _, p := range mfq.predicates { @@ -427,12 +429,12 @@ func (mfq *MailboxFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mfq.order { p(selector) } - if offset := mfq.offset; offset != nil { + if offset := mfq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mfq.limit; limit != nil { + if limit := mfq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -440,13 +442,8 @@ func (mfq *MailboxFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { // MailboxFlagGroupBy is the group-by builder for MailboxFlag entities. type MailboxFlagGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MailboxFlagQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -455,74 +452,77 @@ func (mfgb *MailboxFlagGroupBy) Aggregate(fns ...AggregateFunc) *MailboxFlagGrou return mfgb } -// Scan applies the group-by query and scans the result into the given value. -func (mfgb *MailboxFlagGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := mfgb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (mfgb *MailboxFlagGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mfgb.build.ctx, "GroupBy") + if err := mfgb.build.prepareQuery(ctx); err != nil { return err } - mfgb.sql = query - return mfgb.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxFlagQuery, *MailboxFlagGroupBy](ctx, mfgb.build, mfgb, mfgb.build.inters, v) } -func (mfgb *MailboxFlagGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range mfgb.fields { - if !mailboxflag.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mfgb *MailboxFlagGroupBy) sqlScan(ctx context.Context, root *MailboxFlagQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mfgb.fns)) + for _, fn := range mfgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mfgb.flds)+len(mfgb.fns)) + for _, f := range *mfgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mfgb.sqlQuery() + selector.GroupBy(selector.Columns(*mfgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mfgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mfgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mfgb *MailboxFlagGroupBy) sqlQuery() *sql.Selector { - selector := mfgb.sql.Select() - aggregation := make([]string, 0, len(mfgb.fns)) - for _, fn := range mfgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mfgb.fields)+len(mfgb.fns)) - for _, f := range mfgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mfgb.fields...)...) -} - // MailboxFlagSelect is the builder for selecting fields of MailboxFlag entities. type MailboxFlagSelect struct { *MailboxFlagQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (mfs *MailboxFlagSelect) Aggregate(fns ...AggregateFunc) *MailboxFlagSelect { + mfs.fns = append(mfs.fns, fns...) + return mfs } // Scan applies the selector query and scans the result into the given value. -func (mfs *MailboxFlagSelect) Scan(ctx context.Context, v interface{}) error { +func (mfs *MailboxFlagSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mfs.ctx, "Select") if err := mfs.prepareQuery(ctx); err != nil { return err } - mfs.sql = mfs.MailboxFlagQuery.sqlQuery(ctx) - return mfs.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxFlagQuery, *MailboxFlagSelect](ctx, mfs.MailboxFlagQuery, mfs, mfs.inters, v) } -func (mfs *MailboxFlagSelect) sqlScan(ctx context.Context, v interface{}) error { +func (mfs *MailboxFlagSelect) sqlScan(ctx context.Context, root *MailboxFlagQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(mfs.fns)) + for _, fn := range mfs.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*mfs.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := mfs.sql.Query() + query, args := selector.Query() if err := mfs.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/mailboxflag_update.go b/internal/db_impl/ent_db/internal/mailboxflag_update.go similarity index 63% rename from internal/db/ent/mailboxflag_update.go rename to internal/db_impl/ent_db/internal/mailboxflag_update.go index 2ce94f03..5c4d5496 100644 --- a/internal/db/ent/mailboxflag_update.go +++ b/internal/db_impl/ent_db/internal/mailboxflag_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,8 +10,8 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxFlagUpdate is the builder for updating MailboxFlag entities. @@ -40,34 +40,7 @@ func (mfu *MailboxFlagUpdate) Mutation() *MailboxFlagMutation { // Save executes the query and returns the number of nodes affected by the update operation. func (mfu *MailboxFlagUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mfu.hooks) == 0 { - affected, err = mfu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mfu.mutation = mutation - affected, err = mfu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mfu.hooks) - 1; i >= 0; i-- { - if mfu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mfu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxFlagMutation](ctx, mfu.sqlSave, mfu.mutation, mfu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -93,16 +66,7 @@ func (mfu *MailboxFlagUpdate) ExecX(ctx context.Context) { } func (mfu *MailboxFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxflag.Table, - Columns: mailboxflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailboxflag.Table, mailboxflag.Columns, sqlgraph.NewFieldSpec(mailboxflag.FieldID, field.TypeInt)) if ps := mfu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -111,11 +75,7 @@ func (mfu *MailboxFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := mfu.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxflag.FieldValue, - }) + _spec.SetField(mailboxflag.FieldValue, field.TypeString, value) } if n, err = sqlgraph.UpdateNodes(ctx, mfu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -125,6 +85,7 @@ func (mfu *MailboxFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mfu.mutation.done = true return n, nil } @@ -147,6 +108,12 @@ func (mfuo *MailboxFlagUpdateOne) Mutation() *MailboxFlagMutation { return mfuo.mutation } +// Where appends a list predicates to the MailboxFlagUpdate builder. +func (mfuo *MailboxFlagUpdateOne) Where(ps ...predicate.MailboxFlag) *MailboxFlagUpdateOne { + mfuo.mutation.Where(ps...) + return mfuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (mfuo *MailboxFlagUpdateOne) Select(field string, fields ...string) *MailboxFlagUpdateOne { @@ -156,40 +123,7 @@ func (mfuo *MailboxFlagUpdateOne) Select(field string, fields ...string) *Mailbo // Save executes the query and returns the updated MailboxFlag entity. func (mfuo *MailboxFlagUpdateOne) Save(ctx context.Context) (*MailboxFlag, error) { - var ( - err error - node *MailboxFlag - ) - if len(mfuo.hooks) == 0 { - node, err = mfuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mfuo.mutation = mutation - node, err = mfuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(mfuo.hooks) - 1; i >= 0; i-- { - if mfuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mfuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MailboxFlag) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxFlagMutation", v) - } - node = nv - } - return node, err + return withHooks[*MailboxFlag, MailboxFlagMutation](ctx, mfuo.sqlSave, mfuo.mutation, mfuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -215,19 +149,10 @@ func (mfuo *MailboxFlagUpdateOne) ExecX(ctx context.Context) { } func (mfuo *MailboxFlagUpdateOne) sqlSave(ctx context.Context) (_node *MailboxFlag, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxflag.Table, - Columns: mailboxflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailboxflag.Table, mailboxflag.Columns, sqlgraph.NewFieldSpec(mailboxflag.FieldID, field.TypeInt)) id, ok := mfuo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "MailboxFlag.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "MailboxFlag.id" for update`)} } _spec.Node.ID.Value = id if fields := mfuo.fields; len(fields) > 0 { @@ -235,7 +160,7 @@ func (mfuo *MailboxFlagUpdateOne) sqlSave(ctx context.Context) (_node *MailboxFl _spec.Node.Columns = append(_spec.Node.Columns, mailboxflag.FieldID) for _, f := range fields { if !mailboxflag.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != mailboxflag.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -250,11 +175,7 @@ func (mfuo *MailboxFlagUpdateOne) sqlSave(ctx context.Context) (_node *MailboxFl } } if value, ok := mfuo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxflag.FieldValue, - }) + _spec.SetField(mailboxflag.FieldValue, field.TypeString, value) } _node = &MailboxFlag{config: mfuo.config} _spec.Assign = _node.assignValues @@ -267,5 +188,6 @@ func (mfuo *MailboxFlagUpdateOne) sqlSave(ctx context.Context) (_node *MailboxFl } return nil, err } + mfuo.mutation.done = true return _node, nil } diff --git a/internal/db/ent/mailboxpermflag.go b/internal/db_impl/ent_db/internal/mailboxpermflag.go similarity index 87% rename from internal/db/ent/mailboxpermflag.go rename to internal/db_impl/ent_db/internal/mailboxpermflag.go index 1c25ac63..a67eb63a 100644 --- a/internal/db/ent/mailboxpermflag.go +++ b/internal/db_impl/ent_db/internal/mailboxpermflag.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -8,7 +8,7 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" ) // MailboxPermFlag is the model entity for the MailboxPermFlag schema. @@ -22,8 +22,8 @@ type MailboxPermFlag struct { } // scanValues returns the types for scanning values from sql.Rows. -func (*MailboxPermFlag) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*MailboxPermFlag) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case mailboxpermflag.FieldID: @@ -41,7 +41,7 @@ func (*MailboxPermFlag) scanValues(columns []string) ([]interface{}, error) { // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the MailboxPermFlag fields. -func (mpf *MailboxPermFlag) assignValues(columns []string, values []interface{}) error { +func (mpf *MailboxPermFlag) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -75,7 +75,7 @@ func (mpf *MailboxPermFlag) assignValues(columns []string, values []interface{}) // Note that you need to call MailboxPermFlag.Unwrap() before calling this method if this MailboxPermFlag // was returned from a transaction, and the transaction was committed or rolled back. func (mpf *MailboxPermFlag) Update() *MailboxPermFlagUpdateOne { - return (&MailboxPermFlagClient{config: mpf.config}).UpdateOne(mpf) + return NewMailboxPermFlagClient(mpf.config).UpdateOne(mpf) } // Unwrap unwraps the MailboxPermFlag entity that was returned from a transaction after it was closed, @@ -83,7 +83,7 @@ func (mpf *MailboxPermFlag) Update() *MailboxPermFlagUpdateOne { func (mpf *MailboxPermFlag) Unwrap() *MailboxPermFlag { _tx, ok := mpf.config.driver.(*txDriver) if !ok { - panic("ent: MailboxPermFlag is not a transactional entity") + panic("internal: MailboxPermFlag is not a transactional entity") } mpf.config.driver = _tx.drv return mpf @@ -102,9 +102,3 @@ func (mpf *MailboxPermFlag) String() string { // MailboxPermFlags is a parsable slice of MailboxPermFlag. type MailboxPermFlags []*MailboxPermFlag - -func (mpf MailboxPermFlags) config(cfg config) { - for _i := range mpf { - mpf[_i].config = cfg - } -} diff --git a/internal/db/ent/mailboxpermflag/mailboxpermflag.go b/internal/db_impl/ent_db/internal/mailboxpermflag/mailboxpermflag.go similarity index 100% rename from internal/db/ent/mailboxpermflag/mailboxpermflag.go rename to internal/db_impl/ent_db/internal/mailboxpermflag/mailboxpermflag.go diff --git a/internal/db/ent/mailboxpermflag/where.go b/internal/db_impl/ent_db/internal/mailboxpermflag/where.go similarity index 56% rename from internal/db/ent/mailboxpermflag/where.go rename to internal/db_impl/ent_db/internal/mailboxpermflag/where.go index 035aec0a..6d959434 100644 --- a/internal/db/ent/mailboxpermflag/where.go +++ b/internal/db_impl/ent_db/internal/mailboxpermflag/where.go @@ -4,184 +4,122 @@ package mailboxpermflag import ( "entgo.io/ent/dialect/sql" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MailboxPermFlag(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MailboxPermFlag(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.MailboxPermFlag(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.MailboxPermFlag(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.MailboxPermFlag(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.MailboxPermFlag(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.MailboxPermFlag(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.MailboxPermFlag(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.MailboxPermFlag(sql.FieldLTE(FieldID, id)) } // Value applies equality check predicate on the "Value" field. It's identical to ValueEQ. func Value(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldEQ(FieldValue, v)) } // ValueEQ applies the EQ predicate on the "Value" field. func ValueEQ(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "Value" field. func ValueNEQ(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "Value" field. func ValueIn(vs ...string) predicate.MailboxPermFlag { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.MailboxPermFlag(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "Value" field. func ValueNotIn(vs ...string) predicate.MailboxPermFlag { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.MailboxPermFlag(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "Value" field. func ValueGT(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "Value" field. func ValueGTE(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "Value" field. func ValueLT(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "Value" field. func ValueLTE(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "Value" field. func ValueContains(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "Value" field. func ValueHasPrefix(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "Value" field. func ValueHasSuffix(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "Value" field. func ValueEqualFold(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "Value" field. func ValueContainsFold(v string) predicate.MailboxPermFlag { - return predicate.MailboxPermFlag(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.MailboxPermFlag(sql.FieldContainsFold(FieldValue, v)) } // And groups predicates with the AND operator between them. diff --git a/internal/db/ent/mailboxpermflag_create.go b/internal/db_impl/ent_db/internal/mailboxpermflag_create.go similarity index 74% rename from internal/db/ent/mailboxpermflag_create.go rename to internal/db_impl/ent_db/internal/mailboxpermflag_create.go index aa8089bd..8a9cbbd2 100644 --- a/internal/db/ent/mailboxpermflag_create.go +++ b/internal/db_impl/ent_db/internal/mailboxpermflag_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -9,7 +9,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" ) // MailboxPermFlagCreate is the builder for creating a MailboxPermFlag entity. @@ -32,49 +32,7 @@ func (mpfc *MailboxPermFlagCreate) Mutation() *MailboxPermFlagMutation { // Save creates the MailboxPermFlag in the database. func (mpfc *MailboxPermFlagCreate) Save(ctx context.Context) (*MailboxPermFlag, error) { - var ( - err error - node *MailboxPermFlag - ) - if len(mpfc.hooks) == 0 { - if err = mpfc.check(); err != nil { - return nil, err - } - node, err = mpfc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxPermFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mpfc.check(); err != nil { - return nil, err - } - mpfc.mutation = mutation - if node, err = mpfc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mpfc.hooks) - 1; i >= 0; i-- { - if mpfc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mpfc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mpfc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MailboxPermFlag) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxPermFlagMutation", v) - } - node = nv - } - return node, err + return withHooks[*MailboxPermFlag, MailboxPermFlagMutation](ctx, mpfc.sqlSave, mpfc.mutation, mpfc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -102,12 +60,15 @@ func (mpfc *MailboxPermFlagCreate) ExecX(ctx context.Context) { // check runs all checks and user-defined validators on the builder. func (mpfc *MailboxPermFlagCreate) check() error { if _, ok := mpfc.mutation.Value(); !ok { - return &ValidationError{Name: "Value", err: errors.New(`ent: missing required field "MailboxPermFlag.Value"`)} + return &ValidationError{Name: "Value", err: errors.New(`internal: missing required field "MailboxPermFlag.Value"`)} } return nil } func (mpfc *MailboxPermFlagCreate) sqlSave(ctx context.Context) (*MailboxPermFlag, error) { + if err := mpfc.check(); err != nil { + return nil, err + } _node, _spec := mpfc.createSpec() if err := sqlgraph.CreateNode(ctx, mpfc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -117,26 +78,18 @@ func (mpfc *MailboxPermFlagCreate) sqlSave(ctx context.Context) (*MailboxPermFla } id := _spec.ID.Value.(int64) _node.ID = int(id) + mpfc.mutation.id = &_node.ID + mpfc.mutation.done = true return _node, nil } func (mpfc *MailboxPermFlagCreate) createSpec() (*MailboxPermFlag, *sqlgraph.CreateSpec) { var ( _node = &MailboxPermFlag{config: mpfc.config} - _spec = &sqlgraph.CreateSpec{ - Table: mailboxpermflag.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxpermflag.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(mailboxpermflag.Table, sqlgraph.NewFieldSpec(mailboxpermflag.FieldID, field.TypeInt)) ) if value, ok := mpfc.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxpermflag.FieldValue, - }) + _spec.SetField(mailboxpermflag.FieldValue, field.TypeString, value) _node.Value = value } return _node, _spec diff --git a/internal/db/ent/mailboxpermflag_delete.go b/internal/db_impl/ent_db/internal/mailboxpermflag_delete.go similarity index 63% rename from internal/db/ent/mailboxpermflag_delete.go rename to internal/db_impl/ent_db/internal/mailboxpermflag_delete.go index 0e2aa06d..2f2b48c9 100644 --- a/internal/db/ent/mailboxpermflag_delete.go +++ b/internal/db_impl/ent_db/internal/mailboxpermflag_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxPermFlagDelete is the builder for deleting a MailboxPermFlag entity. @@ -28,34 +27,7 @@ func (mpfd *MailboxPermFlagDelete) Where(ps ...predicate.MailboxPermFlag) *Mailb // Exec executes the deletion query and returns how many vertices were deleted. func (mpfd *MailboxPermFlagDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mpfd.hooks) == 0 { - affected, err = mpfd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxPermFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mpfd.mutation = mutation - affected, err = mpfd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(mpfd.hooks) - 1; i >= 0; i-- { - if mpfd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mpfd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mpfd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxPermFlagMutation](ctx, mpfd.sqlExec, mpfd.mutation, mpfd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (mpfd *MailboxPermFlagDelete) ExecX(ctx context.Context) int { } func (mpfd *MailboxPermFlagDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxpermflag.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxpermflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(mailboxpermflag.Table, sqlgraph.NewFieldSpec(mailboxpermflag.FieldID, field.TypeInt)) if ps := mpfd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (mpfd *MailboxPermFlagDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + mpfd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MailboxPermFlagDeleteOne struct { mpfd *MailboxPermFlagDelete } +// Where appends a list predicates to the MailboxPermFlagDelete builder. +func (mpfdo *MailboxPermFlagDeleteOne) Where(ps ...predicate.MailboxPermFlag) *MailboxPermFlagDeleteOne { + mpfdo.mpfd.mutation.Where(ps...) + return mpfdo +} + // Exec executes the deletion query. func (mpfdo *MailboxPermFlagDeleteOne) Exec(ctx context.Context) error { n, err := mpfdo.mpfd.Exec(ctx) @@ -111,5 +82,7 @@ func (mpfdo *MailboxPermFlagDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mpfdo *MailboxPermFlagDeleteOne) ExecX(ctx context.Context) { - mpfdo.mpfd.ExecX(ctx) + if err := mpfdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/mailboxpermflag_query.go b/internal/db_impl/ent_db/internal/mailboxpermflag_query.go similarity index 68% rename from internal/db/ent/mailboxpermflag_query.go rename to internal/db_impl/ent_db/internal/mailboxpermflag_query.go index a5b1f983..d955eb89 100644 --- a/internal/db/ent/mailboxpermflag_query.go +++ b/internal/db_impl/ent_db/internal/mailboxpermflag_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,18 +10,16 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxPermFlagQuery is the builder for querying MailboxPermFlag entities. type MailboxPermFlagQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.MailboxPermFlag withFKs bool // intermediate query (i.e. traversal path). @@ -35,26 +33,26 @@ func (mpfq *MailboxPermFlagQuery) Where(ps ...predicate.MailboxPermFlag) *Mailbo return mpfq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mpfq *MailboxPermFlagQuery) Limit(limit int) *MailboxPermFlagQuery { - mpfq.limit = &limit + mpfq.ctx.Limit = &limit return mpfq } -// Offset adds an offset step to the query. +// Offset to start from. func (mpfq *MailboxPermFlagQuery) Offset(offset int) *MailboxPermFlagQuery { - mpfq.offset = &offset + mpfq.ctx.Offset = &offset return mpfq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mpfq *MailboxPermFlagQuery) Unique(unique bool) *MailboxPermFlagQuery { - mpfq.unique = &unique + mpfq.ctx.Unique = &unique return mpfq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (mpfq *MailboxPermFlagQuery) Order(o ...OrderFunc) *MailboxPermFlagQuery { mpfq.order = append(mpfq.order, o...) return mpfq @@ -63,7 +61,7 @@ func (mpfq *MailboxPermFlagQuery) Order(o ...OrderFunc) *MailboxPermFlagQuery { // First returns the first MailboxPermFlag entity from the query. // Returns a *NotFoundError when no MailboxPermFlag was found. func (mpfq *MailboxPermFlagQuery) First(ctx context.Context) (*MailboxPermFlag, error) { - nodes, err := mpfq.Limit(1).All(ctx) + nodes, err := mpfq.Limit(1).All(setContextOp(ctx, mpfq.ctx, "First")) if err != nil { return nil, err } @@ -86,7 +84,7 @@ func (mpfq *MailboxPermFlagQuery) FirstX(ctx context.Context) *MailboxPermFlag { // Returns a *NotFoundError when no MailboxPermFlag ID was found. func (mpfq *MailboxPermFlagQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mpfq.Limit(1).IDs(ctx); err != nil { + if ids, err = mpfq.Limit(1).IDs(setContextOp(ctx, mpfq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -109,7 +107,7 @@ func (mpfq *MailboxPermFlagQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one MailboxPermFlag entity is found. // Returns a *NotFoundError when no MailboxPermFlag entities are found. func (mpfq *MailboxPermFlagQuery) Only(ctx context.Context) (*MailboxPermFlag, error) { - nodes, err := mpfq.Limit(2).All(ctx) + nodes, err := mpfq.Limit(2).All(setContextOp(ctx, mpfq.ctx, "Only")) if err != nil { return nil, err } @@ -137,7 +135,7 @@ func (mpfq *MailboxPermFlagQuery) OnlyX(ctx context.Context) *MailboxPermFlag { // Returns a *NotFoundError when no entities are found. func (mpfq *MailboxPermFlagQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mpfq.Limit(2).IDs(ctx); err != nil { + if ids, err = mpfq.Limit(2).IDs(setContextOp(ctx, mpfq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -162,10 +160,12 @@ func (mpfq *MailboxPermFlagQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of MailboxPermFlags. func (mpfq *MailboxPermFlagQuery) All(ctx context.Context) ([]*MailboxPermFlag, error) { + ctx = setContextOp(ctx, mpfq.ctx, "All") if err := mpfq.prepareQuery(ctx); err != nil { return nil, err } - return mpfq.sqlAll(ctx) + qr := querierAll[[]*MailboxPermFlag, *MailboxPermFlagQuery]() + return withInterceptors[[]*MailboxPermFlag](ctx, mpfq, qr, mpfq.inters) } // AllX is like All, but panics if an error occurs. @@ -178,9 +178,12 @@ func (mpfq *MailboxPermFlagQuery) AllX(ctx context.Context) []*MailboxPermFlag { } // IDs executes the query and returns a list of MailboxPermFlag IDs. -func (mpfq *MailboxPermFlagQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := mpfq.Select(mailboxpermflag.FieldID).Scan(ctx, &ids); err != nil { +func (mpfq *MailboxPermFlagQuery) IDs(ctx context.Context) (ids []int, err error) { + if mpfq.ctx.Unique == nil && mpfq.path != nil { + mpfq.Unique(true) + } + ctx = setContextOp(ctx, mpfq.ctx, "IDs") + if err = mpfq.Select(mailboxpermflag.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -197,10 +200,11 @@ func (mpfq *MailboxPermFlagQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (mpfq *MailboxPermFlagQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mpfq.ctx, "Count") if err := mpfq.prepareQuery(ctx); err != nil { return 0, err } - return mpfq.sqlCount(ctx) + return withInterceptors[int](ctx, mpfq, querierCount[*MailboxPermFlagQuery](), mpfq.inters) } // CountX is like Count, but panics if an error occurs. @@ -214,10 +218,15 @@ func (mpfq *MailboxPermFlagQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mpfq *MailboxPermFlagQuery) Exist(ctx context.Context) (bool, error) { - if err := mpfq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mpfq.ctx, "Exist") + switch _, err := mpfq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return mpfq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -237,14 +246,13 @@ func (mpfq *MailboxPermFlagQuery) Clone() *MailboxPermFlagQuery { } return &MailboxPermFlagQuery{ config: mpfq.config, - limit: mpfq.limit, - offset: mpfq.offset, + ctx: mpfq.ctx.Clone(), order: append([]OrderFunc{}, mpfq.order...), + inters: append([]Interceptor{}, mpfq.inters...), predicates: append([]predicate.MailboxPermFlag{}, mpfq.predicates...), // clone intermediate query. - sql: mpfq.sql.Clone(), - path: mpfq.path, - unique: mpfq.unique, + sql: mpfq.sql.Clone(), + path: mpfq.path, } } @@ -260,19 +268,14 @@ func (mpfq *MailboxPermFlagQuery) Clone() *MailboxPermFlagQuery { // // client.MailboxPermFlag.Query(). // GroupBy(mailboxpermflag.FieldValue). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (mpfq *MailboxPermFlagQuery) GroupBy(field string, fields ...string) *MailboxPermFlagGroupBy { - grbuild := &MailboxPermFlagGroupBy{config: mpfq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mpfq.prepareQuery(ctx); err != nil { - return nil, err - } - return mpfq.sqlQuery(ctx), nil - } + mpfq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MailboxPermFlagGroupBy{build: mpfq} + grbuild.flds = &mpfq.ctx.Fields grbuild.label = mailboxpermflag.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -289,17 +292,32 @@ func (mpfq *MailboxPermFlagQuery) GroupBy(field string, fields ...string) *Mailb // Select(mailboxpermflag.FieldValue). // Scan(ctx, &v) func (mpfq *MailboxPermFlagQuery) Select(fields ...string) *MailboxPermFlagSelect { - mpfq.fields = append(mpfq.fields, fields...) - selbuild := &MailboxPermFlagSelect{MailboxPermFlagQuery: mpfq} - selbuild.label = mailboxpermflag.Label - selbuild.flds, selbuild.scan = &mpfq.fields, selbuild.Scan - return selbuild + mpfq.ctx.Fields = append(mpfq.ctx.Fields, fields...) + sbuild := &MailboxPermFlagSelect{MailboxPermFlagQuery: mpfq} + sbuild.label = mailboxpermflag.Label + sbuild.flds, sbuild.scan = &mpfq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MailboxPermFlagSelect configured with the given aggregations. +func (mpfq *MailboxPermFlagQuery) Aggregate(fns ...AggregateFunc) *MailboxPermFlagSelect { + return mpfq.Select().Aggregate(fns...) } func (mpfq *MailboxPermFlagQuery) prepareQuery(ctx context.Context) error { - for _, f := range mpfq.fields { + for _, inter := range mpfq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mpfq); err != nil { + return err + } + } + } + for _, f := range mpfq.ctx.Fields { if !mailboxpermflag.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if mpfq.path != nil { @@ -321,10 +339,10 @@ func (mpfq *MailboxPermFlagQuery) sqlAll(ctx context.Context, hooks ...queryHook if withFKs { _spec.Node.Columns = append(_spec.Node.Columns, mailboxpermflag.ForeignKeys...) } - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*MailboxPermFlag).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &MailboxPermFlag{config: mpfq.config} nodes = append(nodes, node) return node.assignValues(columns, values) @@ -343,38 +361,22 @@ func (mpfq *MailboxPermFlagQuery) sqlAll(ctx context.Context, hooks ...queryHook func (mpfq *MailboxPermFlagQuery) sqlCount(ctx context.Context) (int, error) { _spec := mpfq.querySpec() - _spec.Node.Columns = mpfq.fields - if len(mpfq.fields) > 0 { - _spec.Unique = mpfq.unique != nil && *mpfq.unique + _spec.Node.Columns = mpfq.ctx.Fields + if len(mpfq.ctx.Fields) > 0 { + _spec.Unique = mpfq.ctx.Unique != nil && *mpfq.ctx.Unique } return sqlgraph.CountNodes(ctx, mpfq.driver, _spec) } -func (mpfq *MailboxPermFlagQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := mpfq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (mpfq *MailboxPermFlagQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxpermflag.Table, - Columns: mailboxpermflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxpermflag.FieldID, - }, - }, - From: mpfq.sql, - Unique: true, - } - if unique := mpfq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(mailboxpermflag.Table, mailboxpermflag.Columns, sqlgraph.NewFieldSpec(mailboxpermflag.FieldID, field.TypeInt)) + _spec.From = mpfq.sql + if unique := mpfq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mpfq.path != nil { + _spec.Unique = true } - if fields := mpfq.fields; len(fields) > 0 { + if fields := mpfq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, mailboxpermflag.FieldID) for i := range fields { @@ -390,10 +392,10 @@ func (mpfq *MailboxPermFlagQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mpfq.limit; limit != nil { + if limit := mpfq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mpfq.offset; offset != nil { + if offset := mpfq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mpfq.order; len(ps) > 0 { @@ -409,7 +411,7 @@ func (mpfq *MailboxPermFlagQuery) querySpec() *sqlgraph.QuerySpec { func (mpfq *MailboxPermFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mpfq.driver.Dialect()) t1 := builder.Table(mailboxpermflag.Table) - columns := mpfq.fields + columns := mpfq.ctx.Fields if len(columns) == 0 { columns = mailboxpermflag.Columns } @@ -418,7 +420,7 @@ func (mpfq *MailboxPermFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mpfq.sql selector.Select(selector.Columns(columns...)...) } - if mpfq.unique != nil && *mpfq.unique { + if mpfq.ctx.Unique != nil && *mpfq.ctx.Unique { selector.Distinct() } for _, p := range mpfq.predicates { @@ -427,12 +429,12 @@ func (mpfq *MailboxPermFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mpfq.order { p(selector) } - if offset := mpfq.offset; offset != nil { + if offset := mpfq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mpfq.limit; limit != nil { + if limit := mpfq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -440,13 +442,8 @@ func (mpfq *MailboxPermFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { // MailboxPermFlagGroupBy is the group-by builder for MailboxPermFlag entities. type MailboxPermFlagGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MailboxPermFlagQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -455,74 +452,77 @@ func (mpfgb *MailboxPermFlagGroupBy) Aggregate(fns ...AggregateFunc) *MailboxPer return mpfgb } -// Scan applies the group-by query and scans the result into the given value. -func (mpfgb *MailboxPermFlagGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := mpfgb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (mpfgb *MailboxPermFlagGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mpfgb.build.ctx, "GroupBy") + if err := mpfgb.build.prepareQuery(ctx); err != nil { return err } - mpfgb.sql = query - return mpfgb.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxPermFlagQuery, *MailboxPermFlagGroupBy](ctx, mpfgb.build, mpfgb, mpfgb.build.inters, v) } -func (mpfgb *MailboxPermFlagGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range mpfgb.fields { - if !mailboxpermflag.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mpfgb *MailboxPermFlagGroupBy) sqlScan(ctx context.Context, root *MailboxPermFlagQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mpfgb.fns)) + for _, fn := range mpfgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mpfgb.flds)+len(mpfgb.fns)) + for _, f := range *mpfgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mpfgb.sqlQuery() + selector.GroupBy(selector.Columns(*mpfgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mpfgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mpfgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mpfgb *MailboxPermFlagGroupBy) sqlQuery() *sql.Selector { - selector := mpfgb.sql.Select() - aggregation := make([]string, 0, len(mpfgb.fns)) - for _, fn := range mpfgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mpfgb.fields)+len(mpfgb.fns)) - for _, f := range mpfgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mpfgb.fields...)...) -} - // MailboxPermFlagSelect is the builder for selecting fields of MailboxPermFlag entities. type MailboxPermFlagSelect struct { *MailboxPermFlagQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (mpfs *MailboxPermFlagSelect) Aggregate(fns ...AggregateFunc) *MailboxPermFlagSelect { + mpfs.fns = append(mpfs.fns, fns...) + return mpfs } // Scan applies the selector query and scans the result into the given value. -func (mpfs *MailboxPermFlagSelect) Scan(ctx context.Context, v interface{}) error { +func (mpfs *MailboxPermFlagSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mpfs.ctx, "Select") if err := mpfs.prepareQuery(ctx); err != nil { return err } - mpfs.sql = mpfs.MailboxPermFlagQuery.sqlQuery(ctx) - return mpfs.sqlScan(ctx, v) + return scanWithInterceptors[*MailboxPermFlagQuery, *MailboxPermFlagSelect](ctx, mpfs.MailboxPermFlagQuery, mpfs, mpfs.inters, v) } -func (mpfs *MailboxPermFlagSelect) sqlScan(ctx context.Context, v interface{}) error { +func (mpfs *MailboxPermFlagSelect) sqlScan(ctx context.Context, root *MailboxPermFlagQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(mpfs.fns)) + for _, fn := range mpfs.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*mpfs.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := mpfs.sql.Query() + query, args := selector.Query() if err := mpfs.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/mailboxpermflag_update.go b/internal/db_impl/ent_db/internal/mailboxpermflag_update.go similarity index 64% rename from internal/db/ent/mailboxpermflag_update.go rename to internal/db_impl/ent_db/internal/mailboxpermflag_update.go index bbfd2583..d5a13206 100644 --- a/internal/db/ent/mailboxpermflag_update.go +++ b/internal/db_impl/ent_db/internal/mailboxpermflag_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,8 +10,8 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MailboxPermFlagUpdate is the builder for updating MailboxPermFlag entities. @@ -40,34 +40,7 @@ func (mpfu *MailboxPermFlagUpdate) Mutation() *MailboxPermFlagMutation { // Save executes the query and returns the number of nodes affected by the update operation. func (mpfu *MailboxPermFlagUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mpfu.hooks) == 0 { - affected, err = mpfu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxPermFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mpfu.mutation = mutation - affected, err = mpfu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mpfu.hooks) - 1; i >= 0; i-- { - if mpfu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mpfu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mpfu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MailboxPermFlagMutation](ctx, mpfu.sqlSave, mpfu.mutation, mpfu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -93,16 +66,7 @@ func (mpfu *MailboxPermFlagUpdate) ExecX(ctx context.Context) { } func (mpfu *MailboxPermFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxpermflag.Table, - Columns: mailboxpermflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxpermflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailboxpermflag.Table, mailboxpermflag.Columns, sqlgraph.NewFieldSpec(mailboxpermflag.FieldID, field.TypeInt)) if ps := mpfu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -111,11 +75,7 @@ func (mpfu *MailboxPermFlagUpdate) sqlSave(ctx context.Context) (n int, err erro } } if value, ok := mpfu.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxpermflag.FieldValue, - }) + _spec.SetField(mailboxpermflag.FieldValue, field.TypeString, value) } if n, err = sqlgraph.UpdateNodes(ctx, mpfu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { @@ -125,6 +85,7 @@ func (mpfu *MailboxPermFlagUpdate) sqlSave(ctx context.Context) (n int, err erro } return 0, err } + mpfu.mutation.done = true return n, nil } @@ -147,6 +108,12 @@ func (mpfuo *MailboxPermFlagUpdateOne) Mutation() *MailboxPermFlagMutation { return mpfuo.mutation } +// Where appends a list predicates to the MailboxPermFlagUpdate builder. +func (mpfuo *MailboxPermFlagUpdateOne) Where(ps ...predicate.MailboxPermFlag) *MailboxPermFlagUpdateOne { + mpfuo.mutation.Where(ps...) + return mpfuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (mpfuo *MailboxPermFlagUpdateOne) Select(field string, fields ...string) *MailboxPermFlagUpdateOne { @@ -156,40 +123,7 @@ func (mpfuo *MailboxPermFlagUpdateOne) Select(field string, fields ...string) *M // Save executes the query and returns the updated MailboxPermFlag entity. func (mpfuo *MailboxPermFlagUpdateOne) Save(ctx context.Context) (*MailboxPermFlag, error) { - var ( - err error - node *MailboxPermFlag - ) - if len(mpfuo.hooks) == 0 { - node, err = mpfuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MailboxPermFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mpfuo.mutation = mutation - node, err = mpfuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(mpfuo.hooks) - 1; i >= 0; i-- { - if mpfuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mpfuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mpfuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MailboxPermFlag) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MailboxPermFlagMutation", v) - } - node = nv - } - return node, err + return withHooks[*MailboxPermFlag, MailboxPermFlagMutation](ctx, mpfuo.sqlSave, mpfuo.mutation, mpfuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -215,19 +149,10 @@ func (mpfuo *MailboxPermFlagUpdateOne) ExecX(ctx context.Context) { } func (mpfuo *MailboxPermFlagUpdateOne) sqlSave(ctx context.Context) (_node *MailboxPermFlag, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: mailboxpermflag.Table, - Columns: mailboxpermflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: mailboxpermflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(mailboxpermflag.Table, mailboxpermflag.Columns, sqlgraph.NewFieldSpec(mailboxpermflag.FieldID, field.TypeInt)) id, ok := mpfuo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "MailboxPermFlag.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "MailboxPermFlag.id" for update`)} } _spec.Node.ID.Value = id if fields := mpfuo.fields; len(fields) > 0 { @@ -235,7 +160,7 @@ func (mpfuo *MailboxPermFlagUpdateOne) sqlSave(ctx context.Context) (_node *Mail _spec.Node.Columns = append(_spec.Node.Columns, mailboxpermflag.FieldID) for _, f := range fields { if !mailboxpermflag.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != mailboxpermflag.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -250,11 +175,7 @@ func (mpfuo *MailboxPermFlagUpdateOne) sqlSave(ctx context.Context) (_node *Mail } } if value, ok := mpfuo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: mailboxpermflag.FieldValue, - }) + _spec.SetField(mailboxpermflag.FieldValue, field.TypeString, value) } _node = &MailboxPermFlag{config: mpfuo.config} _spec.Assign = _node.assignValues @@ -267,5 +188,6 @@ func (mpfuo *MailboxPermFlagUpdateOne) sqlSave(ctx context.Context) (_node *Mail } return nil, err } + mpfuo.mutation.done = true return _node, nil } diff --git a/internal/db/ent/message.go b/internal/db_impl/ent_db/internal/message.go similarity index 91% rename from internal/db/ent/message.go rename to internal/db_impl/ent_db/internal/message.go index 3c8ac699..5b0f483e 100644 --- a/internal/db/ent/message.go +++ b/internal/db_impl/ent_db/internal/message.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -9,7 +9,7 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" ) // Message is the model entity for the Message schema. @@ -66,8 +66,8 @@ func (e MessageEdges) UIDsOrErr() ([]*UID, error) { } // scanValues returns the types for scanning values from sql.Rows. -func (*Message) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*Message) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case message.FieldID: @@ -89,7 +89,7 @@ func (*Message) scanValues(columns []string) ([]interface{}, error) { // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the Message fields. -func (m *Message) assignValues(columns []string, values []interface{}) error { +func (m *Message) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -150,19 +150,19 @@ func (m *Message) assignValues(columns []string, values []interface{}) error { // QueryFlags queries the "flags" edge of the Message entity. func (m *Message) QueryFlags() *MessageFlagQuery { - return (&MessageClient{config: m.config}).QueryFlags(m) + return NewMessageClient(m.config).QueryFlags(m) } // QueryUIDs queries the "UIDs" edge of the Message entity. func (m *Message) QueryUIDs() *UIDQuery { - return (&MessageClient{config: m.config}).QueryUIDs(m) + return NewMessageClient(m.config).QueryUIDs(m) } // Update returns a builder for updating this Message. // Note that you need to call Message.Unwrap() before calling this method if this Message // was returned from a transaction, and the transaction was committed or rolled back. func (m *Message) Update() *MessageUpdateOne { - return (&MessageClient{config: m.config}).UpdateOne(m) + return NewMessageClient(m.config).UpdateOne(m) } // Unwrap unwraps the Message entity that was returned from a transaction after it was closed, @@ -170,7 +170,7 @@ func (m *Message) Update() *MessageUpdateOne { func (m *Message) Unwrap() *Message { _tx, ok := m.config.driver.(*txDriver) if !ok { - panic("ent: Message is not a transactional entity") + panic("internal: Message is not a transactional entity") } m.config.driver = _tx.drv return m @@ -207,9 +207,3 @@ func (m *Message) String() string { // Messages is a parsable slice of Message. type Messages []*Message - -func (m Messages) config(cfg config) { - for _i := range m { - m[_i].config = cfg - } -} diff --git a/internal/db/ent/message/message.go b/internal/db_impl/ent_db/internal/message/message.go similarity index 100% rename from internal/db/ent/message/message.go rename to internal/db_impl/ent_db/internal/message/message.go diff --git a/internal/db/ent/message/where.go b/internal/db_impl/ent_db/internal/message/where.go similarity index 58% rename from internal/db/ent/message/where.go rename to internal/db_impl/ent_db/internal/message/where.go index 28c524f8..c2cc30ef 100644 --- a/internal/db/ent/message/where.go +++ b/internal/db_impl/ent_db/internal/message/where.go @@ -8,691 +8,467 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Message(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Message(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Message(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Message(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Message(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Message(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Message(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Message(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id imap.InternalMessageID) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Message(sql.FieldLTE(FieldID, id)) } // RemoteID applies equality check predicate on the "RemoteID" field. It's identical to RemoteIDEQ. func RemoteID(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldEQ(FieldRemoteID, vc)) } // Date applies equality check predicate on the "Date" field. It's identical to DateEQ. func Date(v time.Time) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldDate), v)) - }) + return predicate.Message(sql.FieldEQ(FieldDate, v)) } // Size applies equality check predicate on the "Size" field. It's identical to SizeEQ. func Size(v int) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSize), v)) - }) + return predicate.Message(sql.FieldEQ(FieldSize, v)) } // Body applies equality check predicate on the "Body" field. It's identical to BodyEQ. func Body(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldEQ(FieldBody, v)) } // BodyStructure applies equality check predicate on the "BodyStructure" field. It's identical to BodyStructureEQ. func BodyStructure(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldEQ(FieldBodyStructure, v)) } // Envelope applies equality check predicate on the "Envelope" field. It's identical to EnvelopeEQ. func Envelope(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldEQ(FieldEnvelope, v)) } // Deleted applies equality check predicate on the "Deleted" field. It's identical to DeletedEQ. func Deleted(v bool) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldDeleted), v)) - }) + return predicate.Message(sql.FieldEQ(FieldDeleted, v)) } // RemoteIDEQ applies the EQ predicate on the "RemoteID" field. func RemoteIDEQ(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldEQ(FieldRemoteID, vc)) } // RemoteIDNEQ applies the NEQ predicate on the "RemoteID" field. func RemoteIDNEQ(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldNEQ(FieldRemoteID, vc)) } // RemoteIDIn applies the In predicate on the "RemoteID" field. func RemoteIDIn(vs ...imap.MessageID) predicate.Message { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = string(vs[i]) } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldRemoteID), v...)) - }) + return predicate.Message(sql.FieldIn(FieldRemoteID, v...)) } // RemoteIDNotIn applies the NotIn predicate on the "RemoteID" field. func RemoteIDNotIn(vs ...imap.MessageID) predicate.Message { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = string(vs[i]) } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldRemoteID), v...)) - }) + return predicate.Message(sql.FieldNotIn(FieldRemoteID, v...)) } // RemoteIDGT applies the GT predicate on the "RemoteID" field. func RemoteIDGT(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldGT(FieldRemoteID, vc)) } // RemoteIDGTE applies the GTE predicate on the "RemoteID" field. func RemoteIDGTE(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldGTE(FieldRemoteID, vc)) } // RemoteIDLT applies the LT predicate on the "RemoteID" field. func RemoteIDLT(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldLT(FieldRemoteID, vc)) } // RemoteIDLTE applies the LTE predicate on the "RemoteID" field. func RemoteIDLTE(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldLTE(FieldRemoteID, vc)) } // RemoteIDContains applies the Contains predicate on the "RemoteID" field. func RemoteIDContains(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldContains(FieldRemoteID, vc)) } // RemoteIDHasPrefix applies the HasPrefix predicate on the "RemoteID" field. func RemoteIDHasPrefix(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldHasPrefix(FieldRemoteID, vc)) } // RemoteIDHasSuffix applies the HasSuffix predicate on the "RemoteID" field. func RemoteIDHasSuffix(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldHasSuffix(FieldRemoteID, vc)) } // RemoteIDIsNil applies the IsNil predicate on the "RemoteID" field. func RemoteIDIsNil() predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.IsNull(s.C(FieldRemoteID))) - }) + return predicate.Message(sql.FieldIsNull(FieldRemoteID)) } // RemoteIDNotNil applies the NotNil predicate on the "RemoteID" field. func RemoteIDNotNil() predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NotNull(s.C(FieldRemoteID))) - }) + return predicate.Message(sql.FieldNotNull(FieldRemoteID)) } // RemoteIDEqualFold applies the EqualFold predicate on the "RemoteID" field. func RemoteIDEqualFold(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldEqualFold(FieldRemoteID, vc)) } // RemoteIDContainsFold applies the ContainsFold predicate on the "RemoteID" field. func RemoteIDContainsFold(v imap.MessageID) predicate.Message { vc := string(v) - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldRemoteID), vc)) - }) + return predicate.Message(sql.FieldContainsFold(FieldRemoteID, vc)) } // DateEQ applies the EQ predicate on the "Date" field. func DateEQ(v time.Time) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldDate), v)) - }) + return predicate.Message(sql.FieldEQ(FieldDate, v)) } // DateNEQ applies the NEQ predicate on the "Date" field. func DateNEQ(v time.Time) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldDate), v)) - }) + return predicate.Message(sql.FieldNEQ(FieldDate, v)) } // DateIn applies the In predicate on the "Date" field. func DateIn(vs ...time.Time) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldDate), v...)) - }) + return predicate.Message(sql.FieldIn(FieldDate, vs...)) } // DateNotIn applies the NotIn predicate on the "Date" field. func DateNotIn(vs ...time.Time) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldDate), v...)) - }) + return predicate.Message(sql.FieldNotIn(FieldDate, vs...)) } // DateGT applies the GT predicate on the "Date" field. func DateGT(v time.Time) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldDate), v)) - }) + return predicate.Message(sql.FieldGT(FieldDate, v)) } // DateGTE applies the GTE predicate on the "Date" field. func DateGTE(v time.Time) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldDate), v)) - }) + return predicate.Message(sql.FieldGTE(FieldDate, v)) } // DateLT applies the LT predicate on the "Date" field. func DateLT(v time.Time) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldDate), v)) - }) + return predicate.Message(sql.FieldLT(FieldDate, v)) } // DateLTE applies the LTE predicate on the "Date" field. func DateLTE(v time.Time) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldDate), v)) - }) + return predicate.Message(sql.FieldLTE(FieldDate, v)) } // SizeEQ applies the EQ predicate on the "Size" field. func SizeEQ(v int) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSize), v)) - }) + return predicate.Message(sql.FieldEQ(FieldSize, v)) } // SizeNEQ applies the NEQ predicate on the "Size" field. func SizeNEQ(v int) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSize), v)) - }) + return predicate.Message(sql.FieldNEQ(FieldSize, v)) } // SizeIn applies the In predicate on the "Size" field. func SizeIn(vs ...int) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSize), v...)) - }) + return predicate.Message(sql.FieldIn(FieldSize, vs...)) } // SizeNotIn applies the NotIn predicate on the "Size" field. func SizeNotIn(vs ...int) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSize), v...)) - }) + return predicate.Message(sql.FieldNotIn(FieldSize, vs...)) } // SizeGT applies the GT predicate on the "Size" field. func SizeGT(v int) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSize), v)) - }) + return predicate.Message(sql.FieldGT(FieldSize, v)) } // SizeGTE applies the GTE predicate on the "Size" field. func SizeGTE(v int) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSize), v)) - }) + return predicate.Message(sql.FieldGTE(FieldSize, v)) } // SizeLT applies the LT predicate on the "Size" field. func SizeLT(v int) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSize), v)) - }) + return predicate.Message(sql.FieldLT(FieldSize, v)) } // SizeLTE applies the LTE predicate on the "Size" field. func SizeLTE(v int) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSize), v)) - }) + return predicate.Message(sql.FieldLTE(FieldSize, v)) } // BodyEQ applies the EQ predicate on the "Body" field. func BodyEQ(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldEQ(FieldBody, v)) } // BodyNEQ applies the NEQ predicate on the "Body" field. func BodyNEQ(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldNEQ(FieldBody, v)) } // BodyIn applies the In predicate on the "Body" field. func BodyIn(vs ...string) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldBody), v...)) - }) + return predicate.Message(sql.FieldIn(FieldBody, vs...)) } // BodyNotIn applies the NotIn predicate on the "Body" field. func BodyNotIn(vs ...string) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldBody), v...)) - }) + return predicate.Message(sql.FieldNotIn(FieldBody, vs...)) } // BodyGT applies the GT predicate on the "Body" field. func BodyGT(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldGT(FieldBody, v)) } // BodyGTE applies the GTE predicate on the "Body" field. func BodyGTE(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldGTE(FieldBody, v)) } // BodyLT applies the LT predicate on the "Body" field. func BodyLT(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldLT(FieldBody, v)) } // BodyLTE applies the LTE predicate on the "Body" field. func BodyLTE(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldLTE(FieldBody, v)) } // BodyContains applies the Contains predicate on the "Body" field. func BodyContains(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldContains(FieldBody, v)) } // BodyHasPrefix applies the HasPrefix predicate on the "Body" field. func BodyHasPrefix(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldHasPrefix(FieldBody, v)) } // BodyHasSuffix applies the HasSuffix predicate on the "Body" field. func BodyHasSuffix(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldHasSuffix(FieldBody, v)) } // BodyEqualFold applies the EqualFold predicate on the "Body" field. func BodyEqualFold(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldEqualFold(FieldBody, v)) } // BodyContainsFold applies the ContainsFold predicate on the "Body" field. func BodyContainsFold(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldBody), v)) - }) + return predicate.Message(sql.FieldContainsFold(FieldBody, v)) } // BodyStructureEQ applies the EQ predicate on the "BodyStructure" field. func BodyStructureEQ(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldEQ(FieldBodyStructure, v)) } // BodyStructureNEQ applies the NEQ predicate on the "BodyStructure" field. func BodyStructureNEQ(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldNEQ(FieldBodyStructure, v)) } // BodyStructureIn applies the In predicate on the "BodyStructure" field. func BodyStructureIn(vs ...string) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldBodyStructure), v...)) - }) + return predicate.Message(sql.FieldIn(FieldBodyStructure, vs...)) } // BodyStructureNotIn applies the NotIn predicate on the "BodyStructure" field. func BodyStructureNotIn(vs ...string) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldBodyStructure), v...)) - }) + return predicate.Message(sql.FieldNotIn(FieldBodyStructure, vs...)) } // BodyStructureGT applies the GT predicate on the "BodyStructure" field. func BodyStructureGT(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldGT(FieldBodyStructure, v)) } // BodyStructureGTE applies the GTE predicate on the "BodyStructure" field. func BodyStructureGTE(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldGTE(FieldBodyStructure, v)) } // BodyStructureLT applies the LT predicate on the "BodyStructure" field. func BodyStructureLT(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldLT(FieldBodyStructure, v)) } // BodyStructureLTE applies the LTE predicate on the "BodyStructure" field. func BodyStructureLTE(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldLTE(FieldBodyStructure, v)) } // BodyStructureContains applies the Contains predicate on the "BodyStructure" field. func BodyStructureContains(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldContains(FieldBodyStructure, v)) } // BodyStructureHasPrefix applies the HasPrefix predicate on the "BodyStructure" field. func BodyStructureHasPrefix(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldHasPrefix(FieldBodyStructure, v)) } // BodyStructureHasSuffix applies the HasSuffix predicate on the "BodyStructure" field. func BodyStructureHasSuffix(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldHasSuffix(FieldBodyStructure, v)) } // BodyStructureEqualFold applies the EqualFold predicate on the "BodyStructure" field. func BodyStructureEqualFold(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldEqualFold(FieldBodyStructure, v)) } // BodyStructureContainsFold applies the ContainsFold predicate on the "BodyStructure" field. func BodyStructureContainsFold(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldBodyStructure), v)) - }) + return predicate.Message(sql.FieldContainsFold(FieldBodyStructure, v)) } // EnvelopeEQ applies the EQ predicate on the "Envelope" field. func EnvelopeEQ(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldEQ(FieldEnvelope, v)) } // EnvelopeNEQ applies the NEQ predicate on the "Envelope" field. func EnvelopeNEQ(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldNEQ(FieldEnvelope, v)) } // EnvelopeIn applies the In predicate on the "Envelope" field. func EnvelopeIn(vs ...string) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldEnvelope), v...)) - }) + return predicate.Message(sql.FieldIn(FieldEnvelope, vs...)) } // EnvelopeNotIn applies the NotIn predicate on the "Envelope" field. func EnvelopeNotIn(vs ...string) predicate.Message { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldEnvelope), v...)) - }) + return predicate.Message(sql.FieldNotIn(FieldEnvelope, vs...)) } // EnvelopeGT applies the GT predicate on the "Envelope" field. func EnvelopeGT(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldGT(FieldEnvelope, v)) } // EnvelopeGTE applies the GTE predicate on the "Envelope" field. func EnvelopeGTE(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldGTE(FieldEnvelope, v)) } // EnvelopeLT applies the LT predicate on the "Envelope" field. func EnvelopeLT(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldLT(FieldEnvelope, v)) } // EnvelopeLTE applies the LTE predicate on the "Envelope" field. func EnvelopeLTE(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldLTE(FieldEnvelope, v)) } // EnvelopeContains applies the Contains predicate on the "Envelope" field. func EnvelopeContains(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldContains(FieldEnvelope, v)) } // EnvelopeHasPrefix applies the HasPrefix predicate on the "Envelope" field. func EnvelopeHasPrefix(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldHasPrefix(FieldEnvelope, v)) } // EnvelopeHasSuffix applies the HasSuffix predicate on the "Envelope" field. func EnvelopeHasSuffix(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldHasSuffix(FieldEnvelope, v)) } // EnvelopeEqualFold applies the EqualFold predicate on the "Envelope" field. func EnvelopeEqualFold(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldEqualFold(FieldEnvelope, v)) } // EnvelopeContainsFold applies the ContainsFold predicate on the "Envelope" field. func EnvelopeContainsFold(v string) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldEnvelope), v)) - }) + return predicate.Message(sql.FieldContainsFold(FieldEnvelope, v)) } // DeletedEQ applies the EQ predicate on the "Deleted" field. func DeletedEQ(v bool) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldDeleted), v)) - }) + return predicate.Message(sql.FieldEQ(FieldDeleted, v)) } // DeletedNEQ applies the NEQ predicate on the "Deleted" field. func DeletedNEQ(v bool) predicate.Message { - return predicate.Message(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldDeleted), v)) - }) + return predicate.Message(sql.FieldNEQ(FieldDeleted, v)) } // HasFlags applies the HasEdge predicate on the "flags" edge. @@ -700,7 +476,6 @@ func HasFlags() predicate.Message { return predicate.Message(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(FlagsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, FlagsTable, FlagsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -728,7 +503,6 @@ func HasUIDs() predicate.Message { return predicate.Message(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(UIDsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, true, UIDsTable, UIDsColumn), ) sqlgraph.HasNeighbors(s, step) diff --git a/internal/db/ent/message_create.go b/internal/db_impl/ent_db/internal/message_create.go similarity index 74% rename from internal/db/ent/message_create.go rename to internal/db_impl/ent_db/internal/message_create.go index c9d923a0..777b8d1a 100644 --- a/internal/db/ent/message_create.go +++ b/internal/db_impl/ent_db/internal/message_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -11,9 +11,9 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // MessageCreate is the builder for creating a Message entity. @@ -124,50 +124,8 @@ func (mc *MessageCreate) Mutation() *MessageMutation { // Save creates the Message in the database. func (mc *MessageCreate) Save(ctx context.Context) (*Message, error) { - var ( - err error - node *Message - ) mc.defaults() - if len(mc.hooks) == 0 { - if err = mc.check(); err != nil { - return nil, err - } - node, err = mc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mc.check(); err != nil { - return nil, err - } - mc.mutation = mutation - if node, err = mc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mc.hooks) - 1; i >= 0; i-- { - if mc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Message) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MessageMutation", v) - } - node = nv - } - return node, err + return withHooks[*Message, MessageMutation](ctx, mc.sqlSave, mc.mutation, mc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -203,27 +161,30 @@ func (mc *MessageCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (mc *MessageCreate) check() error { if _, ok := mc.mutation.Date(); !ok { - return &ValidationError{Name: "Date", err: errors.New(`ent: missing required field "Message.Date"`)} + return &ValidationError{Name: "Date", err: errors.New(`internal: missing required field "Message.Date"`)} } if _, ok := mc.mutation.Size(); !ok { - return &ValidationError{Name: "Size", err: errors.New(`ent: missing required field "Message.Size"`)} + return &ValidationError{Name: "Size", err: errors.New(`internal: missing required field "Message.Size"`)} } if _, ok := mc.mutation.Body(); !ok { - return &ValidationError{Name: "Body", err: errors.New(`ent: missing required field "Message.Body"`)} + return &ValidationError{Name: "Body", err: errors.New(`internal: missing required field "Message.Body"`)} } if _, ok := mc.mutation.BodyStructure(); !ok { - return &ValidationError{Name: "BodyStructure", err: errors.New(`ent: missing required field "Message.BodyStructure"`)} + return &ValidationError{Name: "BodyStructure", err: errors.New(`internal: missing required field "Message.BodyStructure"`)} } if _, ok := mc.mutation.Envelope(); !ok { - return &ValidationError{Name: "Envelope", err: errors.New(`ent: missing required field "Message.Envelope"`)} + return &ValidationError{Name: "Envelope", err: errors.New(`internal: missing required field "Message.Envelope"`)} } if _, ok := mc.mutation.Deleted(); !ok { - return &ValidationError{Name: "Deleted", err: errors.New(`ent: missing required field "Message.Deleted"`)} + return &ValidationError{Name: "Deleted", err: errors.New(`internal: missing required field "Message.Deleted"`)} } return nil } func (mc *MessageCreate) sqlSave(ctx context.Context) (*Message, error) { + if err := mc.check(); err != nil { + return nil, err + } _node, _spec := mc.createSpec() if err := sqlgraph.CreateNode(ctx, mc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -238,78 +199,46 @@ func (mc *MessageCreate) sqlSave(ctx context.Context) (*Message, error) { return nil, err } } + mc.mutation.id = &_node.ID + mc.mutation.done = true return _node, nil } func (mc *MessageCreate) createSpec() (*Message, *sqlgraph.CreateSpec) { var ( _node = &Message{config: mc.config} - _spec = &sqlgraph.CreateSpec{ - Table: message.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUUID, - Column: message.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(message.Table, sqlgraph.NewFieldSpec(message.FieldID, field.TypeUUID)) ) if id, ok := mc.mutation.ID(); ok { _node.ID = id _spec.ID.Value = &id } if value, ok := mc.mutation.RemoteID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldRemoteID, - }) + _spec.SetField(message.FieldRemoteID, field.TypeString, value) _node.RemoteID = value } if value, ok := mc.mutation.Date(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: message.FieldDate, - }) + _spec.SetField(message.FieldDate, field.TypeTime, value) _node.Date = value } if value, ok := mc.mutation.Size(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Value: value, - Column: message.FieldSize, - }) + _spec.SetField(message.FieldSize, field.TypeInt, value) _node.Size = value } if value, ok := mc.mutation.Body(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldBody, - }) + _spec.SetField(message.FieldBody, field.TypeString, value) _node.Body = value } if value, ok := mc.mutation.BodyStructure(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldBodyStructure, - }) + _spec.SetField(message.FieldBodyStructure, field.TypeString, value) _node.BodyStructure = value } if value, ok := mc.mutation.Envelope(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldEnvelope, - }) + _spec.SetField(message.FieldEnvelope, field.TypeString, value) _node.Envelope = value } if value, ok := mc.mutation.Deleted(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: message.FieldDeleted, - }) + _spec.SetField(message.FieldDeleted, field.TypeBool, value) _node.Deleted = value } if nodes := mc.mutation.FlagsIDs(); len(nodes) > 0 { diff --git a/internal/db/ent/message_delete.go b/internal/db_impl/ent_db/internal/message_delete.go similarity index 62% rename from internal/db/ent/message_delete.go rename to internal/db_impl/ent_db/internal/message_delete.go index b5094106..60640466 100644 --- a/internal/db/ent/message_delete.go +++ b/internal/db_impl/ent_db/internal/message_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MessageDelete is the builder for deleting a Message entity. @@ -28,34 +27,7 @@ func (md *MessageDelete) Where(ps ...predicate.Message) *MessageDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (md *MessageDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(md.hooks) == 0 { - affected, err = md.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - md.mutation = mutation - affected, err = md.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(md.hooks) - 1; i >= 0; i-- { - if md.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = md.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, md.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MessageMutation](ctx, md.sqlExec, md.mutation, md.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (md *MessageDelete) ExecX(ctx context.Context) int { } func (md *MessageDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: message.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUUID, - Column: message.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(message.Table, sqlgraph.NewFieldSpec(message.FieldID, field.TypeUUID)) if ps := md.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (md *MessageDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + md.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MessageDeleteOne struct { md *MessageDelete } +// Where appends a list predicates to the MessageDelete builder. +func (mdo *MessageDeleteOne) Where(ps ...predicate.Message) *MessageDeleteOne { + mdo.md.mutation.Where(ps...) + return mdo +} + // Exec executes the deletion query. func (mdo *MessageDeleteOne) Exec(ctx context.Context) error { n, err := mdo.md.Exec(ctx) @@ -111,5 +82,7 @@ func (mdo *MessageDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mdo *MessageDeleteOne) ExecX(ctx context.Context) { - mdo.md.ExecX(ctx) + if err := mdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/message_query.go b/internal/db_impl/ent_db/internal/message_query.go similarity index 73% rename from internal/db/ent/message_query.go rename to internal/db_impl/ent_db/internal/message_query.go index 84416328..d1ac2c4c 100644 --- a/internal/db/ent/message_query.go +++ b/internal/db_impl/ent_db/internal/message_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -12,20 +12,18 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // MessageQuery is the builder for querying Message entities. type MessageQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.Message withFlags *MessageFlagQuery withUIDs *UIDQuery @@ -40,26 +38,26 @@ func (mq *MessageQuery) Where(ps ...predicate.Message) *MessageQuery { return mq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mq *MessageQuery) Limit(limit int) *MessageQuery { - mq.limit = &limit + mq.ctx.Limit = &limit return mq } -// Offset adds an offset step to the query. +// Offset to start from. func (mq *MessageQuery) Offset(offset int) *MessageQuery { - mq.offset = &offset + mq.ctx.Offset = &offset return mq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mq *MessageQuery) Unique(unique bool) *MessageQuery { - mq.unique = &unique + mq.ctx.Unique = &unique return mq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (mq *MessageQuery) Order(o ...OrderFunc) *MessageQuery { mq.order = append(mq.order, o...) return mq @@ -67,7 +65,7 @@ func (mq *MessageQuery) Order(o ...OrderFunc) *MessageQuery { // QueryFlags chains the current query on the "flags" edge. func (mq *MessageQuery) QueryFlags() *MessageFlagQuery { - query := &MessageFlagQuery{config: mq.config} + query := (&MessageFlagClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -89,7 +87,7 @@ func (mq *MessageQuery) QueryFlags() *MessageFlagQuery { // QueryUIDs chains the current query on the "UIDs" edge. func (mq *MessageQuery) QueryUIDs() *UIDQuery { - query := &UIDQuery{config: mq.config} + query := (&UIDClient{config: mq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mq.prepareQuery(ctx); err != nil { return nil, err @@ -112,7 +110,7 @@ func (mq *MessageQuery) QueryUIDs() *UIDQuery { // First returns the first Message entity from the query. // Returns a *NotFoundError when no Message was found. func (mq *MessageQuery) First(ctx context.Context) (*Message, error) { - nodes, err := mq.Limit(1).All(ctx) + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) if err != nil { return nil, err } @@ -135,7 +133,7 @@ func (mq *MessageQuery) FirstX(ctx context.Context) *Message { // Returns a *NotFoundError when no Message ID was found. func (mq *MessageQuery) FirstID(ctx context.Context) (id imap.InternalMessageID, err error) { var ids []imap.InternalMessageID - if ids, err = mq.Limit(1).IDs(ctx); err != nil { + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -158,7 +156,7 @@ func (mq *MessageQuery) FirstIDX(ctx context.Context) imap.InternalMessageID { // Returns a *NotSingularError when more than one Message entity is found. // Returns a *NotFoundError when no Message entities are found. func (mq *MessageQuery) Only(ctx context.Context) (*Message, error) { - nodes, err := mq.Limit(2).All(ctx) + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) if err != nil { return nil, err } @@ -186,7 +184,7 @@ func (mq *MessageQuery) OnlyX(ctx context.Context) *Message { // Returns a *NotFoundError when no entities are found. func (mq *MessageQuery) OnlyID(ctx context.Context) (id imap.InternalMessageID, err error) { var ids []imap.InternalMessageID - if ids, err = mq.Limit(2).IDs(ctx); err != nil { + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -211,10 +209,12 @@ func (mq *MessageQuery) OnlyIDX(ctx context.Context) imap.InternalMessageID { // All executes the query and returns a list of Messages. func (mq *MessageQuery) All(ctx context.Context) ([]*Message, error) { + ctx = setContextOp(ctx, mq.ctx, "All") if err := mq.prepareQuery(ctx); err != nil { return nil, err } - return mq.sqlAll(ctx) + qr := querierAll[[]*Message, *MessageQuery]() + return withInterceptors[[]*Message](ctx, mq, qr, mq.inters) } // AllX is like All, but panics if an error occurs. @@ -227,9 +227,12 @@ func (mq *MessageQuery) AllX(ctx context.Context) []*Message { } // IDs executes the query and returns a list of Message IDs. -func (mq *MessageQuery) IDs(ctx context.Context) ([]imap.InternalMessageID, error) { - var ids []imap.InternalMessageID - if err := mq.Select(message.FieldID).Scan(ctx, &ids); err != nil { +func (mq *MessageQuery) IDs(ctx context.Context) (ids []imap.InternalMessageID, err error) { + if mq.ctx.Unique == nil && mq.path != nil { + mq.Unique(true) + } + ctx = setContextOp(ctx, mq.ctx, "IDs") + if err = mq.Select(message.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -246,10 +249,11 @@ func (mq *MessageQuery) IDsX(ctx context.Context) []imap.InternalMessageID { // Count returns the count of the given query. func (mq *MessageQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mq.ctx, "Count") if err := mq.prepareQuery(ctx); err != nil { return 0, err } - return mq.sqlCount(ctx) + return withInterceptors[int](ctx, mq, querierCount[*MessageQuery](), mq.inters) } // CountX is like Count, but panics if an error occurs. @@ -263,10 +267,15 @@ func (mq *MessageQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mq *MessageQuery) Exist(ctx context.Context) (bool, error) { - if err := mq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mq.ctx, "Exist") + switch _, err := mq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return mq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -286,23 +295,22 @@ func (mq *MessageQuery) Clone() *MessageQuery { } return &MessageQuery{ config: mq.config, - limit: mq.limit, - offset: mq.offset, + ctx: mq.ctx.Clone(), order: append([]OrderFunc{}, mq.order...), + inters: append([]Interceptor{}, mq.inters...), predicates: append([]predicate.Message{}, mq.predicates...), withFlags: mq.withFlags.Clone(), withUIDs: mq.withUIDs.Clone(), // clone intermediate query. - sql: mq.sql.Clone(), - path: mq.path, - unique: mq.unique, + sql: mq.sql.Clone(), + path: mq.path, } } // WithFlags tells the query-builder to eager-load the nodes that are connected to // the "flags" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MessageQuery) WithFlags(opts ...func(*MessageFlagQuery)) *MessageQuery { - query := &MessageFlagQuery{config: mq.config} + query := (&MessageFlagClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -313,7 +321,7 @@ func (mq *MessageQuery) WithFlags(opts ...func(*MessageFlagQuery)) *MessageQuery // WithUIDs tells the query-builder to eager-load the nodes that are connected to // the "UIDs" edge. The optional arguments are used to configure the query builder of the edge. func (mq *MessageQuery) WithUIDs(opts ...func(*UIDQuery)) *MessageQuery { - query := &UIDQuery{config: mq.config} + query := (&UIDClient{config: mq.config}).Query() for _, opt := range opts { opt(query) } @@ -333,19 +341,14 @@ func (mq *MessageQuery) WithUIDs(opts ...func(*UIDQuery)) *MessageQuery { // // client.Message.Query(). // GroupBy(message.FieldRemoteID). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (mq *MessageQuery) GroupBy(field string, fields ...string) *MessageGroupBy { - grbuild := &MessageGroupBy{config: mq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mq.prepareQuery(ctx); err != nil { - return nil, err - } - return mq.sqlQuery(ctx), nil - } + mq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MessageGroupBy{build: mq} + grbuild.flds = &mq.ctx.Fields grbuild.label = message.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -362,17 +365,32 @@ func (mq *MessageQuery) GroupBy(field string, fields ...string) *MessageGroupBy // Select(message.FieldRemoteID). // Scan(ctx, &v) func (mq *MessageQuery) Select(fields ...string) *MessageSelect { - mq.fields = append(mq.fields, fields...) - selbuild := &MessageSelect{MessageQuery: mq} - selbuild.label = message.Label - selbuild.flds, selbuild.scan = &mq.fields, selbuild.Scan - return selbuild + mq.ctx.Fields = append(mq.ctx.Fields, fields...) + sbuild := &MessageSelect{MessageQuery: mq} + sbuild.label = message.Label + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MessageSelect configured with the given aggregations. +func (mq *MessageQuery) Aggregate(fns ...AggregateFunc) *MessageSelect { + return mq.Select().Aggregate(fns...) } func (mq *MessageQuery) prepareQuery(ctx context.Context) error { - for _, f := range mq.fields { + for _, inter := range mq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mq); err != nil { + return err + } + } + } + for _, f := range mq.ctx.Fields { if !message.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if mq.path != nil { @@ -394,10 +412,10 @@ func (mq *MessageQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Mess mq.withUIDs != nil, } ) - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*Message).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &Message{config: mq.config} nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes @@ -494,38 +512,22 @@ func (mq *MessageQuery) loadUIDs(ctx context.Context, query *UIDQuery, nodes []* func (mq *MessageQuery) sqlCount(ctx context.Context) (int, error) { _spec := mq.querySpec() - _spec.Node.Columns = mq.fields - if len(mq.fields) > 0 { - _spec.Unique = mq.unique != nil && *mq.unique + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique } return sqlgraph.CountNodes(ctx, mq.driver, _spec) } -func (mq *MessageQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := mq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (mq *MessageQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: message.Table, - Columns: message.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUUID, - Column: message.FieldID, - }, - }, - From: mq.sql, - Unique: true, - } - if unique := mq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(message.Table, message.Columns, sqlgraph.NewFieldSpec(message.FieldID, field.TypeUUID)) + _spec.From = mq.sql + if unique := mq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mq.path != nil { + _spec.Unique = true } - if fields := mq.fields; len(fields) > 0 { + if fields := mq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, message.FieldID) for i := range fields { @@ -541,10 +543,10 @@ func (mq *MessageQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mq.order; len(ps) > 0 { @@ -560,7 +562,7 @@ func (mq *MessageQuery) querySpec() *sqlgraph.QuerySpec { func (mq *MessageQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mq.driver.Dialect()) t1 := builder.Table(message.Table) - columns := mq.fields + columns := mq.ctx.Fields if len(columns) == 0 { columns = message.Columns } @@ -569,7 +571,7 @@ func (mq *MessageQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mq.sql selector.Select(selector.Columns(columns...)...) } - if mq.unique != nil && *mq.unique { + if mq.ctx.Unique != nil && *mq.ctx.Unique { selector.Distinct() } for _, p := range mq.predicates { @@ -578,12 +580,12 @@ func (mq *MessageQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mq.order { p(selector) } - if offset := mq.offset; offset != nil { + if offset := mq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mq.limit; limit != nil { + if limit := mq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -591,13 +593,8 @@ func (mq *MessageQuery) sqlQuery(ctx context.Context) *sql.Selector { // MessageGroupBy is the group-by builder for Message entities. type MessageGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MessageQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -606,74 +603,77 @@ func (mgb *MessageGroupBy) Aggregate(fns ...AggregateFunc) *MessageGroupBy { return mgb } -// Scan applies the group-by query and scans the result into the given value. -func (mgb *MessageGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := mgb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (mgb *MessageGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") + if err := mgb.build.prepareQuery(ctx); err != nil { return err } - mgb.sql = query - return mgb.sqlScan(ctx, v) + return scanWithInterceptors[*MessageQuery, *MessageGroupBy](ctx, mgb.build, mgb, mgb.build.inters, v) } -func (mgb *MessageGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range mgb.fields { - if !message.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mgb *MessageGroupBy) sqlScan(ctx context.Context, root *MessageQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mgb.fns)) + for _, fn := range mgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mgb.flds)+len(mgb.fns)) + for _, f := range *mgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mgb.sqlQuery() + selector.GroupBy(selector.Columns(*mgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mgb *MessageGroupBy) sqlQuery() *sql.Selector { - selector := mgb.sql.Select() - aggregation := make([]string, 0, len(mgb.fns)) - for _, fn := range mgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mgb.fields)+len(mgb.fns)) - for _, f := range mgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mgb.fields...)...) -} - // MessageSelect is the builder for selecting fields of Message entities. type MessageSelect struct { *MessageQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ms *MessageSelect) Aggregate(fns ...AggregateFunc) *MessageSelect { + ms.fns = append(ms.fns, fns...) + return ms } // Scan applies the selector query and scans the result into the given value. -func (ms *MessageSelect) Scan(ctx context.Context, v interface{}) error { +func (ms *MessageSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ms.ctx, "Select") if err := ms.prepareQuery(ctx); err != nil { return err } - ms.sql = ms.MessageQuery.sqlQuery(ctx) - return ms.sqlScan(ctx, v) + return scanWithInterceptors[*MessageQuery, *MessageSelect](ctx, ms.MessageQuery, ms, ms.inters, v) } -func (ms *MessageSelect) sqlScan(ctx context.Context, v interface{}) error { +func (ms *MessageSelect) sqlScan(ctx context.Context, root *MessageQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ms.fns)) + for _, fn := range ms.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ms.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ms.sql.Query() + query, args := selector.Query() if err := ms.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/message_update.go b/internal/db_impl/ent_db/internal/message_update.go similarity index 78% rename from internal/db/ent/message_update.go rename to internal/db_impl/ent_db/internal/message_update.go index 15564592..ac56d906 100644 --- a/internal/db/ent/message_update.go +++ b/internal/db_impl/ent_db/internal/message_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -12,10 +12,10 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // MessageUpdate is the builder for updating Message entities. @@ -181,34 +181,7 @@ func (mu *MessageUpdate) RemoveUIDs(u ...*UID) *MessageUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (mu *MessageUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mu.hooks) == 0 { - affected, err = mu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mu.mutation = mutation - affected, err = mu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mu.hooks) - 1; i >= 0; i-- { - if mu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MessageMutation](ctx, mu.sqlSave, mu.mutation, mu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -234,16 +207,7 @@ func (mu *MessageUpdate) ExecX(ctx context.Context) { } func (mu *MessageUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: message.Table, - Columns: message.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUUID, - Column: message.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(message.Table, message.Columns, sqlgraph.NewFieldSpec(message.FieldID, field.TypeUUID)) if ps := mu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -252,66 +216,31 @@ func (mu *MessageUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := mu.mutation.RemoteID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldRemoteID, - }) + _spec.SetField(message.FieldRemoteID, field.TypeString, value) } if mu.mutation.RemoteIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: message.FieldRemoteID, - }) + _spec.ClearField(message.FieldRemoteID, field.TypeString) } if value, ok := mu.mutation.Date(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: message.FieldDate, - }) + _spec.SetField(message.FieldDate, field.TypeTime, value) } if value, ok := mu.mutation.Size(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Value: value, - Column: message.FieldSize, - }) + _spec.SetField(message.FieldSize, field.TypeInt, value) } if value, ok := mu.mutation.AddedSize(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Value: value, - Column: message.FieldSize, - }) + _spec.AddField(message.FieldSize, field.TypeInt, value) } if value, ok := mu.mutation.Body(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldBody, - }) + _spec.SetField(message.FieldBody, field.TypeString, value) } if value, ok := mu.mutation.BodyStructure(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldBodyStructure, - }) + _spec.SetField(message.FieldBodyStructure, field.TypeString, value) } if value, ok := mu.mutation.Envelope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldEnvelope, - }) + _spec.SetField(message.FieldEnvelope, field.TypeString, value) } if value, ok := mu.mutation.Deleted(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: message.FieldDeleted, - }) + _spec.SetField(message.FieldDeleted, field.TypeBool, value) } if mu.mutation.FlagsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -429,6 +358,7 @@ func (mu *MessageUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mu.mutation.done = true return n, nil } @@ -588,6 +518,12 @@ func (muo *MessageUpdateOne) RemoveUIDs(u ...*UID) *MessageUpdateOne { return muo.RemoveUIDIDs(ids...) } +// Where appends a list predicates to the MessageUpdate builder. +func (muo *MessageUpdateOne) Where(ps ...predicate.Message) *MessageUpdateOne { + muo.mutation.Where(ps...) + return muo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (muo *MessageUpdateOne) Select(field string, fields ...string) *MessageUpdateOne { @@ -597,40 +533,7 @@ func (muo *MessageUpdateOne) Select(field string, fields ...string) *MessageUpda // Save executes the query and returns the updated Message entity. func (muo *MessageUpdateOne) Save(ctx context.Context) (*Message, error) { - var ( - err error - node *Message - ) - if len(muo.hooks) == 0 { - node, err = muo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - muo.mutation = mutation - node, err = muo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(muo.hooks) - 1; i >= 0; i-- { - if muo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = muo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, muo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Message) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MessageMutation", v) - } - node = nv - } - return node, err + return withHooks[*Message, MessageMutation](ctx, muo.sqlSave, muo.mutation, muo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -656,19 +559,10 @@ func (muo *MessageUpdateOne) ExecX(ctx context.Context) { } func (muo *MessageUpdateOne) sqlSave(ctx context.Context) (_node *Message, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: message.Table, - Columns: message.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeUUID, - Column: message.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(message.Table, message.Columns, sqlgraph.NewFieldSpec(message.FieldID, field.TypeUUID)) id, ok := muo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Message.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "Message.id" for update`)} } _spec.Node.ID.Value = id if fields := muo.fields; len(fields) > 0 { @@ -676,7 +570,7 @@ func (muo *MessageUpdateOne) sqlSave(ctx context.Context) (_node *Message, err e _spec.Node.Columns = append(_spec.Node.Columns, message.FieldID) for _, f := range fields { if !message.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != message.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -691,66 +585,31 @@ func (muo *MessageUpdateOne) sqlSave(ctx context.Context) (_node *Message, err e } } if value, ok := muo.mutation.RemoteID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldRemoteID, - }) + _spec.SetField(message.FieldRemoteID, field.TypeString, value) } if muo.mutation.RemoteIDCleared() { - _spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Column: message.FieldRemoteID, - }) + _spec.ClearField(message.FieldRemoteID, field.TypeString) } if value, ok := muo.mutation.Date(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: message.FieldDate, - }) + _spec.SetField(message.FieldDate, field.TypeTime, value) } if value, ok := muo.mutation.Size(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Value: value, - Column: message.FieldSize, - }) + _spec.SetField(message.FieldSize, field.TypeInt, value) } if value, ok := muo.mutation.AddedSize(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Value: value, - Column: message.FieldSize, - }) + _spec.AddField(message.FieldSize, field.TypeInt, value) } if value, ok := muo.mutation.Body(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldBody, - }) + _spec.SetField(message.FieldBody, field.TypeString, value) } if value, ok := muo.mutation.BodyStructure(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldBodyStructure, - }) + _spec.SetField(message.FieldBodyStructure, field.TypeString, value) } if value, ok := muo.mutation.Envelope(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: message.FieldEnvelope, - }) + _spec.SetField(message.FieldEnvelope, field.TypeString, value) } if value, ok := muo.mutation.Deleted(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: message.FieldDeleted, - }) + _spec.SetField(message.FieldDeleted, field.TypeBool, value) } if muo.mutation.FlagsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -871,5 +730,6 @@ func (muo *MessageUpdateOne) sqlSave(ctx context.Context) (_node *Message, err e } return nil, err } + muo.mutation.done = true return _node, nil } diff --git a/internal/db/ent/messageflag.go b/internal/db_impl/ent_db/internal/messageflag.go similarity index 86% rename from internal/db/ent/messageflag.go rename to internal/db_impl/ent_db/internal/messageflag.go index f1d82172..331e2cef 100644 --- a/internal/db/ent/messageflag.go +++ b/internal/db_impl/ent_db/internal/messageflag.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -8,8 +8,8 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" ) // MessageFlag is the model entity for the MessageFlag schema. @@ -48,8 +48,8 @@ func (e MessageFlagEdges) MessagesOrErr() (*Message, error) { } // scanValues returns the types for scanning values from sql.Rows. -func (*MessageFlag) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*MessageFlag) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case messageflag.FieldID: @@ -67,7 +67,7 @@ func (*MessageFlag) scanValues(columns []string) ([]interface{}, error) { // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the MessageFlag fields. -func (mf *MessageFlag) assignValues(columns []string, values []interface{}) error { +func (mf *MessageFlag) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -99,14 +99,14 @@ func (mf *MessageFlag) assignValues(columns []string, values []interface{}) erro // QueryMessages queries the "messages" edge of the MessageFlag entity. func (mf *MessageFlag) QueryMessages() *MessageQuery { - return (&MessageFlagClient{config: mf.config}).QueryMessages(mf) + return NewMessageFlagClient(mf.config).QueryMessages(mf) } // Update returns a builder for updating this MessageFlag. // Note that you need to call MessageFlag.Unwrap() before calling this method if this MessageFlag // was returned from a transaction, and the transaction was committed or rolled back. func (mf *MessageFlag) Update() *MessageFlagUpdateOne { - return (&MessageFlagClient{config: mf.config}).UpdateOne(mf) + return NewMessageFlagClient(mf.config).UpdateOne(mf) } // Unwrap unwraps the MessageFlag entity that was returned from a transaction after it was closed, @@ -114,7 +114,7 @@ func (mf *MessageFlag) Update() *MessageFlagUpdateOne { func (mf *MessageFlag) Unwrap() *MessageFlag { _tx, ok := mf.config.driver.(*txDriver) if !ok { - panic("ent: MessageFlag is not a transactional entity") + panic("internal: MessageFlag is not a transactional entity") } mf.config.driver = _tx.drv return mf @@ -133,9 +133,3 @@ func (mf *MessageFlag) String() string { // MessageFlags is a parsable slice of MessageFlag. type MessageFlags []*MessageFlag - -func (mf MessageFlags) config(cfg config) { - for _i := range mf { - mf[_i].config = cfg - } -} diff --git a/internal/db/ent/messageflag/messageflag.go b/internal/db_impl/ent_db/internal/messageflag/messageflag.go similarity index 100% rename from internal/db/ent/messageflag/messageflag.go rename to internal/db_impl/ent_db/internal/messageflag/messageflag.go diff --git a/internal/db/ent/messageflag/where.go b/internal/db_impl/ent_db/internal/messageflag/where.go similarity index 62% rename from internal/db/ent/messageflag/where.go rename to internal/db_impl/ent_db/internal/messageflag/where.go index a4377f12..65cc881a 100644 --- a/internal/db/ent/messageflag/where.go +++ b/internal/db_impl/ent_db/internal/messageflag/where.go @@ -5,184 +5,122 @@ package messageflag import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MessageFlag(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.MessageFlag(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.MessageFlag(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.MessageFlag(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.MessageFlag(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.MessageFlag(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.MessageFlag(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.MessageFlag(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.MessageFlag(sql.FieldLTE(FieldID, id)) } // Value applies equality check predicate on the "Value" field. It's identical to ValueEQ. func Value(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldEQ(FieldValue, v)) } // ValueEQ applies the EQ predicate on the "Value" field. func ValueEQ(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "Value" field. func ValueNEQ(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "Value" field. func ValueIn(vs ...string) predicate.MessageFlag { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.MessageFlag(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "Value" field. func ValueNotIn(vs ...string) predicate.MessageFlag { - v := make([]interface{}, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.MessageFlag(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "Value" field. func ValueGT(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "Value" field. func ValueGTE(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "Value" field. func ValueLT(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "Value" field. func ValueLTE(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "Value" field. func ValueContains(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "Value" field. func ValueHasPrefix(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "Value" field. func ValueHasSuffix(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "Value" field. func ValueEqualFold(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "Value" field. func ValueContainsFold(v string) predicate.MessageFlag { - return predicate.MessageFlag(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.MessageFlag(sql.FieldContainsFold(FieldValue, v)) } // HasMessages applies the HasEdge predicate on the "messages" edge. @@ -190,7 +128,6 @@ func HasMessages() predicate.MessageFlag { return predicate.MessageFlag(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(MessagesTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, MessagesTable, MessagesColumn), ) sqlgraph.HasNeighbors(s, step) diff --git a/internal/db/ent/messageflag_create.go b/internal/db_impl/ent_db/internal/messageflag_create.go similarity index 78% rename from internal/db/ent/messageflag_create.go rename to internal/db_impl/ent_db/internal/messageflag_create.go index 7b7bec11..c99f9e1a 100644 --- a/internal/db/ent/messageflag_create.go +++ b/internal/db_impl/ent_db/internal/messageflag_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,8 +10,8 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" ) // MessageFlagCreate is the builder for creating a MessageFlag entity. @@ -53,49 +53,7 @@ func (mfc *MessageFlagCreate) Mutation() *MessageFlagMutation { // Save creates the MessageFlag in the database. func (mfc *MessageFlagCreate) Save(ctx context.Context) (*MessageFlag, error) { - var ( - err error - node *MessageFlag - ) - if len(mfc.hooks) == 0 { - if err = mfc.check(); err != nil { - return nil, err - } - node, err = mfc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = mfc.check(); err != nil { - return nil, err - } - mfc.mutation = mutation - if node, err = mfc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(mfc.hooks) - 1; i >= 0; i-- { - if mfc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mfc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MessageFlag) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MessageFlagMutation", v) - } - node = nv - } - return node, err + return withHooks[*MessageFlag, MessageFlagMutation](ctx, mfc.sqlSave, mfc.mutation, mfc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -123,12 +81,15 @@ func (mfc *MessageFlagCreate) ExecX(ctx context.Context) { // check runs all checks and user-defined validators on the builder. func (mfc *MessageFlagCreate) check() error { if _, ok := mfc.mutation.Value(); !ok { - return &ValidationError{Name: "Value", err: errors.New(`ent: missing required field "MessageFlag.Value"`)} + return &ValidationError{Name: "Value", err: errors.New(`internal: missing required field "MessageFlag.Value"`)} } return nil } func (mfc *MessageFlagCreate) sqlSave(ctx context.Context) (*MessageFlag, error) { + if err := mfc.check(); err != nil { + return nil, err + } _node, _spec := mfc.createSpec() if err := sqlgraph.CreateNode(ctx, mfc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -138,26 +99,18 @@ func (mfc *MessageFlagCreate) sqlSave(ctx context.Context) (*MessageFlag, error) } id := _spec.ID.Value.(int64) _node.ID = int(id) + mfc.mutation.id = &_node.ID + mfc.mutation.done = true return _node, nil } func (mfc *MessageFlagCreate) createSpec() (*MessageFlag, *sqlgraph.CreateSpec) { var ( _node = &MessageFlag{config: mfc.config} - _spec = &sqlgraph.CreateSpec{ - Table: messageflag.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: messageflag.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(messageflag.Table, sqlgraph.NewFieldSpec(messageflag.FieldID, field.TypeInt)) ) if value, ok := mfc.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: messageflag.FieldValue, - }) + _spec.SetField(messageflag.FieldValue, field.TypeString, value) _node.Value = value } if nodes := mfc.mutation.MessagesIDs(); len(nodes) > 0 { diff --git a/internal/db/ent/messageflag_delete.go b/internal/db_impl/ent_db/internal/messageflag_delete.go similarity index 62% rename from internal/db/ent/messageflag_delete.go rename to internal/db_impl/ent_db/internal/messageflag_delete.go index 728d5d8d..f6371b70 100644 --- a/internal/db/ent/messageflag_delete.go +++ b/internal/db_impl/ent_db/internal/messageflag_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MessageFlagDelete is the builder for deleting a MessageFlag entity. @@ -28,34 +27,7 @@ func (mfd *MessageFlagDelete) Where(ps ...predicate.MessageFlag) *MessageFlagDel // Exec executes the deletion query and returns how many vertices were deleted. func (mfd *MessageFlagDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mfd.hooks) == 0 { - affected, err = mfd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mfd.mutation = mutation - affected, err = mfd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(mfd.hooks) - 1; i >= 0; i-- { - if mfd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mfd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MessageFlagMutation](ctx, mfd.sqlExec, mfd.mutation, mfd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (mfd *MessageFlagDelete) ExecX(ctx context.Context) int { } func (mfd *MessageFlagDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: messageflag.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: messageflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(messageflag.Table, sqlgraph.NewFieldSpec(messageflag.FieldID, field.TypeInt)) if ps := mfd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (mfd *MessageFlagDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + mfd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type MessageFlagDeleteOne struct { mfd *MessageFlagDelete } +// Where appends a list predicates to the MessageFlagDelete builder. +func (mfdo *MessageFlagDeleteOne) Where(ps ...predicate.MessageFlag) *MessageFlagDeleteOne { + mfdo.mfd.mutation.Where(ps...) + return mfdo +} + // Exec executes the deletion query. func (mfdo *MessageFlagDeleteOne) Exec(ctx context.Context) error { n, err := mfdo.mfd.Exec(ctx) @@ -111,5 +82,7 @@ func (mfdo *MessageFlagDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (mfdo *MessageFlagDeleteOne) ExecX(ctx context.Context) { - mfdo.mfd.ExecX(ctx) + if err := mfdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/messageflag_query.go b/internal/db_impl/ent_db/internal/messageflag_query.go similarity index 70% rename from internal/db/ent/messageflag_query.go rename to internal/db_impl/ent_db/internal/messageflag_query.go index 39244ca8..5d15eb5d 100644 --- a/internal/db/ent/messageflag_query.go +++ b/internal/db_impl/ent_db/internal/messageflag_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -11,19 +11,17 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MessageFlagQuery is the builder for querying MessageFlag entities. type MessageFlagQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.MessageFlag withMessages *MessageQuery withFKs bool @@ -38,26 +36,26 @@ func (mfq *MessageFlagQuery) Where(ps ...predicate.MessageFlag) *MessageFlagQuer return mfq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (mfq *MessageFlagQuery) Limit(limit int) *MessageFlagQuery { - mfq.limit = &limit + mfq.ctx.Limit = &limit return mfq } -// Offset adds an offset step to the query. +// Offset to start from. func (mfq *MessageFlagQuery) Offset(offset int) *MessageFlagQuery { - mfq.offset = &offset + mfq.ctx.Offset = &offset return mfq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (mfq *MessageFlagQuery) Unique(unique bool) *MessageFlagQuery { - mfq.unique = &unique + mfq.ctx.Unique = &unique return mfq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (mfq *MessageFlagQuery) Order(o ...OrderFunc) *MessageFlagQuery { mfq.order = append(mfq.order, o...) return mfq @@ -65,7 +63,7 @@ func (mfq *MessageFlagQuery) Order(o ...OrderFunc) *MessageFlagQuery { // QueryMessages chains the current query on the "messages" edge. func (mfq *MessageFlagQuery) QueryMessages() *MessageQuery { - query := &MessageQuery{config: mfq.config} + query := (&MessageClient{config: mfq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := mfq.prepareQuery(ctx); err != nil { return nil, err @@ -88,7 +86,7 @@ func (mfq *MessageFlagQuery) QueryMessages() *MessageQuery { // First returns the first MessageFlag entity from the query. // Returns a *NotFoundError when no MessageFlag was found. func (mfq *MessageFlagQuery) First(ctx context.Context) (*MessageFlag, error) { - nodes, err := mfq.Limit(1).All(ctx) + nodes, err := mfq.Limit(1).All(setContextOp(ctx, mfq.ctx, "First")) if err != nil { return nil, err } @@ -111,7 +109,7 @@ func (mfq *MessageFlagQuery) FirstX(ctx context.Context) *MessageFlag { // Returns a *NotFoundError when no MessageFlag ID was found. func (mfq *MessageFlagQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mfq.Limit(1).IDs(ctx); err != nil { + if ids, err = mfq.Limit(1).IDs(setContextOp(ctx, mfq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -134,7 +132,7 @@ func (mfq *MessageFlagQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one MessageFlag entity is found. // Returns a *NotFoundError when no MessageFlag entities are found. func (mfq *MessageFlagQuery) Only(ctx context.Context) (*MessageFlag, error) { - nodes, err := mfq.Limit(2).All(ctx) + nodes, err := mfq.Limit(2).All(setContextOp(ctx, mfq.ctx, "Only")) if err != nil { return nil, err } @@ -162,7 +160,7 @@ func (mfq *MessageFlagQuery) OnlyX(ctx context.Context) *MessageFlag { // Returns a *NotFoundError when no entities are found. func (mfq *MessageFlagQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = mfq.Limit(2).IDs(ctx); err != nil { + if ids, err = mfq.Limit(2).IDs(setContextOp(ctx, mfq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -187,10 +185,12 @@ func (mfq *MessageFlagQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of MessageFlags. func (mfq *MessageFlagQuery) All(ctx context.Context) ([]*MessageFlag, error) { + ctx = setContextOp(ctx, mfq.ctx, "All") if err := mfq.prepareQuery(ctx); err != nil { return nil, err } - return mfq.sqlAll(ctx) + qr := querierAll[[]*MessageFlag, *MessageFlagQuery]() + return withInterceptors[[]*MessageFlag](ctx, mfq, qr, mfq.inters) } // AllX is like All, but panics if an error occurs. @@ -203,9 +203,12 @@ func (mfq *MessageFlagQuery) AllX(ctx context.Context) []*MessageFlag { } // IDs executes the query and returns a list of MessageFlag IDs. -func (mfq *MessageFlagQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := mfq.Select(messageflag.FieldID).Scan(ctx, &ids); err != nil { +func (mfq *MessageFlagQuery) IDs(ctx context.Context) (ids []int, err error) { + if mfq.ctx.Unique == nil && mfq.path != nil { + mfq.Unique(true) + } + ctx = setContextOp(ctx, mfq.ctx, "IDs") + if err = mfq.Select(messageflag.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -222,10 +225,11 @@ func (mfq *MessageFlagQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (mfq *MessageFlagQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mfq.ctx, "Count") if err := mfq.prepareQuery(ctx); err != nil { return 0, err } - return mfq.sqlCount(ctx) + return withInterceptors[int](ctx, mfq, querierCount[*MessageFlagQuery](), mfq.inters) } // CountX is like Count, but panics if an error occurs. @@ -239,10 +243,15 @@ func (mfq *MessageFlagQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (mfq *MessageFlagQuery) Exist(ctx context.Context) (bool, error) { - if err := mfq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, mfq.ctx, "Exist") + switch _, err := mfq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return mfq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -262,22 +271,21 @@ func (mfq *MessageFlagQuery) Clone() *MessageFlagQuery { } return &MessageFlagQuery{ config: mfq.config, - limit: mfq.limit, - offset: mfq.offset, + ctx: mfq.ctx.Clone(), order: append([]OrderFunc{}, mfq.order...), + inters: append([]Interceptor{}, mfq.inters...), predicates: append([]predicate.MessageFlag{}, mfq.predicates...), withMessages: mfq.withMessages.Clone(), // clone intermediate query. - sql: mfq.sql.Clone(), - path: mfq.path, - unique: mfq.unique, + sql: mfq.sql.Clone(), + path: mfq.path, } } // WithMessages tells the query-builder to eager-load the nodes that are connected to // the "messages" edge. The optional arguments are used to configure the query builder of the edge. func (mfq *MessageFlagQuery) WithMessages(opts ...func(*MessageQuery)) *MessageFlagQuery { - query := &MessageQuery{config: mfq.config} + query := (&MessageClient{config: mfq.config}).Query() for _, opt := range opts { opt(query) } @@ -297,19 +305,14 @@ func (mfq *MessageFlagQuery) WithMessages(opts ...func(*MessageQuery)) *MessageF // // client.MessageFlag.Query(). // GroupBy(messageflag.FieldValue). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (mfq *MessageFlagQuery) GroupBy(field string, fields ...string) *MessageFlagGroupBy { - grbuild := &MessageFlagGroupBy{config: mfq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := mfq.prepareQuery(ctx); err != nil { - return nil, err - } - return mfq.sqlQuery(ctx), nil - } + mfq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MessageFlagGroupBy{build: mfq} + grbuild.flds = &mfq.ctx.Fields grbuild.label = messageflag.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -326,17 +329,32 @@ func (mfq *MessageFlagQuery) GroupBy(field string, fields ...string) *MessageFla // Select(messageflag.FieldValue). // Scan(ctx, &v) func (mfq *MessageFlagQuery) Select(fields ...string) *MessageFlagSelect { - mfq.fields = append(mfq.fields, fields...) - selbuild := &MessageFlagSelect{MessageFlagQuery: mfq} - selbuild.label = messageflag.Label - selbuild.flds, selbuild.scan = &mfq.fields, selbuild.Scan - return selbuild + mfq.ctx.Fields = append(mfq.ctx.Fields, fields...) + sbuild := &MessageFlagSelect{MessageFlagQuery: mfq} + sbuild.label = messageflag.Label + sbuild.flds, sbuild.scan = &mfq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MessageFlagSelect configured with the given aggregations. +func (mfq *MessageFlagQuery) Aggregate(fns ...AggregateFunc) *MessageFlagSelect { + return mfq.Select().Aggregate(fns...) } func (mfq *MessageFlagQuery) prepareQuery(ctx context.Context) error { - for _, f := range mfq.fields { + for _, inter := range mfq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mfq); err != nil { + return err + } + } + } + for _, f := range mfq.ctx.Fields { if !messageflag.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if mfq.path != nil { @@ -364,10 +382,10 @@ func (mfq *MessageFlagQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([] if withFKs { _spec.Node.Columns = append(_spec.Node.Columns, messageflag.ForeignKeys...) } - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*MessageFlag).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &MessageFlag{config: mfq.config} nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes @@ -404,6 +422,9 @@ func (mfq *MessageFlagQuery) loadMessages(ctx context.Context, query *MessageQue } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(message.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -423,38 +444,22 @@ func (mfq *MessageFlagQuery) loadMessages(ctx context.Context, query *MessageQue func (mfq *MessageFlagQuery) sqlCount(ctx context.Context) (int, error) { _spec := mfq.querySpec() - _spec.Node.Columns = mfq.fields - if len(mfq.fields) > 0 { - _spec.Unique = mfq.unique != nil && *mfq.unique + _spec.Node.Columns = mfq.ctx.Fields + if len(mfq.ctx.Fields) > 0 { + _spec.Unique = mfq.ctx.Unique != nil && *mfq.ctx.Unique } return sqlgraph.CountNodes(ctx, mfq.driver, _spec) } -func (mfq *MessageFlagQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := mfq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (mfq *MessageFlagQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: messageflag.Table, - Columns: messageflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: messageflag.FieldID, - }, - }, - From: mfq.sql, - Unique: true, - } - if unique := mfq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(messageflag.Table, messageflag.Columns, sqlgraph.NewFieldSpec(messageflag.FieldID, field.TypeInt)) + _spec.From = mfq.sql + if unique := mfq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if mfq.path != nil { + _spec.Unique = true } - if fields := mfq.fields; len(fields) > 0 { + if fields := mfq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, messageflag.FieldID) for i := range fields { @@ -470,10 +475,10 @@ func (mfq *MessageFlagQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := mfq.limit; limit != nil { + if limit := mfq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := mfq.offset; offset != nil { + if offset := mfq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := mfq.order; len(ps) > 0 { @@ -489,7 +494,7 @@ func (mfq *MessageFlagQuery) querySpec() *sqlgraph.QuerySpec { func (mfq *MessageFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(mfq.driver.Dialect()) t1 := builder.Table(messageflag.Table) - columns := mfq.fields + columns := mfq.ctx.Fields if len(columns) == 0 { columns = messageflag.Columns } @@ -498,7 +503,7 @@ func (mfq *MessageFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = mfq.sql selector.Select(selector.Columns(columns...)...) } - if mfq.unique != nil && *mfq.unique { + if mfq.ctx.Unique != nil && *mfq.ctx.Unique { selector.Distinct() } for _, p := range mfq.predicates { @@ -507,12 +512,12 @@ func (mfq *MessageFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range mfq.order { p(selector) } - if offset := mfq.offset; offset != nil { + if offset := mfq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := mfq.limit; limit != nil { + if limit := mfq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -520,13 +525,8 @@ func (mfq *MessageFlagQuery) sqlQuery(ctx context.Context) *sql.Selector { // MessageFlagGroupBy is the group-by builder for MessageFlag entities. type MessageFlagGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *MessageFlagQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -535,74 +535,77 @@ func (mfgb *MessageFlagGroupBy) Aggregate(fns ...AggregateFunc) *MessageFlagGrou return mfgb } -// Scan applies the group-by query and scans the result into the given value. -func (mfgb *MessageFlagGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := mfgb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (mfgb *MessageFlagGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mfgb.build.ctx, "GroupBy") + if err := mfgb.build.prepareQuery(ctx); err != nil { return err } - mfgb.sql = query - return mfgb.sqlScan(ctx, v) + return scanWithInterceptors[*MessageFlagQuery, *MessageFlagGroupBy](ctx, mfgb.build, mfgb, mfgb.build.inters, v) } -func (mfgb *MessageFlagGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range mfgb.fields { - if !messageflag.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (mfgb *MessageFlagGroupBy) sqlScan(ctx context.Context, root *MessageFlagQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mfgb.fns)) + for _, fn := range mfgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mfgb.flds)+len(mfgb.fns)) + for _, f := range *mfgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := mfgb.sqlQuery() + selector.GroupBy(selector.Columns(*mfgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := mfgb.driver.Query(ctx, query, args, rows); err != nil { + if err := mfgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (mfgb *MessageFlagGroupBy) sqlQuery() *sql.Selector { - selector := mfgb.sql.Select() - aggregation := make([]string, 0, len(mfgb.fns)) - for _, fn := range mfgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(mfgb.fields)+len(mfgb.fns)) - for _, f := range mfgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(mfgb.fields...)...) -} - // MessageFlagSelect is the builder for selecting fields of MessageFlag entities. type MessageFlagSelect struct { *MessageFlagQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (mfs *MessageFlagSelect) Aggregate(fns ...AggregateFunc) *MessageFlagSelect { + mfs.fns = append(mfs.fns, fns...) + return mfs } // Scan applies the selector query and scans the result into the given value. -func (mfs *MessageFlagSelect) Scan(ctx context.Context, v interface{}) error { +func (mfs *MessageFlagSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mfs.ctx, "Select") if err := mfs.prepareQuery(ctx); err != nil { return err } - mfs.sql = mfs.MessageFlagQuery.sqlQuery(ctx) - return mfs.sqlScan(ctx, v) + return scanWithInterceptors[*MessageFlagQuery, *MessageFlagSelect](ctx, mfs.MessageFlagQuery, mfs, mfs.inters, v) } -func (mfs *MessageFlagSelect) sqlScan(ctx context.Context, v interface{}) error { +func (mfs *MessageFlagSelect) sqlScan(ctx context.Context, root *MessageFlagQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(mfs.fns)) + for _, fn := range mfs.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*mfs.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := mfs.sql.Query() + query, args := selector.Query() if err := mfs.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/messageflag_update.go b/internal/db_impl/ent_db/internal/messageflag_update.go similarity index 75% rename from internal/db/ent/messageflag_update.go rename to internal/db_impl/ent_db/internal/messageflag_update.go index 923a9270..0c49ae62 100644 --- a/internal/db/ent/messageflag_update.go +++ b/internal/db_impl/ent_db/internal/messageflag_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -11,9 +11,9 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // MessageFlagUpdate is the builder for updating MessageFlag entities. @@ -67,34 +67,7 @@ func (mfu *MessageFlagUpdate) ClearMessages() *MessageFlagUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (mfu *MessageFlagUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(mfu.hooks) == 0 { - affected, err = mfu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mfu.mutation = mutation - affected, err = mfu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(mfu.hooks) - 1; i >= 0; i-- { - if mfu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, mfu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, MessageFlagMutation](ctx, mfu.sqlSave, mfu.mutation, mfu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -120,16 +93,7 @@ func (mfu *MessageFlagUpdate) ExecX(ctx context.Context) { } func (mfu *MessageFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: messageflag.Table, - Columns: messageflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: messageflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(messageflag.Table, messageflag.Columns, sqlgraph.NewFieldSpec(messageflag.FieldID, field.TypeInt)) if ps := mfu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -138,11 +102,7 @@ func (mfu *MessageFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := mfu.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: messageflag.FieldValue, - }) + _spec.SetField(messageflag.FieldValue, field.TypeString, value) } if mfu.mutation.MessagesCleared() { edge := &sqlgraph.EdgeSpec{ @@ -187,6 +147,7 @@ func (mfu *MessageFlagUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + mfu.mutation.done = true return n, nil } @@ -234,6 +195,12 @@ func (mfuo *MessageFlagUpdateOne) ClearMessages() *MessageFlagUpdateOne { return mfuo } +// Where appends a list predicates to the MessageFlagUpdate builder. +func (mfuo *MessageFlagUpdateOne) Where(ps ...predicate.MessageFlag) *MessageFlagUpdateOne { + mfuo.mutation.Where(ps...) + return mfuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (mfuo *MessageFlagUpdateOne) Select(field string, fields ...string) *MessageFlagUpdateOne { @@ -243,40 +210,7 @@ func (mfuo *MessageFlagUpdateOne) Select(field string, fields ...string) *Messag // Save executes the query and returns the updated MessageFlag entity. func (mfuo *MessageFlagUpdateOne) Save(ctx context.Context) (*MessageFlag, error) { - var ( - err error - node *MessageFlag - ) - if len(mfuo.hooks) == 0 { - node, err = mfuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*MessageFlagMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - mfuo.mutation = mutation - node, err = mfuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(mfuo.hooks) - 1; i >= 0; i-- { - if mfuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = mfuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, mfuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*MessageFlag) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from MessageFlagMutation", v) - } - node = nv - } - return node, err + return withHooks[*MessageFlag, MessageFlagMutation](ctx, mfuo.sqlSave, mfuo.mutation, mfuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -302,19 +236,10 @@ func (mfuo *MessageFlagUpdateOne) ExecX(ctx context.Context) { } func (mfuo *MessageFlagUpdateOne) sqlSave(ctx context.Context) (_node *MessageFlag, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: messageflag.Table, - Columns: messageflag.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: messageflag.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(messageflag.Table, messageflag.Columns, sqlgraph.NewFieldSpec(messageflag.FieldID, field.TypeInt)) id, ok := mfuo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "MessageFlag.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "MessageFlag.id" for update`)} } _spec.Node.ID.Value = id if fields := mfuo.fields; len(fields) > 0 { @@ -322,7 +247,7 @@ func (mfuo *MessageFlagUpdateOne) sqlSave(ctx context.Context) (_node *MessageFl _spec.Node.Columns = append(_spec.Node.Columns, messageflag.FieldID) for _, f := range fields { if !messageflag.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != messageflag.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -337,11 +262,7 @@ func (mfuo *MessageFlagUpdateOne) sqlSave(ctx context.Context) (_node *MessageFl } } if value, ok := mfuo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: messageflag.FieldValue, - }) + _spec.SetField(messageflag.FieldValue, field.TypeString, value) } if mfuo.mutation.MessagesCleared() { edge := &sqlgraph.EdgeSpec{ @@ -389,5 +310,6 @@ func (mfuo *MessageFlagUpdateOne) sqlSave(ctx context.Context) (_node *MessageFl } return nil, err } + mfuo.mutation.done = true return _node, nil } diff --git a/internal/db/ent/migrate/migrate.go b/internal/db_impl/ent_db/internal/migrate/migrate.go similarity index 100% rename from internal/db/ent/migrate/migrate.go rename to internal/db_impl/ent_db/internal/migrate/migrate.go diff --git a/internal/db/ent/migrate/schema.go b/internal/db_impl/ent_db/internal/migrate/schema.go similarity index 100% rename from internal/db/ent/migrate/schema.go rename to internal/db_impl/ent_db/internal/migrate/schema.go diff --git a/internal/db/ent/mutation.go b/internal/db_impl/ent_db/internal/mutation.go similarity index 96% rename from internal/db/ent/mutation.go rename to internal/db_impl/ent_db/internal/mutation.go index 8bdc9d5f..d2faa1dd 100644 --- a/internal/db/ent/mutation.go +++ b/internal/db_impl/ent_db/internal/mutation.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,17 +10,18 @@ import ( "time" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxattr" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxflag" - "github.com/ProtonMail/gluon/internal/db/ent/mailboxpermflag" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxattr" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailboxpermflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" "entgo.io/ent" + "entgo.io/ent/dialect/sql" ) const ( @@ -119,7 +120,7 @@ func (m DeletedSubscriptionMutation) Client() *Client { // it returns an error otherwise. func (m DeletedSubscriptionMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -231,11 +232,26 @@ func (m *DeletedSubscriptionMutation) Where(ps ...predicate.DeletedSubscription) m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the DeletedSubscriptionMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *DeletedSubscriptionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.DeletedSubscription, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *DeletedSubscriptionMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *DeletedSubscriptionMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (DeletedSubscription). func (m *DeletedSubscriptionMutation) Type() string { return m.typ @@ -501,7 +517,7 @@ func (m MailboxMutation) Client() *Client { // it returns an error otherwise. func (m MailboxMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -996,11 +1012,26 @@ func (m *MailboxMutation) Where(ps ...predicate.Mailbox) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MailboxMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MailboxMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Mailbox, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *MailboxMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *MailboxMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Mailbox). func (m *MailboxMutation) Type() string { return m.typ @@ -1449,7 +1480,7 @@ func (m MailboxAttrMutation) Client() *Client { // it returns an error otherwise. func (m MailboxAttrMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -1525,11 +1556,26 @@ func (m *MailboxAttrMutation) Where(ps ...predicate.MailboxAttr) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MailboxAttrMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MailboxAttrMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.MailboxAttr, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *MailboxAttrMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *MailboxAttrMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (MailboxAttr). func (m *MailboxAttrMutation) Type() string { return m.typ @@ -1760,7 +1806,7 @@ func (m MailboxFlagMutation) Client() *Client { // it returns an error otherwise. func (m MailboxFlagMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -1836,11 +1882,26 @@ func (m *MailboxFlagMutation) Where(ps ...predicate.MailboxFlag) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MailboxFlagMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MailboxFlagMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.MailboxFlag, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *MailboxFlagMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *MailboxFlagMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (MailboxFlag). func (m *MailboxFlagMutation) Type() string { return m.typ @@ -2071,7 +2132,7 @@ func (m MailboxPermFlagMutation) Client() *Client { // it returns an error otherwise. func (m MailboxPermFlagMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -2147,11 +2208,26 @@ func (m *MailboxPermFlagMutation) Where(ps ...predicate.MailboxPermFlag) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MailboxPermFlagMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MailboxPermFlagMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.MailboxPermFlag, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *MailboxPermFlagMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *MailboxPermFlagMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (MailboxPermFlag). func (m *MailboxPermFlagMutation) Type() string { return m.typ @@ -2395,7 +2471,7 @@ func (m MessageMutation) Client() *Client { // it returns an error otherwise. func (m MessageMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -2834,11 +2910,26 @@ func (m *MessageMutation) Where(ps ...predicate.Message) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MessageMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MessageMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Message, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *MessageMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *MessageMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Message). func (m *MessageMutation) Type() string { return m.typ @@ -3259,7 +3350,7 @@ func (m MessageFlagMutation) Client() *Client { // it returns an error otherwise. func (m MessageFlagMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -3374,11 +3465,26 @@ func (m *MessageFlagMutation) Where(ps ...predicate.MessageFlag) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the MessageFlagMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MessageFlagMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.MessageFlag, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *MessageFlagMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *MessageFlagMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (MessageFlag). func (m *MessageFlagMutation) Type() string { return m.typ @@ -3515,8 +3621,6 @@ func (m *MessageFlagMutation) RemovedEdges() []string { // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. func (m *MessageFlagMutation) RemovedIDs(name string) []ent.Value { - switch name { - } return nil } @@ -3644,7 +3748,7 @@ func (m UIDMutation) Client() *Client { // it returns an error otherwise. func (m UIDMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") + return nil, errors.New("internal: mutation is not running in a transaction") } tx := &Tx{config: m.config} tx.init() @@ -3890,11 +3994,26 @@ func (m *UIDMutation) Where(ps ...predicate.UID) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the UIDMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *UIDMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.UID, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *UIDMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *UIDMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (UID). func (m *UIDMutation) Type() string { return m.typ @@ -4087,8 +4206,6 @@ func (m *UIDMutation) RemovedEdges() []string { // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. func (m *UIDMutation) RemovedIDs(name string) []ent.Value { - switch name { - } return nil } diff --git a/internal/db/ent/predicate/predicate.go b/internal/db_impl/ent_db/internal/predicate/predicate.go similarity index 100% rename from internal/db/ent/predicate/predicate.go rename to internal/db_impl/ent_db/internal/predicate/predicate.go diff --git a/internal/db/ent/runtime.go b/internal/db_impl/ent_db/internal/runtime.go similarity index 87% rename from internal/db/ent/runtime.go rename to internal/db_impl/ent_db/internal/runtime.go index a36a489b..7c8fa47b 100644 --- a/internal/db/ent/runtime.go +++ b/internal/db_impl/ent_db/internal/runtime.go @@ -1,13 +1,13 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/schema" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/schema" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // The init function reads all schema descriptors with runtime code diff --git a/internal/db_impl/ent_db/internal/runtime/runtime.go b/internal/db_impl/ent_db/internal/runtime/runtime.go new file mode 100644 index 00000000..5bac1906 --- /dev/null +++ b/internal/db_impl/ent_db/internal/runtime/runtime.go @@ -0,0 +1,10 @@ +// Code generated by ent, DO NOT EDIT. + +package runtime + +// The schema-stitching logic is generated in github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/runtime.go + +const ( + Version = "v0.11.8" // Version of ent codegen. + Sum = "h1:M/M0QL1CYCUSdqGRXUrXhFYSDRJPsOOrr+RLEej/gyQ=" // Sum of ent codegen. +) diff --git a/internal/db/ent/schema/mailbox.go b/internal/db_impl/ent_db/internal/schema/mailbox.go similarity index 100% rename from internal/db/ent/schema/mailbox.go rename to internal/db_impl/ent_db/internal/schema/mailbox.go diff --git a/internal/db/ent/schema/mailboxattr.go b/internal/db_impl/ent_db/internal/schema/mailboxattr.go similarity index 100% rename from internal/db/ent/schema/mailboxattr.go rename to internal/db_impl/ent_db/internal/schema/mailboxattr.go diff --git a/internal/db/ent/schema/mailboxflag.go b/internal/db_impl/ent_db/internal/schema/mailboxflag.go similarity index 100% rename from internal/db/ent/schema/mailboxflag.go rename to internal/db_impl/ent_db/internal/schema/mailboxflag.go diff --git a/internal/db/ent/schema/mailboxpermflag.go b/internal/db_impl/ent_db/internal/schema/mailboxpermflag.go similarity index 100% rename from internal/db/ent/schema/mailboxpermflag.go rename to internal/db_impl/ent_db/internal/schema/mailboxpermflag.go diff --git a/internal/db/ent/schema/message.go b/internal/db_impl/ent_db/internal/schema/message.go similarity index 100% rename from internal/db/ent/schema/message.go rename to internal/db_impl/ent_db/internal/schema/message.go diff --git a/internal/db/ent/schema/messageflag.go b/internal/db_impl/ent_db/internal/schema/messageflag.go similarity index 100% rename from internal/db/ent/schema/messageflag.go rename to internal/db_impl/ent_db/internal/schema/messageflag.go diff --git a/internal/db/ent/schema/subscriptions.go b/internal/db_impl/ent_db/internal/schema/subscriptions.go similarity index 100% rename from internal/db/ent/schema/subscriptions.go rename to internal/db_impl/ent_db/internal/schema/subscriptions.go diff --git a/internal/db/ent/schema/uid.go b/internal/db_impl/ent_db/internal/schema/uid.go similarity index 100% rename from internal/db/ent/schema/uid.go rename to internal/db_impl/ent_db/internal/schema/uid.go diff --git a/internal/db/ent/tx.go b/internal/db_impl/ent_db/internal/tx.go similarity index 92% rename from internal/db/ent/tx.go rename to internal/db_impl/ent_db/internal/tx.go index d97d7848..b3350afd 100644 --- a/internal/db/ent/tx.go +++ b/internal/db_impl/ent_db/internal/tx.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -32,12 +32,6 @@ type Tx struct { // lazily loaded. client *Client clientOnce sync.Once - - // completion callbacks. - mu sync.Mutex - onCommit []CommitHook - onRollback []RollbackHook - // ctx lives for the life of the transaction. It is // the same context used by the underlying connection. ctx context.Context @@ -82,9 +76,9 @@ func (tx *Tx) Commit() error { var fn Committer = CommitFunc(func(context.Context, *Tx) error { return txDriver.tx.Commit() }) - tx.mu.Lock() - hooks := append([]CommitHook(nil), tx.onCommit...) - tx.mu.Unlock() + txDriver.mu.Lock() + hooks := append([]CommitHook(nil), txDriver.onCommit...) + txDriver.mu.Unlock() for i := len(hooks) - 1; i >= 0; i-- { fn = hooks[i](fn) } @@ -93,9 +87,10 @@ func (tx *Tx) Commit() error { // OnCommit adds a hook to call on commit. func (tx *Tx) OnCommit(f CommitHook) { - tx.mu.Lock() - defer tx.mu.Unlock() - tx.onCommit = append(tx.onCommit, f) + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onCommit = append(txDriver.onCommit, f) + txDriver.mu.Unlock() } type ( @@ -137,9 +132,9 @@ func (tx *Tx) Rollback() error { var fn Rollbacker = RollbackFunc(func(context.Context, *Tx) error { return txDriver.tx.Rollback() }) - tx.mu.Lock() - hooks := append([]RollbackHook(nil), tx.onRollback...) - tx.mu.Unlock() + txDriver.mu.Lock() + hooks := append([]RollbackHook(nil), txDriver.onRollback...) + txDriver.mu.Unlock() for i := len(hooks) - 1; i >= 0; i-- { fn = hooks[i](fn) } @@ -148,9 +143,10 @@ func (tx *Tx) Rollback() error { // OnRollback adds a hook to call on rollback. func (tx *Tx) OnRollback(f RollbackHook) { - tx.mu.Lock() - defer tx.mu.Unlock() - tx.onRollback = append(tx.onRollback, f) + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onRollback = append(txDriver.onRollback, f) + txDriver.mu.Unlock() } // Client returns a Client that binds to current transaction. @@ -189,6 +185,10 @@ type txDriver struct { drv dialect.Driver // tx is the underlying transaction. tx dialect.Tx + // completion hooks. + mu sync.Mutex + onCommit []CommitHook + onRollback []RollbackHook } // newTx creates a new transactional driver. @@ -219,12 +219,12 @@ func (*txDriver) Commit() error { return nil } func (*txDriver) Rollback() error { return nil } // Exec calls tx.Exec. -func (tx *txDriver) Exec(ctx context.Context, query string, args, v interface{}) error { +func (tx *txDriver) Exec(ctx context.Context, query string, args, v any) error { return tx.tx.Exec(ctx, query, args, v) } // Query calls tx.Query. -func (tx *txDriver) Query(ctx context.Context, query string, args, v interface{}) error { +func (tx *txDriver) Query(ctx context.Context, query string, args, v any) error { return tx.tx.Query(ctx, query, args, v) } diff --git a/internal/db/ent/uid.go b/internal/db_impl/ent_db/internal/uid.go similarity index 89% rename from internal/db/ent/uid.go rename to internal/db_impl/ent_db/internal/uid.go index 5eb4b9f9..bd1ecd40 100644 --- a/internal/db/ent/uid.go +++ b/internal/db_impl/ent_db/internal/uid.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "fmt" @@ -8,9 +8,9 @@ import ( "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // UID is the model entity for the UID schema. @@ -69,8 +69,8 @@ func (e UIDEdges) MailboxOrErr() (*Mailbox, error) { } // scanValues returns the types for scanning values from sql.Rows. -func (*UID) scanValues(columns []string) ([]interface{}, error) { - values := make([]interface{}, len(columns)) +func (*UID) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) for i := range columns { switch columns[i] { case uid.FieldDeleted, uid.FieldRecent: @@ -90,7 +90,7 @@ func (*UID) scanValues(columns []string) ([]interface{}, error) { // assignValues assigns the values that were returned from sql.Rows (after scanning) // to the UID fields. -func (u *UID) assignValues(columns []string, values []interface{}) error { +func (u *UID) assignValues(columns []string, values []any) error { if m, n := len(values), len(columns); m < n { return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) } @@ -141,19 +141,19 @@ func (u *UID) assignValues(columns []string, values []interface{}) error { // QueryMessage queries the "message" edge of the UID entity. func (u *UID) QueryMessage() *MessageQuery { - return (&UIDClient{config: u.config}).QueryMessage(u) + return NewUIDClient(u.config).QueryMessage(u) } // QueryMailbox queries the "mailbox" edge of the UID entity. func (u *UID) QueryMailbox() *MailboxQuery { - return (&UIDClient{config: u.config}).QueryMailbox(u) + return NewUIDClient(u.config).QueryMailbox(u) } // Update returns a builder for updating this UID. // Note that you need to call UID.Unwrap() before calling this method if this UID // was returned from a transaction, and the transaction was committed or rolled back. func (u *UID) Update() *UIDUpdateOne { - return (&UIDClient{config: u.config}).UpdateOne(u) + return NewUIDClient(u.config).UpdateOne(u) } // Unwrap unwraps the UID entity that was returned from a transaction after it was closed, @@ -161,7 +161,7 @@ func (u *UID) Update() *UIDUpdateOne { func (u *UID) Unwrap() *UID { _tx, ok := u.config.driver.(*txDriver) if !ok { - panic("ent: UID is not a transactional entity") + panic("internal: UID is not a transactional entity") } u.config.driver = _tx.drv return u @@ -186,9 +186,3 @@ func (u *UID) String() string { // UIDs is a parsable slice of UID. type UIDs []*UID - -func (u UIDs) config(cfg config) { - for _i := range u { - u[_i].config = cfg - } -} diff --git a/internal/db/ent/uid/uid.go b/internal/db_impl/ent_db/internal/uid/uid.go similarity index 100% rename from internal/db/ent/uid/uid.go rename to internal/db_impl/ent_db/internal/uid/uid.go diff --git a/internal/db/ent/uid/where.go b/internal/db_impl/ent_db/internal/uid/where.go similarity index 67% rename from internal/db/ent/uid/where.go rename to internal/db_impl/ent_db/internal/uid/where.go index 2fc79f1b..2564df99 100644 --- a/internal/db/ent/uid/where.go +++ b/internal/db_impl/ent_db/internal/uid/where.go @@ -6,198 +6,142 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" ) // ID filters vertices based on their ID field. func ID(id int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.UID(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.UID(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.UID(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.UID(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - v := make([]interface{}, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.UID(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.UID(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.UID(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.UID(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.UID(sql.FieldLTE(FieldID, id)) } // UID applies equality check predicate on the "UID" field. It's identical to UIDEQ. func UID(v imap.UID) predicate.UID { vc := uint32(v) - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUID), vc)) - }) + return predicate.UID(sql.FieldEQ(FieldUID, vc)) } // Deleted applies equality check predicate on the "Deleted" field. It's identical to DeletedEQ. func Deleted(v bool) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldDeleted), v)) - }) + return predicate.UID(sql.FieldEQ(FieldDeleted, v)) } // Recent applies equality check predicate on the "Recent" field. It's identical to RecentEQ. func Recent(v bool) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRecent), v)) - }) + return predicate.UID(sql.FieldEQ(FieldRecent, v)) } // UIDEQ applies the EQ predicate on the "UID" field. func UIDEQ(v imap.UID) predicate.UID { vc := uint32(v) - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldUID), vc)) - }) + return predicate.UID(sql.FieldEQ(FieldUID, vc)) } // UIDNEQ applies the NEQ predicate on the "UID" field. func UIDNEQ(v imap.UID) predicate.UID { vc := uint32(v) - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldUID), vc)) - }) + return predicate.UID(sql.FieldNEQ(FieldUID, vc)) } // UIDIn applies the In predicate on the "UID" field. func UIDIn(vs ...imap.UID) predicate.UID { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = uint32(vs[i]) } - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldUID), v...)) - }) + return predicate.UID(sql.FieldIn(FieldUID, v...)) } // UIDNotIn applies the NotIn predicate on the "UID" field. func UIDNotIn(vs ...imap.UID) predicate.UID { - v := make([]interface{}, len(vs)) + v := make([]any, len(vs)) for i := range v { v[i] = uint32(vs[i]) } - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldUID), v...)) - }) + return predicate.UID(sql.FieldNotIn(FieldUID, v...)) } // UIDGT applies the GT predicate on the "UID" field. func UIDGT(v imap.UID) predicate.UID { vc := uint32(v) - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldUID), vc)) - }) + return predicate.UID(sql.FieldGT(FieldUID, vc)) } // UIDGTE applies the GTE predicate on the "UID" field. func UIDGTE(v imap.UID) predicate.UID { vc := uint32(v) - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldUID), vc)) - }) + return predicate.UID(sql.FieldGTE(FieldUID, vc)) } // UIDLT applies the LT predicate on the "UID" field. func UIDLT(v imap.UID) predicate.UID { vc := uint32(v) - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldUID), vc)) - }) + return predicate.UID(sql.FieldLT(FieldUID, vc)) } // UIDLTE applies the LTE predicate on the "UID" field. func UIDLTE(v imap.UID) predicate.UID { vc := uint32(v) - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldUID), vc)) - }) + return predicate.UID(sql.FieldLTE(FieldUID, vc)) } // DeletedEQ applies the EQ predicate on the "Deleted" field. func DeletedEQ(v bool) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldDeleted), v)) - }) + return predicate.UID(sql.FieldEQ(FieldDeleted, v)) } // DeletedNEQ applies the NEQ predicate on the "Deleted" field. func DeletedNEQ(v bool) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldDeleted), v)) - }) + return predicate.UID(sql.FieldNEQ(FieldDeleted, v)) } // RecentEQ applies the EQ predicate on the "Recent" field. func RecentEQ(v bool) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldRecent), v)) - }) + return predicate.UID(sql.FieldEQ(FieldRecent, v)) } // RecentNEQ applies the NEQ predicate on the "Recent" field. func RecentNEQ(v bool) predicate.UID { - return predicate.UID(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldRecent), v)) - }) + return predicate.UID(sql.FieldNEQ(FieldRecent, v)) } // HasMessage applies the HasEdge predicate on the "message" edge. @@ -205,7 +149,6 @@ func HasMessage() predicate.UID { return predicate.UID(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(MessageTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, false, MessageTable, MessageColumn), ) sqlgraph.HasNeighbors(s, step) @@ -233,7 +176,6 @@ func HasMailbox() predicate.UID { return predicate.UID(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(MailboxTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, MailboxTable, MailboxColumn), ) sqlgraph.HasNeighbors(s, step) diff --git a/internal/db/ent/uid_create.go b/internal/db_impl/ent_db/internal/uid_create.go similarity index 78% rename from internal/db/ent/uid_create.go rename to internal/db_impl/ent_db/internal/uid_create.go index 59d5565a..e813b610 100644 --- a/internal/db/ent/uid_create.go +++ b/internal/db_impl/ent_db/internal/uid_create.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -10,9 +10,9 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // UIDCreate is the builder for creating a UID entity. @@ -101,50 +101,8 @@ func (uc *UIDCreate) Mutation() *UIDMutation { // Save creates the UID in the database. func (uc *UIDCreate) Save(ctx context.Context) (*UID, error) { - var ( - err error - node *UID - ) uc.defaults() - if len(uc.hooks) == 0 { - if err = uc.check(); err != nil { - return nil, err - } - node, err = uc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*UIDMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = uc.check(); err != nil { - return nil, err - } - uc.mutation = mutation - if node, err = uc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(uc.hooks) - 1; i >= 0; i-- { - if uc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = uc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, uc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*UID) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from UIDMutation", v) - } - node = nv - } - return node, err + return withHooks[*UID, UIDMutation](ctx, uc.sqlSave, uc.mutation, uc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -184,18 +142,21 @@ func (uc *UIDCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (uc *UIDCreate) check() error { if _, ok := uc.mutation.UID(); !ok { - return &ValidationError{Name: "UID", err: errors.New(`ent: missing required field "UID.UID"`)} + return &ValidationError{Name: "UID", err: errors.New(`internal: missing required field "UID.UID"`)} } if _, ok := uc.mutation.Deleted(); !ok { - return &ValidationError{Name: "Deleted", err: errors.New(`ent: missing required field "UID.Deleted"`)} + return &ValidationError{Name: "Deleted", err: errors.New(`internal: missing required field "UID.Deleted"`)} } if _, ok := uc.mutation.Recent(); !ok { - return &ValidationError{Name: "Recent", err: errors.New(`ent: missing required field "UID.Recent"`)} + return &ValidationError{Name: "Recent", err: errors.New(`internal: missing required field "UID.Recent"`)} } return nil } func (uc *UIDCreate) sqlSave(ctx context.Context) (*UID, error) { + if err := uc.check(); err != nil { + return nil, err + } _node, _spec := uc.createSpec() if err := sqlgraph.CreateNode(ctx, uc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -205,42 +166,26 @@ func (uc *UIDCreate) sqlSave(ctx context.Context) (*UID, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + uc.mutation.id = &_node.ID + uc.mutation.done = true return _node, nil } func (uc *UIDCreate) createSpec() (*UID, *sqlgraph.CreateSpec) { var ( _node = &UID{config: uc.config} - _spec = &sqlgraph.CreateSpec{ - Table: uid.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: uid.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(uid.Table, sqlgraph.NewFieldSpec(uid.FieldID, field.TypeInt)) ) if value, ok := uc.mutation.UID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: uid.FieldUID, - }) + _spec.SetField(uid.FieldUID, field.TypeUint32, value) _node.UID = value } if value, ok := uc.mutation.Deleted(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: uid.FieldDeleted, - }) + _spec.SetField(uid.FieldDeleted, field.TypeBool, value) _node.Deleted = value } if value, ok := uc.mutation.Recent(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: uid.FieldRecent, - }) + _spec.SetField(uid.FieldRecent, field.TypeBool, value) _node.Recent = value } if nodes := uc.mutation.MessageIDs(); len(nodes) > 0 { diff --git a/internal/db/ent/uid_delete.go b/internal/db_impl/ent_db/internal/uid_delete.go similarity index 61% rename from internal/db/ent/uid_delete.go rename to internal/db_impl/ent_db/internal/uid_delete.go index fb7be4ac..56000ba7 100644 --- a/internal/db/ent/uid_delete.go +++ b/internal/db_impl/ent_db/internal/uid_delete.go @@ -1,16 +1,15 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // UIDDelete is the builder for deleting a UID entity. @@ -28,34 +27,7 @@ func (ud *UIDDelete) Where(ps ...predicate.UID) *UIDDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (ud *UIDDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(ud.hooks) == 0 { - affected, err = ud.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*UIDMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ud.mutation = mutation - affected, err = ud.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(ud.hooks) - 1; i >= 0; i-- { - if ud.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ud.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, ud.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, UIDMutation](ctx, ud.sqlExec, ud.mutation, ud.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (ud *UIDDelete) ExecX(ctx context.Context) int { } func (ud *UIDDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: uid.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: uid.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(uid.Table, sqlgraph.NewFieldSpec(uid.FieldID, field.TypeInt)) if ps := ud.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (ud *UIDDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + ud.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type UIDDeleteOne struct { ud *UIDDelete } +// Where appends a list predicates to the UIDDelete builder. +func (udo *UIDDeleteOne) Where(ps ...predicate.UID) *UIDDeleteOne { + udo.ud.mutation.Where(ps...) + return udo +} + // Exec executes the deletion query. func (udo *UIDDeleteOne) Exec(ctx context.Context) error { n, err := udo.ud.Exec(ctx) @@ -111,5 +82,7 @@ func (udo *UIDDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (udo *UIDDeleteOne) ExecX(ctx context.Context) { - udo.ud.ExecX(ctx) + if err := udo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/internal/db/ent/uid_query.go b/internal/db_impl/ent_db/internal/uid_query.go similarity index 72% rename from internal/db/ent/uid_query.go rename to internal/db_impl/ent_db/internal/uid_query.go index 89e724b5..7e5bef5c 100644 --- a/internal/db/ent/uid_query.go +++ b/internal/db_impl/ent_db/internal/uid_query.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -11,20 +11,18 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // UIDQuery is the builder for querying UID entities. type UIDQuery struct { config - limit *int - offset *int - unique *bool + ctx *QueryContext order []OrderFunc - fields []string + inters []Interceptor predicates []predicate.UID withMessage *MessageQuery withMailbox *MailboxQuery @@ -40,26 +38,26 @@ func (uq *UIDQuery) Where(ps ...predicate.UID) *UIDQuery { return uq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (uq *UIDQuery) Limit(limit int) *UIDQuery { - uq.limit = &limit + uq.ctx.Limit = &limit return uq } -// Offset adds an offset step to the query. +// Offset to start from. func (uq *UIDQuery) Offset(offset int) *UIDQuery { - uq.offset = &offset + uq.ctx.Offset = &offset return uq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (uq *UIDQuery) Unique(unique bool) *UIDQuery { - uq.unique = &unique + uq.ctx.Unique = &unique return uq } -// Order adds an order step to the query. +// Order specifies how the records should be ordered. func (uq *UIDQuery) Order(o ...OrderFunc) *UIDQuery { uq.order = append(uq.order, o...) return uq @@ -67,7 +65,7 @@ func (uq *UIDQuery) Order(o ...OrderFunc) *UIDQuery { // QueryMessage chains the current query on the "message" edge. func (uq *UIDQuery) QueryMessage() *MessageQuery { - query := &MessageQuery{config: uq.config} + query := (&MessageClient{config: uq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := uq.prepareQuery(ctx); err != nil { return nil, err @@ -89,7 +87,7 @@ func (uq *UIDQuery) QueryMessage() *MessageQuery { // QueryMailbox chains the current query on the "mailbox" edge. func (uq *UIDQuery) QueryMailbox() *MailboxQuery { - query := &MailboxQuery{config: uq.config} + query := (&MailboxClient{config: uq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := uq.prepareQuery(ctx); err != nil { return nil, err @@ -112,7 +110,7 @@ func (uq *UIDQuery) QueryMailbox() *MailboxQuery { // First returns the first UID entity from the query. // Returns a *NotFoundError when no UID was found. func (uq *UIDQuery) First(ctx context.Context) (*UID, error) { - nodes, err := uq.Limit(1).All(ctx) + nodes, err := uq.Limit(1).All(setContextOp(ctx, uq.ctx, "First")) if err != nil { return nil, err } @@ -135,7 +133,7 @@ func (uq *UIDQuery) FirstX(ctx context.Context) *UID { // Returns a *NotFoundError when no UID ID was found. func (uq *UIDQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(1).IDs(ctx); err != nil { + if ids, err = uq.Limit(1).IDs(setContextOp(ctx, uq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -158,7 +156,7 @@ func (uq *UIDQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one UID entity is found. // Returns a *NotFoundError when no UID entities are found. func (uq *UIDQuery) Only(ctx context.Context) (*UID, error) { - nodes, err := uq.Limit(2).All(ctx) + nodes, err := uq.Limit(2).All(setContextOp(ctx, uq.ctx, "Only")) if err != nil { return nil, err } @@ -186,7 +184,7 @@ func (uq *UIDQuery) OnlyX(ctx context.Context) *UID { // Returns a *NotFoundError when no entities are found. func (uq *UIDQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = uq.Limit(2).IDs(ctx); err != nil { + if ids, err = uq.Limit(2).IDs(setContextOp(ctx, uq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -211,10 +209,12 @@ func (uq *UIDQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of UIDs. func (uq *UIDQuery) All(ctx context.Context) ([]*UID, error) { + ctx = setContextOp(ctx, uq.ctx, "All") if err := uq.prepareQuery(ctx); err != nil { return nil, err } - return uq.sqlAll(ctx) + qr := querierAll[[]*UID, *UIDQuery]() + return withInterceptors[[]*UID](ctx, uq, qr, uq.inters) } // AllX is like All, but panics if an error occurs. @@ -227,9 +227,12 @@ func (uq *UIDQuery) AllX(ctx context.Context) []*UID { } // IDs executes the query and returns a list of UID IDs. -func (uq *UIDQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := uq.Select(uid.FieldID).Scan(ctx, &ids); err != nil { +func (uq *UIDQuery) IDs(ctx context.Context) (ids []int, err error) { + if uq.ctx.Unique == nil && uq.path != nil { + uq.Unique(true) + } + ctx = setContextOp(ctx, uq.ctx, "IDs") + if err = uq.Select(uid.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -246,10 +249,11 @@ func (uq *UIDQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (uq *UIDQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, uq.ctx, "Count") if err := uq.prepareQuery(ctx); err != nil { return 0, err } - return uq.sqlCount(ctx) + return withInterceptors[int](ctx, uq, querierCount[*UIDQuery](), uq.inters) } // CountX is like Count, but panics if an error occurs. @@ -263,10 +267,15 @@ func (uq *UIDQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (uq *UIDQuery) Exist(ctx context.Context) (bool, error) { - if err := uq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, uq.ctx, "Exist") + switch _, err := uq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("internal: check existence: %w", err) + default: + return true, nil } - return uq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -286,23 +295,22 @@ func (uq *UIDQuery) Clone() *UIDQuery { } return &UIDQuery{ config: uq.config, - limit: uq.limit, - offset: uq.offset, + ctx: uq.ctx.Clone(), order: append([]OrderFunc{}, uq.order...), + inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.UID{}, uq.predicates...), withMessage: uq.withMessage.Clone(), withMailbox: uq.withMailbox.Clone(), // clone intermediate query. - sql: uq.sql.Clone(), - path: uq.path, - unique: uq.unique, + sql: uq.sql.Clone(), + path: uq.path, } } // WithMessage tells the query-builder to eager-load the nodes that are connected to // the "message" edge. The optional arguments are used to configure the query builder of the edge. func (uq *UIDQuery) WithMessage(opts ...func(*MessageQuery)) *UIDQuery { - query := &MessageQuery{config: uq.config} + query := (&MessageClient{config: uq.config}).Query() for _, opt := range opts { opt(query) } @@ -313,7 +321,7 @@ func (uq *UIDQuery) WithMessage(opts ...func(*MessageQuery)) *UIDQuery { // WithMailbox tells the query-builder to eager-load the nodes that are connected to // the "mailbox" edge. The optional arguments are used to configure the query builder of the edge. func (uq *UIDQuery) WithMailbox(opts ...func(*MailboxQuery)) *UIDQuery { - query := &MailboxQuery{config: uq.config} + query := (&MailboxClient{config: uq.config}).Query() for _, opt := range opts { opt(query) } @@ -333,19 +341,14 @@ func (uq *UIDQuery) WithMailbox(opts ...func(*MailboxQuery)) *UIDQuery { // // client.UID.Query(). // GroupBy(uid.FieldUID). -// Aggregate(ent.Count()). +// Aggregate(internal.Count()). // Scan(ctx, &v) func (uq *UIDQuery) GroupBy(field string, fields ...string) *UIDGroupBy { - grbuild := &UIDGroupBy{config: uq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := uq.prepareQuery(ctx); err != nil { - return nil, err - } - return uq.sqlQuery(ctx), nil - } + uq.ctx.Fields = append([]string{field}, fields...) + grbuild := &UIDGroupBy{build: uq} + grbuild.flds = &uq.ctx.Fields grbuild.label = uid.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -362,17 +365,32 @@ func (uq *UIDQuery) GroupBy(field string, fields ...string) *UIDGroupBy { // Select(uid.FieldUID). // Scan(ctx, &v) func (uq *UIDQuery) Select(fields ...string) *UIDSelect { - uq.fields = append(uq.fields, fields...) - selbuild := &UIDSelect{UIDQuery: uq} - selbuild.label = uid.Label - selbuild.flds, selbuild.scan = &uq.fields, selbuild.Scan - return selbuild + uq.ctx.Fields = append(uq.ctx.Fields, fields...) + sbuild := &UIDSelect{UIDQuery: uq} + sbuild.label = uid.Label + sbuild.flds, sbuild.scan = &uq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a UIDSelect configured with the given aggregations. +func (uq *UIDQuery) Aggregate(fns ...AggregateFunc) *UIDSelect { + return uq.Select().Aggregate(fns...) } func (uq *UIDQuery) prepareQuery(ctx context.Context) error { - for _, f := range uq.fields { + for _, inter := range uq.inters { + if inter == nil { + return fmt.Errorf("internal: uninitialized interceptor (forgotten import internal/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, uq); err != nil { + return err + } + } + } + for _, f := range uq.ctx.Fields { if !uid.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } } if uq.path != nil { @@ -401,10 +419,10 @@ func (uq *UIDQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UID, err if withFKs { _spec.Node.Columns = append(_spec.Node.Columns, uid.ForeignKeys...) } - _spec.ScanValues = func(columns []string) ([]interface{}, error) { + _spec.ScanValues = func(columns []string) ([]any, error) { return (*UID).scanValues(nil, columns) } - _spec.Assign = func(columns []string, values []interface{}) error { + _spec.Assign = func(columns []string, values []any) error { node := &UID{config: uq.config} nodes = append(nodes, node) node.Edges.loadedTypes = loadedTypes @@ -447,6 +465,9 @@ func (uq *UIDQuery) loadMessage(ctx context.Context, query *MessageQuery, nodes } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(message.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -476,6 +497,9 @@ func (uq *UIDQuery) loadMailbox(ctx context.Context, query *MailboxQuery, nodes } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(mailbox.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -495,38 +519,22 @@ func (uq *UIDQuery) loadMailbox(ctx context.Context, query *MailboxQuery, nodes func (uq *UIDQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() - _spec.Node.Columns = uq.fields - if len(uq.fields) > 0 { - _spec.Unique = uq.unique != nil && *uq.unique + _spec.Node.Columns = uq.ctx.Fields + if len(uq.ctx.Fields) > 0 { + _spec.Unique = uq.ctx.Unique != nil && *uq.ctx.Unique } return sqlgraph.CountNodes(ctx, uq.driver, _spec) } -func (uq *UIDQuery) sqlExist(ctx context.Context) (bool, error) { - n, err := uq.sqlCount(ctx) - if err != nil { - return false, fmt.Errorf("ent: check existence: %w", err) - } - return n > 0, nil -} - func (uq *UIDQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: uid.Table, - Columns: uid.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: uid.FieldID, - }, - }, - From: uq.sql, - Unique: true, - } - if unique := uq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(uid.Table, uid.Columns, sqlgraph.NewFieldSpec(uid.FieldID, field.TypeInt)) + _spec.From = uq.sql + if unique := uq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if uq.path != nil { + _spec.Unique = true } - if fields := uq.fields; len(fields) > 0 { + if fields := uq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, uid.FieldID) for i := range fields { @@ -542,10 +550,10 @@ func (uq *UIDQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := uq.order; len(ps) > 0 { @@ -561,7 +569,7 @@ func (uq *UIDQuery) querySpec() *sqlgraph.QuerySpec { func (uq *UIDQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(uq.driver.Dialect()) t1 := builder.Table(uid.Table) - columns := uq.fields + columns := uq.ctx.Fields if len(columns) == 0 { columns = uid.Columns } @@ -570,7 +578,7 @@ func (uq *UIDQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = uq.sql selector.Select(selector.Columns(columns...)...) } - if uq.unique != nil && *uq.unique { + if uq.ctx.Unique != nil && *uq.ctx.Unique { selector.Distinct() } for _, p := range uq.predicates { @@ -579,12 +587,12 @@ func (uq *UIDQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range uq.order { p(selector) } - if offset := uq.offset; offset != nil { + if offset := uq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := uq.limit; limit != nil { + if limit := uq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -592,13 +600,8 @@ func (uq *UIDQuery) sqlQuery(ctx context.Context) *sql.Selector { // UIDGroupBy is the group-by builder for UID entities. type UIDGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *UIDQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -607,74 +610,77 @@ func (ugb *UIDGroupBy) Aggregate(fns ...AggregateFunc) *UIDGroupBy { return ugb } -// Scan applies the group-by query and scans the result into the given value. -func (ugb *UIDGroupBy) Scan(ctx context.Context, v interface{}) error { - query, err := ugb.path(ctx) - if err != nil { +// Scan applies the selector query and scans the result into the given value. +func (ugb *UIDGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ugb.build.ctx, "GroupBy") + if err := ugb.build.prepareQuery(ctx); err != nil { return err } - ugb.sql = query - return ugb.sqlScan(ctx, v) + return scanWithInterceptors[*UIDQuery, *UIDGroupBy](ctx, ugb.build, ugb, ugb.build.inters, v) } -func (ugb *UIDGroupBy) sqlScan(ctx context.Context, v interface{}) error { - for _, f := range ugb.fields { - if !uid.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (ugb *UIDGroupBy) sqlScan(ctx context.Context, root *UIDQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(ugb.fns)) + for _, fn := range ugb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*ugb.flds)+len(ugb.fns)) + for _, f := range *ugb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := ugb.sqlQuery() + selector.GroupBy(selector.Columns(*ugb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := ugb.driver.Query(ctx, query, args, rows); err != nil { + if err := ugb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (ugb *UIDGroupBy) sqlQuery() *sql.Selector { - selector := ugb.sql.Select() - aggregation := make([]string, 0, len(ugb.fns)) - for _, fn := range ugb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(ugb.fields)+len(ugb.fns)) - for _, f := range ugb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(ugb.fields...)...) -} - // UIDSelect is the builder for selecting fields of UID entities. type UIDSelect struct { *UIDQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (us *UIDSelect) Aggregate(fns ...AggregateFunc) *UIDSelect { + us.fns = append(us.fns, fns...) + return us } // Scan applies the selector query and scans the result into the given value. -func (us *UIDSelect) Scan(ctx context.Context, v interface{}) error { +func (us *UIDSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, us.ctx, "Select") if err := us.prepareQuery(ctx); err != nil { return err } - us.sql = us.UIDQuery.sqlQuery(ctx) - return us.sqlScan(ctx, v) + return scanWithInterceptors[*UIDQuery, *UIDSelect](ctx, us.UIDQuery, us, us.inters, v) } -func (us *UIDSelect) sqlScan(ctx context.Context, v interface{}) error { +func (us *UIDSelect) sqlScan(ctx context.Context, root *UIDQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(us.fns)) + for _, fn := range us.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*us.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := us.sql.Query() + query, args := selector.Query() if err := us.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/internal/db/ent/uid_update.go b/internal/db_impl/ent_db/internal/uid_update.go similarity index 78% rename from internal/db/ent/uid_update.go rename to internal/db_impl/ent_db/internal/uid_update.go index 8eb1a062..2d28c5fd 100644 --- a/internal/db/ent/uid_update.go +++ b/internal/db_impl/ent_db/internal/uid_update.go @@ -1,6 +1,6 @@ // Code generated by ent, DO NOT EDIT. -package ent +package internal import ( "context" @@ -11,10 +11,10 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/predicate" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/predicate" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" ) // UIDUpdate is the builder for updating UID entities. @@ -128,34 +128,7 @@ func (uu *UIDUpdate) ClearMailbox() *UIDUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (uu *UIDUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(uu.hooks) == 0 { - affected, err = uu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*UIDMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - uu.mutation = mutation - affected, err = uu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(uu.hooks) - 1; i >= 0; i-- { - if uu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = uu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, uu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks[int, UIDMutation](ctx, uu.sqlSave, uu.mutation, uu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -181,16 +154,7 @@ func (uu *UIDUpdate) ExecX(ctx context.Context) { } func (uu *UIDUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: uid.Table, - Columns: uid.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: uid.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(uid.Table, uid.Columns, sqlgraph.NewFieldSpec(uid.FieldID, field.TypeInt)) if ps := uu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -199,32 +163,16 @@ func (uu *UIDUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := uu.mutation.UID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: uid.FieldUID, - }) + _spec.SetField(uid.FieldUID, field.TypeUint32, value) } if value, ok := uu.mutation.AddedUID(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: uid.FieldUID, - }) + _spec.AddField(uid.FieldUID, field.TypeUint32, value) } if value, ok := uu.mutation.Deleted(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: uid.FieldDeleted, - }) + _spec.SetField(uid.FieldDeleted, field.TypeBool, value) } if value, ok := uu.mutation.Recent(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: uid.FieldRecent, - }) + _spec.SetField(uid.FieldRecent, field.TypeBool, value) } if uu.mutation.MessageCleared() { edge := &sqlgraph.EdgeSpec{ @@ -304,6 +252,7 @@ func (uu *UIDUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + uu.mutation.done = true return n, nil } @@ -411,6 +360,12 @@ func (uuo *UIDUpdateOne) ClearMailbox() *UIDUpdateOne { return uuo } +// Where appends a list predicates to the UIDUpdate builder. +func (uuo *UIDUpdateOne) Where(ps ...predicate.UID) *UIDUpdateOne { + uuo.mutation.Where(ps...) + return uuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (uuo *UIDUpdateOne) Select(field string, fields ...string) *UIDUpdateOne { @@ -420,40 +375,7 @@ func (uuo *UIDUpdateOne) Select(field string, fields ...string) *UIDUpdateOne { // Save executes the query and returns the updated UID entity. func (uuo *UIDUpdateOne) Save(ctx context.Context) (*UID, error) { - var ( - err error - node *UID - ) - if len(uuo.hooks) == 0 { - node, err = uuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*UIDMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - uuo.mutation = mutation - node, err = uuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(uuo.hooks) - 1; i >= 0; i-- { - if uuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = uuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, uuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*UID) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from UIDMutation", v) - } - node = nv - } - return node, err + return withHooks[*UID, UIDMutation](ctx, uuo.sqlSave, uuo.mutation, uuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -479,19 +401,10 @@ func (uuo *UIDUpdateOne) ExecX(ctx context.Context) { } func (uuo *UIDUpdateOne) sqlSave(ctx context.Context) (_node *UID, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: uid.Table, - Columns: uid.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: uid.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(uid.Table, uid.Columns, sqlgraph.NewFieldSpec(uid.FieldID, field.TypeInt)) id, ok := uuo.mutation.ID() if !ok { - return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UID.id" for update`)} + return nil, &ValidationError{Name: "id", err: errors.New(`internal: missing "UID.id" for update`)} } _spec.Node.ID.Value = id if fields := uuo.fields; len(fields) > 0 { @@ -499,7 +412,7 @@ func (uuo *UIDUpdateOne) sqlSave(ctx context.Context) (_node *UID, err error) { _spec.Node.Columns = append(_spec.Node.Columns, uid.FieldID) for _, f := range fields { if !uid.ValidColumn(f) { - return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + return nil, &ValidationError{Name: f, err: fmt.Errorf("internal: invalid field %q for query", f)} } if f != uid.FieldID { _spec.Node.Columns = append(_spec.Node.Columns, f) @@ -514,32 +427,16 @@ func (uuo *UIDUpdateOne) sqlSave(ctx context.Context) (_node *UID, err error) { } } if value, ok := uuo.mutation.UID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: uid.FieldUID, - }) + _spec.SetField(uid.FieldUID, field.TypeUint32, value) } if value, ok := uuo.mutation.AddedUID(); ok { - _spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{ - Type: field.TypeUint32, - Value: value, - Column: uid.FieldUID, - }) + _spec.AddField(uid.FieldUID, field.TypeUint32, value) } if value, ok := uuo.mutation.Deleted(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: uid.FieldDeleted, - }) + _spec.SetField(uid.FieldDeleted, field.TypeBool, value) } if value, ok := uuo.mutation.Recent(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeBool, - Value: value, - Column: uid.FieldRecent, - }) + _spec.SetField(uid.FieldRecent, field.TypeBool, value) } if uuo.mutation.MessageCleared() { edge := &sqlgraph.EdgeSpec{ @@ -622,5 +519,6 @@ func (uuo *UIDUpdateOne) sqlSave(ctx context.Context) (_node *UID, err error) { } return nil, err } + uuo.mutation.done = true return _node, nil } diff --git a/internal/db/mailbox.go b/internal/db_impl/ent_db/mailbox.go similarity index 62% rename from internal/db/mailbox.go rename to internal/db_impl/ent_db/mailbox.go index ddc1a880..3900df73 100644 --- a/internal/db/mailbox.go +++ b/internal/db_impl/ent_db/mailbox.go @@ -1,29 +1,30 @@ -package db +package ent_db import ( "context" "fmt" + "github.com/ProtonMail/gluon/db" "strings" "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/uid" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" "github.com/ProtonMail/gluon/internal/ids" "github.com/bradenaw/juniper/xslices" ) func CreateMailbox( ctx context.Context, - tx *ent.Tx, + tx *internal.Tx, mboxID imap.MailboxID, name string, flags, permFlags, attrs imap.FlagSet, uidValidity imap.UID, -) (*ent.Mailbox, error) { +) (*internal.Mailbox, error) { create := tx.Mailbox.Create(). SetName(name). SetUIDValidity(uidValidity) @@ -73,15 +74,15 @@ func CreateMailbox( return mbox, nil } -func MailboxExistsWithID(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (bool, error) { +func MailboxExistsWithID(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (bool, error) { return client.Mailbox.Query().Where(mailbox.ID(mboxID)).Exist(ctx) } -func MailboxExistsWithRemoteID(ctx context.Context, client *ent.Client, mboxID imap.MailboxID) (bool, error) { +func MailboxExistsWithRemoteID(ctx context.Context, client *internal.Client, mboxID imap.MailboxID) (bool, error) { return client.Mailbox.Query().Where(mailbox.RemoteID(mboxID)).Exist(ctx) } -func GetMailboxIDFromRemoteID(ctx context.Context, client *ent.Client, mboxID imap.MailboxID) (imap.InternalMailboxID, error) { +func GetMailboxIDFromRemoteID(ctx context.Context, client *internal.Client, mboxID imap.MailboxID) (imap.InternalMailboxID, error) { mbox, err := client.Mailbox.Query().Where(mailbox.RemoteID(mboxID)).Select(mailbox.FieldID).Only(ctx) if err != nil { return 0, err @@ -90,11 +91,11 @@ func GetMailboxIDFromRemoteID(ctx context.Context, client *ent.Client, mboxID im return mbox.ID, nil } -func MailboxExistsWithName(ctx context.Context, client *ent.Client, name string) (bool, error) { +func MailboxExistsWithName(ctx context.Context, client *internal.Client, name string) (bool, error) { return client.Mailbox.Query().Where(mailbox.Name(name)).Exist(ctx) } -func RenameMailboxWithRemoteID(ctx context.Context, tx *ent.Tx, mboxID imap.MailboxID, name string) error { +func RenameMailboxWithRemoteID(ctx context.Context, tx *internal.Tx, mboxID imap.MailboxID, name string) error { if _, err := tx.Mailbox.Update(). Where(mailbox.RemoteID(mboxID)). SetName(name). @@ -109,7 +110,7 @@ func RenameMailboxWithRemoteID(ctx context.Context, tx *ent.Tx, mboxID imap.Mail // It returns the (potentially new) global UID validity. func DeleteMailboxWithRemoteID( ctx context.Context, - tx *ent.Tx, + tx *internal.Tx, mboxID imap.MailboxID, ) error { mbox, err := tx.Mailbox.Query().Where(mailbox.RemoteID(mboxID)).Select(mailbox.FieldSubscribed, mailbox.FieldName).Only(ctx) @@ -130,7 +131,7 @@ func DeleteMailboxWithRemoteID( return nil } -func UpdateRemoteMailboxID(ctx context.Context, tx *ent.Tx, internalID imap.InternalMailboxID, remoteID imap.MailboxID) error { +func UpdateRemoteMailboxID(ctx context.Context, tx *internal.Tx, internalID imap.InternalMailboxID, remoteID imap.MailboxID) error { if _, err := tx.Mailbox.Update(). Where(mailbox.ID(internalID)). SetRemoteID(remoteID). @@ -141,7 +142,7 @@ func UpdateRemoteMailboxID(ctx context.Context, tx *ent.Tx, internalID imap.Inte return nil } -func BumpMailboxUIDNext(ctx context.Context, tx *ent.Tx, mbox *ent.Mailbox, withCount ...int) error { +func BumpMailboxUIDNext(ctx context.Context, tx *internal.Tx, mbox *internal.Mailbox, withCount ...int) error { var n int if len(withCount) > 0 { @@ -159,7 +160,7 @@ func BumpMailboxUIDNext(ctx context.Context, tx *ent.Tx, mbox *ent.Mailbox, with return nil } -func GetMailboxName(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (string, error) { +func GetMailboxName(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (string, error) { mailbox, err := client.Mailbox.Query().Where(mailbox.ID(mboxID)).Select(mailbox.FieldName).Only(ctx) if err != nil { return "", err @@ -168,7 +169,7 @@ func GetMailboxName(ctx context.Context, client *ent.Client, mboxID imap.Interna return mailbox.Name, nil } -func GetMailboxNameWithRemoteID(ctx context.Context, client *ent.Client, mboxID imap.MailboxID) (string, error) { +func GetMailboxNameWithRemoteID(ctx context.Context, client *internal.Client, mboxID imap.MailboxID) (string, error) { mailbox, err := client.Mailbox.Query().Where(mailbox.RemoteID(mboxID)).Select(mailbox.FieldName).Only(ctx) if err != nil { return "", err @@ -177,7 +178,7 @@ func GetMailboxNameWithRemoteID(ctx context.Context, client *ent.Client, mboxID return mailbox.Name, nil } -func GetMailboxMessageIDPairs(ctx context.Context, client *ent.Client, mailboxID imap.InternalMailboxID) ([]ids.MessageIDPair, error) { +func GetMailboxMessageIDPairs(ctx context.Context, client *internal.Client, mailboxID imap.InternalMailboxID) ([]db.MessageIDPair, error) { messages, err := client.Message.Query(). Where(message.HasUIDsWith(uid.HasMailboxWith(mailbox.ID(mailboxID)))). Select(message.FieldID, message.FieldRemoteID). @@ -186,15 +187,18 @@ func GetMailboxMessageIDPairs(ctx context.Context, client *ent.Client, mailboxID return nil, err } - return xslices.Map(messages, func(message *ent.Message) ids.MessageIDPair { - return ids.NewMessageIDPair(message) + return xslices.Map(messages, func(message *internal.Message) db.MessageIDPair { + return db.MessageIDPair{ + InternalID: message.ID, + RemoteID: message.RemoteID, + } }), nil } -func GetAllMailboxes(ctx context.Context, client *ent.Client) ([]*ent.Mailbox, error) { +func GetAllMailboxes(ctx context.Context, client *internal.Client) ([]*internal.Mailbox, error) { const QueryLimit = 16000 - var mailboxes []*ent.Mailbox + var mailboxes []*internal.Mailbox for i := 0; ; i += QueryLimit { result, err := client.Mailbox.Query(). @@ -217,60 +221,30 @@ func GetAllMailboxes(ctx context.Context, client *ent.Client) ([]*ent.Mailbox, e return mailboxes, nil } -func GetMailboxByName(ctx context.Context, client *ent.Client, name string) (*ent.Mailbox, error) { +func GetMailboxByName(ctx context.Context, client *internal.Client, name string) (*internal.Mailbox, error) { return client.Mailbox.Query().Where(mailbox.Name(name)).Only(ctx) } -func GetMailboxByID(ctx context.Context, client *ent.Client, id imap.InternalMailboxID) (*ent.Mailbox, error) { +func GetMailboxByID(ctx context.Context, client *internal.Client, id imap.InternalMailboxID) (*internal.Mailbox, error) { return client.Mailbox.Query().Where(mailbox.ID(id)).Only(ctx) } -func GetMailboxByRemoteID(ctx context.Context, client *ent.Client, id imap.MailboxID) (*ent.Mailbox, error) { +func GetMailboxByRemoteID(ctx context.Context, client *internal.Client, id imap.MailboxID) (*internal.Mailbox, error) { return client.Mailbox.Query().Where(mailbox.RemoteID(id)).Only(ctx) } -func GetMailboxRecentCount(ctx context.Context, client *ent.Client, mbox *ent.Mailbox) (int, error) { +func GetMailboxRecentCount(ctx context.Context, client *internal.Client, mbox *internal.Mailbox) (int, error) { return mbox.QueryUIDs().Where(uid.Recent(true)).Count(ctx) } -func GetMailboxMessageCount(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (int, error) { +func GetMailboxMessageCount(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (int, error) { return client.UID.Query().Where(func(s *sql.Selector) { s.Where(sql.EQ(uid.MailboxColumn, mboxID)) }).Count(ctx) } -type SnapshotMessageResult struct { - InternalID imap.InternalMessageID `json:"uid_message"` - RemoteID imap.MessageID `json:"remote_id"` - UID imap.UID `json:"uid"` - Recent bool `json:"recent"` - Deleted bool `json:"deleted"` - Flags string `json:"flags"` -} - -func (msg *SnapshotMessageResult) GetFlagSet() imap.FlagSet { - var flagSet imap.FlagSet - - if len(msg.Flags) > 0 { - flags := strings.Split(msg.Flags, ",") - flagSet = imap.NewFlagSetFromSlice(flags) - } else { - flagSet = imap.NewFlagSet() - } - - if msg.Deleted { - flagSet.AddToSelf(imap.FlagDeleted) - } - - if msg.Recent { - flagSet.AddToSelf(imap.FlagRecent) - } - - return flagSet -} - -func GetMailboxMessagesForNewSnapshot(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) ([]SnapshotMessageResult, error) { - messages := make([]SnapshotMessageResult, 0, 32) +func GetMailboxMessagesForNewSnapshot(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) ([]db.SnapshotMessageResult, error) { + messages := make([]db.SnapshotMessageResult, 0, 32) if err := client.UID.Query().Where(func(s *sql.Selector) { msgTable := sql.Table(message.Table) @@ -288,7 +262,7 @@ func GetMailboxMessagesForNewSnapshot(ctx context.Context, client *ent.Client, m return messages, nil } -func GetMailboxIDWithRemoteID(ctx context.Context, client *ent.Client, mboxID imap.MailboxID) (imap.InternalMailboxID, error) { +func GetMailboxIDWithRemoteID(ctx context.Context, client *internal.Client, mboxID imap.MailboxID) (imap.InternalMailboxID, error) { mbox, err := client.Mailbox.Query().Where(mailbox.RemoteID(mboxID)).Select(mailbox.FieldID).Only(ctx) if err != nil { return 0, err @@ -297,18 +271,18 @@ func GetMailboxIDWithRemoteID(ctx context.Context, client *ent.Client, mboxID im return mbox.ID, nil } -func TranslateRemoteMailboxIDs(ctx context.Context, client *ent.Client, mboxIDs []imap.MailboxID) ([]imap.InternalMailboxID, error) { +func TranslateRemoteMailboxIDs(ctx context.Context, client *internal.Client, mboxIDs []imap.MailboxID) ([]imap.InternalMailboxID, error) { mboxes, err := client.Mailbox.Query().Where(mailbox.RemoteIDIn(mboxIDs...)).Select(mailbox.FieldID).All(ctx) if err != nil { return nil, err } - return xslices.Map(mboxes, func(m *ent.Mailbox) imap.InternalMailboxID { + return xslices.Map(mboxes, func(m *internal.Mailbox) imap.InternalMailboxID { return m.ID }), nil } -func CreateMailboxIfNotExists(ctx context.Context, tx *ent.Tx, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) error { +func CreateMailboxIfNotExists(ctx context.Context, tx *internal.Tx, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) error { exists, err := MailboxExistsWithRemoteID(ctx, tx.Client(), mbox.ID) if err != nil { return err @@ -332,10 +306,10 @@ func CreateMailboxIfNotExists(ctx context.Context, tx *ent.Tx, mbox imap.Mailbox return nil } -func GetOrCreateMailbox(ctx context.Context, tx *ent.Tx, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) (*ent.Mailbox, error) { +func GetOrCreateMailbox(ctx context.Context, tx *internal.Tx, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) (*internal.Mailbox, error) { mailbox, err := tx.Mailbox.Query().Where(mailbox.RemoteID(mbox.ID)).Only(ctx) if err != nil { - if !ent.IsNotFound(err) { + if !internal.IsNotFound(err) { return nil, err } } else { @@ -354,7 +328,7 @@ func GetOrCreateMailbox(ctx context.Context, tx *ent.Tx, mbox imap.Mailbox, deli ) } -func FilterMailboxContains(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID, messageIDs []ids.MessageIDPair) ([]imap.InternalMessageID, error) { +func FilterMailboxContains(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID, messageIDs []db.MessageIDPair) ([]imap.InternalMessageID, error) { type result struct { InternalID imap.InternalMessageID `json:"uid_message"` } @@ -362,7 +336,7 @@ func FilterMailboxContains(ctx context.Context, client *ent.Client, mboxID imap. var r []result if err := client.UID.Query().Where(func(s *sql.Selector) { - s.Where(sql.And(sql.EQ(uid.MailboxColumn, mboxID), sql.In(uid.MessageColumn, xslices.Map(messageIDs, func(id ids.MessageIDPair) interface{} { + s.Where(sql.And(sql.EQ(uid.MailboxColumn, mboxID), sql.In(uid.MessageColumn, xslices.Map(messageIDs, func(id db.MessageIDPair) interface{} { return id.InternalID })...))) s.Select(uid.MessageColumn) @@ -375,7 +349,7 @@ func FilterMailboxContains(ctx context.Context, client *ent.Client, mboxID imap. }), nil } -func FilterMailboxContainsInternalID(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]imap.InternalMessageID, error) { +func FilterMailboxContainsInternalID(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]imap.InternalMessageID, error) { type result struct { InternalID imap.InternalMessageID `json:"uid_message"` } @@ -396,51 +370,51 @@ func FilterMailboxContainsInternalID(ctx context.Context, client *ent.Client, mb }), nil } -func GetMailboxFlags(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { +func GetMailboxFlags(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { mbox, err := client.Mailbox.Query().Where(mailbox.ID(mboxID)).WithFlags().Only(ctx) if err != nil { return imap.FlagSet{}, err } - return imap.NewFlagSetFromSlice(xslices.Map(mbox.Edges.Flags, func(flag *ent.MailboxFlag) string { + return imap.NewFlagSetFromSlice(xslices.Map(mbox.Edges.Flags, func(flag *internal.MailboxFlag) string { return flag.Value })), nil } -func GetMailboxPermanentFlags(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { +func GetMailboxPermanentFlags(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { mbox, err := client.Mailbox.Query().Where(mailbox.ID(mboxID)).WithPermanentFlags().Only(ctx) if err != nil { return imap.FlagSet{}, err } - return imap.NewFlagSetFromSlice(xslices.Map(mbox.Edges.PermanentFlags, func(flag *ent.MailboxPermFlag) string { + return imap.NewFlagSetFromSlice(xslices.Map(mbox.Edges.PermanentFlags, func(flag *internal.MailboxPermFlag) string { return flag.Value })), nil } -func GetMailboxAttributes(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { +func GetMailboxAttributes(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { mbox, err := client.Mailbox.Query().Where(mailbox.ID(mboxID)).WithAttributes().Only(ctx) if err != nil { return imap.FlagSet{}, err } - return imap.NewFlagSetFromSlice(xslices.Map(mbox.Edges.Attributes, func(flag *ent.MailboxAttr) string { + return imap.NewFlagSetFromSlice(xslices.Map(mbox.Edges.Attributes, func(flag *internal.MailboxAttr) string { return flag.Value })), nil } -func IsMessageInMailbox(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID, messageID imap.InternalMailboxID) (bool, error) { +func IsMessageInMailbox(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID, messageID imap.InternalMailboxID) (bool, error) { return client.UID.Query().Where(func(s *sql.Selector) { s.Where(sql.And(sql.EQ(uid.MailboxColumn, mboxID), sql.EQ(uid.MessageColumn, messageID))) s.Select(uid.MessageColumn) }).Exist(ctx) } -func GetMailboxCount(ctx context.Context, client *ent.Client) (int, error) { +func GetMailboxCount(ctx context.Context, client *internal.Client) (int, error) { return client.Mailbox.Query().Count(ctx) } -func GetMailboxUID(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (imap.UID, error) { +func GetMailboxUID(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (imap.UID, error) { mbox, err := client.Mailbox.Query().Where(mailbox.ID(mboxID)).Select(mailbox.FieldUIDNext).Only(ctx) if err != nil { return 0, err @@ -449,7 +423,7 @@ func GetMailboxUID(ctx context.Context, client *ent.Client, mboxID imap.Internal return mbox.UIDNext, err } -func GetMailboxMessageCountAndUID(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID) (int, imap.UID, error) { +func GetMailboxMessageCountAndUID(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID) (int, imap.UID, error) { messageCount, err := GetMailboxMessageCount(ctx, client, mboxID) if err != nil { return 0, 0, err diff --git a/internal/db/message.go b/internal/db_impl/ent_db/message.go similarity index 68% rename from internal/db/message.go rename to internal/db_impl/ent_db/message.go index 38163e10..091151ec 100644 --- a/internal/db/message.go +++ b/internal/db_impl/ent_db/message.go @@ -1,38 +1,28 @@ -package db +package ent_db import ( "context" "fmt" - "strings" + "github.com/ProtonMail/gluon/db" "entgo.io/ent/dialect/sql" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/db/ent/mailbox" - "github.com/ProtonMail/gluon/internal/db/ent/message" - "github.com/ProtonMail/gluon/internal/db/ent/messageflag" - "github.com/ProtonMail/gluon/internal/db/ent/uid" - "github.com/ProtonMail/gluon/internal/ids" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/message" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/messageflag" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/uid" "github.com/bradenaw/juniper/xslices" "golang.org/x/exp/slices" ) -const ChunkLimit = 1000 +const ChunkLimit = db.ChunkLimit -type CreateMessageReq struct { - Message imap.Message - InternalID imap.InternalMessageID - LiteralSize int - Body string - Structure string - Envelope string -} - -func CreateMessages(ctx context.Context, tx *ent.Tx, reqs ...*CreateMessageReq) ([]*ent.Message, error) { - flags := make(map[imap.InternalMessageID][]*ent.MessageFlag) +func CreateMessages(ctx context.Context, tx *internal.Tx, reqs ...*db.CreateMessageReq) ([]*internal.Message, error) { + flags := make(map[imap.InternalMessageID][]*internal.MessageFlag) for _, req := range reqs { - builders := xslices.Map(req.Message.Flags.ToSlice(), func(flag string) *ent.MessageFlagCreate { + builders := xslices.Map(req.Message.Flags.ToSlice(), func(flag string) *internal.MessageFlagCreate { return tx.MessageFlag.Create().SetValue(flag) }) @@ -44,7 +34,7 @@ func CreateMessages(ctx context.Context, tx *ent.Tx, reqs ...*CreateMessageReq) flags[req.InternalID] = entFlags } - builders := xslices.Map(reqs, func(req *CreateMessageReq) *ent.MessageCreate { + builders := xslices.Map(reqs, func(req *db.CreateMessageReq) *internal.MessageCreate { msgCreate := tx.Message.Create(). SetID(req.InternalID). SetDate(req.Message.Date). @@ -61,7 +51,7 @@ func CreateMessages(ctx context.Context, tx *ent.Tx, reqs ...*CreateMessageReq) return msgCreate }) - messages := make([]*ent.Message, 0, len(builders)) + messages := make([]*internal.Message, 0, len(builders)) // Avoid too many SQL variables error. for _, chunk := range xslices.Chunk(builders, ChunkLimit) { @@ -76,7 +66,7 @@ func CreateMessages(ctx context.Context, tx *ent.Tx, reqs ...*CreateMessageReq) return messages, nil } -func AddMessagesToMailbox(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, mboxID imap.InternalMailboxID) ([]UIDWithFlags, error) { +func AddMessagesToMailbox(ctx context.Context, tx *internal.Tx, messageIDs []imap.InternalMessageID, mboxID imap.InternalMailboxID) ([]db.UIDWithFlags, error) { if len(messageIDs) == 0 { return nil, nil } @@ -88,7 +78,7 @@ func AddMessagesToMailbox(ctx context.Context, tx *ent.Tx, messageIDs []imap.Int return nil, err } - var builders []*ent.UIDCreate + var builders []*internal.UIDCreate for idx, messageID := range messageIDs { nextUID := mbox.UIDNext.Add(uint32(idx)) @@ -117,16 +107,16 @@ func AddMessagesToMailbox(ctx context.Context, tx *ent.Tx, messageIDs []imap.Int return GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx, tx.Client(), mboxID, messageIDs) } -func CreateAndAddMessageToMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID, req *CreateMessageReq) (imap.UID, imap.FlagSet, error) { +func CreateAndAddMessageToMailbox(ctx context.Context, tx *internal.Tx, mboxID imap.InternalMailboxID, req *db.CreateMessageReq) (imap.UID, imap.FlagSet, error) { mbox, err := tx.Mailbox.Query().Where(mailbox.ID(mboxID)).Select(mailbox.FieldID, mailbox.FieldUIDNext).Only(ctx) if err != nil { return 0, imap.FlagSet{}, err } - var flags []*ent.MessageFlag + var flags []*internal.MessageFlag { - builders := xslices.Map(req.Message.Flags.ToSlice(), func(flag string) *ent.MessageFlagCreate { + builders := xslices.Map(req.Message.Flags.ToSlice(), func(flag string) *internal.MessageFlagCreate { return tx.MessageFlag.Create().SetValue(flag) }) @@ -172,13 +162,7 @@ func CreateAndAddMessageToMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.I return uid.UID, NewFlagSet(uid, flags), err } -type CreateAndAddMessagesResult struct { - UID imap.UID - Flags imap.FlagSet - MessageID ids.MessageIDPair -} - -func BumpMailboxUIDsForMessage(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, mboxID imap.InternalMailboxID) ([]UIDWithFlags, error) { +func BumpMailboxUIDsForMessage(ctx context.Context, tx *internal.Tx, messageIDs []imap.InternalMessageID, mboxID imap.InternalMailboxID) ([]db.UIDWithFlags, error) { messageUIDs := make(map[imap.InternalMessageID]imap.UID) mbox, err := tx.Mailbox.Query().Where(mailbox.ID(mboxID)).Only(ctx) @@ -186,7 +170,7 @@ func BumpMailboxUIDsForMessage(ctx context.Context, tx *ent.Tx, messageIDs []ima return nil, err } - var builders []*ent.UIDUpdate + var builders []*internal.UIDUpdate for idx, messageID := range messageIDs { uidNext := mbox.UIDNext.Add(uint32(idx)) @@ -211,7 +195,7 @@ func BumpMailboxUIDsForMessage(ctx context.Context, tx *ent.Tx, messageIDs []ima return GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx, tx.Client(), mboxID, messageIDs) } -func RemoveMessagesFromMailbox(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, mboxID imap.InternalMailboxID) error { +func RemoveMessagesFromMailbox(ctx context.Context, tx *internal.Tx, messageIDs []imap.InternalMessageID, mboxID imap.InternalMailboxID) error { // Avoid too many SQL variables error. for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { if _, err := tx.UID.Delete(). @@ -228,7 +212,7 @@ func RemoveMessagesFromMailbox(ctx context.Context, tx *ent.Tx, messageIDs []ima return nil } -func MessageExistsWithRemoteID(ctx context.Context, client *ent.Client, messageID imap.MessageID) (bool, error) { +func MessageExistsWithRemoteID(ctx context.Context, client *internal.Client, messageID imap.MessageID) (bool, error) { count, err := client.Message.Query().Where(message.RemoteID(messageID)).Count(ctx) if err != nil { return false, err @@ -237,19 +221,19 @@ func MessageExistsWithRemoteID(ctx context.Context, client *ent.Client, messageI return count > 0, nil } -func GetMessage(ctx context.Context, client *ent.Client, messageID imap.InternalMessageID) (*ent.Message, error) { +func GetMessage(ctx context.Context, client *internal.Client, messageID imap.InternalMessageID) (*internal.Message, error) { return client.Message.Query().Where(message.ID(messageID)).Only(ctx) } -func GetImportedMessageData(ctx context.Context, client *ent.Client, messageID imap.InternalMessageID) (*ent.Message, error) { +func GetImportedMessageData(ctx context.Context, client *internal.Client, messageID imap.InternalMessageID) (*internal.Message, error) { return client.Message.Query().Where(message.ID(messageID)).WithFlags().Select(message.FieldDate).Only(ctx) } -func GetMessageDateAndSize(ctx context.Context, client *ent.Client, messageID imap.InternalMessageID) (*ent.Message, error) { +func GetMessageDateAndSize(ctx context.Context, client *internal.Client, messageID imap.InternalMessageID) (*internal.Message, error) { return client.Message.Query().Where(message.ID(messageID)).Select(message.FieldSize, message.FieldDate).Only(ctx) } -func GetMessageMailboxIDs(ctx context.Context, client *ent.Client, messageID imap.InternalMessageID) ([]imap.InternalMailboxID, error) { +func GetMessageMailboxIDs(ctx context.Context, client *internal.Client, messageID imap.InternalMessageID) ([]imap.InternalMailboxID, error) { type tmp struct { MBoxID imap.InternalMailboxID `json:"mailbox_ui_ds"` } @@ -267,40 +251,10 @@ func GetMessageMailboxIDs(ctx context.Context, client *ent.Client, messageID ima }), nil } -type UIDWithFlags struct { - InternalID imap.InternalMessageID `json:"uid_message"` - RemoteID imap.MessageID `json:"remote_id"` - UID imap.UID `json:"uid"` - Recent bool `json:"recent"` - Deleted bool `json:"deleted"` - Flags string `json:"flags"` -} - -func (u *UIDWithFlags) GetFlagSet() imap.FlagSet { - var flagSet imap.FlagSet - - if len(u.Flags) > 0 { - flags := strings.Split(u.Flags, ",") - flagSet = imap.NewFlagSetFromSlice(flags) - } else { - flagSet = imap.NewFlagSet() - } - - if u.Deleted { - flagSet.AddToSelf(imap.FlagDeleted) - } - - if u.Recent { - flagSet.AddToSelf(imap.FlagRecent) - } - - return flagSet -} - // GetMessageUIDsWithFlagsAfterAddOrUIDBump exploits a property of adding a message to or bumping the UIDs of existing message in mailbox. It can only be // used if you can guarantee that the messageID list contains only IDs that have recently added or bumped in the mailbox. -func GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]UIDWithFlags, error) { - result := make([]UIDWithFlags, 0, len(messageIDs)) +func GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { + result := make([]db.UIDWithFlags, 0, len(messageIDs)) // Hav to split this in chunks as this can trigger too many SQL Variables. for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { @@ -320,7 +274,7 @@ func GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, client *ent.C } } - slices.SortFunc(result, func(v1 UIDWithFlags, v2 UIDWithFlags) bool { + slices.SortFunc(result, func(v1 db.UIDWithFlags, v2 db.UIDWithFlags) bool { return v1.UID < v2.UID }) @@ -331,21 +285,15 @@ func GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, client *ent.C return result, nil } -type MessageFlagSet struct { - ID imap.InternalMessageID - RemoteID imap.MessageID - FlagSet imap.FlagSet -} - // GetMessageFlags returns the flags of the given messages. // It does not include per-mailbox flags (\Deleted, \Recent)! -func GetMessageFlags(ctx context.Context, client *ent.Client, messageIDs []imap.InternalMessageID) ([]MessageFlagSet, error) { - result := make([]MessageFlagSet, 0, len(messageIDs)) +func GetMessageFlags(ctx context.Context, client *internal.Client, messageIDs []imap.InternalMessageID) ([]db.MessageFlagSet, error) { + result := make([]db.MessageFlagSet, 0, len(messageIDs)) for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { chunkMessages, err := client.Message.Query(). Where(message.IDIn(chunk...)). - WithFlags(func(query *ent.MessageFlagQuery) { + WithFlags(func(query *internal.MessageFlagQuery) { query.Select(messageflag.FieldValue) }). Select(message.FieldID, message.FieldRemoteID). @@ -355,7 +303,7 @@ func GetMessageFlags(ctx context.Context, client *ent.Client, messageIDs []imap. } for _, msg := range chunkMessages { - mfs := MessageFlagSet{ + mfs := db.MessageFlagSet{ ID: msg.ID, RemoteID: msg.RemoteID, FlagSet: imap.NewFlagSetWithCapacity(len(msg.Edges.Flags)), @@ -372,7 +320,7 @@ func GetMessageFlags(ctx context.Context, client *ent.Client, messageIDs []imap. return result, nil } -func GetMessageDeleted(ctx context.Context, client *ent.Client, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) (map[imap.InternalMessageID]bool, error) { +func GetMessageDeleted(ctx context.Context, client *internal.Client, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) (map[imap.InternalMessageID]bool, error) { type tmp struct { MsgID imap.InternalMessageID `json:"uid_message"` Deleted bool `json:"deleted"` @@ -402,9 +350,9 @@ func GetMessageDeleted(ctx context.Context, client *ent.Client, mboxID imap.Inte return res, nil } -func AddMessageFlag(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, addFlag string) error { +func AddMessageFlag(ctx context.Context, tx *internal.Tx, messageIDs []imap.InternalMessageID, addFlag string) error { for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { - builders := xslices.Map(chunk, func(imap.InternalMessageID) *ent.MessageFlagCreate { + builders := xslices.Map(chunk, func(imap.InternalMessageID) *internal.MessageFlagCreate { return tx.MessageFlag.Create().SetValue(addFlag) }) @@ -423,7 +371,7 @@ func AddMessageFlag(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalM return nil } -func RemoveMessageFlag(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, remFlag string) error { +func RemoveMessageFlag(ctx context.Context, tx *internal.Tx, messageIDs []imap.InternalMessageID, remFlag string) error { remFlagSet := imap.NewFlagSet(remFlag) for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { @@ -436,8 +384,8 @@ func RemoveMessageFlag(ctx context.Context, tx *ent.Tx, messageIDs []imap.Intern return err } - flags := xslices.Map(messages, func(message *ent.Message) *ent.MessageFlag { - return message.Edges.Flags[xslices.IndexFunc(message.Edges.Flags, func(flag *ent.MessageFlag) bool { + flags := xslices.Map(messages, func(message *internal.Message) *internal.MessageFlag { + return message.Edges.Flags[xslices.IndexFunc(message.Edges.Flags, func(flag *internal.MessageFlag) bool { return remFlagSet.Contains(flag.Value) })] }) @@ -452,7 +400,7 @@ func RemoveMessageFlag(ctx context.Context, tx *ent.Tx, messageIDs []imap.Intern return nil } -func SetMessageFlags(ctx context.Context, tx *ent.Tx, messageIDs []imap.InternalMessageID, setFlags imap.FlagSet) error { +func SetMessageFlags(ctx context.Context, tx *internal.Tx, messageIDs []imap.InternalMessageID, setFlags imap.FlagSet) error { for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { messages, err := tx.Message.Query(). Where(message.IDIn(chunk...)). @@ -464,7 +412,7 @@ func SetMessageFlags(ctx context.Context, tx *ent.Tx, messageIDs []imap.Internal } for _, message := range messages { - curFlagSet := imap.NewFlagSetFromSlice(xslices.Map(message.Edges.Flags, func(flag *ent.MessageFlag) string { + curFlagSet := imap.NewFlagSetFromSlice(xslices.Map(message.Edges.Flags, func(flag *internal.MessageFlag) string { return flag.Value })) @@ -472,7 +420,7 @@ func SetMessageFlags(ctx context.Context, tx *ent.Tx, messageIDs []imap.Internal return !curFlagSet.Contains(flag) }) - builders := xslices.Map(addFlags, func(flag string) *ent.MessageFlagCreate { + builders := xslices.Map(addFlags, func(flag string) *internal.MessageFlagCreate { return tx.MessageFlag.Create().SetValue(flag) }) @@ -481,7 +429,7 @@ func SetMessageFlags(ctx context.Context, tx *ent.Tx, messageIDs []imap.Internal return err } - remEntFlags := xslices.Filter(message.Edges.Flags, func(flag *ent.MessageFlag) bool { + remEntFlags := xslices.Filter(message.Edges.Flags, func(flag *internal.MessageFlag) bool { return !setFlags.Contains(flag.Value) }) @@ -497,7 +445,7 @@ func SetMessageFlags(ctx context.Context, tx *ent.Tx, messageIDs []imap.Internal return nil } -func SetDeletedFlag(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID, deleted bool) error { +func SetDeletedFlag(ctx context.Context, tx *internal.Tx, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID, deleted bool) error { for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { if _, err := tx.UID.Update(). Where( @@ -515,7 +463,7 @@ func SetDeletedFlag(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailbox return nil } -func ClearRecentFlag(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID, messageID imap.InternalMessageID) error { +func ClearRecentFlag(ctx context.Context, tx *internal.Tx, mboxID imap.InternalMailboxID, messageID imap.InternalMessageID) error { if _, err := tx.UID.Update(). Where( func(s *sql.Selector) { @@ -529,7 +477,7 @@ func ClearRecentFlag(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailbo return nil } -func ClearRecentFlags(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID) error { +func ClearRecentFlags(ctx context.Context, tx *internal.Tx, mboxID imap.InternalMailboxID) error { if _, err := tx.UID.Update(). Where( func(s *sql.Selector) { @@ -543,7 +491,7 @@ func ClearRecentFlags(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailb return nil } -func UpdateRemoteMessageID(ctx context.Context, tx *ent.Tx, internalID imap.InternalMessageID, remoteID imap.MessageID) error { +func UpdateRemoteMessageID(ctx context.Context, tx *internal.Tx, internalID imap.InternalMessageID, remoteID imap.MessageID) error { if _, err := tx.Message.Update(). Where(message.ID(internalID)). SetRemoteID(remoteID). @@ -554,7 +502,7 @@ func UpdateRemoteMessageID(ctx context.Context, tx *ent.Tx, internalID imap.Inte return nil } -func MarkMessageAsDeleted(ctx context.Context, tx *ent.Tx, messageID imap.InternalMessageID) error { +func MarkMessageAsDeleted(ctx context.Context, tx *internal.Tx, messageID imap.InternalMessageID) error { if _, err := tx.Message.Update().Where(message.ID(messageID)).SetDeleted(true).Save(ctx); err != nil { return err } @@ -562,7 +510,7 @@ func MarkMessageAsDeleted(ctx context.Context, tx *ent.Tx, messageID imap.Intern return nil } -func MarkMessageAsDeletedAndAssignRandomRemoteID(ctx context.Context, tx *ent.Tx, messageID imap.InternalMessageID) error { +func MarkMessageAsDeletedAndAssignRandomRemoteID(ctx context.Context, tx *internal.Tx, messageID imap.InternalMessageID) error { randomID := imap.MessageID(fmt.Sprintf("DELETED-%v", imap.NewInternalMessageID())) if _, err := tx.Message.Update().Where(message.ID(messageID)).SetDeleted(true).SetRemoteID(randomID).Save(ctx); err != nil { return err @@ -571,7 +519,7 @@ func MarkMessageAsDeletedAndAssignRandomRemoteID(ctx context.Context, tx *ent.Tx return nil } -func MarkMessageAsDeletedWithRemoteID(ctx context.Context, tx *ent.Tx, messageID imap.MessageID) error { +func MarkMessageAsDeletedWithRemoteID(ctx context.Context, tx *internal.Tx, messageID imap.MessageID) error { if _, err := tx.Message.Update().Where(message.RemoteID(messageID)).SetDeleted(true).Save(ctx); err != nil { return err } @@ -579,7 +527,7 @@ func MarkMessageAsDeletedWithRemoteID(ctx context.Context, tx *ent.Tx, messageID return nil } -func DeleteMessages(ctx context.Context, tx *ent.Tx, messageIDs ...imap.InternalMessageID) error { +func DeleteMessages(ctx context.Context, tx *internal.Tx, messageIDs ...imap.InternalMessageID) error { for _, chunk := range xslices.Chunk(messageIDs, ChunkLimit) { if _, err := tx.Message.Delete().Where(message.IDIn(chunk...)).Exec(ctx); err != nil { return err @@ -589,26 +537,40 @@ func DeleteMessages(ctx context.Context, tx *ent.Tx, messageIDs ...imap.Internal return nil } -func GetMessageIDsMarkedDeleted(ctx context.Context, client *ent.Client) ([]imap.InternalMessageID, error) { +func GetMessageIDsMarkedDeleted(ctx context.Context, client *internal.Client) ([]imap.InternalMessageID, error) { messages, err := client.Message.Query().Where(message.Deleted(true)).Select(message.FieldID).All(ctx) if err != nil { return nil, err } - return xslices.Map(messages, func(t *ent.Message) imap.InternalMessageID { + return xslices.Map(messages, func(t *internal.Message) imap.InternalMessageID { return t.ID }), nil } -func HasMessageWithID(ctx context.Context, client *ent.Client, id imap.InternalMessageID) (bool, error) { +func HasMessageWithID(ctx context.Context, client *internal.Client, id imap.InternalMessageID) (bool, error) { return client.Message.Query().Where(message.ID(id)).Exist(ctx) } -func HasMessageWithRemoteID(ctx context.Context, client *ent.Client, id imap.MessageID) (bool, error) { - return client.Message.Query().Where(message.RemoteID(id)).Exist(ctx) +func HasMessageWithRemoteID(ctx context.Context, client *internal.Client, id imap.MessageID) (bool, error) { + _, err := client.Message.Query().Where(message.RemoteID(id)).Select(message.FieldRemoteID).Only(ctx) + if err != nil { + if internal.IsNotFound(err) { + return false, nil + } + + return false, err + } + + // For whatever weird reason, this stopped working all together. No code changes were made, but reflection started + // failing. + // Now we get error="internal: check existence: sql/scan: missing struct field for column: id (id)". + //return client.Message.Query().Where(message.RemoteID(id)).Exist(ctx) + + return true, nil } -func GetMessageIDFromRemoteID(ctx context.Context, client *ent.Client, id imap.MessageID) (imap.InternalMessageID, error) { +func GetMessageIDFromRemoteID(ctx context.Context, client *internal.Client, id imap.MessageID) (imap.InternalMessageID, error) { message, err := client.Message.Query().Where(message.RemoteID(id)).Select(message.FieldID).Only(ctx) if err != nil { return imap.InternalMessageID{}, err @@ -617,7 +579,7 @@ func GetMessageIDFromRemoteID(ctx context.Context, client *ent.Client, id imap.M return message.ID, nil } -func GetMessageWithIDWithDeletedFlag(ctx context.Context, client *ent.Client, id imap.InternalMessageID) (*ent.Message, error) { +func GetMessageWithIDWithDeletedFlag(ctx context.Context, client *internal.Client, id imap.InternalMessageID) (*internal.Message, error) { message, err := client.Message.Query().Where(message.ID(id)).Select(message.FieldID, message.FieldDeleted).Only(ctx) if err != nil { return nil, err @@ -626,7 +588,7 @@ func GetMessageWithIDWithDeletedFlag(ctx context.Context, client *ent.Client, id return message, nil } -func GetMessageFromRemoteIDWithDeletedFlag(ctx context.Context, client *ent.Client, id imap.MessageID) (*ent.Message, error) { +func GetMessageFromRemoteIDWithDeletedFlag(ctx context.Context, client *internal.Client, id imap.MessageID) (*internal.Message, error) { message, err := client.Message.Query().Where(message.RemoteID(id)).Select(message.FieldID, message.FieldDeleted).Only(ctx) if err != nil { return nil, err @@ -635,7 +597,7 @@ func GetMessageFromRemoteIDWithDeletedFlag(ctx context.Context, client *ent.Clie return message, nil } -func GetMessageRemoteIDFromID(ctx context.Context, client *ent.Client, id imap.InternalMessageID) (imap.MessageID, error) { +func GetMessageRemoteIDFromID(ctx context.Context, client *internal.Client, id imap.InternalMessageID) (imap.MessageID, error) { message, err := client.Message.Query().Where(message.ID(id)).Select(message.FieldRemoteID).Only(ctx) if err != nil { return "", err @@ -644,8 +606,8 @@ func GetMessageRemoteIDFromID(ctx context.Context, client *ent.Client, id imap.I return message.RemoteID, nil } -func NewFlagSet(msgUID *ent.UID, flags []*ent.MessageFlag) imap.FlagSet { - flagSet := imap.NewFlagSetFromSlice(xslices.Map(flags, func(flag *ent.MessageFlag) string { +func NewFlagSet(msgUID *internal.UID, flags []*internal.MessageFlag) imap.FlagSet { + flagSet := imap.NewFlagSetFromSlice(xslices.Map(flags, func(flag *internal.MessageFlag) string { return flag.Value })) @@ -660,8 +622,8 @@ func NewFlagSet(msgUID *ent.UID, flags []*ent.MessageFlag) imap.FlagSet { return flagSet } -func GetHighestMessageID(ctx context.Context, client *ent.Client) (imap.InternalMessageID, error) { - message, err := client.Message.Query().Select(message.FieldID).Order(ent.Desc(message.FieldID)).Limit(1).All(ctx) +func GetHighestMessageID(ctx context.Context, client *internal.Client) (imap.InternalMessageID, error) { + message, err := client.Message.Query().Select(message.FieldID).Order(internal.Desc(message.FieldID)).Limit(1).All(ctx) if err != nil { return imap.InternalMessageID{}, err } @@ -673,7 +635,7 @@ func GetHighestMessageID(ctx context.Context, client *ent.Client) (imap.Internal return message[0].ID, nil } -func GetAllMessagesIDsAsMap(ctx context.Context, client *ent.Client) (map[imap.InternalMessageID]struct{}, error) { +func GetAllMessagesIDsAsMap(ctx context.Context, client *internal.Client) (map[imap.InternalMessageID]struct{}, error) { messages, err := client.Message.Query().Select(message.FieldID).All(ctx) if err != nil { return nil, err diff --git a/internal/db_impl/ent_db/ops_read.go b/internal/db_impl/ent_db/ops_read.go new file mode 100644 index 00000000..16ea70e9 --- /dev/null +++ b/internal/db_impl/ent_db/ops_read.go @@ -0,0 +1,332 @@ +package ent_db + +import ( + "context" + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/imap" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" + "github.com/bradenaw/juniper/xslices" + "time" +) + +type EntOpsRead struct { + client *internal.Client +} + +func newOpsReadFromClient(client *internal.Client) *EntOpsRead { + return &EntOpsRead{ + client: client, + } +} + +func newOpsReadFromTx(tx *internal.Tx) *EntOpsRead { + return &EntOpsRead{ + client: tx.Client(), + } +} + +func (op *EntOpsRead) MailboxExistsWithID(ctx context.Context, mboxID imap.InternalMailboxID) (bool, error) { + return wrapEntErrFnTyped(func() (bool, error) { + return MailboxExistsWithID(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) MailboxExistsWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (bool, error) { + return wrapEntErrFnTyped(func() (bool, error) { + return MailboxExistsWithRemoteID(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) MailboxExistsWithName(ctx context.Context, name string) (bool, error) { + return wrapEntErrFnTyped(func() (bool, error) { + return MailboxExistsWithName(ctx, op.client, name) + }) +} + +func (op *EntOpsRead) GetMailboxIDFromRemoteID(ctx context.Context, mboxID imap.MailboxID) (imap.InternalMailboxID, error) { + return wrapEntErrFnTyped(func() (imap.InternalMailboxID, error) { + return GetMailboxIDFromRemoteID(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxName(ctx context.Context, mboxID imap.InternalMailboxID) (string, error) { + return wrapEntErrFnTyped(func() (string, error) { + return GetMailboxName(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxNameWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (string, error) { + return wrapEntErrFnTyped(func() (string, error) { + return GetMailboxNameWithRemoteID(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxMessageIDPairs(ctx context.Context, mboxID imap.InternalMailboxID) ([]db.MessageIDPair, error) { + return wrapEntErrFnTyped(func() ([]db.MessageIDPair, error) { + return GetMailboxMessageIDPairs(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetAllMailboxes(ctx context.Context) ([]*db.Mailbox, error) { + return wrapEntErrFnTyped(func() ([]*db.Mailbox, error) { + val, err := GetAllMailboxes(ctx, op.client) + + return xslices.Map(val, entMBoxToDB), err + }) +} + +func (op *EntOpsRead) GetAllMailboxesAsRemoteIDs(ctx context.Context) ([]imap.MailboxID, error) { + return wrapEntErrFnTyped(func() ([]imap.MailboxID, error) { + val, err := GetAllMailboxes(ctx, op.client) + + return xslices.Map(val, func(t *internal.Mailbox) imap.MailboxID { + return t.RemoteID + }), err + }) +} + +func (op *EntOpsRead) GetMailboxByName(ctx context.Context, name string) (*db.Mailbox, error) { + return wrapEntErrFnTyped(func() (*db.Mailbox, error) { + val, err := GetMailboxByName(ctx, op.client, name) + + return entMBoxToDB(val), err + }) +} + +func (op *EntOpsRead) GetMailboxByID(ctx context.Context, mboxID imap.InternalMailboxID) (*db.Mailbox, error) { + return wrapEntErrFnTyped(func() (*db.Mailbox, error) { + val, err := GetMailboxByID(ctx, op.client, mboxID) + + return entMBoxToDB(val), err + }) +} + +func (op *EntOpsRead) GetMailboxByRemoteID(ctx context.Context, mboxID imap.MailboxID) (*db.Mailbox, error) { + return wrapEntErrFnTyped(func() (*db.Mailbox, error) { + val, err := GetMailboxByRemoteID(ctx, op.client, mboxID) + + return entMBoxToDB(val), err + }) +} + +func (op *EntOpsRead) GetMailboxRecentCount(ctx context.Context, mboxID imap.InternalMailboxID) (int, error) { + return wrapEntErrFnTyped(func() (int, error) { + mbox, err := GetMailboxByID(ctx, op.client, mboxID) + if err != nil { + return 0, err + } + + return GetMailboxRecentCount(ctx, op.client, mbox) + }) +} + +func (op *EntOpsRead) GetMailboxMessageCount(ctx context.Context, mboxID imap.InternalMailboxID) (int, error) { + return wrapEntErrFnTyped(func() (int, error) { + return GetMailboxMessageCount(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxMessageCountWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (int, error) { + return wrapEntErrFnTyped(func() (int, error) { + mbox, err := GetMailboxByRemoteID(ctx, op.client, mboxID) + if err != nil { + return 0, err + } + + return GetMailboxMessageCount(ctx, op.client, mbox.ID) + }) +} + +func (op *EntOpsRead) GetMailboxFlags(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { + return wrapEntErrFnTyped(func() (imap.FlagSet, error) { + return GetMailboxFlags(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxPermanentFlags(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { + return wrapEntErrFnTyped(func() (imap.FlagSet, error) { + return GetMailboxPermanentFlags(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxAttributes(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { + return wrapEntErrFnTyped(func() (imap.FlagSet, error) { + return GetMailboxAttributes(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxUID(ctx context.Context, mboxID imap.InternalMailboxID) (imap.UID, error) { + return wrapEntErrFnTyped(func() (imap.UID, error) { + return GetMailboxUID(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) GetMailboxMessageCountAndUID(ctx context.Context, mboxID imap.InternalMailboxID) (int, imap.UID, error) { + var count int + + var uid imap.UID + + err := wrapEntErrFn(func() error { + var err error + + count, uid, err = GetMailboxMessageCountAndUID(ctx, op.client, mboxID) + + return err + }) + + return count, uid, err +} + +func (op *EntOpsRead) GetMailboxMessageForNewSnapshot(ctx context.Context, mboxID imap.InternalMailboxID) ([]db.SnapshotMessageResult, error) { + return wrapEntErrFnTyped(func() ([]db.SnapshotMessageResult, error) { + return GetMailboxMessagesForNewSnapshot(ctx, op.client, mboxID) + }) +} + +func (op *EntOpsRead) MailboxTranslateRemoteIDs(ctx context.Context, mboxIDs []imap.MailboxID) ([]imap.InternalMailboxID, error) { + return wrapEntErrFnTyped(func() ([]imap.InternalMailboxID, error) { + return TranslateRemoteMailboxIDs(ctx, op.client, mboxIDs) + }) +} + +func (op *EntOpsRead) MailboxFilterContains(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []db.MessageIDPair) ([]imap.InternalMessageID, error) { + return wrapEntErrFnTyped(func() ([]imap.InternalMessageID, error) { + return FilterMailboxContains(ctx, op.client, mboxID, messageIDs) + }) +} + +func (op *EntOpsRead) MailboxFilterContainsInternalID(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]imap.InternalMessageID, error) { + return wrapEntErrFnTyped(func() ([]imap.InternalMessageID, error) { + return FilterMailboxContainsInternalID(ctx, op.client, mboxID, messageIDs) + }) +} + +func (op *EntOpsRead) GetMailboxCount(ctx context.Context) (int, error) { + return wrapEntErrFnTyped(func() (int, error) { + return GetMailboxCount(ctx, op.client) + }) +} + +func (op *EntOpsRead) GetMailboxMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { + return wrapEntErrFnTyped(func() ([]db.UIDWithFlags, error) { + return GetMessageUIDsWithFlagsAfterAddOrUIDBump(ctx, op.client, mboxID, messageIDs) + }) +} + +func (op *EntOpsRead) MessageExists(ctx context.Context, id imap.InternalMessageID) (bool, error) { + return wrapEntErrFnTyped(func() (bool, error) { + return HasMessageWithID(ctx, op.client, id) + }) +} + +func (op *EntOpsRead) MessageExistsWithRemoteID(ctx context.Context, id imap.MessageID) (bool, error) { + return wrapEntErrFnTyped(func() (bool, error) { + return HasMessageWithRemoteID(ctx, op.client, id) + }) +} + +func (op *EntOpsRead) GetMessage(ctx context.Context, id imap.InternalMessageID) (*db.Message, error) { + return wrapEntErrFnTyped(func() (*db.Message, error) { + msg, err := GetMessage(ctx, op.client, id) + + return entMessageToDB(msg), err + }) +} + +func (op *EntOpsRead) GetTotalMessageCount(ctx context.Context) (int, error) { + return wrapEntErrFnTyped(func() (int, error) { + return op.client.Message.Query().Count(ctx) + }) +} + +func (op *EntOpsRead) GetMessageRemoteID(ctx context.Context, id imap.InternalMessageID) (imap.MessageID, error) { + return wrapEntErrFnTyped(func() (imap.MessageID, error) { + return GetMessageRemoteIDFromID(ctx, op.client, id) + }) +} + +func (op *EntOpsRead) GetImportedMessageData(ctx context.Context, id imap.InternalMessageID) (*db.Message, error) { + return wrapEntErrFnTyped(func() (*db.Message, error) { + msg, err := GetImportedMessageData(ctx, op.client, id) + + return entMessageToDB(msg), err + }) +} + +func (op *EntOpsRead) GetMessageDateAndSize(ctx context.Context, id imap.InternalMessageID) (time.Time, int, error) { + var date time.Time + + var size int + + err := wrapEntErrFn(func() error { + msg, err := GetMessageDateAndSize(ctx, op.client, id) + if err != nil { + return err + } + + date = msg.Date + size = msg.Size + + return err + }) + + return date, size, err +} + +func (op *EntOpsRead) GetMessageMailboxIDs(ctx context.Context, id imap.InternalMessageID) ([]imap.InternalMailboxID, error) { + return wrapEntErrFnTyped(func() ([]imap.InternalMailboxID, error) { + return GetMessageMailboxIDs(ctx, op.client, id) + }) +} + +func (op *EntOpsRead) GetMessagesFlags(ctx context.Context, ids []imap.InternalMessageID) ([]db.MessageFlagSet, error) { + return wrapEntErrFnTyped(func() ([]db.MessageFlagSet, error) { + return GetMessageFlags(ctx, op.client, ids) + }) +} + +func (op *EntOpsRead) GetMessageIDsMarkedAsDelete(ctx context.Context) ([]imap.InternalMessageID, error) { + return wrapEntErrFnTyped(func() ([]imap.InternalMessageID, error) { + return GetMessageIDsMarkedDeleted(ctx, op.client) + }) +} + +func (op *EntOpsRead) GetMessageIDFromRemoteID(ctx context.Context, id imap.MessageID) (imap.InternalMessageID, error) { + return wrapEntErrFnTyped(func() (imap.InternalMessageID, error) { + return GetMessageIDFromRemoteID(ctx, op.client, id) + }) +} + +func (op *EntOpsRead) GetMessageDeletedFlag(ctx context.Context, id imap.InternalMessageID) (bool, error) { + return wrapEntErrFnTyped(func() (bool, error) { + msg, err := GetMessageWithIDWithDeletedFlag(ctx, op.client, id) + if err != nil { + return false, err + } + + return msg.Deleted, nil + }) +} + +func (op *EntOpsRead) GetAllMessagesIDsAsMap(ctx context.Context) (map[imap.InternalMessageID]struct{}, error) { + return wrapEntErrFnTyped(func() (map[imap.InternalMessageID]struct{}, error) { + return GetAllMessagesIDsAsMap(ctx, op.client) + }) +} + +func (op *EntOpsRead) GetDeletedSubscriptionSet(ctx context.Context) (map[imap.MailboxID]*db.DeletedSubscription, error) { + return wrapEntErrFnTyped(func() (map[imap.MailboxID]*db.DeletedSubscription, error) { + ent, err := GetDeletedSubscriptionSet(ctx, op.client) + if err != nil { + return nil, err + } + + result := make(map[imap.MailboxID]*db.DeletedSubscription, len(ent)) + + for k, v := range ent { + result[k] = entSubscriptionToDB(v) + } + + return result, nil + }) +} diff --git a/internal/db_impl/ent_db/ops_write.go b/internal/db_impl/ent_db/ops_write.go new file mode 100644 index 00000000..8d914950 --- /dev/null +++ b/internal/db_impl/ent_db/ops_write.go @@ -0,0 +1,224 @@ +package ent_db + +import ( + "context" + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/imap" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" + "github.com/bradenaw/juniper/xslices" +) + +type EntOpsWrite struct { + EntOpsRead + tx *internal.Tx +} + +func newEntOpsWrite(tx *internal.Tx) *EntOpsWrite { + return &EntOpsWrite{ + EntOpsRead: EntOpsRead{client: tx.Client()}, + tx: tx, + } +} + +func (op *EntOpsWrite) CreateMailbox(ctx context.Context, + mboxID imap.MailboxID, + name string, + flags, permFlags, attrs imap.FlagSet, + uidValidity imap.UID) (*db.Mailbox, error) { + return wrapEntErrFnTyped(func() (*db.Mailbox, error) { + mbox, err := CreateMailbox(ctx, op.tx, mboxID, name, flags, permFlags, attrs, uidValidity) + + return entMBoxToDB(mbox), err + }) +} + +func (op *EntOpsWrite) GetOrCreateMailbox(ctx context.Context, + mboxID imap.MailboxID, + name string, + flags, permFlags, attrs imap.FlagSet, + uidValidity imap.UID) (*db.Mailbox, error) { + return wrapEntErrFnTyped(func() (*db.Mailbox, error) { + mbox, err := CreateMailbox(ctx, op.tx, mboxID, name, flags, permFlags, attrs, uidValidity) + + return entMBoxToDB(mbox), err + }) +} + +func (op *EntOpsWrite) GetOrCreateMailboxAlt(ctx context.Context, + mbox imap.Mailbox, + delimiter string, + uidValidity imap.UID) (*db.Mailbox, error) { + return wrapEntErrFnTyped(func() (*db.Mailbox, error) { + mbox, err := GetOrCreateMailbox(ctx, op.tx, mbox, delimiter, uidValidity) + + return entMBoxToDB(mbox), err + }) +} + +func (op *EntOpsWrite) RenameMailboxWithRemoteID(ctx context.Context, mboxID imap.MailboxID, name string) error { + return wrapEntErrFn(func() error { + return RenameMailboxWithRemoteID(ctx, op.tx, mboxID, name) + }) +} + +func (op *EntOpsWrite) DeleteMailboxWithRemoteID(ctx context.Context, mboxID imap.MailboxID) error { + return wrapEntErrFn(func() error { + return DeleteMailboxWithRemoteID(ctx, op.tx, mboxID) + }) +} + +func (op *EntOpsWrite) BumpMailboxUIDNext(ctx context.Context, mboxID imap.InternalMailboxID, count int) error { + return wrapEntErrFn(func() error { + mbox, err := op.tx.Mailbox.Query().Where(mailbox.ID(mboxID)).Only(ctx) + if err != nil { + return err + } + + return BumpMailboxUIDNext(ctx, op.tx, mbox, count) + }) +} + +func (op *EntOpsWrite) AddMessagesToMailbox(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { + return wrapEntErrFnTyped(func() ([]db.UIDWithFlags, error) { + return AddMessagesToMailbox(ctx, op.tx, messageIDs, mboxID) + }) +} + +func (op *EntOpsWrite) BumpMailboxUIDsForMessage(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { + return wrapEntErrFnTyped(func() ([]db.UIDWithFlags, error) { + return BumpMailboxUIDsForMessage(ctx, op.tx, messageIDs, mboxID) + }) +} + +func (op *EntOpsWrite) RemoveMessagesFromMailbox(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) error { + return wrapEntErrFn(func() error { + return RemoveMessagesFromMailbox(ctx, op.tx, messageIDs, mboxID) + }) +} + +func (op *EntOpsWrite) ClearRecentFlagInMailboxOnMessage(ctx context.Context, mboxID imap.InternalMailboxID, messageID imap.InternalMessageID) error { + return wrapEntErrFn(func() error { + return ClearRecentFlag(ctx, op.tx, mboxID, messageID) + }) +} + +func (op *EntOpsWrite) ClearRecentFlagsInMailbox(ctx context.Context, mboxID imap.InternalMailboxID) error { + return wrapEntErrFn(func() error { + return ClearRecentFlags(ctx, op.tx, mboxID) + }) +} + +func (op *EntOpsWrite) CreateMailboxIfNotExists(ctx context.Context, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) error { + return wrapEntErrFn(func() error { + return CreateMailboxIfNotExists(ctx, op.tx, mbox, delimiter, uidValidity) + }) +} + +func (op *EntOpsWrite) SetMailboxMessagesDeletedFlag(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID, deleted bool) error { + return wrapEntErrFn(func() error { + return SetDeletedFlag(ctx, op.tx, mboxID, messageIDs, deleted) + }) +} + +func (op *EntOpsWrite) SetMailboxSubscribed(ctx context.Context, mboxID imap.InternalMailboxID, subscribed bool) error { + return wrapEntErrFn(func() error { + return op.tx.Mailbox.Update().Where(mailbox.ID(mboxID)).SetSubscribed(subscribed).Exec(ctx) + }) +} + +func (op *EntOpsWrite) UpdateRemoteMailboxID(ctx context.Context, mobxID imap.InternalMailboxID, remoteID imap.MailboxID) error { + return wrapEntErrFn(func() error { + return UpdateRemoteMailboxID(ctx, op.tx, mobxID, remoteID) + }) +} + +func (op *EntOpsWrite) SetMailboxUIDValidity(ctx context.Context, mboxID imap.InternalMailboxID, uidValidity imap.UID) error { + return wrapEntErrFn(func() error { + return op.tx.Mailbox.Update().Where(mailbox.ID(mboxID)).SetUIDValidity(uidValidity).Exec(ctx) + }) +} + +func (op *EntOpsWrite) CreateMessages(ctx context.Context, reqs ...*db.CreateMessageReq) ([]*db.Message, error) { + return wrapEntErrFnTyped(func() ([]*db.Message, error) { + msgs, err := CreateMessages(ctx, op.tx, reqs...) + + return xslices.Map(msgs, entMessageToDB), err + }) +} + +func (op *EntOpsWrite) CreateMessageAndAddToMailbox(ctx context.Context, mbox imap.InternalMailboxID, req *db.CreateMessageReq) (imap.UID, imap.FlagSet, error) { + var uid imap.UID + + var flagSet imap.FlagSet + + err := wrapEntErrFn(func() error { + var err error + + uid, flagSet, err = CreateAndAddMessageToMailbox(ctx, op.tx, mbox, req) + + return err + }) + + return uid, flagSet, err +} + +func (op *EntOpsWrite) MarkMessageAsDeleted(ctx context.Context, id imap.InternalMessageID) error { + return wrapEntErrFn(func() error { + return MarkMessageAsDeleted(ctx, op.tx, id) + }) +} + +func (op *EntOpsWrite) MarkMessageAsDeletedAndAssignRandomRemoteID(ctx context.Context, id imap.InternalMessageID) error { + return wrapEntErrFn(func() error { + return MarkMessageAsDeletedAndAssignRandomRemoteID(ctx, op.tx, id) + }) +} + +func (op *EntOpsWrite) MarkMessageAsDeletedWithRemoteID(ctx context.Context, id imap.MessageID) error { + return wrapEntErrFn(func() error { + return MarkMessageAsDeletedWithRemoteID(ctx, op.tx, id) + }) +} + +func (op *EntOpsWrite) DeleteMessages(ctx context.Context, ids []imap.InternalMessageID) error { + return wrapEntErrFn(func() error { + return DeleteMessages(ctx, op.tx, ids...) + }) +} + +func (op *EntOpsWrite) UpdateRemoteMessageID(ctx context.Context, internalID imap.InternalMessageID, remoteID imap.MessageID) error { + return wrapEntErrFn(func() error { + return UpdateRemoteMessageID(ctx, op.tx, internalID, remoteID) + }) +} + +func (op *EntOpsWrite) AddFlagToMessages(ctx context.Context, ids []imap.InternalMessageID, flag string) error { + return wrapEntErrFn(func() error { + return AddMessageFlag(ctx, op.tx, ids, flag) + }) +} + +func (op *EntOpsWrite) RemoveFlagFromMessages(ctx context.Context, ids []imap.InternalMessageID, flag string) error { + return wrapEntErrFn(func() error { + return RemoveMessageFlag(ctx, op.tx, ids, flag) + }) +} + +func (op *EntOpsWrite) SetFlagsOnMessages(ctx context.Context, ids []imap.InternalMessageID, flags imap.FlagSet) error { + return wrapEntErrFn(func() error { + return SetMessageFlags(ctx, op.tx, ids, flags) + }) +} + +func (op *EntOpsWrite) AddDeletedSubscription(ctx context.Context, mboxName string, mboxID imap.MailboxID) error { + return wrapEntErrFn(func() error { + return AddDeletedSubscription(ctx, op.tx, mboxName, mboxID) + }) +} + +func (op *EntOpsWrite) RemoveDeletedSubscriptionWithName(ctx context.Context, mboxName string) (int, error) { + return wrapEntErrFnTyped(func() (int, error) { + return RemoveDeletedSubscriptionWithName(ctx, op.tx, mboxName) + }) +} diff --git a/internal/db/subscription.go b/internal/db_impl/ent_db/subscription.go similarity index 63% rename from internal/db/subscription.go rename to internal/db_impl/ent_db/subscription.go index 6fa702ca..79b46747 100644 --- a/internal/db/subscription.go +++ b/internal/db_impl/ent_db/subscription.go @@ -1,14 +1,14 @@ -package db +package ent_db import ( "context" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/db/ent/deletedsubscription" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/deletedsubscription" ) -func AddDeletedSubscription(ctx context.Context, tx *ent.Tx, mboxName string, mboxID imap.MailboxID) error { +func AddDeletedSubscription(ctx context.Context, tx *internal.Tx, mboxName string, mboxID imap.MailboxID) error { count, err := tx.DeletedSubscription.Update().Where(deletedsubscription.NameEqualFold(mboxName)).SetRemoteID(mboxID).Save(ctx) if err != nil { return err @@ -23,14 +23,14 @@ func AddDeletedSubscription(ctx context.Context, tx *ent.Tx, mboxName string, mb return nil } -func RemoveDeletedSubscriptionWithName(ctx context.Context, tx *ent.Tx, mboxName string) (int, error) { +func RemoveDeletedSubscriptionWithName(ctx context.Context, tx *internal.Tx, mboxName string) (int, error) { return tx.DeletedSubscription.Delete().Where(deletedsubscription.NameEqualFold(mboxName)).Exec(ctx) } -func GetDeletedSubscriptionSet(ctx context.Context, client *ent.Client) (map[imap.MailboxID]*ent.DeletedSubscription, error) { +func GetDeletedSubscriptionSet(ctx context.Context, client *internal.Client) (map[imap.MailboxID]*internal.DeletedSubscription, error) { const QueryLimit = 16000 - subscriptions := make(map[imap.MailboxID]*ent.DeletedSubscription) + subscriptions := make(map[imap.MailboxID]*internal.DeletedSubscription) for i := 0; ; i += QueryLimit { result, err := client.DeletedSubscription.Query(). diff --git a/internal/db_impl/ent_db/type_conversions.go b/internal/db_impl/ent_db/type_conversions.go new file mode 100644 index 00000000..a6595fa1 --- /dev/null +++ b/internal/db_impl/ent_db/type_conversions.go @@ -0,0 +1,103 @@ +package ent_db + +import ( + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" + "github.com/bradenaw/juniper/xslices" +) + +func entMailboxFlagToDB(flag *internal.MailboxFlag) *db.MailboxFlag { + if flag == nil { + return nil + } + + return &db.MailboxFlag{ + ID: flag.ID, + Value: flag.Value, + } +} + +func entMailboxPermFlagsToDB(flag *internal.MailboxPermFlag) *db.MailboxFlag { + if flag == nil { + return nil + } + + return &db.MailboxFlag{ + ID: flag.ID, + Value: flag.Value, + } +} + +func entMailboxAttrToDB(flag *internal.MailboxAttr) *db.MailboxAttr { + if flag == nil { + return nil + } + + return &db.MailboxAttr{ + ID: flag.ID, + Value: flag.Value, + } +} + +func entMBoxToDB(mbox *internal.Mailbox) *db.Mailbox { + if mbox == nil { + return nil + } + + return &db.Mailbox{ + ID: mbox.ID, + RemoteID: mbox.RemoteID, + Name: mbox.Name, + UIDNext: mbox.UIDNext, + UIDValidity: mbox.UIDValidity, + Subscribed: mbox.Subscribed, + Flags: xslices.Map(mbox.Edges.Flags, entMailboxFlagToDB), + PermanentFlags: xslices.Map(mbox.Edges.PermanentFlags, entMailboxPermFlagsToDB), + Attributes: xslices.Map(mbox.Edges.Attributes, entMailboxAttrToDB), + } +} + +func entMessageFlagsToDB(flag *internal.MessageFlag) *db.MessageFlag { + return &db.MessageFlag{ + ID: flag.ID, + Value: flag.Value, + } +} + +func entMessageUIDToDB(uid *internal.UID) *db.UID { + return &db.UID{ + UID: uid.UID, + Deleted: uid.Deleted, + Recent: uid.Recent, + } +} + +func entMessageToDB(msg *internal.Message) *db.Message { + if msg == nil { + return nil + } + + return &db.Message{ + ID: msg.ID, + RemoteID: msg.RemoteID, + Date: msg.Date, + Size: msg.Size, + Body: msg.Body, + BodyStructure: msg.BodyStructure, + Envelope: msg.Envelope, + Deleted: msg.Deleted, + Flags: xslices.Map(msg.Edges.Flags, entMessageFlagsToDB), + UIDs: xslices.Map(msg.Edges.UIDs, entMessageUIDToDB), + } +} + +func entSubscriptionToDB(s *internal.DeletedSubscription) *db.DeletedSubscription { + if s == nil { + return nil + } + + return &db.DeletedSubscription{ + Name: s.Name, + RemoteID: s.RemoteID, + } +} diff --git a/internal/ids/ids.go b/internal/ids/ids.go index cf75a299..234160a3 100644 --- a/internal/ids/ids.go +++ b/internal/ids/ids.go @@ -5,76 +5,8 @@ import ( "strings" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent" ) -type MailboxIDPair struct { - InternalID imap.InternalMailboxID - RemoteID imap.MailboxID -} - -func (m *MailboxIDPair) String() string { - return fmt.Sprintf("%v::%v", m.InternalID, m.RemoteID) -} - -type MessageIDPair struct { - InternalID imap.InternalMessageID - RemoteID imap.MessageID -} - -func (m *MessageIDPair) String() string { - return fmt.Sprintf("%v::%v", m.InternalID, m.RemoteID) -} - -func NewMailboxIDPair(mbox *ent.Mailbox) MailboxIDPair { - return MailboxIDPair{ - InternalID: mbox.ID, - RemoteID: mbox.RemoteID, - } -} - -func NewMailboxIDPairWithoutRemote(internalID imap.InternalMailboxID) MailboxIDPair { - return MailboxIDPair{ - InternalID: internalID, - RemoteID: "", - } -} - -func NewMessageIDPair(msg *ent.Message) MessageIDPair { - return MessageIDPair{ - InternalID: msg.ID, - RemoteID: msg.RemoteID, - } -} - -func SplitMessageIDPairSlice(s []MessageIDPair) ([]imap.InternalMessageID, []imap.MessageID) { - l := len(s) - - internalMessageIDs := make([]imap.InternalMessageID, 0, l) - remoteMessageIDs := make([]imap.MessageID, 0, l) - - for _, v := range s { - internalMessageIDs = append(internalMessageIDs, v.InternalID) - remoteMessageIDs = append(remoteMessageIDs, v.RemoteID) - } - - return internalMessageIDs, remoteMessageIDs -} - -func SplitMailboxIDPairSlice(s []MailboxIDPair) ([]imap.InternalMailboxID, []imap.MailboxID) { - l := len(s) - - internalMailboxIDs := make([]imap.InternalMailboxID, 0, l) - mailboxIDs := make([]imap.MailboxID, 0, l) - - for _, v := range s { - internalMailboxIDs = append(internalMailboxIDs, v.InternalID) - mailboxIDs = append(mailboxIDs, v.RemoteID) - } - - return internalMailboxIDs, mailboxIDs -} - const GluonRecoveryMailboxName = "Recovered Messages" const GluonRecoveryMailboxNameLowerCase = "recovered messages" const GluonInternalRecoveryMailboxRemoteID = imap.MailboxID("GLUON-INTERNAL-RECOVERY-MBOX") diff --git a/internal/state/actions.go b/internal/state/actions.go index a9aad84d..540416d0 100644 --- a/internal/state/actions.go +++ b/internal/state/actions.go @@ -7,9 +7,8 @@ import ( "strings" "time" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/rfc822" @@ -18,21 +17,20 @@ import ( "golang.org/x/exp/slices" ) -func (state *State) actionCreateAndGetMailbox(ctx context.Context, tx *ent.Tx, name string, uidValidity imap.UID) (*ent.Mailbox, error) { +func (state *State) actionCreateAndGetMailbox(ctx context.Context, tx db.Transaction, name string, uidValidity imap.UID) (*db.Mailbox, error) { res, err := state.user.GetRemote().CreateMailbox(ctx, strings.Split(name, state.delimiter)) if err != nil { return nil, err } - exists, err := db.MailboxExistsWithRemoteID(ctx, tx.Client(), res.ID) + exists, err := tx.MailboxExistsWithRemoteID(ctx, res.ID) if err != nil { return nil, err } if !exists { - mbox, err := db.CreateMailbox( + mbox, err := tx.CreateMailbox( ctx, - tx, res.ID, strings.Join(res.Name, state.user.GetDelimiter()), res.Flags, @@ -47,31 +45,31 @@ func (state *State) actionCreateAndGetMailbox(ctx context.Context, tx *ent.Tx, n return mbox, nil } - return db.GetMailboxByRemoteID(ctx, tx.Client(), res.ID) + return tx.GetMailboxByRemoteID(ctx, res.ID) } -func (state *State) actionCreateMailbox(ctx context.Context, tx *ent.Tx, name string, uidValidity imap.UID) error { +func (state *State) actionCreateMailbox(ctx context.Context, tx db.Transaction, name string, uidValidity imap.UID) error { res, err := state.user.GetRemote().CreateMailbox(ctx, strings.Split(name, state.delimiter)) if err != nil { return err } - return db.CreateMailboxIfNotExists(ctx, tx, res, state.delimiter, uidValidity) + return tx.CreateMailboxIfNotExists(ctx, res, state.delimiter, uidValidity) } -func (state *State) actionDeleteMailbox(ctx context.Context, tx *ent.Tx, mboxID ids.MailboxIDPair) ([]Update, error) { +func (state *State) actionDeleteMailbox(ctx context.Context, tx db.Transaction, mboxID db.MailboxIDPair) ([]Update, error) { if err := state.user.GetRemote().DeleteMailbox(ctx, mboxID.RemoteID); err != nil { return nil, err } - if err := db.DeleteMailboxWithRemoteID(ctx, tx, mboxID.RemoteID); err != nil { + if err := tx.DeleteMailboxWithRemoteID(ctx, mboxID.RemoteID); err != nil { return nil, err } return []Update{NewMailboxDeletedStateUpdate(mboxID.InternalID)}, nil } -func (state *State) actionUpdateMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.MailboxID, newName string) error { +func (state *State) actionUpdateMailbox(ctx context.Context, tx db.Transaction, mboxID imap.MailboxID, newName string) error { if err := state.user.GetRemote().UpdateMailbox( ctx, mboxID, @@ -80,13 +78,13 @@ func (state *State) actionUpdateMailbox(ctx context.Context, tx *ent.Tx, mboxID return err } - return db.RenameMailboxWithRemoteID(ctx, tx, mboxID, newName) + return tx.RenameMailboxWithRemoteID(ctx, mboxID, newName) } func (state *State) actionCreateMessage( ctx context.Context, - tx *ent.Tx, - mboxID ids.MailboxIDPair, + tx db.Transaction, + mboxID db.MailboxIDPair, literal []byte, flags imap.FlagSet, date time.Time, @@ -100,14 +98,14 @@ func (state *State) actionCreateMessage( { // Handle the case where duplicate messages can return the same remote ID. - knownInternalID, knownErr := db.GetMessageIDFromRemoteID(ctx, tx.Client(), res.ID) - if knownErr != nil && !ent.IsNotFound(knownErr) { + knownInternalID, knownErr := tx.GetMessageIDFromRemoteID(ctx, res.ID) + if knownErr != nil && !db.IsErrNotFound(knownErr) { return nil, 0, knownErr } if knownErr == nil { // Try to collect the original message date. var existingMessageDate time.Time - if existingMessage, msgErr := db.GetMessage(ctx, tx.Client(), internalID); msgErr == nil { + if existingMessage, msgErr := tx.GetMessage(ctx, internalID); msgErr == nil { existingMessageDate = existingMessage.Date } @@ -128,7 +126,7 @@ func (state *State) actionCreateMessage( updates, result, err := state.actionAddMessagesToMailbox(ctx, tx, - []ids.MessageIDPair{{InternalID: knownInternalID, RemoteID: res.ID}}, + []db.MessageIDPair{{InternalID: knownInternalID, RemoteID: res.ID}}, mboxID, isSelectedMailbox, ) @@ -163,7 +161,7 @@ func (state *State) actionCreateMessage( InternalID: internalID, } - messageUID, flagSet, err := db.CreateAndAddMessageToMailbox(ctx, tx, mboxID.InternalID, &req) + messageUID, flagSet, err := tx.CreateMessageAndAddToMailbox(ctx, mboxID.InternalID, &req) if err != nil { return nil, 0, err } @@ -177,7 +175,7 @@ func (state *State) actionCreateMessage( updates := []Update{newExistsStateUpdateWithExists( mboxID.InternalID, - []*exists{newExists(ids.MessageIDPair{InternalID: internalID, RemoteID: res.ID}, messageUID, flagSet)}, + []*exists{newExists(db.MessageIDPair{InternalID: internalID, RemoteID: res.ID}, messageUID, flagSet)}, st, ), } @@ -187,7 +185,7 @@ func (state *State) actionCreateMessage( func (state *State) actionCreateRecoveredMessage( ctx context.Context, - tx *ent.Tx, + tx db.Transaction, literal []byte, flags imap.FlagSet, date time.Time, @@ -225,14 +223,14 @@ func (state *State) actionCreateRecoveredMessage( recoveryMBoxID := state.user.GetRecoveryMailboxID() - messageUID, flagSet, err := db.CreateAndAddMessageToMailbox(ctx, tx, recoveryMBoxID.InternalID, &req) + messageUID, flagSet, err := tx.CreateMessageAndAddToMailbox(ctx, recoveryMBoxID.InternalID, &req) if err != nil { return nil, false, err } var updates = []Update{newExistsStateUpdateWithExists( recoveryMBoxID.InternalID, - []*exists{newExists(ids.MessageIDPair{InternalID: internalID, RemoteID: remoteID}, messageUID, flagSet)}, + []*exists{newExists(db.MessageIDPair{InternalID: internalID, RemoteID: remoteID}, messageUID, flagSet)}, nil, ), } @@ -242,20 +240,20 @@ func (state *State) actionCreateRecoveredMessage( func (state *State) actionAddMessagesToMailbox( ctx context.Context, - tx *ent.Tx, - messageIDs []ids.MessageIDPair, - mboxID ids.MailboxIDPair, + tx db.Transaction, + messageIDs []db.MessageIDPair, + mboxID db.MailboxIDPair, isMailboxSelected bool, ) ([]Update, []db.UIDWithFlags, error) { var allUpdates []Update { - haveMessageIDs, err := db.FilterMailboxContains(ctx, tx.Client(), mboxID.InternalID, messageIDs) + haveMessageIDs, err := tx.MailboxFilterContains(ctx, mboxID.InternalID, messageIDs) if err != nil { return nil, nil, err } - if remMessageIDs := xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { + if remMessageIDs := xslices.Filter(messageIDs, func(messageID db.MessageIDPair) bool { return slices.Contains(haveMessageIDs, messageID.InternalID) }); len(remMessageIDs) > 0 { updates, err := state.actionRemoveMessagesFromMailboxUnchecked(ctx, tx, remMessageIDs, mboxID) @@ -267,7 +265,7 @@ func (state *State) actionAddMessagesToMailbox( } } - internalIDs, remoteIDs := ids.SplitMessageIDPairSlice(messageIDs) + internalIDs, remoteIDs := db.SplitMessageIDPairSlice(messageIDs) if err := state.user.GetRemote().AddMessagesToMailbox(ctx, remoteIDs, mboxID.RemoteID); err != nil { return nil, nil, err @@ -291,11 +289,11 @@ func (state *State) actionAddMessagesToMailbox( func (state *State) actionAddRecoveredMessagesToMailbox( ctx context.Context, - tx *ent.Tx, - messageIDs []ids.MessageIDPair, - mboxID ids.MailboxIDPair, + tx db.Transaction, + messageIDs []db.MessageIDPair, + mboxID db.MailboxIDPair, ) ([]db.UIDWithFlags, Update, error) { - internalIDs, remoteIDs := ids.SplitMessageIDPairSlice(messageIDs) + internalIDs, remoteIDs := db.SplitMessageIDPairSlice(messageIDs) if err := state.user.GetRemote().AddMessagesToMailbox(ctx, remoteIDs, mboxID.RemoteID); err != nil { return nil, nil, err @@ -306,39 +304,39 @@ func (state *State) actionAddRecoveredMessagesToMailbox( func (state *State) actionImportRecoveredMessage( ctx context.Context, - tx *ent.Tx, + tx db.Transaction, id imap.InternalMessageID, mboxID imap.MailboxID, -) (ids.MessageIDPair, bool, error) { - message, err := db.GetImportedMessageData(ctx, tx.Client(), id) +) (db.MessageIDPair, bool, error) { + message, err := tx.GetImportedMessageData(ctx, id) if err != nil { - return ids.MessageIDPair{}, false, err + return db.MessageIDPair{}, false, err } literal, err := state.user.GetStore().Get(id) if err != nil { - return ids.MessageIDPair{}, false, err + return db.MessageIDPair{}, false, err } messageFlags := imap.NewFlagSet() - for _, flag := range message.Edges.Flags { + for _, flag := range message.Flags { messageFlags.AddToSelf(flag.Value) } internalID, res, newLiteral, err := state.user.GetRemote().CreateMessage(ctx, mboxID, literal, messageFlags, message.Date) if err != nil { - return ids.MessageIDPair{}, false, err + return db.MessageIDPair{}, false, err } { // Handle the unlikely case where duplicate messages can return the same remote ID. - internalID, err := db.GetMessageIDFromRemoteID(ctx, tx.Client(), res.ID) - if err != nil && !ent.IsNotFound(err) { - return ids.MessageIDPair{}, false, err + internalID, err := tx.GetMessageIDFromRemoteID(ctx, res.ID) + if err != nil && !db.IsErrNotFound(err) { + return db.MessageIDPair{}, false, err } if err == nil { - return ids.MessageIDPair{ + return db.MessageIDPair{ InternalID: internalID, RemoteID: res.ID, }, true, nil @@ -347,16 +345,16 @@ func (state *State) actionImportRecoveredMessage( parsedMessage, err := imap.NewParsedMessage(newLiteral) if err != nil { - return ids.MessageIDPair{}, false, err + return db.MessageIDPair{}, false, err } literalReader, literalSize, err := rfc822.SetHeaderValueNoMemCopy(newLiteral, ids.InternalIDKey, internalID.String()) if err != nil { - return ids.MessageIDPair{}, false, fmt.Errorf("failed to set internal ID: %w", err) + return db.MessageIDPair{}, false, fmt.Errorf("failed to set internal ID: %w", err) } if err := state.user.GetStore().SetUnchecked(internalID, literalReader); err != nil { - return ids.MessageIDPair{}, false, fmt.Errorf("failed to store message literal: %w", err) + return db.MessageIDPair{}, false, fmt.Errorf("failed to store message literal: %w", err) } req := db.CreateMessageReq{ @@ -368,11 +366,11 @@ func (state *State) actionImportRecoveredMessage( InternalID: internalID, } - if _, err := db.CreateMessages(ctx, tx, &req); err != nil { - return ids.MessageIDPair{}, false, err + if _, err := tx.CreateMessages(ctx, &req); err != nil { + return db.MessageIDPair{}, false, err } - return ids.MessageIDPair{ + return db.MessageIDPair{ InternalID: internalID, RemoteID: res.ID, }, false, nil @@ -380,11 +378,11 @@ func (state *State) actionImportRecoveredMessage( func (state *State) actionCopyMessagesOutOfRecoveryMailbox( ctx context.Context, - tx *ent.Tx, - messageIDs []ids.MessageIDPair, - mboxID ids.MailboxIDPair, + tx db.Transaction, + messageIDs []db.MessageIDPair, + mboxID db.MailboxIDPair, ) ([]Update, []db.UIDWithFlags, error) { - ids := make([]ids.MessageIDPair, 0, len(messageIDs)) + ids := make([]db.MessageIDPair, 0, len(messageIDs)) // Import messages to remote. for _, id := range messageIDs { @@ -407,11 +405,11 @@ func (state *State) actionCopyMessagesOutOfRecoveryMailbox( func (state *State) actionMoveMessagesOutOfRecoveryMailbox( ctx context.Context, - tx *ent.Tx, - messageIDs []ids.MessageIDPair, - mboxID ids.MailboxIDPair, + tx db.Transaction, + messageIDs []db.MessageIDPair, + mboxID db.MailboxIDPair, ) ([]Update, []db.UIDWithFlags, error) { - ids := make([]ids.MessageIDPair, 0, len(messageIDs)) + ids := make([]db.MessageIDPair, 0, len(messageIDs)) oldInternalIDs := make([]imap.InternalMessageID, 0, len(messageIDs)) // Import messages to remote. @@ -422,7 +420,7 @@ func (state *State) actionMoveMessagesOutOfRecoveryMailbox( } if !deduped { - if err := db.MarkMessageAsDeleted(ctx, tx, id.InternalID); err != nil { + if err := tx.MarkMessageAsDeleted(ctx, id.InternalID); err != nil { return nil, nil, err } } @@ -460,11 +458,11 @@ func (state *State) actionMoveMessagesOutOfRecoveryMailbox( // have already validated the input beforehand (e.g.: actionAddMessagesToMailbox and actionRemoveMessagesFromMailbox). func (state *State) actionRemoveMessagesFromMailboxUnchecked( ctx context.Context, - tx *ent.Tx, - messageIDs []ids.MessageIDPair, - mboxID ids.MailboxIDPair, + tx db.Transaction, + messageIDs []db.MessageIDPair, + mboxID db.MailboxIDPair, ) ([]Update, error) { - internalIDs, remoteIDs := ids.SplitMessageIDPairSlice(messageIDs) + internalIDs, remoteIDs := db.SplitMessageIDPairSlice(messageIDs) if mboxID.InternalID != state.user.GetRecoveryMailboxID().InternalID { if err := state.user.GetRemote().RemoveMessagesFromMailbox(ctx, remoteIDs, mboxID.RemoteID); err != nil { @@ -479,16 +477,16 @@ func (state *State) actionRemoveMessagesFromMailboxUnchecked( func (state *State) actionRemoveMessagesFromMailbox( ctx context.Context, - tx *ent.Tx, - messageIDs []ids.MessageIDPair, - mboxID ids.MailboxIDPair, + tx db.Transaction, + messageIDs []db.MessageIDPair, + mboxID db.MailboxIDPair, ) ([]Update, error) { - haveMessageIDs, err := db.FilterMailboxContains(ctx, tx.Client(), mboxID.InternalID, messageIDs) + haveMessageIDs, err := tx.MailboxFilterContains(ctx, mboxID.InternalID, messageIDs) if err != nil { return nil, err } - messageIDs = xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { + messageIDs = xslices.Filter(messageIDs, func(messageID db.MessageIDPair) bool { return slices.Contains(haveMessageIDs, messageID.InternalID) }) @@ -501,16 +499,16 @@ func (state *State) actionRemoveMessagesFromMailbox( func (state *State) actionMoveMessages( ctx context.Context, - tx *ent.Tx, - messageIDs []ids.MessageIDPair, - mboxFromID, mboxToID ids.MailboxIDPair, + tx db.Transaction, + messageIDs []db.MessageIDPair, + mboxFromID, mboxToID db.MailboxIDPair, ) ([]Update, []db.UIDWithFlags, error) { var allUpdates []Update if mboxFromID.InternalID == mboxToID.InternalID { - internalIDs, _ := ids.SplitMessageIDPairSlice(messageIDs) + internalIDs, _ := db.SplitMessageIDPairSlice(messageIDs) - uid, err := db.BumpMailboxUIDsForMessage(ctx, tx, internalIDs, mboxToID.InternalID) + uid, err := tx.BumpMailboxUIDsForMessage(ctx, mboxToID.InternalID, internalIDs) if err != nil { return nil, nil, err } @@ -519,12 +517,12 @@ func (state *State) actionMoveMessages( } { - messageIDsToAdd, err := db.FilterMailboxContains(ctx, tx.Client(), mboxToID.InternalID, messageIDs) + messageIDsToAdd, err := tx.MailboxFilterContains(ctx, mboxToID.InternalID, messageIDs) if err != nil { return nil, nil, err } - if remMessageIDs := xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { + if remMessageIDs := xslices.Filter(messageIDs, func(messageID db.MessageIDPair) bool { return slices.Contains(messageIDsToAdd, messageID.InternalID) }); len(remMessageIDs) > 0 { updates, err := state.actionRemoveMessagesFromMailboxUnchecked(ctx, tx, remMessageIDs, mboxToID) @@ -536,16 +534,16 @@ func (state *State) actionMoveMessages( } } - messageInFromMBox, err := db.FilterMailboxContains(ctx, tx.Client(), mboxFromID.InternalID, messageIDs) + messageInFromMBox, err := tx.MailboxFilterContains(ctx, mboxFromID.InternalID, messageIDs) if err != nil { return nil, nil, err } - messagesIDsToMove := xslices.Filter(messageIDs, func(messageID ids.MessageIDPair) bool { + messagesIDsToMove := xslices.Filter(messageIDs, func(messageID db.MessageIDPair) bool { return slices.Contains(messageInFromMBox, messageID.InternalID) }) - internalIDs, remoteIDs := ids.SplitMessageIDPairSlice(messagesIDsToMove) + internalIDs, remoteIDs := db.SplitMessageIDPairSlice(messagesIDsToMove) shouldRemoveOldMessages, err := state.user.GetRemote().MoveMessagesFromMailbox(ctx, remoteIDs, mboxFromID.RemoteID, mboxToID.RemoteID) if err != nil { @@ -564,7 +562,7 @@ func (state *State) actionMoveMessages( func (state *State) actionAddMessageFlags( ctx context.Context, - tx *ent.Tx, + tx db.Transaction, messages []snapMsgWithSeq, addFlags imap.FlagSet, ) ([]Update, error) { @@ -577,7 +575,7 @@ func (state *State) actionAddMessageFlags( func (state *State) actionRemoveMessageFlags( ctx context.Context, - tx *ent.Tx, + tx db.Transaction, messages []snapMsgWithSeq, remFlags imap.FlagSet, ) ([]Update, error) { @@ -589,7 +587,7 @@ func (state *State) actionRemoveMessageFlags( } func (state *State) actionSetMessageFlags(ctx context.Context, - tx *ent.Tx, + tx db.Transaction, messages []snapMsgWithSeq, setFlags imap.FlagSet) ([]Update, error) { if setFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { diff --git a/internal/state/mailbox.go b/internal/state/mailbox.go index f3ce0cc8..3746584e 100644 --- a/internal/state/mailbox.go +++ b/internal/state/mailbox.go @@ -8,10 +8,9 @@ import ( "time" "github.com/ProtonMail/gluon/connector" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/response" "github.com/ProtonMail/gluon/reporter" @@ -21,7 +20,7 @@ import ( ) type Mailbox struct { - id ids.MailboxIDPair + id db.MailboxIDPair name string subscribed bool uidValidity imap.UID @@ -40,9 +39,9 @@ type AppendOnlyMailbox interface { UIDValidity() imap.UID } -func newMailbox(mbox *ent.Mailbox, state *State, snap *snapshot) *Mailbox { +func newMailbox(mbox *db.Mailbox, state *State, snap *snapshot) *Mailbox { return &Mailbox{ - id: ids.NewMailboxIDPair(mbox), + id: db.NewMailboxIDPair(mbox), name: mbox.Name, uidValidity: mbox.UIDValidity, uidNext: mbox.UIDNext, @@ -85,20 +84,20 @@ func (m *Mailbox) Count() int { } func (m *Mailbox) Flags(ctx context.Context) (imap.FlagSet, error) { - return stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { - return db.GetMailboxFlags(ctx, client, m.id.InternalID) + return stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (imap.FlagSet, error) { + return client.GetMailboxFlags(ctx, m.id.InternalID) }) } func (m *Mailbox) PermanentFlags(ctx context.Context) (imap.FlagSet, error) { - return stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { - return db.GetMailboxPermanentFlags(ctx, client, m.id.InternalID) + return stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (imap.FlagSet, error) { + return client.GetMailboxPermanentFlags(ctx, m.id.InternalID) }) } func (m *Mailbox) Attributes(ctx context.Context) (imap.FlagSet, error) { - return stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (imap.FlagSet, error) { - return db.GetMailboxAttributes(ctx, client, m.id.InternalID) + return stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (imap.FlagSet, error) { + return client.GetMailboxAttributes(ctx, m.id.InternalID) }) } @@ -147,8 +146,8 @@ func (m *Mailbox) GetMessagesWithoutFlagCount(flag string) int { } func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap.FlagSet, date time.Time) (imap.UID, error) { - if err := stateDBRead(ctx, m.state, func(ctx context.Context, client *ent.Client) error { - if messageCount, uid, err := db.GetMailboxMessageCountAndUID(ctx, client, m.snap.mboxID.InternalID); err != nil { + if err := stateDBRead(ctx, m.state, func(ctx context.Context, client db.ReadOnly) error { + if messageCount, uid, err := client.GetMailboxMessageCountAndUID(ctx, m.snap.mboxID.InternalID); err != nil { return err } else { if err := m.state.imapLimits.CheckMailBoxMessageCount(messageCount, 1); err != nil { @@ -185,30 +184,21 @@ func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap. return 0, err } - if message, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Message, error) { - message, err := db.GetMessageWithIDWithDeletedFlag(ctx, client, msgID) - if err != nil { - if ent.IsNotFound(err) { - return nil, nil - } - - return nil, err - } - - return message, nil - }); err != nil || message == nil { + if messageDeleted, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (bool, error) { + return client.GetMessageDeletedFlag(ctx, msgID) + }); err != nil { logrus.WithError(err).Warn("The message has an unknown internal ID") - } else if !message.Deleted { + } else if !messageDeleted { logrus.Debugf("Appending duplicate message with Internal ID:%v", msgID.ShortID()) // Only shuffle around messages that haven't been marked for deletion. - if res, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, []db.UIDWithFlags, error) { - remoteID, err := db.GetMessageRemoteIDFromID(ctx, tx.Client(), msgID) + if res, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, []db.UIDWithFlags, error) { + remoteID, err := tx.GetMessageRemoteID(ctx, msgID) if err != nil { return nil, nil, err } return m.state.actionAddMessagesToMailbox(ctx, tx, - []ids.MessageIDPair{{InternalID: msgID, RemoteID: remoteID}}, + []db.MessageIDPair{{InternalID: msgID, RemoteID: remoteID}}, m.id, m.snap == m.state.snap, ) @@ -229,7 +219,7 @@ func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap. } } - return stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, imap.UID, error) { + return stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, imap.UID, error) { return m.state.actionCreateMessage(ctx, tx, m.snap.mboxID, literal, flags, date, m.snap == m.state.snap, appendIntoDrafts) }) } @@ -245,7 +235,7 @@ func (m *Mailbox) Append(ctx context.Context, literal []byte, flags imap.FlagSet } // Failed to append to mailbox attempt to insert into recovery mailbox. - knownMessage, recoverErr := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, bool, error) { + knownMessage, recoverErr := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, bool, error) { return m.state.actionCreateRecoveredMessage(ctx, tx, literal, flags, date) }) if recoverErr != nil && !knownMessage { @@ -269,8 +259,8 @@ func (m *Mailbox) Copy(ctx context.Context, seq []command.SeqRange, name string) return nil, ErrOperationNotAllowed } - mbox, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { - return db.GetMailboxByName(ctx, client, name) + mbox, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (*db.Mailbox, error) { + return client.GetMailboxByName(ctx, name) }) if err != nil { return nil, ErrNoSuchMailbox @@ -281,7 +271,7 @@ func (m *Mailbox) Copy(ctx context.Context, seq []command.SeqRange, name string) return nil, err } - msgIDs := make([]ids.MessageIDPair, len(messages)) + msgIDs := make([]db.MessageIDPair, len(messages)) msgUIDs := make([]imap.UID, len(messages)) for i := 0; i < len(messages); i++ { @@ -290,11 +280,11 @@ func (m *Mailbox) Copy(ctx context.Context, seq []command.SeqRange, name string) msgIDs[i] = snapMsg.ID } - destUIDs, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, []db.UIDWithFlags, error) { + destUIDs, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, []db.UIDWithFlags, error) { if m.state.user.GetRecoveryMailboxID().InternalID == m.snap.mboxID.InternalID { - return m.state.actionCopyMessagesOutOfRecoveryMailbox(ctx, tx, msgIDs, ids.NewMailboxIDPair(mbox)) + return m.state.actionCopyMessagesOutOfRecoveryMailbox(ctx, tx, msgIDs, db.NewMailboxIDPair(mbox)) } else { - return m.state.actionAddMessagesToMailbox(ctx, tx, msgIDs, ids.NewMailboxIDPair(mbox), m.snap == m.state.snap) + return m.state.actionAddMessagesToMailbox(ctx, tx, msgIDs, db.NewMailboxIDPair(mbox), m.snap == m.state.snap) } }) if err != nil { @@ -320,8 +310,8 @@ func (m *Mailbox) Move(ctx context.Context, seq []command.SeqRange, name string) return nil, ErrOperationNotAllowed } - mbox, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { - return db.GetMailboxByName(ctx, client, name) + mbox, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (*db.Mailbox, error) { + return client.GetMailboxByName(ctx, name) }) if err != nil { return nil, ErrNoSuchMailbox @@ -332,7 +322,7 @@ func (m *Mailbox) Move(ctx context.Context, seq []command.SeqRange, name string) return nil, err } - msgIDs := make([]ids.MessageIDPair, len(messages)) + msgIDs := make([]db.MessageIDPair, len(messages)) msgUIDs := make([]imap.UID, len(messages)) for i := 0; i < len(messages); i++ { @@ -341,11 +331,11 @@ func (m *Mailbox) Move(ctx context.Context, seq []command.SeqRange, name string) msgIDs[i] = snapMsg.ID } - destUIDs, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, []db.UIDWithFlags, error) { + destUIDs, err := stateDBWriteResult(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, []db.UIDWithFlags, error) { if m.state.user.GetRecoveryMailboxID().InternalID == m.snap.mboxID.InternalID { - return m.state.actionMoveMessagesOutOfRecoveryMailbox(ctx, tx, msgIDs, ids.NewMailboxIDPair(mbox)) + return m.state.actionMoveMessagesOutOfRecoveryMailbox(ctx, tx, msgIDs, db.NewMailboxIDPair(mbox)) } else { - return m.state.actionMoveMessages(ctx, tx, msgIDs, m.snap.mboxID, ids.NewMailboxIDPair(mbox)) + return m.state.actionMoveMessages(ctx, tx, msgIDs, m.snap.mboxID, db.NewMailboxIDPair(mbox)) } }) if err != nil { @@ -369,7 +359,7 @@ func (m *Mailbox) Store(ctx context.Context, seqSet []command.SeqRange, action c return err } - return stateDBWrite(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { + return stateDBWrite(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { switch action { case command.StoreActionAddFlags: return m.state.actionAddMessageFlags(ctx, tx, messages, flags) @@ -386,7 +376,7 @@ func (m *Mailbox) Store(ctx context.Context, seqSet []command.SeqRange, action c } func (m *Mailbox) Expunge(ctx context.Context, seq []command.SeqRange) error { - var msgIDs []ids.MessageIDPair + var msgIDs []db.MessageIDPair if seq != nil { snapMsgs, err := m.snap.getMessagesInRange(ctx, seq) @@ -394,7 +384,7 @@ func (m *Mailbox) Expunge(ctx context.Context, seq []command.SeqRange) error { return err } - msgIDs = make([]ids.MessageIDPair, 0, len(snapMsgs)) + msgIDs = make([]db.MessageIDPair, 0, len(snapMsgs)) for _, v := range snapMsgs { if v.toExpunge { @@ -405,7 +395,7 @@ func (m *Mailbox) Expunge(ctx context.Context, seq []command.SeqRange) error { msgIDs = m.snap.getAllMessagesIDsMarkedDelete() } - return stateDBWrite(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { + return stateDBWrite(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { return m.state.actionRemoveMessagesFromMailbox(ctx, tx, msgIDs, m.snap.mboxID) }) } diff --git a/internal/state/mailbox_fetch.go b/internal/state/mailbox_fetch.go index fc2e8134..87a6e2bc 100644 --- a/internal/state/mailbox_fetch.go +++ b/internal/state/mailbox_fetch.go @@ -3,6 +3,7 @@ package state import ( "context" "fmt" + "github.com/ProtonMail/gluon/db" "runtime" "strconv" "strings" @@ -12,8 +13,6 @@ import ( "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" "github.com/ProtonMail/gluon/internal/contexts" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/response" "github.com/ProtonMail/gluon/rfc822" "github.com/bradenaw/juniper/parallel" @@ -28,7 +27,7 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons return err } - operations := make([]func(snapMsgWithSeq, *ent.Message, []byte) (response.Item, error), 0, len(cmd.Attributes)) + operations := make([]func(snapMsgWithSeq, *db.Message, []byte) (response.Item, error), 0, len(cmd.Attributes)) var ( needsLiteral bool @@ -84,7 +83,7 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons setSeen = true } - op := func(_ snapMsgWithSeq, _ *ent.Message, literal []byte) (response.Item, error) { + op := func(_ snapMsgWithSeq, _ *db.Message, literal []byte) (response.Item, error) { return fetchAttributeBodySection(attribute, literal) } @@ -118,8 +117,8 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons defer async.HandlePanic(m.state.panicHandler) msg := snapMessages[i] - message, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Message, error) { - return db.GetMessage(ctx, client, msg.ID.InternalID) + message, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (*db.Message, error) { + return client.GetMessage(ctx, msg.ID.InternalID) }) if err != nil { return err @@ -175,7 +174,7 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons }) if len(msgsToBeMarkedSeen) != 0 { - if err := stateDBWrite(ctx, m.state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { + if err := stateDBWrite(ctx, m.state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { return m.state.actionAddMessageFlags(ctx, tx, msgsToBeMarkedSeen, imap.NewFlagSet(imap.FlagSeen)) }); err != nil { return err @@ -185,47 +184,47 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons return nil } -func fetchEnvelope(_ snapMsgWithSeq, message *ent.Message, _ []byte) (response.Item, error) { +func fetchEnvelope(_ snapMsgWithSeq, message *db.Message, _ []byte) (response.Item, error) { return response.ItemEnvelope(message.Envelope), nil } -func fetchFlags(msg snapMsgWithSeq, message *ent.Message, _ []byte) (response.Item, error) { +func fetchFlags(msg snapMsgWithSeq, message *db.Message, _ []byte) (response.Item, error) { return response.ItemFlags(msg.flags), nil } -func fetchInternalDate(_ snapMsgWithSeq, message *ent.Message, _ []byte) (response.Item, error) { +func fetchInternalDate(_ snapMsgWithSeq, message *db.Message, _ []byte) (response.Item, error) { return response.ItemInternalDate(message.Date), nil } -func fetchRFC822(_ snapMsgWithSeq, _ *ent.Message, literal []byte) (response.Item, error) { +func fetchRFC822(_ snapMsgWithSeq, _ *db.Message, literal []byte) (response.Item, error) { return response.ItemRFC822Literal(literal), nil } -func fetchRFC822Header(_ snapMsgWithSeq, _ *ent.Message, literal []byte) (response.Item, error) { +func fetchRFC822Header(_ snapMsgWithSeq, _ *db.Message, literal []byte) (response.Item, error) { section := rfc822.Parse(literal) return response.ItemRFC822Header(section.Header()), nil } -func fetchRFC822Size(_ snapMsgWithSeq, message *ent.Message, _ []byte) (response.Item, error) { +func fetchRFC822Size(_ snapMsgWithSeq, message *db.Message, _ []byte) (response.Item, error) { return response.ItemRFC822Size(message.Size), nil } -func fetchRFC822Text(_ snapMsgWithSeq, _ *ent.Message, literal []byte) (response.Item, error) { +func fetchRFC822Text(_ snapMsgWithSeq, _ *db.Message, literal []byte) (response.Item, error) { section := rfc822.Parse(literal) return response.ItemRFC822Text(section.Body()), nil } -func fetchBody(_ snapMsgWithSeq, message *ent.Message, _ []byte) (response.Item, error) { +func fetchBody(_ snapMsgWithSeq, message *db.Message, _ []byte) (response.Item, error) { return response.ItemBody(message.Body), nil } -func fetchBodyStructure(_ snapMsgWithSeq, message *ent.Message, _ []byte) (response.Item, error) { +func fetchBodyStructure(_ snapMsgWithSeq, message *db.Message, _ []byte) (response.Item, error) { return response.ItemBodyStructure(message.BodyStructure), nil } -func fetchUID(msg snapMsgWithSeq, _ *ent.Message, _ []byte) (response.Item, error) { +func fetchUID(msg snapMsgWithSeq, _ *db.Message, _ []byte) (response.Item, error) { return response.ItemUID(msg.UID), nil } diff --git a/internal/state/mailbox_search.go b/internal/state/mailbox_search.go index 038a18d6..5ff71126 100644 --- a/internal/state/mailbox_search.go +++ b/internal/state/mailbox_search.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "github.com/ProtonMail/gluon/db" "runtime" "strings" "sync/atomic" @@ -13,8 +14,6 @@ import ( "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" "github.com/ProtonMail/gluon/internal/contexts" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/rfc5322" "github.com/ProtonMail/gluon/rfc822" "github.com/bradenaw/juniper/parallel" @@ -88,15 +87,16 @@ func buildSearchData(ctx context.Context, m *Mailbox, op *buildSearchOpResult, m data := searchData{message: message} if op.needsMessage { - dbm, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client *ent.Client) (*ent.Message, error) { - return db.GetMessageDateAndSize(ctx, client, message.ID.InternalID) - }) - if err != nil { + if err := stateDBRead(ctx, m.state, func(ctx context.Context, client db.ReadOnly) error { + date, size, err := client.GetMessageDateAndSize(ctx, message.ID.InternalID) + + data.dbMessage.size = size + data.dbMessage.date = date + + return err + }); err != nil { return searchData{}, err } - - data.dbMessage.size = dbm.Size - data.dbMessage.date = dbm.Date } if op.needsLiteral { diff --git a/internal/state/match.go b/internal/state/match.go index 81848efe..8131f581 100644 --- a/internal/state/match.go +++ b/internal/state/match.go @@ -3,12 +3,11 @@ package state import ( "context" "fmt" + "github.com/ProtonMail/gluon/db" "regexp" "strings" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/bradenaw/juniper/xslices" ) @@ -22,12 +21,12 @@ type matchMailbox struct { Name string Subscribed bool // EntMBox should be set to nil if there is no such value. - EntMBox *ent.Mailbox + EntMBox *db.Mailbox } func getMatches( ctx context.Context, - client *ent.Client, + client db.ReadOnly, allMailboxes []matchMailbox, ref, pattern, delimiter string, subscribed bool, @@ -79,7 +78,7 @@ func getMatches( func prepareMatch( ctx context.Context, - client *ent.Client, + client db.ReadOnly, matchedName string, mbox *matchMailbox, pattern, delimiter string, @@ -108,13 +107,13 @@ func prepareMatch( if mbox.EntMBox != nil { atts = imap.NewFlagSetFromSlice(xslices.Map( - mbox.EntMBox.Edges.Attributes, - func(flag *ent.MailboxAttr) string { + mbox.EntMBox.Attributes, + func(flag *db.MailboxAttr) string { return flag.Value }, )) - recent, err := db.GetMailboxRecentCount(ctx, client, mbox.EntMBox) + recent, err := client.GetMailboxRecentCount(ctx, mbox.EntMBox.ID) if err != nil { return Match{}, false, err } diff --git a/internal/state/responders.go b/internal/state/responders.go index 4c091d2c..5ef7140c 100644 --- a/internal/state/responders.go +++ b/internal/state/responders.go @@ -3,13 +3,11 @@ package state import ( "context" "fmt" + "github.com/ProtonMail/gluon/db" "sync" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/contexts" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/response" "github.com/ProtonMail/gluon/reporter" "github.com/bradenaw/juniper/xslices" @@ -20,7 +18,7 @@ type responderStateUpdate struct { responders []Responder } -func (r *responderStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (r *responderStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { return s.PushResponder(ctx, tx, r.responders...) } @@ -37,7 +35,7 @@ func (r *responderStateUpdate) String() string { // Seeing as this is only used right now to clear the recent flags, we can avoid a lot of necessary database locking // and transaction overhead. type responderDBUpdate interface { - apply(ctx context.Context, tx *ent.Tx) error + apply(ctx context.Context, tx db.Transaction) error } func NewMailboxIDResponderStateUpdate(id imap.InternalMailboxID, responders ...Responder) Update { @@ -63,12 +61,12 @@ type Responder interface { } type exists struct { - messageID ids.MessageIDPair + messageID db.MessageIDPair messageUID imap.UID flags imap.FlagSet } -func newExists(messageID ids.MessageIDPair, messageUID imap.UID, flags imap.FlagSet) *exists { +func newExists(messageID db.MessageIDPair, messageUID imap.UID, flags imap.FlagSet) *exists { return &exists{messageID: messageID, messageUID: messageUID, flags: flags} } @@ -81,8 +79,8 @@ type clearRecentFlagRespUpdate struct { mboxID imap.InternalMailboxID } -func (u *clearRecentFlagRespUpdate) apply(ctx context.Context, tx *ent.Tx) error { - return db.ClearRecentFlag(ctx, tx, u.mboxID, u.messageID) +func (u *clearRecentFlagRespUpdate) apply(ctx context.Context, tx db.Transaction) error { + return tx.ClearRecentFlagInMailboxOnMessage(ctx, u.mboxID, u.messageID) } // targetedExists needs to be separate so that we update the targetStateID safely when doing concurrent updates @@ -197,7 +195,7 @@ func newExistsStateUpdateWithExists(mailboxID imap.InternalMailboxID, responders } } -func (e *ExistsStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (e *ExistsStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { // This check needs to be thread safe since we don't know when a state update // will be executed. Before each of these updates run we check whether at state // target ID has been set and update for the first state that manages to run this code diff --git a/internal/state/snapshot.go b/internal/state/snapshot.go index c38e3679..7d82be2d 100644 --- a/internal/state/snapshot.go +++ b/internal/state/snapshot.go @@ -2,39 +2,37 @@ package state import ( "context" + "github.com/ProtonMail/gluon/db" "strings" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" "github.com/ProtonMail/gluon/internal/contexts" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/ids" "github.com/bradenaw/juniper/xslices" ) type snapshot struct { - mboxID ids.MailboxIDPair + mboxID db.MailboxIDPair state *State messages *snapMsgList } -func newSnapshot(ctx context.Context, state *State, client *ent.Client, mbox *ent.Mailbox) (*snapshot, error) { - snapshotMessages, err := db.GetMailboxMessagesForNewSnapshot(ctx, client, mbox.ID) +func newSnapshot(ctx context.Context, state *State, client db.ReadOnly, mbox *db.Mailbox) (*snapshot, error) { + snapshotMessages, err := client.GetMailboxMessageForNewSnapshot(ctx, mbox.ID) if err != nil { return nil, err } snap := &snapshot{ - mboxID: ids.NewMailboxIDPair(mbox), + mboxID: db.NewMailboxIDPair(mbox), state: state, messages: newMsgList(len(snapshotMessages)), } for _, snapshotMessage := range snapshotMessages { if err := snap.messages.insert( - ids.MessageIDPair{InternalID: snapshotMessage.InternalID, RemoteID: snapshotMessage.RemoteID}, + db.MessageIDPair{InternalID: snapshotMessage.InternalID, RemoteID: snapshotMessage.RemoteID}, snapshotMessage.UID, snapshotMessage.GetFlagSet(), ); err != nil { @@ -45,9 +43,9 @@ func newSnapshot(ctx context.Context, state *State, client *ent.Client, mbox *en return snap, nil } -func newEmptySnapshot(state *State, mbox *ent.Mailbox) *snapshot { +func newEmptySnapshot(state *State, mbox *db.Mailbox) *snapshot { return &snapshot{ - mboxID: ids.NewMailboxIDPair(mbox), + mboxID: db.NewMailboxIDPair(mbox), state: state, messages: newMsgList(0), } @@ -119,14 +117,14 @@ func (snap *snapshot) getAllMessages() []snapMsgWithSeq { return result } -func (snap *snapshot) getAllMessageIDs() []ids.MessageIDPair { - return xslices.Map(snap.messages.all(), func(msg *snapMsg) ids.MessageIDPair { +func (snap *snapshot) getAllMessageIDs() []db.MessageIDPair { + return xslices.Map(snap.messages.all(), func(msg *snapMsg) db.MessageIDPair { return msg.ID }) } -func (snap *snapshot) getAllMessagesIDsMarkedDelete() []ids.MessageIDPair { - var msgs []ids.MessageIDPair +func (snap *snapshot) getAllMessagesIDsMarkedDelete() []db.MessageIDPair { + var msgs []db.MessageIDPair for _, v := range snap.messages.all() { if v.toExpunge { @@ -237,7 +235,7 @@ func (snap *snapshot) getMessagesWithoutFlagCount(flag string) int { }) } -func (snap *snapshot) appendMessage(messageID ids.MessageIDPair, uid imap.UID, flags imap.FlagSet) error { +func (snap *snapshot) appendMessage(messageID db.MessageIDPair, uid imap.UID, flags imap.FlagSet) error { return snap.messages.insert( messageID, uid, @@ -245,7 +243,7 @@ func (snap *snapshot) appendMessage(messageID ids.MessageIDPair, uid imap.UID, f ) } -func (snap *snapshot) appendMessageFromOtherState(messageID ids.MessageIDPair, uid imap.UID, flags imap.FlagSet) error { +func (snap *snapshot) appendMessageFromOtherState(messageID db.MessageIDPair, uid imap.UID, flags imap.FlagSet) error { snap.messages.insertOutOfOrder( messageID, uid, diff --git a/internal/state/snapshot_messages.go b/internal/state/snapshot_messages.go index dcfc1b4e..12daafe7 100644 --- a/internal/state/snapshot_messages.go +++ b/internal/state/snapshot_messages.go @@ -2,10 +2,10 @@ package state import ( "fmt" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" - "github.com/ProtonMail/gluon/internal/ids" "github.com/bradenaw/juniper/xslices" "golang.org/x/exp/slices" ) @@ -14,7 +14,7 @@ var ErrOutOfOrderUIDInsertion = fmt.Errorf("UIDs must be strictly ascending") // snapMsg is a single message inside a snapshot. type snapMsg struct { - ID ids.MessageIDPair + ID db.MessageIDPair UID imap.UID flags imap.FlagSet toExpunge bool @@ -48,7 +48,7 @@ func (list *snapMsgList) binarySearchByUID(uid imap.UID) (int, bool) { return index, ok } -func (list *snapMsgList) insert(msgID ids.MessageIDPair, msgUID imap.UID, flags imap.FlagSet) error { +func (list *snapMsgList) insert(msgID db.MessageIDPair, msgUID imap.UID, flags imap.FlagSet) error { if len(list.msg) > 0 && list.msg[len(list.msg)-1].UID >= msgUID { return fmt.Errorf("UID-Last=%v UID-Msg=%v: %w", list.msg[len(list.msg)-1].UID, msgUID, ErrOutOfOrderUIDInsertion) } @@ -67,7 +67,7 @@ func (list *snapMsgList) insert(msgID ids.MessageIDPair, msgUID imap.UID, flags return nil } -func (list *snapMsgList) insertOutOfOrder(msgID ids.MessageIDPair, msgUID imap.UID, flags imap.FlagSet) { +func (list *snapMsgList) insertOutOfOrder(msgID db.MessageIDPair, msgUID imap.UID, flags imap.FlagSet) { index, ok := list.binarySearchByUID(msgUID) if ok { panic("Duplicate UID added") diff --git a/internal/state/snapshot_messages_test.go b/internal/state/snapshot_messages_test.go index 40072ae0..34a7e1d7 100644 --- a/internal/state/snapshot_messages_test.go +++ b/internal/state/snapshot_messages_test.go @@ -1,11 +1,11 @@ package state import ( + "github.com/ProtonMail/gluon/db" "testing" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" - "github.com/ProtonMail/gluon/internal/ids" "github.com/stretchr/testify/require" ) @@ -221,8 +221,8 @@ func TestSnapListGetMessages(t *testing.T) { } } -func messageIDPair(internalID imap.InternalMessageID, remoteID imap.MessageID) ids.MessageIDPair { - return ids.MessageIDPair{InternalID: internalID, RemoteID: remoteID} +func messageIDPair(internalID imap.InternalMessageID, remoteID imap.MessageID) db.MessageIDPair { + return db.MessageIDPair{InternalID: internalID, RemoteID: remoteID} } func must[T any](val T, ok bool) T { diff --git a/internal/state/state.go b/internal/state/state.go index 1ab6f6c6..0fc703d4 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -4,13 +4,12 @@ import ( "bytes" "context" "fmt" + "github.com/ProtonMail/gluon/db" "strings" "sync/atomic" "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/response" "github.com/ProtonMail/gluon/limits" @@ -77,20 +76,20 @@ func (state *State) UserID() string { } func (state *State) List(ctx context.Context, ref, pattern string, lsub bool, fn func(map[string]Match) error) error { - return stateDBRead(ctx, state, func(ctx context.Context, client *ent.Client) error { - mailboxes, err := db.GetAllMailboxes(ctx, client) + return stateDBRead(ctx, state, func(ctx context.Context, client db.ReadOnly) error { + mailboxes, err := client.GetAllMailboxes(ctx) if err != nil { return err } recoveryMailboxID := state.user.GetRecoveryMailboxID().InternalID - recoveryMBoxMessageCount, err := db.GetMailboxMessageCount(ctx, client, recoveryMailboxID) + recoveryMBoxMessageCount, err := client.GetMailboxMessageCount(ctx, recoveryMailboxID) if err != nil { logrus.WithError(err).Error("Failed to get recovery mailbox message count, assuming empty") recoveryMBoxMessageCount = 0 } - mailboxes = xslices.Filter(mailboxes, func(mailbox *ent.Mailbox) bool { + mailboxes = xslices.Filter(mailboxes, func(mailbox *db.Mailbox) bool { if mailbox.ID == recoveryMailboxID && recoveryMBoxMessageCount == 0 { return false } @@ -101,7 +100,7 @@ func (state *State) List(ctx context.Context, ref, pattern string, lsub bool, fn case imap.Visible: return true case imap.HiddenIfEmpty: - count, err := db.GetMailboxMessageCount(ctx, client, mailbox.ID) + count, err := client.GetMailboxMessageCount(ctx, mailbox.ID) if err != nil { logrus.WithError(err).Error("Failed to get recovery mailbox message count, assuming not empty") return true @@ -113,10 +112,10 @@ func (state *State) List(ctx context.Context, ref, pattern string, lsub bool, fn } }) - var deletedSubscriptions map[imap.MailboxID]*ent.DeletedSubscription + var deletedSubscriptions map[imap.MailboxID]*db.DeletedSubscription if lsub { - deletedSubscriptions, err = db.GetDeletedSubscriptionSet(ctx, client) + deletedSubscriptions, err = client.GetDeletedSubscriptionSet(ctx) if err != nil { return err } @@ -164,8 +163,8 @@ func (state *State) List(ctx context.Context, ref, pattern string, lsub bool, fn } func (state *State) Select(ctx context.Context, name string, fn func(*Mailbox) error) error { - mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { - return db.GetMailboxByName(ctx, client, name) + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*db.Mailbox, error) { + return client.GetMailboxByName(ctx, name) }) if err != nil { return ErrNoSuchMailbox @@ -177,15 +176,15 @@ func (state *State) Select(ctx context.Context, name string, fn func(*Mailbox) e } } - snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*snapshot, error) { + snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*snapshot, error) { return newSnapshot(ctx, state, client, mbox) }) if err != nil { return err } - if err := stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { - return nil, db.ClearRecentFlags(ctx, tx, mbox.ID) + if err := stateDBWrite(ctx, state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { + return nil, tx.ClearRecentFlagsInMailbox(ctx, mbox.ID) }); err != nil { return err } @@ -197,8 +196,8 @@ func (state *State) Select(ctx context.Context, name string, fn func(*Mailbox) e } func (state *State) Examine(ctx context.Context, name string, fn func(*Mailbox) error) error { - mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { - return db.GetMailboxByName(ctx, client, name) + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*db.Mailbox, error) { + return client.GetMailboxByName(ctx, name) }) if err != nil { return ErrNoSuchMailbox @@ -210,7 +209,7 @@ func (state *State) Examine(ctx context.Context, name string, fn func(*Mailbox) } } - snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*snapshot, error) { + snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*snapshot, error) { return newSnapshot(ctx, state, client, mbox) }) if err != nil { @@ -247,10 +246,8 @@ func (state *State) Create(ctx context.Context, name string) error { } } - return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { - client := tx.Client() - - if mailboxCount, err := db.GetMailboxCount(ctx, client); err != nil { + return stateDBWrite(ctx, state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { + if mailboxCount, err := tx.GetMailboxCount(ctx); err != nil { return nil, err } else if err := state.imapLimits.CheckMailBoxCount(mailboxCount); err != nil { return nil, err @@ -263,14 +260,14 @@ func (state *State) Create(ctx context.Context, name string) error { name = strings.TrimRight(name, state.delimiter) } - if exists, err := db.MailboxExistsWithName(ctx, client, name); err != nil { + if exists, err := tx.MailboxExistsWithName(ctx, name); err != nil { return nil, err } else if exists { return nil, ErrExistingMailbox } for _, superior := range listSuperiors(name, state.delimiter) { - if exists, err := db.MailboxExistsWithName(ctx, client, superior); err != nil { + if exists, err := tx.MailboxExistsWithName(ctx, superior); err != nil { return nil, err } else if exists { continue @@ -297,13 +294,13 @@ func (state *State) Delete(ctx context.Context, name string) (bool, error) { return false, ErrOperationNotAllowed } - mboxID, err := stateDBWriteResult(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, imap.InternalMailboxID, error) { - mbox, err := db.GetMailboxByName(ctx, tx.Client(), name) + mboxID, err := stateDBWriteResult(ctx, state, func(ctx context.Context, tx db.Transaction) ([]Update, imap.InternalMailboxID, error) { + mbox, err := tx.GetMailboxByName(ctx, name) if err != nil { return nil, 0, ErrNoSuchMailbox } - update, err := state.actionDeleteMailbox(ctx, tx, ids.NewMailboxIDPair(mbox)) + update, err := state.actionDeleteMailbox(ctx, tx, db.NewMailboxIDPair(mbox)) if err != nil { return nil, 0, err } @@ -319,7 +316,7 @@ func (state *State) Delete(ctx context.Context, name string) (bool, error) { func (state *State) Rename(ctx context.Context, oldName, newName string) error { type Result struct { - MBox *ent.Mailbox + MBox *db.Mailbox MBoxesToCreate []string } @@ -327,15 +324,13 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { return ErrOperationNotAllowed } - return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { - client := tx.Client() - - mbox, err := db.GetMailboxByName(ctx, client, oldName) + return stateDBWrite(ctx, state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { + mbox, err := tx.GetMailboxByName(ctx, oldName) if err != nil { return nil, ErrNoSuchMailbox } - if exists, err := db.MailboxExistsWithName(ctx, client, newName); err != nil { + if exists, err := tx.MailboxExistsWithName(ctx, newName); err != nil { return nil, err } else if exists { return nil, ErrExistingMailbox @@ -343,7 +338,7 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { var mboxesToCreate []string for _, superior := range listSuperiors(newName, state.delimiter) { - if exists, err := db.MailboxExistsWithName(ctx, client, superior); err != nil { + if exists, err := tx.MailboxExistsWithName(ctx, superior); err != nil { return nil, err } else if exists { if superior == oldName { @@ -366,7 +361,7 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { return nil, err } - if err := db.CreateMailboxIfNotExists(ctx, tx, res, state.delimiter, uidValidity); err != nil { + if err := tx.CreateMailboxIfNotExists(ctx, res, state.delimiter, uidValidity); err != nil { return nil, err } } @@ -380,24 +375,24 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { } // Locally update all inferiors so we don't wait for update - mailboxes, err := db.GetAllMailboxes(ctx, tx.Client()) + mailboxes, err := tx.GetAllMailboxes(ctx) if err != nil { return nil, err } - inferiors := listInferiors(oldName, state.delimiter, xslices.Map(mailboxes, func(mailbox *ent.Mailbox) string { + inferiors := listInferiors(oldName, state.delimiter, xslices.Map(mailboxes, func(mailbox *db.Mailbox) string { return mailbox.Name })) for _, inferior := range inferiors { - mbox, err := db.GetMailboxByName(ctx, tx.Client(), inferior) + mbox, err := tx.GetMailboxByName(ctx, inferior) if err != nil { return nil, ErrNoSuchMailbox } newInferior := newName + strings.TrimPrefix(inferior, oldName) - if err := db.RenameMailboxWithRemoteID(ctx, tx, mbox.RemoteID, newInferior); err != nil { + if err := tx.RenameMailboxWithRemoteID(ctx, mbox.RemoteID, newInferior); err != nil { return nil, err } } @@ -407,8 +402,8 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { } func (state *State) Subscribe(ctx context.Context, name string) error { - return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { - mbox, err := db.GetMailboxByName(ctx, tx.Client(), name) + return stateDBWrite(ctx, state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { + mbox, err := tx.GetMailboxByName(ctx, name) if err != nil { return nil, ErrNoSuchMailbox } @@ -417,16 +412,16 @@ func (state *State) Subscribe(ctx context.Context, name string) error { return nil, ErrAlreadySubscribed } - return nil, mbox.Update().SetSubscribed(true).Exec(ctx) + return nil, tx.SetMailboxSubscribed(ctx, mbox.ID, true) }) } func (state *State) Unsubscribe(ctx context.Context, name string) error { - return stateDBWrite(ctx, state, func(ctx context.Context, tx *ent.Tx) ([]Update, error) { - mbox, err := db.GetMailboxByName(ctx, tx.Client(), name) + return stateDBWrite(ctx, state, func(ctx context.Context, tx db.Transaction) ([]Update, error) { + mbox, err := tx.GetMailboxByName(ctx, name) if err != nil { // If mailbox does not exist, check that if it is present in the deleted subscription table - if count, err := db.RemoveDeletedSubscriptionWithName(ctx, tx, name); err != nil { + if count, err := tx.RemoveDeletedSubscriptionWithName(ctx, name); err != nil { return nil, err } else if count == 0 { return nil, ErrNoSuchMailbox @@ -439,7 +434,7 @@ func (state *State) Unsubscribe(ctx context.Context, name string) error { return nil, ErrAlreadyUnsubscribed } - return nil, mbox.Update().SetSubscribed(false).Exec(ctx) + return nil, tx.SetMailboxSubscribed(ctx, mbox.ID, false) }) } @@ -455,8 +450,8 @@ func (state *State) Idle(ctx context.Context, fn func([]response.Response, chan } func (state *State) Mailbox(ctx context.Context, name string, fn func(*Mailbox) error) error { - mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { - return db.GetMailboxByName(ctx, client, name) + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*db.Mailbox, error) { + return client.GetMailboxByName(ctx, name) }) if err != nil { return ErrNoSuchMailbox @@ -466,7 +461,7 @@ func (state *State) Mailbox(ctx context.Context, name string, fn func(*Mailbox) return fn(newMailbox(mbox, state, state.snap)) } - snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*snapshot, error) { + snap, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*snapshot, error) { return newSnapshot(ctx, state, client, mbox) }) if err != nil { @@ -484,8 +479,8 @@ func (state *State) AppendOnlyMailbox(ctx context.Context, name string, fn func( return ErrOperationNotAllowed } - mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { - return db.GetMailboxByName(ctx, client, name) + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*db.Mailbox, error) { + return client.GetMailboxByName(ctx, name) }) if err != nil { return ErrNoSuchMailbox @@ -505,8 +500,8 @@ func (state *State) Selected(ctx context.Context, fn func(*Mailbox) error) error return ErrSessionNotSelected } - mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) { - return db.GetMailboxByID(ctx, client, state.snap.mboxID.InternalID) + mbox, err := stateDBReadResult(ctx, state, func(ctx context.Context, client db.ReadOnly) (*db.Mailbox, error) { + return client.GetMailboxByID(ctx, state.snap.mboxID.InternalID) }) if err != nil { return ErrNoSuchMailbox @@ -572,7 +567,7 @@ func (state *State) ApplyUpdate(ctx context.Context, update Update) error { return nil } - if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx db.Transaction) error { return update.Apply(ctx, tx, state) }); err != nil { reporter.MessageWithContext(ctx, @@ -599,7 +594,7 @@ func (state *State) markInvalid() { } // renameInbox creates a new mailbox and moves everything there. -func (state *State) renameInbox(ctx context.Context, tx *ent.Tx, inbox *ent.Mailbox, newName string) ([]Update, error) { +func (state *State) renameInbox(ctx context.Context, tx db.Transaction, inbox *db.Mailbox, newName string) ([]Update, error) { uidValidity, err := state.user.GenerateUIDValidity() if err != nil { return nil, err @@ -610,14 +605,14 @@ func (state *State) renameInbox(ctx context.Context, tx *ent.Tx, inbox *ent.Mail return nil, err } - messageIDs, err := db.GetMailboxMessageIDPairs(ctx, tx.Client(), inbox.ID) + messageIDs, err := tx.GetMailboxMessageIDPairs(ctx, inbox.ID) if err != nil { return nil, err } - mboxIDPair := ids.NewMailboxIDPair(mbox) + mboxIDPair := db.NewMailboxIDPair(mbox) - updates, _, err := state.actionMoveMessages(ctx, tx, messageIDs, ids.NewMailboxIDPair(inbox), mboxIDPair) + updates, _, err := state.actionMoveMessages(ctx, tx, messageIDs, db.NewMailboxIDPair(inbox), mboxIDPair) if err != nil { return nil, err } @@ -644,7 +639,7 @@ func (state *State) endIdle() { state.idleCh = nil } -func (state *State) getLiteral(ctx context.Context, messageID ids.MessageIDPair) ([]byte, error) { +func (state *State) getLiteral(ctx context.Context, messageID db.MessageIDPair) ([]byte, error) { var literal []byte storeLiteral, firstErr := state.user.GetStore().Get(messageID.InternalID) @@ -710,7 +705,7 @@ func (state *State) flushResponses(ctx context.Context, permitExpunge bool) ([]r } } - if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx db.Transaction) error { for _, update := range dbUpdates { if err := update.apply(ctx, tx); err != nil { return err @@ -725,7 +720,7 @@ func (state *State) flushResponses(ctx context.Context, permitExpunge bool) ([]r return response.Merge(responses), nil } -func (state *State) PushResponder(ctx context.Context, tx *ent.Tx, responder ...Responder) error { +func (state *State) PushResponder(ctx context.Context, tx db.Transaction, responder ...Responder) error { if state.idleCh == nil { return state.queueResponder(responder...) } @@ -833,18 +828,18 @@ func (state *State) close() error { return nil } -func stateDBRead(ctx context.Context, state *State, fn func(context.Context, *ent.Client) error) error { +func stateDBRead(ctx context.Context, state *State, fn func(context.Context, db.ReadOnly) error) error { return state.user.GetDB().Read(ctx, fn) } -func stateDBReadResult[T any](ctx context.Context, state *State, fn func(context.Context, *ent.Client) (T, error)) (T, error) { - return db.ReadResult(ctx, state.user.GetDB(), fn) +func stateDBReadResult[T any](ctx context.Context, state *State, fn func(context.Context, db.ReadOnly) (T, error)) (T, error) { + return db.ClientReadType(ctx, state.user.GetDB(), fn) } -func stateDBWrite(ctx context.Context, state *State, fn func(context.Context, *ent.Tx) ([]Update, error)) error { +func stateDBWrite(ctx context.Context, state *State, fn func(context.Context, db.Transaction) ([]Update, error)) error { var updates []Update - if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx db.Transaction) error { up, err := fn(ctx, tx) updates = up return err @@ -854,7 +849,7 @@ func stateDBWrite(ctx context.Context, state *State, fn func(context.Context, *e // need to create a separate transaction for the state updates so that import changes get written first. if len(updates) != 0 { - if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx db.Transaction) error { return state.user.QueueOrApplyStateUpdate(ctx, tx, updates...) }); err != nil { return err @@ -864,10 +859,10 @@ func stateDBWrite(ctx context.Context, state *State, fn func(context.Context, *e return nil } -func stateDBWriteResult[T any](ctx context.Context, state *State, fn func(context.Context, *ent.Tx) ([]Update, T, error)) (T, error) { +func stateDBWriteResult[T any](ctx context.Context, state *State, fn func(context.Context, db.Transaction) ([]Update, T, error)) (T, error) { var updates []Update - result, err := db.WriteResult(ctx, state.user.GetDB(), func(ctx context.Context, tx *ent.Tx) (T, error) { + result, err := db.ClientWriteType(ctx, state.user.GetDB(), func(ctx context.Context, tx db.Transaction) (T, error) { up, val, err := fn(ctx, tx) updates = up return val, err @@ -879,7 +874,7 @@ func stateDBWriteResult[T any](ctx context.Context, state *State, fn func(contex // need to create a separate transaction for the state updates so that import changes get written first. if len(updates) != 0 { - if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx *ent.Tx) error { + if err := state.user.GetDB().Write(ctx, func(ctx context.Context, tx db.Transaction) error { return state.user.QueueOrApplyStateUpdate(ctx, tx, updates...) }); err != nil { return result, err diff --git a/internal/state/updates.go b/internal/state/updates.go index d8e0b783..62704fd5 100644 --- a/internal/state/updates.go +++ b/internal/state/updates.go @@ -3,12 +3,11 @@ package state import ( "context" "fmt" + "github.com/ProtonMail/gluon/db" "strings" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/contexts" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/ids" "github.com/bradenaw/juniper/xslices" ) @@ -17,7 +16,7 @@ type Update interface { // Filter returns true when the state can be passed into A. Filter(s *State) bool // Apply the update to a given state. - Apply(cxt context.Context, tx *ent.Tx, s *State) error + Apply(cxt context.Context, tx db.Transaction, s *State) error String() string } @@ -31,7 +30,7 @@ func newMessageFlagsComboStateUpdate() *messageFlagsComboStateUpdate { return &messageFlagsComboStateUpdate{} } -func (u *messageFlagsComboStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *messageFlagsComboStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { for _, v := range u.updates { if err := v.Apply(ctx, tx, s); err != nil { return err @@ -53,11 +52,11 @@ type messageFlagsAddedStateUpdate struct { AllStateFilter messageIDs []imap.InternalMessageID flags imap.FlagSet - mboxID ids.MailboxIDPair + mboxID db.MailboxIDPair stateID StateID } -func newMessageFlagsAddedStateUpdate(flags imap.FlagSet, mboxID ids.MailboxIDPair, messageIDs []imap.InternalMessageID, stateID StateID) Update { +func newMessageFlagsAddedStateUpdate(flags imap.FlagSet, mboxID db.MailboxIDPair, messageIDs []imap.InternalMessageID, stateID StateID) Update { return &messageFlagsAddedStateUpdate{ flags: flags, mboxID: mboxID, @@ -66,7 +65,7 @@ func newMessageFlagsAddedStateUpdate(flags imap.FlagSet, mboxID ids.MailboxIDPai } } -func (u *messageFlagsAddedStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *messageFlagsAddedStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { for _, messageID := range u.messageIDs { newFlags := u.flags @@ -96,7 +95,7 @@ func (u *messageFlagsAddedStateUpdate) String() string { // applyMessageFlagsAdded adds the flags to the given messages. func (state *State) applyMessageFlagsAdded(ctx context.Context, - tx *ent.Tx, + tx db.Transaction, messageIDs []imap.InternalMessageID, addFlags imap.FlagSet) ([]Update, error) { if addFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { @@ -106,9 +105,7 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, // Since DB state can be more up to date then the flag state we should only emit add flag updates for values // that actually changed. - client := tx.Client() - - curFlags, err := db.GetMessageFlags(ctx, client, messageIDs) + curFlags, err := tx.GetMessagesFlags(ctx, messageIDs) if err != nil { return nil, err } @@ -150,7 +147,7 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, flagStateUpdate := newMessageFlagsComboStateUpdate() if addFlags.ContainsUnchecked(imap.FlagDeletedLowerCase) { - if err := db.SetDeletedFlag(ctx, tx, state.snap.mboxID.InternalID, messageIDs, true); err != nil { + if err := tx.SetMailboxMessagesDeletedFlag(ctx, state.snap.mboxID.InternalID, messageIDs, true); err != nil { return nil, err } @@ -169,7 +166,7 @@ func (state *State) applyMessageFlagsAdded(ctx context.Context, } } - if err := db.AddMessageFlag(ctx, tx, messagesToFlag, flag); err != nil { + if err := tx.AddFlagToMessages(ctx, messagesToFlag, flag); err != nil { return nil, err } @@ -183,11 +180,11 @@ type messageFlagsRemovedStateUpdate struct { AllStateFilter messageIDs []imap.InternalMessageID flags imap.FlagSet - mboxID ids.MailboxIDPair + mboxID db.MailboxIDPair stateID StateID } -func NewMessageFlagsRemovedStateUpdate(flags imap.FlagSet, mboxID ids.MailboxIDPair, messageIDs []imap.InternalMessageID, stateID StateID) Update { +func NewMessageFlagsRemovedStateUpdate(flags imap.FlagSet, mboxID db.MailboxIDPair, messageIDs []imap.InternalMessageID, stateID StateID) Update { return &messageFlagsRemovedStateUpdate{ flags: flags, mboxID: mboxID, @@ -196,7 +193,7 @@ func NewMessageFlagsRemovedStateUpdate(flags imap.FlagSet, mboxID ids.MailboxIDP } } -func (u *messageFlagsRemovedStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *messageFlagsRemovedStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { for _, messageID := range u.messageIDs { newFlags := u.flags @@ -226,16 +223,14 @@ func (u *messageFlagsRemovedStateUpdate) String() string { // applyMessageFlagsRemoved removes the flags from the given messages. func (state *State) applyMessageFlagsRemoved(ctx context.Context, - tx *ent.Tx, + tx db.Transaction, messageIDs []imap.InternalMessageID, remFlags imap.FlagSet) ([]Update, error) { if remFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { return nil, fmt.Errorf("the recent flag is read-only") } - client := tx.Client() - - curFlags, err := db.GetMessageFlags(ctx, client, messageIDs) + curFlags, err := tx.GetMessagesFlags(ctx, messageIDs) if err != nil { return nil, err } @@ -276,7 +271,7 @@ func (state *State) applyMessageFlagsRemoved(ctx context.Context, flagStateUpdate := newMessageFlagsComboStateUpdate() if remFlags.ContainsUnchecked(imap.FlagDeletedLowerCase) { - if err := db.SetDeletedFlag(ctx, tx, state.snap.mboxID.InternalID, messageIDs, false); err != nil { + if err := tx.SetMailboxMessagesDeletedFlag(ctx, state.snap.mboxID.InternalID, messageIDs, false); err != nil { return nil, err } @@ -295,7 +290,7 @@ func (state *State) applyMessageFlagsRemoved(ctx context.Context, } } - if err := db.RemoveMessageFlag(ctx, tx, messagesToFlag, flag); err != nil { + if err := tx.RemoveFlagFromMessages(ctx, messagesToFlag, flag); err != nil { return nil, err } @@ -309,11 +304,11 @@ type messageFlagsSetStateUpdate struct { AllStateFilter messageIDs []imap.InternalMessageID flags imap.FlagSet - mboxID ids.MailboxIDPair + mboxID db.MailboxIDPair stateID StateID } -func NewMessageFlagsSetStateUpdate(flags imap.FlagSet, mboxID ids.MailboxIDPair, messageIDs []imap.InternalMessageID, stateID StateID) Update { +func NewMessageFlagsSetStateUpdate(flags imap.FlagSet, mboxID db.MailboxIDPair, messageIDs []imap.InternalMessageID, stateID StateID) Update { return &messageFlagsSetStateUpdate{ flags: flags, mboxID: mboxID, @@ -322,7 +317,7 @@ func NewMessageFlagsSetStateUpdate(flags imap.FlagSet, mboxID ids.MailboxIDPair, } } -func (u *messageFlagsSetStateUpdate) Apply(ctx context.Context, tx *ent.Tx, state *State) error { +func (u *messageFlagsSetStateUpdate) Apply(ctx context.Context, tx db.Transaction, state *State) error { for _, messageID := range u.messageIDs { newFlags := u.flags @@ -352,7 +347,7 @@ func (u *messageFlagsSetStateUpdate) String() string { // applyMessageFlagsSet sets the flags of the given messages. func (state *State) applyMessageFlagsSet(ctx context.Context, - tx *ent.Tx, + tx db.Transaction, messageIDs []imap.InternalMessageID, setFlags imap.FlagSet) ([]Update, error) { if setFlags.ContainsUnchecked(imap.FlagRecentLowerCase) { @@ -363,7 +358,7 @@ func (state *State) applyMessageFlagsSet(ctx context.Context, return nil, nil } - curFlags, err := db.GetMessageFlags(ctx, tx.Client(), messageIDs) + curFlags, err := tx.GetMessagesFlags(ctx, messageIDs) if err != nil { return nil, err } @@ -398,11 +393,11 @@ func (state *State) applyMessageFlagsSet(ctx context.Context, } } - if err := db.SetDeletedFlag(ctx, tx, state.snap.mboxID.InternalID, messageIDs, setFlags.Contains(imap.FlagDeleted)); err != nil { + if err := tx.SetMailboxMessagesDeletedFlag(ctx, state.snap.mboxID.InternalID, messageIDs, setFlags.Contains(imap.FlagDeleted)); err != nil { return nil, err } - if err := db.SetMessageFlags(ctx, tx, messageIDs, setFlags.Remove(imap.FlagDeleted)); err != nil { + if err := tx.SetFlagsOnMessages(ctx, messageIDs, setFlags.Remove(imap.FlagDeleted)); err != nil { return nil, err } @@ -421,7 +416,7 @@ func NewMailboxRemoteIDUpdateStateUpdate(internalID imap.InternalMailboxID, remo } } -func (u *mailboxRemoteIDUpdateStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *mailboxRemoteIDUpdateStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { s.snap.mboxID.RemoteID = u.remoteID return nil @@ -439,7 +434,7 @@ func NewMailboxDeletedStateUpdate(mboxID imap.InternalMailboxID) Update { return &mailboxDeletedStateUpdate{MBoxIDStateFilter: MBoxIDStateFilter{MboxID: mboxID}} } -func (u *mailboxDeletedStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *mailboxDeletedStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { s.markInvalid() return nil @@ -457,7 +452,7 @@ func NewUIDValidityBumpedStateUpdate() Update { return &uidValidityBumpedStateUpdate{} } -func (u *uidValidityBumpedStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *uidValidityBumpedStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { s.markInvalid() return nil diff --git a/internal/state/updates_mailbox.go b/internal/state/updates_mailbox.go index b0321c4e..01604cc8 100644 --- a/internal/state/updates_mailbox.go +++ b/internal/state/updates_mailbox.go @@ -2,11 +2,9 @@ package state import ( "context" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/limits" "github.com/bradenaw/juniper/xslices" ) @@ -16,14 +14,14 @@ import ( // MoveMessagesFromMailbox moves messages from one mailbox to the other. func MoveMessagesFromMailbox( ctx context.Context, - tx *ent.Tx, + tx db.Transaction, mboxFromID, mboxToID imap.InternalMailboxID, messageIDs []imap.InternalMessageID, s *State, imapLimits limits.IMAP, removeOldMessages bool, ) ([]db.UIDWithFlags, []Update, error) { - messageCount, uid, err := db.GetMailboxMessageCountAndUID(ctx, tx.Client(), mboxToID) + messageCount, uid, err := tx.GetMailboxMessageCountAndUID(ctx, mboxToID) if err != nil { return nil, nil, err } @@ -37,12 +35,12 @@ func MoveMessagesFromMailbox( } if mboxFromID != mboxToID && removeOldMessages { - if err := db.RemoveMessagesFromMailbox(ctx, tx, messageIDs, mboxFromID); err != nil { + if err := tx.RemoveMessagesFromMailbox(ctx, mboxFromID, messageIDs); err != nil { return nil, nil, err } } - messageUIDs, err := db.AddMessagesToMailbox(ctx, tx, messageIDs, mboxToID) + messageUIDs, err := tx.AddMessagesToMailbox(ctx, mboxToID, messageIDs) if err != nil { return nil, nil, err } @@ -50,7 +48,7 @@ func MoveMessagesFromMailbox( stateUpdates := make([]Update, 0, len(messageIDs)+1) { responders := xslices.Map(messageUIDs, func(uid db.UIDWithFlags) *exists { - return newExists(ids.MessageIDPair{ + return newExists(db.MessageIDPair{ InternalID: uid.InternalID, RemoteID: uid.RemoteID, }, uid.UID, uid.GetFlagSet()) @@ -68,8 +66,13 @@ func MoveMessagesFromMailbox( } // AddMessagesToMailbox adds the messages to the given mailbox. -func AddMessagesToMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID, s *State, imapLimits limits.IMAP) ([]db.UIDWithFlags, Update, error) { - messageCount, uid, err := db.GetMailboxMessageCountAndUID(ctx, tx.Client(), mboxID) +func AddMessagesToMailbox(ctx context.Context, + tx db.Transaction, + mboxID imap.InternalMailboxID, + messageIDs []imap.InternalMessageID, + s *State, + imapLimits limits.IMAP) ([]db.UIDWithFlags, Update, error) { + messageCount, uid, err := tx.GetMailboxMessageCountAndUID(ctx, mboxID) if err != nil { return nil, nil, err } @@ -82,13 +85,13 @@ func AddMessagesToMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.InternalM return nil, nil, err } - messageUIDs, err := db.AddMessagesToMailbox(ctx, tx, messageIDs, mboxID) + messageUIDs, err := tx.AddMessagesToMailbox(ctx, mboxID, messageIDs) if err != nil { return nil, nil, err } responders := xslices.Map(messageUIDs, func(uid db.UIDWithFlags) *exists { - return newExists(ids.MessageIDPair{ + return newExists(db.MessageIDPair{ InternalID: uid.InternalID, RemoteID: uid.RemoteID, }, uid.UID, uid.GetFlagSet()) @@ -98,9 +101,9 @@ func AddMessagesToMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.InternalM } // RemoveMessagesFromMailbox removes the messages from the given mailbox. -func RemoveMessagesFromMailbox(ctx context.Context, tx *ent.Tx, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]Update, error) { +func RemoveMessagesFromMailbox(ctx context.Context, tx db.Transaction, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]Update, error) { if len(messageIDs) > 0 { - if err := db.RemoveMessagesFromMailbox(ctx, tx, messageIDs, mboxID); err != nil { + if err := tx.RemoveMessagesFromMailbox(ctx, mboxID, messageIDs); err != nil { return nil, err } } diff --git a/internal/state/updates_remote.go b/internal/state/updates_remote.go index 2f4e3d5f..83b1d0ea 100644 --- a/internal/state/updates_remote.go +++ b/internal/state/updates_remote.go @@ -3,10 +3,10 @@ package state import ( "context" "fmt" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/contexts" - "github.com/ProtonMail/gluon/internal/db/ent" ) type RemoteAddMessageFlagsStateUpdate struct { @@ -21,7 +21,7 @@ func NewRemoteAddMessageFlagsStateUpdate(messageID imap.InternalMessageID, flag } } -func (u *RemoteAddMessageFlagsStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *RemoteAddMessageFlagsStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { return s.PushResponder(ctx, tx, NewFetch(u.MessageID, imap.NewFlagSet(u.flag), contexts.IsUID(ctx), contexts.IsSilent(ctx), false, FetchFlagOpAdd)) } @@ -41,7 +41,7 @@ func NewRemoteRemoveMessageFlagsStateUpdate(messageID imap.InternalMessageID, fl } } -func (u *RemoteRemoveMessageFlagsStateUpdate) Apply(ctx context.Context, tx *ent.Tx, s *State) error { +func (u *RemoteRemoveMessageFlagsStateUpdate) Apply(ctx context.Context, tx db.Transaction, s *State) error { return s.PushResponder(ctx, tx, NewFetch(u.MessageID, imap.NewFlagSet(u.flag), contexts.IsUID(ctx), contexts.IsSilent(ctx), false, FetchFlagOpRem)) } diff --git a/internal/state/user_interface.go b/internal/state/user_interface.go index f7174f8b..4c59ce41 100644 --- a/internal/state/user_interface.go +++ b/internal/state/user_interface.go @@ -2,12 +2,10 @@ package state import ( "context" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/internal/utils" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db" - "github.com/ProtonMail/gluon/internal/db/ent" - "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/store" ) @@ -19,17 +17,17 @@ type UserInterface interface { GetDelimiter() string - GetDB() *db.DB + GetDB() db.Client GetRemote() Connector GetStore() *store.WriteControlledStore - QueueOrApplyStateUpdate(ctx context.Context, tx *ent.Tx, update ...Update) error + QueueOrApplyStateUpdate(ctx context.Context, tx db.Transaction, update ...Update) error ReleaseState(ctx context.Context, st *State) error - GetRecoveryMailboxID() ids.MailboxIDPair + GetRecoveryMailboxID() db.MailboxIDPair GenerateUIDValidity() (imap.UID, error) diff --git a/option.go b/option.go index cd160bc9..c4f8b363 100644 --- a/option.go +++ b/option.go @@ -2,6 +2,7 @@ package gluon import ( "crypto/tls" + "github.com/ProtonMail/gluon/db" "io" "time" @@ -230,3 +231,15 @@ func (w withUIDValidityGenerator) config(builder *serverBuilder) { func WithUIDValidityGenerator(generator imap.UIDValidityGenerator) Option { return &withUIDValidityGenerator{generator: generator} } + +type withDBClient struct { + ci db.ClientInterface +} + +func (w withDBClient) config(builder *serverBuilder) { + builder.dbCI = w.ci +} + +func WithDBClient(ci db.ClientInterface) Option { + return &withDBClient{ci: ci} +} diff --git a/store/mock_store/store.go b/store/mock_store/store.go index 0b8f2ad8..5c2046a6 100644 --- a/store/mock_store/store.go +++ b/store/mock_store/store.go @@ -5,6 +5,7 @@ package mock_store import ( + io "io" reflect "reflect" imap "github.com/ProtonMail/gluon/imap" @@ -97,7 +98,7 @@ func (mr *MockStoreMockRecorder) List() *gomock.Call { } // Set mocks base method. -func (m *MockStore) Set(arg0 imap.InternalMessageID, arg1 []byte) error { +func (m *MockStore) Set(arg0 imap.InternalMessageID, arg1 io.Reader) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Set", arg0, arg1) ret0, _ := ret[0].(error) diff --git a/tests/account_test.go b/tests/account_test.go index a2496ae9..751b9423 100644 --- a/tests/account_test.go +++ b/tests/account_test.go @@ -2,7 +2,7 @@ package tests import ( "errors" - "github.com/ProtonMail/gluon/internal/db" + "github.com/ProtonMail/gluon/db" "github.com/stretchr/testify/require" "os" "path/filepath" diff --git a/tests/db_test.go b/tests/db_test.go index ab19ec37..c7dc0995 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -2,14 +2,16 @@ package tests import ( "context" + "github.com/ProtonMail/gluon/db" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/stretchr/testify/require" ) func dbCheckUserMessageCount(s *testSession, user string, expectedCount int) { - err := s.withUserDB(user, func(ent *ent.Client, ctx context.Context) { - val, err := ent.Message.Query().Count(ctx) + err := s.withUserDB(user, func(ent db.Client, ctx context.Context) { + val, err := db.ClientReadType(ctx, ent, func(ctx context.Context, only db.ReadOnly) (int, error) { + return only.GetTotalMessageCount(ctx) + }) require.NoError(s.tb, err) require.Equal(s.tb, expectedCount, val) }) diff --git a/tests/server_test.go b/tests/server_test.go index 17c4267f..1d866b2b 100644 --- a/tests/server_test.go +++ b/tests/server_test.go @@ -5,6 +5,8 @@ import ( "crypto/tls" "encoding/hex" "fmt" + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/internal/db_impl" "net" "path/filepath" "testing" @@ -77,6 +79,7 @@ type serverOptions struct { imapLimits limits.IMAP reporter reporter.Reporter uidValidityGenerator imap.UIDValidityGenerator + database db.ClientInterface } func (s *serverOptions) defaultUsername() string { @@ -179,6 +182,14 @@ type uidValidityGeneratorOption struct { generator imap.UIDValidityGenerator } +type withDatabaseOption struct { + database db.ClientInterface +} + +func (w withDatabaseOption) apply(options *serverOptions) { + options.database = w.database +} + func (u uidValidityGeneratorOption) apply(options *serverOptions) { options.uidValidityGenerator = u.generator } @@ -227,6 +238,10 @@ func withDatabaseDir(dir string) serverOption { return &databaseDirOption{dir: dir} } +func withDatabase(ci db.ClientInterface) serverOption { + return &withDatabaseOption{database: ci} +} + func defaultServerOptions(tb testing.TB, modifiers ...serverOption) *serverOptions { options := &serverOptions{ credentials: []credentials{{ @@ -241,6 +256,7 @@ func defaultServerOptions(tb testing.TB, modifiers ...serverOption) *serverOptio storeBuilder: &store.OnDiskStoreBuilder{}, connectorBuilder: &dummyConnectorBuilder{}, imapLimits: limits.DefaultLimits(), + database: db_impl.NewEntDB(), } for _, op := range modifiers { diff --git a/tests/session_test.go b/tests/session_test.go index 2dcb9848..9f009334 100644 --- a/tests/session_test.go +++ b/tests/session_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/ProtonMail/gluon/db" "io" "net" "os" @@ -11,12 +12,10 @@ import ( "testing" "time" - "entgo.io/ent/dialect" "github.com/ProtonMail/gluon" "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/events" "github.com/ProtonMail/gluon/imap" - "github.com/ProtonMail/gluon/internal/db/ent" "github.com/ProtonMail/gluon/internal/utils" "github.com/ProtonMail/go-mbox" "github.com/emersion/go-imap/client" @@ -113,13 +112,13 @@ func (s *testSession) newClient() *client.Client { return client } -func (s *testSession) withUserDB(user string, fn func(client *ent.Client, ctx context.Context)) error { - path, ok := s.userDBPaths[s.userIDs[user]] +func (s *testSession) withUserDB(user string, fn func(client db.Client, ctx context.Context)) error { + userID, ok := s.userIDs[user] if !ok { return fmt.Errorf("User not found") } - client, err := ent.Open(dialect.SQLite, path) + client, _, err := s.options.database.New(s.server.GetDatabasePath(), userID) if err != nil { return err } From 4d0403568c49e47a7deaf4fdfc610a8ee2f6a529 Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Fri, 26 May 2023 10:21:56 +0200 Subject: [PATCH 2/4] chore: Fix go-imports --- builder.go | 4 ++-- db/deferred_delete.go | 3 ++- db/ops_mailbox.go | 3 ++- db/ops_message.go | 3 ++- db/ops_subscription.go | 1 + db/types.go | 3 ++- internal/backend/backend.go | 2 +- internal/backend/connector_updates.go | 2 +- internal/backend/state_user_interface_impl.go | 4 ++-- internal/db_impl/ent_db/mailbox.go | 2 +- internal/db_impl/ent_db/message.go | 2 +- internal/db_impl/ent_db/ops_read.go | 3 ++- internal/db_impl/ent_db/ops_write.go | 1 + internal/session/errors.go | 3 ++- internal/session/handle_append.go | 1 + internal/state/mailbox_fetch.go | 2 +- internal/state/mailbox_search.go | 2 +- internal/state/match.go | 2 +- internal/state/responders.go | 2 +- internal/state/snapshot.go | 2 +- internal/state/snapshot_messages.go | 2 +- internal/state/snapshot_messages_test.go | 2 +- internal/state/state.go | 2 +- internal/state/updates.go | 2 +- internal/state/updates_mailbox.go | 2 +- internal/state/updates_remote.go | 2 +- internal/state/user_interface.go | 4 ++-- internal/utils/message_hashmap.go | 3 ++- internal/utils/utils.go | 1 + option.go | 2 +- rfc5322/parser_test.go | 2 +- rfc5322/validation.go | 1 + rfc5322/validation_test.go | 3 ++- rfc822/hash.go | 3 ++- tests/account_test.go | 5 +++-- tests/db_test.go | 2 +- tests/deleted_test.go | 3 ++- tests/server_test.go | 4 ++-- tests/session_test.go | 2 +- 39 files changed, 55 insertions(+), 39 deletions(-) diff --git a/builder.go b/builder.go index 56b52023..88d38315 100644 --- a/builder.go +++ b/builder.go @@ -2,15 +2,15 @@ package gluon import ( "crypto/tls" - "github.com/ProtonMail/gluon/db" - "github.com/ProtonMail/gluon/internal/db_impl/ent_db" "io" "os" "time" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/backend" + "github.com/ProtonMail/gluon/internal/db_impl/ent_db" "github.com/ProtonMail/gluon/internal/session" "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/profiling" diff --git a/db/deferred_delete.go b/db/deferred_delete.go index 33463e55..d110aeae 100644 --- a/db/deferred_delete.go +++ b/db/deferred_delete.go @@ -3,9 +3,10 @@ package db import ( "errors" "fmt" - "github.com/google/uuid" "os" "path/filepath" + + "github.com/google/uuid" ) // DeleteDB will rename all the database files for the given user to a directory within the same folder to avoid diff --git a/db/ops_mailbox.go b/db/ops_mailbox.go index eb22d684..4c4e8557 100644 --- a/db/ops_mailbox.go +++ b/db/ops_mailbox.go @@ -2,8 +2,9 @@ package db import ( "context" - "github.com/ProtonMail/gluon/imap" "strings" + + "github.com/ProtonMail/gluon/imap" ) type MailboxReadOps interface { diff --git a/db/ops_message.go b/db/ops_message.go index ee9da379..cd67ab17 100644 --- a/db/ops_message.go +++ b/db/ops_message.go @@ -2,9 +2,10 @@ package db import ( "context" + "time" + "github.com/ProtonMail/gluon/imap" "github.com/bradenaw/juniper/xslices" - "time" ) type MessageReadOps interface { diff --git a/db/ops_subscription.go b/db/ops_subscription.go index 272342d5..b79d4a83 100644 --- a/db/ops_subscription.go +++ b/db/ops_subscription.go @@ -2,6 +2,7 @@ package db import ( "context" + "github.com/ProtonMail/gluon/imap" ) diff --git a/db/types.go b/db/types.go index 7f4a7b49..9d7147a0 100644 --- a/db/types.go +++ b/db/types.go @@ -2,8 +2,9 @@ package db import ( "fmt" - "github.com/ProtonMail/gluon/imap" "time" + + "github.com/ProtonMail/gluon/imap" ) type MailboxIDPair struct { diff --git a/internal/backend/backend.go b/internal/backend/backend.go index e2693064..5e487583 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -3,13 +3,13 @@ package backend import ( "context" "fmt" - "github.com/ProtonMail/gluon/db" "sync" "sync/atomic" "time" "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/connector" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/limits" diff --git a/internal/backend/connector_updates.go b/internal/backend/connector_updates.go index 8fea5ac1..68604a7f 100644 --- a/internal/backend/connector_updates.go +++ b/internal/backend/connector_updates.go @@ -4,12 +4,12 @@ import ( "bytes" "context" "fmt" - "github.com/ProtonMail/gluon/db" "io" "os" "runtime" "strings" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/state" diff --git a/internal/backend/state_user_interface_impl.go b/internal/backend/state_user_interface_impl.go index a348f518..69691358 100644 --- a/internal/backend/state_user_interface_impl.go +++ b/internal/backend/state_user_interface_impl.go @@ -2,12 +2,12 @@ package backend import ( "context" - "github.com/ProtonMail/gluon/db" - "github.com/ProtonMail/gluon/internal/utils" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/state" + "github.com/ProtonMail/gluon/internal/utils" "github.com/ProtonMail/gluon/store" ) diff --git a/internal/db_impl/ent_db/mailbox.go b/internal/db_impl/ent_db/mailbox.go index 3900df73..6355a7cf 100644 --- a/internal/db_impl/ent_db/mailbox.go +++ b/internal/db_impl/ent_db/mailbox.go @@ -3,10 +3,10 @@ package ent_db import ( "context" "fmt" - "github.com/ProtonMail/gluon/db" "strings" "entgo.io/ent/dialect/sql" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" diff --git a/internal/db_impl/ent_db/message.go b/internal/db_impl/ent_db/message.go index 091151ec..e2d34505 100644 --- a/internal/db_impl/ent_db/message.go +++ b/internal/db_impl/ent_db/message.go @@ -3,9 +3,9 @@ package ent_db import ( "context" "fmt" - "github.com/ProtonMail/gluon/db" "entgo.io/ent/dialect/sql" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal/mailbox" diff --git a/internal/db_impl/ent_db/ops_read.go b/internal/db_impl/ent_db/ops_read.go index 16ea70e9..bfdb2d7e 100644 --- a/internal/db_impl/ent_db/ops_read.go +++ b/internal/db_impl/ent_db/ops_read.go @@ -2,11 +2,12 @@ package ent_db import ( "context" + "time" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" "github.com/bradenaw/juniper/xslices" - "time" ) type EntOpsRead struct { diff --git a/internal/db_impl/ent_db/ops_write.go b/internal/db_impl/ent_db/ops_write.go index 8d914950..b04f3e94 100644 --- a/internal/db_impl/ent_db/ops_write.go +++ b/internal/db_impl/ent_db/ops_write.go @@ -2,6 +2,7 @@ package ent_db import ( "context" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/db_impl/ent_db/internal" diff --git a/internal/session/errors.go b/internal/session/errors.go index 82a8b540..16140a86 100644 --- a/internal/session/errors.go +++ b/internal/session/errors.go @@ -3,10 +3,11 @@ package session import ( "context" "errors" + "net" + "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/internal/state" "github.com/ProtonMail/gluon/rfc822" - "net" ) var ( diff --git a/internal/session/handle_append.go b/internal/session/handle_append.go index 04169321..5b3b0426 100644 --- a/internal/session/handle_append.go +++ b/internal/session/handle_append.go @@ -3,6 +3,7 @@ package session import ( "context" "errors" + "github.com/ProtonMail/gluon/imap/command" "github.com/ProtonMail/gluon/internal/response" "github.com/ProtonMail/gluon/internal/state" diff --git a/internal/state/mailbox_fetch.go b/internal/state/mailbox_fetch.go index 87a6e2bc..ec6bef85 100644 --- a/internal/state/mailbox_fetch.go +++ b/internal/state/mailbox_fetch.go @@ -3,13 +3,13 @@ package state import ( "context" "fmt" - "github.com/ProtonMail/gluon/db" "runtime" "strconv" "strings" "sync/atomic" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" "github.com/ProtonMail/gluon/internal/contexts" diff --git a/internal/state/mailbox_search.go b/internal/state/mailbox_search.go index 5ff71126..a35b5841 100644 --- a/internal/state/mailbox_search.go +++ b/internal/state/mailbox_search.go @@ -4,13 +4,13 @@ import ( "bytes" "context" "fmt" - "github.com/ProtonMail/gluon/db" "runtime" "strings" "sync/atomic" "time" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" "github.com/ProtonMail/gluon/internal/contexts" diff --git a/internal/state/match.go b/internal/state/match.go index 8131f581..fcf819ad 100644 --- a/internal/state/match.go +++ b/internal/state/match.go @@ -3,10 +3,10 @@ package state import ( "context" "fmt" - "github.com/ProtonMail/gluon/db" "regexp" "strings" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/bradenaw/juniper/xslices" ) diff --git a/internal/state/responders.go b/internal/state/responders.go index 5ef7140c..1bba1e2f 100644 --- a/internal/state/responders.go +++ b/internal/state/responders.go @@ -3,9 +3,9 @@ package state import ( "context" "fmt" - "github.com/ProtonMail/gluon/db" "sync" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/contexts" "github.com/ProtonMail/gluon/internal/response" diff --git a/internal/state/snapshot.go b/internal/state/snapshot.go index 7d82be2d..8e497698 100644 --- a/internal/state/snapshot.go +++ b/internal/state/snapshot.go @@ -2,9 +2,9 @@ package state import ( "context" - "github.com/ProtonMail/gluon/db" "strings" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" "github.com/ProtonMail/gluon/internal/contexts" diff --git a/internal/state/snapshot_messages.go b/internal/state/snapshot_messages.go index 12daafe7..a73dc2ec 100644 --- a/internal/state/snapshot_messages.go +++ b/internal/state/snapshot_messages.go @@ -2,8 +2,8 @@ package state import ( "fmt" - "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" "github.com/bradenaw/juniper/xslices" diff --git a/internal/state/snapshot_messages_test.go b/internal/state/snapshot_messages_test.go index 34a7e1d7..314b4c5f 100644 --- a/internal/state/snapshot_messages_test.go +++ b/internal/state/snapshot_messages_test.go @@ -1,9 +1,9 @@ package state import ( - "github.com/ProtonMail/gluon/db" "testing" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/command" "github.com/stretchr/testify/require" diff --git a/internal/state/state.go b/internal/state/state.go index 0fc703d4..b35d8ea3 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -4,11 +4,11 @@ import ( "bytes" "context" "fmt" - "github.com/ProtonMail/gluon/db" "strings" "sync/atomic" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/ids" "github.com/ProtonMail/gluon/internal/response" diff --git a/internal/state/updates.go b/internal/state/updates.go index 62704fd5..df918eb9 100644 --- a/internal/state/updates.go +++ b/internal/state/updates.go @@ -3,9 +3,9 @@ package state import ( "context" "fmt" - "github.com/ProtonMail/gluon/db" "strings" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/contexts" "github.com/ProtonMail/gluon/internal/ids" diff --git a/internal/state/updates_mailbox.go b/internal/state/updates_mailbox.go index 01604cc8..940264fc 100644 --- a/internal/state/updates_mailbox.go +++ b/internal/state/updates_mailbox.go @@ -2,8 +2,8 @@ package state import ( "context" - "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/limits" "github.com/bradenaw/juniper/xslices" diff --git a/internal/state/updates_remote.go b/internal/state/updates_remote.go index 83b1d0ea..f279156c 100644 --- a/internal/state/updates_remote.go +++ b/internal/state/updates_remote.go @@ -3,8 +3,8 @@ package state import ( "context" "fmt" - "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/contexts" ) diff --git a/internal/state/user_interface.go b/internal/state/user_interface.go index 4c59ce41..671e11ea 100644 --- a/internal/state/user_interface.go +++ b/internal/state/user_interface.go @@ -2,10 +2,10 @@ package state import ( "context" - "github.com/ProtonMail/gluon/db" - "github.com/ProtonMail/gluon/internal/utils" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" + "github.com/ProtonMail/gluon/internal/utils" "github.com/ProtonMail/gluon/store" ) diff --git a/internal/utils/message_hashmap.go b/internal/utils/message_hashmap.go index 122ee91d..5b190b31 100644 --- a/internal/utils/message_hashmap.go +++ b/internal/utils/message_hashmap.go @@ -1,9 +1,10 @@ package utils import ( + "sync" + "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/rfc822" - "sync" ) // MessageHashesMap tracks the hashes for a literal and it's associated internal IMAP ID. diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 8e8cd68b..92997ae7 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -2,6 +2,7 @@ package utils import ( "errors" + "github.com/google/uuid" ) diff --git a/option.go b/option.go index c4f8b363..49b6a211 100644 --- a/option.go +++ b/option.go @@ -2,11 +2,11 @@ package gluon import ( "crypto/tls" - "github.com/ProtonMail/gluon/db" "io" "time" "github.com/ProtonMail/gluon/async" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" limits2 "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/profiling" diff --git a/rfc5322/parser_test.go b/rfc5322/parser_test.go index f061a2cb..c9816075 100644 --- a/rfc5322/parser_test.go +++ b/rfc5322/parser_test.go @@ -2,12 +2,12 @@ package rfc5322 import ( "bytes" - "github.com/stretchr/testify/require" "net/mail" "testing" "github.com/ProtonMail/gluon/rfcparser" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func newTestRFCParser(s string) *rfcparser.Parser { diff --git a/rfc5322/validation.go b/rfc5322/validation.go index 2d3aefdd..136b2d99 100644 --- a/rfc5322/validation.go +++ b/rfc5322/validation.go @@ -3,6 +3,7 @@ package rfc5322 import ( "errors" "fmt" + "github.com/ProtonMail/gluon/rfc822" ) diff --git a/rfc5322/validation_test.go b/rfc5322/validation_test.go index d385c036..453eccba 100644 --- a/rfc5322/validation_test.go +++ b/rfc5322/validation_test.go @@ -1,8 +1,9 @@ package rfc5322 import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestValidateMessageHeaderFields_RequiredFieldsPass(t *testing.T) { diff --git a/rfc822/hash.go b/rfc822/hash.go index 201b675a..9b982acc 100644 --- a/rfc822/hash.go +++ b/rfc822/hash.go @@ -4,8 +4,9 @@ import ( "bytes" "crypto/sha256" "encoding/base64" - "github.com/sirupsen/logrus" "strings" + + "github.com/sirupsen/logrus" ) // GetMessageHash returns the hash of the given message. diff --git a/tests/account_test.go b/tests/account_test.go index 751b9423..49ed0fe4 100644 --- a/tests/account_test.go +++ b/tests/account_test.go @@ -2,11 +2,12 @@ package tests import ( "errors" - "github.com/ProtonMail/gluon/db" - "github.com/stretchr/testify/require" "os" "path/filepath" "testing" + + "github.com/ProtonMail/gluon/db" + "github.com/stretchr/testify/require" ) func TestAccountRemovalMovesDBToDeferredDeleteFolder(t *testing.T) { diff --git a/tests/db_test.go b/tests/db_test.go index c7dc0995..93d808bf 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -2,8 +2,8 @@ package tests import ( "context" - "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/db" "github.com/stretchr/testify/require" ) diff --git a/tests/deleted_test.go b/tests/deleted_test.go index 87d4f25e..fc7c434a 100644 --- a/tests/deleted_test.go +++ b/tests/deleted_test.go @@ -1,11 +1,12 @@ package tests import ( - "github.com/stretchr/testify/require" "os" "path/filepath" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestDeleted(t *testing.T) { diff --git a/tests/server_test.go b/tests/server_test.go index 1d866b2b..aedd8227 100644 --- a/tests/server_test.go +++ b/tests/server_test.go @@ -5,8 +5,6 @@ import ( "crypto/tls" "encoding/hex" "fmt" - "github.com/ProtonMail/gluon/db" - "github.com/ProtonMail/gluon/internal/db_impl" "net" "path/filepath" "testing" @@ -14,7 +12,9 @@ import ( "github.com/ProtonMail/gluon" "github.com/ProtonMail/gluon/connector" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" + "github.com/ProtonMail/gluon/internal/db_impl" "github.com/ProtonMail/gluon/internal/hash" "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/logging" diff --git a/tests/session_test.go b/tests/session_test.go index 9f009334..6e9fcf3b 100644 --- a/tests/session_test.go +++ b/tests/session_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/ProtonMail/gluon/db" "io" "net" "os" @@ -14,6 +13,7 @@ import ( "github.com/ProtonMail/gluon" "github.com/ProtonMail/gluon/connector" + "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/events" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/utils" From a8afac739f5098944f369cee937395393a973bfb Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Fri, 26 May 2023 10:33:10 +0200 Subject: [PATCH 3/4] refactor: Workaround golangci-lint crash Install golangci-lint from source rather than a release for the time being. --- .github/workflows/pull-request.yml | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 4a1ca858..006faad3 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -19,27 +19,13 @@ jobs: with: go-version: '1.18' - - name: Remove old static libs if modified - if: needs.check.outputs.changed == 'true' - run: | - rm -r internal/parser/lib - - - name: Download new static libs if modified - if: needs.check.outputs.changed == 'true' - uses: actions/download-artifact@v3 - with: - name: ${{ matrix.os }}-libs - path: internal/parser/lib - - name: Run go mod tidy run: go mod tidy - name: Run golangci-lint - uses: golangci/golangci-lint-action@v3 - with: - version: v1.51.1 - args: --timeout=500s - skip-cache: true + run: | + go install github.com/golangci/golangci-lint/cmd/golangci-lint@251ceaa228607dd3e0371694a1ab2c45d21cb744 + golangci-lint run --timeout=500s - name: Run tests run: go test -timeout 15m -v ./... From 64f3b750c8b3eebf33d0c073e38186652ca53298 Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Fri, 26 May 2023 10:39:53 +0200 Subject: [PATCH 4/4] chore: Bump Go version to 1.20 --- .github/workflows/pull-request.yml | 4 ++-- go.mod | 2 +- go.sum | 5 ----- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 006faad3..b6c0c09f 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -14,10 +14,10 @@ jobs: - name: Get sources uses: actions/checkout@v3 - - name: Set up Go 1.18 + - name: Set up Go 1.20 uses: actions/setup-go@v3 with: - go-version: '1.18' + go-version: '1.20' - name: Run go mod tidy run: go mod tidy diff --git a/go.mod b/go.mod index d2fd5a41..21dd5004 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/ProtonMail/gluon -go 1.18 +go 1.20 require ( entgo.io/ent v0.11.8 diff --git a/go.sum b/go.sum index 58fabfd6..6dcdf870 100644 --- a/go.sum +++ b/go.sum @@ -51,12 +51,10 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348 h1:MtvEpTB6LX3vkb4ax0b5D2DHbNAUsen0Gx5wZoq3lV4= -github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= -github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/pierrec/lz4/v4 v4.1.17 h1:kV4Ip+/hUBC+8T6+2EgburRtkE9ef4nbY3f4dFhGjMc= github.com/pierrec/lz4/v4 v4.1.17/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= @@ -69,8 +67,6 @@ github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/f github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -117,7 +113,6 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.3.1-0.20221202221704-aa9f4b2f3d57 h1:/X0t/E4VxbZE7MLS7auvE7YICHeVvbIa9vkOVvYW/24= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=