From 5f8002fec7fc079d5a922f12f5df293f9311503a Mon Sep 17 00:00:00 2001 From: YangKeao Date: Wed, 15 Mar 2023 00:30:39 -0400 Subject: [PATCH] session, com_stmt: store and restore the params for cursor fetch (#41441) close pingcap/tidb#40094 --- server/conn_stmt.go | 8 +++++++ server/conn_stmt_test.go | 50 ++++++++++++++++++++++++++++++++++++++++ server/driver.go | 18 +++++++++++++++ server/driver_tidb.go | 12 ++++++++++ 4 files changed, 88 insertions(+) diff --git a/server/conn_stmt.go b/server/conn_stmt.go index b6f2716c89631..47ad8177c3bce 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -304,6 +304,11 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm rs = &rsWithHooks{ResultSet: rs, onClosed: unhold} } stmt.StoreResultSet(rs) + // also store the preparedParams in the stmt, so we could restore the params in the following fetch command + // the params should have been parsed in `(&cc.ctx).ExecuteStmt(ctx, execStmt)`. + stmt.StorePreparedCtx(&PreparedStatementCtx{ + Params: vars.PreparedParams, + }) if err = cc.writeColumnInfo(rs.Columns()); err != nil { return false, err } @@ -346,6 +351,9 @@ func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err err return errors.Annotate(mysql.NewErr(mysql.ErrUnknownStmtHandler, strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch"), cc.preparedStmt2String(stmtID)) } + + cc.ctx.GetSessionVars().PreparedParams = stmt.GetPreparedCtx().Params + if topsqlstate.TopSQLEnabled() { prepareObj, _ := cc.preparedStmtID2CachePreparedStmt(stmtID) if prepareObj != nil && prepareObj.SQLDigest != nil { diff --git a/server/conn_stmt_test.go b/server/conn_stmt_test.go index 655d6de66a646..61a7d3ad3e82c 100644 --- a/server/conn_stmt_test.go +++ b/server/conn_stmt_test.go @@ -485,3 +485,53 @@ func TestCursorReadWithRCCheckTS(t *testing.T) { tk.MustExec("rollback") } } + +func TestCursorWithParams(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + srv := CreateMockServer(t, store) + srv.SetDomain(dom) + defer srv.Close() + + appendUint32 := binary.LittleEndian.AppendUint32 + ctx := context.Background() + c := CreateMockConn(t, srv).(*mockConn) + + tk := testkit.NewTestKitWithSession(t, store, c.Context().Session) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(id_1 int, id_2 int)") + tk.MustExec("insert into t values (1, 1), (1, 2)") + + stmt1, _, _, err := c.Context().Prepare("select * from t where id_1 = ? and id_2 = ?") + require.NoError(t, err) + stmt2, _, _, err := c.Context().Prepare("select * from t where id_1 = ?") + require.NoError(t, err) + + // `execute stmt1 using 1,2` with cursor + require.NoError(t, c.Dispatch(ctx, append( + appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt1.ID())), + mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, + 0x0, 0x1, 0x3, 0x0, 0x3, 0x0, + 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, + ))) + + // `execute stmt2 using 1` with cursor + require.NoError(t, c.Dispatch(ctx, append( + appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt2.ID())), + mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, + 0x0, 0x1, 0x3, 0x0, + 0x1, 0x0, 0x0, 0x0, + ))) + + // fetch stmt2 with fetch size 256 + require.NoError(t, c.Dispatch(ctx, append( + appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt2.ID())), + 0x0, 0x1, 0x0, 0x0, + ))) + + // fetch stmt1 with fetch size 256, as it has more params, if we didn't restore the parameters, it will panic. + require.NoError(t, c.Dispatch(ctx, append( + appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt1.ID())), + 0x0, 0x1, 0x0, 0x0, + ))) +} diff --git a/server/driver.go b/server/driver.go index a4a59f4cba2b5..a410d79adf763 100644 --- a/server/driver.go +++ b/server/driver.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/extension" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/chunk" ) @@ -29,6 +30,17 @@ type IDriver interface { OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState, extensions *extension.SessionExtensions) (*TiDBContext, error) } +// PreparedStatementCtx stores the context generated in `execute` statement for a prepared statement +// subsequent stmt fetching could restore the session variables from this context +type PreparedStatementCtx struct { + // Params is the params used in `execute` statement + Params variable.PreparedParams + // TODO: store and restore variables, but be careful that we'll also need to restore the variables after FETCH + // a cleaner way to solve this problem is to always reading params from a statement scope (but not session scope) + // context. But switching in/out related context is simpler on current code base, and the affected radius is more + // controllable. +} + // PreparedStatement is the interface to use a prepared statement. type PreparedStatement interface { // ID returns statement ID @@ -58,6 +70,12 @@ type PreparedStatement interface { // GetResultSet gets ResultSet associated this statement GetResultSet() ResultSet + // StorePreparedCtx stores context in `execute` statement for subsequent stmt fetching + StorePreparedCtx(ctx *PreparedStatementCtx) + + // GetPreparedParams gets the prepared params associated this statement + GetPreparedCtx() *PreparedStatementCtx + // Reset removes all bound parameters. Reset() diff --git a/server/driver_tidb.go b/server/driver_tidb.go index b37b76a75a889..175e896a07cce 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -68,6 +68,8 @@ type TiDBStatement struct { ctx *TiDBContext rs ResultSet sql string + + preparedStatementCtx *PreparedStatementCtx } // ID implements PreparedStatement ID method. @@ -142,6 +144,16 @@ func (ts *TiDBStatement) GetResultSet() ResultSet { return ts.rs } +// StorePreparedCtx implements PreparedStatement StorePreparedCtx method. +func (ts *TiDBStatement) StorePreparedCtx(ctx *PreparedStatementCtx) { + ts.preparedStatementCtx = ctx +} + +// GetPreparedCtx implements PreparedStatement GetPreparedCtx method. +func (ts *TiDBStatement) GetPreparedCtx() *PreparedStatementCtx { + return ts.preparedStatementCtx +} + // Reset implements PreparedStatement Reset method. func (ts *TiDBStatement) Reset() { for i := range ts.boundParams {