diff --git a/internal/server/data/data.go b/internal/server/data/data.go index cce471f724..7e7b6c681b 100644 --- a/internal/server/data/data.go +++ b/internal/server/data/data.go @@ -137,19 +137,6 @@ func (d *DB) Begin(ctx context.Context, opts *sql.TxOptions) (*Transaction, erro return &Transaction{DB: tx, committed: new(atomic.Bool)}, nil } -type WriteTxn interface { - ReadTxn - Exec(sql string, values ...interface{}) (sql.Result, error) -} - -type ReadTxn interface { - DriverName() string - Query(query string, args ...any) (*sql.Rows, error) - QueryRow(query string, args ...any) *sql.Row - - OrganizationID() uid.ID -} - // GormTxn is used as a shim in preparation for removing gorm. type GormTxn interface { WriteTxn @@ -371,16 +358,12 @@ func add[T models.Modelable](tx GormTxn, model *T) error { setOrg(tx, model) var err error - if tx.DriverName() == "postgres" { - // failures on postgres need to be rolled back in order to - // continue using the same transaction - db.SavePoint("beforeCreate") - err = db.Create(model).Error - if err != nil { - db.RollbackTo("beforeCreate") - } - } else { - err = db.Create(model).Error + // failures on postgres need to be rolled back in order to + // continue using the same transaction + db.SavePoint("beforeCreate") + err = db.Create(model).Error + if err != nil { + db.RollbackTo("beforeCreate") } return handleError(err) } diff --git a/internal/server/data/migrations.go b/internal/server/data/migrations.go index 656abffb0c..d4b4026fc1 100644 --- a/internal/server/data/migrations.go +++ b/internal/server/data/migrations.go @@ -8,8 +8,6 @@ import ( "strings" "time" - "gorm.io/gorm" - "github.com/infrahq/infra/internal" "github.com/infrahq/infra/internal/logging" "github.com/infrahq/infra/internal/server/data/migrator" @@ -25,9 +23,6 @@ func migrations() []*migrator.Migration { ID: "202204281130", Migrate: func(tx migrator.DB) error { stmt := `ALTER TABLE settings DROP COLUMN IF EXISTS signup_enabled` - if tx.DriverName() == "sqlite" { - stmt = `ALTER TABLE settings DROP COLUMN signup_enabled` - } _, err := tx.Exec(stmt) return err }, @@ -37,9 +32,6 @@ func migrations() []*migrator.Migration { ID: "202204291613", Migrate: func(tx migrator.DB) error { stmt := `ALTER TABLE identities DROP COLUMN IF EXISTS kind` - if tx.DriverName() == "sqlite" { - stmt = `ALTER TABLE identities DROP COLUMN kind` - } _, err := tx.Exec(stmt) return err }, @@ -82,54 +74,18 @@ func migrations() []*migrator.Migration { var schemaSQL string func initializeSchema(db migrator.DB) error { - if db.DriverName() == "sqlite" { - dataDB, ok := db.(*DB) - if !ok { - panic("unexpected DB type, remove this with gorm") - } - return autoMigrateSchema(dataDB.DB) - } - if _, err := db.Exec(schemaSQL); err != nil { return fmt.Errorf("failed to exec sql: %w", err) } return nil } -func autoMigrateSchema(db *gorm.DB) error { - tables := []interface{}{ - &models.ProviderUser{}, - &models.Group{}, - &models.Identity{}, - &models.Provider{}, - &models.Grant{}, - &models.Destination{}, - &models.AccessKey{}, - &models.Settings{}, - &models.EncryptionKey{}, - &models.Credential{}, - &models.Organization{}, - &models.PasswordResetToken{}, - } - - for _, table := range tables { - if err := db.AutoMigrate(table); err != nil { - return err - } - } - - return nil -} - // #2294: set the provider kind on existing providers func addKindToProviders() *migrator.Migration { return &migrator.Migration{ ID: "202206151027", Migrate: func(tx migrator.DB) error { stmt := `ALTER TABLE providers ADD COLUMN IF NOT EXISTS kind text` - if tx.DriverName() == "sqlite" { - stmt = `ALTER TABLE providers ADD COLUMN kind text` - } if _, err := tx.Exec(stmt); err != nil { return err } @@ -248,9 +204,6 @@ func setDestinationLastSeenAt() *migrator.Migration { } stmt := `ALTER TABLE destinations ADD COLUMN last_seen_at timestamp with time zone` - if tx.DriverName() == "sqlite" { - stmt = `ALTER TABLE destinations ADD COLUMN last_seen_at datetime` - } if _, err := tx.Exec(stmt); err != nil { return err } diff --git a/internal/server/data/migrator/helpers.go b/internal/server/data/migrator/helpers.go index 01a283ee03..9bfa78bb17 100644 --- a/internal/server/data/migrator/helpers.go +++ b/internal/server/data/migrator/helpers.go @@ -17,10 +17,6 @@ func HasTable(tx DB, name string) bool { WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND table_type = 'BASE TABLE' ` - if tx.DriverName() == "sqlite" { - stmt = `SELECT count(*) FROM sqlite_master WHERE type = 'table' AND name = ?` - } - if err := tx.QueryRow(stmt, name).Scan(&count); err != nil { logging.L.Warn().Err(err).Msg("failed to check if table exists") return false @@ -40,17 +36,6 @@ func HasColumn(tx DB, table string, column string) bool { WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ? ` - - if tx.DriverName() == "sqlite" { - stmt = ` - SELECT count(*) - FROM sqlite_master - WHERE type = 'table' AND name = ? - AND sql LIKE ? - ` - column = "% " + column + " %" - } - if err := tx.QueryRow(stmt, table, column).Scan(&count); err != nil { logging.L.Warn().Err(err).Msg("failed to check if column exists") return false @@ -69,16 +54,6 @@ func HasConstraint(tx DB, table string, constraint string) bool { WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND constraint_name = ? ` - if tx.DriverName() == "sqlite" { - stmt = ` - SELECT count(*) - FROM sqlite_master - WHERE type = 'table' AND tbl_name = ? - AND sql LIKE ? - ` - constraint = "%CONSTRAINT `" + constraint + "`%" - } - if err := tx.QueryRow(stmt, table, constraint).Scan(&count); err != nil { logging.L.Warn().Err(err).Msg("failed to check if constraint exists") return false diff --git a/internal/server/data/migrator/helpers_test.go b/internal/server/data/migrator/helpers_test.go index 582347d3b6..a6de2e9f4d 100644 --- a/internal/server/data/migrator/helpers_test.go +++ b/internal/server/data/migrator/helpers_test.go @@ -8,9 +8,6 @@ import ( func setupExampleTable(t *testing.T, db DB) { t.Helper() - if db.DriverName() == "sqlite" { - t.Skip("does not work with sqlite") - } _, _ = db.Exec("DROP TABLE example") diff --git a/internal/server/data/migrator/migrator.go b/internal/server/data/migrator/migrator.go index d17aacfc81..410fea2607 100644 --- a/internal/server/data/migrator/migrator.go +++ b/internal/server/data/migrator/migrator.go @@ -34,9 +34,6 @@ type Migration struct { } type DB interface { - // DriverName returns the name of the database driver. - DriverName() string - Exec(stmt string, args ...any) (sql.Result, error) Query(query string, args ...any) (*sql.Rows, error) QueryRow(query string, args ...any) *sql.Row diff --git a/internal/server/data/query.go b/internal/server/data/query.go index 687b3e8436..7c2baf8d64 100644 --- a/internal/server/data/query.go +++ b/internal/server/data/query.go @@ -9,6 +9,20 @@ import ( "github.com/infrahq/infra/uid" ) +// ReadTxn can perform read queries and contains metadata about the request. +type ReadTxn interface { + Query(query string, args ...any) (*sql.Rows, error) + QueryRow(query string, args ...any) *sql.Row + + OrganizationID() uid.ID +} + +// WriteTxn extends ReadTxn by adding write queries. +type WriteTxn interface { + ReadTxn + Exec(sql string, values ...interface{}) (sql.Result, error) +} + type Table interface { Table() string // Columns returns the names of the table's columns. Columns must return