From f45166a180d7cb6ba85c90354843251f158de6fa Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Thu, 22 Sep 2022 17:51:06 -0400 Subject: [PATCH] improve: add TxOptions to txn.Begin This will allow us to start read only transactions, or transactions with different isolation levels. --- internal/access/access_test.go | 2 +- internal/server/config.go | 2 +- internal/server/data/data.go | 9 ++++----- internal/server/data/data_test.go | 6 +++--- internal/server/data/migrations_test.go | 6 +++--- internal/server/handlers_test.go | 2 +- internal/server/middleware.go | 2 +- internal/server/middleware_test.go | 6 +++--- internal/server/routes.go | 8 +++++--- 9 files changed, 22 insertions(+), 21 deletions(-) diff --git a/internal/access/access_test.go b/internal/access/access_test.go index fb06a2dac4..ca51120827 100644 --- a/internal/access/access_test.go +++ b/internal/access/access_test.go @@ -70,7 +70,7 @@ func setupAccessTestContext(t *testing.T) (*gin.Context, *data.Transaction, *mod func txnForTestCase(t *testing.T, db *data.DB) *data.Transaction { t.Helper() - tx, err := db.Begin(context.Background()) + tx, err := db.Begin(context.Background(), nil) assert.NilError(t, err) t.Cleanup(func() { assert.NilError(t, tx.Rollback()) diff --git a/internal/server/config.go b/internal/server/config.go index d8bf8f1f43..a1e9419f52 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -606,7 +606,7 @@ func (s Server) loadConfig(config Config) error { org := s.db.DefaultOrg - tx, err := s.db.Begin(context.Background()) + tx, err := s.db.Begin(context.Background(), nil) if err != nil { return err } diff --git a/internal/server/data/data.go b/internal/server/data/data.go index fe60915f86..2485e25331 100644 --- a/internal/server/data/data.go +++ b/internal/server/data/data.go @@ -41,7 +41,7 @@ func NewDB(connection gorm.Dialector, dbOpts NewDBOptions) (*DB, error) { return nil, fmt.Errorf("db conn: %w", err) } dataDB := &DB{DB: db} - tx, err := dataDB.Begin(context.TODO()) + tx, err := dataDB.Begin(context.TODO(), nil) if err != nil { return nil, err } @@ -126,9 +126,8 @@ func (d *DB) GormDB() *gorm.DB { return d.DB } -// TODO: accept sql.TxOptions when we remove gorm -func (d *DB) Begin(ctx context.Context) (*Transaction, error) { - tx := d.DB.WithContext(ctx).Begin() +func (d *DB) Begin(ctx context.Context, opts *sql.TxOptions) (*Transaction, error) { + tx := d.DB.WithContext(ctx).Begin(opts) if err := tx.Error; err != nil { return nil, err } @@ -239,7 +238,7 @@ func newRawDB(connection gorm.Dialector, options NewDBOptions) (*gorm.DB, error) const defaultOrganizationID = 1000 func initialize(db *DB) error { - tx, err := db.Begin(context.TODO()) + tx, err := db.Begin(context.TODO(), nil) if err != nil { return err } diff --git a/internal/server/data/data_test.go b/internal/server/data/data_test.go index abf565ecf4..688975c374 100644 --- a/internal/server/data/data_test.go +++ b/internal/server/data/data_test.go @@ -30,7 +30,7 @@ func setupDB(t *testing.T, driver gorm.Dialector) *DB { func txnForTestCase(t *testing.T, db *DB, orgID uid.ID) *Transaction { t.Helper() - tx, err := db.Begin(context.Background()) + tx, err := db.Begin(context.Background(), nil) assert.NilError(t, err) t.Cleanup(func() { _ = tx.Rollback() @@ -173,7 +173,7 @@ func TestDB_Begin(t *testing.T) { runDBTests(t, func(t *testing.T, db *DB) { t.Run("rollback", func(t *testing.T) { ctx := context.Background() - tx, err := db.Begin(ctx) + tx, err := db.Begin(ctx, nil) assert.NilError(t, err) tx = tx.WithOrgID(db.DefaultOrg.ID) @@ -193,7 +193,7 @@ func TestDB_Begin(t *testing.T) { }) t.Run("commit", func(t *testing.T) { ctx := context.Background() - tx, err := db.Begin(ctx) + tx, err := db.Begin(ctx, nil) assert.NilError(t, err) tx = tx.WithOrgID(db.DefaultOrg.ID) diff --git a/internal/server/data/migrations_test.go b/internal/server/data/migrations_test.go index 31fbfe8776..c21fb7ad7b 100644 --- a/internal/server/data/migrations_test.go +++ b/internal/server/data/migrations_test.go @@ -77,7 +77,7 @@ func TestMigrations(t *testing.T) { }, } - tx, err := db.Begin(context.Background()) + tx, err := db.Begin(context.Background(), nil) assert.NilError(t, err) defer tx.Rollback() @@ -87,7 +87,7 @@ func TestMigrations(t *testing.T) { assert.NilError(t, tx.Commit()) t.Run("run again to check idempotency", func(t *testing.T) { - tx, err := db.Begin(context.Background()) + tx, err := db.Begin(context.Background(), nil) assert.NilError(t, err) defer tx.Rollback() @@ -96,7 +96,7 @@ func TestMigrations(t *testing.T) { assert.NilError(t, tx.Commit()) }) - tx, err = db.Begin(context.Background()) + tx, err = db.Begin(context.Background(), nil) assert.NilError(t, err) defer tx.Rollback() tc.expected(t, tx) diff --git a/internal/server/handlers_test.go b/internal/server/handlers_test.go index 5a86c45a19..e1d3bca91d 100644 --- a/internal/server/handlers_test.go +++ b/internal/server/handlers_test.go @@ -94,7 +94,7 @@ func loginAs(tx *data.Transaction, user *models.Identity) *gin.Context { func txnForTestCase(t *testing.T, db *data.DB) *data.Transaction { t.Helper() - tx, err := db.Begin(context.Background()) + tx, err := db.Begin(context.Background(), nil) assert.NilError(t, err) t.Cleanup(func() { assert.NilError(t, tx.Rollback()) diff --git a/internal/server/middleware.go b/internal/server/middleware.go index 66d2fced97..93825a1cbb 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -74,7 +74,7 @@ func handleInfraDestinationHeader(rCtx access.RequestContext, uniqueID string) e // If the request identifies an organization (which is required for most routes) // a rate limit will be applied to all requests from the same organization. func authenticateRequest(c *gin.Context, route routeSettings, srv *Server) (access.Authenticated, error) { - tx, err := srv.db.Begin(c.Request.Context()) + tx, err := srv.db.Begin(c.Request.Context(), nil) if err != nil { return access.Authenticated{}, err } diff --git a/internal/server/middleware_test.go b/internal/server/middleware_test.go index 4d11b313e8..f9183a7fe0 100644 --- a/internal/server/middleware_test.go +++ b/internal/server/middleware_test.go @@ -99,7 +99,7 @@ func TestDBTimeout(t *testing.T) { c.Request = c.Request.WithContext(ctx) - tx, err := srv.db.Begin(c.Request.Context()) + tx, err := srv.db.Begin(c.Request.Context(), nil) if err != nil { sendAPIError(c, err) return @@ -438,7 +438,7 @@ func TestAuthenticateRequest(t *testing.T) { } createOrgs(t, srv.db, otherOrg, org) - tx, err := srv.db.Begin(context.Background()) + tx, err := srv.db.Begin(context.Background(), nil) assert.NilError(t, err) tx = tx.WithOrgID(org.ID) @@ -567,7 +567,7 @@ func TestValidateRequestOrganization(t *testing.T) { } createOrgs(t, srv.db, otherOrg, org) - tx, err := srv.db.Begin(context.Background()) + tx, err := srv.db.Begin(context.Background(), nil) assert.NilError(t, err) tx = tx.WithOrgID(org.ID) diff --git a/internal/server/routes.go b/internal/server/routes.go index efdba48906..4031bcbb85 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -205,7 +205,7 @@ func wrapRoute[Req, Res any](a *API, routeID routeIdentifier, route route[Req, R return err } - tx, err := a.server.db.Begin(c.Request.Context()) + tx, err := a.server.db.Begin(c.Request.Context(), nil) if err != nil { return err } @@ -291,8 +291,10 @@ func responseStatusCode(method string, resp any) int { func get[Req, Res any](a *API, r *routeGroup, path string, handler HandlerFunc[Req, Res]) { add(a, r, http.MethodGet, path, route[Req, Res]{ - handler: handler, - routeSettings: routeSettings{omitFromTelemetry: true}, + handler: handler, + routeSettings: routeSettings{ + omitFromTelemetry: true, + }, }) }