From c96345877b76dd3c2975fa0e1d2abca16b7529be Mon Sep 17 00:00:00 2001 From: Jun-Seok Heo Date: Tue, 27 Nov 2018 09:54:07 +0900 Subject: [PATCH 01/11] executor,planner: make `order by ?` performs correctly when prepare-cache is enabled --- executor/prepared.go | 4 ++ executor/prepared_test.go | 50 ++++++++++++++++++++++ expression/simple_rewriter.go | 2 +- expression/util.go | 57 +++++++++++++++++++++++++- planner/core/cacheable_checker.go | 14 +++++++ planner/core/cacheable_checker_test.go | 14 +++++++ planner/core/expression_rewriter.go | 24 +++++++++-- planner/core/logical_plan_builder.go | 44 ++++++++++++++++++-- sessionctx/variable/session.go | 1 + 9 files changed, 201 insertions(+), 9 deletions(-) diff --git a/executor/prepared.go b/executor/prepared.go index d2cf1c6b9ba03..be2283ea9dae5 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -97,6 +97,10 @@ func NewPrepareExec(ctx sessionctx.Context, is infoschema.InfoSchema, sqlTxt str // Next implements the Executor Next interface. func (e *PrepareExec) Next(ctx context.Context, chk *chunk.Chunk) error { vars := e.ctx.GetSessionVars() + vars.InPrepare = true + defer func() { + vars.InPrepare = false + }() if e.ID != 0 { // Must be the case when we retry a prepare. // Make sure it is idempotent. diff --git a/executor/prepared_test.go b/executor/prepared_test.go index 07e54ef4ff356..89be1ee961f60 100644 --- a/executor/prepared_test.go +++ b/executor/prepared_test.go @@ -688,3 +688,53 @@ func (s *testSuite) TestPrepareDealloc(c *C) { tk.MustExec("deallocate prepare stmt4") c.Assert(tk.Se.PreparedPlanCache().Size(), Equals, 0) } + +func (s *testSuite) TestPreparedIssue8153(c *C) { + orgEnable := plannercore.PreparedPlanCacheEnabled() + orgCapacity := plannercore.PreparedPlanCacheCapacity + defer func() { + plannercore.SetPreparedPlanCache(orgEnable) + plannercore.PreparedPlanCacheCapacity = orgCapacity + }() + flags := []bool{false, true} + for _, flag := range flags { + plannercore.SetPreparedPlanCache(flag) + plannercore.PreparedPlanCacheCapacity = 100 + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int, b int)") + tk.MustExec("insert into t (a, b) values (1,3), (2,2), (3,1)") + + tk.MustExec(`prepare stmt from 'select * from t order by ? asc'`) + r := tk.MustQuery(`execute stmt using @param;`) + r.Check(testkit.Rows("1 3", "2 2", "3 1")) + + tk.MustExec(`set @param = 1`) + r = tk.MustQuery(`execute stmt using @param;`) + r.Check(testkit.Rows("1 3", "2 2", "3 1")) + + tk.MustExec(`set @param = 2`) + r = tk.MustQuery(`execute stmt using @param;`) + r.Check(testkit.Rows("3 1", "2 2", "1 3")) + + tk.MustExec(`set @param = 3`) + _, err := tk.Exec(`execute stmt using @param;`) + c.Assert(err.Error(), Equals, "[planner:1054]Unknown column '?' in 'order clause'") + + tk.MustExec(`set @param = '##'`) + r = tk.MustQuery(`execute stmt using @param;`) + r.Check(testkit.Rows("1 3", "2 2", "3 1")) + + tk.MustExec("insert into t (a, b) values (1,1), (1,2), (2,1), (2,3), (3,2), (3,3)") + tk.MustExec(`prepare stmt from 'select ?, sum(a) from t group by ?'`) + + tk.MustExec(`set @a=1,@b=1`) + r = tk.MustQuery(`execute stmt using @a,@b;`) + r.Check(testkit.Rows("1 18")) + + tk.MustExec(`set @a=1,@b=2`) + _, err = tk.Exec(`execute stmt using @a,@b;`) + c.Assert(err.Error(), Equals, "[planner:1056]Can't group on 'sum(a)'") + } +} diff --git a/expression/simple_rewriter.go b/expression/simple_rewriter.go index 46bdac88a329c..518c833452226 100644 --- a/expression/simple_rewriter.go +++ b/expression/simple_rewriter.go @@ -151,7 +151,7 @@ func (sr *simpleRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok boo } case *driver.ParamMarkerExpr: var value Expression - value, sr.err = GetParamExpression(sr.ctx, v, sr.useCache()) + value, sr.err = GetParamExpression(sr.ctx, v) if sr.err != nil { return retNode, false } diff --git a/expression/util.go b/expression/util.go index e5805fb3ddcab..ce36a5d494dd2 100644 --- a/expression/util.go +++ b/expression/util.go @@ -511,7 +511,8 @@ func DatumToConstant(d types.Datum, tp byte) *Constant { } // GetParamExpression generate a getparam function expression. -func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr, useCache bool) (Expression, error) { +func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr) (Expression, error) { + useCache := ctx.GetSessionVars().StmtCtx.UseCache tp := types.NewFieldType(mysql.TypeUnspecified) types.DefaultParamTypeForValue(v.GetValue(), tp) value := &Constant{Value: v.Datum, RetType: tp} @@ -526,3 +527,57 @@ func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr, useCa } return value, nil } + +// ParamToByItemNode generate ByItem node from ParamMarkerExpr. +func ParamToByItemNode(ctx sessionctx.Context, v *driver.ParamMarkerExpr) (*ast.ByItem, bool, error) { + value, err := GetParamExpression(ctx, v) + if err != nil { + return nil, true, errors.Trace(err) + } + str, isNull, err := GetStringFromConstant(ctx, value) + if err != nil { + return nil, true, errors.Trace(err) + } + if isNull { + return nil, true, nil + } + pos, err := strconv.Atoi(str) + if err == nil { + byItem := &ast.ByItem{Expr: &ast.PositionExpr{N: pos, P: v}} + return byItem, false, nil + } + return nil, true, nil +} + +// GetStringFromConstant gets a string value from the Constant expression. +func GetStringFromConstant(ctx sessionctx.Context, value Expression) (string, bool, error) { + con, ok := value.(*Constant) + if !ok { + err := errors.Errorf("Not a Constant expression %+v", value) + return "", true, errors.Trace(err) + } + str, isNull, err := con.EvalString(ctx, chunk.Row{}) + if err != nil { + return "", true, errors.Trace(err) + } + if isNull { + return "", true, nil + } + return str, false, nil +} + +// GetIntFromConstant gets an interger value from the Constant expression. +func GetIntFromConstant(ctx sessionctx.Context, value Expression) (int, bool, error) { + str, isNull, err := GetStringFromConstant(ctx, value) + if err != nil { + return 0, true, errors.Trace(err) + } + if isNull { + return 0, true, nil + } + intNum, err := strconv.Atoi(str) + if err != nil { + return 0, true, errors.Trace(err) + } + return intNum, false, nil +} diff --git a/planner/core/cacheable_checker.go b/planner/core/cacheable_checker.go index 76edbd01a2d12..49e08eb8227b1 100644 --- a/planner/core/cacheable_checker.go +++ b/planner/core/cacheable_checker.go @@ -55,6 +55,20 @@ func (checker *cacheableChecker) Enter(in ast.Node) (out ast.Node, skipChildren checker.cacheable = false return in, true } + case *ast.OrderByClause: + for _, item := range node.Items { + if _, isParamMarker := item.Expr.(*driver.ParamMarkerExpr); isParamMarker { + checker.cacheable = false + return in, true + } + } + case *ast.GroupByClause: + for _, item := range node.Items { + if _, isParamMarker := item.Expr.(*driver.ParamMarkerExpr); isParamMarker { + checker.cacheable = false + return in, true + } + } case *ast.Limit: if node.Count != nil { if _, isParamMarker := node.Count.(*driver.ParamMarkerExpr); isParamMarker { diff --git a/planner/core/cacheable_checker_test.go b/planner/core/cacheable_checker_test.go index 67278f5cfe99b..8f3d287701533 100644 --- a/planner/core/cacheable_checker_test.go +++ b/planner/core/cacheable_checker_test.go @@ -177,4 +177,18 @@ func (s *testCacheableSuite) TestCacheable(c *C) { Limit: limitStmt, } c.Assert(Cacheable(stmt), IsTrue) + + paramExpr := &driver.ParamMarkerExpr{} + orderByClause := &ast.OrderByClause{Items: []*ast.ByItem{{Expr: paramExpr}}} + stmt = &ast.SelectStmt{ + OrderBy: orderByClause, + } + c.Assert(Cacheable(stmt), IsFalse) + + valExpr := &driver.ValueExpr{} + orderByClause = &ast.OrderByClause{Items: []*ast.ByItem{{Expr: valExpr}}} + stmt = &ast.SelectStmt{ + OrderBy: orderByClause, + } + c.Assert(Cacheable(stmt), IsTrue) } diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 1da8e11ed81ce..8e203f576c0a3 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -756,7 +756,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok er.ctxStack = append(er.ctxStack, value) case *driver.ParamMarkerExpr: var value expression.Expression - value, er.err = expression.GetParamExpression(er.ctx, v, er.useCache()) + value, er.err = expression.GetParamExpression(er.ctx, v) if er.err != nil { return retNode, false } @@ -941,10 +941,26 @@ func (er *expressionRewriter) isNullToExpression(v *ast.IsNullExpr) { } func (er *expressionRewriter) positionToScalarFunc(v *ast.PositionExpr) { - if v.N > 0 && v.N <= er.schema.Len() { - er.ctxStack = append(er.ctxStack, er.schema.Columns[v.N-1]) + pos := v.N + str := strconv.Itoa(pos) + stkLen := len(er.ctxStack) + if v.P != nil { + val := er.ctxStack[stkLen-1] + intNum, isNull, err := expression.GetIntFromConstant(er.ctx, val) + er.ctxStack = er.ctxStack[:stkLen-1] + str = "?" + if err == nil { + if isNull { + return + } + pos = intNum + } + er.err = err + } + if er.err == nil && pos > 0 && pos <= er.schema.Len() { + er.ctxStack = append(er.ctxStack, er.schema.Columns[pos-1]) } else { - er.err = ErrUnknownColumn.GenWithStackByArgs(strconv.Itoa(v.N), clauseMsg[er.b.curClause]) + er.err = ErrUnknownColumn.GenWithStackByArgs(str, clauseMsg[er.b.curClause]) } } diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index fd5b5ffd58460..b8c48b0e3559f 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -827,6 +827,21 @@ func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper sort := LogicalSort{}.Init(b.ctx) exprs := make([]*ByItems, 0, len(byItems)) for _, item := range byItems { + switch x := item.Expr.(type) { + case *driver.ParamMarkerExpr: + if b.ctx.GetSessionVars().InPrepare { + continue + } + newItem, isNull, err := expression.ParamToByItemNode(b.ctx, x) + if err != nil { + err := errors.Errorf("Unknown column '%+v' in 'order clause'", "?") + return nil, errors.Trace(err) + } + if isNull { + continue + } + item = newItem + } it, np, err := b.rewrite(item.Expr, p, aggMapper, true) if err != nil { return nil, errors.Trace(err) @@ -1565,14 +1580,37 @@ func (b *PlanBuilder) resolveGbyExprs(p LogicalPlan, gby *ast.GroupByClause, fie schema: p.Schema(), } for _, item := range gby.Items { + hasParam := false resolver.inExpr = false - retExpr, _ := item.Expr.Accept(resolver) + var retExpr ast.Node + switch x := item.Expr.(type) { + case *driver.ParamMarkerExpr: + hasParam = true + if b.ctx.GetSessionVars().InPrepare { + continue + } + newItem, isNull, err := expression.ParamToByItemNode(b.ctx, x) + if err != nil { + err := errors.Errorf("Unknown column '%+v' in 'group statement'", "?") + return nil, nil, errors.Trace(err) + } + if isNull { + continue + } + retExpr, _ = newItem.Expr.Accept(resolver) + default: + retExpr, _ = item.Expr.Accept(resolver) + } + if resolver.err != nil { return nil, nil, errors.Trace(resolver.err) } + if !hasParam { + item.Expr = retExpr.(ast.ExprNode) + } - item.Expr = retExpr.(ast.ExprNode) - expr, np, err := b.rewrite(item.Expr, p, nil, true) + itemExpr := retExpr.(ast.ExprNode) + expr, np, err := b.rewrite(itemExpr, p, nil, true) if err != nil { return nil, nil, errors.Trace(err) } diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 4b413bec8d087..7f562438b817f 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -174,6 +174,7 @@ type SessionVars struct { MemQuota BatchSize RetryLimit int64 + InPrepare bool DisableTxnAutoRetry bool // UsersLock is a lock for user defined variables. UsersLock sync.RWMutex From c42c2b2297d546e8bc46e2443f7c9a04655f3876 Mon Sep 17 00:00:00 2001 From: Jun-Seok Heo Date: Tue, 27 Nov 2018 18:45:33 +0900 Subject: [PATCH 02/11] address the comments --- expression/util.go | 27 +++++++-------------------- planner/core/expression_rewriter.go | 2 +- planner/core/logical_plan_builder.go | 4 ++-- 3 files changed, 10 insertions(+), 23 deletions(-) diff --git a/expression/util.go b/expression/util.go index ce36a5d494dd2..77399fc2e3df0 100644 --- a/expression/util.go +++ b/expression/util.go @@ -534,19 +534,12 @@ func ParamToByItemNode(ctx sessionctx.Context, v *driver.ParamMarkerExpr) (*ast. if err != nil { return nil, true, errors.Trace(err) } - str, isNull, err := GetStringFromConstant(ctx, value) - if err != nil { + pos, isNull, err := GetIntFromConstant(ctx, value) + if err != nil || isNull { return nil, true, errors.Trace(err) } - if isNull { - return nil, true, nil - } - pos, err := strconv.Atoi(str) - if err == nil { - byItem := &ast.ByItem{Expr: &ast.PositionExpr{N: pos, P: v}} - return byItem, false, nil - } - return nil, true, nil + byItem := &ast.ByItem{Expr: &ast.PositionExpr{N: pos, P: v}} + return byItem, false, nil } // GetStringFromConstant gets a string value from the Constant expression. @@ -557,12 +550,9 @@ func GetStringFromConstant(ctx sessionctx.Context, value Expression) (string, bo return "", true, errors.Trace(err) } str, isNull, err := con.EvalString(ctx, chunk.Row{}) - if err != nil { + if err != nil || isNull { return "", true, errors.Trace(err) } - if isNull { - return "", true, nil - } return str, false, nil } @@ -572,12 +562,9 @@ func GetIntFromConstant(ctx sessionctx.Context, value Expression) (int, bool, er if err != nil { return 0, true, errors.Trace(err) } - if isNull { - return 0, true, nil - } intNum, err := strconv.Atoi(str) - if err != nil { - return 0, true, errors.Trace(err) + if err != nil || isNull { + return 0, true, nil } return intNum, false, nil } diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 8e203f576c0a3..db33738ad500c 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -943,8 +943,8 @@ func (er *expressionRewriter) isNullToExpression(v *ast.IsNullExpr) { func (er *expressionRewriter) positionToScalarFunc(v *ast.PositionExpr) { pos := v.N str := strconv.Itoa(pos) - stkLen := len(er.ctxStack) if v.P != nil { + stkLen := len(er.ctxStack) val := er.ctxStack[stkLen-1] intNum, isNull, err := expression.GetIntFromConstant(er.ctx, val) er.ctxStack = er.ctxStack[:stkLen-1] diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index b8c48b0e3559f..93479c0d93746 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -834,7 +834,7 @@ func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper } newItem, isNull, err := expression.ParamToByItemNode(b.ctx, x) if err != nil { - err := errors.Errorf("Unknown column '%+v' in 'order clause'", "?") + err := ErrUnknownColumn.GenWithStackByArgs("?", clauseMsg[b.curClause]) return nil, errors.Trace(err) } if isNull { @@ -1591,7 +1591,7 @@ func (b *PlanBuilder) resolveGbyExprs(p LogicalPlan, gby *ast.GroupByClause, fie } newItem, isNull, err := expression.ParamToByItemNode(b.ctx, x) if err != nil { - err := errors.Errorf("Unknown column '%+v' in 'group statement'", "?") + err := ErrUnknownColumn.GenWithStackByArgs("?", clauseMsg[b.curClause]) return nil, nil, errors.Trace(err) } if isNull { From 7ff1e1ccaa2dbbc139c0c4c808b95546ffbc0232 Mon Sep 17 00:00:00 2001 From: Jun-Seok Heo Date: Tue, 27 Nov 2018 19:48:46 +0900 Subject: [PATCH 03/11] fix a potential sporadic CI error due to the PR #8339 --- executor/prepared_test.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/executor/prepared_test.go b/executor/prepared_test.go index 89be1ee961f60..1e56b37ee94e9 100644 --- a/executor/prepared_test.go +++ b/executor/prepared_test.go @@ -692,14 +692,21 @@ func (s *testSuite) TestPrepareDealloc(c *C) { func (s *testSuite) TestPreparedIssue8153(c *C) { orgEnable := plannercore.PreparedPlanCacheEnabled() orgCapacity := plannercore.PreparedPlanCacheCapacity + orgMemGuardRatio := plannercore.PreparedPlanCacheMemoryGuardRatio + orgMaxMemory := plannercore.PreparedPlanCacheMaxMemory defer func() { plannercore.SetPreparedPlanCache(orgEnable) plannercore.PreparedPlanCacheCapacity = orgCapacity + plannercore.PreparedPlanCacheMemoryGuardRatio = orgMemGuardRatio + plannercore.PreparedPlanCacheMaxMemory = orgMaxMemory }() flags := []bool{false, true} for _, flag := range flags { + var err error plannercore.SetPreparedPlanCache(flag) plannercore.PreparedPlanCacheCapacity = 100 + plannercore.PreparedPlanCacheMemoryGuardRatio = 0.1 + plannercore.PreparedPlanCacheMaxMemory, err = memory.MemTotal() tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -719,7 +726,7 @@ func (s *testSuite) TestPreparedIssue8153(c *C) { r.Check(testkit.Rows("3 1", "2 2", "1 3")) tk.MustExec(`set @param = 3`) - _, err := tk.Exec(`execute stmt using @param;`) + _, err = tk.Exec(`execute stmt using @param;`) c.Assert(err.Error(), Equals, "[planner:1054]Unknown column '?' in 'order clause'") tk.MustExec(`set @param = '##'`) From 2c1e00c586ddfc2daf5aeacb49efe2248af86cfe Mon Sep 17 00:00:00 2001 From: Jun-Seok Heo Date: Wed, 28 Nov 2018 15:26:48 +0900 Subject: [PATCH 04/11] correct the condition --- expression/util.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/expression/util.go b/expression/util.go index 77399fc2e3df0..e71d410e9baa9 100644 --- a/expression/util.go +++ b/expression/util.go @@ -559,11 +559,11 @@ func GetStringFromConstant(ctx sessionctx.Context, value Expression) (string, bo // GetIntFromConstant gets an interger value from the Constant expression. func GetIntFromConstant(ctx sessionctx.Context, value Expression) (int, bool, error) { str, isNull, err := GetStringFromConstant(ctx, value) - if err != nil { + if err != nil || isNull { return 0, true, errors.Trace(err) } intNum, err := strconv.Atoi(str) - if err != nil || isNull { + if err != nil { return 0, true, nil } return intNum, false, nil From 5716434f7308e42c98fff29b2dcedf9667a60580 Mon Sep 17 00:00:00 2001 From: Jun-Seok Heo Date: Thu, 29 Nov 2018 08:39:11 +0900 Subject: [PATCH 05/11] code clean up --- expression/util.go | 16 +++++++----- planner/core/expression_rewriter.go | 2 +- planner/core/logical_plan_builder.go | 39 ++++++++++++---------------- 3 files changed, 26 insertions(+), 31 deletions(-) diff --git a/expression/util.go b/expression/util.go index e71d410e9baa9..3293a48e2780a 100644 --- a/expression/util.go +++ b/expression/util.go @@ -528,18 +528,20 @@ func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr) (Expr return value, nil } -// ParamToByItemNode generate ByItem node from ParamMarkerExpr. -func ParamToByItemNode(ctx sessionctx.Context, v *driver.ParamMarkerExpr) (*ast.ByItem, bool, error) { - value, err := GetParamExpression(ctx, v) +// PosFromPositionExpr generates a position value from PositionExpr. +func PosFromPositionExpr(ctx sessionctx.Context, v *ast.PositionExpr) (int, bool, error) { + if v.P == nil { + return v.N, false, nil + } + value, err := GetParamExpression(ctx, v.P.(*driver.ParamMarkerExpr)) if err != nil { - return nil, true, errors.Trace(err) + return 0, true, errors.Trace(err) } pos, isNull, err := GetIntFromConstant(ctx, value) if err != nil || isNull { - return nil, true, errors.Trace(err) + return 0, true, errors.Trace(err) } - byItem := &ast.ByItem{Expr: &ast.PositionExpr{N: pos, P: v}} - return byItem, false, nil + return pos, false, nil } // GetStringFromConstant gets a string value from the Constant expression. diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index db33738ad500c..dec7ba06af6cf 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -947,13 +947,13 @@ func (er *expressionRewriter) positionToScalarFunc(v *ast.PositionExpr) { stkLen := len(er.ctxStack) val := er.ctxStack[stkLen-1] intNum, isNull, err := expression.GetIntFromConstant(er.ctx, val) - er.ctxStack = er.ctxStack[:stkLen-1] str = "?" if err == nil { if isNull { return } pos = intNum + er.ctxStack = er.ctxStack[:stkLen-1] } er.err = err } diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 93479c0d93746..32a3113968c2f 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -18,6 +18,7 @@ import ( "math" "math/bits" "reflect" + "strconv" "strings" "unicode" @@ -827,20 +828,12 @@ func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper sort := LogicalSort{}.Init(b.ctx) exprs := make([]*ByItems, 0, len(byItems)) for _, item := range byItems { - switch x := item.Expr.(type) { + switch v := item.Expr.(type) { case *driver.ParamMarkerExpr: if b.ctx.GetSessionVars().InPrepare { continue } - newItem, isNull, err := expression.ParamToByItemNode(b.ctx, x) - if err != nil { - err := ErrUnknownColumn.GenWithStackByArgs("?", clauseMsg[b.curClause]) - return nil, errors.Trace(err) - } - if isNull { - continue - } - item = newItem + item = &ast.ByItem{Expr: &ast.PositionExpr{P: v}} } it, np, err := b.rewrite(item.Expr, p, aggMapper, true) if err != nil { @@ -1167,6 +1160,7 @@ func (b *PlanBuilder) extractAggFuncs(fields []*ast.SelectField) ([]*ast.Aggrega // gbyResolver resolves group by items from select fields. type gbyResolver struct { + ctx sessionctx.Context fields []*ast.SelectField schema *expression.Schema err error @@ -1211,14 +1205,19 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { return inNode, false } case *ast.PositionExpr: - if v.N < 1 || v.N > len(g.fields) { - g.err = errors.Errorf("Unknown column '%d' in 'group statement'", v.N) + pos, isNull, err := expression.PosFromPositionExpr(g.ctx, v) + if err != nil || isNull { + g.err = ErrUnknownColumn.GenWithStackByArgs("?", clauseMsg[groupByClause]) + return inNode, false + } + if pos < 1 || pos > len(g.fields) { + g.err = ErrUnknownColumn.GenWithStackByArgs(strconv.Itoa(v.N), clauseMsg[groupByClause]) return inNode, false } - ret := g.fields[v.N-1].Expr + ret := g.fields[pos-1].Expr ret.Accept(extractor) if len(extractor.AggFuncs) != 0 { - g.err = ErrWrongGroupField.GenWithStackByArgs(g.fields[v.N-1].Text()) + g.err = ErrWrongGroupField.GenWithStackByArgs(g.fields[pos-1].Text()) return inNode, false } return ret, true @@ -1576,6 +1575,7 @@ func (b *PlanBuilder) resolveGbyExprs(p LogicalPlan, gby *ast.GroupByClause, fie b.curClause = groupByClause exprs := make([]expression.Expression, 0, len(gby.Items)) resolver := &gbyResolver{ + ctx: b.ctx, fields: fields, schema: p.Schema(), } @@ -1583,20 +1583,13 @@ func (b *PlanBuilder) resolveGbyExprs(p LogicalPlan, gby *ast.GroupByClause, fie hasParam := false resolver.inExpr = false var retExpr ast.Node - switch x := item.Expr.(type) { + switch v := item.Expr.(type) { case *driver.ParamMarkerExpr: hasParam = true if b.ctx.GetSessionVars().InPrepare { continue } - newItem, isNull, err := expression.ParamToByItemNode(b.ctx, x) - if err != nil { - err := ErrUnknownColumn.GenWithStackByArgs("?", clauseMsg[b.curClause]) - return nil, nil, errors.Trace(err) - } - if isNull { - continue - } + newItem := &ast.ByItem{Expr: &ast.PositionExpr{P: v}} retExpr, _ = newItem.Expr.Accept(resolver) default: retExpr, _ = item.Expr.Accept(resolver) From 07ae3cae8c6ff68d2072c9ac07b949e3985f283f Mon Sep 17 00:00:00 2001 From: Jun-Seok Heo Date: Thu, 29 Nov 2018 17:24:15 +0900 Subject: [PATCH 06/11] cleaning the codes --- executor/executor.go | 4 ++ executor/prepared.go | 4 -- expression/util.go | 11 +++++ planner/core/logical_plan_builder.go | 62 ++++++++++++++++------------ sessionctx/stmtctx/stmtctx.go | 1 + sessionctx/variable/session.go | 1 - 6 files changed, 51 insertions(+), 32 deletions(-) diff --git a/executor/executor.go b/executor/executor.go index f561b7b4f8112..809f0bbff29ea 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1234,6 +1234,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.MemTracker = memory.NewTracker(s.Text(), vars.MemQuotaQuery) sc.NowTs = time.Time{} sc.SysTs = time.Time{} + sc.InPreparedStmt = vars.StmtCtx.InPreparedStmt switch config.GetGlobalConfig().OOMAction { case config.OOMActionCancel: sc.MemTracker.SetActionOnExceed(&memory.PanicOnExceed{}) @@ -1245,6 +1246,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { if execStmt, ok := s.(*ast.ExecuteStmt); ok { s, err = getPreparedStmt(execStmt, vars) + sc.InPreparedStmt = false } // TODO: Many same bool variables here. // We should set only two variables ( @@ -1305,6 +1307,8 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.InShowWarning = true sc.SetWarnings(vars.StmtCtx.GetWarnings()) } + case *ast.PrepareStmt: + sc.InPreparedStmt = true default: sc.IgnoreTruncate = true sc.IgnoreZeroInDate = true diff --git a/executor/prepared.go b/executor/prepared.go index be2283ea9dae5..d2cf1c6b9ba03 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -97,10 +97,6 @@ func NewPrepareExec(ctx sessionctx.Context, is infoschema.InfoSchema, sqlTxt str // Next implements the Executor Next interface. func (e *PrepareExec) Next(ctx context.Context, chk *chunk.Chunk) error { vars := e.ctx.GetSessionVars() - vars.InPrepare = true - defer func() { - vars.InPrepare = false - }() if e.ID != 0 { // Must be the case when we retry a prepare. // Make sure it is idempotent. diff --git a/expression/util.go b/expression/util.go index 3293a48e2780a..875be2ceb5449 100644 --- a/expression/util.go +++ b/expression/util.go @@ -528,6 +528,17 @@ func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr) (Expr return value, nil } +// ConvertToByItemExpr rewrites ByItem.ExprNode to a proper ExprNode. +func ConvertToByItemExpr(ctx sessionctx.Context, n ast.Node) ast.Node { + switch v := n.(type) { + case *driver.ParamMarkerExpr: + if !ctx.GetSessionVars().StmtCtx.InPreparedStmt { + return &ast.PositionExpr{P: v} + } + } + return n +} + // PosFromPositionExpr generates a position value from PositionExpr. func PosFromPositionExpr(ctx sessionctx.Context, v *ast.PositionExpr) (int, bool, error) { if v.P == nil { diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 32a3113968c2f..f42b0e83687ed 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -819,6 +819,26 @@ func (by *ByItems) Clone() *ByItems { return &ByItems{Expr: by.Expr.Clone(), Desc: by.Desc} } +// itemTransformer transforms ParamMarkerExpr to PositionExpr in the context of ByItem +type itemTransformer struct { + ctx sessionctx.Context + isParam bool +} + +func (t *itemTransformer) Enter(inNode ast.Node) (ast.Node, bool) { + switch inNode.(type) { + case *driver.ParamMarkerExpr: + newNode := expression.ConvertToByItemExpr(t.ctx, inNode) + t.isParam = true + return newNode, true + } + return inNode, false +} + +func (t *itemTransformer) Leave(inNode ast.Node) (ast.Node, bool) { + return inNode, false +} + func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper map[*ast.AggregateFuncExpr]int) (*LogicalSort, error) { if _, isUnion := p.(*LogicalUnionAll); isUnion { b.curClause = globalOrderByClause @@ -827,14 +847,10 @@ func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper } sort := LogicalSort{}.Init(b.ctx) exprs := make([]*ByItems, 0, len(byItems)) + transformer := &itemTransformer{ctx: b.ctx} for _, item := range byItems { - switch v := item.Expr.(type) { - case *driver.ParamMarkerExpr: - if b.ctx.GetSessionVars().InPrepare { - continue - } - item = &ast.ByItem{Expr: &ast.PositionExpr{P: v}} - } + newExpr, _ := item.Expr.Accept(transformer) + item.Expr = newExpr.(ast.ExprNode) it, np, err := b.rewrite(item.Expr, p, aggMapper, true) if err != nil { return nil, errors.Trace(err) @@ -1160,17 +1176,22 @@ func (b *PlanBuilder) extractAggFuncs(fields []*ast.SelectField) ([]*ast.Aggrega // gbyResolver resolves group by items from select fields. type gbyResolver struct { - ctx sessionctx.Context - fields []*ast.SelectField - schema *expression.Schema - err error - inExpr bool + ctx sessionctx.Context + fields []*ast.SelectField + schema *expression.Schema + err error + inExpr bool + isParam bool } func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) { switch inNode.(type) { case *ast.SubqueryExpr, *ast.CompareSubqueryExpr, *ast.ExistsSubqueryExpr: return inNode, true + case *driver.ParamMarkerExpr: + newNode := expression.ConvertToByItemExpr(g.ctx, inNode) + g.isParam = true + return newNode, true case *driver.ValueExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.ColumnName: default: g.inExpr = true @@ -1580,25 +1601,12 @@ func (b *PlanBuilder) resolveGbyExprs(p LogicalPlan, gby *ast.GroupByClause, fie schema: p.Schema(), } for _, item := range gby.Items { - hasParam := false resolver.inExpr = false - var retExpr ast.Node - switch v := item.Expr.(type) { - case *driver.ParamMarkerExpr: - hasParam = true - if b.ctx.GetSessionVars().InPrepare { - continue - } - newItem := &ast.ByItem{Expr: &ast.PositionExpr{P: v}} - retExpr, _ = newItem.Expr.Accept(resolver) - default: - retExpr, _ = item.Expr.Accept(resolver) - } - + retExpr, _ := item.Expr.Accept(resolver) if resolver.err != nil { return nil, nil, errors.Trace(resolver.err) } - if !hasParam { + if !resolver.isParam { item.Expr = retExpr.(ast.ExprNode) } diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index ef416642c2c62..2c8966f5ce8fd 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -61,6 +61,7 @@ type StatementContext struct { PadCharToFullLength bool BatchCheck bool InNullRejectCheck bool + InPreparedStmt bool // mu struct holds variables that change during execution. mu struct { diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 7f562438b817f..4b413bec8d087 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -174,7 +174,6 @@ type SessionVars struct { MemQuota BatchSize RetryLimit int64 - InPrepare bool DisableTxnAutoRetry bool // UsersLock is a lock for user defined variables. UsersLock sync.RWMutex From 14201c7b89126229f22e10434bb8242039515e6a Mon Sep 17 00:00:00 2001 From: Jun-Seok Heo Date: Thu, 29 Nov 2018 01:09:36 +0900 Subject: [PATCH 07/11] fix a potential error --- planner/core/logical_plan_builder.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index f42b0e83687ed..3676ddf5c7881 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -1232,7 +1232,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { return inNode, false } if pos < 1 || pos > len(g.fields) { - g.err = ErrUnknownColumn.GenWithStackByArgs(strconv.Itoa(v.N), clauseMsg[groupByClause]) + g.err = ErrUnknownColumn.GenWithStackByArgs(strconv.Itoa(pos), clauseMsg[groupByClause]) return inNode, false } ret := g.fields[pos-1].Expr From 3f78986d6263404f33a437161535254f96b8d882 Mon Sep 17 00:00:00 2001 From: Jun-Seok Heo Date: Thu, 29 Nov 2018 01:53:33 +0900 Subject: [PATCH 08/11] fix error messages --- planner/core/errors.go | 2 ++ planner/core/logical_plan_builder.go | 5 ++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/planner/core/errors.go b/planner/core/errors.go index 89c81fa47b086..48645a062988a 100644 --- a/planner/core/errors.go +++ b/planner/core/errors.go @@ -28,6 +28,7 @@ const ( codeWrongUsage = mysql.ErrWrongUsage codeAmbiguous = mysql.ErrNonUniq + codeUnknown = mysql.ErrUnknown codeUnknownColumn = mysql.ErrBadField codeUnknownTable = mysql.ErrUnknownTable codeWrongArguments = mysql.ErrWrongArguments @@ -64,6 +65,7 @@ var ( ErrWrongUsage = terror.ClassOptimizer.New(codeWrongUsage, mysql.MySQLErrName[mysql.ErrWrongUsage]) ErrAmbiguous = terror.ClassOptimizer.New(codeAmbiguous, mysql.MySQLErrName[mysql.ErrNonUniq]) + ErrUnknown = terror.ClassOptimizer.New(codeUnknown, mysql.MySQLErrName[mysql.ErrUnknown]) ErrUnknownColumn = terror.ClassOptimizer.New(codeUnknownColumn, mysql.MySQLErrName[mysql.ErrBadField]) ErrUnknownTable = terror.ClassOptimizer.New(codeUnknownTable, mysql.MySQLErrName[mysql.ErrUnknownTable]) ErrWrongArguments = terror.ClassOptimizer.New(codeWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments]) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 3676ddf5c7881..8594a732f05df 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -18,7 +18,6 @@ import ( "math" "math/bits" "reflect" - "strconv" "strings" "unicode" @@ -1228,11 +1227,11 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { case *ast.PositionExpr: pos, isNull, err := expression.PosFromPositionExpr(g.ctx, v) if err != nil || isNull { - g.err = ErrUnknownColumn.GenWithStackByArgs("?", clauseMsg[groupByClause]) + g.err = ErrUnknown.GenWithStackByArgs() return inNode, false } if pos < 1 || pos > len(g.fields) { - g.err = ErrUnknownColumn.GenWithStackByArgs(strconv.Itoa(pos), clauseMsg[groupByClause]) + g.err = errors.Errorf("Unknown column '%d' in 'group statement'", pos) return inNode, false } ret := g.fields[pos-1].Expr From 105d4f4b589e15821eeef62a16323a723d52144e Mon Sep 17 00:00:00 2001 From: Jun-Seok Heo Date: Sat, 1 Dec 2018 07:55:38 +0900 Subject: [PATCH 09/11] remove StmtCtx.InPrepare and set the default parameter value as `NULL` --- executor/executor.go | 4 --- executor/prepared.go | 3 +- expression/util.go | 6 ++-- planner/core/logical_plan_builder.go | 46 +++++++++++++++++++--------- planner/core/point_get_plan.go | 3 +- sessionctx/stmtctx/stmtctx.go | 1 - 6 files changed, 36 insertions(+), 27 deletions(-) diff --git a/executor/executor.go b/executor/executor.go index 809f0bbff29ea..f561b7b4f8112 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1234,7 +1234,6 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.MemTracker = memory.NewTracker(s.Text(), vars.MemQuotaQuery) sc.NowTs = time.Time{} sc.SysTs = time.Time{} - sc.InPreparedStmt = vars.StmtCtx.InPreparedStmt switch config.GetGlobalConfig().OOMAction { case config.OOMActionCancel: sc.MemTracker.SetActionOnExceed(&memory.PanicOnExceed{}) @@ -1246,7 +1245,6 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { if execStmt, ok := s.(*ast.ExecuteStmt); ok { s, err = getPreparedStmt(execStmt, vars) - sc.InPreparedStmt = false } // TODO: Many same bool variables here. // We should set only two variables ( @@ -1307,8 +1305,6 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.InShowWarning = true sc.SetWarnings(vars.StmtCtx.GetWarnings()) } - case *ast.PrepareStmt: - sc.InPreparedStmt = true default: sc.IgnoreTruncate = true sc.IgnoreZeroInDate = true diff --git a/executor/prepared.go b/executor/prepared.go index d2cf1c6b9ba03..bdf1e5288d69a 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -26,7 +26,6 @@ import ( plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" - "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/sqlexec" @@ -161,7 +160,7 @@ func (e *PrepareExec) Next(ctx context.Context, chk *chunk.Chunk) error { // We try to build the real statement of preparedStmt. for i := range prepared.Params { - prepared.Params[i].(*driver.ParamMarkerExpr).Datum = types.NewIntDatum(0) + prepared.Params[i].(*driver.ParamMarkerExpr).Datum.SetNull() } var p plannercore.Plan p, err = plannercore.BuildLogicalPlan(e.ctx, stmt, e.is) diff --git a/expression/util.go b/expression/util.go index 875be2ceb5449..7fdab614a7030 100644 --- a/expression/util.go +++ b/expression/util.go @@ -529,12 +529,10 @@ func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr) (Expr } // ConvertToByItemExpr rewrites ByItem.ExprNode to a proper ExprNode. -func ConvertToByItemExpr(ctx sessionctx.Context, n ast.Node) ast.Node { +func ConvertToByItemExpr(n ast.Node) ast.Node { switch v := n.(type) { case *driver.ParamMarkerExpr: - if !ctx.GetSessionVars().StmtCtx.InPreparedStmt { - return &ast.PositionExpr{P: v} - } + return &ast.PositionExpr{P: v} } return n } diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 8594a732f05df..78b15c4a6594d 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -34,7 +34,6 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" @@ -820,15 +819,12 @@ func (by *ByItems) Clone() *ByItems { // itemTransformer transforms ParamMarkerExpr to PositionExpr in the context of ByItem type itemTransformer struct { - ctx sessionctx.Context - isParam bool } func (t *itemTransformer) Enter(inNode ast.Node) (ast.Node, bool) { switch inNode.(type) { case *driver.ParamMarkerExpr: - newNode := expression.ConvertToByItemExpr(t.ctx, inNode) - t.isParam = true + newNode := expression.ConvertToByItemExpr(inNode) return newNode, true } return inNode, false @@ -846,7 +842,7 @@ func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper } sort := LogicalSort{}.Init(b.ctx) exprs := make([]*ByItems, 0, len(byItems)) - transformer := &itemTransformer{ctx: b.ctx} + transformer := &itemTransformer{} for _, item := range byItems { newExpr, _ := item.Expr.Accept(transformer) item.Expr = newExpr.(ast.ExprNode) @@ -866,7 +862,27 @@ func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper // getUintForLimitOffset gets uint64 value for limit/offset. // For ordinary statement, limit/offset should be uint64 constant value. // For prepared statement, limit/offset is string. We should convert it to uint64. -func getUintForLimitOffset(sc *stmtctx.StatementContext, val interface{}) (uint64, error) { +func getUintForLimitOffset(ctx sessionctx.Context, n ast.Node) (uint64, error) { + var val interface{} + switch v := n.(type) { + case *driver.ValueExpr: + val = v.GetValue() + case *driver.ParamMarkerExpr: + param, err := expression.GetParamExpression(ctx, v) + if err != nil { + return 0, errors.Trace(err) + } + str, isNull, err := expression.GetStringFromConstant(ctx, param) + if err != nil { + return 0, errors.Trace(err) + } + if isNull { + return 0, nil + } + val = str + default: + return 0, errors.Errorf("Invalid type %T for LogicalLimit/Offset", v) + } switch v := val.(type) { case uint64: return v, nil @@ -875,22 +891,23 @@ func getUintForLimitOffset(sc *stmtctx.StatementContext, val interface{}) (uint6 return uint64(v), nil } case string: + sc := ctx.GetSessionVars().StmtCtx uVal, err := types.StrToUint(sc, v) return uVal, errors.Trace(err) } return 0, errors.Errorf("Invalid type %T for LogicalLimit/Offset", val) } -func extractLimitCountOffset(sc *stmtctx.StatementContext, limit *ast.Limit) (count uint64, +func extractLimitCountOffset(ctx sessionctx.Context, limit *ast.Limit) (count uint64, offset uint64, err error) { if limit.Count != nil { - count, err = getUintForLimitOffset(sc, limit.Count.(ast.ValueExpr).GetValue()) + count, err = getUintForLimitOffset(ctx, limit.Count) if err != nil { return 0, 0, ErrWrongArguments.GenWithStackByArgs("LIMIT") } } if limit.Offset != nil { - offset, err = getUintForLimitOffset(sc, limit.Offset.(ast.ValueExpr).GetValue()) + offset, err = getUintForLimitOffset(ctx, limit.Offset) if err != nil { return 0, 0, ErrWrongArguments.GenWithStackByArgs("LIMIT") } @@ -904,8 +921,7 @@ func (b *PlanBuilder) buildLimit(src LogicalPlan, limit *ast.Limit) (LogicalPlan offset, count uint64 err error ) - sc := b.ctx.GetSessionVars().StmtCtx - if count, offset, err = extractLimitCountOffset(sc, limit); err != nil { + if count, offset, err = extractLimitCountOffset(b.ctx, limit); err != nil { return nil, err } @@ -1188,7 +1204,7 @@ func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) { case *ast.SubqueryExpr, *ast.CompareSubqueryExpr, *ast.ExistsSubqueryExpr: return inNode, true case *driver.ParamMarkerExpr: - newNode := expression.ConvertToByItemExpr(g.ctx, inNode) + newNode := expression.ConvertToByItemExpr(inNode) g.isParam = true return newNode, true case *driver.ValueExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.ColumnName: @@ -1226,8 +1242,10 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { } case *ast.PositionExpr: pos, isNull, err := expression.PosFromPositionExpr(g.ctx, v) - if err != nil || isNull { + if err != nil { g.err = ErrUnknown.GenWithStackByArgs() + } + if err != nil || isNull { return inNode, false } if pos < 1 || pos > len(g.fields) { diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index eb177aabfcc7a..ea73910f3d645 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -148,8 +148,7 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP if selStmt.Having != nil || selStmt.LockTp != ast.SelectLockNone { return nil } else if selStmt.Limit != nil { - sc := ctx.GetSessionVars().StmtCtx - count, offset, err := extractLimitCountOffset(sc, selStmt.Limit) + count, offset, err := extractLimitCountOffset(ctx, selStmt.Limit) if err != nil || count == 0 || offset > 0 { return nil } diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 2c8966f5ce8fd..ef416642c2c62 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -61,7 +61,6 @@ type StatementContext struct { PadCharToFullLength bool BatchCheck bool InNullRejectCheck bool - InPreparedStmt bool // mu struct holds variables that change during execution. mu struct { From 253106d7be6a457ff0bb63832d3538137c1eec8f Mon Sep 17 00:00:00 2001 From: Jun-Seok Heo Date: Mon, 3 Dec 2018 14:57:27 +0900 Subject: [PATCH 10/11] remove the unnecessary trace --- expression/util.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/expression/util.go b/expression/util.go index 7fdab614a7030..0b3d7b5f095cd 100644 --- a/expression/util.go +++ b/expression/util.go @@ -544,7 +544,7 @@ func PosFromPositionExpr(ctx sessionctx.Context, v *ast.PositionExpr) (int, bool } value, err := GetParamExpression(ctx, v.P.(*driver.ParamMarkerExpr)) if err != nil { - return 0, true, errors.Trace(err) + return 0, true, err } pos, isNull, err := GetIntFromConstant(ctx, value) if err != nil || isNull { From 7ac5db04e364ff46f3885296216a1bf6c1abf96c Mon Sep 17 00:00:00 2001 From: Jun-Seok Heo Date: Mon, 3 Dec 2018 22:25:15 +0900 Subject: [PATCH 11/11] clean up the codes --- expression/util.go | 10 +++------- planner/core/logical_plan_builder.go | 8 ++++---- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/expression/util.go b/expression/util.go index 0b3d7b5f095cd..c5e301b232d40 100644 --- a/expression/util.go +++ b/expression/util.go @@ -528,13 +528,9 @@ func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr) (Expr return value, nil } -// ConvertToByItemExpr rewrites ByItem.ExprNode to a proper ExprNode. -func ConvertToByItemExpr(n ast.Node) ast.Node { - switch v := n.(type) { - case *driver.ParamMarkerExpr: - return &ast.PositionExpr{P: v} - } - return n +// ConstructPositionExpr constructs PositionExpr with the given ParamMarkerExpr. +func ConstructPositionExpr(p *driver.ParamMarkerExpr) *ast.PositionExpr { + return &ast.PositionExpr{P: p} } // PosFromPositionExpr generates a position value from PositionExpr. diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 78b15c4a6594d..8179ee33fbf87 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -822,9 +822,9 @@ type itemTransformer struct { } func (t *itemTransformer) Enter(inNode ast.Node) (ast.Node, bool) { - switch inNode.(type) { + switch n := inNode.(type) { case *driver.ParamMarkerExpr: - newNode := expression.ConvertToByItemExpr(inNode) + newNode := expression.ConstructPositionExpr(n) return newNode, true } return inNode, false @@ -1200,11 +1200,11 @@ type gbyResolver struct { } func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) { - switch inNode.(type) { + switch n := inNode.(type) { case *ast.SubqueryExpr, *ast.CompareSubqueryExpr, *ast.ExistsSubqueryExpr: return inNode, true case *driver.ParamMarkerExpr: - newNode := expression.ConvertToByItemExpr(inNode) + newNode := expression.ConstructPositionExpr(n) g.isParam = true return newNode, true case *driver.ValueExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.ColumnName: