Skip to content

Commit

Permalink
feat: add QueryEvent.Model
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Sep 27, 2021
1 parent 8fdc713 commit 7688201
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 13 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ jobs:
name: build
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
go-version: [1.16.x, 1.17.x]

Expand Down
18 changes: 9 additions & 9 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
func (db *DB) ExecContext(
ctx context.Context, query string, args ...interface{},
) (sql.Result, error) {
ctx, event := db.beforeQuery(ctx, nil, query, args)
ctx, event := db.beforeQuery(ctx, nil, query, args, nil)
res, err := db.DB.ExecContext(ctx, db.format(query, args))
db.afterQuery(ctx, event, res, err)
return res, err
Expand All @@ -216,7 +216,7 @@ func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
func (db *DB) QueryContext(
ctx context.Context, query string, args ...interface{},
) (*sql.Rows, error) {
ctx, event := db.beforeQuery(ctx, nil, query, args)
ctx, event := db.beforeQuery(ctx, nil, query, args, nil)
rows, err := db.DB.QueryContext(ctx, db.format(query, args))
db.afterQuery(ctx, event, nil, err)
return rows, err
Expand All @@ -227,7 +227,7 @@ func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row {
}

func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
ctx, event := db.beforeQuery(ctx, nil, query, args)
ctx, event := db.beforeQuery(ctx, nil, query, args, nil)
row := db.DB.QueryRowContext(ctx, db.format(query, args))
db.afterQuery(ctx, event, nil, row.Err())
return row
Expand Down Expand Up @@ -258,7 +258,7 @@ func (db *DB) Conn(ctx context.Context) (Conn, error) {
func (c Conn) ExecContext(
ctx context.Context, query string, args ...interface{},
) (sql.Result, error) {
ctx, event := c.db.beforeQuery(ctx, nil, query, args)
ctx, event := c.db.beforeQuery(ctx, nil, query, args, nil)
res, err := c.Conn.ExecContext(ctx, c.db.format(query, args))
c.db.afterQuery(ctx, event, res, err)
return res, err
Expand All @@ -267,14 +267,14 @@ func (c Conn) ExecContext(
func (c Conn) QueryContext(
ctx context.Context, query string, args ...interface{},
) (*sql.Rows, error) {
ctx, event := c.db.beforeQuery(ctx, nil, query, args)
ctx, event := c.db.beforeQuery(ctx, nil, query, args, nil)
rows, err := c.Conn.QueryContext(ctx, c.db.format(query, args))
c.db.afterQuery(ctx, event, nil, err)
return rows, err
}

func (c Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
ctx, event := c.db.beforeQuery(ctx, nil, query, args)
ctx, event := c.db.beforeQuery(ctx, nil, query, args, nil)
row := c.Conn.QueryRowContext(ctx, c.db.format(query, args))
c.db.afterQuery(ctx, event, nil, row.Err())
return row
Expand Down Expand Up @@ -392,7 +392,7 @@ func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
func (tx Tx) ExecContext(
ctx context.Context, query string, args ...interface{},
) (sql.Result, error) {
ctx, event := tx.db.beforeQuery(ctx, nil, query, args)
ctx, event := tx.db.beforeQuery(ctx, nil, query, args, nil)
res, err := tx.Tx.ExecContext(ctx, tx.db.format(query, args))
tx.db.afterQuery(ctx, event, res, err)
return res, err
Expand All @@ -405,7 +405,7 @@ func (tx Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
func (tx Tx) QueryContext(
ctx context.Context, query string, args ...interface{},
) (*sql.Rows, error) {
ctx, event := tx.db.beforeQuery(ctx, nil, query, args)
ctx, event := tx.db.beforeQuery(ctx, nil, query, args, nil)
rows, err := tx.Tx.QueryContext(ctx, tx.db.format(query, args))
tx.db.afterQuery(ctx, event, nil, err)
return rows, err
Expand All @@ -416,7 +416,7 @@ func (tx Tx) QueryRow(query string, args ...interface{}) *sql.Row {
}

func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
ctx, event := tx.db.beforeQuery(ctx, nil, query, args)
ctx, event := tx.db.beforeQuery(ctx, nil, query, args, nil)
row := tx.Tx.QueryRowContext(ctx, tx.db.format(query, args))
tx.db.afterQuery(ctx, event, nil, row.Err())
return row
Expand Down
3 changes: 3 additions & 0 deletions hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type QueryEvent struct {
QueryAppender schema.Query
Query string
QueryArgs []interface{}
Model Model

StartTime time.Time
Result sql.Result
Expand Down Expand Up @@ -52,6 +53,7 @@ func (db *DB) beforeQuery(
queryApp schema.Query,
query string,
queryArgs []interface{},
model Model,
) (context.Context, *QueryEvent) {
atomic.AddUint32(&db.stats.Queries, 1)

Expand All @@ -62,6 +64,7 @@ func (db *DB) beforeQuery(
event := &QueryEvent{
DB: db,

Model: model,
QueryAppender: queryApp,
Query: query,
QueryArgs: queryArgs,
Expand Down
4 changes: 2 additions & 2 deletions query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ func (q *baseQuery) scan(
model Model,
hasDest bool,
) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil)
ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil, q.model)

rows, err := q.conn.QueryContext(ctx, query)
if err != nil {
Expand Down Expand Up @@ -491,7 +491,7 @@ func (q *baseQuery) exec(
queryApp schema.Query,
query string,
) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil)
ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil, q.model)

res, err := q.conn.ExecContext(ctx, query)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) {
}

query := internal.String(queryBytes)
ctx, event := q.db.beforeQuery(ctx, qq, query, nil)
ctx, event := q.db.beforeQuery(ctx, qq, query, nil, q.model)

var num int
err = q.conn.QueryRowContext(ctx, query).Scan(&num)
Expand Down Expand Up @@ -803,7 +803,7 @@ func (q *SelectQuery) Exists(ctx context.Context) (bool, error) {
}

query := internal.String(queryBytes)
ctx, event := q.db.beforeQuery(ctx, qq, query, nil)
ctx, event := q.db.beforeQuery(ctx, qq, query, nil, q.model)

var exists bool
err = q.conn.QueryRowContext(ctx, query).Scan(&exists)
Expand Down

0 comments on commit 7688201

Please sign in to comment.