From 053057005e55083bf91d95194d5e475f88f09c2f Mon Sep 17 00:00:00 2001 From: djshow832 <873581766@qq.com> Date: Fri, 30 Dec 2022 19:08:24 +0800 Subject: [PATCH 1/3] support prepare --- planner/core/memtable_predicate_extractor.go | 16 +++-- .../core/memtable_predicate_extractor_test.go | 70 +++++++++++++++++++ 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/planner/core/memtable_predicate_extractor.go b/planner/core/memtable_predicate_extractor.go index 1b8d953518b49..b4fa83d335a36 100644 --- a/planner/core/memtable_predicate_extractor.go +++ b/planner/core/memtable_predicate_extractor.go @@ -83,10 +83,14 @@ func (extractHelper) extractColInConsExpr(extractCols map[int64]*types.FieldName results := make([]types.Datum, 0, len(args[1:])) for _, arg := range args[1:] { constant, ok := arg.(*expression.Constant) - if !ok || constant.DeferredExpr != nil || constant.ParamMarker != nil { + if !ok || constant.DeferredExpr != nil { return "", nil } - results = append(results, constant.Value) + v := constant.Value + if constant.ParamMarker != nil { + v = constant.ParamMarker.GetUserVar() + } + results = append(results, v) } return name.ColName.L, results } @@ -117,10 +121,14 @@ func (extractHelper) extractColBinaryOpConsExpr(extractCols map[int64]*types.Fie // SELECT * FROM t1 WHERE c='rhs' // SELECT * FROM t1 WHERE 'lhs'=c constant, ok := args[1-colIdx].(*expression.Constant) - if !ok || constant.DeferredExpr != nil || constant.ParamMarker != nil { + if !ok || constant.DeferredExpr != nil { return "", nil } - return name.ColName.L, []types.Datum{constant.Value} + v := constant.Value + if constant.ParamMarker != nil { + v = constant.ParamMarker.GetUserVar() + } + return name.ColName.L, []types.Datum{v} } // extract the OR expression, e.g: diff --git a/planner/core/memtable_predicate_extractor_test.go b/planner/core/memtable_predicate_extractor_test.go index 97d55b09ed8ec..6bf62b80eaa59 100644 --- a/planner/core/memtable_predicate_extractor_test.go +++ b/planner/core/memtable_predicate_extractor_test.go @@ -26,6 +26,8 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/planner" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/testkit" @@ -1749,3 +1751,71 @@ PARTITION BY RANGE COLUMNS ( id ) ( require.Equal(t, ca.tableIDs, tableids) } } + +func TestExtractorInPreparedStmt(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(fmt.Sprintf(`set @a=1, @b=2, @c="a%%", @d=cast('2019-10-10 10:10:10' as datetime)`)) + + var cases = []struct { + prepared string + exec string + checker func(extractor plannercore.MemTablePredicateExtractor) + }{ + { + prepared: "prepare stmt from 'select * from information_schema.TIKV_REGION_STATUS where table_id = ?'", + exec: "execute stmt using @a", + checker: func(extractor plannercore.MemTablePredicateExtractor) { + rse := extractor.(*plannercore.TiKVRegionStatusExtractor) + tableids := rse.GetTablesID() + slices.Sort(tableids) + require.Equal(t, []int64{1}, tableids) + }, + }, + { + prepared: "prepare stmt from 'select * from information_schema.TIKV_REGION_STATUS where table_id = ? or table_id = ?'", + exec: "execute stmt using @a, @b", + checker: func(extractor plannercore.MemTablePredicateExtractor) { + rse := extractor.(*plannercore.TiKVRegionStatusExtractor) + tableids := rse.GetTablesID() + slices.Sort(tableids) + require.Equal(t, []int64{1, 2}, tableids) + }, + }, + { + prepared: "prepare stmt from 'select * from information_schema.TIKV_REGION_STATUS where table_id in (?,?)'", + exec: "execute stmt using @a, @b", + checker: func(extractor plannercore.MemTablePredicateExtractor) { + rse := extractor.(*plannercore.TiKVRegionStatusExtractor) + tableids := rse.GetTablesID() + slices.Sort(tableids) + require.Equal(t, []int64{1, 2}, tableids) + }, + }, + { + prepared: "prepare stmt from 'select * from information_schema.COLUMNS where table_name like ?'", + exec: "execute stmt using @c", + checker: func(extractor plannercore.MemTablePredicateExtractor) { + rse := extractor.(*plannercore.ColumnsTableExtractor) + require.EqualValues(t, []string{"a%"}, rse.TableNamePatterns) + }, + }, + { + prepared: "prepare stmt from 'select * from information_schema.tidb_hot_regions_history where update_time>=?'", + exec: "execute stmt using @d", + checker: func(extractor plannercore.MemTablePredicateExtractor) { + rse := extractor.(*plannercore.HotRegionsHistoryTableExtractor) + require.Equal(t, timestamp(t, "2019-10-10 10:10:10"), rse.StartTime) + }, + }, + } + parser := parser.New() + for _, ca := range cases { + tk.MustExec(ca.prepared) + stmt, err := parser.ParseOneStmt(ca.exec, "", "") + require.NoError(t, err) + plan, _, err := planner.OptimizeExecStmt(context.Background(), tk.Session(), stmt.(*ast.ExecuteStmt), dom.InfoSchema()) + extractor := plan.(*plannercore.Execute).Plan.(*plannercore.PhysicalMemTable).Extractor + ca.checker(extractor) + } +} From 21e90651f8d1e6754ec9c578ea74bf62bc92b129 Mon Sep 17 00:00:00 2001 From: djshow832 <873581766@qq.com> Date: Fri, 30 Dec 2022 20:48:18 +0800 Subject: [PATCH 2/3] add binary test --- expression/expression.go | 2 + .../core/memtable_predicate_extractor_test.go | 69 +++++++++++++++---- 2 files changed, 57 insertions(+), 14 deletions(-) diff --git a/expression/expression.go b/expression/expression.go index 00697c2df68ea..a942842bad5e6 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -1574,6 +1574,8 @@ func Args2Expressions4Test(args ...interface{}) []Expression { ft = types.NewFieldType(mysql.TypeDouble) case types.KindString: ft = types.NewFieldType(mysql.TypeVarString) + case types.KindMysqlTime: + ft = types.NewFieldType(mysql.TypeTimestamp) default: exprs[i] = nil continue diff --git a/planner/core/memtable_predicate_extractor_test.go b/planner/core/memtable_predicate_extractor_test.go index 6bf62b80eaa59..829eabe91b088 100644 --- a/planner/core/memtable_predicate_extractor_test.go +++ b/planner/core/memtable_predicate_extractor_test.go @@ -25,12 +25,14 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/planner" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/hint" "github.com/pingcap/tidb/util/set" "github.com/stretchr/testify/require" @@ -1755,16 +1757,17 @@ PARTITION BY RANGE COLUMNS ( id ) ( func TestExtractorInPreparedStmt(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomain(t) tk := testkit.NewTestKit(t, store) - tk.MustExec(fmt.Sprintf(`set @a=1, @b=2, @c="a%%", @d=cast('2019-10-10 10:10:10' as datetime)`)) var cases = []struct { prepared string - exec string + userVars []interface{} + params []interface{} checker func(extractor plannercore.MemTablePredicateExtractor) }{ { - prepared: "prepare stmt from 'select * from information_schema.TIKV_REGION_STATUS where table_id = ?'", - exec: "execute stmt using @a", + prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id = ?", + userVars: []interface{}{1}, + params: []interface{}{1}, checker: func(extractor plannercore.MemTablePredicateExtractor) { rse := extractor.(*plannercore.TiKVRegionStatusExtractor) tableids := rse.GetTablesID() @@ -1773,8 +1776,9 @@ func TestExtractorInPreparedStmt(t *testing.T) { }, }, { - prepared: "prepare stmt from 'select * from information_schema.TIKV_REGION_STATUS where table_id = ? or table_id = ?'", - exec: "execute stmt using @a, @b", + prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id = ? or table_id = ?", + userVars: []interface{}{1, 2}, + params: []interface{}{1, 2}, checker: func(extractor plannercore.MemTablePredicateExtractor) { rse := extractor.(*plannercore.TiKVRegionStatusExtractor) tableids := rse.GetTablesID() @@ -1783,8 +1787,9 @@ func TestExtractorInPreparedStmt(t *testing.T) { }, }, { - prepared: "prepare stmt from 'select * from information_schema.TIKV_REGION_STATUS where table_id in (?,?)'", - exec: "execute stmt using @a, @b", + prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id in (?,?)", + userVars: []interface{}{1, 2}, + params: []interface{}{1, 2}, checker: func(extractor plannercore.MemTablePredicateExtractor) { rse := extractor.(*plannercore.TiKVRegionStatusExtractor) tableids := rse.GetTablesID() @@ -1793,29 +1798,65 @@ func TestExtractorInPreparedStmt(t *testing.T) { }, }, { - prepared: "prepare stmt from 'select * from information_schema.COLUMNS where table_name like ?'", - exec: "execute stmt using @c", + prepared: "select * from information_schema.COLUMNS where table_name like ?", + userVars: []interface{}{`"a%"`}, + params: []interface{}{"a%"}, checker: func(extractor plannercore.MemTablePredicateExtractor) { rse := extractor.(*plannercore.ColumnsTableExtractor) require.EqualValues(t, []string{"a%"}, rse.TableNamePatterns) }, }, { - prepared: "prepare stmt from 'select * from information_schema.tidb_hot_regions_history where update_time>=?'", - exec: "execute stmt using @d", + prepared: "select * from information_schema.tidb_hot_regions_history where update_time>=?", + userVars: []interface{}{"cast('2019-10-10 10:10:10' as datetime)"}, + params: []interface{}{func() types.Time { + tt, err := types.ParseTimestamp(tk.Session().GetSessionVars().StmtCtx, "2019-10-10 10:10:10") + require.NoError(t, err) + return tt + }()}, checker: func(extractor plannercore.MemTablePredicateExtractor) { rse := extractor.(*plannercore.HotRegionsHistoryTableExtractor) require.Equal(t, timestamp(t, "2019-10-10 10:10:10"), rse.StartTime) }, }, } + + // text protocol parser := parser.New() for _, ca := range cases { - tk.MustExec(ca.prepared) - stmt, err := parser.ParseOneStmt(ca.exec, "", "") + tk.MustExec(fmt.Sprintf("prepare stmt from '%s'", ca.prepared)) + setStmt := "set " + exec := "execute stmt using " + for i, uv := range ca.userVars { + name := fmt.Sprintf("@a%d", i) + setStmt += fmt.Sprintf("%s=%v", name, uv) + exec += name + if i != len(ca.userVars)-1 { + setStmt += "," + exec += "," + } + } + tk.MustExec(setStmt) + stmt, err := parser.ParseOneStmt(exec, "", "") require.NoError(t, err) plan, _, err := planner.OptimizeExecStmt(context.Background(), tk.Session(), stmt.(*ast.ExecuteStmt), dom.InfoSchema()) extractor := plan.(*plannercore.Execute).Plan.(*plannercore.PhysicalMemTable).Extractor ca.checker(extractor) } + + // binary protocol + for _, ca := range cases { + id, _, _, err := tk.Session().PrepareStmt(ca.prepared) + require.NoError(t, err) + prepStmt, err := tk.Session().GetSessionVars().GetPreparedStmtByID(id) + require.NoError(t, err) + params := expression.Args2Expressions4Test(ca.params...) + execStmt := &ast.ExecuteStmt{ + BinaryArgs: params, + PrepStmt: prepStmt, + } + plan, _, err := planner.OptimizeExecStmt(context.Background(), tk.Session(), execStmt, dom.InfoSchema()) + extractor := plan.(*plannercore.Execute).Plan.(*plannercore.PhysicalMemTable).Extractor + ca.checker(extractor) + } } From 5f493b53f2e5df59180d24f6af66aeeecb77b62c Mon Sep 17 00:00:00 2001 From: djshow832 <873581766@qq.com> Date: Fri, 30 Dec 2022 22:07:03 +0800 Subject: [PATCH 3/3] fix unused error --- planner/core/memtable_predicate_extractor_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/planner/core/memtable_predicate_extractor_test.go b/planner/core/memtable_predicate_extractor_test.go index 829eabe91b088..ecf7c47218ab9 100644 --- a/planner/core/memtable_predicate_extractor_test.go +++ b/planner/core/memtable_predicate_extractor_test.go @@ -1840,6 +1840,7 @@ func TestExtractorInPreparedStmt(t *testing.T) { stmt, err := parser.ParseOneStmt(exec, "", "") require.NoError(t, err) plan, _, err := planner.OptimizeExecStmt(context.Background(), tk.Session(), stmt.(*ast.ExecuteStmt), dom.InfoSchema()) + require.NoError(t, err) extractor := plan.(*plannercore.Execute).Plan.(*plannercore.PhysicalMemTable).Extractor ca.checker(extractor) } @@ -1856,6 +1857,7 @@ func TestExtractorInPreparedStmt(t *testing.T) { PrepStmt: prepStmt, } plan, _, err := planner.OptimizeExecStmt(context.Background(), tk.Session(), execStmt, dom.InfoSchema()) + require.NoError(t, err) extractor := plan.(*plannercore.Execute).Plan.(*plannercore.PhysicalMemTable).Extractor ca.checker(extractor) }