From 3b6ee4cb74eb81d76e4607cbb04caa64352d00bd Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Tue, 6 Jun 2023 11:48:38 +0200 Subject: [PATCH] feat(GODT-2510): Direct SQLite3 DB implementation using SQLite3. This current version is feature complete with the ent implementation. The only noticeable difference is the addition of the db version table. The new database is now the default, but we will still run tests on CI to compare behavior against old implementation until it is completely removed. --- .github/workflows/pull-request.yml | 6 + builder.go | 4 +- db/ops.go | 1 + db/ops_mailbox.go | 2 +- db/ops_message.go | 2 +- db/types.go | 9 + go.mod | 2 +- go.sum | 2 + internal/backend/connector_updates.go | 2 +- internal/db_impl/db_impl.go | 5 + internal/db_impl/ent_db/ops_read.go | 4 +- internal/db_impl/sqlite3/client.go | 274 +++++++++ internal/db_impl/sqlite3/migration_v0.go | 54 ++ internal/db_impl/sqlite3/migrations.go | 86 +++ internal/db_impl/sqlite3/query_utils.go | 214 +++++++ internal/db_impl/sqlite3/read_ops.go | 598 ++++++++++++++++++++ internal/db_impl/sqlite3/tables.go | 160 ++++++ internal/db_impl/sqlite3/tracer.go | 472 ++++++++++++++++ internal/db_impl/sqlite3/types.go | 23 + internal/db_impl/sqlite3/v0/constants.go | 46 ++ internal/db_impl/sqlite3/wrappers.go | 132 +++++ internal/db_impl/sqlite3/write_ops.go | 676 +++++++++++++++++++++++ internal/state/actions.go | 2 +- internal/state/mailbox.go | 4 + internal/state/mailbox_fetch.go | 2 +- internal/state/state.go | 4 +- tests/db_test.go | 26 + tests/server_test.go | 10 +- 28 files changed, 2809 insertions(+), 13 deletions(-) create mode 100644 internal/db_impl/sqlite3/client.go create mode 100644 internal/db_impl/sqlite3/migration_v0.go create mode 100644 internal/db_impl/sqlite3/migrations.go create mode 100644 internal/db_impl/sqlite3/query_utils.go create mode 100644 internal/db_impl/sqlite3/read_ops.go create mode 100644 internal/db_impl/sqlite3/tables.go create mode 100644 internal/db_impl/sqlite3/tracer.go create mode 100644 internal/db_impl/sqlite3/types.go create mode 100644 internal/db_impl/sqlite3/v0/constants.go create mode 100644 internal/db_impl/sqlite3/wrappers.go create mode 100644 internal/db_impl/sqlite3/write_ops.go diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index b6c0c09f..f7baecfb 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -30,6 +30,12 @@ jobs: - name: Run tests run: go test -timeout 15m -v ./... + - name: Run tests (Ent) + if: runner.os == 'Linux' + run: go test -timeout 15m -v ./... + env: + GLUON_TEST_FORCE_ENT_DB: true + - name: Run tests with race check if: runner.os != 'Windows' run: go test -race -v ./tests diff --git a/builder.go b/builder.go index 88d38315..f75f03c3 100644 --- a/builder.go +++ b/builder.go @@ -10,7 +10,7 @@ import ( "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/backend" - "github.com/ProtonMail/gluon/internal/db_impl/ent_db" + "github.com/ProtonMail/gluon/internal/db_impl/sqlite3" "github.com/ProtonMail/gluon/internal/session" "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/profiling" @@ -50,7 +50,7 @@ func newBuilder() (*serverBuilder, error) { imapLimits: limits.DefaultLimits(), uidValidityGenerator: imap.DefaultEpochUIDValidityGenerator(), panicHandler: async.NoopPanicHandler{}, - dbCI: ent_db.NewEntDBBuilder(), + dbCI: sqlite3.NewBuilder(), }, nil } diff --git a/db/ops.go b/db/ops.go index cd950cf2..c04969dc 100644 --- a/db/ops.go +++ b/db/ops.go @@ -7,6 +7,7 @@ type ReadOnly interface { } type Transaction interface { + ReadOnly MailboxWriteOps MessageWriteOps SubscriptionWriteOps diff --git a/db/ops_mailbox.go b/db/ops_mailbox.go index 4c4e8557..f0400a37 100644 --- a/db/ops_mailbox.go +++ b/db/ops_mailbox.go @@ -22,7 +22,7 @@ type MailboxReadOps interface { GetMailboxMessageIDPairs(ctx context.Context, mboxID imap.InternalMailboxID) ([]MessageIDPair, error) - GetAllMailboxes(ctx context.Context) ([]*Mailbox, error) + GetAllMailboxesWithAttr(ctx context.Context) ([]*Mailbox, error) GetAllMailboxesAsRemoteIDs(ctx context.Context) ([]imap.MailboxID, error) diff --git a/db/ops_message.go b/db/ops_message.go index cd67ab17..66d7531e 100644 --- a/db/ops_message.go +++ b/db/ops_message.go @@ -13,7 +13,7 @@ type MessageReadOps interface { MessageExistsWithRemoteID(ctx context.Context, id imap.MessageID) (bool, error) - GetMessage(ctx context.Context, id imap.InternalMessageID) (*Message, error) + GetMessageNoEdges(ctx context.Context, id imap.InternalMessageID) (*Message, error) GetTotalMessageCount(ctx context.Context) (int, error) diff --git a/db/types.go b/db/types.go index 9d7147a0..124fd120 100644 --- a/db/types.go +++ b/db/types.go @@ -5,6 +5,7 @@ import ( "time" "github.com/ProtonMail/gluon/imap" + "github.com/bradenaw/juniper/xslices" ) type MailboxIDPair struct { @@ -101,6 +102,14 @@ type MessageFlag struct { Value string } +func MessageFlagsFromFlagSet(set imap.FlagSet) []*MessageFlag { + return xslices.Map(set.ToSlice(), func(t string) *MessageFlag { + return &MessageFlag{ + Value: t, + } + }) +} + type Message struct { ID imap.InternalMessageID RemoteID imap.MessageID diff --git a/go.mod b/go.mod index 21dd5004..318820c1 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/emersion/go-imap-uidplus v0.0.0-20200503180755-e75854c361e9 github.com/golang/mock v1.6.0 github.com/google/uuid v1.3.0 - github.com/mattn/go-sqlite3 v1.14.16 + github.com/mattn/go-sqlite3 v1.14.17 github.com/pierrec/lz4/v4 v4.1.17 github.com/pkg/profile v1.7.0 github.com/sirupsen/logrus v1.9.0 diff --git a/go.sum b/go.sum index 6dcdf870..e6714326 100644 --- a/go.sum +++ b/go.sum @@ -53,6 +53,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348 h1:MtvEpTB6LX3vkb4ax0b5D2DHbNAUsen0Gx5wZoq3lV4= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= github.com/pierrec/lz4/v4 v4.1.17 h1:kV4Ip+/hUBC+8T6+2EgburRtkE9ef4nbY3f4dFhGjMc= diff --git a/internal/backend/connector_updates.go b/internal/backend/connector_updates.go index 68604a7f..35dcdcc3 100644 --- a/internal/backend/connector_updates.go +++ b/internal/backend/connector_updates.go @@ -717,7 +717,7 @@ func (user *user) applyMessageUpdated(ctx context.Context, update *imap.MessageU // applyUIDValidityBumped applies a UIDValidityBumped event to the user. func (user *user) applyUIDValidityBumped(ctx context.Context, update *imap.UIDValidityBumped) error { if err := user.db.Write(ctx, func(ctx context.Context, tx db.Transaction) error { - mailboxes, err := tx.GetAllMailboxes(ctx) + mailboxes, err := tx.GetAllMailboxesWithAttr(ctx) if err != nil { return err } diff --git a/internal/db_impl/db_impl.go b/internal/db_impl/db_impl.go index 1c96616a..92ac7427 100644 --- a/internal/db_impl/db_impl.go +++ b/internal/db_impl/db_impl.go @@ -3,8 +3,13 @@ package db_impl import ( "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/internal/db_impl/ent_db" + "github.com/ProtonMail/gluon/internal/db_impl/sqlite3" ) func NewEntDB() db.ClientInterface { return ent_db.NewEntDB() } + +func NewSQLiteDB(options ...sqlite3.Option) db.ClientInterface { + return sqlite3.NewBuilder(options...) +} diff --git a/internal/db_impl/ent_db/ops_read.go b/internal/db_impl/ent_db/ops_read.go index bfdb2d7e..bb64462e 100644 --- a/internal/db_impl/ent_db/ops_read.go +++ b/internal/db_impl/ent_db/ops_read.go @@ -68,7 +68,7 @@ func (op *EntOpsRead) GetMailboxMessageIDPairs(ctx context.Context, mboxID imap. }) } -func (op *EntOpsRead) GetAllMailboxes(ctx context.Context) ([]*db.Mailbox, error) { +func (op *EntOpsRead) GetAllMailboxesWithAttr(ctx context.Context) ([]*db.Mailbox, error) { return wrapEntErrFnTyped(func() ([]*db.Mailbox, error) { val, err := GetAllMailboxes(ctx, op.client) @@ -226,7 +226,7 @@ func (op *EntOpsRead) MessageExistsWithRemoteID(ctx context.Context, id imap.Mes }) } -func (op *EntOpsRead) GetMessage(ctx context.Context, id imap.InternalMessageID) (*db.Message, error) { +func (op *EntOpsRead) GetMessageNoEdges(ctx context.Context, id imap.InternalMessageID) (*db.Message, error) { return wrapEntErrFnTyped(func() (*db.Message, error) { msg, err := GetMessage(ctx, op.client, id) diff --git a/internal/db_impl/sqlite3/client.go b/internal/db_impl/sqlite3/client.go new file mode 100644 index 00000000..05e5ac24 --- /dev/null +++ b/internal/db_impl/sqlite3/client.go @@ -0,0 +1,274 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "sync" + + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/internal/utils" + "github.com/ProtonMail/gluon/reporter" + "github.com/google/uuid" + _ "github.com/mattn/go-sqlite3" + "github.com/sirupsen/logrus" +) + +type Client struct { + db *sql.DB + lock sync.RWMutex + debug bool + trace bool +} + +func (c *Client) Init(ctx context.Context) error { + if _, err := c.db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { + return fmt.Errorf("failed to enable db pragma: %w", err) + } + + if _, err := c.db.ExecContext(ctx, "PRAGMA journal_mode = WAL"); err != nil { + return fmt.Errorf("failed to enable db pragma: %w", err) + } + + if _, err := c.db.ExecContext(ctx, "PRAGMA journal_mode = WAL"); err != nil { + return fmt.Errorf("failed to enable db pragma: %w", err) + } + + return c.wrapTx(ctx, func(ctx context.Context, tx *sql.Tx, entry *logrus.Entry) error { + entry.Debugf("Running database migrations") + return RunMigrations(ctx, TXWrapper{tx: tx}) + }) +} + +func (c *Client) Read(ctx context.Context, op func(context.Context, db.ReadOnly) error) error { + c.lock.RLock() + defer c.lock.RUnlock() + + rdID := uuid.NewString() + + if c.debug { + logrus.Debugf("Begin Read %v", rdID) + defer logrus.Debugf("End Read %v", rdID) + } + + entry := logrus.WithField("rd", rdID) + + var qw QueryWrapper = &DBWrapper{ + db: c.db, + } + + if c.debug { + qw = &DebugQueryWrapper{ + entry: entry, + qw: qw, + } + } + + var ops db.ReadOnly = &readOps{qw: qw} + + if c.trace { + ops = &ReadTracer{rd: ops, entry: entry} + } + + if err := op(ctx, ops); err != nil { + return err + } + + return nil +} + +func (c *Client) Write(ctx context.Context, op func(context.Context, db.Transaction) error) error { + return c.wrapTx(ctx, func(ctx context.Context, tx *sql.Tx, entry *logrus.Entry) error { + + var qw QueryWrapper = &TXWrapper{ + tx: tx, + } + + if c.debug { + qw = &DebugQueryWrapper{ + qw: qw, + entry: entry, + } + } + + var transaction db.Transaction = &writeOps{ + readOps: readOps{ + qw: qw, + }, + qw: qw, + } + + if c.trace { + transaction = &WriteTracer{tx: transaction, ReadTracer: ReadTracer{rd: transaction, entry: entry}} + } + + return op(ctx, transaction) + }) +} + +func (c *Client) wrapTx(ctx context.Context, op func(context.Context, *sql.Tx, *logrus.Entry) error) error { + c.lock.Lock() + defer c.lock.Unlock() + + var entry *logrus.Entry + + if c.debug { + entry = logrus.WithField("tx", uuid.NewString()) + } else { + entry = logrus.WithField("tx", "tx") + } + + tx, err := c.db.BeginTx(ctx, nil) + if err != nil { + return err + } + + if c.debug { + entry.Debugf("Begin Transaction") + } + + defer func() { + if v := recover(); v != nil { + if c.debug { + entry.Debugf("Panic during Transaction") + } + + if err := tx.Rollback(); err != nil { + panic(fmt.Errorf("rolling back while recovering (%v): %w", v, err)) + } + + panic(v) + } + }() + + if err := op(ctx, tx, entry); err != nil { + if c.debug { + entry.Debugf("Rolling back Transaction") + } + + if rerr := tx.Rollback(); rerr != nil { + return fmt.Errorf("rolling back transaction: %w", rerr) + } + + return err + } + + if err := tx.Commit(); err != nil { + if !errors.Is(err, context.Canceled) { + reporter.MessageWithContext(ctx, + "Failed to commit database transaction", + reporter.Context{"error": err, "type": utils.ErrCause(err)}, + ) + } + + if c.debug { + entry.Debugf("Failed to commit Transaction") + } + + return fmt.Errorf("%v: %w", err, db.ErrTransactionFailed) + } + + if c.debug { + entry.Debugf("Transaction Committed") + } + + return nil +} + +func (c *Client) Close() error { + c.lock.Lock() + defer c.lock.Unlock() + + return c.db.Close() +} + +type Builder struct { + debug bool + trace bool +} + +type Option interface { + apply(builder *Builder) +} + +type dbDebugOption struct{} + +func (dbDebugOption) apply(builder *Builder) { + builder.debug = true +} + +type dbTraceOption struct{} + +func (dbTraceOption) apply(builder *Builder) { + builder.trace = true +} + +// Trace enables db interface call tracing. Name of the called functions will be written to trace log. +func Trace() Option { + return &dbTraceOption{} +} + +// Debug enables logging of the SQL queries and their values. Written to debug log. +func Debug() Option { + return &dbDebugOption{} +} + +func NewBuilder(options ...Option) db.ClientInterface { + builder := &Builder{ + debug: false, + trace: false, + } + + for _, opt := range options { + opt.apply(builder) + } + + return builder +} + +func (b Builder) New(dir string, userID string) (db.Client, bool, error) { + if err := os.MkdirAll(dir, 0o700); err != nil { + return nil, false, err + } + + path := getDatabasePath(dir, userID) + + // Check if the database already exists. + exists, err := pathExists(path) + if err != nil { + return nil, false, err + } + + client, err := sql.Open("sqlite3", getDatabaseConn(dir, userID, path)) + if err != nil { + return nil, false, err + } + + return &Client{db: client, debug: b.debug, trace: b.trace}, !exists, nil +} + +func (Builder) Delete(dir string, userID string) error { + return db.DeleteDB(dir, userID) +} + +func getDatabasePath(dir, userID string) string { + return filepath.Join(dir, fmt.Sprintf("%v.db", userID)) +} + +func pathExists(path string) (bool, error) { + if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) { + return false, nil + } else if err != nil { + return false, err + } + + return true, nil +} + +func getDatabaseConn(dir, userID, path string) string { + return fmt.Sprintf("file:%v?cache=shared&_fk=1&_journal=WAL", path) +} diff --git a/internal/db_impl/sqlite3/migration_v0.go b/internal/db_impl/sqlite3/migration_v0.go new file mode 100644 index 00000000..1772f565 --- /dev/null +++ b/internal/db_impl/sqlite3/migration_v0.go @@ -0,0 +1,54 @@ +package sqlite3 + +import ( + "context" + "fmt" + + "github.com/bradenaw/juniper/xmaps" + "github.com/bradenaw/juniper/xslices" + "github.com/sirupsen/logrus" +) + +type MigrationV0 struct{} + +func (m MigrationV0) Run(ctx context.Context, tx TXWrapper) error { + tables := []Table{ + &DeletedSubscriptionsTable{}, + &MailboxesTable{}, + &MailboxFlagsTable{}, + &MailboxAttrTable{}, + &MailboxPermFlagsTable{}, + &MessagesTable{}, + &MessageFlagsTable{}, + &UIDsTable{}, + &GluonVersionTable{}, + } + + tablesNames := xslices.Map(tables, func(t Table) string { + return t.Name() + }) + + query := fmt.Sprintf("SELECT `name` FROM sqlite_master WHERE `type` = 'table' AND `name` NOT LIKE 'sqlite_%%' AND `name` IN (%v)", + GenSQLIn(len(tables))) + + args := MapSliceToAny(tablesNames) + + sqlTables, err := MapQueryRows[string](ctx, tx, query, args...) + if err != nil { + return err + } + + tablesSet := xmaps.SetFromSlice(sqlTables) + + for _, table := range tables { + if !tablesSet.Contains(table.Name()) { + logrus.Debugf("Table '%v' does not exist, creating", table.Name()) + + if err := table.Create(ctx, tx); err != nil { + return err + } + } + } + + return nil +} diff --git a/internal/db_impl/sqlite3/migrations.go b/internal/db_impl/sqlite3/migrations.go new file mode 100644 index 00000000..aa2d777e --- /dev/null +++ b/internal/db_impl/sqlite3/migrations.go @@ -0,0 +1,86 @@ +package sqlite3 + +import ( + "context" + "errors" + "fmt" + "github.com/ProtonMail/gluon/db" + "github.com/sirupsen/logrus" +) + +type Migration interface { + Run(ctx context.Context, tx TXWrapper) error +} + +var migrationList = []Migration{ + &MigrationV0{}, +} + +func RunMigrations(ctx context.Context, tx TXWrapper) error { + dbVersion, err := getDatabaseVersion(ctx, tx) + if err != nil { + return fmt.Errorf("failed to get db version: %w", err) + } + + if dbVersion < 0 { + logrus.Debug("Version table does not exist, running all migrations") + + for idx, m := range migrationList { + logrus.Debugf("Running migration for version %v", idx) + + if err := m.Run(ctx, tx); err != nil { + return fmt.Errorf("failed to run migration %v: %w", idx, err) + } + } + + if err := updateDBVersion(ctx, tx, len(migrationList)-1); err != nil { + return fmt.Errorf("failed to update db version:%w", err) + } + + logrus.Debug("Migrations completed") + + return nil + } + + logrus.Debugf("DB Version is %v", dbVersion) + + for i := dbVersion + 1; i < len(migrationList); i++ { + logrus.Debugf("Running migration for version %v", i) + + if err := migrationList[i].Run(ctx, tx); err != nil { + return err + } + } + + if err := updateDBVersion(ctx, tx, len(migrationList)-1); err != nil { + return fmt.Errorf("failed to update db version:%w", err) + } + + logrus.Debug("Migrations completed") + + return nil +} + +// getDatabaseVersion returns -1 if the version table does not exist or the version information contained within. +func getDatabaseVersion(ctx context.Context, tx TXWrapper) (int, error) { + query := "SELECT `name` FROM sqlite_master WHERE `type` = 'table' AND `name` NOT LIKE 'sqlite_%' AND `name` = 'gluon_version'" + + _, err := MapQueryRow[string](ctx, tx, query) + if err != nil { + if errors.Is(err, db.ErrNotFound) { + return -1, nil + } + + return 0, err + } + + versionQuery := "SELECT `version` FROM gluon_version WHERE `id` = 0" + + return MapQueryRow[int](ctx, tx, versionQuery) +} + +func updateDBVersion(ctx context.Context, tx TXWrapper, version int) error { + query := "UPDATE gluon_version SET `version` = ? WHERE `id` = 0" + + return ExecQueryAndCheckUpdatedNotZero(ctx, tx, query, version) +} diff --git a/internal/db_impl/sqlite3/query_utils.go b/internal/db_impl/sqlite3/query_utils.go new file mode 100644 index 00000000..5d4f5d3a --- /dev/null +++ b/internal/db_impl/sqlite3/query_utils.go @@ -0,0 +1,214 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "github.com/ProtonMail/gluon/db" + "github.com/bradenaw/juniper/xslices" + "github.com/sirupsen/logrus" +) + +// Collection of SQL utilities to process SQL Rows and to convert SQL errors to db.Errors. + +type RowScanner interface { + Scan(args ...any) error +} + +func MapStmtRowsFn[T any](ctx context.Context, qw StmtWrapper, m func(RowScanner) (T, error), args ...any) ([]T, error) { + rows, err := qw.QueryContext(ctx, args...) + if err != nil { + return nil, mapSQLError(err) + } + + return mapSQLRowsFn(rows, m) +} + +func MapStmtRows[T any](ctx context.Context, qw StmtWrapper, args ...any) ([]T, error) { + return MapStmtRowsFn(ctx, qw, func(scanner RowScanner) (T, error) { + var v T + + err := scanner.Scan(&v) + + return v, err + }, args...) +} + +func MapStmtRowFn[T any](ctx context.Context, qw StmtWrapper, m func(RowScanner) (T, error), args ...any) (T, error) { + rows := qw.QueryRowContext(ctx, args...) + + return mapSQLRowFn(rows, m) +} + +func MapStmtRow[T any](ctx context.Context, qw StmtWrapper, args ...any) (T, error) { + return MapStmtRowFn(ctx, qw, func(scanner RowScanner) (T, error) { + var v T + + err := scanner.Scan(&v) + + return v, err + }, args...) +} + +func MapQueryRowsFn[T any](ctx context.Context, qw QueryWrapper, query string, m func(RowScanner) (T, error), args ...any) ([]T, error) { + rows, err := qw.QueryContext(ctx, query, args...) + if err != nil { + return nil, mapSQLError(err) + } + + return mapSQLRowsFn(rows, m) +} + +func MapQueryRows[T any](ctx context.Context, qw QueryWrapper, query string, args ...any) ([]T, error) { + return MapQueryRowsFn(ctx, qw, query, func(scanner RowScanner) (T, error) { + var v T + + err := scanner.Scan(&v) + + return v, err + }, args...) +} + +func MapQueryRowFn[T any](ctx context.Context, qw QueryWrapper, query string, m func(RowScanner) (T, error), args ...any) (T, error) { + row := qw.QueryRowContext(ctx, query, args...) + + return mapSQLRowFn(row, m) +} + +func MapQueryRow[T any](ctx context.Context, qw QueryWrapper, query string, args ...any) (T, error) { + return MapQueryRowFn(ctx, qw, query, func(scanner RowScanner) (T, error) { + var v T + + err := scanner.Scan(&v) + + return v, err + }, args...) +} + +func ExecQueryAndCheckUpdatedNotZero(ctx context.Context, wrapper QueryWrapper, query string, args ...any) error { + updated, err := ExecQuery(ctx, wrapper, query, args...) + if err != nil { + return err + } + + if updated == 0 { + return fmt.Errorf("no values changed") + } + + return nil +} + +func ExecStmtAndCheckUpdatedNotZero(ctx context.Context, wrapper StmtWrapper, args ...any) error { + updated, err := ExecStmt(ctx, wrapper, args...) + if err != nil { + return err + } + + if updated == 0 { + return fmt.Errorf("no values changed") + } + + return nil +} + +func ExecQuery(ctx context.Context, wrapper QueryWrapper, query string, args ...any) (int, error) { + r, err := wrapper.ExecContext(ctx, query, args...) + if err != nil { + return 0, err + } + + affected, err := r.RowsAffected() + if err != nil { + panic("affected rows is unsupported") + } + + return int(affected), nil +} + +func ExecStmt(ctx context.Context, wrapper StmtWrapper, args ...any) (int, error) { + r, err := wrapper.ExecContext(ctx, args...) + if err != nil { + return 0, err + } + + affected, err := r.RowsAffected() + if err != nil { + panic("affected rows is unsupported") + } + + return int(affected), nil +} + +func GenSQLIn(count int) string { + if count <= 0 { + panic("count can't be less or equal to 0") + } + + if count == 1 { + return "?" + } + + return strings.Repeat("?,", count-1) + "?" +} + +func MapSliceToAny[T any](v []T) []any { + return xslices.Map(v, func(t T) any { + return t + }) +} + +func QueryExists(ctx context.Context, qw QueryWrapper, query string, args ...any) (bool, error) { + if _, err := MapQueryRow[int](ctx, qw, query, args...); err != nil { + if errors.Is(err, db.ErrNotFound) { + return false, nil + } + + return false, err + } + + return true, nil +} + +func WrapStmtClose(st StmtWrapper) { + if err := st.Close(); err != nil { + logrus.WithError(err).Error("Failed to close statement") + } +} + +func mapSQLError(err error) error { + if err == nil { + return nil + } + + if errors.Is(err, sql.ErrNoRows) { + return db.ErrNotFound + } + + return err +} + +func mapSQLRowsFn[T any](rows *sql.Rows, m func(RowScanner) (T, error)) ([]T, error) { + defer func() { _ = rows.Close() }() + + var result []T + + for rows.Next() { + val, err := m(rows) + if err != nil { + return nil, err + } + + result = append(result, val) + } + + return result, nil +} + +func mapSQLRowFn[T any](row *sql.Row, m func(scanner RowScanner) (T, error)) (T, error) { + v, err := m(row) + + return v, mapSQLError(err) +} diff --git a/internal/db_impl/sqlite3/read_ops.go b/internal/db_impl/sqlite3/read_ops.go new file mode 100644 index 00000000..7b718baa --- /dev/null +++ b/internal/db_impl/sqlite3/read_ops.go @@ -0,0 +1,598 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/imap" + v0 "github.com/ProtonMail/gluon/internal/db_impl/sqlite3/v0" + "github.com/bradenaw/juniper/xmaps" + "github.com/bradenaw/juniper/xslices" +) + +type readOps struct { + qw QueryWrapper +} + +func (r readOps) MailboxExistsWithID(ctx context.Context, mboxID imap.InternalMailboxID) (bool, error) { + query := fmt.Sprintf("SELEC 1 FROM %[1]v WHERE `%[2]v` = ? LIMIT 1", + v0.MailboxesTableName, + v0.MailboxesFieldID, + ) + + return QueryExists(ctx, r.qw, query, mboxID) +} + +func (r readOps) MailboxExistsWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (bool, error) { + query := fmt.Sprintf("SELECT 1 FROM %[1]v WHERE `%[2]v` = ? LIMIT 1", + v0.MailboxesTableName, + v0.MailboxesFieldRemoteID, + v0.MessagesFieldID, + ) + + return QueryExists(ctx, r.qw, query, mboxID) +} + +func (r readOps) MailboxExistsWithName(ctx context.Context, name string) (bool, error) { + query := fmt.Sprintf("SELECT 1 FROM %[1]v WHERE `%[2]v` = ? LIMIT 1", + v0.MailboxesTableName, + v0.MailboxesFieldName, + ) + + return QueryExists(ctx, r.qw, query, name) +} + +func (r readOps) GetMailboxIDFromRemoteID(ctx context.Context, mboxID imap.MailboxID) (imap.InternalMailboxID, error) { + query := fmt.Sprintf("SELECT `%[2]v` FROM %[1]v WHERE `%[3]v` = ?", + v0.MailboxesTableName, + v0.MailboxesFieldID, + v0.MailboxesFieldRemoteID, + ) + + return MapQueryRow[imap.InternalMailboxID](ctx, r.qw, query, mboxID) +} + +func (r readOps) GetMailboxName(ctx context.Context, mboxID imap.InternalMailboxID) (string, error) { + query := fmt.Sprintf("SELECT `%[2]v` FROM %[1]v WHERE `%[3]v` = ?", + v0.MailboxesTableName, + v0.MailboxesFieldName, + v0.MailboxesFieldID, + ) + + return MapQueryRow[string](ctx, r.qw, query, mboxID) +} + +func (r readOps) GetMailboxNameWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (string, error) { + query := fmt.Sprintf("SELECT `%[2]v` FROM %[1]v WHERE `%[3]v` = ?", + v0.MailboxesTableName, + v0.MailboxesFieldName, + v0.MailboxesFieldRemoteID, + ) + + return MapQueryRow[string](ctx, r.qw, query, mboxID) +} + +func (r readOps) GetMailboxMessageIDPairs(ctx context.Context, mboxID imap.InternalMailboxID) ([]db.MessageIDPair, error) { + query := fmt.Sprintf("SELECT `%[2]v`, `%[3]v` FROM %[1]v WHERE `%[1]v`.`%[2]v` IN (SELECT `%[4]v`.`%[5]v` FROM %[4]v WHERE `%[4]v`.`%[6]v` = ?)", + v0.MessagesTableName, + v0.MessagesFieldID, + v0.MessagesFieldRemoteID, + v0.UIDsTableName, + v0.UIDsFieldMessageID, + v0.UIDsFieldMailboxID, + ) + + return MapQueryRowsFn(ctx, r.qw, query, func(scanner RowScanner) (db.MessageIDPair, error) { + var id db.MessageIDPair + + if err := scanner.Scan(&id.InternalID, &id.RemoteID); err != nil { + return db.MessageIDPair{}, err + } + + return id, nil + }, mboxID) +} + +func (r readOps) GetAllMailboxesWithAttr(ctx context.Context) ([]*db.Mailbox, error) { + query := fmt.Sprintf("SELECT * FROM %v", v0.MailboxesTableName) + + mailboxes, err := MapQueryRowsFn(ctx, r.qw, query, ScanMailbox) + if err != nil { + return nil, err + } + + attrQuery := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` = ?", + v0.MailboxAttrsFieldValue, + v0.MailboxAttrsTableName, + v0.MailboxAttrsFieldMailboxID, + ) + + stmt, err := r.qw.PrepareStatement(ctx, attrQuery) + if err != nil { + return nil, err + } + + defer WrapStmtClose(stmt) + + for _, mbox := range mailboxes { + attrs, err := MapStmtRows[string](ctx, stmt, mbox.ID) + if err != nil { + return nil, err + } + + mbox.Attributes = xslices.Map(attrs, func(t string) *db.MailboxAttr { + return &db.MailboxAttr{Value: t} + }) + } + + return mailboxes, nil +} + +func (r readOps) GetAllMailboxesAsRemoteIDs(ctx context.Context) ([]imap.MailboxID, error) { + query := fmt.Sprintf("SELECT `%v` FROM %v", v0.MessagesFieldRemoteID, v0.MailboxesTableName) + + return MapQueryRows[imap.MailboxID](ctx, r.qw, query) +} + +func (r readOps) GetMailboxByName(ctx context.Context, name string) (*db.Mailbox, error) { + query := fmt.Sprintf("SELECT * FROM %v WHERE `%v` = ?", v0.MailboxesTableName, v0.MailboxesFieldName) + + return MapQueryRowFn(ctx, r.qw, query, ScanMailbox, name) +} + +func (r readOps) GetMailboxByID(ctx context.Context, mboxID imap.InternalMailboxID) (*db.Mailbox, error) { + query := fmt.Sprintf("SELECT * FROM %v WHERE `%v` = ?", v0.MailboxesTableName, v0.MailboxesFieldID) + + return MapQueryRowFn(ctx, r.qw, query, ScanMailbox, mboxID) +} + +func (r readOps) GetMailboxByRemoteID(ctx context.Context, mboxID imap.MailboxID) (*db.Mailbox, error) { + query := fmt.Sprintf("SELECT * FROM %v WHERE `%v` = ?", v0.MailboxesTableName, v0.MailboxesFieldRemoteID) + + return MapQueryRowFn(ctx, r.qw, query, ScanMailbox, mboxID) +} + +func (r readOps) GetMailboxRecentCount(ctx context.Context, mboxID imap.InternalMailboxID) (int, error) { + query := fmt.Sprintf("SELECT COUNT(*) FROM %v WHERE `%v` = TRUE AND `%v` = ?", + v0.UIDsTableName, + v0.UIDsFieldRecent, + v0.UIDsFieldMailboxID, + ) + + return MapQueryRow[int](ctx, r.qw, query, mboxID) +} + +func (r readOps) GetMailboxMessageCount(ctx context.Context, mboxID imap.InternalMailboxID) (int, error) { + query := fmt.Sprintf("SELECT COUNT(*) FROM %v WHERE `%v` = ?", + v0.UIDsTableName, + v0.UIDsFieldMailboxID, + ) + + return MapQueryRow[int](ctx, r.qw, query, mboxID) +} + +func (r readOps) GetMailboxMessageCountWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (int, error) { + internalID, err := r.GetMailboxIDFromRemoteID(ctx, mboxID) + if err != nil { + return 0, err + } + + query := fmt.Sprintf("SELECT COUNT(*) FROM %v WHERE `%v` = ?", + v0.UIDsTableName, + v0.UIDsFieldMailboxID, + ) + + return MapQueryRow[int](ctx, r.qw, query, internalID) +} + +func (r readOps) GetMailboxFlags(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { + query := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` = ?", + v0.MailboxFlagsFieldValue, + v0.MailboxFlagsTableName, + v0.MailboxFlagsFieldMailboxID, + ) + + flags, err := MapQueryRows[string](ctx, r.qw, query, mboxID) + if err != nil { + return imap.FlagSet{}, err + } + + return imap.NewFlagSetFromSlice(flags), nil +} + +func (r readOps) GetMailboxPermanentFlags(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { + query := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` = ?", + v0.MailboxPermFlagsFieldValue, + v0.MailboxPermFlagsTableName, + v0.MailboxPermFlagsFieldMailboxID, + ) + + flags, err := MapQueryRows[string](ctx, r.qw, query, mboxID) + if err != nil { + return imap.FlagSet{}, err + } + + return imap.NewFlagSetFromSlice(flags), nil +} + +func (r readOps) GetMailboxAttributes(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { + query := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` = ?", + v0.MailboxAttrsFieldValue, + v0.MailboxAttrsTableName, + v0.MailboxAttrsFieldMailboxID, + ) + + flags, err := MapQueryRows[string](ctx, r.qw, query, mboxID) + if err != nil { + return imap.FlagSet{}, err + } + + return imap.NewFlagSetFromSlice(flags), nil +} + +func (r readOps) GetMailboxUID(ctx context.Context, mboxID imap.InternalMailboxID) (imap.UID, error) { + query := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` = ? ", + v0.MailboxesFieldUIDNext, + v0.MailboxesTableName, + v0.MailboxesFieldID, + ) + + return MapQueryRow[imap.UID](ctx, r.qw, query, mboxID) +} + +func (r readOps) GetMailboxMessageCountAndUID(ctx context.Context, mboxID imap.InternalMailboxID) (int, imap.UID, error) { + count, err := r.GetMailboxMessageCount(ctx, mboxID) + if err != nil { + return 0, 0, err + } + + uid, err := r.GetMailboxUID(ctx, mboxID) + if err != nil { + return 0, 0, err + } + + return count, uid, nil +} + +func (r readOps) GetMailboxMessageForNewSnapshot(ctx context.Context, mboxID imap.InternalMailboxID) ([]db.SnapshotMessageResult, error) { + query := "SELECT `t1`.`remote_id`, GROUP_CONCAT(`t2`.`value`) AS `flags`, `ui_ds`.`recent`, `ui_ds`.`deleted`, `ui_ds`.`uid`, `ui_ds`.`uid_message` FROM `ui_ds`" + + " JOIN `messages` AS `t1` ON `ui_ds`.`uid_message` = `t1`.`id`" + + " LEFT JOIN `message_flags` AS `t2` ON `ui_ds`.`uid_message` = `t2`.`message_flags` WHERE `mailbox_ui_ds` = ?" + + " GROUP BY `ui_ds`.`uid_message` ORDER BY `ui_ds`.`uid`" + + return MapQueryRowsFn(ctx, r.qw, query, func(scanner RowScanner) (db.SnapshotMessageResult, error) { + var r db.SnapshotMessageResult + var flags sql.NullString + + if err := scanner.Scan(&r.RemoteID, &flags, &r.Recent, &r.Deleted, &r.UID, &r.InternalID); err != nil { + return db.SnapshotMessageResult{}, err + } + + r.Flags = flags.String + + return r, nil + }, mboxID) +} + +func (r readOps) MailboxTranslateRemoteIDs(ctx context.Context, mboxIDs []imap.MailboxID) ([]imap.InternalMailboxID, error) { + result := make([]imap.InternalMailboxID, 0, len(mboxIDs)) + + for _, chunk := range xslices.Chunk(mboxIDs, db.ChunkLimit) { + query := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` IN (%v)", + v0.MailboxesFieldID, + v0.MailboxesTableName, + v0.MailboxesFieldRemoteID, + GenSQLIn(len(chunk)), + ) + + r, err := MapQueryRows[imap.InternalMailboxID](ctx, r.qw, query, MapSliceToAny(chunk)...) + if err != nil { + return nil, err + } + + result = append(result, r...) + } + + return result, nil +} + +func (r readOps) MailboxFilterContains(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []db.MessageIDPair) ([]imap.InternalMessageID, error) { + return r.MailboxFilterContainsInternalID(ctx, mboxID, xslices.Map(messageIDs, func(t db.MessageIDPair) imap.InternalMessageID { + return t.InternalID + })) +} + +func (r readOps) MailboxFilterContainsInternalID(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]imap.InternalMessageID, error) { + result := make([]imap.InternalMessageID, 0, len(messageIDs)) + + for _, chunk := range xslices.Chunk(messageIDs, db.ChunkLimit) { + query := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` IN (%v) AND `%v` = ?", + v0.UIDsFieldMessageID, + v0.UIDsTableName, + v0.UIDsFieldMessageID, + GenSQLIn(len(chunk)), + v0.UIDsFieldMailboxID, + ) + + r, err := MapQueryRows[imap.InternalMessageID](ctx, r.qw, query, append(MapSliceToAny(chunk), mboxID)...) + if err != nil { + return nil, err + } + + result = append(result, r...) + } + + return result, nil +} + +func (r readOps) GetMailboxCount(ctx context.Context) (int, error) { + query := fmt.Sprintf("SELECT COUNT(*) FROM %v", v0.MailboxesTableName) + + return MapQueryRow[int](ctx, r.qw, query) +} + +func (r readOps) GetMailboxMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { + result := make([]db.UIDWithFlags, 0, len(messageIDs)) + + for _, chunk := range xslices.Chunk(messageIDs, db.ChunkLimit) { + query := fmt.Sprintf("SELECT `t1`.`remote_id`, GROUP_CONCAT(`t2`.`value`) AS `flags`, `ui_ds`.`recent`, `ui_ds`.`deleted`, `ui_ds`.`uid`, `ui_ds`.`uid_message` FROM `ui_ds`"+ + " JOIN `messages` AS `t1` ON `ui_ds`.`uid_message` = `t1`.`id`"+ + " LEFT JOIN `message_flags` AS `t2` ON `ui_ds`.`uid_message` = `t2`.`message_flags` WHERE `mailbox_ui_ds` = ? AND `uid_message` in (%v)"+ + " GROUP BY `ui_ds`.`uid_message` ORDER BY `ui_ds`.`uid`", + GenSQLIn(len(chunk))) + + args := make([]any, 0, len(chunk)+1) + args = append(args, mboxID) + args = append(args, MapSliceToAny(chunk)...) + + r, err := MapQueryRowsFn(ctx, r.qw, query, func(scanner RowScanner) (db.UIDWithFlags, error) { + var r db.UIDWithFlags + var flags sql.NullString + + if err := scanner.Scan(&r.RemoteID, &flags, &r.Recent, &r.Deleted, &r.UID, &r.InternalID); err != nil { + return db.UIDWithFlags{}, err + } + + r.Flags = flags.String + + return r, nil + }, args...) + if err != nil { + return nil, err + } + + result = append(result, r...) + } + + return result, nil +} + +func (r readOps) MessageExists(ctx context.Context, id imap.InternalMessageID) (bool, error) { + query := fmt.Sprintf("SELECT 1 FROM %v WHERE `%v` = ? LIMIT 1", v0.MessagesTableName, v0.MessagesFieldID) + + return QueryExists(ctx, r.qw, query, id) +} + +func (r readOps) MessageExistsWithRemoteID(ctx context.Context, id imap.MessageID) (bool, error) { + query := fmt.Sprintf("SELECT 1 FROM %v WHERE `%v` = ? LIMIT 1", v0.MessagesTableName, v0.MessagesFieldRemoteID) + + return QueryExists(ctx, r.qw, query, id) +} + +func (r readOps) GetMessageNoEdges(ctx context.Context, id imap.InternalMessageID) (*db.Message, error) { + query := fmt.Sprintf("SELECT * FROM %v WHERE `%v` = ?", v0.MessagesTableName, v0.MessagesFieldID) + + return MapQueryRowFn(ctx, r.qw, query, ScanMessage, id) +} + +func (r readOps) GetTotalMessageCount(ctx context.Context) (int, error) { + query := fmt.Sprintf("SELECT COUNT(*) FROM %v", v0.MessagesTableName) + + return MapQueryRow[int](ctx, r.qw, query) +} + +func (r readOps) GetMessageRemoteID(ctx context.Context, id imap.InternalMessageID) (imap.MessageID, error) { + query := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` = ?", v0.MessagesFieldRemoteID, v0.MessagesTableName, v0.MessagesFieldID) + + return MapQueryRow[imap.MessageID](ctx, r.qw, query, id) +} + +func (r readOps) GetImportedMessageData(ctx context.Context, id imap.InternalMessageID) (*db.Message, error) { + flagsQuery := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` = ?", + v0.MessageFlagsFieldValue, + v0.MessageFlagsTableName, + v0.MessageFlagsFieldMessageID, + ) + + messageQuery := fmt.Sprintf("SELECT * FROM %v WHERE `%v` = ?", + v0.MessagesTableName, + v0.MessagesFieldID, + ) + + msg, err := MapQueryRowFn(ctx, r.qw, messageQuery, ScanMessage, id) + if err != nil { + return nil, err + } + + flags, err := MapQueryRowsFn(ctx, r.qw, flagsQuery, func(scanner RowScanner) (*db.MessageFlag, error) { + mf := new(db.MessageFlag) + + if err := scanner.Scan(&mf.Value); err != nil { + return nil, err + } + + return mf, nil + }, id) + if err != nil { + return nil, err + } + + msg.Flags = flags + + return msg, nil +} + +func (r readOps) GetMessageDateAndSize(ctx context.Context, id imap.InternalMessageID) (time.Time, int, error) { + query := fmt.Sprintf("SELECT `%v`, `%v` FROM %v WHERE `%v` =?", + v0.MessagesFieldDate, + v0.MessagesFieldSize, + v0.MessagesTableName, + v0.MessagesFieldID, + ) + + type DateSize struct { + Date time.Time + Size int + } + + dt, err := MapQueryRowFn(ctx, r.qw, query, func(scanner RowScanner) (DateSize, error) { + var dt DateSize + + if err := scanner.Scan(&dt.Date, &dt.Size); err != nil { + return DateSize{}, err + } + + return dt, nil + }, id) + if err != nil { + return time.Time{}, 0, err + } + + return dt.Date, dt.Size, nil +} + +func (r readOps) GetMessageMailboxIDs(ctx context.Context, id imap.InternalMessageID) ([]imap.InternalMailboxID, error) { + query := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` = ?", + v0.UIDsFieldMailboxID, + v0.UIDsTableName, + v0.UIDsFieldMessageID, + ) + + return MapQueryRows[imap.InternalMailboxID](ctx, r.qw, query, id) +} + +func (r readOps) GetMessagesFlags(ctx context.Context, ids []imap.InternalMessageID) ([]db.MessageFlagSet, error) { + var result = make([]db.MessageFlagSet, 0, len(ids)) + + flagQuery := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` = ?", + v0.MessageFlagsFieldValue, + v0.MessageFlagsTableName, + v0.MessageFlagsFieldMessageID, + ) + + remoteIDQuery := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` = ?", + v0.MessagesFieldRemoteID, + v0.MessagesTableName, + v0.MessagesFieldID, + ) + + flagStmt, err := r.qw.PrepareStatement(ctx, flagQuery) + if err != nil { + return nil, err + } + + defer WrapStmtClose(flagStmt) + + remoteIDStmt, err := r.qw.PrepareStatement(ctx, remoteIDQuery) + if err != nil { + return nil, err + } + + defer WrapStmtClose(remoteIDStmt) + + // GODT:2522 - Would SELECT GROUP BY id and then reconstructing the flag list over that be faster? + // GODT:2522 - Store remote ID in message flags + + for _, id := range ids { + flags, err := MapStmtRows[string](ctx, flagStmt, id) + if err != nil { + return nil, err + } + + remoteID, err := MapStmtRow[imap.MessageID](ctx, remoteIDStmt, id) + if err != nil { + return nil, err + } + + result = append(result, db.MessageFlagSet{ + ID: id, + RemoteID: remoteID, + FlagSet: imap.NewFlagSetFromSlice(flags), + }) + } + + return result, nil +} + +func (r readOps) GetMessageIDsMarkedAsDelete(ctx context.Context) ([]imap.InternalMessageID, error) { + query := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` = TRUE", + v0.MessagesFieldID, + v0.MessagesTableName, + v0.MessagesFieldDeleted, + ) + + return MapQueryRows[imap.InternalMessageID](ctx, r.qw, query) +} + +func (r readOps) GetMessageIDFromRemoteID(ctx context.Context, id imap.MessageID) (imap.InternalMessageID, error) { + query := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` = ?", + v0.MessagesFieldID, + v0.MessagesTableName, + v0.MessagesFieldRemoteID, + ) + + return MapQueryRow[imap.InternalMessageID](ctx, r.qw, query, id) +} + +func (r readOps) GetMessageDeletedFlag(ctx context.Context, id imap.InternalMessageID) (bool, error) { + query := fmt.Sprintf("SELECT `%v` FROM %v WHERE `%v` = ?", + v0.MessagesFieldDeleted, + v0.MessagesTableName, + v0.MessagesFieldID, + ) + + return MapQueryRow[bool](ctx, r.qw, query, id) +} + +func (r readOps) GetAllMessagesIDsAsMap(ctx context.Context) (map[imap.InternalMessageID]struct{}, error) { + query := fmt.Sprintf("SELECT `%v` FROM %v", v0.MessagesFieldID, v0.MessagesTableName) + + ids, err := MapQueryRows[imap.InternalMessageID](ctx, r.qw, query) + if err != nil { + return nil, err + } + + return xmaps.SetFromSlice(ids), nil +} + +func (r readOps) GetDeletedSubscriptionSet(ctx context.Context) (map[imap.MailboxID]*db.DeletedSubscription, error) { + query := fmt.Sprintf("SELECT `%v`, `%v` FROM %v", + v0.DeletedSubscriptionsFieldName, + v0.DeletedSubscriptionsFieldRemoteID, + v0.DeletedSubscriptionsTableName, + ) + + deletedSubscriptions, err := MapQueryRowsFn(ctx, r.qw, query, func(scanner RowScanner) (*db.DeletedSubscription, error) { + ds := new(db.DeletedSubscription) + + if err := scanner.Scan(&ds.Name, &ds.RemoteID); err != nil { + return nil, err + } + + return ds, nil + }) + if err != nil { + return nil, err + } + + result := make(map[imap.MailboxID]*db.DeletedSubscription, len(deletedSubscriptions)) + + for _, v := range deletedSubscriptions { + result[v.RemoteID] = v + } + + return result, nil +} diff --git a/internal/db_impl/sqlite3/tables.go b/internal/db_impl/sqlite3/tables.go new file mode 100644 index 00000000..25213e09 --- /dev/null +++ b/internal/db_impl/sqlite3/tables.go @@ -0,0 +1,160 @@ +package sqlite3 + +import ( + "context" +) + +type Table interface { + Name() string + Create(ctx context.Context, tx TXWrapper) error +} + +func execQueries(ctx context.Context, tx TXWrapper, queries []string) error { + for _, q := range queries { + if _, err := ExecQuery(ctx, tx, q); err != nil { + return err + } + } + + return nil +} + +type DeletedSubscriptionsTable struct{} + +func (d DeletedSubscriptionsTable) Name() string { + return "deleted_subscriptions" +} + +func (d DeletedSubscriptionsTable) Create(ctx context.Context, tx TXWrapper) error { + queries := []string{ + "CREATE TABLE `deleted_subscriptions` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `name` text NOT NULL, `remote_id` text NOT NULL)", + "CREATE UNIQUE INDEX `deleted_subscriptions_name_key` ON `deleted_subscriptions` (`name`)", + "CREATE UNIQUE INDEX `deleted_subscriptions_remote_id_key` ON `deleted_subscriptions` (`remote_id`)", + "CREATE INDEX `deletedsubscription_remote_id` ON `deleted_subscriptions` (`remote_id`)", + "CREATE INDEX `deletedsubscription_name` ON `deleted_subscriptions` (`name`)", + } + + return execQueries(ctx, tx, queries) +} + +type MailboxesTable struct{} + +func (m MailboxesTable) Name() string { + return "mailboxes" +} + +func (m MailboxesTable) Create(ctx context.Context, tx TXWrapper) error { + queries := []string{ + "CREATE TABLE `mailboxes` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `remote_id` text NULL, `name` text NOT NULL, `uid_next` integer NOT NULL DEFAULT 1, `uid_validity` integer NOT NULL DEFAULT 1, `subscribed` bool NOT NULL DEFAULT true)", + "CREATE UNIQUE INDEX `mailboxes_remote_id_key` ON `mailboxes` (`remote_id`)", + "CREATE UNIQUE INDEX `mailboxes_name_key` ON `mailboxes` (`name`)", + "CREATE INDEX `mailbox_id` ON `mailboxes` (`id`)", + "CREATE INDEX `mailbox_remote_id` ON `mailboxes` (`remote_id`)", + "CREATE INDEX `mailbox_name` ON `mailboxes` (`name`)", + } + + return execQueries(ctx, tx, queries) +} + +type MailboxAttrTable struct{} + +func (m MailboxAttrTable) Name() string { + return "mailbox_attrs" +} + +func (m MailboxAttrTable) Create(ctx context.Context, tx TXWrapper) error { + queries := []string{ + "CREATE TABLE `mailbox_attrs` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `value` text NOT NULL, `mailbox_attributes` integer NULL, CONSTRAINT `mailbox_attrs_mailboxes_attributes` FOREIGN KEY (`mailbox_attributes`) REFERENCES `mailboxes` (`id`) ON DELETE CASCADE)", + } + + return execQueries(ctx, tx, queries) +} + +type MailboxFlagsTable struct{} + +func (m MailboxFlagsTable) Name() string { + return "mailbox_flags" +} + +func (m MailboxFlagsTable) Create(ctx context.Context, tx TXWrapper) error { + queries := []string{ + "CREATE TABLE `mailbox_flags` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `value` text NOT NULL, `mailbox_flags` integer NULL, CONSTRAINT `mailbox_flags_mailboxes_flags` FOREIGN KEY (`mailbox_flags`) REFERENCES `mailboxes` (`id`) ON DELETE CASCADE)", + } + + return execQueries(ctx, tx, queries) +} + +type MailboxPermFlagsTable struct{} + +func (m MailboxPermFlagsTable) Name() string { + return "mailbox_perm_flags" +} + +func (m MailboxPermFlagsTable) Create(ctx context.Context, tx TXWrapper) error { + queries := []string{ + "CREATE TABLE `mailbox_perm_flags` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `value` text NOT NULL, `mailbox_permanent_flags` integer NULL, CONSTRAINT `mailbox_perm_flags_mailboxes_permanent_flags` FOREIGN KEY (`mailbox_permanent_flags`) REFERENCES `mailboxes` (`id`) ON DELETE CASCADE)", + } + + return execQueries(ctx, tx, queries) +} + +type MessagesTable struct{} + +func (m MessagesTable) Name() string { + return "messages" +} + +func (m MessagesTable) Create(ctx context.Context, tx TXWrapper) error { + queries := []string{ + "CREATE TABLE `messages` (`id` uuid NOT NULL, `remote_id` text NULL, `date` datetime NOT NULL, `size` integer NOT NULL, `body` text NOT NULL, `body_structure` text NOT NULL, `envelope` text NOT NULL, `deleted` bool NOT NULL DEFAULT false, PRIMARY KEY (`id`))", + "CREATE UNIQUE INDEX `messages_remote_id_key` ON `messages` (`remote_id`)", + "CREATE INDEX `message_id` ON `messages` (`id`)", + "CREATE INDEX `message_remote_id` ON `messages` (`remote_id`)", + } + + return execQueries(ctx, tx, queries) +} + +type MessageFlagsTable struct{} + +func (m MessageFlagsTable) Name() string { + return "message_flags" +} + +func (m MessageFlagsTable) Create(ctx context.Context, tx TXWrapper) error { + queries := []string{ + "CREATE TABLE `message_flags` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `value` text NOT NULL, `message_flags` uuid NULL, CONSTRAINT `message_flags_messages_flags` FOREIGN KEY (`message_flags`) REFERENCES `messages` (`id`) ON DELETE CASCADE)", + } + + return execQueries(ctx, tx, queries) +} + +type UIDsTable struct{} + +func (U UIDsTable) Name() string { + return "ui_ds" +} + +func (U UIDsTable) Create(ctx context.Context, tx TXWrapper) error { + queries := []string{ + "CREATE TABLE `ui_ds` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `uid` integer NOT NULL, `deleted` bool NOT NULL DEFAULT false, `recent` bool NOT NULL DEFAULT true, `mailbox_ui_ds` integer NULL, `uid_message` uuid NULL, CONSTRAINT `ui_ds_mailboxes_UIDs` FOREIGN KEY (`mailbox_ui_ds`) REFERENCES `mailboxes` (`id`) ON DELETE CASCADE, CONSTRAINT `ui_ds_messages_message` FOREIGN KEY (`uid_message`) REFERENCES `messages` (`id`) ON DELETE SET NULL)", + "CREATE INDEX `uid_uid_uid_message` ON `ui_ds` (`uid`, `uid_message`)", + } + + return execQueries(ctx, tx, queries) +} + +type GluonVersionTable struct{} + +func (g GluonVersionTable) Name() string { + return "gluon_version" +} + +func (g GluonVersionTable) Create(ctx context.Context, tx TXWrapper) error { + queries := []string{ + "CREATE TABLE `gluon_version` (`id` integer NOT NULL PRIMARY KEY CHECK(`id` =0), `version` integer NOT NULL)", + "INSERT INTO gluon_version (`id`, `version`) VALUES (0,0)", + } + + return execQueries(ctx, tx, queries) +} diff --git a/internal/db_impl/sqlite3/tracer.go b/internal/db_impl/sqlite3/tracer.go new file mode 100644 index 00000000..015741bc --- /dev/null +++ b/internal/db_impl/sqlite3/tracer.go @@ -0,0 +1,472 @@ +package sqlite3 + +import ( + "context" + "time" + + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/imap" + "github.com/sirupsen/logrus" +) + +// ReadTracer prints all method names to a trace log. +type ReadTracer struct { + rd db.ReadOnly + entry *logrus.Entry +} + +func (r ReadTracer) MailboxExistsWithID(ctx context.Context, mboxID imap.InternalMailboxID) (bool, error) { + r.entry.Tracef("MailboxExistsWithID") + + return r.rd.MailboxExistsWithID(ctx, mboxID) +} + +func (r ReadTracer) MailboxExistsWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (bool, error) { + r.entry.Tracef("MailboxExistsWithRemoteID") + + return r.rd.MailboxExistsWithRemoteID(ctx, mboxID) +} + +func (r ReadTracer) MailboxExistsWithName(ctx context.Context, name string) (bool, error) { + r.entry.Tracef("MailboxExistsWithName") + + return r.rd.MailboxExistsWithName(ctx, name) +} + +func (r ReadTracer) GetMailboxIDFromRemoteID(ctx context.Context, mboxID imap.MailboxID) (imap.InternalMailboxID, error) { + r.entry.Tracef("GetMailboxIDFromRemoteID") + + return r.rd.GetMailboxIDFromRemoteID(ctx, mboxID) +} + +func (r ReadTracer) GetMailboxName(ctx context.Context, mboxID imap.InternalMailboxID) (string, error) { + r.entry.Tracef("GetMailboxName") + + return r.rd.GetMailboxName(ctx, mboxID) +} + +func (r ReadTracer) GetMailboxNameWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (string, error) { + r.entry.Tracef("GetMailboxNameWithRemoteID") + + return r.rd.GetMailboxNameWithRemoteID(ctx, mboxID) +} + +func (r ReadTracer) GetMailboxMessageIDPairs(ctx context.Context, mboxID imap.InternalMailboxID) ([]db.MessageIDPair, error) { + r.entry.Tracef("GetMailboxMessageIDPairs") + + return r.rd.GetMailboxMessageIDPairs(ctx, mboxID) +} + +func (r ReadTracer) GetAllMailboxesWithAttr(ctx context.Context) ([]*db.Mailbox, error) { + r.entry.Tracef("GetAllMailboxesWithAttr") + + return r.rd.GetAllMailboxesWithAttr(ctx) +} + +func (r ReadTracer) GetAllMailboxesAsRemoteIDs(ctx context.Context) ([]imap.MailboxID, error) { + r.entry.Tracef("GetAllMailboxesAsRemoteIDs") + + return r.rd.GetAllMailboxesAsRemoteIDs(ctx) +} + +func (r ReadTracer) GetMailboxByName(ctx context.Context, name string) (*db.Mailbox, error) { + r.entry.Tracef("GetMailboxByName") + + return r.rd.GetMailboxByName(ctx, name) +} + +func (r ReadTracer) GetMailboxByID(ctx context.Context, mboxID imap.InternalMailboxID) (*db.Mailbox, error) { + r.entry.Tracef("GetMailboxByID") + + return r.rd.GetMailboxByID(ctx, mboxID) +} + +func (r ReadTracer) GetMailboxByRemoteID(ctx context.Context, mboxID imap.MailboxID) (*db.Mailbox, error) { + r.entry.Tracef("GetMailboxByRemoteID") + + return r.rd.GetMailboxByRemoteID(ctx, mboxID) +} + +func (r ReadTracer) GetMailboxRecentCount(ctx context.Context, mboxID imap.InternalMailboxID) (int, error) { + r.entry.Tracef("GetMailboxRecentCount") + + return r.rd.GetMailboxRecentCount(ctx, mboxID) +} + +func (r ReadTracer) GetMailboxMessageCount(ctx context.Context, mboxID imap.InternalMailboxID) (int, error) { + r.entry.Tracef("GetMailboxMessageCount") + + return r.rd.GetMailboxMessageCount(ctx, mboxID) +} + +func (r ReadTracer) GetMailboxMessageCountWithRemoteID(ctx context.Context, mboxID imap.MailboxID) (int, error) { + r.entry.Tracef("GetMailboxMessageCountWithRemoteID") + + return r.rd.GetMailboxMessageCountWithRemoteID(ctx, mboxID) +} + +func (r ReadTracer) GetMailboxFlags(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { + r.entry.Tracef("GetMailboxFlags") + + return r.rd.GetMailboxFlags(ctx, mboxID) +} + +func (r ReadTracer) GetMailboxPermanentFlags(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { + r.entry.Tracef("GetMailboxPermanentFlags") + + return r.rd.GetMailboxPermanentFlags(ctx, mboxID) +} + +func (r ReadTracer) GetMailboxAttributes(ctx context.Context, mboxID imap.InternalMailboxID) (imap.FlagSet, error) { + r.entry.Tracef("GetMailboxAttributes") + + return r.rd.GetMailboxAttributes(ctx, mboxID) +} + +func (r ReadTracer) GetMailboxUID(ctx context.Context, mboxID imap.InternalMailboxID) (imap.UID, error) { + r.entry.Tracef("GetMailboxUID") + + return r.rd.GetMailboxUID(ctx, mboxID) +} + +func (r ReadTracer) GetMailboxMessageCountAndUID(ctx context.Context, mboxID imap.InternalMailboxID) (int, imap.UID, error) { + r.entry.Tracef("GetMailboxMessageCountAndUID") + + return r.rd.GetMailboxMessageCountAndUID(ctx, mboxID) +} + +func (r ReadTracer) GetMailboxMessageForNewSnapshot(ctx context.Context, mboxID imap.InternalMailboxID) ([]db.SnapshotMessageResult, error) { + r.entry.Tracef("GetMailboxMessagesForNewSnapshot") + + return r.rd.GetMailboxMessageForNewSnapshot(ctx, mboxID) +} + +func (r ReadTracer) MailboxTranslateRemoteIDs(ctx context.Context, mboxIDs []imap.MailboxID) ([]imap.InternalMailboxID, error) { + r.entry.Tracef("MailboxTranslateRemoteIDs") + + return r.rd.MailboxTranslateRemoteIDs(ctx, mboxIDs) +} + +func (r ReadTracer) MailboxFilterContains(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []db.MessageIDPair) ([]imap.InternalMessageID, error) { + r.entry.Tracef("MailboxFilterContains") + + return r.rd.MailboxFilterContains(ctx, mboxID, messageIDs) +} + +func (r ReadTracer) MailboxFilterContainsInternalID(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]imap.InternalMessageID, error) { + r.entry.Tracef("MailboxFilterContainsInternalID") + + return r.rd.MailboxFilterContainsInternalID(ctx, mboxID, messageIDs) +} + +func (r ReadTracer) GetMailboxCount(ctx context.Context) (int, error) { + r.entry.Tracef("GetMailboxCount") + + return r.rd.GetMailboxCount(ctx) +} + +func (r ReadTracer) GetMailboxMessageUIDsWithFlagsAfterAddOrUIDBump(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { + r.entry.Tracef("GetMailboxMessageUIDsWithFlagsAfterAddOrUIDBump") + + return r.rd.GetMailboxMessageUIDsWithFlagsAfterAddOrUIDBump(ctx, mboxID, messageIDs) +} + +func (r ReadTracer) MessageExists(ctx context.Context, id imap.InternalMessageID) (bool, error) { + r.entry.Tracef("MessageExists") + + return r.rd.MessageExists(ctx, id) +} + +func (r ReadTracer) MessageExistsWithRemoteID(ctx context.Context, id imap.MessageID) (bool, error) { + r.entry.Tracef("MessageExistsWithRemoteID") + + return r.rd.MessageExistsWithRemoteID(ctx, id) +} + +func (r ReadTracer) GetMessageNoEdges(ctx context.Context, id imap.InternalMessageID) (*db.Message, error) { + r.entry.Tracef("GetMessagesNoEdges") + + return r.rd.GetMessageNoEdges(ctx, id) +} + +func (r ReadTracer) GetTotalMessageCount(ctx context.Context) (int, error) { + r.entry.Tracef("GetTotalMessagecount") + + return r.rd.GetTotalMessageCount(ctx) +} + +func (r ReadTracer) GetMessageRemoteID(ctx context.Context, id imap.InternalMessageID) (imap.MessageID, error) { + r.entry.Tracef("GetMessageRemoteID") + + return r.rd.GetMessageRemoteID(ctx, id) +} + +func (r ReadTracer) GetImportedMessageData(ctx context.Context, id imap.InternalMessageID) (*db.Message, error) { + r.entry.Tracef("GetImportedMessageData") + + return r.rd.GetImportedMessageData(ctx, id) +} + +func (r ReadTracer) GetMessageDateAndSize(ctx context.Context, id imap.InternalMessageID) (time.Time, int, error) { + r.entry.Tracef("GetMessageDateAndSize") + + return r.rd.GetMessageDateAndSize(ctx, id) +} + +func (r ReadTracer) GetMessageMailboxIDs(ctx context.Context, id imap.InternalMessageID) ([]imap.InternalMailboxID, error) { + r.entry.Tracef("GetMailboxIDs") + + return r.rd.GetMessageMailboxIDs(ctx, id) +} + +func (r ReadTracer) GetMessagesFlags(ctx context.Context, ids []imap.InternalMessageID) ([]db.MessageFlagSet, error) { + r.entry.Tracef("GetMessageFlags") + + return r.rd.GetMessagesFlags(ctx, ids) +} + +func (r ReadTracer) GetMessageIDsMarkedAsDelete(ctx context.Context) ([]imap.InternalMessageID, error) { + r.entry.Tracef("GetMessageIDsMarkedAsDelete") + + return r.rd.GetMessageIDsMarkedAsDelete(ctx) +} + +func (r ReadTracer) GetMessageIDFromRemoteID(ctx context.Context, id imap.MessageID) (imap.InternalMessageID, error) { + r.entry.Tracef("GetMessageIDFromRemoteID") + + return r.rd.GetMessageIDFromRemoteID(ctx, id) +} + +func (r ReadTracer) GetMessageDeletedFlag(ctx context.Context, id imap.InternalMessageID) (bool, error) { + r.entry.Tracef("GetMessageDeletedFlag") + + return r.rd.GetMessageDeletedFlag(ctx, id) +} + +func (r ReadTracer) GetAllMessagesIDsAsMap(ctx context.Context) (map[imap.InternalMessageID]struct{}, error) { + r.entry.Tracef("GetAllMessagesIDsAsMap") + + return r.rd.GetAllMessagesIDsAsMap(ctx) +} + +func (r ReadTracer) GetDeletedSubscriptionSet(ctx context.Context) (map[imap.MailboxID]*db.DeletedSubscription, error) { + r.entry.Tracef("GetDeletedSubscriptionSet") + + return r.rd.GetDeletedSubscriptionSet(ctx) +} + +// WriteTracer prints all method names to a trace log. +type WriteTracer struct { + ReadTracer + tx db.Transaction +} + +func (w WriteTracer) CreateMailbox( + ctx context.Context, + mboxID imap.MailboxID, + name string, + flags, permFlags, attrs imap.FlagSet, + uidValidity imap.UID, +) (*db.Mailbox, error) { + w.entry.Tracef("CreateMailbox") + + return w.tx.CreateMailbox(ctx, mboxID, name, flags, permFlags, attrs, uidValidity) +} + +func (w WriteTracer) GetOrCreateMailbox( + ctx context.Context, + mboxID imap.MailboxID, + name string, + flags, permFlags, attrs imap.FlagSet, + uidValidity imap.UID, +) (*db.Mailbox, error) { + w.entry.Tracef("GetOrCreateMailbox") + + return w.tx.GetOrCreateMailbox(ctx, mboxID, name, flags, permFlags, attrs, uidValidity) +} + +func (w WriteTracer) GetOrCreateMailboxAlt( + ctx context.Context, + mbox imap.Mailbox, + delimiter string, + uidValidity imap.UID, +) (*db.Mailbox, error) { + w.entry.Tracef("GetOrCreateMailboxAlt") + + return w.tx.GetOrCreateMailboxAlt(ctx, mbox, delimiter, uidValidity) +} + +func (w WriteTracer) RenameMailboxWithRemoteID(ctx context.Context, mboxID imap.MailboxID, name string) error { + w.entry.Tracef("RenameMailboxWithRemoteID") + + return w.tx.RenameMailboxWithRemoteID(ctx, mboxID, name) +} + +func (w WriteTracer) DeleteMailboxWithRemoteID(ctx context.Context, mboxID imap.MailboxID) error { + w.entry.Tracef("DeleteMailboxWithRemoteID") + + return w.tx.DeleteMailboxWithRemoteID(ctx, mboxID) +} + +func (w WriteTracer) BumpMailboxUIDNext(ctx context.Context, mboxID imap.InternalMailboxID, count int) error { + w.entry.Tracef("BumpMailboxUIDNext") + + return w.tx.BumpMailboxUIDNext(ctx, mboxID, count) +} + +func (w WriteTracer) AddMessagesToMailbox( + ctx context.Context, + mboxID imap.InternalMailboxID, + messageIDs []imap.InternalMessageID, +) ([]db.UIDWithFlags, error) { + w.entry.Tracef("AddMessagesToMailbox") + + return w.tx.AddMessagesToMailbox(ctx, mboxID, messageIDs) +} + +func (w WriteTracer) BumpMailboxUIDsForMessage( + ctx context.Context, + mboxID imap.InternalMailboxID, + messageIDs []imap.InternalMessageID, +) ([]db.UIDWithFlags, error) { + w.entry.Tracef("BumpMailboxUIDsForMessage") + + return w.tx.BumpMailboxUIDsForMessage(ctx, mboxID, messageIDs) +} + +func (w WriteTracer) RemoveMessagesFromMailbox( + ctx context.Context, + mboxID imap.InternalMailboxID, + messageIDs []imap.InternalMessageID, +) error { + w.entry.Tracef("RemoveMessagesFromMailbox") + + return w.tx.RemoveMessagesFromMailbox(ctx, mboxID, messageIDs) +} + +func (w WriteTracer) ClearRecentFlagInMailboxOnMessage( + ctx context.Context, + mboxID imap.InternalMailboxID, + messageID imap.InternalMessageID, +) error { + w.entry.Tracef("ClearRecentFlagInMailboxOnMessage") + + return w.tx.ClearRecentFlagInMailboxOnMessage(ctx, mboxID, messageID) +} + +func (w WriteTracer) ClearRecentFlagsInMailbox(ctx context.Context, mboxID imap.InternalMailboxID) error { + w.entry.Tracef("ClearRecentFlagsInMailbox") + + return w.tx.ClearRecentFlagsInMailbox(ctx, mboxID) +} + +func (w WriteTracer) CreateMailboxIfNotExists(ctx context.Context, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) error { + w.entry.Tracef("ClearMailboxIfNotExists") + + return w.tx.CreateMailboxIfNotExists(ctx, mbox, delimiter, uidValidity) +} + +func (w WriteTracer) SetMailboxMessagesDeletedFlag( + ctx context.Context, + mboxID imap.InternalMailboxID, + messageIDs []imap.InternalMessageID, + deleted bool, +) error { + w.entry.Tracef("SetMailboxMessagesDeleteFlag") + + return w.tx.SetMailboxMessagesDeletedFlag(ctx, mboxID, messageIDs, deleted) +} + +func (w WriteTracer) SetMailboxSubscribed(ctx context.Context, mboxID imap.InternalMailboxID, subscribed bool) error { + w.entry.Tracef("SetMailboxSubscribed") + + return w.tx.SetMailboxSubscribed(ctx, mboxID, subscribed) +} + +func (w WriteTracer) UpdateRemoteMailboxID(ctx context.Context, mobxID imap.InternalMailboxID, remoteID imap.MailboxID) error { + w.entry.Tracef("UpdateRemoteMailboxID") + + return w.tx.UpdateRemoteMailboxID(ctx, mobxID, remoteID) +} + +func (w WriteTracer) SetMailboxUIDValidity(ctx context.Context, mboxID imap.InternalMailboxID, uidValidity imap.UID) error { + w.entry.Tracef("SetMailboxUIDValidity") + + return w.tx.SetMailboxUIDValidity(ctx, mboxID, uidValidity) +} + +func (w WriteTracer) CreateMessages(ctx context.Context, reqs ...*db.CreateMessageReq) ([]*db.Message, error) { + w.entry.Tracef("CreateMessages") + + return w.tx.CreateMessages(ctx, reqs...) +} + +func (w WriteTracer) CreateMessageAndAddToMailbox( + ctx context.Context, + mbox imap.InternalMailboxID, + req *db.CreateMessageReq, +) (imap.UID, imap.FlagSet, error) { + w.entry.Tracef("CreateMessageAndAddToMailbox") + + return w.tx.CreateMessageAndAddToMailbox(ctx, mbox, req) +} + +func (w WriteTracer) MarkMessageAsDeleted(ctx context.Context, id imap.InternalMessageID) error { + w.entry.Tracef("MarkMessageAsDeleted") + + return w.tx.MarkMessageAsDeleted(ctx, id) +} + +func (w WriteTracer) MarkMessageAsDeletedAndAssignRandomRemoteID(ctx context.Context, id imap.InternalMessageID) error { + w.entry.Tracef("MarkMessageAsDeletedAndAssignRandomRemoteID") + + return w.tx.MarkMessageAsDeletedAndAssignRandomRemoteID(ctx, id) +} + +func (w WriteTracer) MarkMessageAsDeletedWithRemoteID(ctx context.Context, id imap.MessageID) error { + w.entry.Tracef("MarkMessageAsDeletedWithRemoteID") + + return w.tx.MarkMessageAsDeletedWithRemoteID(ctx, id) +} + +func (w WriteTracer) DeleteMessages(ctx context.Context, ids []imap.InternalMessageID) error { + w.entry.Tracef("DeleteMessages") + + return w.tx.DeleteMessages(ctx, ids) +} + +func (w WriteTracer) UpdateRemoteMessageID(ctx context.Context, internalID imap.InternalMessageID, remoteID imap.MessageID) error { + w.entry.Tracef("UpdateRemoteMessageID") + + return w.tx.UpdateRemoteMessageID(ctx, internalID, remoteID) +} + +func (w WriteTracer) AddFlagToMessages(ctx context.Context, ids []imap.InternalMessageID, flag string) error { + w.entry.Tracef("AddFlagsToMessage") + + return w.tx.AddFlagToMessages(ctx, ids, flag) +} + +func (w WriteTracer) RemoveFlagFromMessages(ctx context.Context, ids []imap.InternalMessageID, flag string) error { + w.entry.Tracef("RemoveFlagsFromMessages") + + return w.tx.RemoveFlagFromMessages(ctx, ids, flag) +} + +func (w WriteTracer) SetFlagsOnMessages(ctx context.Context, ids []imap.InternalMessageID, flags imap.FlagSet) error { + w.entry.Tracef("SetFlagsOnMessages") + + return w.tx.SetFlagsOnMessages(ctx, ids, flags) +} + +func (w WriteTracer) AddDeletedSubscription(ctx context.Context, mboxName string, mboxID imap.MailboxID) error { + w.entry.Tracef("AddDeletedSubscription") + + return w.tx.AddDeletedSubscription(ctx, mboxName, mboxID) +} + +func (w WriteTracer) RemoveDeletedSubscriptionWithName(ctx context.Context, mboxName string) (int, error) { + w.entry.Tracef("RemoveDeletedSubscriptionWithName") + + return w.tx.RemoveDeletedSubscriptionWithName(ctx, mboxName) +} diff --git a/internal/db_impl/sqlite3/types.go b/internal/db_impl/sqlite3/types.go new file mode 100644 index 00000000..62a3699c --- /dev/null +++ b/internal/db_impl/sqlite3/types.go @@ -0,0 +1,23 @@ +package sqlite3 + +import "github.com/ProtonMail/gluon/db" + +func ScanMailbox(scanner RowScanner) (*db.Mailbox, error) { + mbox := new(db.Mailbox) + + if err := scanner.Scan(&mbox.ID, &mbox.RemoteID, &mbox.Name, &mbox.UIDNext, &mbox.UIDValidity, &mbox.Subscribed); err != nil { + return nil, err + } + + return mbox, nil +} + +func ScanMessage(scanner RowScanner) (*db.Message, error) { + msg := new(db.Message) + + if err := scanner.Scan(&msg.ID, &msg.RemoteID, &msg.Date, &msg.Size, &msg.Body, &msg.BodyStructure, &msg.Envelope, &msg.Deleted); err != nil { + return nil, err + } + + return msg, nil +} diff --git a/internal/db_impl/sqlite3/v0/constants.go b/internal/db_impl/sqlite3/v0/constants.go new file mode 100644 index 00000000..343b5f8b --- /dev/null +++ b/internal/db_impl/sqlite3/v0/constants.go @@ -0,0 +1,46 @@ +package v0 + +const DeletedSubscriptionsTableName = "deleted_subscriptions" +const DeletedSubscriptionsFieldName = "name" +const DeletedSubscriptionsFieldRemoteID = "remote_id" + +const MailboxAttrsTableName = "mailbox_attrs" +const MailboxAttrsFieldValue = "value" +const MailboxAttrsFieldMailboxID = "mailbox_attributes" + +const MailboxFlagsTableName = "mailbox_flags" +const MailboxFlagsFieldValue = "value" +const MailboxFlagsFieldMailboxID = "mailbox_flags" + +const MailboxPermFlagsTableName = "mailbox_perm_flags" +const MailboxPermFlagsFieldValue = "value" +const MailboxPermFlagsFieldMailboxID = "mailbox_permanent_flags" + +const MailboxesTableName = "mailboxes" +const MailboxesFieldID = "id" +const MailboxesFieldRemoteID = "remote_id" +const MailboxesFieldName = "name" +const MailboxesFieldUIDNext = "uid_next" +const MailboxesFieldUIDValidity = "uid_validity" +const MailboxesFieldSubscribed = "subscribed" + +const MessageFlagsTableName = "message_flags" +const MessageFlagsFieldValue = "value" +const MessageFlagsFieldMessageID = "message_flags" + +const MessagesTableName = "messages" +const MessagesFieldID = "id" +const MessagesFieldRemoteID = "remote_id" +const MessagesFieldDate = "date" +const MessagesFieldSize = "size" +const MessagesFieldBody = "body" +const MessagesFieldBodyStructure = "body_structure" +const MessagesFieldEnvelope = "envelope" +const MessagesFieldDeleted = "deleted" + +const UIDsTableName = "ui_ds" +const UIDsFieldUID = "uid" +const UIDsFieldDeleted = "deleted" +const UIDsFieldRecent = "recent" +const UIDsFieldMailboxID = "mailbox_ui_ds" +const UIDsFieldMessageID = "uid_message" diff --git a/internal/db_impl/sqlite3/wrappers.go b/internal/db_impl/sqlite3/wrappers.go new file mode 100644 index 00000000..2c12f8a6 --- /dev/null +++ b/internal/db_impl/sqlite3/wrappers.go @@ -0,0 +1,132 @@ +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/sirupsen/logrus" +) + +// Collection of wrappers to help with tracing and debugging of SQL queries and statements. + +// QueryWrapper is a wrapper around go's sql.DB and sql.Tx types so we can override the calls with trackers (e.g.: +// DebugQueryWrapper). +type QueryWrapper interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + PrepareStatement(ctx context.Context, query string) (StmtWrapper, error) +} + +// StmtWrapper is a wrapper around go's sql.Stmt type so we can override the calls with trackers (e.g.: +// DebugStmtWrapper). +type StmtWrapper interface { + QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, args ...any) *sql.Row + ExecContext(ctx context.Context, args ...any) (sql.Result, error) + Close() error +} + +type DBWrapper struct { + db *sql.DB +} + +func (d DBWrapper) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + return d.db.QueryContext(ctx, query, args...) +} + +func (d DBWrapper) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { + return d.db.QueryRowContext(ctx, query, args...) +} + +func (d DBWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + return d.db.ExecContext(ctx, query, args...) +} + +func (d DBWrapper) PrepareStatement(ctx context.Context, query string) (StmtWrapper, error) { + return d.db.PrepareContext(ctx, query) +} + +type TXWrapper struct { + tx *sql.Tx +} + +func (t TXWrapper) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + return t.tx.QueryContext(ctx, query, args...) +} + +func (t TXWrapper) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { + return t.tx.QueryRowContext(ctx, query, args...) +} + +func (t TXWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + return t.tx.ExecContext(ctx, query, args...) +} + +func (t TXWrapper) PrepareStatement(ctx context.Context, query string) (StmtWrapper, error) { + return t.tx.PrepareContext(ctx, query) +} + +type DebugQueryWrapper struct { + qw QueryWrapper + entry *logrus.Entry +} + +func (d DebugQueryWrapper) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + d.entry.Debugf("query=%v args=%v", query, args) + + return d.qw.QueryContext(ctx, query, args...) +} + +func (d DebugQueryWrapper) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { + d.entry.Debugf("query=%v args=%v", query, args) + + return d.qw.QueryRowContext(ctx, query, args...) +} + +func (d DebugQueryWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + d.entry.Debugf("Exec=%v args=%v", query, args) + + return d.qw.ExecContext(ctx, query, args...) +} + +func (d DebugQueryWrapper) PrepareStatement(ctx context.Context, query string) (StmtWrapper, error) { + stmt, err := d.qw.PrepareStatement(ctx, query) + if err != nil { + return nil, err + } + + return &DebugStmtWrapper{ + sw: stmt, + entry: d.entry, + query: query, + }, nil +} + +type DebugStmtWrapper struct { + sw StmtWrapper + entry *logrus.Entry + query string +} + +func (d DebugStmtWrapper) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) { + d.entry.Debugf("query=%v args=%v", d.query, args) + + return d.sw.QueryContext(ctx, args...) +} + +func (d DebugStmtWrapper) QueryRowContext(ctx context.Context, args ...any) *sql.Row { + d.entry.Debugf("query=%v args=%v", d.query, args) + + return d.sw.QueryRowContext(ctx, args...) +} + +func (d DebugStmtWrapper) ExecContext(ctx context.Context, args ...any) (sql.Result, error) { + d.entry.Debugf("query=%v args=%v", d.query, args) + + return d.sw.ExecContext(ctx, args...) +} + +func (d DebugStmtWrapper) Close() error { + return d.sw.Close() +} diff --git a/internal/db_impl/sqlite3/write_ops.go b/internal/db_impl/sqlite3/write_ops.go new file mode 100644 index 00000000..5c7d76f9 --- /dev/null +++ b/internal/db_impl/sqlite3/write_ops.go @@ -0,0 +1,676 @@ +package sqlite3 + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/ProtonMail/gluon/db" + "github.com/ProtonMail/gluon/imap" + v0 "github.com/ProtonMail/gluon/internal/db_impl/sqlite3/v0" + "github.com/bradenaw/juniper/xslices" +) + +type writeOps struct { + readOps + qw QueryWrapper +} + +func (w writeOps) CreateMailbox( + ctx context.Context, + mboxID imap.MailboxID, + name string, + flags, permFlags, attrs imap.FlagSet, + uidValidity imap.UID, +) (*db.Mailbox, error) { + createMBoxQuery := fmt.Sprintf("INSERT INTO %v (`%v`, `%v`, `%v`, `%v`, `%v`) VALUES (?,?,?,?,?) RETURNING `%v`", + v0.MailboxesTableName, + v0.MailboxesFieldRemoteID, + v0.MailboxesFieldName, + v0.MailboxesFieldUIDNext, + v0.MailboxesFieldUIDValidity, + v0.MailboxesFieldSubscribed, + v0.MailboxesFieldID, + ) + + internalID, err := MapQueryRow[imap.InternalMailboxID](ctx, w.qw, createMBoxQuery, + mboxID, + name, + imap.UID(1), + uidValidity, + true, + ) + if err != nil { + return nil, err + } + + createFlags := func(tableName, fieldID, fieldValue string, flags imap.FlagSet) error { + query := fmt.Sprintf("INSERT INTO %v (`%v`, `%v`) VALUES (?, ?)", + tableName, + fieldID, + fieldValue, + ) + + stmt, err := w.qw.PrepareStatement(ctx, query) + if err != nil { + return err + } + + defer WrapStmtClose(stmt) + + for _, f := range flags.ToSliceUnsorted() { + if _, err := ExecStmt(ctx, stmt, internalID, f); err != nil { + return err + } + } + + return nil + } + + if err := createFlags(v0.MailboxFlagsTableName, v0.MailboxFlagsFieldMailboxID, v0.MailboxFlagsFieldValue, flags); err != nil { + return nil, err + } + + if err := createFlags(v0.MailboxPermFlagsTableName, v0.MailboxPermFlagsFieldMailboxID, v0.MailboxPermFlagsFieldValue, permFlags); err != nil { + return nil, err + } + + if err := createFlags(v0.MailboxAttrsTableName, v0.MailboxAttrsFieldMailboxID, v0.MailboxAttrsFieldValue, attrs); err != nil { + return nil, err + } + + return &db.Mailbox{ + ID: internalID, + RemoteID: mboxID, + Name: name, + UIDNext: 1, + UIDValidity: uidValidity, + Subscribed: true, + Flags: nil, + PermanentFlags: nil, + Attributes: nil, + }, nil +} + +func (w writeOps) GetOrCreateMailbox( + ctx context.Context, + mboxID imap.MailboxID, + name string, + flags, permFlags, attrs imap.FlagSet, + uidValidity imap.UID, +) (*db.Mailbox, error) { + mbox, err := w.GetMailboxByRemoteID(ctx, mboxID) + if err != nil { + if !errors.Is(err, db.ErrNotFound) { + return nil, err + } + } else { + return mbox, nil + } + + return w.CreateMailbox(ctx, mboxID, name, flags, permFlags, attrs, uidValidity) +} + +func (w writeOps) GetOrCreateMailboxAlt(ctx context.Context, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) (*db.Mailbox, error) { + return w.GetOrCreateMailbox( + ctx, + mbox.ID, + strings.Join(mbox.Name, delimiter), + mbox.Flags, + mbox.PermanentFlags, + mbox.Attributes, + uidValidity, + ) +} + +func (w writeOps) RenameMailboxWithRemoteID(ctx context.Context, mboxID imap.MailboxID, name string) error { + query := fmt.Sprintf("UPDATE %v SET `%v` = ? WHERE `%v` = ?", + v0.MailboxesTableName, + v0.MailboxesFieldName, + v0.MailboxesFieldRemoteID, + ) + + return ExecQueryAndCheckUpdatedNotZero(ctx, w.qw, query, name, mboxID) +} + +func (w writeOps) DeleteMailboxWithRemoteID(ctx context.Context, mboxID imap.MailboxID) error { + mbox, err := w.GetMailboxByRemoteID(ctx, mboxID) + if err != nil { + if errors.Is(err, db.ErrNotFound) { + return nil + } + + return err + } + + if mbox.Subscribed { + if err := w.AddDeletedSubscription(ctx, mbox.Name, mboxID); err != nil { + return err + } + } + + query := fmt.Sprintf("DELETE FROM %v WHERE `%v` = ?", + v0.MailboxesTableName, + v0.MailboxesFieldRemoteID) + + _, err = ExecQuery(ctx, w.qw, query, mboxID) + + return err +} + +func (w writeOps) BumpMailboxUIDNext(ctx context.Context, mboxID imap.InternalMailboxID, count int) error { + mboxUID, err := w.GetMailboxUID(ctx, mboxID) + if err != nil { + return err + } + + query := fmt.Sprintf("UPDATE %v SET `%v` = ? WHERE `%v` = ?", + v0.MailboxesTableName, + v0.MailboxesFieldUIDNext, + v0.MailboxesFieldID, + ) + + return ExecQueryAndCheckUpdatedNotZero(ctx, w.qw, query, mboxUID.Add(uint32(count)), mboxID) +} + +func (w writeOps) AddMessagesToMailbox( + ctx context.Context, + mboxID imap.InternalMailboxID, + messageIDs []imap.InternalMessageID, +) ([]db.UIDWithFlags, error) { + if len(messageIDs) == 0 { + return nil, nil + } + + mboxUID, err := w.GetMailboxUID(ctx, mboxID) + if err != nil { + return nil, err + } + + for chunkIdx, chunk := range xslices.Chunk(messageIDs, db.ChunkLimit/2) { + query := fmt.Sprintf("INSERT INTO %v (`%v`, `%v`, `%v`) VALUES %v", + v0.UIDsTableName, + v0.UIDsFieldMailboxID, + v0.UIDsFieldUID, + v0.UIDsFieldMessageID, + strings.Join(xslices.Repeat("(?,?,?)", len(chunk)), ","), + ) + + args := make([]any, 0, 3*len(chunk)) + + for idIdx, id := range chunk { + nextUID := mboxUID.Add(uint32((chunkIdx * db.ChunkLimit / 2) + idIdx)) + args = append(args, mboxID, nextUID, id) + } + + if _, err := ExecQuery(ctx, w.qw, query, args...); err != nil { + return nil, err + } + } + + if err := w.BumpMailboxUIDNext(ctx, mboxID, len(messageIDs)); err != nil { + return nil, err + } + + return w.GetMailboxMessageUIDsWithFlagsAfterAddOrUIDBump(ctx, mboxID, messageIDs) +} + +func (w writeOps) BumpMailboxUIDsForMessage(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) ([]db.UIDWithFlags, error) { + mboxUID, err := w.GetMailboxUID(ctx, mboxID) + if err != nil { + return nil, err + } + + query := fmt.Sprintf("UPDATE %v SET `%v` = ? WHERE `%v` = ? AND `%v`= ?", + v0.UIDsTableName, + v0.UIDsFieldUID, + v0.UIDsFieldMessageID, + v0.UIDsFieldMailboxID, + ) + + stmt, err := w.qw.PrepareStatement(ctx, query) + if err != nil { + return nil, err + } + + defer WrapStmtClose(stmt) + + for idx, id := range messageIDs { + nextUID := mboxUID.Add(uint32(idx)) + + if err := ExecStmtAndCheckUpdatedNotZero(ctx, stmt, nextUID, id, mboxID); err != nil { + return nil, err + } + } + + if err := w.BumpMailboxUIDNext(ctx, mboxID, len(messageIDs)); err != nil { + return nil, err + } + + return w.GetMailboxMessageUIDsWithFlagsAfterAddOrUIDBump(ctx, mboxID, messageIDs) +} + +func (w writeOps) RemoveMessagesFromMailbox(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID) error { + for _, chunk := range xslices.Chunk(messageIDs, db.ChunkLimit) { + query := fmt.Sprintf("DELETE FROM %v WHERE `%v` IN (%v) AND `%v` =?", + v0.UIDsTableName, + v0.UIDsFieldMessageID, + GenSQLIn(len(chunk)), + v0.UIDsFieldMailboxID, + ) + + if _, err := ExecQuery(ctx, w.qw, query, append(MapSliceToAny(messageIDs), mboxID)...); err != nil { + return err + } + } + + return nil +} + +func (w writeOps) ClearRecentFlagInMailboxOnMessage(ctx context.Context, mboxID imap.InternalMailboxID, messageID imap.InternalMessageID) error { + query := fmt.Sprintf("UPDATE %v SET `%v` = FALSE WHERE `%v` = ? AND `%v` =?", + v0.UIDsTableName, + v0.UIDsFieldRecent, + v0.UIDsFieldMailboxID, + v0.UIDsFieldMessageID, + ) + + return ExecQueryAndCheckUpdatedNotZero(ctx, w.qw, query, mboxID, messageID) +} + +func (w writeOps) ClearRecentFlagsInMailbox(ctx context.Context, mboxID imap.InternalMailboxID) error { + query := fmt.Sprintf("UPDATE %v SET `%v` = FALSE WHERE `%v` = ?", + v0.UIDsTableName, + v0.UIDsFieldRecent, + v0.UIDsFieldMailboxID, + ) + + _, err := ExecQuery(ctx, w.qw, query, mboxID) + + return err +} + +func (w writeOps) CreateMailboxIfNotExists(ctx context.Context, mbox imap.Mailbox, delimiter string, uidValidity imap.UID) error { + _, err := w.GetOrCreateMailboxAlt(ctx, mbox, delimiter, uidValidity) + + return err +} + +func (w writeOps) SetMailboxMessagesDeletedFlag(ctx context.Context, mboxID imap.InternalMailboxID, messageIDs []imap.InternalMessageID, deleted bool) error { + for _, chunk := range xslices.Chunk(messageIDs, db.ChunkLimit) { + query := fmt.Sprintf("UPDATE %v SET `%v` =? WHERE `%v` IN (%v) AND `%v` =? ", + v0.UIDsTableName, + v0.UIDsFieldDeleted, + v0.UIDsFieldMessageID, + GenSQLIn(len(chunk)), + v0.UIDsFieldMailboxID, + ) + + args := make([]any, 0, len(chunk)+2) + args = append(args, deleted) + args = append(args, MapSliceToAny(chunk)...) + args = append(args, mboxID) + + if _, err := ExecQuery(ctx, w.qw, query, args...); err != nil { + return err + } + } + + return nil +} + +func (w writeOps) SetMailboxSubscribed(ctx context.Context, mboxID imap.InternalMailboxID, subscribed bool) error { + query := fmt.Sprintf("UPDATE %v SET `%v` = ? WHERE `%v` = ?", + v0.MailboxesTableName, + v0.MailboxesFieldSubscribed, + v0.MailboxesFieldID, + ) + + _, err := ExecQuery(ctx, w.qw, query, subscribed, mboxID) + + return err +} + +func (w writeOps) UpdateRemoteMailboxID(ctx context.Context, mboxID imap.InternalMailboxID, remoteID imap.MailboxID) error { + query := fmt.Sprintf("UPDATE %v SET `%v` = ? WHERE `%v` = ?", + v0.MailboxesTableName, + v0.MailboxesFieldRemoteID, + v0.MailboxesFieldID, + ) + + return ExecQueryAndCheckUpdatedNotZero(ctx, w.qw, query, remoteID, mboxID) +} + +func (w writeOps) SetMailboxUIDValidity(ctx context.Context, mboxID imap.InternalMailboxID, uidValidity imap.UID) error { + query := fmt.Sprintf("UPDATE %v SET `%v` = ? WHERE `%v` = ?", + v0.MailboxesTableName, + v0.MailboxesFieldUIDValidity, + v0.MailboxesFieldID, + ) + + return ExecQueryAndCheckUpdatedNotZero(ctx, w.qw, query, uidValidity, mboxID) +} + +func (w writeOps) CreateMessages(ctx context.Context, reqs ...*db.CreateMessageReq) ([]*db.Message, error) { + result := make([]*db.Message, 0, len(reqs)) + + for _, chunk := range xslices.Chunk(reqs, db.ChunkLimit) { + createMessageQuery := fmt.Sprintf("INSERT INTO %v (`%v`, `%v`, `%v`, `%v`, `%v`, `%v`, `%v`) VALUES %v", + v0.MessagesTableName, + v0.MessagesFieldID, + v0.MessagesFieldRemoteID, + v0.MessagesFieldDate, + v0.MessagesFieldSize, + v0.MessagesFieldBody, + v0.MessagesFieldBodyStructure, + v0.MessagesFieldEnvelope, + strings.Join(xslices.Repeat("(?,?,?,?,?,?,?)", len(chunk)), ","), + ) + + args := make([]any, 0, len(chunk)*6) + flagArgs := make([]any, 0, len(chunk)*2) + + for _, req := range chunk { + args = append(args, + req.InternalID, + req.Message.ID, + req.Message.Date, + req.LiteralSize, + req.Body, + req.Structure, + req.Envelope) + + for _, f := range req.Message.Flags.ToSliceUnsorted() { + flagArgs = append(flagArgs, req.InternalID, f) + } + + result = append(result, &db.Message{ + ID: req.InternalID, + RemoteID: req.Message.ID, + Date: req.Message.Date, + Size: req.LiteralSize, + Body: req.Body, + BodyStructure: req.Structure, + Envelope: req.Envelope, + Deleted: false, + Flags: db.MessageFlagsFromFlagSet(req.Message.Flags), + UIDs: nil, + }) + } + + if _, err := ExecQuery(ctx, w.qw, createMessageQuery, args...); err != nil { + return nil, err + } + + for _, chunk := range xslices.Chunk(flagArgs, db.ChunkLimit) { + createFlagsQuery := fmt.Sprintf("INSERT INTO %v (`%v`, `%v`) VALUES %v", + v0.MessageFlagsTableName, + v0.MessageFlagsFieldMessageID, + v0.MessageFlagsFieldValue, + strings.Join(xslices.Repeat("(?,?)", len(chunk)/2), ","), + ) + + if _, err := ExecQuery(ctx, w.qw, createFlagsQuery, chunk...); err != nil { + return nil, err + } + } + } + + return result, nil +} + +func (w writeOps) CreateMessageAndAddToMailbox(ctx context.Context, mbox imap.InternalMailboxID, req *db.CreateMessageReq) (imap.UID, imap.FlagSet, error) { + mboxUID, err := w.GetMailboxUID(ctx, mbox) + if err != nil { + return 0, imap.FlagSet{}, err + } + + createMessageQuery := fmt.Sprintf("INSERT INTO %v (`%v`, `%v`, `%v`, `%v`, `%v`, `%v`, `%v`) VALUES (?,?,?,?,?,?,?)", + v0.MessagesTableName, + v0.MessagesFieldID, + v0.MessagesFieldRemoteID, + v0.MessagesFieldDate, + v0.MessagesFieldSize, + v0.MessagesFieldBody, + v0.MessagesFieldBodyStructure, + v0.MessagesFieldEnvelope, + ) + + if _, err := ExecQuery(ctx, w.qw, + createMessageQuery, + req.InternalID, + req.Message.ID, + req.Message.Date, + req.LiteralSize, + req.Body, + req.Structure, + req.Envelope, + ); err != nil { + return 0, imap.FlagSet{}, err + } + + if req.Message.Flags.Len() != 0 { + createFlagsQuery := fmt.Sprintf("INSERT INTO %v (`%v`, `%v`) VALUES %v", + v0.MessageFlagsTableName, + v0.MessageFlagsFieldMessageID, + v0.MessageFlagsFieldValue, + strings.Join(xslices.Repeat("(?, ?)", req.Message.Flags.Len()), ","), + ) + + args := make([]any, 0, req.Message.Flags.Len()*2) + for _, f := range req.Message.Flags.ToSliceUnsorted() { + args = append(args, req.InternalID, f) + } + + if _, err := ExecQuery(ctx, w.qw, createFlagsQuery, args...); err != nil { + return 0, imap.FlagSet{}, err + } + } + + addToMboxQuery := fmt.Sprintf("INSERT INTO %v (`%v`, `%v`, `%v`) VALUES (?,?,?)", + v0.UIDsTableName, + v0.UIDsFieldUID, + v0.UIDsFieldMessageID, + v0.UIDsFieldMailboxID, + ) + + if _, err := ExecQuery(ctx, w.qw, addToMboxQuery, mboxUID, req.InternalID, mbox); err != nil { + return 0, imap.FlagSet{}, err + } + + if err := w.BumpMailboxUIDNext(ctx, mbox, 1); err != nil { + return 0, imap.FlagSet{}, err + } + + flags := req.Message.Flags.Add(imap.FlagRecent) + + return mboxUID, flags, nil +} + +func (w writeOps) MarkMessageAsDeleted(ctx context.Context, id imap.InternalMessageID) error { + query := fmt.Sprintf("UPDATE %v SET `%v` = TRUE WHERE `%v` = ?", + v0.MessagesTableName, + v0.MessagesFieldDeleted, + v0.MessagesFieldID, + ) + + _, err := ExecQuery(ctx, w.qw, query, id) + + return err +} + +func (w writeOps) MarkMessageAsDeletedAndAssignRandomRemoteID(ctx context.Context, id imap.InternalMessageID) error { + query := fmt.Sprintf("UPDATE %v SET `%v` = TRUE, `%v` = ? WHERE `%v` = ?", + v0.MessagesTableName, + v0.MessagesFieldDeleted, + v0.MessagesFieldRemoteID, + v0.MessagesFieldID, + ) + + randomID := imap.MessageID(fmt.Sprintf("DELETED-%v", imap.NewInternalMessageID())) + + _, err := ExecQuery(ctx, w.qw, query, randomID, id) + + return err +} + +func (w writeOps) MarkMessageAsDeletedWithRemoteID(ctx context.Context, id imap.MessageID) error { + query := fmt.Sprintf("UPDATE %v SET `%v` = TRUE WHERE `%v` = ?", + v0.MessagesTableName, + v0.MessagesFieldDeleted, + v0.MessagesFieldRemoteID, + ) + + _, err := ExecQuery(ctx, w.qw, query, id) + + return err +} + +func (w writeOps) DeleteMessages(ctx context.Context, ids []imap.InternalMessageID) error { + for _, chunk := range xslices.Chunk(ids, db.ChunkLimit) { + query := fmt.Sprintf("DELETE FROM %v WHERE `%v` IN (%v)", + v0.MessagesTableName, + v0.MessagesFieldID, + GenSQLIn(len(chunk)), + ) + + if _, err := ExecQuery(ctx, w.qw, query, MapSliceToAny(chunk)...); err != nil { + return err + } + } + + return nil +} + +func (w writeOps) UpdateRemoteMessageID(ctx context.Context, internalID imap.InternalMessageID, remoteID imap.MessageID) error { + query := fmt.Sprintf("UPDATE %v SET `%v` = ? WHERE `%v` = ?", + v0.MessagesFieldID, + v0.MessagesFieldRemoteID, + v0.MessagesFieldID, + ) + + return ExecQueryAndCheckUpdatedNotZero(ctx, w.qw, query, remoteID, internalID) +} + +func (w writeOps) AddFlagToMessages(ctx context.Context, ids []imap.InternalMessageID, flag string) error { + for _, chunk := range xslices.Chunk(ids, db.ChunkLimit) { + query := fmt.Sprintf("INSERT INTO %v (`%v`, `%v`) VALUES %v", + v0.MessageFlagsTableName, + v0.MessageFlagsFieldMessageID, + v0.MessageFlagsFieldValue, + strings.Join(xslices.Repeat("(?, ?)", len(chunk)), ","), + ) + + args := make([]any, 0, len(chunk)*2) + + for _, id := range chunk { + args = append(args, id, flag) + } + + if _, err := ExecQuery(ctx, w.qw, query, args...); err != nil { + return err + } + } + + return nil +} + +func (w writeOps) RemoveFlagFromMessages(ctx context.Context, ids []imap.InternalMessageID, flag string) error { + for _, chunk := range xslices.Chunk(ids, db.ChunkLimit) { + query := fmt.Sprintf("DELETE FROM %v WHERE `%v` IN (%v) AND `%v` = ?", + v0.MessageFlagsTableName, + v0.MessageFlagsFieldMessageID, + GenSQLIn(len(chunk)), + v0.MessageFlagsFieldValue, + ) + + if _, err := ExecQuery(ctx, w.qw, query, append(MapSliceToAny(chunk), flag)...); err != nil { + return err + } + } + + return nil +} + +func (w writeOps) SetFlagsOnMessages(ctx context.Context, ids []imap.InternalMessageID, flags imap.FlagSet) error { + // GODT-2522: can silently ignore duplicates with INSERT OR IGNORE INTO ... if constraint exists. + flagSlice := flags.ToSliceUnsorted() + + flagsSQLIn := GenSQLIn(len(flagSlice)) + + for _, chunk := range xslices.Chunk(ids, db.ChunkLimit/2) { + deleteQuery := fmt.Sprintf("DELETE FROM %v WHERE `%v` IN (%v) AND `%v` NOT IN(%v)", + v0.MessageFlagsTableName, + v0.MessageFlagsFieldMessageID, + GenSQLIn(len(chunk)), + v0.MessageFlagsFieldValue, + flagsSQLIn, + ) + + insertQuery := fmt.Sprintf("INSERT OR REPLACE INTO %v (`%v`, `%v`) VALUES %v", + v0.MessageFlagsTableName, + v0.MessageFlagsFieldMessageID, + v0.MessageFlagsFieldValue, + strings.Join(xslices.Repeat("(?,?)", len(flagSlice)), ","), + ) + + deleteArgs := make([]any, 0, len(ids)+len(flagSlice)) + deleteArgs = append(deleteArgs, MapSliceToAny(chunk)...) + deleteArgs = append(deleteArgs, MapSliceToAny(flagSlice)...) + + if _, err := ExecQuery(ctx, w.qw, deleteQuery, deleteArgs...); err != nil { + return err + } + + insertArgs := make([]any, 0, len(flagSlice)*2*len(chunk)) + + for _, id := range chunk { + for _, flag := range flagSlice { + insertArgs = append(insertArgs, id, flag) + } + } + + if _, err := ExecQuery(ctx, w.qw, insertQuery, insertArgs...); err != nil { + return err + } + } + + return nil +} + +func (w writeOps) AddDeletedSubscription(ctx context.Context, mboxName string, mboxID imap.MailboxID) error { + updateQuery := fmt.Sprintf("UPDATE %v SET `%v` = ? WHERE `%v` = ?", + v0.DeletedSubscriptionsTableName, + v0.DeletedSubscriptionsFieldRemoteID, + v0.DeletedSubscriptionsFieldName, + ) + + count, err := ExecQuery(ctx, w.qw, updateQuery, mboxID, mboxName) + if err != nil { + return err + } + + if count == 0 { + createQuery := fmt.Sprintf("INSERT INTO %v (`%v`, `%v`) VALUES (?, ?)", + v0.DeletedSubscriptionsTableName, + v0.DeletedSubscriptionsFieldName, + v0.DeletedSubscriptionsFieldRemoteID, + ) + + if _, err := ExecQuery(ctx, w.qw, createQuery, mboxName, mboxID); err != nil { + return err + } + } + + return nil +} + +func (w writeOps) RemoveDeletedSubscriptionWithName(ctx context.Context, mboxName string) (int, error) { + query := fmt.Sprintf("DELETE FROM %v WHERE `%v` = ?", + v0.DeletedSubscriptionsTableName, + v0.DeletedSubscriptionsFieldName, + ) + + return ExecQuery(ctx, w.qw, query, mboxName) +} diff --git a/internal/state/actions.go b/internal/state/actions.go index 540416d0..091807c8 100644 --- a/internal/state/actions.go +++ b/internal/state/actions.go @@ -105,7 +105,7 @@ func (state *State) actionCreateMessage( if knownErr == nil { // Try to collect the original message date. var existingMessageDate time.Time - if existingMessage, msgErr := tx.GetMessage(ctx, internalID); msgErr == nil { + if existingMessage, msgErr := tx.GetMessageNoEdges(ctx, internalID); msgErr == nil { existingMessageDate = existingMessage.Date } diff --git a/internal/state/mailbox.go b/internal/state/mailbox.go index 92316c0d..e38c0f8f 100644 --- a/internal/state/mailbox.go +++ b/internal/state/mailbox.go @@ -188,6 +188,10 @@ func (m *Mailbox) AppendRegular(ctx context.Context, literal []byte, flags imap. if messageDeleted, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (bool, error) { return client.GetMessageDeletedFlag(ctx, msgID) }); err != nil { + if !errors.Is(err, db.ErrNotFound) { + return 0, err + } + logrus.WithError(err).Warn("The message has an unknown internal ID") } else if !messageDeleted { logrus.Debugf("Appending duplicate message with Internal ID:%v", msgID.ShortID()) diff --git a/internal/state/mailbox_fetch.go b/internal/state/mailbox_fetch.go index ec6bef85..0a3f9cd5 100644 --- a/internal/state/mailbox_fetch.go +++ b/internal/state/mailbox_fetch.go @@ -118,7 +118,7 @@ func (m *Mailbox) Fetch(ctx context.Context, cmd *command.Fetch, ch chan respons msg := snapMessages[i] message, err := stateDBReadResult(ctx, m.state, func(ctx context.Context, client db.ReadOnly) (*db.Message, error) { - return client.GetMessage(ctx, msg.ID.InternalID) + return client.GetMessageNoEdges(ctx, msg.ID.InternalID) }) if err != nil { return err diff --git a/internal/state/state.go b/internal/state/state.go index b35d8ea3..68a59e8a 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -77,7 +77,7 @@ func (state *State) UserID() string { func (state *State) List(ctx context.Context, ref, pattern string, lsub bool, fn func(map[string]Match) error) error { return stateDBRead(ctx, state, func(ctx context.Context, client db.ReadOnly) error { - mailboxes, err := client.GetAllMailboxes(ctx) + mailboxes, err := client.GetAllMailboxesWithAttr(ctx) if err != nil { return err } @@ -375,7 +375,7 @@ func (state *State) Rename(ctx context.Context, oldName, newName string) error { } // Locally update all inferiors so we don't wait for update - mailboxes, err := tx.GetAllMailboxes(ctx) + mailboxes, err := tx.GetAllMailboxesWithAttr(ctx) if err != nil { return nil, err } diff --git a/tests/db_test.go b/tests/db_test.go index 93d808bf..ef18f62f 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -2,6 +2,9 @@ package tests import ( "context" + "github.com/ProtonMail/gluon/internal/db_impl" + "os" + "testing" "github.com/ProtonMail/gluon/db" "github.com/stretchr/testify/require" @@ -17,3 +20,26 @@ func dbCheckUserMessageCount(s *testSession, user string, expectedCount int) { }) require.NoError(s.tb, err) } + +func TestRunEntThenSqlite3(t *testing.T) { + if _, ok := os.LookupEnv("GLUON_TEST_FORCE_ENT_DB"); ok { + t.Skip("Does not make sense to run this test under these conditions") + } + + dataDir := t.TempDir() + + dbDir := t.TempDir() + + // Run once with Ent DB. + runServer(t, defaultServerOptions(t, withDatabase(db_impl.NewEntDB()), withDatabaseDir(dbDir), withDataDir(dataDir)), func(session *testSession) { + + }) + + // Run once with SQLite DB. + runServer(t, defaultServerOptions(t, withDatabase(db_impl.NewSQLiteDB()), withDatabaseDir(dbDir), withDataDir(dataDir)), func(session *testSession) { + }) + + // Run second time with SQLite DB. + runServer(t, defaultServerOptions(t, withDatabase(db_impl.NewSQLiteDB()), withDatabaseDir(dbDir), withDataDir(dataDir)), func(session *testSession) { + }) +} diff --git a/tests/server_test.go b/tests/server_test.go index aedd8227..a9648965 100644 --- a/tests/server_test.go +++ b/tests/server_test.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "net" + "os" "path/filepath" "testing" "time" @@ -256,13 +257,19 @@ func defaultServerOptions(tb testing.TB, modifiers ...serverOption) *serverOptio storeBuilder: &store.OnDiskStoreBuilder{}, connectorBuilder: &dummyConnectorBuilder{}, imapLimits: limits.DefaultLimits(), - database: db_impl.NewEntDB(), + database: db_impl.NewSQLiteDB(), } for _, op := range modifiers { op.apply(options) } + if _, ok := os.LookupEnv("GLUON_TEST_FORCE_ENT_DB"); ok { + logrus.Info("Forcing database to ent") + + options.database = db_impl.NewEntDB() + } + return options } @@ -305,6 +312,7 @@ func runServer(tb testing.TB, options *serverOptions, tests func(session *testSe gluon.WithStoreBuilder(options.storeBuilder), gluon.WithReporter(reporter), gluon.WithIMAPLimits(options.imapLimits), + gluon.WithDBClient(options.database), } if options.disableParallelism {