diff --git a/expression/expression.go b/expression/expression.go index d1493aed92788..30897624b82f3 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -1565,6 +1565,8 @@ func Args2Expressions4Test(args ...interface{}) []Expression { ft = types.NewFieldType(mysql.TypeDouble) case types.KindString: ft = types.NewFieldType(mysql.TypeVarString) + case types.KindMysqlTime: + ft = types.NewFieldType(mysql.TypeTimestamp) default: exprs[i] = nil continue diff --git a/planner/core/memtable_predicate_extractor.go b/planner/core/memtable_predicate_extractor.go index 1b8d953518b49..b4fa83d335a36 100644 --- a/planner/core/memtable_predicate_extractor.go +++ b/planner/core/memtable_predicate_extractor.go @@ -83,10 +83,14 @@ func (extractHelper) extractColInConsExpr(extractCols map[int64]*types.FieldName results := make([]types.Datum, 0, len(args[1:])) for _, arg := range args[1:] { constant, ok := arg.(*expression.Constant) - if !ok || constant.DeferredExpr != nil || constant.ParamMarker != nil { + if !ok || constant.DeferredExpr != nil { return "", nil } - results = append(results, constant.Value) + v := constant.Value + if constant.ParamMarker != nil { + v = constant.ParamMarker.GetUserVar() + } + results = append(results, v) } return name.ColName.L, results } @@ -117,10 +121,14 @@ func (extractHelper) extractColBinaryOpConsExpr(extractCols map[int64]*types.Fie // SELECT * FROM t1 WHERE c='rhs' // SELECT * FROM t1 WHERE 'lhs'=c constant, ok := args[1-colIdx].(*expression.Constant) - if !ok || constant.DeferredExpr != nil || constant.ParamMarker != nil { + if !ok || constant.DeferredExpr != nil { return "", nil } - return name.ColName.L, []types.Datum{constant.Value} + v := constant.Value + if constant.ParamMarker != nil { + v = constant.ParamMarker.GetUserVar() + } + return name.ColName.L, []types.Datum{v} } // extract the OR expression, e.g: diff --git a/planner/core/memtable_predicate_extractor_test.go b/planner/core/memtable_predicate_extractor_test.go index 97d55b09ed8ec..ecf7c47218ab9 100644 --- a/planner/core/memtable_predicate_extractor_test.go +++ b/planner/core/memtable_predicate_extractor_test.go @@ -25,10 +25,14 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/planner" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/hint" "github.com/pingcap/tidb/util/set" "github.com/stretchr/testify/require" @@ -1749,3 +1753,112 @@ PARTITION BY RANGE COLUMNS ( id ) ( require.Equal(t, ca.tableIDs, tableids) } } + +func TestExtractorInPreparedStmt(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + + var cases = []struct { + prepared string + userVars []interface{} + params []interface{} + checker func(extractor plannercore.MemTablePredicateExtractor) + }{ + { + prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id = ?", + userVars: []interface{}{1}, + params: []interface{}{1}, + checker: func(extractor plannercore.MemTablePredicateExtractor) { + rse := extractor.(*plannercore.TiKVRegionStatusExtractor) + tableids := rse.GetTablesID() + slices.Sort(tableids) + require.Equal(t, []int64{1}, tableids) + }, + }, + { + prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id = ? or table_id = ?", + userVars: []interface{}{1, 2}, + params: []interface{}{1, 2}, + checker: func(extractor plannercore.MemTablePredicateExtractor) { + rse := extractor.(*plannercore.TiKVRegionStatusExtractor) + tableids := rse.GetTablesID() + slices.Sort(tableids) + require.Equal(t, []int64{1, 2}, tableids) + }, + }, + { + prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id in (?,?)", + userVars: []interface{}{1, 2}, + params: []interface{}{1, 2}, + checker: func(extractor plannercore.MemTablePredicateExtractor) { + rse := extractor.(*plannercore.TiKVRegionStatusExtractor) + tableids := rse.GetTablesID() + slices.Sort(tableids) + require.Equal(t, []int64{1, 2}, tableids) + }, + }, + { + prepared: "select * from information_schema.COLUMNS where table_name like ?", + userVars: []interface{}{`"a%"`}, + params: []interface{}{"a%"}, + checker: func(extractor plannercore.MemTablePredicateExtractor) { + rse := extractor.(*plannercore.ColumnsTableExtractor) + require.EqualValues(t, []string{"a%"}, rse.TableNamePatterns) + }, + }, + { + prepared: "select * from information_schema.tidb_hot_regions_history where update_time>=?", + userVars: []interface{}{"cast('2019-10-10 10:10:10' as datetime)"}, + params: []interface{}{func() types.Time { + tt, err := types.ParseTimestamp(tk.Session().GetSessionVars().StmtCtx, "2019-10-10 10:10:10") + require.NoError(t, err) + return tt + }()}, + checker: func(extractor plannercore.MemTablePredicateExtractor) { + rse := extractor.(*plannercore.HotRegionsHistoryTableExtractor) + require.Equal(t, timestamp(t, "2019-10-10 10:10:10"), rse.StartTime) + }, + }, + } + + // text protocol + parser := parser.New() + for _, ca := range cases { + tk.MustExec(fmt.Sprintf("prepare stmt from '%s'", ca.prepared)) + setStmt := "set " + exec := "execute stmt using " + for i, uv := range ca.userVars { + name := fmt.Sprintf("@a%d", i) + setStmt += fmt.Sprintf("%s=%v", name, uv) + exec += name + if i != len(ca.userVars)-1 { + setStmt += "," + exec += "," + } + } + tk.MustExec(setStmt) + stmt, err := parser.ParseOneStmt(exec, "", "") + require.NoError(t, err) + plan, _, err := planner.OptimizeExecStmt(context.Background(), tk.Session(), stmt.(*ast.ExecuteStmt), dom.InfoSchema()) + require.NoError(t, err) + extractor := plan.(*plannercore.Execute).Plan.(*plannercore.PhysicalMemTable).Extractor + ca.checker(extractor) + } + + // binary protocol + for _, ca := range cases { + id, _, _, err := tk.Session().PrepareStmt(ca.prepared) + require.NoError(t, err) + prepStmt, err := tk.Session().GetSessionVars().GetPreparedStmtByID(id) + require.NoError(t, err) + params := expression.Args2Expressions4Test(ca.params...) + execStmt := &ast.ExecuteStmt{ + BinaryArgs: params, + PrepStmt: prepStmt, + } + plan, _, err := planner.OptimizeExecStmt(context.Background(), tk.Session(), execStmt, dom.InfoSchema()) + require.NoError(t, err) + extractor := plan.(*plannercore.Execute).Plan.(*plannercore.PhysicalMemTable).Extractor + ca.checker(extractor) + } +}