Skip to content

Commit

Permalink
upgrades: add ExecForCountInTxns helper function for testing
Browse files Browse the repository at this point in the history
This patch adds a helper function called `ExecForCountInTxns`, which abstracts
out the logic to repeatedly run SQL statements on a database in transactions
of a specified size. It also refactors existing usages of this pattern to use
this new function.

Release note: None
  • Loading branch information
andyyang890 committed Mar 7, 2023
1 parent a95ffcd commit de2f842
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,27 +80,14 @@ func runTestDatabaseRoleSettingsUserIDMigration(t *testing.T, numUsers int) {
upgrades.InjectLegacyTable(ctx, t, s, systemschema.DatabaseRoleSettingsTable, getTableDescForDatabaseRoleSettingsTableBeforeRoleIDCol)

// Create test users.
tx, err := db.BeginTx(ctx, nil /* opts */)
require.NoError(t, err)
txRunner := sqlutils.MakeSQLRunner(tx)
for i := 0; i < numUsers; i++ {
// Group statements into transactions of 100 users to speed up creation.
if i != 0 && i%100 == 0 {
err := tx.Commit()
require.NoError(t, err)
tx, err = db.BeginTx(ctx, nil /* opts */)
require.NoError(t, err)
txRunner = sqlutils.MakeSQLRunner(tx)
}
upgrades.ExecForCountInTxns(ctx, t, db, numUsers, 100 /* txCount */, func(txRunner *sqlutils.SQLRunner, i int) {
txRunner.Exec(t, fmt.Sprintf("CREATE USER testuser%d", i))
txRunner.Exec(t, fmt.Sprintf(`ALTER USER testuser%d SET application_name = 'roach sql'`, i))
}
err = tx.Commit()
require.NoError(t, err)
})
tdb.CheckQueryResults(t, "SELECT count(*) FROM system.database_role_settings", [][]string{{strconv.Itoa(numUsers)}})

// Run migrations.
_, err = tc.Conns[0].ExecContext(ctx, `SET CLUSTER SETTING version = $1`,
_, err := tc.Conns[0].ExecContext(ctx, `SET CLUSTER SETTING version = $1`,
clusterversion.ByKey(clusterversion.V23_1DatabaseRoleSettingsHasRoleIDColumn).String())
require.NoError(t, err)
_, err = tc.Conns[0].ExecContext(ctx, `SET CLUSTER SETTING version = $1`,
Expand Down
29 changes: 29 additions & 0 deletions pkg/upgrade/upgrades/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/catalog/systemschema"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/tabledesc"
"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -167,3 +168,31 @@ func GetTable(

// WaitForJobStatement is exported so that it can be detected by a testing knob.
const WaitForJobStatement = waitForJobStatement

// ExecForCountInTxns allows statements to be repeatedly run on a database
// in transactions of a specified size.
func ExecForCountInTxns(
ctx context.Context,
t *testing.T,
db *gosql.DB,
count int,
txCount int,
fn func(txRunner *sqlutils.SQLRunner, i int),
) {
tx, err := db.BeginTx(ctx, nil /* opts */)
require.NoError(t, err)
txRunner := sqlutils.MakeSQLRunner(tx)
// Group statements into transactions of txCount runs to speed up creation.
for i := 0; i < count; i++ {
if i != 0 && i%txCount == 0 {
err := tx.Commit()
require.NoError(t, err)
tx, err = db.BeginTx(ctx, nil /* opts */)
require.NoError(t, err)
txRunner = sqlutils.MakeSQLRunner(tx)
}
fn(txRunner, i)
}
err = tx.Commit()
require.NoError(t, err)
}
21 changes: 4 additions & 17 deletions pkg/upgrade/upgrades/role_members_ids_migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,10 @@ func runTestRoleMembersIDMigration(t *testing.T, numUsers int) {

// Create test users.
expectedNumRoleMembersRows := 1
tx, err := db.BeginTx(ctx, nil /* opts */)
require.NoError(t, err)
txRunner := sqlutils.MakeSQLRunner(tx)
for i := 0; i < numUsers; i++ {
// Group statements into transactions of 100 users to speed up creation.
if i != 0 && i%100 == 0 {
err := tx.Commit()
require.NoError(t, err)
tx, err = db.BeginTx(ctx, nil /* opts */)
require.NoError(t, err)
txRunner = sqlutils.MakeSQLRunner(tx)
}
upgrades.ExecForCountInTxns(ctx, t, db, numUsers, 100 /* txCount */, func(txRunner *sqlutils.SQLRunner, i int) {
txRunner.Exec(t, fmt.Sprintf("CREATE USER testuser%d", i))
if i == 0 {
continue
return
}
// Randomly choose an earlier test user to grant to the current test user.
grantStmt := fmt.Sprintf("GRANT testuser%d to testuser%d", rand.Intn(i), i)
Expand All @@ -117,15 +106,13 @@ func runTestRoleMembersIDMigration(t *testing.T, numUsers int) {
}
txRunner.Exec(t, grantStmt)
expectedNumRoleMembersRows += 1
}
err = tx.Commit()
require.NoError(t, err)
})
tdb.CheckQueryResults(t, "SELECT count(*) FROM system.role_members", [][]string{
{fmt.Sprintf("%d", expectedNumRoleMembersRows)},
})

// Run migrations.
_, err = tc.Conns[0].ExecContext(ctx, `SET CLUSTER SETTING version = $1`,
_, err := tc.Conns[0].ExecContext(ctx, `SET CLUSTER SETTING version = $1`,
clusterversion.ByKey(clusterversion.V23_1RoleMembersTableHasIDColumns).String())
require.NoError(t, err)
_, err = tc.Conns[0].ExecContext(ctx, `SET CLUSTER SETTING version = $1`,
Expand Down
19 changes: 3 additions & 16 deletions pkg/upgrade/upgrades/system_privileges_user_id_migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,27 +84,14 @@ func runTestSystemPrivilegesUserIDMigration(t *testing.T, numUsers int) {
upgrades.InjectLegacyTable(ctx, t, s, systemschema.SystemPrivilegeTable, getTableDescForSystemPrivilegesTableBeforeUserIDCol)

// Create test users.
tx, err := db.BeginTx(ctx, nil /* opts */)
require.NoError(t, err)
txRunner := sqlutils.MakeSQLRunner(tx)
for i := 0; i < numUsers; i++ {
// Group statements into transactions of 100 users to speed up creation.
if i != 0 && i%100 == 0 {
err := tx.Commit()
require.NoError(t, err)
tx, err = db.BeginTx(ctx, nil /* opts */)
require.NoError(t, err)
txRunner = sqlutils.MakeSQLRunner(tx)
}
upgrades.ExecForCountInTxns(ctx, t, db, numUsers, 100 /* txCount */, func(txRunner *sqlutils.SQLRunner, i int) {
txRunner.Exec(t, fmt.Sprintf("CREATE USER testuser%d", i))
txRunner.Exec(t, fmt.Sprintf("GRANT SYSTEM MODIFYCLUSTERSETTING TO testuser%d", i))
}
err = tx.Commit()
require.NoError(t, err)
})
tdb.CheckQueryResults(t, "SELECT count(*) FROM system.privileges", [][]string{{strconv.Itoa(numUsers)}})

// Run migrations.
_, err = tc.Conns[0].ExecContext(ctx, `SET CLUSTER SETTING version = $1`,
_, err := tc.Conns[0].ExecContext(ctx, `SET CLUSTER SETTING version = $1`,
clusterversion.ByKey(clusterversion.V23_1SystemPrivilegesTableHasUserIDColumn).String())
require.NoError(t, err)
_, err = tc.Conns[0].ExecContext(ctx, `SET CLUSTER SETTING version = $1`,
Expand Down
19 changes: 3 additions & 16 deletions pkg/upgrade/upgrades/web_sessions_table_user_id_migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,7 @@ func runTestWebSessionsUserIDMigration(t *testing.T, numUsers int) {
upgrades.InjectLegacyTable(ctx, t, s, systemschema.WebSessionsTable, getTableDescForSystemWebSessionsTableBeforeUserIDCol)

// Create test users.
tx, err := db.BeginTx(ctx, nil /* opts */)
require.NoError(t, err)
txRunner := sqlutils.MakeSQLRunner(tx)
for i := 0; i < numUsers; i++ {
// Group statements into transactions of 100 users to speed up creation.
if i != 0 && i%100 == 0 {
err := tx.Commit()
require.NoError(t, err)
tx, err = db.BeginTx(ctx, nil /* opts */)
require.NoError(t, err)
txRunner = sqlutils.MakeSQLRunner(tx)
}
upgrades.ExecForCountInTxns(ctx, t, db, numUsers, 100 /* txCount */, func(txRunner *sqlutils.SQLRunner, i int) {
txRunner.Exec(t, fmt.Sprintf("CREATE USER testuser%d", i))

// Simulate the INSERT that happens in the actual authentication code.
Expand All @@ -109,13 +98,11 @@ VALUES (
'2023-02-14 20:56:30.699447'
)
`, i))
}
err = tx.Commit()
require.NoError(t, err)
})
tdb.CheckQueryResults(t, "SELECT count(*) FROM system.web_sessions", [][]string{{strconv.Itoa(numUsers)}})

// Run migrations.
_, err = tc.Conns[0].ExecContext(ctx, `SET CLUSTER SETTING version = $1`,
_, err := tc.Conns[0].ExecContext(ctx, `SET CLUSTER SETTING version = $1`,
clusterversion.ByKey(clusterversion.V23_1WebSessionsTableHasUserIDColumn).String())
require.NoError(t, err)
_, err = tc.Conns[0].ExecContext(ctx, `SET CLUSTER SETTING version = $1`,
Expand Down

0 comments on commit de2f842

Please sign in to comment.