Skip to content

Commit

Permalink
Merge pull request #3327 from infrahq/dnephin/sql-tx-options
Browse files Browse the repository at this point in the history
Add TxOptions to txn.Begin
  • Loading branch information
dnephin authored Sep 27, 2022
2 parents c96ddb4 + f45166a commit 27804f6
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 21 deletions.
2 changes: 1 addition & 1 deletion internal/access/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func setupAccessTestContext(t *testing.T) (*gin.Context, *data.Transaction, *mod

func txnForTestCase(t *testing.T, db *data.DB) *data.Transaction {
t.Helper()
tx, err := db.Begin(context.Background())
tx, err := db.Begin(context.Background(), nil)
assert.NilError(t, err)
t.Cleanup(func() {
assert.NilError(t, tx.Rollback())
Expand Down
2 changes: 1 addition & 1 deletion internal/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ func (s Server) loadConfig(config Config) error {

org := s.db.DefaultOrg

tx, err := s.db.Begin(context.Background())
tx, err := s.db.Begin(context.Background(), nil)
if err != nil {
return err
}
Expand Down
9 changes: 4 additions & 5 deletions internal/server/data/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func NewDB(connection gorm.Dialector, dbOpts NewDBOptions) (*DB, error) {
return nil, fmt.Errorf("db conn: %w", err)
}
dataDB := &DB{DB: db}
tx, err := dataDB.Begin(context.TODO())
tx, err := dataDB.Begin(context.TODO(), nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -126,9 +126,8 @@ func (d *DB) GormDB() *gorm.DB {
return d.DB
}

// TODO: accept sql.TxOptions when we remove gorm
func (d *DB) Begin(ctx context.Context) (*Transaction, error) {
tx := d.DB.WithContext(ctx).Begin()
func (d *DB) Begin(ctx context.Context, opts *sql.TxOptions) (*Transaction, error) {
tx := d.DB.WithContext(ctx).Begin(opts)
if err := tx.Error; err != nil {
return nil, err
}
Expand Down Expand Up @@ -239,7 +238,7 @@ func newRawDB(connection gorm.Dialector, options NewDBOptions) (*gorm.DB, error)
const defaultOrganizationID = 1000

func initialize(db *DB) error {
tx, err := db.Begin(context.TODO())
tx, err := db.Begin(context.TODO(), nil)
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions internal/server/data/data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func setupDB(t *testing.T, driver gorm.Dialector) *DB {

func txnForTestCase(t *testing.T, db *DB, orgID uid.ID) *Transaction {
t.Helper()
tx, err := db.Begin(context.Background())
tx, err := db.Begin(context.Background(), nil)
assert.NilError(t, err)
t.Cleanup(func() {
_ = tx.Rollback()
Expand Down Expand Up @@ -173,7 +173,7 @@ func TestDB_Begin(t *testing.T) {
runDBTests(t, func(t *testing.T, db *DB) {
t.Run("rollback", func(t *testing.T) {
ctx := context.Background()
tx, err := db.Begin(ctx)
tx, err := db.Begin(ctx, nil)
assert.NilError(t, err)
tx = tx.WithOrgID(db.DefaultOrg.ID)

Expand All @@ -193,7 +193,7 @@ func TestDB_Begin(t *testing.T) {
})
t.Run("commit", func(t *testing.T) {
ctx := context.Background()
tx, err := db.Begin(ctx)
tx, err := db.Begin(ctx, nil)
assert.NilError(t, err)
tx = tx.WithOrgID(db.DefaultOrg.ID)

Expand Down
6 changes: 3 additions & 3 deletions internal/server/data/migrations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func TestMigrations(t *testing.T) {
},
}

tx, err := db.Begin(context.Background())
tx, err := db.Begin(context.Background(), nil)
assert.NilError(t, err)
defer tx.Rollback()

Expand All @@ -87,7 +87,7 @@ func TestMigrations(t *testing.T) {
assert.NilError(t, tx.Commit())

t.Run("run again to check idempotency", func(t *testing.T) {
tx, err := db.Begin(context.Background())
tx, err := db.Begin(context.Background(), nil)
assert.NilError(t, err)
defer tx.Rollback()

Expand All @@ -96,7 +96,7 @@ func TestMigrations(t *testing.T) {
assert.NilError(t, tx.Commit())
})

tx, err = db.Begin(context.Background())
tx, err = db.Begin(context.Background(), nil)
assert.NilError(t, err)
defer tx.Rollback()
tc.expected(t, tx)
Expand Down
2 changes: 1 addition & 1 deletion internal/server/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func loginAs(tx *data.Transaction, user *models.Identity) *gin.Context {

func txnForTestCase(t *testing.T, db *data.DB) *data.Transaction {
t.Helper()
tx, err := db.Begin(context.Background())
tx, err := db.Begin(context.Background(), nil)
assert.NilError(t, err)
t.Cleanup(func() {
assert.NilError(t, tx.Rollback())
Expand Down
2 changes: 1 addition & 1 deletion internal/server/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func handleInfraDestinationHeader(rCtx access.RequestContext, uniqueID string) e
// If the request identifies an organization (which is required for most routes)
// a rate limit will be applied to all requests from the same organization.
func authenticateRequest(c *gin.Context, route routeSettings, srv *Server) (access.Authenticated, error) {
tx, err := srv.db.Begin(c.Request.Context())
tx, err := srv.db.Begin(c.Request.Context(), nil)
if err != nil {
return access.Authenticated{}, err
}
Expand Down
6 changes: 3 additions & 3 deletions internal/server/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func TestDBTimeout(t *testing.T) {

c.Request = c.Request.WithContext(ctx)

tx, err := srv.db.Begin(c.Request.Context())
tx, err := srv.db.Begin(c.Request.Context(), nil)
if err != nil {
sendAPIError(c, err)
return
Expand Down Expand Up @@ -438,7 +438,7 @@ func TestAuthenticateRequest(t *testing.T) {
}
createOrgs(t, srv.db, otherOrg, org)

tx, err := srv.db.Begin(context.Background())
tx, err := srv.db.Begin(context.Background(), nil)
assert.NilError(t, err)
tx = tx.WithOrgID(org.ID)

Expand Down Expand Up @@ -567,7 +567,7 @@ func TestValidateRequestOrganization(t *testing.T) {
}
createOrgs(t, srv.db, otherOrg, org)

tx, err := srv.db.Begin(context.Background())
tx, err := srv.db.Begin(context.Background(), nil)
assert.NilError(t, err)
tx = tx.WithOrgID(org.ID)

Expand Down
8 changes: 5 additions & 3 deletions internal/server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func wrapRoute[Req, Res any](a *API, routeID routeIdentifier, route route[Req, R
return err
}

tx, err := a.server.db.Begin(c.Request.Context())
tx, err := a.server.db.Begin(c.Request.Context(), nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -291,8 +291,10 @@ func responseStatusCode(method string, resp any) int {

func get[Req, Res any](a *API, r *routeGroup, path string, handler HandlerFunc[Req, Res]) {
add(a, r, http.MethodGet, path, route[Req, Res]{
handler: handler,
routeSettings: routeSettings{omitFromTelemetry: true},
handler: handler,
routeSettings: routeSettings{
omitFromTelemetry: true,
},
})
}

Expand Down

0 comments on commit 27804f6

Please sign in to comment.