From 053b1fd49efec6add765ff93721b457adcd49c48 Mon Sep 17 00:00:00 2001 From: Michael Fridman Date: Sat, 19 Oct 2024 12:34:32 -0400 Subject: [PATCH] testing: replace check with stretchr/testify (#842) --- database/store_test.go | 124 ++--- globals_test.go | 234 ++++----- goose_cli_test.go | 58 +-- goose_embed_test.go | 30 +- internal/check/check.go | 86 ---- .../migrationstats/migrationstats_test.go | 38 +- internal/sqlparser/parse_test.go | 29 +- internal/sqlparser/parser_test.go | 59 +-- .../integration/postgres_locking_test.go | 69 ++- migrate_test.go | 110 ++-- provider_collect_test.go | 158 +++--- provider_options_test.go | 26 +- provider_run_test.go | 474 +++++++++--------- provider_test.go | 22 +- .../error/gomigrations_error_test.go | 50 +- tests/gomigrations/register/register_test.go | 82 +-- .../success/gomigrations_success_test.go | 36 +- 17 files changed, 799 insertions(+), 886 deletions(-) delete mode 100644 internal/check/check.go diff --git a/database/store_test.go b/database/store_test.go index d63f7f818..8e9d6e89f 100644 --- a/database/store_test.go +++ b/database/store_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/pressly/goose/v3/database" - "github.com/pressly/goose/v3/internal/check" + "github.com/stretchr/testify/require" "go.uber.org/multierr" "modernc.org/sqlite" ) @@ -22,47 +22,47 @@ func TestDialectStore(t *testing.T) { t.Run("invalid", func(t *testing.T) { // Test empty table name. _, err := database.NewStore(database.DialectSQLite3, "") - check.HasError(t, err) + require.Error(t, err) // Test unknown dialect. _, err = database.NewStore("unknown-dialect", "foo") - check.HasError(t, err) + require.Error(t, err) // Test empty dialect. _, err = database.NewStore("", "foo") - check.HasError(t, err) + require.Error(t, err) }) // Test generic behavior. t.Run("sqlite3", func(t *testing.T) { db, err := sql.Open("sqlite", ":memory:") - check.NoError(t, err) + require.NoError(t, err) testStore(context.Background(), t, database.DialectSQLite3, db, func(t *testing.T, err error) { var sqliteErr *sqlite.Error ok := errors.As(err, &sqliteErr) - check.Bool(t, ok, true) - check.Equal(t, sqliteErr.Code(), 1) // Generic error (SQLITE_ERROR) - check.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists") + require.True(t, ok) + require.Equal(t, sqliteErr.Code(), 1) // Generic error (SQLITE_ERROR) + require.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists") }) }) t.Run("ListMigrations", func(t *testing.T) { dir := t.TempDir() db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) - check.NoError(t, err) + require.NoError(t, err) store, err := database.NewStore(database.DialectSQLite3, "foo") - check.NoError(t, err) + require.NoError(t, err) err = store.CreateVersionTable(context.Background(), db) - check.NoError(t, err) + require.NoError(t, err) insert := func(db *sql.DB, version int64) error { return store.Insert(context.Background(), db, database.InsertRequest{Version: version}) } - check.NoError(t, insert(db, 1)) - check.NoError(t, insert(db, 3)) - check.NoError(t, insert(db, 2)) + require.NoError(t, insert(db, 1)) + require.NoError(t, insert(db, 3)) + require.NoError(t, insert(db, 2)) res, err := store.ListMigrations(context.Background(), db) - check.NoError(t, err) - check.Number(t, len(res), 3) + require.NoError(t, err) + require.Equal(t, len(res), 3) // Check versions are in descending order: [2, 3, 1] - check.Number(t, res[0].Version, 2) - check.Number(t, res[1].Version, 3) - check.Number(t, res[2].Version, 1) + require.EqualValues(t, res[0].Version, 2) + require.EqualValues(t, res[1].Version, 3) + require.EqualValues(t, res[2].Version, 1) }) } @@ -81,95 +81,95 @@ func testStore( tablename = "test_goose_db_version" ) store, err := database.NewStore(d, tablename) - check.NoError(t, err) + require.NoError(t, err) // Create the version table. err = runTx(ctx, db, func(tx *sql.Tx) error { return store.CreateVersionTable(ctx, tx) }) - check.NoError(t, err) + require.NoError(t, err) // Create the version table again. This should fail. err = runTx(ctx, db, func(tx *sql.Tx) error { return store.CreateVersionTable(ctx, tx) }) - check.HasError(t, err) + require.Error(t, err) if alreadyExists != nil { alreadyExists(t, err) } // Get the latest version. There should be none. _, err = store.GetLatestVersion(ctx, db) - check.IsError(t, err, database.ErrVersionNotFound) + require.ErrorIs(t, err, database.ErrVersionNotFound) // List migrations. There should be none. err = runConn(ctx, db, func(conn *sql.Conn) error { res, err := store.ListMigrations(ctx, conn) - check.NoError(t, err) - check.Number(t, len(res), 0) + require.NoError(t, err) + require.Equal(t, len(res), 0) return nil }) - check.NoError(t, err) + require.NoError(t, err) // Insert 5 migrations in addition to the zero migration. for i := 0; i < 6; i++ { err = runConn(ctx, db, func(conn *sql.Conn) error { err := store.Insert(ctx, conn, database.InsertRequest{Version: int64(i)}) - check.NoError(t, err) + require.NoError(t, err) latest, err := store.GetLatestVersion(ctx, conn) - check.NoError(t, err) - check.Number(t, latest, int64(i)) + require.NoError(t, err) + require.Equal(t, latest, int64(i)) return nil }) - check.NoError(t, err) + require.NoError(t, err) } // List migrations. There should be 6. err = runConn(ctx, db, func(conn *sql.Conn) error { res, err := store.ListMigrations(ctx, conn) - check.NoError(t, err) - check.Number(t, len(res), 6) + require.NoError(t, err) + require.Equal(t, len(res), 6) // Check versions are in descending order. for i := 0; i < 6; i++ { - check.Number(t, res[i].Version, 5-i) + require.EqualValues(t, res[i].Version, 5-i) } return nil }) - check.NoError(t, err) + require.NoError(t, err) // Delete 3 migrations backwards for i := 5; i >= 3; i-- { err = runConn(ctx, db, func(conn *sql.Conn) error { err := store.Delete(ctx, conn, int64(i)) - check.NoError(t, err) + require.NoError(t, err) latest, err := store.GetLatestVersion(ctx, conn) - check.NoError(t, err) - check.Number(t, latest, int64(i-1)) + require.NoError(t, err) + require.Equal(t, latest, int64(i-1)) return nil }) - check.NoError(t, err) + require.NoError(t, err) } // List migrations. There should be 3. err = runConn(ctx, db, func(conn *sql.Conn) error { res, err := store.ListMigrations(ctx, conn) - check.NoError(t, err) - check.Number(t, len(res), 3) + require.NoError(t, err) + require.Equal(t, len(res), 3) // Check that the remaining versions are in descending order. for i := 0; i < 3; i++ { - check.Number(t, res[i].Version, 2-i) + require.EqualValues(t, res[i].Version, 2-i) } return nil }) - check.NoError(t, err) + require.NoError(t, err) // Get remaining migrations one by one. for i := 0; i < 3; i++ { err = runConn(ctx, db, func(conn *sql.Conn) error { res, err := store.GetMigration(ctx, conn, int64(i)) - check.NoError(t, err) - check.Equal(t, res.IsApplied, true) - check.Equal(t, res.Timestamp.IsZero(), false) + require.NoError(t, err) + require.Equal(t, res.IsApplied, true) + require.Equal(t, res.Timestamp.IsZero(), false) return nil }) - check.NoError(t, err) + require.NoError(t, err) } // Delete remaining migrations one by one and use all 3 connection types: @@ -177,46 +177,46 @@ func testStore( // 1. *sql.Tx err = runTx(ctx, db, func(tx *sql.Tx) error { err := store.Delete(ctx, tx, 2) - check.NoError(t, err) + require.NoError(t, err) latest, err := store.GetLatestVersion(ctx, tx) - check.NoError(t, err) - check.Number(t, latest, 1) + require.NoError(t, err) + require.EqualValues(t, latest, 1) return nil }) - check.NoError(t, err) + require.NoError(t, err) // 2. *sql.Conn err = runConn(ctx, db, func(conn *sql.Conn) error { err := store.Delete(ctx, conn, 1) - check.NoError(t, err) + require.NoError(t, err) latest, err := store.GetLatestVersion(ctx, conn) - check.NoError(t, err) - check.Number(t, latest, 0) + require.NoError(t, err) + require.EqualValues(t, latest, 0) return nil }) - check.NoError(t, err) + require.NoError(t, err) // 3. *sql.DB err = store.Delete(ctx, db, 0) - check.NoError(t, err) + require.NoError(t, err) _, err = store.GetLatestVersion(ctx, db) - check.IsError(t, err, database.ErrVersionNotFound) + require.ErrorIs(t, err, database.ErrVersionNotFound) // List migrations. There should be none. err = runConn(ctx, db, func(conn *sql.Conn) error { res, err := store.ListMigrations(ctx, conn) - check.NoError(t, err) - check.Number(t, len(res), 0) + require.NoError(t, err) + require.Equal(t, len(res), 0) return nil }) - check.NoError(t, err) + require.NoError(t, err) // Try to get a migration that does not exist. err = runConn(ctx, db, func(conn *sql.Conn) error { _, err := store.GetMigration(ctx, conn, 0) - check.HasError(t, err) - check.Bool(t, errors.Is(err, database.ErrVersionNotFound), true) + require.Error(t, err) + require.True(t, errors.Is(err, database.ErrVersionNotFound)) return nil }) - check.NoError(t, err) + require.NoError(t, err) } func runTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (retErr error) { diff --git a/globals_test.go b/globals_test.go index 5ea5a0024..638e830ba 100644 --- a/globals_test.go +++ b/globals_test.go @@ -5,31 +5,31 @@ import ( "database/sql" "testing" - "github.com/pressly/goose/v3/internal/check" + "github.com/stretchr/testify/require" ) func TestNewGoMigration(t *testing.T) { t.Run("valid_both_nil", func(t *testing.T) { m := NewGoMigration(1, nil, nil) // roundtrip - check.Equal(t, m.Version, int64(1)) - check.Equal(t, m.Type, TypeGo) - check.Equal(t, m.Registered, true) - check.Equal(t, m.Next, int64(-1)) - check.Equal(t, m.Previous, int64(-1)) - check.Equal(t, m.Source, "") - check.Bool(t, m.UpFnNoTxContext == nil, true) - check.Bool(t, m.DownFnNoTxContext == nil, true) - check.Bool(t, m.UpFnContext == nil, true) - check.Bool(t, m.DownFnContext == nil, true) - check.Bool(t, m.UpFn == nil, true) - check.Bool(t, m.DownFn == nil, true) - check.Bool(t, m.UpFnNoTx == nil, true) - check.Bool(t, m.DownFnNoTx == nil, true) - check.Bool(t, m.goUp != nil, true) - check.Bool(t, m.goDown != nil, true) - check.Equal(t, m.goUp.Mode, TransactionEnabled) - check.Equal(t, m.goDown.Mode, TransactionEnabled) + require.Equal(t, m.Version, int64(1)) + require.Equal(t, m.Type, TypeGo) + require.Equal(t, m.Registered, true) + require.Equal(t, m.Next, int64(-1)) + require.Equal(t, m.Previous, int64(-1)) + require.Equal(t, m.Source, "") + require.Nil(t, m.UpFnNoTxContext) + require.Nil(t, m.DownFnNoTxContext) + require.Nil(t, m.UpFnContext) + require.Nil(t, m.DownFnContext) + require.Nil(t, m.UpFn) + require.Nil(t, m.DownFn) + require.Nil(t, m.UpFnNoTx) + require.Nil(t, m.DownFnNoTx) + require.True(t, m.goUp != nil) + require.True(t, m.goDown != nil) + require.Equal(t, m.goUp.Mode, TransactionEnabled) + require.Equal(t, m.goDown.Mode, TransactionEnabled) }) t.Run("all_set", func(t *testing.T) { // This will eventually be an error when registering migrations. @@ -39,14 +39,14 @@ func TestNewGoMigration(t *testing.T) { &GoFunc{RunTx: func(context.Context, *sql.Tx) error { return nil }, RunDB: func(context.Context, *sql.DB) error { return nil }}, ) // check only functions - check.Bool(t, m.UpFn != nil, true) - check.Bool(t, m.UpFnContext != nil, true) - check.Bool(t, m.UpFnNoTx != nil, true) - check.Bool(t, m.UpFnNoTxContext != nil, true) - check.Bool(t, m.DownFn != nil, true) - check.Bool(t, m.DownFnContext != nil, true) - check.Bool(t, m.DownFnNoTx != nil, true) - check.Bool(t, m.DownFnNoTxContext != nil, true) + require.True(t, m.UpFn != nil) + require.True(t, m.UpFnContext != nil) + require.True(t, m.UpFnNoTx != nil) + require.True(t, m.UpFnNoTxContext != nil) + require.True(t, m.DownFn != nil) + require.True(t, m.DownFnContext != nil) + require.True(t, m.DownFnNoTx != nil) + require.True(t, m.DownFnNoTxContext != nil) }) } @@ -59,67 +59,67 @@ func TestTransactionMode(t *testing.T) { err := SetGlobalMigrations( NewGoMigration(1, &GoFunc{RunTx: runTx, RunDB: runDB}, nil), // cannot specify both ) - check.HasError(t, err) - check.Contains(t, err.Error(), "up function: must specify exactly one of RunTx or RunDB") + require.Error(t, err) + require.Contains(t, err.Error(), "up function: must specify exactly one of RunTx or RunDB") err = SetGlobalMigrations( NewGoMigration(1, nil, &GoFunc{RunTx: runTx, RunDB: runDB}), // cannot specify both ) - check.HasError(t, err) - check.Contains(t, err.Error(), "down function: must specify exactly one of RunTx or RunDB") + require.Error(t, err) + require.Contains(t, err.Error(), "down function: must specify exactly one of RunTx or RunDB") err = SetGlobalMigrations( NewGoMigration(1, &GoFunc{RunTx: runTx, Mode: TransactionDisabled}, nil), // invalid explicit mode tx ) - check.HasError(t, err) - check.Contains(t, err.Error(), "up function: transaction mode must be enabled or unspecified when RunTx is set") + require.Error(t, err) + require.Contains(t, err.Error(), "up function: transaction mode must be enabled or unspecified when RunTx is set") err = SetGlobalMigrations( NewGoMigration(1, nil, &GoFunc{RunTx: runTx, Mode: TransactionDisabled}), // invalid explicit mode tx ) - check.HasError(t, err) - check.Contains(t, err.Error(), "down function: transaction mode must be enabled or unspecified when RunTx is set") + require.Error(t, err) + require.Contains(t, err.Error(), "down function: transaction mode must be enabled or unspecified when RunTx is set") err = SetGlobalMigrations( NewGoMigration(1, &GoFunc{RunDB: runDB, Mode: TransactionEnabled}, nil), // invalid explicit mode no-tx ) - check.HasError(t, err) - check.Contains(t, err.Error(), "up function: transaction mode must be disabled or unspecified when RunDB is set") + require.Error(t, err) + require.Contains(t, err.Error(), "up function: transaction mode must be disabled or unspecified when RunDB is set") err = SetGlobalMigrations( NewGoMigration(1, nil, &GoFunc{RunDB: runDB, Mode: TransactionEnabled}), // invalid explicit mode no-tx ) - check.HasError(t, err) - check.Contains(t, err.Error(), "down function: transaction mode must be disabled or unspecified when RunDB is set") + require.Error(t, err) + require.Contains(t, err.Error(), "down function: transaction mode must be disabled or unspecified when RunDB is set") t.Run("default_mode", func(t *testing.T) { t.Cleanup(ResetGlobalMigrations) m := NewGoMigration(1, nil, nil) err = SetGlobalMigrations(m) - check.NoError(t, err) - check.Number(t, len(registeredGoMigrations), 1) + require.NoError(t, err) + require.Equal(t, len(registeredGoMigrations), 1) registered := registeredGoMigrations[1] - check.Bool(t, registered.goUp != nil, true) - check.Bool(t, registered.goDown != nil, true) - check.Equal(t, registered.goUp.Mode, TransactionEnabled) - check.Equal(t, registered.goDown.Mode, TransactionEnabled) + require.True(t, registered.goUp != nil) + require.True(t, registered.goDown != nil) + require.Equal(t, registered.goUp.Mode, TransactionEnabled) + require.Equal(t, registered.goDown.Mode, TransactionEnabled) migration2 := NewGoMigration(2, nil, nil) // reset so we can check the default is set migration2.goUp.Mode, migration2.goDown.Mode = 0, 0 err = SetGlobalMigrations(migration2) - check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: up function: invalid mode: 0") + require.Error(t, err) + require.Contains(t, err.Error(), "invalid go migration: up function: invalid mode: 0") migration3 := NewGoMigration(3, nil, nil) // reset so we can check the default is set migration3.goDown.Mode = 0 err = SetGlobalMigrations(migration3) - check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: down function: invalid mode: 0") + require.Error(t, err) + require.Contains(t, err.Error(), "invalid go migration: down function: invalid mode: 0") }) t.Run("unknown_mode", func(t *testing.T) { m := NewGoMigration(1, nil, nil) m.goUp.Mode, m.goDown.Mode = 3, 3 // reset to default err := SetGlobalMigrations(m) - check.HasError(t, err) - check.Contains(t, err.Error(), "invalid mode: 3") + require.Error(t, err) + require.Contains(t, err.Error(), "invalid mode: 3") }) } @@ -131,12 +131,12 @@ func TestLegacyFunctions(t *testing.T) { assertMigration := func(t *testing.T, m *Migration, version int64) { t.Helper() - check.Equal(t, m.Version, version) - check.Equal(t, m.Type, TypeGo) - check.Equal(t, m.Registered, true) - check.Equal(t, m.Next, int64(-1)) - check.Equal(t, m.Previous, int64(-1)) - check.Equal(t, m.Source, "") + require.Equal(t, m.Version, version) + require.Equal(t, m.Type, TypeGo) + require.Equal(t, m.Registered, true) + require.Equal(t, m.Next, int64(-1)) + require.Equal(t, m.Previous, int64(-1)) + require.Equal(t, m.Source, "") } t.Run("all_tx", func(t *testing.T) { @@ -144,46 +144,46 @@ func TestLegacyFunctions(t *testing.T) { err := SetGlobalMigrations( NewGoMigration(1, &GoFunc{RunTx: runTx}, &GoFunc{RunTx: runTx}), ) - check.NoError(t, err) - check.Number(t, len(registeredGoMigrations), 1) + require.NoError(t, err) + require.Equal(t, len(registeredGoMigrations), 1) m := registeredGoMigrations[1] assertMigration(t, m, 1) // Legacy functions. - check.Bool(t, m.UpFnNoTxContext == nil, true) - check.Bool(t, m.DownFnNoTxContext == nil, true) + require.Nil(t, m.UpFnNoTxContext) + require.Nil(t, m.DownFnNoTxContext) // Context-aware functions. - check.Bool(t, m.goUp == nil, false) - check.Bool(t, m.UpFnContext == nil, false) - check.Bool(t, m.goDown == nil, false) - check.Bool(t, m.DownFnContext == nil, false) + require.NotNil(t, m.goUp) + require.NotNil(t, m.UpFnContext) + require.NotNil(t, m.goDown) + require.NotNil(t, m.DownFnContext) // Always nil - check.Bool(t, m.UpFn == nil, false) - check.Bool(t, m.DownFn == nil, false) - check.Bool(t, m.UpFnNoTx == nil, true) - check.Bool(t, m.DownFnNoTx == nil, true) + require.NotNil(t, m.UpFn) + require.NotNil(t, m.DownFn) + require.Nil(t, m.UpFnNoTx) + require.Nil(t, m.DownFnNoTx) }) t.Run("all_db", func(t *testing.T) { t.Cleanup(ResetGlobalMigrations) err := SetGlobalMigrations( NewGoMigration(2, &GoFunc{RunDB: runDB}, &GoFunc{RunDB: runDB}), ) - check.NoError(t, err) - check.Number(t, len(registeredGoMigrations), 1) + require.NoError(t, err) + require.Equal(t, len(registeredGoMigrations), 1) m := registeredGoMigrations[2] assertMigration(t, m, 2) // Legacy functions. - check.Bool(t, m.UpFnNoTxContext == nil, false) - check.Bool(t, m.goUp == nil, false) - check.Bool(t, m.DownFnNoTxContext == nil, false) - check.Bool(t, m.goDown == nil, false) + require.NotNil(t, m.UpFnNoTxContext) + require.NotNil(t, m.goUp) + require.NotNil(t, m.DownFnNoTxContext) + require.NotNil(t, m.goDown) // Context-aware functions. - check.Bool(t, m.UpFnContext == nil, true) - check.Bool(t, m.DownFnContext == nil, true) + require.Nil(t, m.UpFnContext) + require.Nil(t, m.DownFnContext) // Always nil - check.Bool(t, m.UpFn == nil, true) - check.Bool(t, m.DownFn == nil, true) - check.Bool(t, m.UpFnNoTx == nil, false) - check.Bool(t, m.DownFnNoTx == nil, false) + require.Nil(t, m.UpFn) + require.Nil(t, m.DownFn) + require.NotNil(t, m.UpFnNoTx) + require.NotNil(t, m.DownFnNoTx) }) } @@ -195,91 +195,91 @@ func TestGlobalRegister(t *testing.T) { // Success. err := SetGlobalMigrations([]*Migration{}...) - check.NoError(t, err) + require.NoError(t, err) err = SetGlobalMigrations( NewGoMigration(1, &GoFunc{RunTx: runTx}, nil), ) - check.NoError(t, err) + require.NoError(t, err) // Try to register the same migration again. err = SetGlobalMigrations( NewGoMigration(1, &GoFunc{RunTx: runTx}, nil), ) - check.HasError(t, err) - check.Contains(t, err.Error(), "go migration with version 1 already registered") + require.Error(t, err) + require.Contains(t, err.Error(), "go migration with version 1 already registered") err = SetGlobalMigrations(&Migration{Registered: true, Version: 2, Type: TypeGo}) - check.HasError(t, err) - check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations") + require.Error(t, err) + require.Contains(t, err.Error(), "must use NewGoMigration to construct migrations") } func TestCheckMigration(t *testing.T) { // Success. err := checkGoMigration(NewGoMigration(1, nil, nil)) - check.NoError(t, err) + require.NoError(t, err) // Failures. err = checkGoMigration(&Migration{}) - check.HasError(t, err) - check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations") + require.Error(t, err) + require.Contains(t, err.Error(), "must use NewGoMigration to construct migrations") err = checkGoMigration(&Migration{construct: true}) - check.HasError(t, err) - check.Contains(t, err.Error(), "must be registered") + require.Error(t, err) + require.Contains(t, err.Error(), "must be registered") err = checkGoMigration(&Migration{construct: true, Registered: true}) - check.HasError(t, err) - check.Contains(t, err.Error(), `type must be "go"`) + require.Error(t, err) + require.Contains(t, err.Error(), `type must be "go"`) err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo}) - check.HasError(t, err) - check.Contains(t, err.Error(), "version must be greater than zero") + require.Error(t, err) + require.Contains(t, err.Error(), "version must be greater than zero") err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{}, goDown: &GoFunc{}}) - check.HasError(t, err) - check.Contains(t, err.Error(), "up function: invalid mode: 0") + require.Error(t, err) + require.Contains(t, err.Error(), "up function: invalid mode: 0") err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{}}) - check.HasError(t, err) - check.Contains(t, err.Error(), "down function: invalid mode: 0") + require.Error(t, err) + require.Contains(t, err.Error(), "down function: invalid mode: 0") // Success. err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{Mode: TransactionEnabled}}) - check.NoError(t, err) + require.NoError(t, err) // Failures. err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"}) - check.HasError(t, err) - check.Contains(t, err.Error(), `source must have .go extension: "foo"`) + require.Error(t, err) + require.Contains(t, err.Error(), `source must have .go extension: "foo"`) err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"}) - check.HasError(t, err) - check.Contains(t, err.Error(), `no filename separator '_' found`) + require.Error(t, err) + require.Contains(t, err.Error(), `no filename separator '_' found`) err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"}) - check.HasError(t, err) - check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`) + require.Error(t, err) + require.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`) err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"}) - check.HasError(t, err) - check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`) + require.Error(t, err) + require.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`) err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, UpFnContext: func(context.Context, *sql.Tx) error { return nil }, UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil }, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{Mode: TransactionEnabled}, }) - check.HasError(t, err) - check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext") + require.Error(t, err) + require.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext") err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, DownFnContext: func(context.Context, *sql.Tx) error { return nil }, DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil }, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{Mode: TransactionEnabled}, }) - check.HasError(t, err) - check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext") + require.Error(t, err) + require.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext") err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, UpFn: func(*sql.Tx) error { return nil }, UpFnNoTx: func(*sql.DB) error { return nil }, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{Mode: TransactionEnabled}, }) - check.HasError(t, err) - check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx") + require.Error(t, err) + require.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx") err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, DownFn: func(*sql.Tx) error { return nil }, DownFnNoTx: func(*sql.DB) error { return nil }, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{Mode: TransactionEnabled}, }) - check.HasError(t, err) - check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx") + require.Error(t, err) + require.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx") } diff --git a/goose_cli_test.go b/goose_cli_test.go index 020e48023..8b22b4a09 100644 --- a/goose_cli_test.go +++ b/goose_cli_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - "github.com/pressly/goose/v3/internal/check" + "github.com/stretchr/testify/require" _ "modernc.org/sqlite" ) @@ -23,8 +23,8 @@ func TestFullBinary(t *testing.T) { t.Parallel() cli := buildGooseCLI(t, false) out, err := cli.run("--version") - check.NoError(t, err) - check.Equal(t, out, "goose version: "+gooseTestBinaryVersion+"\n") + require.NoError(t, err) + require.Equal(t, out, "goose version: "+gooseTestBinaryVersion+"\n") } func TestLiteBinary(t *testing.T) { @@ -34,8 +34,8 @@ func TestLiteBinary(t *testing.T) { t.Run("binary_version", func(t *testing.T) { t.Parallel() out, err := cli.run("--version") - check.NoError(t, err) - check.Equal(t, out, "goose version: "+gooseTestBinaryVersion+"\n") + require.NoError(t, err) + require.Equal(t, out, "goose version: "+gooseTestBinaryVersion+"\n") }) t.Run("default_binary", func(t *testing.T) { t.Parallel() @@ -55,8 +55,8 @@ func TestLiteBinary(t *testing.T) { } for _, c := range commands { out, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), c.cmd) - check.NoError(t, err) - check.Contains(t, out, c.out) + require.NoError(t, err) + require.Contains(t, out, c.out) } }) t.Run("gh_issue_532", func(t *testing.T) { @@ -65,13 +65,13 @@ func TestLiteBinary(t *testing.T) { dir := t.TempDir() total := countSQLFiles(t, "testdata/migrations") _, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), "up") - check.NoError(t, err) + require.NoError(t, err) out, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), "up") - check.NoError(t, err) - check.Contains(t, out, "goose: no migrations to run. current version: "+strconv.Itoa(total)) + require.NoError(t, err) + require.Contains(t, out, "goose: no migrations to run. current version: "+strconv.Itoa(total)) out, err = cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), "version") - check.NoError(t, err) - check.Contains(t, out, "goose: version "+strconv.Itoa(total)) + require.NoError(t, err) + require.Contains(t, out, "goose: version "+strconv.Itoa(total)) }) t.Run("gh_issue_293", func(t *testing.T) { // https://github.com/pressly/goose/issues/293 @@ -92,8 +92,8 @@ func TestLiteBinary(t *testing.T) { } for _, c := range commands { out, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), c.cmd) - check.NoError(t, err) - check.Contains(t, out, c.out) + require.NoError(t, err) + require.Contains(t, out, c.out) } }) t.Run("gh_issue_336", func(t *testing.T) { @@ -101,8 +101,8 @@ func TestLiteBinary(t *testing.T) { t.Parallel() dir := t.TempDir() _, err := cli.run("-dir="+dir, "sqlite3", filepath.Join(dir, "sql.db"), "up") - check.HasError(t, err) - check.Contains(t, err.Error(), "goose run: no migration files found") + require.Error(t, err) + require.Contains(t, err.Error(), "goose run: no migration files found") }) t.Run("create_and_fix", func(t *testing.T) { t.Parallel() @@ -112,8 +112,8 @@ func TestLiteBinary(t *testing.T) { createEmptyFile(t, dir, "20230826163141_charlie.sql") createEmptyFile(t, dir, "20230826163151_delta.go") total, err := os.ReadDir(dir) - check.NoError(t, err) - check.Number(t, len(total), 4) + require.NoError(t, err) + require.Equal(t, len(total), 4) migrationFiles := []struct { name string fileType string @@ -128,22 +128,22 @@ func TestLiteBinary(t *testing.T) { args = append(args, f.fileType) } out, err := cli.run(args...) - check.NoError(t, err) - check.Contains(t, out, "Created new file") + require.NoError(t, err) + require.Contains(t, out, "Created new file") // ensure different timestamps, granularity is 1 second if i < len(migrationFiles)-1 { time.Sleep(1100 * time.Millisecond) } } total, err = os.ReadDir(dir) - check.NoError(t, err) - check.Number(t, len(total), 7) + require.NoError(t, err) + require.Equal(t, len(total), 7) out, err := cli.run("-dir="+dir, "fix") - check.NoError(t, err) - check.Contains(t, out, "RENAMED") + require.NoError(t, err) + require.Contains(t, out, "RENAMED") files, err := os.ReadDir(dir) - check.NoError(t, err) - check.Number(t, len(files), 7) + require.NoError(t, err) + require.Equal(t, len(files), 7) expected := []string{ "00001_alpha.sql", "00003_bravo.sql", @@ -154,7 +154,7 @@ func TestLiteBinary(t *testing.T) { "00008_golf.go", } for i, f := range files { - check.Equal(t, f.Name(), expected[i]) + require.Equal(t, f.Name(), expected[i]) } }) } @@ -201,7 +201,7 @@ func buildGooseCLI(t *testing.T, lite bool) gooseBinary { func countSQLFiles(t *testing.T, dir string) int { t.Helper() files, err := filepath.Glob(filepath.Join(dir, "*.sql")) - check.NoError(t, err) + require.NoError(t, err) return len(files) } @@ -209,6 +209,6 @@ func createEmptyFile(t *testing.T, dir, name string) { t.Helper() path := filepath.Join(dir, name) f, err := os.Create(path) - check.NoError(t, err) + require.NoError(t, err) defer f.Close() } diff --git a/goose_embed_test.go b/goose_embed_test.go index a09f87992..f5055a7bd 100644 --- a/goose_embed_test.go +++ b/goose_embed_test.go @@ -9,7 +9,7 @@ import ( "testing" "github.com/pressly/goose/v3" - "github.com/pressly/goose/v3/internal/check" + "github.com/stretchr/testify/require" _ "modernc.org/sqlite" ) @@ -20,43 +20,43 @@ func TestEmbeddedMigrations(t *testing.T) { dir := t.TempDir() // not using t.Parallel here to avoid races db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) - check.NoError(t, err) + require.NoError(t, err) db.SetMaxOpenConns(1) migrationFiles, err := fs.ReadDir(embedMigrations, "testdata/migrations") - check.NoError(t, err) + require.NoError(t, err) total := len(migrationFiles) // decouple from existing structure fsys, err := fs.Sub(embedMigrations, "testdata/migrations") - check.NoError(t, err) + require.NoError(t, err) goose.SetBaseFS(fsys) t.Cleanup(func() { goose.SetBaseFS(nil) }) - check.NoError(t, goose.SetDialect("sqlite3")) + require.NoError(t, goose.SetDialect("sqlite3")) t.Run("migration_cycle", func(t *testing.T) { err := goose.Up(db, ".") - check.NoError(t, err) + require.NoError(t, err) ver, err := goose.GetDBVersion(db) - check.NoError(t, err) - check.Number(t, ver, total) + require.NoError(t, err) + require.EqualValues(t, ver, total) err = goose.Reset(db, ".") - check.NoError(t, err) + require.NoError(t, err) ver, err = goose.GetDBVersion(db) - check.NoError(t, err) - check.Number(t, ver, 0) + require.NoError(t, err) + require.EqualValues(t, ver, 0) }) t.Run("create_uses_os_fs", func(t *testing.T) { dir := t.TempDir() err := goose.Create(db, dir, "test", "sql") - check.NoError(t, err) + require.NoError(t, err) paths, _ := filepath.Glob(filepath.Join(dir, "*test.sql")) - check.NumberNotZero(t, len(paths)) + require.NotZero(t, len(paths)) err = goose.Fix(dir) - check.NoError(t, err) + require.NoError(t, err) _, err = os.Stat(filepath.Join(dir, "00001_test.sql")) - check.NoError(t, err) + require.NoError(t, err) }) } diff --git a/internal/check/check.go b/internal/check/check.go deleted file mode 100644 index 76dfac7d6..000000000 --- a/internal/check/check.go +++ /dev/null @@ -1,86 +0,0 @@ -package check - -import ( - "errors" - "fmt" - "reflect" - "strings" - "testing" -) - -func NoError(t *testing.T, err error) { - t.Helper() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } -} - -func HasError(t *testing.T, err error) { - t.Helper() - if err == nil { - t.Fatal("expecting an error: got nil") - } -} - -func IsError(t *testing.T, err, target error) { - t.Helper() - if !errors.Is(err, target) { - t.Fatalf("expecting specific error:\ngot: %v\nwant: %v", err, target) - - } -} - -func Number(t *testing.T, got, want interface{}) { - t.Helper() - gotNumber, err := reflectToInt64(got) - if err != nil { - t.Fatal(err) - } - wantNumber, err := reflectToInt64(want) - if err != nil { - t.Fatal(err) - } - if gotNumber != wantNumber { - t.Fatalf("unexpected number value: got:%d want:%d ", gotNumber, wantNumber) - } -} - -func Equal(t *testing.T, got, want interface{}) { - t.Helper() - if !reflect.DeepEqual(got, want) { - t.Fatalf("failed deep equal:\ngot:\t%v\nwant:\t%v\v", got, want) - } -} - -func NumberNotZero(t *testing.T, got interface{}) { - t.Helper() - gotNumber, err := reflectToInt64(got) - if err != nil { - t.Fatal(err) - } - if gotNumber == 0 { - t.Fatalf("unexpected number value: got:%d want non-zero ", gotNumber) - } -} - -func Bool(t *testing.T, got, want bool) { - t.Helper() - if got != want { - t.Fatalf("unexpected boolean value: got:%t want:%t", got, want) - } -} - -func Contains(t *testing.T, got, want string) { - t.Helper() - if !strings.Contains(got, want) { - t.Errorf("failed to find substring:\n%s\n\nin string value:\n%s", got, want) - } -} - -func reflectToInt64(v interface{}) (int64, error) { - switch typ := v.(type) { - case int, int8, int16, int32, int64: - return reflect.ValueOf(typ).Int(), nil - } - return 0, fmt.Errorf("invalid number: must be int64 type: got:%T", v) -} diff --git a/internal/migrationstats/migrationstats_test.go b/internal/migrationstats/migrationstats_test.go index 03830727d..c24c83e06 100644 --- a/internal/migrationstats/migrationstats_test.go +++ b/internal/migrationstats/migrationstats_test.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - "github.com/pressly/goose/v3/internal/check" + "github.com/stretchr/testify/require" ) func TestParsingGoMigrations(t *testing.T) { @@ -31,11 +31,11 @@ func TestParsingGoMigrations(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { g, err := parseGoFile(strings.NewReader(tc.input)) - check.NoError(t, err) - check.Equal(t, g.useTx != nil, true) - check.Bool(t, *g.useTx, tc.wantTx) - check.Equal(t, g.downFuncName, tc.wantDownName) - check.Equal(t, g.upFuncName, tc.wantUpName) + require.NoError(t, err) + require.Equal(t, g.useTx != nil, true) + require.Equal(t, *g.useTx, tc.wantTx) + require.Equal(t, g.downFuncName, tc.wantDownName) + require.Equal(t, g.upFuncName, tc.wantUpName) }) } } @@ -45,15 +45,15 @@ func TestGoMigrationStats(t *testing.T) { base := "../../tests/gomigrations/success/testdata" all, err := os.ReadDir(base) - check.NoError(t, err) - check.Equal(t, len(all), 16) + require.NoError(t, err) + require.Equal(t, len(all), 16) files := make([]string, 0, len(all)) for _, f := range all { files = append(files, filepath.Join(base, f.Name())) } stats, err := GatherStats(NewFileWalker(files...), false) - check.NoError(t, err) - check.Equal(t, len(stats), 16) + require.NoError(t, err) + require.Equal(t, len(stats), 16) checkGoStats(t, stats[0], "001_up_down.go", 1, 1, 1, true) checkGoStats(t, stats[1], "002_up_only.go", 2, 1, 0, true) checkGoStats(t, stats[2], "003_down_only.go", 3, 0, 1, true) @@ -74,22 +74,22 @@ func TestGoMigrationStats(t *testing.T) { func checkGoStats(t *testing.T, stats *Stats, filename string, version int64, upCount, downCount int, tx bool) { t.Helper() - check.Equal(t, filepath.Base(stats.FileName), filename) - check.Equal(t, stats.Version, version) - check.Equal(t, stats.UpCount, upCount) - check.Equal(t, stats.DownCount, downCount) - check.Equal(t, stats.Tx, tx) + require.Equal(t, filepath.Base(stats.FileName), filename) + require.Equal(t, stats.Version, version) + require.Equal(t, stats.UpCount, upCount) + require.Equal(t, stats.DownCount, downCount) + require.Equal(t, stats.Tx, tx) } func TestParsingGoMigrationsError(t *testing.T) { t.Parallel() _, err := parseGoFile(strings.NewReader(emptyInit)) - check.HasError(t, err) - check.Contains(t, err.Error(), "no registered goose functions") + require.Error(t, err) + require.Contains(t, err.Error(), "no registered goose functions") _, err = parseGoFile(strings.NewReader(wrongName)) - check.HasError(t, err) - check.Contains(t, err.Error(), "AddMigration, AddMigrationNoTx, AddMigrationContext, AddMigrationNoTxContext") + require.Error(t, err) + require.Contains(t, err.Error(), "AddMigration, AddMigrationNoTx, AddMigrationContext, AddMigrationNoTxContext") } var ( diff --git a/internal/sqlparser/parse_test.go b/internal/sqlparser/parse_test.go index 632bbe13b..0ea1ae406 100644 --- a/internal/sqlparser/parse_test.go +++ b/internal/sqlparser/parse_test.go @@ -1,13 +1,12 @@ package sqlparser_test import ( - "errors" "os" "testing" "testing/fstest" - "github.com/pressly/goose/v3/internal/check" "github.com/pressly/goose/v3/internal/sqlparser" + "github.com/stretchr/testify/require" ) func TestParseAllFromFS(t *testing.T) { @@ -15,17 +14,17 @@ func TestParseAllFromFS(t *testing.T) { t.Run("file_not_exist", func(t *testing.T) { mapFS := fstest.MapFS{} _, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false) - check.HasError(t, err) - check.Bool(t, errors.Is(err, os.ErrNotExist), true) + require.Error(t, err) + require.ErrorIs(t, err, os.ErrNotExist) }) t.Run("empty_file", func(t *testing.T) { mapFS := fstest.MapFS{ "001_foo.sql": &fstest.MapFile{}, } _, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false) - check.HasError(t, err) - check.Contains(t, err.Error(), "failed to parse migration") - check.Contains(t, err.Error(), "must start with '-- +goose Up' annotation") + require.Error(t, err) + require.Contains(t, err.Error(), "failed to parse migration") + require.Contains(t, err.Error(), "must start with '-- +goose Up' annotation") }) t.Run("all_statements", func(t *testing.T) { mapFS := fstest.MapFS{ @@ -53,26 +52,26 @@ DROP TABLE foo; `), } parsedSQL, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false) - check.NoError(t, err) + require.NoError(t, err) assertParsedSQL(t, parsedSQL, true, 0, 0) parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "002_bar.sql", false) - check.NoError(t, err) + require.NoError(t, err) assertParsedSQL(t, parsedSQL, true, 0, 0) parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "003_baz.sql", false) - check.NoError(t, err) + require.NoError(t, err) assertParsedSQL(t, parsedSQL, true, 2, 1) parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "004_qux.sql", false) - check.NoError(t, err) + require.NoError(t, err) assertParsedSQL(t, parsedSQL, false, 1, 1) }) } func assertParsedSQL(t *testing.T, got *sqlparser.ParsedSQL, useTx bool, up, down int) { t.Helper() - check.Bool(t, got != nil, true) - check.Equal(t, len(got.Up), up) - check.Equal(t, len(got.Down), down) - check.Equal(t, got.UseTx, useTx) + require.NotNil(t, got) + require.Equal(t, len(got.Up), up) + require.Equal(t, len(got.Down), down) + require.Equal(t, got.UseTx, useTx) } func newFile(data string) *fstest.MapFile { diff --git a/internal/sqlparser/parser_test.go b/internal/sqlparser/parser_test.go index 155717f5e..a8f24f2ac 100644 --- a/internal/sqlparser/parser_test.go +++ b/internal/sqlparser/parser_test.go @@ -8,7 +8,7 @@ import ( "strings" "testing" - "github.com/pressly/goose/v3/internal/check" + "github.com/stretchr/testify/require" ) var ( @@ -91,14 +91,14 @@ func TestInvalidUp(t *testing.T) { testdataDir := filepath.Join("testdata", "invalid", "up") entries, err := os.ReadDir(testdataDir) - check.NoError(t, err) - check.NumberNotZero(t, len(entries)) + require.NoError(t, err) + require.NotZero(t, len(entries)) for _, entry := range entries { by, err := os.ReadFile(filepath.Join(testdataDir, entry.Name())) - check.NoError(t, err) + require.NoError(t, err) _, _, err = ParseSQLMigration(strings.NewReader(string(by)), DirectionUp, false) - check.HasError(t, err) + require.Error(t, err) } } @@ -410,11 +410,11 @@ func testValid(t *testing.T, dir string, count int, direction Direction) { t.Helper() f, err := os.Open(filepath.Join(dir, "input.sql")) - check.NoError(t, err) + require.NoError(t, err) t.Cleanup(func() { f.Close() }) statements, _, err := ParseSQLMigration(f, direction, debug) - check.NoError(t, err) - check.Number(t, len(statements), count) + require.NoError(t, err) + require.Equal(t, len(statements), count) compareStatements(t, dir, statements, direction) } @@ -422,7 +422,7 @@ func compareStatements(t *testing.T, dir string, statements []string, direction t.Helper() files, err := filepath.Glob(filepath.Join(dir, fmt.Sprintf("*.%s.golden.sql", direction))) - check.NoError(t, err) + require.NoError(t, err) if len(statements) != len(files) { t.Fatalf("mismatch between parsed statements (%d) and golden files (%d), did you check in NN.{up|down}.golden.sql file in %q?", len(statements), len(files), dir) } @@ -433,12 +433,12 @@ func compareStatements(t *testing.T, dir string, statements []string, direction t.Fatal(`failed to cut on file delimiter ".", must be of the format NN.{up|down}.golden.sql`) } index, err := strconv.Atoi(before) - check.NoError(t, err) + require.NoError(t, err) index-- goldenFilePath := filepath.Join(dir, goldenFile) by, err := os.ReadFile(goldenFilePath) - check.NoError(t, err) + require.NoError(t, err) got, want := statements[index], string(by) @@ -452,7 +452,7 @@ func compareStatements(t *testing.T, dir string, statements []string, direction filepath.Join("internal", "sqlparser", goldenFilePath), ) err := os.WriteFile(goldenFilePath+".FAIL", []byte(got), 0644) - check.NoError(t, err) + require.NoError(t, err) } } } @@ -504,8 +504,8 @@ CREATE TABLE post ( ); ` _, _, err := ParseSQLMigration(strings.NewReader(s), DirectionUp, debug) - check.HasError(t, err) - check.Contains(t, err.Error(), "variable substitution failed: $SOME_UNSET_VAR: required env var not set:") + require.Error(t, err) + require.Contains(t, err.Error(), "variable substitution failed: $SOME_UNSET_VAR: required env var not set:") } func Test_extractAnnotation(t *testing.T) { @@ -513,76 +513,79 @@ func Test_extractAnnotation(t *testing.T) { name string input string want annotation - wantErr func(t *testing.T, err error) + wantErr bool }{ { name: "Up", input: "-- +goose Up", want: annotationUp, - wantErr: check.NoError, + wantErr: false, }, { name: "Down", input: "-- +goose Down", want: annotationDown, - wantErr: check.NoError, + wantErr: false, }, { name: "StmtBegin", input: "-- +goose StatementBegin", want: annotationStatementBegin, - wantErr: check.NoError, + wantErr: false, }, { name: "NoTransact", input: "-- +goose NO TRANSACTION", want: annotationNoTransaction, - wantErr: check.NoError, + wantErr: false, }, { name: "Unsupported", input: "-- +goose unsupported", want: "", - wantErr: check.HasError, + wantErr: true, }, { name: "Empty", input: "-- +goose", want: "", - wantErr: check.HasError, + wantErr: true, }, { name: "statement with spaces and Uppercase", input: "-- +goose UP ", want: annotationUp, - wantErr: check.NoError, + wantErr: false, }, { name: "statement with leading whitespace - error", input: " -- +goose UP ", want: "", - wantErr: check.HasError, + wantErr: true, }, { name: "statement with leading \t - error", input: "\t-- +goose UP ", want: "", - wantErr: check.HasError, + wantErr: true, }, { name: "multiple +goose annotations - error", input: "-- +goose +goose Up", want: "", - wantErr: check.HasError, + wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := extractAnnotation(tt.input) - tt.wantErr(t, err) - - check.Equal(t, got, tt.want) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, got, tt.want) }) } } diff --git a/internal/testing/integration/postgres_locking_test.go b/internal/testing/integration/postgres_locking_test.go index 76413ba08..992496f88 100644 --- a/internal/testing/integration/postgres_locking_test.go +++ b/internal/testing/integration/postgres_locking_test.go @@ -15,7 +15,6 @@ import ( "time" "github.com/pressly/goose/v3" - "github.com/pressly/goose/v3/internal/check" "github.com/pressly/goose/v3/internal/testing/testdb" "github.com/pressly/goose/v3/lock" "github.com/stretchr/testify/require" @@ -433,23 +432,23 @@ func TestPostgresPending(t *testing.T) { for i := 0; i < workers; i++ { g.Go(func() error { p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS(testDir)) - check.NoError(t, err) + require.NoError(t, err) hasPending, err := p.HasPending(context.Background()) - check.NoError(t, err) + require.NoError(t, err) boolCh <- hasPending current, target, err := p.GetVersions(context.Background()) - check.NoError(t, err) - check.Number(t, current, int64(wantCurrent)) - check.Number(t, target, int64(wantTarget)) + require.NoError(t, err) + require.Equal(t, current, int64(wantCurrent)) + require.Equal(t, target, int64(wantTarget)) return nil }) } - check.NoError(t, g.Wait()) + require.NoError(t, g.Wait()) close(boolCh) // expect all values to be true for hasPending := range boolCh { - check.Bool(t, hasPending, want) + require.Equal(t, hasPending, want) } } t.Run("concurrent_has_pending", func(t *testing.T) { @@ -458,9 +457,9 @@ func TestPostgresPending(t *testing.T) { // apply all migrations p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS("testdata/migrations/postgres")) - check.NoError(t, err) + require.NoError(t, err) _, err = p.Up(context.Background()) - check.NoError(t, err) + require.NoError(t, err) t.Run("concurrent_no_pending", func(t *testing.T) { run(t, false, len(files), len(files)) @@ -480,10 +479,10 @@ SELECT pg_sleep_for('4 seconds'); sessionLocker, err := lock.NewPostgresSessionLocker(lock.WithLockTimeout(1, 10), lock.WithLockID(lockID)) // Timeout 5min. Try every 1s up to 10 times. require.NoError(t, err) newProvider, err := goose.NewProvider(goose.DialectPostgres, db, fsys, goose.WithSessionLocker(sessionLocker)) - check.NoError(t, err) - check.Number(t, len(newProvider.ListSources()), 1) + require.NoError(t, err) + require.Equal(t, len(newProvider.ListSources()), 1) oldProvider := p - check.Number(t, len(oldProvider.ListSources()), len(files)) + require.Equal(t, len(oldProvider.ListSources()), len(files)) var g errgroup.Group g.Go(func() error { @@ -491,13 +490,13 @@ SELECT pg_sleep_for('4 seconds'); if err != nil { return err } - check.Bool(t, hasPending, true) + require.True(t, hasPending) current, target, err := newProvider.GetVersions(context.Background()) if err != nil { return err } - check.Number(t, current, lastVersion) - check.Number(t, target, lastVersion+1) + require.EqualValues(t, current, lastVersion) + require.EqualValues(t, target, lastVersion+1) return nil }) g.Go(func() error { @@ -505,16 +504,16 @@ SELECT pg_sleep_for('4 seconds'); if err != nil { return err } - check.Bool(t, hasPending, false) + require.False(t, hasPending) current, target, err := oldProvider.GetVersions(context.Background()) if err != nil { return err } - check.Number(t, current, lastVersion) - check.Number(t, target, lastVersion) + require.EqualValues(t, current, lastVersion) + require.EqualValues(t, target, lastVersion) return nil }) - check.NoError(t, g.Wait()) + require.NoError(t, g.Wait()) // A new provider is running in the background with a session lock to simulate a long running // migration. If older instances come up, they should not have any pending migrations and not be @@ -526,29 +525,29 @@ SELECT pg_sleep_for('4 seconds'); }) time.Sleep(1 * time.Second) isLocked, err := existsPgLock(context.Background(), db, lockID) - check.NoError(t, err) - check.Bool(t, isLocked, true) + require.NoError(t, err) + require.True(t, isLocked) hasPending, err := oldProvider.HasPending(context.Background()) - check.NoError(t, err) - check.Bool(t, hasPending, false) + require.NoError(t, err) + require.False(t, hasPending) current, target, err := oldProvider.GetVersions(context.Background()) - check.NoError(t, err) - check.Number(t, current, lastVersion) - check.Number(t, target, lastVersion) + require.NoError(t, err) + require.EqualValues(t, current, lastVersion) + require.EqualValues(t, target, lastVersion) // Wait for the long running migration to finish - check.NoError(t, g.Wait()) + require.NoError(t, g.Wait()) // Check that the new migration was applied hasPending, err = newProvider.HasPending(context.Background()) - check.NoError(t, err) - check.Bool(t, hasPending, false) + require.NoError(t, err) + require.False(t, hasPending) current, target, err = newProvider.GetVersions(context.Background()) - check.NoError(t, err) - check.Number(t, current, lastVersion+1) - check.Number(t, target, lastVersion+1) + require.NoError(t, err) + require.EqualValues(t, current, lastVersion+1) + require.EqualValues(t, target, lastVersion+1) // The max version should be the new migration currentVersion, err := newProvider.GetDBVersion(context.Background()) - check.NoError(t, err) - check.Number(t, currentVersion, lastVersion+1) + require.NoError(t, err) + require.EqualValues(t, currentVersion, lastVersion+1) } func existsPgLock(ctx context.Context, db *sql.DB, lockID int64) (bool, error) { diff --git a/migrate_test.go b/migrate_test.go index 9158829ea..b73b5187b 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -7,7 +7,7 @@ import ( "path/filepath" "testing" - "github.com/pressly/goose/v3/internal/check" + "github.com/stretchr/testify/require" ) func TestMigrationSort(t *testing.T) { @@ -68,10 +68,10 @@ func TestCollectMigrations(t *testing.T) { t.Run("no_migration_files_found", func(t *testing.T) { tmp := t.TempDir() err := os.MkdirAll(filepath.Join(tmp, "migrations-test"), 0755) - check.NoError(t, err) + require.NoError(t, err) _, err = collectMigrationsFS(os.DirFS(tmp), "migrations-test", 0, math.MaxInt64, nil) - check.HasError(t, err) - check.Contains(t, err.Error(), "no migration files found") + require.Error(t, err) + require.Contains(t, err.Error(), "no migration files found") }) t.Run("filesystem_registered_with_single_dirpath", func(t *testing.T) { t.Cleanup(func() { clearMap(registeredGoMigrations) }) @@ -79,26 +79,26 @@ func TestCollectMigrations(t *testing.T) { file3, file4 := "19081_a.go", "19082_b.go" AddNamedMigrationContext(file1, nil, nil) AddNamedMigrationContext(file2, nil, nil) - check.Number(t, len(registeredGoMigrations), 2) + require.Equal(t, len(registeredGoMigrations), 2) tmp := t.TempDir() dir := filepath.Join(tmp, "migrations", "dir1") err := os.MkdirAll(dir, 0755) - check.NoError(t, err) + require.NoError(t, err) createEmptyFile(t, dir, file1) createEmptyFile(t, dir, file2) createEmptyFile(t, dir, file3) createEmptyFile(t, dir, file4) fsys := os.DirFS(tmp) files, err := fs.ReadDir(fsys, "migrations/dir1") - check.NoError(t, err) - check.Number(t, len(files), 4) + require.NoError(t, err) + require.Equal(t, len(files), 4) all, err := collectMigrationsFS(fsys, "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) - check.NoError(t, err) - check.Number(t, len(all), 4) - check.Number(t, all[0].Version, 9081) - check.Number(t, all[1].Version, 9082) - check.Number(t, all[2].Version, 19081) - check.Number(t, all[3].Version, 19082) + require.NoError(t, err) + require.Equal(t, len(all), 4) + require.EqualValues(t, all[0].Version, 9081) + require.EqualValues(t, all[1].Version, 9082) + require.EqualValues(t, all[2].Version, 19081) + require.EqualValues(t, all[3].Version, 19082) }) t.Run("filesystem_registered_with_multiple_dirpath", func(t *testing.T) { t.Cleanup(func() { clearMap(registeredGoMigrations) }) @@ -106,14 +106,14 @@ func TestCollectMigrations(t *testing.T) { AddNamedMigrationContext(file1, nil, nil) AddNamedMigrationContext(file2, nil, nil) AddNamedMigrationContext(file3, nil, nil) - check.Number(t, len(registeredGoMigrations), 3) + require.Equal(t, len(registeredGoMigrations), 3) tmp := t.TempDir() dir1 := filepath.Join(tmp, "migrations", "dir1") dir2 := filepath.Join(tmp, "migrations", "dir2") err := os.MkdirAll(dir1, 0755) - check.NoError(t, err) + require.NoError(t, err) err = os.MkdirAll(dir2, 0755) - check.NoError(t, err) + require.NoError(t, err) createEmptyFile(t, dir1, file1) createEmptyFile(t, dir1, file2) createEmptyFile(t, dir2, file3) @@ -122,33 +122,33 @@ func TestCollectMigrations(t *testing.T) { // even though 3 Go migrations have been registered. { all, err := collectMigrationsFS(fsys, "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) - check.NoError(t, err) - check.Number(t, len(all), 2) - check.Number(t, all[0].Version, 1) - check.Number(t, all[1].Version, 2) + require.NoError(t, err) + require.Equal(t, len(all), 2) + require.EqualValues(t, all[0].Version, 1) + require.EqualValues(t, all[1].Version, 2) } // Validate if dirpath 2 is specified we only get the one Go migration in migrations/dir2 folder // even though 3 Go migrations have been registered. { all, err := collectMigrationsFS(fsys, "migrations/dir2", 0, math.MaxInt64, registeredGoMigrations) - check.NoError(t, err) - check.Number(t, len(all), 1) - check.Number(t, all[0].Version, 1111) + require.NoError(t, err) + require.Equal(t, len(all), 1) + require.EqualValues(t, all[0].Version, 1111) } }) t.Run("empty_filesystem_registered_manually", func(t *testing.T) { t.Cleanup(func() { clearMap(registeredGoMigrations) }) AddNamedMigrationContext("00101_a.go", nil, nil) AddNamedMigrationContext("00102_b.go", nil, nil) - check.Number(t, len(registeredGoMigrations), 2) + require.Equal(t, len(registeredGoMigrations), 2) tmp := t.TempDir() err := os.MkdirAll(filepath.Join(tmp, "migrations"), 0755) - check.NoError(t, err) + require.NoError(t, err) all, err := collectMigrationsFS(os.DirFS(tmp), "migrations", 0, math.MaxInt64, registeredGoMigrations) - check.NoError(t, err) - check.Number(t, len(all), 2) - check.Number(t, all[0].Version, 101) - check.Number(t, all[1].Version, 102) + require.NoError(t, err) + require.Equal(t, len(all), 2) + require.EqualValues(t, all[0].Version, 101) + require.EqualValues(t, all[1].Version, 102) }) t.Run("unregistered_go_migrations", func(t *testing.T) { t.Cleanup(func() { clearMap(registeredGoMigrations) }) @@ -157,67 +157,67 @@ func TestCollectMigrations(t *testing.T) { // valid looking file2 Go migration AddNamedMigrationContext(file1, nil, nil) AddNamedMigrationContext(file3, nil, nil) - check.Number(t, len(registeredGoMigrations), 2) + require.Equal(t, len(registeredGoMigrations), 2) tmp := t.TempDir() dir1 := filepath.Join(tmp, "migrations", "dir1") err := os.MkdirAll(dir1, 0755) - check.NoError(t, err) + require.NoError(t, err) // Include the valid file2 with file1, file3. But remember, it has NOT been // registered. createEmptyFile(t, dir1, file1) createEmptyFile(t, dir1, file2) createEmptyFile(t, dir1, file3) all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) - check.NoError(t, err) - check.Number(t, len(all), 3) - check.Number(t, all[0].Version, 1) - check.Bool(t, all[0].Registered, true) - check.Number(t, all[1].Version, 998) + require.NoError(t, err) + require.Equal(t, len(all), 3) + require.EqualValues(t, all[0].Version, 1) + require.True(t, all[0].Registered) + require.EqualValues(t, all[1].Version, 998) // This migrations is marked unregistered and will lazily raise an error if/when this // migration is run - check.Bool(t, all[1].Registered, false) - check.Number(t, all[2].Version, 999) - check.Bool(t, all[2].Registered, true) + require.False(t, all[1].Registered) + require.EqualValues(t, all[2].Version, 999) + require.True(t, all[2].Registered) }) t.Run("with_skipped_go_files", func(t *testing.T) { t.Cleanup(func() { clearMap(registeredGoMigrations) }) file1, file2, file3, file4 := "00001_a.go", "00002_b.sql", "00999_c_test.go", "embed.go" AddNamedMigrationContext(file1, nil, nil) - check.Number(t, len(registeredGoMigrations), 1) + require.Equal(t, len(registeredGoMigrations), 1) tmp := t.TempDir() dir1 := filepath.Join(tmp, "migrations", "dir1") err := os.MkdirAll(dir1, 0755) - check.NoError(t, err) + require.NoError(t, err) createEmptyFile(t, dir1, file1) createEmptyFile(t, dir1, file2) createEmptyFile(t, dir1, file3) createEmptyFile(t, dir1, file4) all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations) - check.NoError(t, err) - check.Number(t, len(all), 2) - check.Number(t, all[0].Version, 1) - check.Bool(t, all[0].Registered, true) - check.Number(t, all[1].Version, 2) - check.Bool(t, all[1].Registered, false) + require.NoError(t, err) + require.Equal(t, len(all), 2) + require.EqualValues(t, all[0].Version, 1) + require.True(t, all[0].Registered) + require.EqualValues(t, all[1].Version, 2) + require.False(t, all[1].Registered) }) t.Run("current_and_target", func(t *testing.T) { t.Cleanup(func() { clearMap(registeredGoMigrations) }) file1, file2, file3 := "01001_a.go", "01002_b.sql", "01003_c.go" AddNamedMigrationContext(file1, nil, nil) AddNamedMigrationContext(file3, nil, nil) - check.Number(t, len(registeredGoMigrations), 2) + require.Equal(t, len(registeredGoMigrations), 2) tmp := t.TempDir() dir1 := filepath.Join(tmp, "migrations", "dir1") err := os.MkdirAll(dir1, 0755) - check.NoError(t, err) + require.NoError(t, err) createEmptyFile(t, dir1, file1) createEmptyFile(t, dir1, file2) createEmptyFile(t, dir1, file3) all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 1001, 1003, registeredGoMigrations) - check.NoError(t, err) - check.Number(t, len(all), 2) - check.Number(t, all[0].Version, 1002) - check.Number(t, all[1].Version, 1003) + require.NoError(t, err) + require.Equal(t, len(all), 2) + require.EqualValues(t, all[0].Version, 1002) + require.EqualValues(t, all[1].Version, 1003) }) } @@ -254,7 +254,7 @@ func TestVersionFilter(t *testing.T) { func createEmptyFile(t *testing.T, dir, name string) { path := filepath.Join(dir, name) f, err := os.Create(path) - check.NoError(t, err) + require.NoError(t, err) defer f.Close() } diff --git a/provider_collect_test.go b/provider_collect_test.go index 32404f0b7..17b574444 100644 --- a/provider_collect_test.go +++ b/provider_collect_test.go @@ -5,31 +5,31 @@ import ( "testing" "testing/fstest" - "github.com/pressly/goose/v3/internal/check" + "github.com/stretchr/testify/require" ) func TestCollectFileSources(t *testing.T) { t.Parallel() t.Run("nil_fsys", func(t *testing.T) { sources, err := collectFilesystemSources(nil, false, nil, nil) - check.NoError(t, err) - check.Bool(t, sources != nil, true) - check.Number(t, len(sources.goSources), 0) - check.Number(t, len(sources.sqlSources), 0) + require.NoError(t, err) + require.True(t, sources != nil) + require.Equal(t, len(sources.goSources), 0) + require.Equal(t, len(sources.sqlSources), 0) }) t.Run("noop_fsys", func(t *testing.T) { sources, err := collectFilesystemSources(noopFS{}, false, nil, nil) - check.NoError(t, err) - check.Bool(t, sources != nil, true) - check.Number(t, len(sources.goSources), 0) - check.Number(t, len(sources.sqlSources), 0) + require.NoError(t, err) + require.True(t, sources != nil) + require.Equal(t, len(sources.goSources), 0) + require.Equal(t, len(sources.sqlSources), 0) }) t.Run("empty_fsys", func(t *testing.T) { sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil) - check.NoError(t, err) - check.Number(t, len(sources.goSources), 0) - check.Number(t, len(sources.sqlSources), 0) - check.Bool(t, sources != nil, true) + require.NoError(t, err) + require.Equal(t, len(sources.goSources), 0) + require.Equal(t, len(sources.sqlSources), 0) + require.True(t, sources != nil) }) t.Run("incorrect_fsys", func(t *testing.T) { mapFS := fstest.MapFS{ @@ -37,21 +37,21 @@ func TestCollectFileSources(t *testing.T) { } // strict disable - should not error sources, err := collectFilesystemSources(mapFS, false, nil, nil) - check.NoError(t, err) - check.Number(t, len(sources.goSources), 0) - check.Number(t, len(sources.sqlSources), 0) + require.NoError(t, err) + require.Equal(t, len(sources.goSources), 0) + require.Equal(t, len(sources.sqlSources), 0) // strict enabled - should error _, err = collectFilesystemSources(mapFS, true, nil, nil) - check.HasError(t, err) - check.Contains(t, err.Error(), "migration version must be greater than zero") + require.Error(t, err) + require.Contains(t, err.Error(), "migration version must be greater than zero") }) t.Run("collect", func(t *testing.T) { fsys, err := fs.Sub(newSQLOnlyFS(), "migrations") - check.NoError(t, err) + require.NoError(t, err) sources, err := collectFilesystemSources(fsys, false, nil, nil) - check.NoError(t, err) - check.Number(t, len(sources.sqlSources), 4) - check.Number(t, len(sources.goSources), 0) + require.NoError(t, err) + require.Equal(t, len(sources.sqlSources), 4) + require.Equal(t, len(sources.goSources), 0) expected := fileSources{ sqlSources: []Source{ newSource(TypeSQL, "00001_foo.sql", 1), @@ -61,12 +61,12 @@ func TestCollectFileSources(t *testing.T) { }, } for i := 0; i < len(sources.sqlSources); i++ { - check.Equal(t, sources.sqlSources[i], expected.sqlSources[i]) + require.Equal(t, sources.sqlSources[i], expected.sqlSources[i]) } }) t.Run("excludes", func(t *testing.T) { fsys, err := fs.Sub(newSQLOnlyFS(), "migrations") - check.NoError(t, err) + require.NoError(t, err) sources, err := collectFilesystemSources( fsys, false, @@ -77,9 +77,9 @@ func TestCollectFileSources(t *testing.T) { }, nil, ) - check.NoError(t, err) - check.Number(t, len(sources.sqlSources), 2) - check.Number(t, len(sources.goSources), 0) + require.NoError(t, err) + require.Equal(t, len(sources.sqlSources), 2) + require.Equal(t, len(sources.goSources), 0) expected := fileSources{ sqlSources: []Source{ newSource(TypeSQL, "00001_foo.sql", 1), @@ -87,7 +87,7 @@ func TestCollectFileSources(t *testing.T) { }, } for i := 0; i < len(sources.sqlSources); i++ { - check.Equal(t, sources.sqlSources[i], expected.sqlSources[i]) + require.Equal(t, sources.sqlSources[i], expected.sqlSources[i]) } }) t.Run("strict", func(t *testing.T) { @@ -95,10 +95,10 @@ func TestCollectFileSources(t *testing.T) { // Add a file with no version number mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")} fsys, err := fs.Sub(mapFS, "migrations") - check.NoError(t, err) + require.NoError(t, err) _, err = collectFilesystemSources(fsys, true, nil, nil) - check.HasError(t, err) - check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`) + require.Error(t, err) + require.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`) }) t.Run("skip_go_test_files", func(t *testing.T) { mapFS := fstest.MapFS{ @@ -109,9 +109,9 @@ func TestCollectFileSources(t *testing.T) { "5_foo_test.go": {Data: []byte(`package goose_test`)}, } sources, err := collectFilesystemSources(mapFS, false, nil, nil) - check.NoError(t, err) - check.Number(t, len(sources.sqlSources), 4) - check.Number(t, len(sources.goSources), 0) + require.NoError(t, err) + require.Equal(t, len(sources.sqlSources), 4) + require.Equal(t, len(sources.goSources), 0) }) t.Run("skip_random_files", func(t *testing.T) { mapFS := fstest.MapFS{ @@ -124,18 +124,18 @@ func TestCollectFileSources(t *testing.T) { "some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)}, } sources, err := collectFilesystemSources(mapFS, false, nil, nil) - check.NoError(t, err) - check.Number(t, len(sources.sqlSources), 2) - check.Number(t, len(sources.goSources), 1) + require.NoError(t, err) + require.Equal(t, len(sources.sqlSources), 2) + require.Equal(t, len(sources.goSources), 1) // 1 - check.Equal(t, sources.sqlSources[0].Path, "1_foo.sql") - check.Equal(t, sources.sqlSources[0].Version, int64(1)) + require.Equal(t, sources.sqlSources[0].Path, "1_foo.sql") + require.Equal(t, sources.sqlSources[0].Version, int64(1)) // 2 - check.Equal(t, sources.sqlSources[1].Path, "5_qux.sql") - check.Equal(t, sources.sqlSources[1].Version, int64(5)) + require.Equal(t, sources.sqlSources[1].Path, "5_qux.sql") + require.Equal(t, sources.sqlSources[1].Version, int64(5)) // 3 - check.Equal(t, sources.goSources[0].Path, "4_something.go") - check.Equal(t, sources.goSources[0].Version, int64(4)) + require.Equal(t, sources.goSources[0].Path, "4_something.go") + require.Equal(t, sources.goSources[0].Version, int64(4)) }) t.Run("duplicate_versions", func(t *testing.T) { mapFS := fstest.MapFS{ @@ -143,8 +143,8 @@ func TestCollectFileSources(t *testing.T) { "01_bar.sql": sqlMapFile, } _, err := collectFilesystemSources(mapFS, false, nil, nil) - check.HasError(t, err) - check.Contains(t, err.Error(), "found duplicate migration version 1") + require.Error(t, err) + require.Contains(t, err.Error(), "found duplicate migration version 1") }) t.Run("dirpath", func(t *testing.T) { mapFS := fstest.MapFS{ @@ -157,13 +157,13 @@ func TestCollectFileSources(t *testing.T) { assertDirpath := func(dirpath string, sqlSources []Source) { t.Helper() f, err := fs.Sub(mapFS, dirpath) - check.NoError(t, err) + require.NoError(t, err) got, err := collectFilesystemSources(f, false, nil, nil) - check.NoError(t, err) - check.Number(t, len(got.sqlSources), len(sqlSources)) - check.Number(t, len(got.goSources), 0) + require.NoError(t, err) + require.Equal(t, len(got.sqlSources), len(sqlSources)) + require.Equal(t, len(got.goSources), 0) for i := 0; i < len(got.sqlSources); i++ { - check.Equal(t, got.sqlSources[i], sqlSources[i]) + require.Equal(t, got.sqlSources[i], sqlSources[i]) } } assertDirpath(".", []Source{ @@ -193,35 +193,35 @@ func TestMerge(t *testing.T) { "migrations/00003_baz.go": {Data: []byte(`package migrations`)}, } fsys, err := fs.Sub(mapFS, "migrations") - check.NoError(t, err) + require.NoError(t, err) sources, err := collectFilesystemSources(fsys, false, nil, nil) - check.NoError(t, err) - check.Equal(t, len(sources.sqlSources), 1) - check.Equal(t, len(sources.goSources), 2) + require.NoError(t, err) + require.Equal(t, len(sources.sqlSources), 1) + require.Equal(t, len(sources.goSources), 2) t.Run("valid", func(t *testing.T) { registered := map[int64]*Migration{ 2: NewGoMigration(2, nil, nil), 3: NewGoMigration(3, nil, nil), } migrations, err := merge(sources, registered) - check.NoError(t, err) - check.Number(t, len(migrations), 3) + require.NoError(t, err) + require.Equal(t, len(migrations), 3) assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2)) assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3)) }) t.Run("unregistered_all", func(t *testing.T) { _, err := merge(sources, nil) - check.HasError(t, err) - check.Contains(t, err.Error(), "error: detected 2 unregistered Go files:") - check.Contains(t, err.Error(), "00002_bar.go") - check.Contains(t, err.Error(), "00003_baz.go") + require.Error(t, err) + require.Contains(t, err.Error(), "error: detected 2 unregistered Go files:") + require.Contains(t, err.Error(), "00002_bar.go") + require.Contains(t, err.Error(), "00003_baz.go") }) t.Run("unregistered_some", func(t *testing.T) { _, err := merge(sources, map[int64]*Migration{2: NewGoMigration(2, nil, nil)}) - check.HasError(t, err) - check.Contains(t, err.Error(), "error: detected 1 unregistered Go file") - check.Contains(t, err.Error(), "00003_baz.go") + require.Error(t, err) + require.Contains(t, err.Error(), "error: detected 1 unregistered Go file") + require.Contains(t, err.Error(), "00003_baz.go") }) t.Run("duplicate_sql", func(t *testing.T) { _, err := merge(sources, map[int64]*Migration{ @@ -229,8 +229,8 @@ func TestMerge(t *testing.T) { 2: NewGoMigration(2, nil, nil), 3: NewGoMigration(3, nil, nil), }) - check.HasError(t, err) - check.Contains(t, err.Error(), "found duplicate migration version 1") + require.Error(t, err) + require.Contains(t, err.Error(), "found duplicate migration version 1") }) }) t.Run("no_go_files_on_disk", func(t *testing.T) { @@ -241,17 +241,17 @@ func TestMerge(t *testing.T) { "migrations/00005_baz.sql": sqlMapFile, } fsys, err := fs.Sub(mapFS, "migrations") - check.NoError(t, err) + require.NoError(t, err) sources, err := collectFilesystemSources(fsys, false, nil, nil) - check.NoError(t, err) + require.NoError(t, err) t.Run("unregistered_all", func(t *testing.T) { migrations, err := merge(sources, map[int64]*Migration{ 3: NewGoMigration(3, nil, nil), // 4 is missing 6: NewGoMigration(6, nil, nil), }) - check.NoError(t, err) - check.Number(t, len(migrations), 5) + require.NoError(t, err) + require.Equal(t, len(migrations), 5) assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) assertMigration(t, migrations[1], newSource(TypeSQL, "00002_bar.sql", 2)) assertMigration(t, migrations[2], newSource(TypeGo, "", 3)) @@ -265,9 +265,9 @@ func TestMerge(t *testing.T) { "migrations/00002_bar.go": &fstest.MapFile{Data: []byte(`package migrations`)}, } fsys, err := fs.Sub(mapFS, "migrations") - check.NoError(t, err) + require.NoError(t, err) sources, err := collectFilesystemSources(fsys, false, nil, nil) - check.NoError(t, err) + require.NoError(t, err) t.Run("unregistered_all", func(t *testing.T) { migrations, err := merge(sources, map[int64]*Migration{ // This is the only Go file on disk. @@ -276,8 +276,8 @@ func TestMerge(t *testing.T) { 3: NewGoMigration(3, nil, nil), 6: NewGoMigration(6, nil, nil), }) - check.NoError(t, err) - check.Number(t, len(migrations), 4) + require.NoError(t, err) + require.Equal(t, len(migrations), 4) assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2)) assertMigration(t, migrations[2], newSource(TypeGo, "", 3)) @@ -288,15 +288,15 @@ func TestMerge(t *testing.T) { func assertMigration(t *testing.T, got *Migration, want Source) { t.Helper() - check.Equal(t, got.Type, want.Type) - check.Equal(t, got.Version, want.Version) - check.Equal(t, got.Source, want.Path) + require.Equal(t, got.Type, want.Type) + require.Equal(t, got.Version, want.Version) + require.Equal(t, got.Source, want.Path) switch got.Type { case TypeGo: - check.Bool(t, got.goUp != nil, true) - check.Bool(t, got.goDown != nil, true) + require.True(t, got.goUp != nil) + require.True(t, got.goDown != nil) case TypeSQL: - check.Bool(t, got.sql.Parsed, false) + require.False(t, got.sql.Parsed) default: t.Fatalf("unknown migration type: %s", got.Type) } diff --git a/provider_options_test.go b/provider_options_test.go index b9bff8b3d..f22d3757f 100644 --- a/provider_options_test.go +++ b/provider_options_test.go @@ -8,14 +8,14 @@ import ( "github.com/pressly/goose/v3" "github.com/pressly/goose/v3/database" - "github.com/pressly/goose/v3/internal/check" + "github.com/stretchr/testify/require" _ "modernc.org/sqlite" ) func TestNewProvider(t *testing.T) { dir := t.TempDir() db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) - check.NoError(t, err) + require.NoError(t, err) fsys := fstest.MapFS{ "1_foo.sql": {Data: []byte(migration1)}, "2_bar.sql": {Data: []byte(migration2)}, @@ -25,41 +25,41 @@ func TestNewProvider(t *testing.T) { t.Run("invalid", func(t *testing.T) { // Empty dialect not allowed _, err = goose.NewProvider("", db, fsys) - check.HasError(t, err) + require.Error(t, err) // Invalid dialect not allowed _, err = goose.NewProvider("unknown-dialect", db, fsys) - check.HasError(t, err) + require.Error(t, err) // Nil db not allowed _, err = goose.NewProvider(goose.DialectSQLite3, nil, fsys) - check.HasError(t, err) + require.Error(t, err) // Nil store not allowed _, err = goose.NewProvider(goose.DialectSQLite3, db, nil, goose.WithStore(nil)) - check.HasError(t, err) + require.Error(t, err) // Cannot set both dialect and store store, err := database.NewStore(goose.DialectSQLite3, "custom_table") - check.NoError(t, err) + require.NoError(t, err) _, err = goose.NewProvider(goose.DialectSQLite3, db, nil, goose.WithStore(store)) - check.HasError(t, err) + require.Error(t, err) // Multiple stores not allowed _, err = goose.NewProvider(goose.DialectSQLite3, db, nil, goose.WithStore(store), goose.WithStore(store), ) - check.HasError(t, err) + require.Error(t, err) }) t.Run("valid", func(t *testing.T) { // Valid dialect, db, and fsys allowed _, err = goose.NewProvider(goose.DialectSQLite3, db, fsys) - check.NoError(t, err) + require.NoError(t, err) // Valid dialect, db, fsys, and verbose allowed _, err = goose.NewProvider(goose.DialectSQLite3, db, fsys, goose.WithVerbose(testing.Verbose()), ) - check.NoError(t, err) + require.NoError(t, err) // Custom store allowed store, err := database.NewStore(goose.DialectSQLite3, "custom_table") - check.NoError(t, err) + require.NoError(t, err) _, err = goose.NewProvider("", db, nil, goose.WithStore(store)) - check.HasError(t, err) + require.Error(t, err) }) } diff --git a/provider_run_test.go b/provider_run_test.go index 604ba8695..793072ba4 100644 --- a/provider_run_test.go +++ b/provider_run_test.go @@ -16,7 +16,7 @@ import ( "github.com/pressly/goose/v3" "github.com/pressly/goose/v3/database" - "github.com/pressly/goose/v3/internal/check" + "github.com/stretchr/testify/require" ) func TestProviderRun(t *testing.T) { @@ -24,38 +24,38 @@ func TestProviderRun(t *testing.T) { t.Run("closed_db", func(t *testing.T) { p, db := newProviderWithDB(t) - check.NoError(t, db.Close()) + require.NoError(t, db.Close()) _, err := p.Up(context.Background()) - check.HasError(t, err) - check.Equal(t, err.Error(), "failed to initialize: sql: database is closed") + require.Error(t, err) + require.Equal(t, err.Error(), "failed to initialize: sql: database is closed") }) t.Run("ping_and_close", func(t *testing.T) { p, _ := newProviderWithDB(t) t.Cleanup(func() { - check.NoError(t, p.Close()) + require.NoError(t, p.Close()) }) - check.NoError(t, p.Ping(context.Background())) + require.NoError(t, p.Ping(context.Background())) }) t.Run("apply_unknown_version", func(t *testing.T) { p, _ := newProviderWithDB(t) _, err := p.ApplyVersion(context.Background(), 999, true) - check.HasError(t, err) - check.Bool(t, errors.Is(err, goose.ErrVersionNotFound), true) + require.Error(t, err) + require.True(t, errors.Is(err, goose.ErrVersionNotFound)) _, err = p.ApplyVersion(context.Background(), 999, false) - check.HasError(t, err) - check.Bool(t, errors.Is(err, goose.ErrVersionNotFound), true) + require.Error(t, err) + require.True(t, errors.Is(err, goose.ErrVersionNotFound)) }) t.Run("run_zero", func(t *testing.T) { p, _ := newProviderWithDB(t) _, err := p.UpTo(context.Background(), 0) - check.HasError(t, err) - check.Equal(t, err.Error(), "version must be greater than 0") + require.Error(t, err) + require.Equal(t, err.Error(), "version must be greater than 0") _, err = p.DownTo(context.Background(), -1) - check.HasError(t, err) - check.Equal(t, err.Error(), "invalid version: must be a valid number or zero: -1") + require.Error(t, err) + require.Equal(t, err.Error(), "invalid version: must be a valid number or zero: -1") _, err = p.ApplyVersion(context.Background(), 0, true) - check.HasError(t, err) - check.Equal(t, err.Error(), "version must be greater than 0") + require.Error(t, err) + require.Equal(t, err.Error(), "version must be greater than 0") }) t.Run("up_and_down_all", func(t *testing.T) { ctx := context.Background() @@ -64,15 +64,15 @@ func TestProviderRun(t *testing.T) { numCount = 7 ) sources := p.ListSources() - check.Number(t, len(sources), numCount) + require.Equal(t, len(sources), numCount) // Ensure only SQL migrations are returned for _, s := range sources { - check.Equal(t, s.Type, goose.TypeSQL) + require.Equal(t, s.Type, goose.TypeSQL) } // Test Up res, err := p.Up(ctx) - check.NoError(t, err) - check.Number(t, len(res), numCount) + require.NoError(t, err) + require.Equal(t, len(res), numCount) assertResult(t, res[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false) assertResult(t, res[1], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "up", false) assertResult(t, res[2], newSource(goose.TypeSQL, "00003_comments_table.sql", 3), "up", false) @@ -82,8 +82,8 @@ func TestProviderRun(t *testing.T) { assertResult(t, res[6], newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), "up", true) // Test Down res, err = p.DownTo(ctx, 0) - check.NoError(t, err) - check.Number(t, len(res), numCount) + require.NoError(t, err) + require.Equal(t, len(res), numCount) assertResult(t, res[0], newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), "down", true) assertResult(t, res[1], newSource(goose.TypeSQL, "00006_empty_up.sql", 6), "down", true) assertResult(t, res[2], newSource(goose.TypeSQL, "00005_posts_view.sql", 5), "down", false) @@ -107,13 +107,13 @@ func TestProviderRun(t *testing.T) { } break } - check.NoError(t, err) - check.Bool(t, res != nil, true) - check.Number(t, res.Source.Version, int64(counter)) + require.NoError(t, err) + require.True(t, res != nil) + require.Equal(t, res.Source.Version, int64(counter)) } currentVersion, err := p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, currentVersion, int64(maxVersion)) + require.NoError(t, err) + require.Equal(t, currentVersion, int64(maxVersion)) // Reset counter counter = 0 // Rollback all migrations one-by-one. @@ -126,14 +126,14 @@ func TestProviderRun(t *testing.T) { } break } - check.NoError(t, err) - check.Bool(t, res != nil, true) - check.Number(t, res.Source.Version, int64(maxVersion-counter+1)) + require.NoError(t, err) + require.True(t, res != nil) + require.Equal(t, res.Source.Version, int64(maxVersion-counter+1)) } // Once everything is tested the version should match the highest testdata version currentVersion, err = p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, currentVersion, 0) + require.NoError(t, err) + require.EqualValues(t, currentVersion, 0) }) t.Run("up_to", func(t *testing.T) { ctx := context.Background() @@ -142,18 +142,18 @@ func TestProviderRun(t *testing.T) { upToVersion int64 = 2 ) results, err := p.UpTo(ctx, upToVersion) - check.NoError(t, err) - check.Number(t, len(results), upToVersion) + require.NoError(t, err) + require.EqualValues(t, len(results), upToVersion) assertResult(t, results[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false) assertResult(t, results[1], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "up", false) // Fetch the goose version from DB currentVersion, err := p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, currentVersion, upToVersion) + require.NoError(t, err) + require.Equal(t, currentVersion, upToVersion) // Validate the version actually matches what goose claims it is gotVersion, err := getMaxVersionID(db, goose.DefaultTablename) - check.NoError(t, err) - check.Number(t, gotVersion, upToVersion) + require.NoError(t, err) + require.Equal(t, gotVersion, upToVersion) }) t.Run("sql_connections", func(t *testing.T) { tt := []struct { @@ -177,26 +177,26 @@ func TestProviderRun(t *testing.T) { db.SetMaxIdleConns(tc.maxIdleConns) } sources := p.ListSources() - check.NumberNotZero(t, len(sources)) + require.NotZero(t, len(sources)) currentVersion, err := p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, currentVersion, 0) + require.NoError(t, err) + require.EqualValues(t, currentVersion, 0) { // Apply all up migrations upResult, err := p.Up(ctx) - check.NoError(t, err) - check.Number(t, len(upResult), len(sources)) + require.NoError(t, err) + require.Equal(t, len(upResult), len(sources)) currentVersion, err := p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, currentVersion, p.ListSources()[len(sources)-1].Version) + require.NoError(t, err) + require.Equal(t, currentVersion, p.ListSources()[len(sources)-1].Version) // Validate the db migration version actually matches what goose claims it is gotVersion, err := getMaxVersionID(db, goose.DefaultTablename) - check.NoError(t, err) - check.Number(t, gotVersion, currentVersion) + require.NoError(t, err) + require.Equal(t, gotVersion, currentVersion) tables, err := getTableNames(db) - check.NoError(t, err) + require.NoError(t, err) if !reflect.DeepEqual(tables, knownTables) { t.Logf("got tables: %v", tables) t.Logf("known tables: %v", knownTables) @@ -206,14 +206,14 @@ func TestProviderRun(t *testing.T) { { // Apply all down migrations downResult, err := p.DownTo(ctx, 0) - check.NoError(t, err) - check.Number(t, len(downResult), len(sources)) + require.NoError(t, err) + require.Equal(t, len(downResult), len(sources)) gotVersion, err := getMaxVersionID(db, goose.DefaultTablename) - check.NoError(t, err) - check.Number(t, gotVersion, 0) + require.NoError(t, err) + require.EqualValues(t, gotVersion, 0) // Should only be left with a single table, the default goose table tables, err := getTableNames(db) - check.NoError(t, err) + require.NoError(t, err) knownTables := []string{goose.DefaultTablename, "sqlite_sequence"} if !reflect.DeepEqual(tables, knownTables) { t.Logf("got tables: %v", tables) @@ -231,7 +231,7 @@ func TestProviderRun(t *testing.T) { // Apply all migrations in the up direction. for _, s := range sources { res, err := p.ApplyVersion(ctx, s.Version, true) - check.NoError(t, err) + require.NoError(t, err) // Round-trip the migration result through the database to ensure it's valid. var empty bool if s.Version == 6 || s.Version == 7 { @@ -243,7 +243,7 @@ func TestProviderRun(t *testing.T) { for i := len(sources) - 1; i >= 0; i-- { s := sources[i] res, err := p.ApplyVersion(ctx, s.Version, false) - check.NoError(t, err) + require.NoError(t, err) // Round-trip the migration result through the database to ensure it's valid. var empty bool if s.Version == 6 || s.Version == 7 { @@ -253,11 +253,11 @@ func TestProviderRun(t *testing.T) { } // Try apply version 1 multiple times _, err := p.ApplyVersion(ctx, 1, true) - check.NoError(t, err) + require.NoError(t, err) _, err = p.ApplyVersion(ctx, 1, true) - check.HasError(t, err) - check.Bool(t, errors.Is(err, goose.ErrAlreadyApplied), true) - check.Contains(t, err.Error(), "version 1: migration already applied") + require.Error(t, err) + require.True(t, errors.Is(err, goose.ErrAlreadyApplied)) + require.Contains(t, err.Error(), "version 1: migration already applied") }) t.Run("status", func(t *testing.T) { ctx := context.Background() @@ -265,8 +265,8 @@ func TestProviderRun(t *testing.T) { numCount := len(p.ListSources()) // Before any migrations are applied, the status should be empty. status, err := p.Status(ctx) - check.NoError(t, err) - check.Number(t, len(status), numCount) + require.NoError(t, err) + require.Equal(t, len(status), numCount) assertStatus(t, status[0], goose.StatePending, newSource(goose.TypeSQL, "00001_users_table.sql", 1), true) assertStatus(t, status[1], goose.StatePending, newSource(goose.TypeSQL, "00002_posts_table.sql", 2), true) assertStatus(t, status[2], goose.StatePending, newSource(goose.TypeSQL, "00003_comments_table.sql", 3), true) @@ -276,10 +276,10 @@ func TestProviderRun(t *testing.T) { assertStatus(t, status[6], goose.StatePending, newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), true) // Apply all migrations _, err = p.Up(ctx) - check.NoError(t, err) + require.NoError(t, err) status, err = p.Status(ctx) - check.NoError(t, err) - check.Number(t, len(status), numCount) + require.NoError(t, err) + require.Equal(t, len(status), numCount) assertStatus(t, status[0], goose.StateApplied, newSource(goose.TypeSQL, "00001_users_table.sql", 1), false) assertStatus(t, status[1], goose.StateApplied, newSource(goose.TypeSQL, "00002_posts_table.sql", 2), false) assertStatus(t, status[2], goose.StateApplied, newSource(goose.TypeSQL, "00003_comments_table.sql", 3), false) @@ -317,35 +317,35 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3'); `), } p, err := goose.NewProvider(goose.DialectSQLite3, db, mapFS) - check.NoError(t, err) + require.NoError(t, err) _, err = p.Up(ctx) - check.HasError(t, err) - check.Contains(t, err.Error(), "partial migration error (type:sql,version:2)") + require.Error(t, err) + require.Contains(t, err.Error(), "partial migration error (type:sql,version:2)") var expected *goose.PartialError - check.Bool(t, errors.As(err, &expected), true) + require.True(t, errors.As(err, &expected)) // Check Err field - check.Bool(t, expected.Err != nil, true) - check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)") + require.True(t, expected.Err != nil) + require.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)") // Check Results field - check.Number(t, len(expected.Applied), 1) + require.Equal(t, len(expected.Applied), 1) assertResult(t, expected.Applied[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false) // Check Failed field - check.Bool(t, expected.Failed != nil, true) + require.True(t, expected.Failed != nil) assertSource(t, expected.Failed.Source, goose.TypeSQL, "00002_partial_error.sql", 2) - check.Bool(t, expected.Failed.Empty, false) - check.Bool(t, expected.Failed.Error != nil, true) - check.Contains(t, expected.Failed.Error.Error(), "SQL logic error: no such table: invalid_table (1)") - check.Equal(t, expected.Failed.Direction, "up") - check.Bool(t, expected.Failed.Duration > 0, true) + require.False(t, expected.Failed.Empty) + require.True(t, expected.Failed.Error != nil) + require.Contains(t, expected.Failed.Error.Error(), "SQL logic error: no such table: invalid_table (1)") + require.Equal(t, expected.Failed.Direction, "up") + require.True(t, expected.Failed.Duration > 0) // Ensure the partial error did not affect the database. count, err := countOwners(db) - check.NoError(t, err) - check.Number(t, count, 0) + require.NoError(t, err) + require.Equal(t, count, 0) status, err := p.Status(ctx) - check.NoError(t, err) - check.Number(t, len(status), 3) + require.NoError(t, err) + require.Equal(t, len(status), 3) assertStatus(t, status[0], goose.StateApplied, newSource(goose.TypeSQL, "00001_users_table.sql", 1), false) assertStatus(t, status[1], goose.StatePending, newSource(goose.TypeSQL, "00002_partial_error.sql", 2), true) assertStatus(t, status[2], goose.StatePending, newSource(goose.TypeSQL, "00003_insert_data.sql", 3), true) @@ -391,13 +391,13 @@ func TestConcurrentProvider(t *testing.T) { if t.Failed() { return } - check.Number(t, len(versions), maxVersion) + require.Equal(t, len(versions), maxVersion) for i := 0; i < maxVersion; i++ { - check.Number(t, versions[i], int64(i+1)) + require.Equal(t, versions[i], int64(i+1)) } currentVersion, err := p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, currentVersion, maxVersion) + require.NoError(t, err) + require.EqualValues(t, currentVersion, maxVersion) }) t.Run("down", func(t *testing.T) { ctx := context.Background() @@ -405,10 +405,10 @@ func TestConcurrentProvider(t *testing.T) { maxVersion := len(p.ListSources()) // Apply all migrations _, err := p.Up(ctx) - check.NoError(t, err) + require.NoError(t, err) currentVersion, err := p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, currentVersion, maxVersion) + require.NoError(t, err) + require.EqualValues(t, currentVersion, maxVersion) ch := make(chan []*goose.MigrationResult) var wg sync.WaitGroup @@ -444,10 +444,10 @@ func TestConcurrentProvider(t *testing.T) { if t.Failed() { return } - check.Equal(t, len(valid), 1) - check.Equal(t, len(empty), maxVersion-1) + require.Equal(t, len(valid), 1) + require.Equal(t, len(empty), maxVersion-1) // Ensure the valid result is correct. - check.Number(t, len(valid[0]), maxVersion) + require.Equal(t, len(valid[0]), maxVersion) }) } @@ -473,7 +473,7 @@ func TestNoVersioning(t *testing.T) { ctx := context.Background() dbName := fmt.Sprintf("test_%s.db", randomAlphaNumeric(8)) db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), dbName)) - check.NoError(t, err) + require.NoError(t, err) fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "migrations")) const ( // Total owners created by the seed files. @@ -485,50 +485,50 @@ func TestNoVersioning(t *testing.T) { goose.WithVerbose(testing.Verbose()), goose.WithDisableVersioning(false), // This is the default. ) - check.Number(t, len(p.ListSources()), 3) - check.NoError(t, err) + require.Equal(t, len(p.ListSources()), 3) + require.NoError(t, err) _, err = p.Up(ctx) - check.NoError(t, err) + require.NoError(t, err) baseVersion, err := p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, baseVersion, 3) + require.NoError(t, err) + require.EqualValues(t, baseVersion, 3) t.Run("seed-up-down-to-zero", func(t *testing.T) { fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed")) p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys, goose.WithVerbose(testing.Verbose()), goose.WithDisableVersioning(true), // Provider with no versioning. ) - check.NoError(t, err) - check.Number(t, len(p.ListSources()), 2) + require.NoError(t, err) + require.Equal(t, len(p.ListSources()), 2) // Run (all) up migrations from the seed dir { upResult, err := p.Up(ctx) - check.NoError(t, err) - check.Number(t, len(upResult), 2) + require.NoError(t, err) + require.Equal(t, len(upResult), 2) // When versioning is disabled, we cannot track the version of the seed files. _, err = p.GetDBVersion(ctx) - check.HasError(t, err) + require.Error(t, err) seedOwnerCount, err := countSeedOwners(db) - check.NoError(t, err) - check.Number(t, seedOwnerCount, wantSeedOwnerCount) + require.NoError(t, err) + require.Equal(t, seedOwnerCount, wantSeedOwnerCount) } // Run (all) down migrations from the seed dir { downResult, err := p.DownTo(ctx, 0) - check.NoError(t, err) - check.Number(t, len(downResult), 2) + require.NoError(t, err) + require.Equal(t, len(downResult), 2) // When versioning is disabled, we cannot track the version of the seed files. _, err = p.GetDBVersion(ctx) - check.HasError(t, err) + require.Error(t, err) seedOwnerCount, err := countSeedOwners(db) - check.NoError(t, err) - check.Number(t, seedOwnerCount, 0) + require.NoError(t, err) + require.Equal(t, seedOwnerCount, 0) } // The migrations added 4 non-seed owners, they must remain in the database afterwards ownerCount, err := countOwners(db) - check.NoError(t, err) - check.Number(t, ownerCount, wantOwnerCount) + require.NoError(t, err) + require.Equal(t, ownerCount, wantOwnerCount) }) } @@ -548,22 +548,22 @@ func TestAllowMissing(t *testing.T) { p, err := goose.NewProvider(goose.DialectSQLite3, db, newFsys(), goose.WithAllowOutofOrder(false), ) - check.NoError(t, err) + require.NoError(t, err) // Create and apply first 3 migrations. _, err = p.UpTo(ctx, 3) - check.NoError(t, err) + require.NoError(t, err) currentVersion, err := p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, currentVersion, 3) + require.NoError(t, err) + require.EqualValues(t, currentVersion, 3) // Developer A - migration 5 (mistakenly applied) result, err := p.ApplyVersion(ctx, 5, true) - check.NoError(t, err) - check.Number(t, result.Source.Version, 5) + require.NoError(t, err) + require.EqualValues(t, result.Source.Version, 5) current, err := p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, current, 5) + require.NoError(t, err) + require.EqualValues(t, current, 5) // The database has migrations 1,2,3,5 applied. @@ -571,31 +571,31 @@ func TestAllowMissing(t *testing.T) { // default goose does not allow missing (out-of-order) migrations, which means halt if a // missing migration is detected. _, err = p.Up(ctx) - check.HasError(t, err) + require.Error(t, err) // found 1 missing (out-of-order) migration: [00004_insert_data.sql] - check.Contains(t, err.Error(), "missing (out-of-order) migration") + require.Contains(t, err.Error(), "missing (out-of-order) migration") // Confirm db version is unchanged. current, err = p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, current, 5) + require.NoError(t, err) + require.EqualValues(t, current, 5) _, err = p.UpByOne(ctx) - check.HasError(t, err) + require.Error(t, err) // found 1 missing (out-of-order) migration: [00004_insert_data.sql] - check.Contains(t, err.Error(), "missing (out-of-order) migration") + require.Contains(t, err.Error(), "missing (out-of-order) migration") // Confirm db version is unchanged. current, err = p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, current, 5) + require.NoError(t, err) + require.EqualValues(t, current, 5) _, err = p.UpTo(ctx, math.MaxInt64) - check.HasError(t, err) + require.Error(t, err) // found 1 missing (out-of-order) migration: [00004_insert_data.sql] - check.Contains(t, err.Error(), "missing (out-of-order) migration") + require.Contains(t, err.Error(), "missing (out-of-order) migration") // Confirm db version is unchanged. current, err = p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, current, 5) + require.NoError(t, err) + require.EqualValues(t, current, 5) }) t.Run("missing_allowed", func(t *testing.T) { @@ -603,43 +603,43 @@ func TestAllowMissing(t *testing.T) { p, err := goose.NewProvider(goose.DialectSQLite3, db, newFsys(), goose.WithAllowOutofOrder(true), ) - check.NoError(t, err) + require.NoError(t, err) // Create and apply first 3 migrations. _, err = p.UpTo(ctx, 3) - check.NoError(t, err) + require.NoError(t, err) currentVersion, err := p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, currentVersion, 3) + require.NoError(t, err) + require.EqualValues(t, currentVersion, 3) // Developer A - migration 5 (mistakenly applied) { _, err = p.ApplyVersion(ctx, 5, true) - check.NoError(t, err) + require.NoError(t, err) current, err := p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, current, 5) + require.NoError(t, err) + require.EqualValues(t, current, 5) } // Developer B - migration 4 (missing) and 6 (new) { // 4 upResult, err := p.UpByOne(ctx) - check.NoError(t, err) - check.Bool(t, upResult != nil, true) - check.Number(t, upResult.Source.Version, 4) + require.NoError(t, err) + require.True(t, upResult != nil) + require.EqualValues(t, upResult.Source.Version, 4) // 6 upResult, err = p.UpByOne(ctx) - check.NoError(t, err) - check.Bool(t, upResult != nil, true) - check.Number(t, upResult.Source.Version, 6) + require.NoError(t, err) + require.True(t, upResult != nil) + require.EqualValues(t, upResult.Source.Version, 6) count, err := getGooseVersionCount(db, goose.DefaultTablename) - check.NoError(t, err) - check.Number(t, count, 6) + require.NoError(t, err) + require.EqualValues(t, count, 6) current, err := p.GetDBVersion(ctx) - check.NoError(t, err) + require.NoError(t, err) // Expecting max(version_id) to be 8 - check.Number(t, current, 6) + require.EqualValues(t, current, 6) } // The applied order in the database is expected to be: @@ -649,12 +649,12 @@ func TestAllowMissing(t *testing.T) { testDownAndVersion := func(wantDBVersion, wantResultVersion int64) { currentVersion, err := p.GetDBVersion(ctx) - check.NoError(t, err) - check.Number(t, currentVersion, wantDBVersion) + require.NoError(t, err) + require.Equal(t, currentVersion, wantDBVersion) downRes, err := p.Down(ctx) - check.NoError(t, err) - check.Bool(t, downRes != nil, true) - check.Number(t, downRes.Source.Version, wantResultVersion) + require.NoError(t, err) + require.True(t, downRes != nil) + require.Equal(t, downRes.Source.Version, wantResultVersion) } // This behaviour may need to change, see the following issues for more details: @@ -668,8 +668,8 @@ func TestAllowMissing(t *testing.T) { testDownAndVersion(2, 2) testDownAndVersion(1, 1) _, err = p.Down(ctx) - check.HasError(t, err) - check.Bool(t, errors.Is(err, goose.ErrNoNextVersion), true) + require.Error(t, err) + require.True(t, errors.Is(err, goose.ErrNoNextVersion)) }) } @@ -690,30 +690,30 @@ func TestSQLiteSharedCache(t *testing.T) { // database connections as follows: file::memory:?cache=shared" t.Run("shared_cache", func(t *testing.T) { db, err := sql.Open("sqlite", "file::memory:?cache=shared") - check.NoError(t, err) + require.NoError(t, err) fsys := fstest.MapFS{"00001_a.sql": newMapFile(`-- +goose Up`)} p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys, goose.WithGoMigrations( goose.NewGoMigration(2, &goose.GoFunc{Mode: goose.TransactionDisabled}, nil), ), ) - check.NoError(t, err) + require.NoError(t, err) _, err = p.Up(context.Background()) - check.NoError(t, err) + require.NoError(t, err) }) t.Run("no_shared_cache", func(t *testing.T) { db, err := sql.Open("sqlite", "file::memory:") - check.NoError(t, err) + require.NoError(t, err) fsys := fstest.MapFS{"00001_a.sql": newMapFile(`-- +goose Up`)} p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys, goose.WithGoMigrations( goose.NewGoMigration(2, &goose.GoFunc{Mode: goose.TransactionDisabled}, nil), ), ) - check.NoError(t, err) + require.NoError(t, err) _, err = p.Up(context.Background()) - check.HasError(t, err) - check.Contains(t, err.Error(), "SQL logic error: no such table: goose_db_version") + require.Error(t, err) + require.Contains(t, err.Error(), "SQL logic error: no such table: goose_db_version") }) } @@ -736,26 +736,26 @@ func TestGoMigrationPanic(t *testing.T) { p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), nil, goose.WithGoMigrations(migration), // Add a Go migration that panics. ) - check.NoError(t, err) + require.NoError(t, err) _, err = p.Up(ctx) - check.HasError(t, err) - check.Contains(t, err.Error(), wantErrString) + require.Error(t, err) + require.Contains(t, err.Error(), wantErrString) var expected *goose.PartialError - check.Bool(t, errors.As(err, &expected), true) - check.Contains(t, expected.Err.Error(), wantErrString) + require.True(t, errors.As(err, &expected)) + require.Contains(t, expected.Err.Error(), wantErrString) } func TestCustomStoreTableExists(t *testing.T) { t.Parallel() store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename) - check.NoError(t, err) + require.NoError(t, err) p, err := goose.NewProvider("", newDB(t), newFsys(), goose.WithStore(&customStoreSQLite3{store}), ) - check.NoError(t, err) + require.NoError(t, err) _, err = p.Up(context.Background()) - check.NoError(t, err) + require.NoError(t, err) } func TestProviderApply(t *testing.T) { @@ -763,13 +763,13 @@ func TestProviderApply(t *testing.T) { ctx := context.Background() p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), newFsys()) - check.NoError(t, err) + require.NoError(t, err) _, err = p.ApplyVersion(ctx, 1, true) - check.NoError(t, err) + require.NoError(t, err) // This version has a corresponding down migration, but has never been applied. _, err = p.ApplyVersion(ctx, 2, false) - check.HasError(t, err) - check.Bool(t, errors.Is(err, goose.ErrNotApplied), true) + require.Error(t, err) + require.True(t, errors.Is(err, goose.ErrNotApplied)) } func TestPending(t *testing.T) { @@ -780,31 +780,31 @@ func TestPending(t *testing.T) { p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), fsys, goose.WithAllowOutofOrder(true), ) - check.NoError(t, err) + require.NoError(t, err) // Some migrations have been applied out of order. _, err = p.ApplyVersion(ctx, 1, true) - check.NoError(t, err) + require.NoError(t, err) _, err = p.ApplyVersion(ctx, 3, true) - check.NoError(t, err) + require.NoError(t, err) // Even though the latest migration HAS been applied, there are still pending out-of-order // migrations. current, target, err := p.GetVersions(ctx) - check.NoError(t, err) - check.Number(t, current, 3) - check.Number(t, target, len(fsys)) + require.NoError(t, err) + require.EqualValues(t, current, 3) + require.EqualValues(t, target, len(fsys)) hasPending, err := p.HasPending(ctx) - check.NoError(t, err) - check.Bool(t, hasPending, true) + require.NoError(t, err) + require.True(t, hasPending) // Apply the missing migrations. _, err = p.Up(ctx) - check.NoError(t, err) + require.NoError(t, err) // All migrations have been applied. hasPending, err = p.HasPending(ctx) - check.NoError(t, err) - check.Bool(t, hasPending, false) + require.NoError(t, err) + require.False(t, hasPending) current, target, err = p.GetVersions(ctx) - check.NoError(t, err) - check.Number(t, current, target) + require.NoError(t, err) + require.Equal(t, current, target) }) t.Run("disallow_out_of_order", func(t *testing.T) { ctx := context.Background() @@ -814,24 +814,24 @@ func TestPending(t *testing.T) { p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), fsys, goose.WithAllowOutofOrder(false), ) - check.NoError(t, err) + require.NoError(t, err) // Some migrations have been applied. _, err = p.ApplyVersion(ctx, 1, true) - check.NoError(t, err) + require.NoError(t, err) _, err = p.ApplyVersion(ctx, versionToApply, true) - check.NoError(t, err) + require.NoError(t, err) // TODO(mf): revisit the pending check behavior in addition to the HasPending // method. current, target, err := p.GetVersions(ctx) - check.NoError(t, err) - check.Number(t, current, versionToApply) - check.Number(t, target, len(fsys)) + require.NoError(t, err) + require.Equal(t, current, versionToApply) + require.EqualValues(t, target, len(fsys)) _, err = p.HasPending(ctx) - check.HasError(t, err) - check.Contains(t, err.Error(), "missing (out-of-order) migration") + require.Error(t, err) + require.Contains(t, err.Error(), "missing (out-of-order) migration") _, err = p.Up(ctx) - check.HasError(t, err) - check.Contains(t, err.Error(), "missing (out-of-order) migration") + require.Error(t, err) + require.Contains(t, err.Error(), "missing (out-of-order) migration") } t.Run("latest_version", func(t *testing.T) { @@ -874,7 +874,7 @@ func TestGoOnly(t *testing.T) { q := `SELECT count(*)FROM users` var count int err := db.QueryRow(q).Scan(&count) - check.NoError(t, err) + require.NoError(t, err) return count } @@ -888,7 +888,7 @@ func TestGoOnly(t *testing.T) { ), } err := goose.SetGlobalMigrations(register...) - check.NoError(t, err) + require.NoError(t, err) t.Cleanup(goose.ResetGlobalMigrations) db := newDB(t) @@ -902,33 +902,33 @@ func TestGoOnly(t *testing.T) { p, err := goose.NewProvider(goose.DialectSQLite3, db, nil, goose.WithGoMigrations(register...), ) - check.NoError(t, err) + require.NoError(t, err) sources := p.ListSources() - check.Number(t, len(p.ListSources()), 2) + require.Equal(t, len(p.ListSources()), 2) assertSource(t, sources[0], goose.TypeGo, "", 1) assertSource(t, sources[1], goose.TypeGo, "", 2) // Apply migration 1 res, err := p.UpByOne(ctx) - check.NoError(t, err) + require.NoError(t, err) assertResult(t, res, newSource(goose.TypeGo, "", 1), "up", false) - check.Number(t, countUser(db), 0) - check.Bool(t, tableExists(t, db, "users"), true) + require.Equal(t, countUser(db), 0) + require.True(t, tableExists(t, db, "users")) // Apply migration 2 res, err = p.UpByOne(ctx) - check.NoError(t, err) + require.NoError(t, err) assertResult(t, res, newSource(goose.TypeGo, "", 2), "up", false) - check.Number(t, countUser(db), 3) + require.Equal(t, countUser(db), 3) // Rollback migration 2 res, err = p.Down(ctx) - check.NoError(t, err) + require.NoError(t, err) assertResult(t, res, newSource(goose.TypeGo, "", 2), "down", false) - check.Number(t, countUser(db), 0) + require.Equal(t, countUser(db), 0) // Rollback migration 1 res, err = p.Down(ctx) - check.NoError(t, err) + require.NoError(t, err) assertResult(t, res, newSource(goose.TypeGo, "", 1), "down", false) // Check table does not exist - check.Bool(t, tableExists(t, db, "users"), false) + require.False(t, tableExists(t, db, "users")) }) t.Run("with_db", func(t *testing.T) { ctx := context.Background() @@ -944,7 +944,7 @@ func TestGoOnly(t *testing.T) { ), } err := goose.SetGlobalMigrations(register...) - check.NoError(t, err) + require.NoError(t, err) t.Cleanup(goose.ResetGlobalMigrations) db := newDB(t) @@ -958,33 +958,33 @@ func TestGoOnly(t *testing.T) { p, err := goose.NewProvider(goose.DialectSQLite3, db, nil, goose.WithGoMigrations(register...), ) - check.NoError(t, err) + require.NoError(t, err) sources := p.ListSources() - check.Number(t, len(p.ListSources()), 2) + require.Equal(t, len(p.ListSources()), 2) assertSource(t, sources[0], goose.TypeGo, "", 1) assertSource(t, sources[1], goose.TypeGo, "", 2) // Apply migration 1 res, err := p.UpByOne(ctx) - check.NoError(t, err) + require.NoError(t, err) assertResult(t, res, newSource(goose.TypeGo, "", 1), "up", false) - check.Number(t, countUser(db), 0) - check.Bool(t, tableExists(t, db, "users"), true) + require.Equal(t, countUser(db), 0) + require.True(t, tableExists(t, db, "users")) // Apply migration 2 res, err = p.UpByOne(ctx) - check.NoError(t, err) + require.NoError(t, err) assertResult(t, res, newSource(goose.TypeGo, "", 2), "up", false) - check.Number(t, countUser(db), 3) + require.Equal(t, countUser(db), 3) // Rollback migration 2 res, err = p.Down(ctx) - check.NoError(t, err) + require.NoError(t, err) assertResult(t, res, newSource(goose.TypeGo, "", 2), "down", false) - check.Number(t, countUser(db), 0) + require.Equal(t, countUser(db), 0) // Rollback migration 1 res, err = p.Down(ctx) - check.NoError(t, err) + require.NoError(t, err) assertResult(t, res, newSource(goose.TypeGo, "", 1), "down", false) // Check table does not exist - check.Bool(t, tableExists(t, db, "users"), false) + require.False(t, tableExists(t, db, "users")) }) } @@ -1006,7 +1006,7 @@ func tableExists(t *testing.T, db *sql.DB, table string) bool { q := fmt.Sprintf(`SELECT CASE WHEN COUNT(*) > 0 THEN 1 ELSE 0 END AS table_exists FROM sqlite_master WHERE type = 'table' AND name = '%s'`, table) var b string err := db.QueryRow(q).Scan(&b) - check.NoError(t, err) + require.NoError(t, err) return b == "1" } @@ -1030,7 +1030,7 @@ func newProviderWithDB(t *testing.T, opts ...goose.ProviderOption) (*goose.Provi goose.WithVerbose(testing.Verbose()), ) p, err := goose.NewProvider(goose.DialectSQLite3, db, newFsys(), opts...) - check.NoError(t, err) + require.NoError(t, err) return p, db } @@ -1038,7 +1038,7 @@ func newDB(t *testing.T) *sql.DB { t.Helper() dbName := fmt.Sprintf("test_%s.db", randomAlphaNumeric(8)) db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), dbName)) - check.NoError(t, err) + require.NoError(t, err) return db } @@ -1074,26 +1074,26 @@ func getTableNames(db *sql.DB) ([]string, error) { func assertStatus(t *testing.T, got *goose.MigrationStatus, state goose.State, source *goose.Source, appliedIsZero bool) { t.Helper() - check.Equal(t, got.State, state) - check.Equal(t, got.Source, source) - check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero) + require.Equal(t, got.State, state) + require.Equal(t, got.Source, source) + require.Equal(t, got.AppliedAt.IsZero(), appliedIsZero) } func assertResult(t *testing.T, got *goose.MigrationResult, source *goose.Source, direction string, isEmpty bool) { t.Helper() - check.Bool(t, got != nil, true) - check.Equal(t, got.Source, source) - check.Equal(t, got.Direction, direction) - check.Equal(t, got.Empty, isEmpty) - check.Bool(t, got.Error == nil, true) - check.Bool(t, got.Duration > 0, true) + require.True(t, got != nil) + require.Equal(t, got.Source, source) + require.Equal(t, got.Direction, direction) + require.Equal(t, got.Empty, isEmpty) + require.Nil(t, got.Error) + require.True(t, got.Duration > 0) } func assertSource(t *testing.T, got *goose.Source, typ goose.MigrationType, name string, version int64) { t.Helper() - check.Equal(t, got.Type, typ) - check.Equal(t, got.Path, name) - check.Equal(t, got.Version, version) + require.Equal(t, got.Type, typ) + require.Equal(t, got.Path, name) + require.Equal(t, got.Version, version) } func newSource(t goose.MigrationType, fullpath string, version int64) *goose.Source { diff --git a/provider_test.go b/provider_test.go index 82676043e..6228b80ad 100644 --- a/provider_test.go +++ b/provider_test.go @@ -9,18 +9,18 @@ import ( "testing/fstest" "github.com/pressly/goose/v3" - "github.com/pressly/goose/v3/internal/check" + "github.com/stretchr/testify/require" _ "modernc.org/sqlite" ) func TestProvider(t *testing.T) { dir := t.TempDir() db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) - check.NoError(t, err) + require.NoError(t, err) t.Run("empty", func(t *testing.T) { _, err := goose.NewProvider(goose.DialectSQLite3, db, fstest.MapFS{}) - check.HasError(t, err) - check.Bool(t, errors.Is(err, goose.ErrNoMigrations), true) + require.Error(t, err) + require.True(t, errors.Is(err, goose.ErrNoMigrations)) }) mapFS := fstest.MapFS{ @@ -28,13 +28,13 @@ func TestProvider(t *testing.T) { "migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)}, } fsys, err := fs.Sub(mapFS, "migrations") - check.NoError(t, err) + require.NoError(t, err) p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys) - check.NoError(t, err) + require.NoError(t, err) sources := p.ListSources() - check.Equal(t, len(sources), 2) - check.Equal(t, sources[0], newSource(goose.TypeSQL, "001_foo.sql", 1)) - check.Equal(t, sources[1], newSource(goose.TypeSQL, "002_bar.sql", 2)) + require.Equal(t, len(sources), 2) + require.Equal(t, sources[0], newSource(goose.TypeSQL, "001_foo.sql", 1)) + require.Equal(t, sources[1], newSource(goose.TypeSQL, "002_bar.sql", 2)) } var ( @@ -78,7 +78,5 @@ ALTER TABLE my_foo RENAME TO foo; func TestPartialErrorUnwrap(t *testing.T) { err := &goose.PartialError{Err: goose.ErrNoCurrentVersion} - - got := errors.Is(err, goose.ErrNoCurrentVersion) - check.Bool(t, got, true) + require.ErrorIs(t, err, goose.ErrNoCurrentVersion) } diff --git a/tests/gomigrations/error/gomigrations_error_test.go b/tests/gomigrations/error/gomigrations_error_test.go index e03cbc114..3a62d8f21 100644 --- a/tests/gomigrations/error/gomigrations_error_test.go +++ b/tests/gomigrations/error/gomigrations_error_test.go @@ -6,68 +6,68 @@ import ( "testing" "github.com/pressly/goose/v3" - "github.com/pressly/goose/v3/internal/check" _ "github.com/pressly/goose/v3/tests/gomigrations/error/testdata" + "github.com/stretchr/testify/require" _ "modernc.org/sqlite" ) func TestGoMigrationByOne(t *testing.T) { tempDir := t.TempDir() db, err := sql.Open("sqlite", filepath.Join(tempDir, "test.db")) - check.NoError(t, err) + require.NoError(t, err) err = goose.SetDialect(string(goose.DialectSQLite3)) - check.NoError(t, err) + require.NoError(t, err) // Create goose table. current, err := goose.EnsureDBVersion(db) - check.NoError(t, err) - check.Number(t, current, 0) + require.NoError(t, err) + require.Equal(t, current, 0) // Collect migrations. dir := "testdata" migrations, err := goose.CollectMigrations(dir, 0, goose.MaxVersion) - check.NoError(t, err) - check.Number(t, len(migrations), 4) + require.NoError(t, err) + require.Equal(t, len(migrations), 4) // Setup table. err = migrations[0].Up(db) - check.NoError(t, err) + require.NoError(t, err) version, err := goose.GetDBVersion(db) - check.NoError(t, err) - check.Number(t, version, 1) + require.NoError(t, err) + require.Equal(t, version, 1) // Registered Go migration run outside a goose tx using *sql.DB. err = migrations[1].Up(db) - check.HasError(t, err) - check.Contains(t, err.Error(), "failed to run go migration") + require.Error(t, err) + require.Contains(t, err.Error(), "failed to run go migration") version, err = goose.GetDBVersion(db) - check.NoError(t, err) - check.Number(t, version, 1) + require.NoError(t, err) + require.Equal(t, version, 1) // This migration was inserting 100 rows, but fails at 50, and // because it's run outside a goose tx then we expect 50 rows. var count int err = db.QueryRow("SELECT COUNT(*) FROM foo").Scan(&count) - check.NoError(t, err) - check.Number(t, count, 50) + require.NoError(t, err) + require.Equal(t, count, 50) // Truncate table so we have 0 rows. err = migrations[2].Up(db) - check.NoError(t, err) + require.NoError(t, err) version, err = goose.GetDBVersion(db) - check.NoError(t, err) + require.NoError(t, err) // We're at version 3, but keep in mind 2 was never applied because it failed. - check.Number(t, version, 3) + require.Equal(t, version, 3) // Registered Go migration run within a tx. err = migrations[3].Up(db) - check.HasError(t, err) - check.Contains(t, err.Error(), "failed to run go migration") + require.Error(t, err) + require.Contains(t, err.Error(), "failed to run go migration") version, err = goose.GetDBVersion(db) - check.NoError(t, err) - check.Number(t, version, 3) // This migration failed, so we're still at 3. + require.NoError(t, err) + require.Equal(t, version, 3) // This migration failed, so we're still at 3. // This migration was inserting 100 rows, but fails at 50. However, since it's // running within a tx we expect none of the inserts to persist. err = db.QueryRow("SELECT COUNT(*) FROM foo").Scan(&count) - check.NoError(t, err) - check.Number(t, count, 0) + require.NoError(t, err) + require.Equal(t, count, 0) } diff --git a/tests/gomigrations/register/register_test.go b/tests/gomigrations/register/register_test.go index b79209a20..1d5e08d73 100644 --- a/tests/gomigrations/register/register_test.go +++ b/tests/gomigrations/register/register_test.go @@ -6,14 +6,14 @@ import ( "testing" "github.com/pressly/goose/v3" - "github.com/pressly/goose/v3/internal/check" _ "github.com/pressly/goose/v3/tests/gomigrations/register/testdata" + "github.com/stretchr/testify/require" ) func TestAddFunctions(t *testing.T) { goMigrations, err := goose.CollectMigrations("testdata", 0, math.MaxInt64) - check.NoError(t, err) - check.Number(t, len(goMigrations), 4) + require.NoError(t, err) + require.Equal(t, len(goMigrations), 4) checkMigration(t, goMigrations[0], &goose.Migration{ Version: 1, @@ -51,12 +51,12 @@ func TestAddFunctions(t *testing.T) { func checkMigration(t *testing.T, got *goose.Migration, want *goose.Migration) { t.Helper() - check.Equal(t, got.Version, want.Version) - check.Equal(t, got.Next, want.Next) - check.Equal(t, got.Previous, want.Previous) - check.Equal(t, filepath.Base(got.Source), want.Source) - check.Equal(t, got.Registered, want.Registered) - check.Bool(t, got.UseTx, want.UseTx) + require.Equal(t, got.Version, want.Version) + require.Equal(t, got.Next, want.Next) + require.Equal(t, got.Previous, want.Previous) + require.Equal(t, filepath.Base(got.Source), want.Source) + require.Equal(t, got.Registered, want.Registered) + require.Equal(t, got.UseTx, want.UseTx) checkFunctions(t, got) } @@ -65,48 +65,48 @@ func checkFunctions(t *testing.T, m *goose.Migration) { switch filepath.Base(m.Source) { case "001_addmigration.go": // With transaction - check.Bool(t, m.UpFn == nil, false) - check.Bool(t, m.DownFn == nil, false) - check.Bool(t, m.UpFnContext == nil, false) - check.Bool(t, m.DownFnContext == nil, false) + require.NotNil(t, m.UpFn) + require.NotNil(t, m.DownFn) + require.NotNil(t, m.UpFnContext) + require.NotNil(t, m.DownFnContext) // No transaction - check.Bool(t, m.UpFnNoTx == nil, true) - check.Bool(t, m.DownFnNoTx == nil, true) - check.Bool(t, m.UpFnNoTxContext == nil, true) - check.Bool(t, m.DownFnNoTxContext == nil, true) + require.Nil(t, m.UpFnNoTx) + require.Nil(t, m.DownFnNoTx) + require.Nil(t, m.UpFnNoTxContext) + require.Nil(t, m.DownFnNoTxContext) case "002_addmigrationnotx.go": // With transaction - check.Bool(t, m.UpFn == nil, true) - check.Bool(t, m.DownFn == nil, true) - check.Bool(t, m.UpFnContext == nil, true) - check.Bool(t, m.DownFnContext == nil, true) + require.Nil(t, m.UpFn) + require.Nil(t, m.DownFn) + require.Nil(t, m.UpFnContext) + require.Nil(t, m.DownFnContext) // No transaction - check.Bool(t, m.UpFnNoTx == nil, false) - check.Bool(t, m.DownFnNoTx == nil, false) - check.Bool(t, m.UpFnNoTxContext == nil, false) - check.Bool(t, m.DownFnNoTxContext == nil, false) + require.NotNil(t, m.UpFnNoTx) + require.NotNil(t, m.DownFnNoTx) + require.NotNil(t, m.UpFnNoTxContext) + require.NotNil(t, m.DownFnNoTxContext) case "003_addmigrationcontext.go": // With transaction - check.Bool(t, m.UpFn == nil, false) - check.Bool(t, m.DownFn == nil, false) - check.Bool(t, m.UpFnContext == nil, false) - check.Bool(t, m.DownFnContext == nil, false) + require.NotNil(t, m.UpFn) + require.NotNil(t, m.DownFn) + require.NotNil(t, m.UpFnContext) + require.NotNil(t, m.DownFnContext) // No transaction - check.Bool(t, m.UpFnNoTx == nil, true) - check.Bool(t, m.DownFnNoTx == nil, true) - check.Bool(t, m.UpFnNoTxContext == nil, true) - check.Bool(t, m.DownFnNoTxContext == nil, true) + require.Nil(t, m.UpFnNoTx) + require.Nil(t, m.DownFnNoTx) + require.Nil(t, m.UpFnNoTxContext) + require.Nil(t, m.DownFnNoTxContext) case "004_addmigrationnotxcontext.go": // With transaction - check.Bool(t, m.UpFn == nil, true) - check.Bool(t, m.DownFn == nil, true) - check.Bool(t, m.UpFnContext == nil, true) - check.Bool(t, m.DownFnContext == nil, true) + require.Nil(t, m.UpFn) + require.Nil(t, m.DownFn) + require.Nil(t, m.UpFnContext) + require.Nil(t, m.DownFnContext) // No transaction - check.Bool(t, m.UpFnNoTx == nil, false) - check.Bool(t, m.DownFnNoTx == nil, false) - check.Bool(t, m.UpFnNoTxContext == nil, false) - check.Bool(t, m.DownFnNoTxContext == nil, false) + require.NotNil(t, m.UpFnNoTx) + require.NotNil(t, m.DownFnNoTx) + require.NotNil(t, m.UpFnNoTxContext) + require.NotNil(t, m.DownFnNoTxContext) default: t.Fatalf("unexpected migration: %s", filepath.Base(m.Source)) } diff --git a/tests/gomigrations/success/gomigrations_success_test.go b/tests/gomigrations/success/gomigrations_success_test.go index 306efe29d..e88ec108d 100644 --- a/tests/gomigrations/success/gomigrations_success_test.go +++ b/tests/gomigrations/success/gomigrations_success_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/pressly/goose/v3" - "github.com/pressly/goose/v3/internal/check" + "github.com/stretchr/testify/require" _ "github.com/pressly/goose/v3/tests/gomigrations/success/testdata" _ "modernc.org/sqlite" @@ -15,39 +15,39 @@ import ( func TestGoMigrationByOne(t *testing.T) { t.Parallel() - check.NoError(t, goose.SetDialect("sqlite3")) + require.NoError(t, goose.SetDialect("sqlite3")) db, err := sql.Open("sqlite", ":memory:") - check.NoError(t, err) + require.NoError(t, err) dir := "testdata" files, err := filepath.Glob(dir + "/*.go") - check.NoError(t, err) + require.NoError(t, err) upByOne := func(t *testing.T) int64 { err = goose.UpByOne(db, dir) t.Logf("err: %v %s", err, dir) - check.NoError(t, err) + require.NoError(t, err) version, err := goose.GetDBVersion(db) - check.NoError(t, err) + require.NoError(t, err) return version } downByOne := func(t *testing.T) int64 { err = goose.Down(db, dir) - check.NoError(t, err) + require.NoError(t, err) version, err := goose.GetDBVersion(db) - check.NoError(t, err) + require.NoError(t, err) return version } // Migrate all files up-by-one. for i := 1; i <= len(files); i++ { - check.Number(t, upByOne(t), i) + require.Equal(t, upByOne(t), i) } version, err := goose.GetDBVersion(db) - check.NoError(t, err) - check.Number(t, version, len(files)) + require.NoError(t, err) + require.Equal(t, version, len(files)) tables, err := ListTables(db) - check.NoError(t, err) - check.Equal(t, tables, []string{ + require.NoError(t, err) + require.Equal(t, tables, []string{ "alpha", "bravo", "charlie", @@ -62,15 +62,15 @@ func TestGoMigrationByOne(t *testing.T) { // Migrate all files down-by-one. for i := len(files) - 1; i >= 0; i-- { - check.Number(t, downByOne(t), i) + require.Equal(t, downByOne(t), i) } version, err = goose.GetDBVersion(db) - check.NoError(t, err) - check.Number(t, version, 0) + require.NoError(t, err) + require.Equal(t, version, 0) tables, err = ListTables(db) - check.NoError(t, err) - check.Equal(t, tables, []string{ + require.NoError(t, err) + require.Equal(t, tables, []string{ "goose_db_version", "sqlite_sequence", })