From 489a69751d12e3523142f168362c31e26daba818 Mon Sep 17 00:00:00 2001 From: Stefan Tudose Date: Wed, 13 Oct 2021 19:18:48 +0200 Subject: [PATCH 1/2] support ExecerContext interface --- conn.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/conn.go b/conn.go index e777f00..31b54d5 100644 --- a/conn.go +++ b/conn.go @@ -189,13 +189,14 @@ func (conn *Conn) Exec(query string, args []driver.Value) (driver.Result, error) panic("not supported") } -// ExecContext calls the original Exec method of the connection. -// It will trigger PreExec, Exec, PostExec hooks. +// ExecContext calls the original ExecContext (or Exec as a fallback) method of the connection. +// It will trigger PreExec, PostExec hooks. // -// If the original connection does not satisfy "database/sql/driver".Execer, it return ErrSkip error. +// If the original connection doesn't satisfy "database/sql/driver".ExecerContext nor "database/sql/driver".Execer, it return ErrSkip error. func (conn *Conn) ExecContext(c context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - execer, ok := conn.Conn.(driver.Execer) - if !ok { + execer, exOk := conn.Conn.(driver.Execer) + execerCtx, exCtxOk := conn.Conn.(driver.ExecerContext) + if !exOk && !exCtxOk { return nil, driver.ErrSkip } @@ -217,7 +218,7 @@ func (conn *Conn) ExecContext(c context.Context, query string, args []driver.Nam } // call the original method. - if execerCtx, ok := execer.(driver.ExecerContext); ok { + if execerCtx != nil { result, err = execerCtx.ExecContext(c, stmt.QueryString, args) } else { select { From 0580e9dfc88ecbacb14c886a773d81cd82a848ce Mon Sep 17 00:00:00 2001 From: Stefan Tudose Date: Wed, 13 Oct 2021 20:21:41 +0200 Subject: [PATCH 2/2] support QueryerContext interface --- conn.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/conn.go b/conn.go index 31b54d5..86bef12 100644 --- a/conn.go +++ b/conn.go @@ -190,9 +190,9 @@ func (conn *Conn) Exec(query string, args []driver.Value) (driver.Result, error) } // ExecContext calls the original ExecContext (or Exec as a fallback) method of the connection. -// It will trigger PreExec, PostExec hooks. +// It will trigger PreExec, Exec, PostExec hooks. // -// If the original connection doesn't satisfy "database/sql/driver".ExecerContext nor "database/sql/driver".Execer, it return ErrSkip error. +// If the original connection does not satisfy "database/sql/driver".ExecerContext nor "database/sql/driver".Execer, it return ErrSkip error. func (conn *Conn) ExecContext(c context.Context, query string, args []driver.NamedValue) (driver.Result, error) { execer, exOk := conn.Conn.(driver.Execer) execerCtx, exCtxOk := conn.Conn.(driver.ExecerContext) @@ -257,10 +257,11 @@ func (conn *Conn) Query(query string, args []driver.Value) (driver.Rows, error) // QueryContext executes a query that may return rows. // It wil trigger PreQuery, Query, PostQuery hooks. // -// If the original connection does not satisfy "database/sql/driver".Queryer, it return ErrSkip error. +// If the original connection does not satisfy "database/sql/driver".QueryerContext nor "database/sql/driver".Queryer, it return ErrSkip error. func (conn *Conn) QueryContext(c context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - queryer, ok := conn.Conn.(driver.Queryer) - if !ok { + queryer, qok := conn.Conn.(driver.Queryer) + queryerCtx, qCtxOk := conn.Conn.(driver.QueryerContext) + if !qok && !qCtxOk { return nil, driver.ErrSkip } @@ -281,7 +282,7 @@ func (conn *Conn) QueryContext(c context.Context, query string, args []driver.Na } // call the original method. - if queryerCtx, ok := conn.Conn.(driver.QueryerContext); ok { + if queryerCtx != nil { rows, err = queryerCtx.QueryContext(c, stmt.QueryString, args) } else { select {