diff --git a/executor/prepared_test.go b/executor/prepared_test.go index 808cbee20a577..1f4771541361a 100644 --- a/executor/prepared_test.go +++ b/executor/prepared_test.go @@ -21,8 +21,10 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/executor" + "github.com/pingcap/tidb/metrics" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/util/testkit" + dto "github.com/prometheus/client_model/go" "golang.org/x/net/context" ) @@ -389,3 +391,179 @@ func (s *testSuite) TestPreparedIssue7579(c *C) { r.Check(nil) } } + +func (s *testSuite) TestPreparedInsert(c *C) { + orgEnable := plannercore.PreparedPlanCacheEnabled() + orgCapacity := plannercore.PreparedPlanCacheCapacity + defer func() { + plannercore.SetPreparedPlanCache(orgEnable) + plannercore.PreparedPlanCacheCapacity = orgCapacity + }() + metrics.PlanCacheCounter.Reset() + counter := metrics.PlanCacheCounter.WithLabelValues("prepare") + pb := &dto.Metric{} + flags := []bool{false, true} + for _, flag := range flags { + plannercore.SetPreparedPlanCache(flag) + plannercore.PreparedPlanCacheCapacity = 100 + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists prepare_test") + tk.MustExec("create table prepare_test (id int PRIMARY KEY, c1 int)") + tk.MustExec(`prepare stmt_insert from 'insert into prepare_test values (?, ?)'`) + tk.MustExec(`set @a=1,@b=1; execute stmt_insert using @a, @b;`) + if flag { + counter.Write(pb) + hit := pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(0)) + } + tk.MustExec(`set @a=2,@b=2; execute stmt_insert using @a, @b;`) + if flag { + counter.Write(pb) + hit := pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(1)) + } + tk.MustExec(`set @a=3,@b=3; execute stmt_insert using @a, @b;`) + if flag { + counter.Write(pb) + hit := pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(2)) + } + + result := tk.MustQuery("select id, c1 from prepare_test where id = ?", 1) + result.Check(testkit.Rows("1 1")) + result = tk.MustQuery("select id, c1 from prepare_test where id = ?", 2) + result.Check(testkit.Rows("2 2")) + result = tk.MustQuery("select id, c1 from prepare_test where id = ?", 3) + result.Check(testkit.Rows("3 3")) + + tk.MustExec(`prepare stmt_insert_select from 'insert into prepare_test (id, c1) select id + 100, c1 + 100 from prepare_test where id = ?'`) + tk.MustExec(`set @a=1; execute stmt_insert_select using @a;`) + if flag { + counter.Write(pb) + hit := pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(2)) + } + tk.MustExec(`set @a=2; execute stmt_insert_select using @a;`) + if flag { + counter.Write(pb) + hit := pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(3)) + } + tk.MustExec(`set @a=3; execute stmt_insert_select using @a;`) + if flag { + counter.Write(pb) + hit := pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(4)) + } + + result = tk.MustQuery("select id, c1 from prepare_test where id = ?", 101) + result.Check(testkit.Rows("101 101")) + result = tk.MustQuery("select id, c1 from prepare_test where id = ?", 102) + result.Check(testkit.Rows("102 102")) + result = tk.MustQuery("select id, c1 from prepare_test where id = ?", 103) + result.Check(testkit.Rows("103 103")) + } +} + +func (s *testSuite) TestPreparedUpdate(c *C) { + orgEnable := plannercore.PreparedPlanCacheEnabled() + orgCapacity := plannercore.PreparedPlanCacheCapacity + defer func() { + plannercore.SetPreparedPlanCache(orgEnable) + plannercore.PreparedPlanCacheCapacity = orgCapacity + }() + metrics.PlanCacheCounter.Reset() + counter := metrics.PlanCacheCounter.WithLabelValues("prepare") + pb := &dto.Metric{} + flags := []bool{false, true} + for _, flag := range flags { + plannercore.SetPreparedPlanCache(flag) + plannercore.PreparedPlanCacheCapacity = 100 + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists prepare_test") + tk.MustExec("create table prepare_test (id int PRIMARY KEY, c1 int)") + tk.MustExec(`insert into prepare_test values (1, 1)`) + tk.MustExec(`insert into prepare_test values (2, 2)`) + tk.MustExec(`insert into prepare_test values (3, 3)`) + + tk.MustExec(`prepare stmt_update from 'update prepare_test set c1 = c1 + ? where id = ?'`) + tk.MustExec(`set @a=1,@b=100; execute stmt_update using @b,@a;`) + if flag { + counter.Write(pb) + hit := pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(0)) + } + tk.MustExec(`set @a=2,@b=200; execute stmt_update using @b,@a;`) + if flag { + counter.Write(pb) + hit := pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(1)) + } + tk.MustExec(`set @a=3,@b=300; execute stmt_update using @b,@a;`) + if flag { + counter.Write(pb) + hit := pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(2)) + } + + result := tk.MustQuery("select id, c1 from prepare_test where id = ?", 1) + result.Check(testkit.Rows("1 101")) + result = tk.MustQuery("select id, c1 from prepare_test where id = ?", 2) + result.Check(testkit.Rows("2 202")) + result = tk.MustQuery("select id, c1 from prepare_test where id = ?", 3) + result.Check(testkit.Rows("3 303")) + } +} + +func (s *testSuite) TestPreparedDelete(c *C) { + orgEnable := plannercore.PreparedPlanCacheEnabled() + orgCapacity := plannercore.PreparedPlanCacheCapacity + defer func() { + plannercore.SetPreparedPlanCache(orgEnable) + plannercore.PreparedPlanCacheCapacity = orgCapacity + }() + metrics.PlanCacheCounter.Reset() + counter := metrics.PlanCacheCounter.WithLabelValues("prepare") + pb := &dto.Metric{} + flags := []bool{false, true} + for _, flag := range flags { + plannercore.SetPreparedPlanCache(flag) + plannercore.PreparedPlanCacheCapacity = 100 + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists prepare_test") + tk.MustExec("create table prepare_test (id int PRIMARY KEY, c1 int)") + tk.MustExec(`insert into prepare_test values (1, 1)`) + tk.MustExec(`insert into prepare_test values (2, 2)`) + tk.MustExec(`insert into prepare_test values (3, 3)`) + + tk.MustExec(`prepare stmt_delete from 'delete from prepare_test where id = ?'`) + tk.MustExec(`set @a=1; execute stmt_delete using @a;`) + if flag { + counter.Write(pb) + hit := pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(0)) + } + tk.MustExec(`set @a=2; execute stmt_delete using @a;`) + if flag { + counter.Write(pb) + hit := pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(1)) + } + tk.MustExec(`set @a=3; execute stmt_delete using @a;`) + if flag { + counter.Write(pb) + hit := pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(2)) + } + + result := tk.MustQuery("select id, c1 from prepare_test where id = ?", 1) + result.Check(nil) + result = tk.MustQuery("select id, c1 from prepare_test where id = ?", 2) + result.Check(nil) + result = tk.MustQuery("select id, c1 from prepare_test where id = ?", 3) + result.Check(nil) + } +} diff --git a/expression/constant.go b/expression/constant.go index 374bd2a434cea..90b562b7a60dc 100644 --- a/expression/constant.go +++ b/expression/constant.go @@ -106,7 +106,7 @@ func (c *Constant) Eval(_ chunk.Row) (types.Datum, error) { } val, err := dt.ConvertTo(sf.GetCtx().GetSessionVars().StmtCtx, retType) if err != nil { - return c.Value, err + return dt, err } c.Value.SetValue(val.GetValue()) } diff --git a/expression/simple_rewriter.go b/expression/simple_rewriter.go index 879d3e8121516..19bf512ab890f 100644 --- a/expression/simple_rewriter.go +++ b/expression/simple_rewriter.go @@ -150,9 +150,11 @@ func (sr *simpleRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok boo sr.inToExpression(len(v.List), v.Not, &v.Type) } case *driver.ParamMarkerExpr: - tp := types.NewFieldType(mysql.TypeUnspecified) - types.DefaultParamTypeForValue(v.GetValue(), tp) - value := &Constant{Value: v.ValueExpr.Datum, RetType: tp} + var value Expression + value, sr.err = GetParamExpression(sr.ctx, v, sr.useCache()) + if sr.err != nil { + return retNode, false + } sr.push(value) case *ast.RowExpr: sr.rowToScalarFunc(v) @@ -168,6 +170,10 @@ func (sr *simpleRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok boo return originInNode, true } +func (sr *simpleRewriter) useCache() bool { + return sr.ctx.GetSessionVars().StmtCtx.UseCache +} + func (sr *simpleRewriter) binaryOpToExpression(v *ast.BinaryOperationExpr) { right := sr.pop() left := sr.pop() diff --git a/expression/util.go b/expression/util.go index 679be67795bca..e5805fb3ddcab 100644 --- a/expression/util.go +++ b/expression/util.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" ) @@ -503,3 +504,25 @@ func ColumnSliceIsIntersect(s1, s2 []*Column) bool { } return false } + +// DatumToConstant generates a Constant expression from a Datum. +func DatumToConstant(d types.Datum, tp byte) *Constant { + return &Constant{Value: d, RetType: types.NewFieldType(tp)} +} + +// GetParamExpression generate a getparam function expression. +func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr, useCache bool) (Expression, error) { + tp := types.NewFieldType(mysql.TypeUnspecified) + types.DefaultParamTypeForValue(v.GetValue(), tp) + value := &Constant{Value: v.Datum, RetType: tp} + if useCache { + f, err := NewFunctionBase(ctx, ast.GetParam, &v.Type, + DatumToConstant(types.NewIntDatum(int64(v.Order)), mysql.TypeLonglong)) + if err != nil { + return nil, errors.Trace(err) + } + f.GetType().Tp = v.Type.Tp + value.DeferredExpr = f + } + return value, nil +} diff --git a/planner/core/cacheable_checker.go b/planner/core/cacheable_checker.go index 38f417d804456..76edbd01a2d12 100644 --- a/planner/core/cacheable_checker.go +++ b/planner/core/cacheable_checker.go @@ -21,7 +21,11 @@ import ( // Cacheable checks whether the input ast is cacheable. func Cacheable(node ast.Node) bool { - if _, isSelect := node.(*ast.SelectStmt); !isSelect { + _, isSelect := node.(*ast.SelectStmt) + _, isUpdate := node.(*ast.UpdateStmt) + _, isInsert := node.(*ast.InsertStmt) + _, isDelete := node.(*ast.DeleteStmt) + if !(isSelect || isUpdate || isInsert || isDelete) { return false } checker := cacheableChecker{ diff --git a/planner/core/cacheable_checker_test.go b/planner/core/cacheable_checker_test.go index 9b6c1367e3042..67278f5cfe99b 100644 --- a/planner/core/cacheable_checker_test.go +++ b/planner/core/cacheable_checker_test.go @@ -27,27 +27,117 @@ type testCacheableSuite struct { } func (s *testCacheableSuite) TestCacheable(c *C) { - // test non-SelectStmt - var stmt ast.Node = &ast.DeleteStmt{} + // test non-SelectStmt/-InsertStmt/-DeleteStmt/-UpdateStmt/-SelectStmt + var stmt ast.Node = &ast.UnionStmt{} c.Assert(Cacheable(stmt), IsFalse) - stmt = &ast.InsertStmt{} + stmt = &ast.ShowStmt{} + c.Assert(Cacheable(stmt), IsFalse) + + stmt = &ast.LoadDataStmt{} c.Assert(Cacheable(stmt), IsFalse) - stmt = &ast.UnionStmt{} + tableRefsClause := &ast.TableRefsClause{TableRefs: &ast.Join{Left: &ast.TableSource{Source: &ast.TableName{}}}} + // test InsertStmt + stmt = &ast.InsertStmt{Table: tableRefsClause} + c.Assert(Cacheable(stmt), IsTrue) + + // test DeleteStmt + whereExpr := &ast.FuncCallExpr{} + stmt = &ast.DeleteStmt{ + TableRefs: tableRefsClause, + Where: whereExpr, + } + c.Assert(Cacheable(stmt), IsTrue) + + for funcName := range expression.UnCacheableFunctions { + whereExpr.FnName = model.NewCIStr(funcName) + c.Assert(Cacheable(stmt), IsFalse) + } + + whereExpr.FnName = model.NewCIStr(ast.Rand) + c.Assert(Cacheable(stmt), IsTrue) + + stmt = &ast.DeleteStmt{ + TableRefs: tableRefsClause, + Where: &ast.ExistsSubqueryExpr{}, + } c.Assert(Cacheable(stmt), IsFalse) - stmt = &ast.UpdateStmt{} + limitStmt := &ast.Limit{ + Count: &driver.ParamMarkerExpr{}, + } + stmt = &ast.DeleteStmt{ + TableRefs: tableRefsClause, + Limit: limitStmt, + } c.Assert(Cacheable(stmt), IsFalse) - stmt = &ast.ShowStmt{} + limitStmt = &ast.Limit{ + Offset: &driver.ParamMarkerExpr{}, + } + stmt = &ast.DeleteStmt{ + TableRefs: tableRefsClause, + Limit: limitStmt, + } c.Assert(Cacheable(stmt), IsFalse) - stmt = &ast.LoadDataStmt{} + limitStmt = &ast.Limit{} + stmt = &ast.DeleteStmt{ + TableRefs: tableRefsClause, + Limit: limitStmt, + } + c.Assert(Cacheable(stmt), IsTrue) + + // test UpdateStmt + whereExpr = &ast.FuncCallExpr{} + stmt = &ast.UpdateStmt{ + TableRefs: tableRefsClause, + Where: whereExpr, + } + c.Assert(Cacheable(stmt), IsTrue) + + for funcName := range expression.UnCacheableFunctions { + whereExpr.FnName = model.NewCIStr(funcName) + c.Assert(Cacheable(stmt), IsFalse) + } + + whereExpr.FnName = model.NewCIStr(ast.Rand) + c.Assert(Cacheable(stmt), IsTrue) + + stmt = &ast.UpdateStmt{ + TableRefs: tableRefsClause, + Where: &ast.ExistsSubqueryExpr{}, + } + c.Assert(Cacheable(stmt), IsFalse) + + limitStmt = &ast.Limit{ + Count: &driver.ParamMarkerExpr{}, + } + stmt = &ast.UpdateStmt{ + TableRefs: tableRefsClause, + Limit: limitStmt, + } + c.Assert(Cacheable(stmt), IsFalse) + + limitStmt = &ast.Limit{ + Offset: &driver.ParamMarkerExpr{}, + } + stmt = &ast.UpdateStmt{ + TableRefs: tableRefsClause, + Limit: limitStmt, + } c.Assert(Cacheable(stmt), IsFalse) + limitStmt = &ast.Limit{} + stmt = &ast.UpdateStmt{ + TableRefs: tableRefsClause, + Limit: limitStmt, + } + c.Assert(Cacheable(stmt), IsTrue) + // test SelectStmt - whereExpr := &ast.FuncCallExpr{} + whereExpr = &ast.FuncCallExpr{} stmt = &ast.SelectStmt{ Where: whereExpr, } @@ -66,7 +156,7 @@ func (s *testCacheableSuite) TestCacheable(c *C) { } c.Assert(Cacheable(stmt), IsFalse) - limitStmt := &ast.Limit{ + limitStmt = &ast.Limit{ Count: &driver.ParamMarkerExpr{}, } stmt = &ast.SelectStmt{ diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index f38d9594b4899..bf850652505d9 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -259,6 +259,18 @@ func (e *Execute) rebuildRange(p Plan) error { return errors.Trace(err) } } + case *Insert: + if x.SelectPlan != nil { + return e.rebuildRange(x.SelectPlan) + } + case *Update: + if x.SelectPlan != nil { + return e.rebuildRange(x.SelectPlan) + } + case *Delete: + if x.SelectPlan != nil { + return e.rebuildRange(x.SelectPlan) + } } return nil } diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 991fcfc997607..9c2bf2b054315 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -757,7 +757,12 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok value := &expression.Constant{Value: v.Datum, RetType: &v.Type} er.ctxStack = append(er.ctxStack, value) case *driver.ParamMarkerExpr: - er.paramToExpression(v) + var value expression.Expression + value, er.err = expression.GetParamExpression(er.ctx, v, er.useCache()) + if er.err != nil { + return retNode, false + } + er.ctxStack = append(er.ctxStack, value) case *ast.VariableExpr: er.rewriteVariable(v) case *ast.FuncCallExpr: @@ -810,24 +815,6 @@ func (er *expressionRewriter) useCache() bool { return er.ctx.GetSessionVars().StmtCtx.UseCache } -func datumToConstant(d types.Datum, tp byte) *expression.Constant { - return &expression.Constant{Value: d, RetType: types.NewFieldType(tp)} -} - -func (er *expressionRewriter) paramToExpression(v *driver.ParamMarkerExpr) { - tp := types.NewFieldType(mysql.TypeUnspecified) - types.DefaultParamTypeForValue(v.GetValue(), tp) - value := &expression.Constant{Value: v.Datum, RetType: tp} - if er.useCache() { - var f expression.Expression - f, er.err = expression.NewFunctionBase(er.ctx, ast.GetParam, &v.Type, - datumToConstant(types.NewIntDatum(int64(v.Order)), mysql.TypeLonglong)) - f.GetType().Tp = v.Type.Tp - value.DeferredExpr = f - } - er.ctxStack = append(er.ctxStack, value) -} - func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) { stkLen := len(er.ctxStack) name := strings.ToLower(v.Name) @@ -837,7 +824,7 @@ func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) { er.ctxStack[stkLen-1], er.err = expression.NewFunction(er.ctx, ast.SetVar, er.ctxStack[stkLen-1].GetType(), - datumToConstant(types.NewDatum(name), mysql.TypeString), + expression.DatumToConstant(types.NewDatum(name), mysql.TypeString), er.ctxStack[stkLen-1]) return } @@ -845,7 +832,7 @@ func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) { ast.GetVar, // TODO: Here is wrong, the sessionVars should store a name -> Datum map. Will fix it later. types.NewFieldType(mysql.TypeString), - datumToConstant(types.NewStringDatum(name), mysql.TypeString)) + expression.DatumToConstant(types.NewStringDatum(name), mysql.TypeString)) if err != nil { er.err = errors.Trace(err) return @@ -871,7 +858,7 @@ func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) { er.err = errors.Trace(err) return } - e := datumToConstant(types.NewStringDatum(val), mysql.TypeVarString) + e := expression.DatumToConstant(types.NewStringDatum(val), mysql.TypeVarString) e.RetType.Charset, _ = er.ctx.GetSessionVars().GetSystemVar(variable.CharacterSetConnection) e.RetType.Collate, _ = er.ctx.GetSessionVars().GetSystemVar(variable.CollationConnection) er.ctxStack = append(er.ctxStack, e) diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 62c6c71816229..eb177aabfcc7a 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -430,10 +430,10 @@ func tryUpdatePointPlan(ctx sessionctx.Context, updateStmt *ast.UpdateStmt) Plan if orderedList == nil { return nil } - updatePlan := &Update{ + updatePlan := Update{ SelectPlan: fastSelect, OrderedList: orderedList, - } + }.Init(ctx) updatePlan.SetSchema(fastSelect.schema) return updatePlan } @@ -480,9 +480,9 @@ func tryDeletePointPlan(ctx sessionctx.Context, delStmt *ast.DeleteStmt) Plan { if checkFastPlanPrivilege(ctx, fastSelect, mysql.SelectPriv, mysql.DeletePriv) != nil { return nil } - delPlan := &Delete{ + delPlan := Delete{ SelectPlan: fastSelect, - } + }.Init(ctx) delPlan.SetSchema(fastSelect.schema) return delPlan }