diff --git a/go.mod b/go.mod index 0a0ea6a4b6..d81ab6c8d2 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,6 @@ require ( gopkg.in/square/go-jose.v2 v2.6.0 gopkg.in/yaml.v2 v2.4.0 gorm.io/driver/postgres v1.3.7 - gorm.io/driver/sqlite v1.3.6 gorm.io/gorm v1.23.5 k8s.io/api v0.25.0 k8s.io/apimachinery v0.25.0 @@ -141,7 +140,6 @@ require ( github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/mattn/go-runewidth v0.0.13 // indirect - github.com/mattn/go-sqlite3 v1.14.12 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect github.com/mitchellh/copystructure v1.2.0 // indirect diff --git a/go.sum b/go.sum index de9e42ccf4..c7bf4d6f05 100644 --- a/go.sum +++ b/go.sum @@ -530,8 +530,6 @@ github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzp github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU= github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/mattn/go-sqlite3 v1.14.12 h1:TJ1bhYJPV44phC+IMu1u2K/i5RriLTPe+yc68XDJ1Z0= -github.com/mattn/go-sqlite3 v1.14.12/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 h1:I0XW9+e1XWDxdcEniV4rQAIOPUGDq67JSCiRCgGCZLI= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= @@ -1236,8 +1234,6 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/postgres v1.3.7 h1:FKF6sIMDHDEvvMF/XJvbnCl0nu6KSKUaPXevJ4r+VYQ= gorm.io/driver/postgres v1.3.7/go.mod h1:f02ympjIcgtHEGFMZvdgTxODZ9snAHDb4hXfigBVuNI= -gorm.io/driver/sqlite v1.3.6 h1:Fi8xNYCUplOqWiPa3/GuCeowRNBRGTf62DEmhMDHeQQ= -gorm.io/driver/sqlite v1.3.6/go.mod h1:Sg1/pvnKtbQ7jLXxfZa+jSHvoX8hoZA8cn4xllOMTgE= gotest.tools/v3 v3.3.0 h1:MfDY1b1/0xN1CyMlQDac0ziEy9zJQd9CXBRRDHw2jJo= gotest.tools/v3 v3.3.0/go.mod h1:Mcr9QNxkg0uMvy/YElmo4SpXgJKWgQvYrT7Kw5RzJ1A= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/internal/access/access_test.go b/internal/access/access_test.go index 2f01707cff..6ecb7568a2 100644 --- a/internal/access/access_test.go +++ b/internal/access/access_test.go @@ -21,11 +21,6 @@ import ( func setupDB(t *testing.T) *data.DB { t.Helper() driver := database.PostgresDriver(t, "_access") - if driver == nil { - lite, err := data.NewSQLiteDriver("file::memory:") - assert.NilError(t, err) - driver = &database.Driver{Dialector: lite} - } patch.ModelsSymmetricKey(t) db, err := data.NewDB(driver.Dialector, nil) diff --git a/internal/cmd/login_test.go b/internal/cmd/login_test.go index 6083c8600a..b6f685b56e 100644 --- a/internal/cmd/login_test.go +++ b/internal/cmd/login_test.go @@ -220,9 +220,8 @@ func setupServerOptions(t *testing.T, opts *server.Options) { // TODO: why do tests fail when the same schemaSuffix is used? suffix := "_cmd_" + t.Name() - if pgDriver := database.PostgresDriver(t, suffix); pgDriver != nil { - opts.DBConnectionString = pgDriver.DSN - } + pgDriver := database.PostgresDriver(t, suffix) + opts.DBConnectionString = pgDriver.DSN } func TestLoginCmd_TLSVerify(t *testing.T) { diff --git a/internal/cmd/server.go b/internal/cmd/server.go index ba7df260f2..b08dea28ce 100644 --- a/internal/cmd/server.go +++ b/internal/cmd/server.go @@ -46,13 +46,6 @@ func newServerCmd() *cobra.Command { options.TLSCache = tlsCache - dbFile, err := canonicalPath(options.DBFile) - if err != nil { - return err - } - - options.DBFile = dbFile - dbEncryptionKey, err := canonicalPath(options.DBEncryptionKey) if err != nil { return err @@ -70,7 +63,6 @@ func newServerCmd() *cobra.Command { cmd.Flags().StringVarP(&configFilename, "config-file", "f", "", "Server configuration file") cmd.Flags().String("tls-cache", "", "Directory to cache TLS certificates") - cmd.Flags().String("db-file", "", "Path to SQLite 3 database") cmd.Flags().String("db-name", "", "Database name") cmd.Flags().String("db-host", "", "Database host") cmd.Flags().Int("db-port", 0, "Database port") @@ -93,7 +85,6 @@ func defaultServerOptions(infraDir string) server.Options { return server.Options{ Version: 0.2, // update this as the config version changes TLSCache: filepath.Join(infraDir, "cache"), - DBFile: filepath.Join(infraDir, "sqlite3.db"), DBEncryptionKey: filepath.Join(infraDir, "sqlite3.db.key"), DBEncryptionKeyProvider: "native", EnableTelemetry: true, diff --git a/internal/cmd/server_test.go b/internal/cmd/server_test.go index 5f1272e172..2c483585ae 100644 --- a/internal/cmd/server_test.go +++ b/internal/cmd/server_test.go @@ -15,6 +15,7 @@ import ( "github.com/infrahq/infra/internal/cmd/types" "github.com/infrahq/infra/internal/server" + "github.com/infrahq/infra/internal/testing/database" ) func TestServerCmd_LoadOptions(t *testing.T) { @@ -126,7 +127,6 @@ enableSignup: false # default is true sessionDuration: 3m sessionExtensionDeadline: 1m -dbFile: /db/file dbEncryptionKey: /this-is-the-path dbEncryptionKeyProvider: the-provider dbHost: the-host @@ -199,7 +199,6 @@ users: DBEncryptionKey: "/this-is-the-path", DBEncryptionKeyProvider: "the-provider", - DBFile: "/db/file", DBHost: "the-host", DBPort: 5432, DBParameters: "sslmode=require", @@ -285,7 +284,6 @@ users: setup: func(t *testing.T, cmd *cobra.Command) { cmd.SetArgs([]string{ "--db-name", "database-name", - "--db-file", "/home/user/database-filename", "--db-port", "12345", "--db-host", "thehostname", "--enable-telemetry=false", @@ -297,7 +295,6 @@ users: expected: func(t *testing.T) server.Options { expected := defaultServerOptions(filepath.Join(dir, ".infra")) expected.DBName = "database-name" - expected.DBFile = "/home/user/database-filename" expected.DBHost = "thehostname" expected.DBPort = 12345 expected.EnableTelemetry = false @@ -317,9 +314,11 @@ users: } func TestServerCmd_WithSecretsConfig(t *testing.T) { + pgDriver := database.PostgresDriver(t, "cmd_server") patchRunServer(t, noServerRun) content := ` + dbConnectionString: ` + pgDriver.DSN + ` addr: http: "127.0.0.1:0" https: "127.0.0.1:0" diff --git a/internal/server/authn/authn_method_test.go b/internal/server/authn/authn_method_test.go index f53debf6a7..b5a1cf4238 100644 --- a/internal/server/authn/authn_method_test.go +++ b/internal/server/authn/authn_method_test.go @@ -17,11 +17,6 @@ import ( func setupDB(t *testing.T) *data.DB { t.Helper() driver := database.PostgresDriver(t, "_authn") - if driver == nil { - lite, err := data.NewSQLiteDriver("file::memory:") - assert.NilError(t, err) - driver = &database.Driver{Dialector: lite} - } patch.ModelsSymmetricKey(t) db, err := data.NewDB(driver.Dialector, nil) diff --git a/internal/server/data/data.go b/internal/server/data/data.go index 6cfa554306..c853b4d69d 100644 --- a/internal/server/data/data.go +++ b/internal/server/data/data.go @@ -6,9 +6,6 @@ import ( "database/sql/driver" "errors" "fmt" - "net/url" - "os" - "path" "reflect" "strings" "time" @@ -16,7 +13,6 @@ import ( "github.com/jackc/pgconn" "github.com/jackc/pgerrcode" - "gorm.io/driver/sqlite" "gorm.io/gorm" "github.com/infrahq/infra/internal" @@ -235,24 +231,6 @@ func initialize(db *DB) error { return nil } -func NewSQLiteDriver(connection string) (gorm.Dialector, error) { - if !strings.HasPrefix(connection, "file::memory") { - if err := os.MkdirAll(path.Dir(connection), os.ModePerm); err != nil { - return nil, err - } - } - uri, err := url.Parse(connection) - if err != nil { - return nil, err - } - query := uri.Query() - query.Add("_journal_mode", "WAL") - uri.RawQuery = query.Encode() - connection = uri.String() - - return sqlite.Open(connection), nil -} - func getDefaultSortFromType(t interface{}) string { ty := reflect.TypeOf(t).Elem() if _, ok := ty.FieldByName("Name"); ok { diff --git a/internal/server/data/data_test.go b/internal/server/data/data_test.go index 0ac4af7973..38a2cfa232 100644 --- a/internal/server/data/data_test.go +++ b/internal/server/data/data_test.go @@ -2,7 +2,6 @@ package data import ( "context" - "os" "testing" "github.com/rs/zerolog" @@ -29,29 +28,13 @@ func setupDB(t *testing.T, driver gorm.Dialector) *DB { return db } -var isEnvironmentCI = os.Getenv("CI") != "" - -// postgresDriver requires postgres to be available in a CI environment, and -// marks the test as skipped when not in CI environment. -func postgresDriver(t *testing.T) gorm.Dialector { - driver := database.PostgresDriver(t, "") - switch { - case driver == nil && isEnvironmentCI: - t.Fatal("CI must test all drivers, set POSTGRESQL_CONNECTION") - case driver == nil: - t.Skip("Set POSTGRESQL_CONNECTION to test against postgresql") - } - return driver.Dialector -} - -// runDBTests against all supported databases. Defaults to only sqlite locally, -// and all supported DBs in CI. +// runDBTests against all supported databases. // Set POSTGRESQL_CONNECTION to a postgresql connection string to run tests // against postgresql. func runDBTests(t *testing.T, run func(t *testing.T, db *DB)) { t.Run("postgres", func(t *testing.T) { - pgsql := postgresDriver(t) - db := setupDB(t, pgsql) + pgsql := database.PostgresDriver(t, "") + db := setupDB(t, pgsql.Dialector) run(t, db) db.Rollback() }) diff --git a/internal/server/data/migrations_test.go b/internal/server/data/migrations_test.go index a42b1dde35..7f6d3b5f99 100644 --- a/internal/server/data/migrations_test.go +++ b/internal/server/data/migrations_test.go @@ -22,6 +22,7 @@ import ( "github.com/infrahq/infra/internal/server/data/migrator" "github.com/infrahq/infra/internal/server/data/schema" "github.com/infrahq/infra/internal/server/models" + "github.com/infrahq/infra/internal/testing/database" "github.com/infrahq/infra/internal/testing/patch" "github.com/infrahq/infra/uid" ) @@ -575,7 +576,7 @@ DELETE FROM settings WHERE id=24567; var initialSchema string runStep(t, "initial schema", func(t *testing.T) { patch.ModelsSymmetricKey(t) - rawDB, err := newRawDB(postgresDriver(t)) + rawDB, err := newRawDB(database.PostgresDriver(t, "").Dialector) assert.NilError(t, err) db := &DB{DB: rawDB} @@ -589,7 +590,7 @@ DELETE FROM settings WHERE id=24567; assert.NilError(t, err) }) - db, err := newRawDB(postgresDriver(t)) + db, err := newRawDB(database.PostgresDriver(t, "").Dialector) assert.NilError(t, err) for i, tc := range testCases { runStep(t, tc.label.Name, func(t *testing.T) { @@ -646,6 +647,8 @@ type testCaseLabel struct { Line string } +var isEnvironmentCI = os.Getenv("CI") != "" + func dumpSchema(t *testing.T, conn string) string { t.Helper() if _, err := exec.LookPath("pg_dump"); err != nil { diff --git a/internal/server/data/migrator/migrator_test.go b/internal/server/data/migrator/migrator_test.go index 212b84c74b..2a7322bc14 100644 --- a/internal/server/data/migrator/migrator_test.go +++ b/internal/server/data/migrator/migrator_test.go @@ -3,17 +3,15 @@ package migrator import ( "database/sql" "database/sql/driver" - "os" - "path/filepath" "testing" - "gorm.io/driver/postgres" - "gorm.io/driver/sqlite" "gorm.io/gorm" "gotest.tools/v3/assert" + + "github.com/infrahq/infra/internal/testing/database" ) -type database struct { +type dbDriver struct { dialect string driver gorm.Dialector } @@ -361,16 +359,10 @@ func migrationCount(t *testing.T, db DB) (count int64) { } func runDBTests(t *testing.T, fn func(t *testing.T, db DB)) { - dir := t.TempDir() - - databases := []database{ - {dialect: "sqlite3", driver: sqlite.Open("file:" + filepath.Join(dir, "sqlite3.db"))}, - } + databases := []dbDriver{} - if pg := os.Getenv("POSTGRESQL_CONNECTION"); pg != "" { - databases = append(databases, database{ - dialect: "postgres", driver: postgres.Open(pg), - }) + if pg := database.PostgresDriver(t, "_migrator"); pg != nil { + databases = append(databases, dbDriver{dialect: "postgres", driver: pg.Dialector}) } for _, database := range databases { diff --git a/internal/server/data/sqlfunc_test.go b/internal/server/data/sqlfunc_test.go index f62081c843..814e20953d 100644 --- a/internal/server/data/sqlfunc_test.go +++ b/internal/server/data/sqlfunc_test.go @@ -5,6 +5,8 @@ import ( "testing" "gotest.tools/v3/assert" + + "github.com/infrahq/infra/internal/testing/database" ) func TestSQLUidStrToIntRoundTrip(t *testing.T) { @@ -13,7 +15,7 @@ func TestSQLUidStrToIntRoundTrip(t *testing.T) { intval int64 err string } - db := setupDB(t, postgresDriver(t)) + db := setupDB(t, database.PostgresDriver(t, "").Dialector) run := func(t *testing.T, tc testCase) { var i int64 diff --git a/internal/server/middleware_test.go b/internal/server/middleware_test.go index d7d42cc409..dce2708887 100644 --- a/internal/server/middleware_test.go +++ b/internal/server/middleware_test.go @@ -29,11 +29,6 @@ import ( func setupDB(t *testing.T) *data.DB { t.Helper() driver := database.PostgresDriver(t, "_server") - if driver == nil { - lite, err := data.NewSQLiteDriver("file::memory:") - assert.NilError(t, err) - driver = &database.Driver{Dialector: lite} - } tpatch.ModelsSymmetricKey(t) db, err := data.NewDB(driver.Dialector, nil) diff --git a/internal/server/models/encryption_test.go b/internal/server/models/encryption_test.go index aecbf2970a..7597185ea7 100644 --- a/internal/server/models/encryption_test.go +++ b/internal/server/models/encryption_test.go @@ -8,6 +8,7 @@ import ( "github.com/infrahq/infra/internal/server/data" "github.com/infrahq/infra/internal/server/models" + "github.com/infrahq/infra/internal/testing/database" "github.com/infrahq/infra/internal/testing/patch" "github.com/infrahq/infra/uid" ) @@ -20,7 +21,7 @@ type StructForTesting struct { func (s StructForTesting) Schema() string { return ` CREATE TABLE struct_for_testings ( - id integer PRIMARY KEY, + id bigint PRIMARY KEY, a_secret text );` } @@ -28,10 +29,8 @@ CREATE TABLE struct_for_testings ( func TestEncryptedAtRest(t *testing.T) { patch.ModelsSymmetricKey(t) - driver, err := data.NewSQLiteDriver("file::memory:") - assert.NilError(t, err) - - db, err := data.NewDB(driver, nil) + pg := database.PostgresDriver(t, "_models") + db, err := data.NewDB(pg.Dialector, nil) assert.NilError(t, err) _, err = db.Exec(StructForTesting{}.Schema()) diff --git a/internal/server/server.go b/internal/server/server.go index ee36028b19..d9ba989b65 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -41,7 +41,6 @@ type Options struct { SessionDuration time.Duration SessionExtensionDeadline time.Duration - DBFile string DBEncryptionKey string DBEncryptionKeyProvider string DBHost string @@ -135,7 +134,7 @@ func New(options Options) (*Server, error) { return nil, fmt.Errorf("key config: %w", err) } - driver, err := server.getDatabaseDriver() + driver, err := getDatabaseDriver(options, server.secrets) if err != nil { return nil, fmt.Errorf("driver: %w", err) } @@ -308,33 +307,31 @@ type routine struct { stop func() } -func (s *Server) getDatabaseDriver() (gorm.Dialector, error) { - pgDSN, err := s.getPostgresConnectionString() - if err != nil { +func getDatabaseDriver(options Options, secretStorage map[string]secrets.SecretStorage) (gorm.Dialector, error) { + pgDSN, err := getPostgresConnectionString(options, secretStorage) + switch { + case err != nil: return nil, fmt.Errorf("postgres: %w", err) + case pgDSN == "": + return nil, fmt.Errorf("missing postgreSQL connection options") } - - if pgDSN != "" { - return postgres.Open(pgDSN), nil - } - - return data.NewSQLiteDriver(s.options.DBFile) + return postgres.Open(pgDSN), nil } // getPostgresConnectionString parses postgres configuration options and returns the connection string -func (s *Server) getPostgresConnectionString() (string, error) { +func getPostgresConnectionString(options Options, secretStorage map[string]secrets.SecretStorage) (string, error) { var pgConn strings.Builder - pgConn.WriteString(s.options.DBConnectionString) + pgConn.WriteString(options.DBConnectionString) - if s.options.DBHost != "" { + if options.DBHost != "" { // config has separate postgres parameters set, combine them into a connection DSN now - fmt.Fprintf(&pgConn, "host=%s ", s.options.DBHost) + fmt.Fprintf(&pgConn, "host=%s ", options.DBHost) - if s.options.DBUsername != "" { - fmt.Fprintf(&pgConn, "user=%s ", s.options.DBUsername) + if options.DBUsername != "" { + fmt.Fprintf(&pgConn, "user=%s ", options.DBUsername) - if s.options.DBPassword != "" { - pass, err := secrets.GetSecret(s.options.DBPassword, s.secrets) + if options.DBPassword != "" { + pass, err := secrets.GetSecret(options.DBPassword, secretStorage) if err != nil { return "", fmt.Errorf("postgres secret: %w", err) } @@ -343,16 +340,16 @@ func (s *Server) getPostgresConnectionString() (string, error) { } } - if s.options.DBPort > 0 { - fmt.Fprintf(&pgConn, "port=%d ", s.options.DBPort) + if options.DBPort > 0 { + fmt.Fprintf(&pgConn, "port=%d ", options.DBPort) } - if s.options.DBName != "" { - fmt.Fprintf(&pgConn, "dbname=%s ", s.options.DBName) + if options.DBName != "" { + fmt.Fprintf(&pgConn, "dbname=%s ", options.DBName) } - if s.options.DBParameters != "" { - fmt.Fprint(&pgConn, s.options.DBParameters) + if options.DBParameters != "" { + fmt.Fprint(&pgConn, options.DBParameters) } } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 17d2dcdd63..4ca9597d26 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -52,48 +52,38 @@ func setupServer(t *testing.T, ops ...func(*testing.T, *Options)) *Server { func TestGetPostgresConnectionURL(t *testing.T) { logging.PatchLogger(t, zerolog.NewTestWriter(t)) - r := newServer(Options{}) - - f := secrets.NewPlainSecretProviderFromConfig(secrets.GenericConfig{}) - r.secrets["plaintext"] = f + storage := map[string]secrets.SecretStorage{ + "plaintext": secrets.NewPlainSecretProviderFromConfig(secrets.GenericConfig{}), + } + options := Options{} - url, err := r.getPostgresConnectionString() + url, err := getPostgresConnectionString(options, storage) assert.NilError(t, err) - assert.Assert(t, is.Len(url, 0)) - r.options.DBHost = "localhost" - - url, err = r.getPostgresConnectionString() + options.DBHost = "localhost" + url, err = getPostgresConnectionString(options, storage) assert.NilError(t, err) - assert.Equal(t, "host=localhost", url) - r.options.DBPort = 5432 - - url, err = r.getPostgresConnectionString() + options.DBPort = 5432 + url, err = getPostgresConnectionString(options, storage) assert.NilError(t, err) assert.Equal(t, "host=localhost port=5432", url) - r.options.DBUsername = "user" - - url, err = r.getPostgresConnectionString() + options.DBUsername = "user" + url, err = getPostgresConnectionString(options, storage) assert.NilError(t, err) - assert.Equal(t, "host=localhost user=user port=5432", url) - r.options.DBPassword = "plaintext:secret" - - url, err = r.getPostgresConnectionString() + options.DBPassword = "plaintext:secret" + url, err = getPostgresConnectionString(options, storage) assert.NilError(t, err) - assert.Equal(t, "host=localhost user=user password=secret port=5432", url) - r.options.DBName = "postgres" - - url, err = r.getPostgresConnectionString() + options.DBName = "postgres" + url, err = getPostgresConnectionString(options, storage) assert.NilError(t, err) - assert.Equal(t, "host=localhost user=user password=secret port=5432 dbname=postgres", url) } @@ -115,11 +105,8 @@ func TestServer_Run(t *testing.T) { }, } - if driver := database.PostgresDriver(t, "_server_run"); driver != nil { - opts.DBConnectionString = driver.DSN - } else { - opts.DBFile = filepath.Join(dir, "sqlite3.db") - } + driver := database.PostgresDriver(t, "_server_run") + opts.DBConnectionString = driver.DSN srv, err := New(opts) assert.NilError(t, err) @@ -216,11 +203,8 @@ func TestServer_Run_UIProxy(t *testing.T) { } assert.NilError(t, opts.UI.ProxyURL.Set(uiSrv.URL)) - if driver := database.PostgresDriver(t, "_server_run"); driver != nil { - opts.DBConnectionString = driver.DSN - } else { - opts.DBFile = filepath.Join(dir, "sqlite3.db") - } + driver := database.PostgresDriver(t, "_server_run") + opts.DBConnectionString = driver.DSN srv, err := New(opts) assert.NilError(t, err) diff --git a/internal/testing/database/postgres.go b/internal/testing/database/postgres.go index 28fcbbec8e..41394b5d9e 100644 --- a/internal/testing/database/postgres.go +++ b/internal/testing/database/postgres.go @@ -12,9 +12,13 @@ import ( type TestingT interface { assert.TestingT Cleanup(func()) + Fatal(...any) + Skip(...any) Helper() } +var isEnvironmentCI = os.Getenv("CI") != "" + // PostgresDriver returns a driver for connecting to postgres based on the // POSTGRESQL_CONNECTION environment variable. The value should be a postgres // connection string, see @@ -25,8 +29,11 @@ type TestingT interface { func PostgresDriver(t TestingT, schemaSuffix string) *Driver { t.Helper() pgConn, ok := os.LookupEnv("POSTGRESQL_CONNECTION") - if !ok { - return nil + switch { + case !ok && isEnvironmentCI: + t.Fatal("CI must test all drivers, set POSTGRESQL_CONNECTION") + case !ok: + t.Skip("Set POSTGRESQL_CONNECTION to test against postgresql") } suffix := strings.NewReplacer("--", "", ";", "", "/", "").Replace(schemaSuffix)