From 30e85b5366b2e51951ef17a0cf362b58f708dab1 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sun, 7 Nov 2021 14:50:11 +0200 Subject: [PATCH] fix: call query hook when tx is started, committed, or rolled back --- README.md | 4 ++-- db.go | 24 +++++++++++++++++++++--- extra/bundebug/debug.go | 2 +- extra/bunotel/otel.go | 2 +- 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 8f91087c5..f0107c2ad 100644 --- a/README.md +++ b/README.md @@ -94,9 +94,9 @@ Projects using Bun: -## But why? +## Why another database client? -So you can write queries like this: +So you can elegantly write complex queries: ```go regionalSales := db.NewSelect(). diff --git a/db.go b/db.go index 72fe118a9..2d7a20f90 100644 --- a/db.go +++ b/db.go @@ -356,7 +356,8 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) { //------------------------------------------------------------------------------ type Tx struct { - db *DB + ctx context.Context + db *DB *sql.Tx } @@ -382,16 +383,33 @@ func (db *DB) Begin() (Tx, error) { } func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { + ctx, event := db.beforeQuery(ctx, nil, "BEGIN", nil, nil) tx, err := db.DB.BeginTx(ctx, opts) + db.afterQuery(ctx, event, nil, err) if err != nil { return Tx{}, err } return Tx{ - db: db, - Tx: tx, + ctx: ctx, + db: db, + Tx: tx, }, nil } +func (tx Tx) Commit() error { + ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, nil) + err := tx.Tx.Commit() + tx.db.afterQuery(ctx, event, nil, err) + return err +} + +func (tx Tx) Rollback() error { + ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, nil) + err := tx.Tx.Rollback() + tx.db.afterQuery(ctx, event, nil, err) + return err +} + func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) { return tx.ExecContext(context.TODO(), query, args...) } diff --git a/extra/bundebug/debug.go b/extra/bundebug/debug.go index 3e5431e2c..11a0c2118 100644 --- a/extra/bundebug/debug.go +++ b/extra/bundebug/debug.go @@ -77,7 +77,7 @@ func (h *QueryHook) AfterQuery(ctx context.Context, event *bun.QueryEvent) { if !h.verbose { switch event.Err { - case nil, sql.ErrNoRows: + case nil, sql.ErrNoRows, sql.ErrTxDone: return } } diff --git a/extra/bunotel/otel.go b/extra/bunotel/otel.go index 774a66861..d09e01dc3 100644 --- a/extra/bunotel/otel.go +++ b/extra/bunotel/otel.go @@ -104,7 +104,7 @@ func (h *QueryHook) AfterQuery(ctx context.Context, event *bun.QueryEvent) { } switch event.Err { - case nil, sql.ErrNoRows: + case nil, sql.ErrNoRows, sql.ErrTxDone: // ignore default: span.RecordError(event.Err)