diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 5e487583..112132e0 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -2,6 +2,7 @@ package backend import ( "context" + "errors" "fmt" "sync" "sync/atomic" @@ -94,16 +95,58 @@ func (b *Backend) AddUser(ctx context.Context, userID string, conn connector.Con return false, err } - db, isNew, err := b.database.New(b.getDBDir(), userID) - if err != nil { + onErrorExit := func() { if err := storeBuilder.Close(); err != nil { logrus.WithError(err).Error("Failed to close store builder") } + } + database, isNew, err := b.database.New(b.getDBDir(), userID) + if err != nil { + onErrorExit() return false, err } - user, err := newUser(ctx, userID, db, conn, storeBuilder, b.delim, b.imapLimits, uidValidityGenerator, b.panicHandler) + if err := database.Init(ctx, uidValidityGenerator); err != nil { + if err := database.Close(); err != nil { + logrus.WithError(err).Errorf("Failed to close db after migration failure") + } + + if !errors.Is(err, db.ErrMigrationFailed) && !errors.Is(err, db.ErrInvalidDatabaseVersion) { + onErrorExit() + return false, err + } + + reporter.ExceptionWithContext(ctx, "database migration failed", reporter.Context{ + "error": err, + }) + + if err := b.database.Delete(b.getDBDir(), userID); err != nil { + onErrorExit() + return false, fmt.Errorf("failed to remove database after migration: %w", err) + } + + database, isNew, err = b.database.New(b.getDBDir(), userID) + if err != nil { + onErrorExit() + return false, err + } + + if !isNew { + if err := database.Close(); err != nil { + logrus.WithError(err).Errorf("failed to closed db") + } + + return false, fmt.Errorf("expected database to be new after failed migration cleanup") + } + + if err := database.Init(ctx, uidValidityGenerator); err != nil { + onErrorExit() + return false, err + } + } + + user, err := newUser(ctx, userID, database, conn, storeBuilder, b.delim, b.imapLimits, uidValidityGenerator, b.panicHandler) if err != nil { return false, err } diff --git a/internal/backend/user.go b/internal/backend/user.go index 3635c595..cab85f33 100644 --- a/internal/backend/user.go +++ b/internal/backend/user.go @@ -60,10 +60,6 @@ func newUser( uidValidityGenerator imap.UIDValidityGenerator, panicHandler async.PanicHandler, ) (*user, error) { - if err := database.Init(ctx, uidValidityGenerator); err != nil { - return nil, err - } - recoveredMessageHashes := utils.NewMessageHashesMap() // Create recovery mailbox if it does not exist diff --git a/internal/db_impl/db_impl.go b/internal/db_impl/db_impl.go index 9d81034f..1fd36649 100644 --- a/internal/db_impl/db_impl.go +++ b/internal/db_impl/db_impl.go @@ -1,6 +1,7 @@ package db_impl import ( + "context" "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/internal/db_impl/sqlite3" ) @@ -8,3 +9,7 @@ import ( func NewSQLiteDB(options ...sqlite3.Option) db.ClientInterface { return sqlite3.NewBuilder(options...) } + +func TestUpdateDBVersion(ctx context.Context, dbPath, userID string, version int) error { + return sqlite3.TestUpdateDBVersion(ctx, dbPath, userID, version) +} diff --git a/internal/db_impl/sqlite3/client.go b/internal/db_impl/sqlite3/client.go index 9daee010..8c40c210 100644 --- a/internal/db_impl/sqlite3/client.go +++ b/internal/db_impl/sqlite3/client.go @@ -289,3 +289,22 @@ func pathExists(path string) (bool, error) { func getDatabaseConn(dir, userID, path string) string { return fmt.Sprintf("file:%v?cache=shared&_fk=1&_journal=WAL", path) } + +func TestUpdateDBVersion(ctx context.Context, dbPath, userID string, version int) error { + client, _, err := NewClient(dbPath, userID, false, false) + if err != nil { + return err + } + + defer func() { + if err := client.Close(); err != nil { + logrus.Panic("failed to close db") + } + }() + + return client.wrapTx(ctx, func(ctx context.Context, tx *sql.Tx, entry *logrus.Entry) error { + qw := utils.TXWrapper{TX: tx} + + return updateDBVersion(ctx, qw, version) + }) +} diff --git a/tests/migration_test.go b/tests/migration_test.go new file mode 100644 index 00000000..acc3cf21 --- /dev/null +++ b/tests/migration_test.go @@ -0,0 +1,23 @@ +package tests + +import ( + "context" + "github.com/ProtonMail/gluon/internal/db_impl" + "github.com/stretchr/testify/require" + "testing" +) + +func TestFailedMigrationRestsDatabase(t *testing.T) { + dbDir := t.TempDir() + serverOptions := defaultServerOptions(t, withDatabaseDir(dbDir)) + + var userID string + + runServer(t, serverOptions, func(session *testSession) { + userID = session.userIDs["user"] + }) + + require.NoError(t, db_impl.TestUpdateDBVersion(context.Background(), dbDir, userID, 99999)) + + runServer(t, serverOptions, func(session *testSession) {}) +}