diff --git a/db/errors.go b/db/errors.go index 0b783bf0..fb6c99b9 100644 --- a/db/errors.go +++ b/db/errors.go @@ -5,6 +5,7 @@ import "errors" var ErrNotFound = errors.New("value not found") var ErrTransactionFailed = errors.New("transaction failed") var ErrMigrationFailed = errors.New("database migration failed") +var ErrInvalidDatabaseVersion = errors.New("invalid database version") func IsErrNotFound(err error) bool { if err == nil { diff --git a/internal/db_impl/sqlite3/client.go b/internal/db_impl/sqlite3/client.go index 1069b1b7..3afb8999 100644 --- a/internal/db_impl/sqlite3/client.go +++ b/internal/db_impl/sqlite3/client.go @@ -71,7 +71,7 @@ func (c *Client) Init(ctx context.Context, generator imap.UIDValidityGenerator) } if err := RunMigrations(ctx, qw, generator); err != nil { - return fmt.Errorf("%w: %v", db.ErrMigrationFailed, err) + return err } return nil diff --git a/internal/db_impl/sqlite3/migration_test.go b/internal/db_impl/sqlite3/migration_test.go index ffe9ab0f..940d86a9 100644 --- a/internal/db_impl/sqlite3/migration_test.go +++ b/internal/db_impl/sqlite3/migration_test.go @@ -3,6 +3,7 @@ package sqlite3 import ( "context" "database/sql" + "errors" "fmt" "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" @@ -17,6 +18,41 @@ import ( "time" ) +func TestMigration_VersionTooHigh(t *testing.T) { + testDir := t.TempDir() + + setup := func() { + client, _, err := NewClient(testDir, "foo", false, false) + require.NoError(t, err) + + ctx := context.Background() + require.NoError(t, client.Init(ctx, &imap.IncrementalUIDValidityGenerator{})) + + defer func() { + require.NoError(t, client.Close()) + }() + + // For version to very high value + require.NoError(t, client.wrapTx(ctx, func(ctx context.Context, tx *sql.Tx, entry *logrus.Entry) error { + qw := utils.TXWrapper{TX: tx} + return updateDBVersion(ctx, qw, 999999) + })) + } + + setup() + + client, _, err := NewClient(testDir, "foo", false, false) + require.NoError(t, err) + + defer func() { + require.NoError(t, client.Close()) + }() + + err = client.Init(context.Background(), imap.DefaultEpochUIDValidityGenerator()) + require.Error(t, err) + require.True(t, errors.Is(err, db.ErrInvalidDatabaseVersion)) +} + func TestRunMigrations(t *testing.T) { logrus.SetLevel(logrus.DebugLevel) @@ -29,6 +65,14 @@ func TestRunMigrations(t *testing.T) { // Fill v0 database. prepareV0Database(t, testDir, "foo", testData, uidGenerator) + // First run, incurs migration. + runAndValidateDB(t, testDir, "foo", testData, uidGenerator) + + // Second run, no migration. + runAndValidateDB(t, testDir, "foo", testData, uidGenerator) +} + +func runAndValidateDB(t *testing.T, testDir, user string, testData *testData, uidGenerator imap.UIDValidityGenerator) { // create client and run all migrations. client, _, err := NewClient(testDir, "foo", false, false) require.NoError(t, err) diff --git a/internal/db_impl/sqlite3/migrations.go b/internal/db_impl/sqlite3/migrations.go index 694afe4e..8f133b32 100644 --- a/internal/db_impl/sqlite3/migrations.go +++ b/internal/db_impl/sqlite3/migrations.go @@ -35,12 +35,12 @@ func RunMigrations(ctx context.Context, tx utils.QueryWrapper, generator imap.UI logrus.Debugf("Running migration for version %v", idx) if err := m.Run(ctx, tx, generator); err != nil { - return fmt.Errorf("failed to run migration %v: %w", idx, err) + return fmt.Errorf("%w %v: %v", db.ErrMigrationFailed, idx, err) } } if err := updateDBVersion(ctx, tx, len(migrationList)-1); err != nil { - return fmt.Errorf("failed to update db version:%w", err) + return fmt.Errorf("%w: failed to update db version: %v", db.ErrMigrationFailed, err) } logrus.Debug("Migrations completed") @@ -50,16 +50,32 @@ func RunMigrations(ctx context.Context, tx utils.QueryWrapper, generator imap.UI logrus.Debugf("DB Version is %v", dbVersion) - for i := dbVersion + 1; i < len(migrationList); i++ { + dbVersion = dbVersion + 1 + + if dbVersion == len(migrationList) { + logrus.Debugf("No migrations to run") + return nil + } + + if dbVersion > len(migrationList) { + return fmt.Errorf( + "%w: database version is %v, but we only support up to %v", + db.ErrInvalidDatabaseVersion, + dbVersion, + len(migrationList), + ) + } + + for i := dbVersion; i < len(migrationList); i++ { logrus.Debugf("Running migration for version %v", i) if err := migrationList[i].Run(ctx, tx, generator); err != nil { - return err + return fmt.Errorf("%w %v: %v", db.ErrMigrationFailed, i, err) } } if err := updateDBVersion(ctx, tx, len(migrationList)-1); err != nil { - return fmt.Errorf("failed to update db version:%w", err) + return fmt.Errorf("%w: failed to update db version: %v", db.ErrMigrationFailed, err) } logrus.Debug("Migrations completed")