Skip to content

Commit

Permalink
support QueryerContext interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Stefan Tudose committed Oct 14, 2021
1 parent 489a697 commit 0580e9d
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ func (conn *Conn) Exec(query string, args []driver.Value) (driver.Result, error)
}

// ExecContext calls the original ExecContext (or Exec as a fallback) method of the connection.
// It will trigger PreExec, PostExec hooks.
// It will trigger PreExec, Exec, PostExec hooks.
//
// If the original connection doesn't satisfy "database/sql/driver".ExecerContext nor "database/sql/driver".Execer, it return ErrSkip error.
// If the original connection does not satisfy "database/sql/driver".ExecerContext nor "database/sql/driver".Execer, it return ErrSkip error.
func (conn *Conn) ExecContext(c context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
execer, exOk := conn.Conn.(driver.Execer)
execerCtx, exCtxOk := conn.Conn.(driver.ExecerContext)
Expand Down Expand Up @@ -257,10 +257,11 @@ func (conn *Conn) Query(query string, args []driver.Value) (driver.Rows, error)
// QueryContext executes a query that may return rows.
// It wil trigger PreQuery, Query, PostQuery hooks.
//
// If the original connection does not satisfy "database/sql/driver".Queryer, it return ErrSkip error.
// If the original connection does not satisfy "database/sql/driver".QueryerContext nor "database/sql/driver".Queryer, it return ErrSkip error.
func (conn *Conn) QueryContext(c context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
queryer, ok := conn.Conn.(driver.Queryer)
if !ok {
queryer, qok := conn.Conn.(driver.Queryer)
queryerCtx, qCtxOk := conn.Conn.(driver.QueryerContext)
if !qok && !qCtxOk {
return nil, driver.ErrSkip
}

Expand All @@ -281,7 +282,7 @@ func (conn *Conn) QueryContext(c context.Context, query string, args []driver.Na
}

// call the original method.
if queryerCtx, ok := conn.Conn.(driver.QueryerContext); ok {
if queryerCtx != nil {
rows, err = queryerCtx.QueryContext(c, stmt.QueryString, args)
} else {
select {
Expand Down

0 comments on commit 0580e9d

Please sign in to comment.