Skip to content

Commit

Permalink
feat: add tx methods to IDB (uptrace#587)
Browse files Browse the repository at this point in the history
  • Loading branch information
isgj committed Jun 28, 2022
1 parent e91543d commit 46eb321
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 0 deletions.
80 changes: 80 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package bun

import (
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -431,6 +433,8 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) {
type Tx struct {
ctx context.Context
db *DB
// name is the name of a savepoint
name string
*sql.Tx
}

Expand Down Expand Up @@ -479,19 +483,45 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
}

func (tx Tx) Commit() error {
if tx.name == "" {
return tx.commitTX()
}
return tx.commitSP()
}

func (tx Tx) commitTX() error {
ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, "COMMIT", nil)
err := tx.Tx.Commit()
tx.db.afterQuery(ctx, event, nil, err)
return err
}

func (tx Tx) commitSP() error {
query := "RELEASE SAVEPOINT " + tx.name
_, err := tx.ExecContext(tx.ctx, query)
return err
}

func (tx Tx) Rollback() error {
if tx.name == "" {
return tx.rollbackTX()
}
return tx.rollbackSP()
}

func (tx Tx) rollbackTX() error {
ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, "ROLLBACK", nil)
err := tx.Tx.Rollback()
tx.db.afterQuery(ctx, event, nil, err)
return err
}

func (tx Tx) rollbackSP() error {
query := "ROLLBACK TO SAVEPOINT " + tx.name
_, err := tx.ExecContext(tx.ctx, query)
return err
}

func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.ExecContext(context.TODO(), query, args...)
}
Expand Down Expand Up @@ -534,6 +564,56 @@ func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interfac

//------------------------------------------------------------------------------

func (tx Tx) Begin() (Tx, error) {
return tx.BeginTx(context.Background(), nil)
}

// BeginTx will save a point in the running transaction.
func (tx Tx) BeginTx(ctx context.Context, _ *sql.TxOptions) (Tx, error) {
sp := make([]byte, 16)
_, err := rand.Read(sp)
if err != nil {
return Tx{}, err
}

qName := "SP_" + hex.EncodeToString(sp)
query := "SAVEPOINT " + qName
_, err = tx.ExecContext(ctx, query)
if err != nil {
return Tx{}, err
}
return Tx{
ctx: ctx,
db: tx.db,
Tx: tx.Tx,
name: qName,
}, nil
}

func (tx Tx) RunInTx(
ctx context.Context, _ *sql.TxOptions, fn func(ctx context.Context, tx Tx) error,
) error {
sp, err := tx.BeginTx(ctx, nil)
if err != nil {
return err
}

var done bool

defer func() {
if !done {
_ = sp.Rollback()
}
}()

if err := fn(ctx, sp); err != nil {
return err
}

done = true
return sp.Commit()
}

func (tx Tx) Dialect() schema.Dialect {
return tx.db.Dialect()
}
Expand Down
97 changes: 97 additions & 0 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ func TestDB(t *testing.T) {
{testEmbedModelPointer},
{testJSONMarshaler},
{testNilDriverValue},
{testRunInTxAndSavepoint},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -1401,3 +1402,99 @@ func testNilDriverValue(t *testing.T, db *bun.DB) {
_, err = db.NewInsert().Model(&Model{Value: &DriverValue{s: "hello"}}).Exec(ctx)
require.NoError(t, err)
}

func testRunInTxAndSavepoint(t *testing.T, db *bun.DB) {
type Counter struct {
Count int64
}

err := db.ResetModel(ctx, (*Counter)(nil))
require.NoError(t, err)

_, err = db.NewInsert().Model(&Counter{Count: 0}).Exec(ctx)
require.NoError(t, err)

err = db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
err := tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error {
_, err := sp.NewUpdate().Model((*Counter)(nil)).
Set("count = count + 1").
Where("1 = 1").
Exec(ctx)
return err
})
require.NoError(t, err)
// rolling back the transaction should rollback what happened inside savepoint
return errors.New("fake error")
})
require.Error(t, err)

var count int
err = db.NewSelect().Model((*Counter)(nil)).Scan(ctx, &count)
require.NoError(t, err)
require.Equal(t, 0, count)

err = db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
err := tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error {
_, err := sp.NewInsert().Model(&Counter{Count: 1}).
Exec(ctx)
require.NoError(t, err)
return err
})
require.NoError(t, err)

// ignored on purpose this error
// rolling back a savepoint should not affect the transaction
// nor other savepoints on the same level
_ = tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error {
_, err := sp.NewInsert().Model(&Counter{Count: 2}).
Exec(ctx)
require.NoError(t, err)
return errors.New("fake error")
})

return err
})
require.NoError(t, err)

count, err = db.NewSelect().Model((*Counter)(nil)).Count(ctx)
require.NoError(t, err)
require.Equal(t, 2, count)

err = db.ResetModel(ctx, (*Counter)(nil))
require.NoError(t, err)

// happy path, commit transaction, savepoints and sub-savepoints
err = db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
_, err := tx.NewInsert().Model(&Counter{Count: 1}).
Exec(ctx)
require.NoError(t, err)

err = tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error {
_, err := sp.NewInsert().Model(&Counter{Count: 1}).
Exec(ctx)
if err != nil {
return err
}

return sp.RunInTx(ctx, nil, func(ctx context.Context, subSp bun.Tx) error {
_, err := subSp.NewInsert().Model(&Counter{Count: 1}).
Exec(ctx)
return err
})
})
require.NoError(t, err)

err = tx.RunInTx(ctx, nil, func(ctx context.Context, sp bun.Tx) error {
_, err := sp.NewInsert().Model(&Counter{Count: 2}).
Exec(ctx)
return err
})

return err
})
require.NoError(t, err)

count, err = db.NewSelect().Model((*Counter)(nil)).Count(ctx)
require.NoError(t, err)
require.Equal(t, 4, count)
}
3 changes: 3 additions & 0 deletions query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ type IDB interface {
NewTruncateTable() *TruncateTableQuery
NewAddColumn() *AddColumnQuery
NewDropColumn() *DropColumnQuery

BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error)
RunInTx(ctx context.Context, opts *sql.TxOptions, f func(ctx context.Context, tx Tx) error) error
}

var (
Expand Down

0 comments on commit 46eb321

Please sign in to comment.