From 61cb1ab8cc82a6e106df72ac88fa781ab7e5d038 Mon Sep 17 00:00:00 2001 From: Songmu Date: Mon, 11 Jan 2021 01:13:33 +0900 Subject: [PATCH] Define PrePrepare, Prepare and PostPrepare to hook prepare --- conn.go | 34 ++++++++++++------ hooks.go | 84 ++++++++++++++++++++++++++++++++++++++++++++ logging_hook_test.go | 21 +++++++++++ proxy_test.go | 9 +++++ 4 files changed, 138 insertions(+), 10 deletions(-) diff --git a/conn.go b/conn.go index 3e0c963..e777f00 100644 --- a/conn.go +++ b/conn.go @@ -50,17 +50,30 @@ func (conn *Conn) Prepare(query string) (driver.Stmt, error) { // PrepareContext returns a prepared statement which is wrapped by Stmt. func (conn *Conn) PrepareContext(c context.Context, query string) (driver.Stmt, error) { - var stmt driver.Stmt + var ctx interface{} + var stmt = &Stmt{ + QueryString: query, + Proxy: conn.Proxy, + Conn: conn, + } var err error + hooks := conn.Proxy.getHooks(c) + if hooks != nil { + defer func() { hooks.postPrepare(c, ctx, stmt, err) }() + if ctx, err = hooks.prePrepare(c, stmt); err != nil { + return nil, err + } + } + if connCtx, ok := conn.Conn.(driver.ConnPrepareContext); ok { - stmt, err = connCtx.PrepareContext(c, query) + stmt.Stmt, err = connCtx.PrepareContext(c, stmt.QueryString) } else { - stmt, err = conn.Conn.Prepare(query) + stmt.Stmt, err = conn.Conn.Prepare(stmt.QueryString) if err == nil { select { default: case <-c.Done(): - stmt.Close() + stmt.Stmt.Close() return nil, c.Err() } } @@ -68,12 +81,13 @@ func (conn *Conn) PrepareContext(c context.Context, query string) (driver.Stmt, if err != nil { return nil, err } - return &Stmt{ - Stmt: stmt, - QueryString: query, - Proxy: conn.Proxy, - Conn: conn, - }, nil + + if hooks != nil { + if err = hooks.prepare(c, ctx, stmt); err != nil { + return nil, err + } + } + return stmt, nil } // Close calls the original Close method. diff --git a/hooks.go b/hooks.go index 1d0a8f6..084b577 100644 --- a/hooks.go +++ b/hooks.go @@ -17,6 +17,9 @@ type hooks interface { preOpen(c context.Context, name string) (interface{}, error) open(c context.Context, ctx interface{}, conn *Conn) error postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error + prePrepare(c context.Context, stmt *Stmt) (interface{}, error) + prepare(c context.Context, ctx interface{}, stmt *Stmt) error + postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) exec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error @@ -109,6 +112,36 @@ type HooksContext struct { // `Hooks.PreOpen` method, and may be nil. PostOpen func(c context.Context, ctx interface{}, conn *Conn, err error) error + // PrePrepare is a callback that gets called prior to calling + // `db.Prepare`, and is ALWAYS called. If this callback returns an + // error, the underlying driver's `db.Exec` and `Hooks.Prepare` methods + // are not called. + // + // The first return value is passed to both `Hooks.Prepare` and + // `Hooks.PostPrepare` callbacks. You may specify anything you want. + // Return nil if you do not need to use it. + // + // The second return value is indicates the error found while + // executing this hook. + PrePrepare func(c context.Context, stmt *Stmt) (interface{}, error) + + // Prepare is called after the underlying driver's `db.Prepare` method + // returns without any errors. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PrePrepare` method, and may be nil. + // + // If this callback returns an error, then the error from this + // callback is returned by the `db.Prepare` method. + Prepare func(c context.Context, ctx interface{}, stmt *Stmt) error + + // PostPrepare is a callback that gets called at the end of + // the call to `db.Prepare`. It is ALWAYS called. + // + // The `ctx` parameter is the return value supplied from the + // `Hooks.PrePrepare` method, and may be nil. + PostPrepare func(c context.Context, ctx interface{}, stmt *Stmt, err error) error + // PreExec is a callback that gets called prior to calling // `Stmt.Exec`, and is ALWAYS called. If this callback returns an // error, the underlying driver's `Stmt.Exec` and `Hooks.Exec` methods @@ -405,6 +438,27 @@ func (h *HooksContext) postOpen(c context.Context, ctx interface{}, conn *Conn, return h.PostOpen(c, ctx, conn, err) } +func (h *HooksContext) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) { + if h == nil || h.PrePrepare == nil { + return nil, nil + } + return h.PrePrepare(c, stmt) +} + +func (h *HooksContext) prepare(c context.Context, ctx interface{}, stmt *Stmt) error { + if h == nil || h.Prepare == nil { + return nil + } + return h.Prepare(c, ctx, stmt) +} + +func (h *HooksContext) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error { + if h == nil || h.PostPrepare == nil { + return nil + } + return h.PostPrepare(c, ctx, stmt, err) +} + func (h *HooksContext) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { if h == nil || h.PreExec == nil { return nil, nil @@ -929,6 +983,18 @@ func (h *Hooks) postOpen(c context.Context, ctx interface{}, conn *Conn, err err return h.PostOpen(ctx, conn) } +func (h *Hooks) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) { + return nil, nil +} + +func (h *Hooks) prepare(c context.Context, ctx interface{}, stmt *Stmt) error { + return nil +} + +func (h *Hooks) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error { + return nil +} + func (h *Hooks) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { if h == nil || h.PreExec == nil { return nil, nil @@ -1187,6 +1253,24 @@ func (h multipleHooks) postOpen(c context.Context, ctx interface{}, conn *Conn, }) } +func (h multipleHooks) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) { + return h.preDo(func(h hooks) (interface{}, error) { + return h.prePrepare(c, stmt) + }) +} + +func (h multipleHooks) prepare(c context.Context, ctx interface{}, stmt *Stmt) error { + return h.do(ctx, func(h hooks, ctx interface{}) error { + return h.prepare(c, ctx, stmt) + }) +} + +func (h multipleHooks) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error { + return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error { + return h.postPrepare(c, ctx, stmt, err) + }) +} + func (h multipleHooks) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { return h.preDo(func(h hooks) (interface{}, error) { return h.preExec(c, stmt, args) diff --git a/logging_hook_test.go b/logging_hook_test.go index 1dfc86b..c7d20ac 100644 --- a/logging_hook_test.go +++ b/logging_hook_test.go @@ -61,6 +61,27 @@ func (h *loggingHook) postOpen(c context.Context, ctx interface{}, conn *Conn, e return nil } +func (h *loggingHook) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) { + h.mu.Lock() + defer h.mu.Unlock() + fmt.Fprintln(h, "[PrePrepare]") + return nil, nil +} + +func (h *loggingHook) prepare(c context.Context, ctx interface{}, stmt *Stmt) error { + h.mu.Lock() + defer h.mu.Unlock() + fmt.Fprintln(h, "[Prepare]") + return nil +} + +func (h *loggingHook) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error { + h.mu.Lock() + defer h.mu.Unlock() + fmt.Fprintln(h, "[PostPrepare]") + return nil +} + func (h *loggingHook) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) { h.mu.Lock() defer h.mu.Unlock() diff --git a/proxy_test.go b/proxy_test.go index 1339abc..4a6d887 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -36,6 +36,7 @@ func TestFakeDB(t *testing.T) { Name: "execAll", }, hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" + + "[PrePrepare]\n[Prepare]\n[PostPrepare]\n" + "[PreExec]\n[Exec]\n[PostExec]\n" + "[PreClose]\n[Close]\n[PostClose]\n", f: func(db *sql.DB) error { @@ -49,6 +50,7 @@ func TestFakeDB(t *testing.T) { FailExec: true, }, hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" + + "[PrePrepare]\n[Prepare]\n[PostPrepare]\n" + "[PreExec]\n[PostExec]\n" + "[PreClose]\n[Close]\n[PostClose]\n", f: func(db *sql.DB) error { @@ -64,6 +66,7 @@ func TestFakeDB(t *testing.T) { Name: "execError-NamedValue", }, hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" + + "[PrePrepare]\n[Prepare]\n[PostPrepare]\n" + "[PreExec]\n[PostExec]\n" + "[PreClose]\n[Close]\n[PostClose]\n", f: func(db *sql.DB) error { @@ -80,6 +83,7 @@ func TestFakeDB(t *testing.T) { Name: "queryAll", }, hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" + + "[PrePrepare]\n[Prepare]\n[PostPrepare]\n" + "[PreQuery]\n[Query]\n[PostQuery]\n", f: func(db *sql.DB) error { _, err := db.Query("SELECT * FROM test WHERE id = ?", 123456789) @@ -92,6 +96,7 @@ func TestFakeDB(t *testing.T) { FailQuery: true, }, hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" + + "[PrePrepare]\n[Prepare]\n[PostPrepare]\n" + "[PreQuery]\n[PostQuery]\n" + "[PreClose]\n[Close]\n[PostClose]\n", f: func(db *sql.DB) error { @@ -107,6 +112,7 @@ func TestFakeDB(t *testing.T) { Name: "prepare", }, hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" + + "[PrePrepare]\n[Prepare]\n[PostPrepare]\n" + "[PreClose]\n[Close]\n[PostClose]\n", f: func(db *sql.DB) error { stmt, err := db.Prepare("SELECT * FROM test WHERE id = ?") @@ -255,6 +261,7 @@ func TestFakeDB(t *testing.T) { ConnType: "fakeConnExt", }, hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" + + "[PrePrepare]\n[Prepare]\n[PostPrepare]\n" + "[PreExec]\n[Exec]\n[PostExec]\n" + "[PreClose]\n[Close]\n[PostClose]\n", f: func(db *sql.DB) error { @@ -325,6 +332,7 @@ func TestFakeDB(t *testing.T) { ConnType: "fakeConnCtx", }, hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" + + "[PrePrepare]\n[Prepare]\n[PostPrepare]\n" + "[PreExec]\n[Exec]\n[PostExec]\n" + "[PreClose]\n[Close]\n[PostClose]\n", f: func(db *sql.DB) error { @@ -343,6 +351,7 @@ func TestFakeDB(t *testing.T) { ConnType: "fakeConnCtx", }, hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" + + "[PrePrepare]\n[Prepare]\n[PostPrepare]\n" + "[PreQuery]\n[Query]\n[PostQuery]\n", f: func(db *sql.DB) error { stmt, err := db.Prepare("SELECT * FROM test WHERE id = ?")