Skip to content

Commit

Permalink
planner: support pushing down predicates to memory tables in prepared…
Browse files Browse the repository at this point in the history
… mode (pingcap#40262)

close pingcap#39605
  • Loading branch information
djshow832 committed Jun 16, 2023
1 parent 413cbca commit e8d91df
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 4 deletions.
2 changes: 2 additions & 0 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,8 @@ func Args2Expressions4Test(args ...interface{}) []Expression {
ft = types.NewFieldType(mysql.TypeDouble)
case types.KindString:
ft = types.NewFieldType(mysql.TypeVarString)
case types.KindMysqlTime:
ft = types.NewFieldType(mysql.TypeTimestamp)
default:
exprs[i] = nil
continue
Expand Down
16 changes: 12 additions & 4 deletions planner/core/memtable_predicate_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ func (extractHelper) extractColInConsExpr(extractCols map[int64]*types.FieldName
results := make([]types.Datum, 0, len(args[1:]))
for _, arg := range args[1:] {
constant, ok := arg.(*expression.Constant)
if !ok || constant.DeferredExpr != nil || constant.ParamMarker != nil {
if !ok || constant.DeferredExpr != nil {
return "", nil
}
results = append(results, constant.Value)
v := constant.Value
if constant.ParamMarker != nil {
v = constant.ParamMarker.GetUserVar()
}
results = append(results, v)
}
return name.ColName.L, results
}
Expand Down Expand Up @@ -117,10 +121,14 @@ func (extractHelper) extractColBinaryOpConsExpr(extractCols map[int64]*types.Fie
// SELECT * FROM t1 WHERE c='rhs'
// SELECT * FROM t1 WHERE 'lhs'=c
constant, ok := args[1-colIdx].(*expression.Constant)
if !ok || constant.DeferredExpr != nil || constant.ParamMarker != nil {
if !ok || constant.DeferredExpr != nil {
return "", nil
}
return name.ColName.L, []types.Datum{constant.Value}
v := constant.Value
if constant.ParamMarker != nil {
v = constant.ParamMarker.GetUserVar()
}
return name.ColName.L, []types.Datum{v}
}

// extract the OR expression, e.g:
Expand Down
113 changes: 113 additions & 0 deletions planner/core/memtable_predicate_extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ import (

"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/planner"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/hint"
"github.com/pingcap/tidb/util/set"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -1749,3 +1753,112 @@ PARTITION BY RANGE COLUMNS ( id ) (
require.Equal(t, ca.tableIDs, tableids)
}
}

func TestExtractorInPreparedStmt(t *testing.T) {
store, dom := testkit.CreateMockStoreAndDomain(t)
tk := testkit.NewTestKit(t, store)

var cases = []struct {
prepared string
userVars []interface{}
params []interface{}
checker func(extractor plannercore.MemTablePredicateExtractor)
}{
{
prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id = ?",
userVars: []interface{}{1},
params: []interface{}{1},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.TiKVRegionStatusExtractor)
tableids := rse.GetTablesID()
slices.Sort(tableids)
require.Equal(t, []int64{1}, tableids)
},
},
{
prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id = ? or table_id = ?",
userVars: []interface{}{1, 2},
params: []interface{}{1, 2},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.TiKVRegionStatusExtractor)
tableids := rse.GetTablesID()
slices.Sort(tableids)
require.Equal(t, []int64{1, 2}, tableids)
},
},
{
prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id in (?,?)",
userVars: []interface{}{1, 2},
params: []interface{}{1, 2},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.TiKVRegionStatusExtractor)
tableids := rse.GetTablesID()
slices.Sort(tableids)
require.Equal(t, []int64{1, 2}, tableids)
},
},
{
prepared: "select * from information_schema.COLUMNS where table_name like ?",
userVars: []interface{}{`"a%"`},
params: []interface{}{"a%"},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.ColumnsTableExtractor)
require.EqualValues(t, []string{"a%"}, rse.TableNamePatterns)
},
},
{
prepared: "select * from information_schema.tidb_hot_regions_history where update_time>=?",
userVars: []interface{}{"cast('2019-10-10 10:10:10' as datetime)"},
params: []interface{}{func() types.Time {
tt, err := types.ParseTimestamp(tk.Session().GetSessionVars().StmtCtx, "2019-10-10 10:10:10")
require.NoError(t, err)
return tt
}()},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.HotRegionsHistoryTableExtractor)
require.Equal(t, timestamp(t, "2019-10-10 10:10:10"), rse.StartTime)
},
},
}

// text protocol
parser := parser.New()
for _, ca := range cases {
tk.MustExec(fmt.Sprintf("prepare stmt from '%s'", ca.prepared))
setStmt := "set "
exec := "execute stmt using "
for i, uv := range ca.userVars {
name := fmt.Sprintf("@a%d", i)
setStmt += fmt.Sprintf("%s=%v", name, uv)
exec += name
if i != len(ca.userVars)-1 {
setStmt += ","
exec += ","
}
}
tk.MustExec(setStmt)
stmt, err := parser.ParseOneStmt(exec, "", "")
require.NoError(t, err)
plan, _, err := planner.OptimizeExecStmt(context.Background(), tk.Session(), stmt.(*ast.ExecuteStmt), dom.InfoSchema())
require.NoError(t, err)
extractor := plan.(*plannercore.Execute).Plan.(*plannercore.PhysicalMemTable).Extractor
ca.checker(extractor)
}

// binary protocol
for _, ca := range cases {
id, _, _, err := tk.Session().PrepareStmt(ca.prepared)
require.NoError(t, err)
prepStmt, err := tk.Session().GetSessionVars().GetPreparedStmtByID(id)
require.NoError(t, err)
params := expression.Args2Expressions4Test(ca.params...)
execStmt := &ast.ExecuteStmt{
BinaryArgs: params,
PrepStmt: prepStmt,
}
plan, _, err := planner.OptimizeExecStmt(context.Background(), tk.Session(), execStmt, dom.InfoSchema())
require.NoError(t, err)
extractor := plan.(*plannercore.Execute).Plan.(*plannercore.PhysicalMemTable).Extractor
ca.checker(extractor)
}
}

0 comments on commit e8d91df

Please sign in to comment.