diff --git a/db.go b/db.go index 110df583e..47e654655 100644 --- a/db.go +++ b/db.go @@ -2,7 +2,9 @@ package bun import ( "context" + "crypto/rand" "database/sql" + "encoding/hex" "fmt" "reflect" "strings" @@ -431,6 +433,8 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) { type Tx struct { ctx context.Context db *DB + // name is the name of a savepoint + name string *sql.Tx } @@ -479,19 +483,51 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { } func (tx Tx) Commit() error { + if tx.name == "" { + return tx.commitTX() + } + return tx.commitSP() +} + +func (tx Tx) commitTX() error { ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, "COMMIT", nil) err := tx.Tx.Commit() tx.db.afterQuery(ctx, event, nil, err) return err } +func (tx Tx) commitSP() error { + if tx.Dialect().Features().Has(feature.MSSavepoint) { + return nil + } + query := "RELEASE SAVEPOINT " + tx.name + _, err := tx.ExecContext(tx.ctx, query) + return err +} + func (tx Tx) Rollback() error { + if tx.name == "" { + return tx.rollbackTX() + } + return tx.rollbackSP() +} + +func (tx Tx) rollbackTX() error { ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, "ROLLBACK", nil) err := tx.Tx.Rollback() tx.db.afterQuery(ctx, event, nil, err) return err } +func (tx Tx) rollbackSP() error { + query := "ROLLBACK TO SAVEPOINT " + tx.name + if tx.Dialect().Features().Has(feature.MSSavepoint) { + query = "ROLLBACK TRANSACTION " + tx.name + } + _, err := tx.ExecContext(tx.ctx, query) + return err +} + func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) { return tx.ExecContext(context.TODO(), query, args...) } @@ -534,6 +570,60 @@ func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interfac //------------------------------------------------------------------------------ +func (tx Tx) Begin() (Tx, error) { + return tx.BeginTx(tx.ctx, nil) +} + +// BeginTx will save a point in the running transaction. +func (tx Tx) BeginTx(ctx context.Context, _ *sql.TxOptions) (Tx, error) { + // mssql savepoint names are limited to 32 characters + sp := make([]byte, 14) + _, err := rand.Read(sp) + if err != nil { + return Tx{}, err + } + + qName := "SP_" + hex.EncodeToString(sp) + query := "SAVEPOINT " + qName + if tx.Dialect().Features().Has(feature.MSSavepoint) { + query = "SAVE TRANSACTION " + qName + } + _, err = tx.ExecContext(ctx, query) + if err != nil { + return Tx{}, err + } + return Tx{ + ctx: ctx, + db: tx.db, + Tx: tx.Tx, + name: qName, + }, nil +} + +func (tx Tx) RunInTx( + ctx context.Context, _ *sql.TxOptions, fn func(ctx context.Context, tx Tx) error, +) error { + sp, err := tx.BeginTx(ctx, nil) + if err != nil { + return err + } + + var done bool + + defer func() { + if !done { + _ = sp.Rollback() + } + }() + + if err := fn(ctx, sp); err != nil { + return err + } + + done = true + return sp.Commit() +} + func (tx Tx) Dialect() schema.Dialect { return tx.db.Dialect() } diff --git a/dialect/feature/feature.go b/dialect/feature/feature.go index 510d6e5de..a2bba2c47 100644 --- a/dialect/feature/feature.go +++ b/dialect/feature/feature.go @@ -29,4 +29,5 @@ const ( OffsetFetch SelectExists UpdateFromTable + MSSavepoint ) diff --git a/dialect/mssqldialect/dialect.go b/dialect/mssqldialect/dialect.go index 68a96342d..f1b05e415 100755 --- a/dialect/mssqldialect/dialect.go +++ b/dialect/mssqldialect/dialect.go @@ -46,7 +46,8 @@ func New() *Dialect { feature.Identity | feature.Output | feature.OffsetFetch | - feature.UpdateFromTable + feature.UpdateFromTable | + feature.MSSavepoint return d } diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 4e620f119..2f60f116c 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -271,6 +271,7 @@ func TestDB(t *testing.T) { {testEmbedModelPointer}, {testJSONMarshaler}, {testNilDriverValue}, + {testRunInTxAndSavepoint}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -1401,3 +1402,99 @@ func testNilDriverValue(t *testing.T, db *bun.DB) { _, err = db.NewInsert().Model(&Model{Value: &DriverValue{s: "hello"}}).Exec(ctx) require.NoError(t, err) } + +func testRunInTxAndSavepoint(t *testing.T, db *bun.DB) { + type Counter struct { + Count int64 + } + + err := db.ResetModel(ctx, (*Counter)(nil)) + require.NoError(t, err) + + _, err = db.NewInsert().Model(&Counter{Count: 0}).Exec(ctx) + require.NoError(t, err) + + err = db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + err := tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error { + _, err := sp.NewUpdate().Model((*Counter)(nil)). + Set("count = count + 1"). + Where("1 = 1"). + Exec(ctx) + return err + }) + require.NoError(t, err) + // rolling back the transaction should rollback what happened inside savepoint + return errors.New("fake error") + }) + require.Error(t, err) + + var count int + err = db.NewSelect().Model((*Counter)(nil)).Scan(ctx, &count) + require.NoError(t, err) + require.Equal(t, 0, count) + + err = db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + err := tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error { + _, err := sp.NewInsert().Model(&Counter{Count: 1}). + Exec(ctx) + require.NoError(t, err) + return err + }) + require.NoError(t, err) + + // ignored on purpose this error + // rolling back a savepoint should not affect the transaction + // nor other savepoints on the same level + _ = tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error { + _, err := sp.NewInsert().Model(&Counter{Count: 2}). + Exec(ctx) + require.NoError(t, err) + return errors.New("fake error") + }) + + return err + }) + require.NoError(t, err) + + count, err = db.NewSelect().Model((*Counter)(nil)).Count(ctx) + require.NoError(t, err) + require.Equal(t, 2, count) + + err = db.ResetModel(ctx, (*Counter)(nil)) + require.NoError(t, err) + + // happy path, commit transaction, savepoints and sub-savepoints + err = db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewInsert().Model(&Counter{Count: 1}). + Exec(ctx) + require.NoError(t, err) + + err = tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error { + _, err := sp.NewInsert().Model(&Counter{Count: 1}). + Exec(ctx) + if err != nil { + return err + } + + return sp.RunInTx(ctx, nil, func(ctx context.Context, subSp bun.Tx) error { + _, err := subSp.NewInsert().Model(&Counter{Count: 1}). + Exec(ctx) + return err + }) + }) + require.NoError(t, err) + + err = tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error { + _, err := sp.NewInsert().Model(&Counter{Count: 2}). + Exec(ctx) + return err + }) + + return err + }) + require.NoError(t, err) + + count, err = db.NewSelect().Model((*Counter)(nil)).Count(ctx) + require.NoError(t, err) + require.Equal(t, 4, count) +} diff --git a/query_base.go b/query_base.go index ea0ffdd45..202e8ec2c 100644 --- a/query_base.go +++ b/query_base.go @@ -57,6 +57,9 @@ type IDB interface { NewTruncateTable() *TruncateTableQuery NewAddColumn() *AddColumnQuery NewDropColumn() *DropColumnQuery + + BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) + RunInTx(ctx context.Context, opts *sql.TxOptions, f func(ctx context.Context, tx Tx) error) error } var (