From 0081e17b02ff66ebbcd35b6c3a4eb21e899a729c Mon Sep 17 00:00:00 2001 From: wuudjac Date: Wed, 13 Feb 2019 19:53:28 +0800 Subject: [PATCH] expression, planner: support builtin function benchmark (#9252) --- expression/builtin_info.go | 99 ++++++++++++++++++++++++++++- expression/builtin_info_test.go | 39 ++++++++++-- expression/function_traits.go | 8 +++ expression/integration_test.go | 36 +++++++++++ planner/core/expression_rewriter.go | 72 +++++++++++++-------- planner/core/planbuilder_test.go | 76 ++++++++++++++++++++++ 6 files changed, 299 insertions(+), 31 deletions(-) diff --git a/expression/builtin_info.go b/expression/builtin_info.go index 8c4371bbfe802..a555bd08d26d5 100644 --- a/expression/builtin_info.go +++ b/expression/builtin_info.go @@ -384,7 +384,104 @@ type benchmarkFunctionClass struct { } func (c *benchmarkFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { - return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "BENCHMARK") + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + // Syntax: BENCHMARK(loop_count, expression) + // Define with same eval type of input arg to avoid unnecessary cast function. + sameEvalType := args[1].GetType().EvalType() + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, sameEvalType) + sig := &builtinBenchmarkSig{bf} + return sig, nil +} + +type builtinBenchmarkSig struct { + baseBuiltinFunc +} + +func (b *builtinBenchmarkSig) Clone() builtinFunc { + newSig := &builtinBenchmarkSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinBenchmarkSig. It will execute expression repeatedly count times. +// See https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_benchmark +func (b *builtinBenchmarkSig) evalInt(row chunk.Row) (int64, bool, error) { + // Get loop count. + loopCount, isNull, err := b.args[0].EvalInt(b.ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + + // BENCHMARK() will return NULL if loop count < 0, + // behavior observed on MySQL 5.7.24. + if loopCount < 0 { + return 0, true, nil + } + + // Eval loop count times based on arg type. + // BENCHMARK() will pass-through the eval error, + // behavior observed on MySQL 5.7.24. + var i int64 + arg, ctx := b.args[1], b.ctx + switch evalType := arg.GetType().EvalType(); evalType { + case types.ETInt: + for ; i < loopCount; i++ { + _, isNull, err = arg.EvalInt(ctx, row) + if err != nil { + return 0, isNull, err + } + } + case types.ETReal: + for ; i < loopCount; i++ { + _, isNull, err = arg.EvalReal(ctx, row) + if err != nil { + return 0, isNull, err + } + } + case types.ETDecimal: + for ; i < loopCount; i++ { + _, isNull, err = arg.EvalDecimal(ctx, row) + if err != nil { + return 0, isNull, err + } + } + case types.ETString: + for ; i < loopCount; i++ { + _, isNull, err = arg.EvalString(ctx, row) + if err != nil { + return 0, isNull, err + } + } + case types.ETDatetime, types.ETTimestamp: + for ; i < loopCount; i++ { + _, isNull, err = arg.EvalTime(ctx, row) + if err != nil { + return 0, isNull, err + } + } + case types.ETDuration: + for ; i < loopCount; i++ { + _, isNull, err = arg.EvalDuration(ctx, row) + if err != nil { + return 0, isNull, err + } + } + case types.ETJson: + for ; i < loopCount; i++ { + _, isNull, err = arg.EvalJSON(ctx, row) + if err != nil { + return 0, isNull, err + } + } + default: // Should never go into here. + return 0, true, errors.Errorf("EvalType %v not implemented for builtin BENCHMARK()", evalType) + } + + // Return value of BENCHMARK() is always 0. + return 0, false, nil } type charsetFunctionClass struct { diff --git a/expression/builtin_info_test.go b/expression/builtin_info_test.go index 6bc1a8805cca6..0cc6d0f57ed71 100644 --- a/expression/builtin_info_test.go +++ b/expression/builtin_info_test.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/parser/charset" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/printer" @@ -120,10 +121,40 @@ func (s *testEvaluatorSuite) TestVersion(c *C) { func (s *testEvaluatorSuite) TestBenchMark(c *C) { defer testleak.AfterTest(c)() - fc := funcs[ast.Benchmark] - f, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(nil, nil))) - c.Assert(f, IsNil) - c.Assert(err, ErrorMatches, "*FUNCTION BENCHMARK does not exist") + + cases := []struct { + LoopCount int + Expression interface{} + Expected int64 + IsNil bool + }{ + {-3, 1, 0, true}, + {0, 1, 0, false}, + {3, 1, 0, false}, + {3, 1.234, 0, false}, + {3, types.NewDecFromFloatForTest(1.234), 0, false}, + {3, "abc", 0, false}, + {3, types.CurrentTime(mysql.TypeDatetime), 0, false}, + {3, types.CurrentTime(mysql.TypeTimestamp), 0, false}, + {3, types.CurrentTime(mysql.TypeDuration), 0, false}, + {3, json.CreateBinary("[1]"), 0, false}, + } + + for _, t := range cases { + f, err := newFunctionForTest(s.ctx, ast.Benchmark, s.primitiveValsToConstants([]interface{}{ + t.LoopCount, + t.Expression, + })...) + c.Assert(err, IsNil) + + d, err := f.Eval(chunk.Row{}) + c.Assert(err, IsNil) + if t.IsNil { + c.Assert(d.IsNull(), IsTrue) + } else { + c.Assert(d.GetInt64(), Equals, t.Expected) + } + } } func (s *testEvaluatorSuite) TestCharset(c *C) { diff --git a/expression/function_traits.go b/expression/function_traits.go index de65e3297770f..d7c0ec881f8bd 100644 --- a/expression/function_traits.go +++ b/expression/function_traits.go @@ -40,6 +40,14 @@ var unFoldableFunctions = map[string]struct{}{ ast.SetVar: {}, ast.GetVar: {}, ast.GetParam: {}, + ast.Benchmark: {}, +} + +// DisableFoldFunctions stores functions which prevent child scope functions from being constant folded. +// Typically, these functions shall also exist in unFoldableFunctions, to stop from being folded when they themselves +// are in child scope of an outer function, and the outer function is recursively folding its children. +var DisableFoldFunctions = map[string]struct{}{ + ast.Benchmark: {}, } // DeferredFunctions stores non-deterministic functions, which can be deferred only when the plan cache is enabled. diff --git a/expression/integration_test.go b/expression/integration_test.go index 31e9e778b8b8d..6c6334fe833c4 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2489,6 +2489,42 @@ func (s *testIntegrationSuite) TestInfoBuiltin(c *C) { result.Check(testkit.Rows("1")) result = tk.MustQuery("select row_count();") result.Check(testkit.Rows("-1")) + + // for benchmark + success := testkit.Rows("0") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int, b int)") + result = tk.MustQuery(`select benchmark(3, benchmark(2, length("abc")))`) + result.Check(success) + err := tk.ExecToErr(`select benchmark(3, length("a", "b"))`) + c.Assert(err, NotNil) + // Quoted from https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_benchmark + // Although the expression can be a subquery, it must return a single column and at most a single row. + // For example, BENCHMARK(10, (SELECT * FROM t)) will fail if the table t has more than one column or + // more than one row. + oneColumnQuery := "select benchmark(10, (select a from t))" + twoColumnQuery := "select benchmark(10, (select * from t))" + // rows * columns: + // 0 * 1, success; + result = tk.MustQuery(oneColumnQuery) + result.Check(success) + // 0 * 2, error; + err = tk.ExecToErr(twoColumnQuery) + c.Assert(err, NotNil) + // 1 * 1, success; + tk.MustExec("insert t values (1, 2)") + result = tk.MustQuery(oneColumnQuery) + result.Check(success) + // 1 * 2, error; + err = tk.ExecToErr(twoColumnQuery) + c.Assert(err, NotNil) + // 2 * 1, error; + tk.MustExec("insert t values (3, 4)") + err = tk.ExecToErr(oneColumnQuery) + c.Assert(err, NotNil) + // 2 * 2, error. + err = tk.ExecToErr(twoColumnQuery) + c.Assert(err, NotNil) } func (s *testIntegrationSuite) TestControlBuiltin(c *C) { diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 4d55f9b48a25c..6b7deb45f5dbc 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -29,7 +29,7 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/types/parser_driver" + driver "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/stringutil" ) @@ -129,6 +129,7 @@ func (b *PlanBuilder) getExpressionRewriter(p LogicalPlan) (rewriter *expression rewriter.aggrMap = nil rewriter.preprocess = nil rewriter.insertPlan = nil + rewriter.disableFoldCounter = 0 rewriter.ctxStack = rewriter.ctxStack[:0] return } @@ -170,6 +171,12 @@ type expressionRewriter struct { // insertPlan is only used to rewrite the expressions inside the assignment // of the "INSERT" statement. insertPlan *Insert + + // disableFoldCounter controls fold-disabled scope. If > 0, rewriter will NOT do constant folding. + // Typically, during visiting AST, while entering the scope(disable), the counter will +1; while + // leaving the scope(enable again), the counter will -1. + // NOTE: This value can be changed during expression rewritten. + disableFoldCounter int } // constructBinaryOpFunction converts binary operator functions @@ -182,7 +189,7 @@ type expressionRewriter struct { func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, r expression.Expression, op string) (expression.Expression, error) { lLen, rLen := expression.GetRowLen(l), expression.GetRowLen(r) if lLen == 1 && rLen == 1 { - return expression.NewFunction(er.ctx, op, types.NewFieldType(mysql.TypeTiny), l, r) + return er.newFunction(op, types.NewFieldType(mysql.TypeTiny), l, r) } else if rLen != lLen { return nil, expression.ErrOperandColumns.GenWithStackByArgs(lLen) } @@ -219,11 +226,11 @@ func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, if err != nil { return nil, errors.Trace(err) } - expr5, err = expression.NewFunction(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), expr3, expression.Null, expr4) + expr5, err = er.newFunction(ast.If, types.NewFieldType(mysql.TypeTiny), expr3, expression.Null, expr4) if err != nil { return nil, errors.Trace(err) } - return expression.NewFunction(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), expr1, expr2, expr5) + return er.newFunction(ast.If, types.NewFieldType(mysql.TypeTiny), expr1, expr2, expr5) } } @@ -309,6 +316,10 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { return inNode, true case *ast.WindowFuncExpr: return er.handleWindowFunction(v) + case *ast.FuncCallExpr: + if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok { + er.disableFoldCounter++ + } default: er.asScalar = true } @@ -362,7 +373,7 @@ func (er *expressionRewriter) handleCompareSubquery(v *ast.CompareSubqueryExpr) for _, col := range np.Schema().Columns { args = append(args, col) } - rexpr, er.err = expression.NewFunction(er.ctx, ast.RowFunc, args[0].GetType(), args...) + rexpr, er.err = er.newFunction(ast.RowFunc, args[0].GetType(), args...) if er.err != nil { er.err = errors.Trace(er.err) return v, true @@ -651,7 +662,7 @@ func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, for _, col := range np.Schema().Columns { args = append(args, col) } - rexpr, er.err = expression.NewFunction(er.ctx, ast.RowFunc, args[0].GetType(), args...) + rexpr, er.err = er.newFunction(ast.RowFunc, args[0].GetType(), args...) if er.err != nil { er.err = errors.Trace(er.err) return v, true @@ -718,7 +729,7 @@ func (er *expressionRewriter) handleScalarSubquery(v *ast.SubqueryExpr) (ast.Nod for _, col := range np.Schema().Columns { newCols = append(newCols, col) } - expr, err1 := expression.NewFunction(er.ctx, ast.RowFunc, newCols[0].GetType(), newCols...) + expr, err1 := er.newFunction(ast.RowFunc, newCols[0].GetType(), newCols...) if err1 != nil { er.err = errors.Trace(err1) return v, true @@ -746,7 +757,7 @@ func (er *expressionRewriter) handleScalarSubquery(v *ast.SubqueryExpr) (ast.Nod Value: data, RetType: np.Schema().Columns[i].GetType()}) } - expr, err1 := expression.NewFunction(er.ctx, ast.RowFunc, newCols[0].GetType(), newCols...) + expr, err1 := er.newFunction(ast.RowFunc, newCols[0].GetType(), newCols...) if err1 != nil { er.err = errors.Trace(err1) return v, true @@ -787,6 +798,9 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok er.rewriteVariable(v) case *ast.FuncCallExpr: er.funcCallToExpression(v) + if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok { + er.disableFoldCounter-- + } case *ast.ColumnName: er.toColumn(v) case *ast.UnaryOperationExpr: @@ -840,6 +854,14 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok return originInNode, true } +// newFunction chooses which expression.NewFunctionImpl() will be used. +func (er *expressionRewriter) newFunction(funcName string, retType *types.FieldType, args ...expression.Expression) (expression.Expression, error) { + if er.disableFoldCounter > 0 { + return expression.NewFunctionBase(er.ctx, funcName, retType, args...) + } + return expression.NewFunction(er.ctx, funcName, retType, args...) +} + func (er *expressionRewriter) checkTimePrecision(ft *types.FieldType) error { if ft.EvalType() == types.ETDuration && ft.Decimal > types.MaxFsp { return errTooBigPrecision.GenWithStackByArgs(ft.Decimal, "CAST", types.MaxFsp) @@ -857,15 +879,13 @@ func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) { sessionVars := er.b.ctx.GetSessionVars() if !v.IsSystem { if v.Value != nil { - er.ctxStack[stkLen-1], er.err = expression.NewFunction(er.ctx, - ast.SetVar, + er.ctxStack[stkLen-1], er.err = er.newFunction(ast.SetVar, er.ctxStack[stkLen-1].GetType(), expression.DatumToConstant(types.NewDatum(name), mysql.TypeString), er.ctxStack[stkLen-1]) return } - f, err := expression.NewFunction(er.ctx, - ast.GetVar, + f, err := er.newFunction(ast.GetVar, // TODO: Here is wrong, the sessionVars should store a name -> Datum map. Will fix it later. types.NewFieldType(mysql.TypeString), expression.DatumToConstant(types.NewStringDatum(name), mysql.TypeString)) @@ -927,7 +947,7 @@ func (er *expressionRewriter) unaryOpToExpression(v *ast.UnaryOperationExpr) { er.err = expression.ErrOperandColumns.GenWithStackByArgs(1) return } - er.ctxStack[stkLen-1], er.err = expression.NewFunction(er.ctx, op, &v.Type, er.ctxStack[stkLen-1]) + er.ctxStack[stkLen-1], er.err = er.newFunction(op, &v.Type, er.ctxStack[stkLen-1]) } func (er *expressionRewriter) binaryOpToExpression(v *ast.BinaryOperationExpr) { @@ -944,7 +964,7 @@ func (er *expressionRewriter) binaryOpToExpression(v *ast.BinaryOperationExpr) { er.err = expression.ErrOperandColumns.GenWithStackByArgs(1) return } - function, er.err = expression.NewFunction(er.ctx, v.Op.String(), types.NewFieldType(mysql.TypeUnspecified), er.ctxStack[stkLen-2:]...) + function, er.err = er.newFunction(v.Op.String(), types.NewFieldType(mysql.TypeUnspecified), er.ctxStack[stkLen-2:]...) } if er.err != nil { er.err = errors.Trace(er.err) @@ -956,7 +976,7 @@ func (er *expressionRewriter) binaryOpToExpression(v *ast.BinaryOperationExpr) { func (er *expressionRewriter) notToExpression(hasNot bool, op string, tp *types.FieldType, args ...expression.Expression) expression.Expression { - opFunc, err := expression.NewFunction(er.ctx, op, tp, args...) + opFunc, err := er.newFunction(op, tp, args...) if err != nil { er.err = errors.Trace(err) return nil @@ -965,7 +985,7 @@ func (er *expressionRewriter) notToExpression(hasNot bool, op string, tp *types. return opFunc } - opFunc, err = expression.NewFunction(er.ctx, ast.UnaryNot, tp, opFunc) + opFunc, err = er.newFunction(ast.UnaryNot, tp, opFunc) if err != nil { er.err = errors.Trace(err) return nil @@ -1073,7 +1093,7 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field function = expression.ComposeDNFCondition(er.ctx, eqFunctions...) if not { var err error - function, err = expression.NewFunction(er.ctx, ast.UnaryNot, tp, function) + function, err = er.newFunction(ast.UnaryNot, tp, function) if err != nil { er.err = err return @@ -1107,7 +1127,7 @@ func (er *expressionRewriter) caseToExpression(v *ast.CaseExpr) { value := er.ctxStack[stkLen-argsLen-1] args = make([]expression.Expression, 0, argsLen) for i := stkLen - argsLen; i < stkLen-1; i += 2 { - arg, err := expression.NewFunction(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), value, er.ctxStack[i]) + arg, err := er.newFunction(ast.EQ, types.NewFieldType(mysql.TypeTiny), value, er.ctxStack[i]) if err != nil { er.err = errors.Trace(err) return @@ -1126,7 +1146,7 @@ func (er *expressionRewriter) caseToExpression(v *ast.CaseExpr) { // else clause args = er.ctxStack[stkLen-argsLen:] } - function, err := expression.NewFunction(er.ctx, ast.Case, &v.Type, args...) + function, err := er.newFunction(ast.Case, &v.Type, args...) if err != nil { er.err = errors.Trace(err) return @@ -1196,7 +1216,7 @@ func (er *expressionRewriter) rowToScalarFunc(v *ast.RowExpr) { rows = append(rows, er.ctxStack[i]) } er.ctxStack = er.ctxStack[:stkLen-length] - function, err := expression.NewFunction(er.ctx, ast.RowFunc, rows[0].GetType(), rows...) + function, err := er.newFunction(ast.RowFunc, rows[0].GetType(), rows...) if err != nil { er.err = errors.Trace(err) return @@ -1212,22 +1232,22 @@ func (er *expressionRewriter) betweenToExpression(v *ast.BetweenExpr) { } var op string var l, r expression.Expression - l, er.err = expression.NewFunction(er.ctx, ast.GE, &v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-2]) + l, er.err = er.newFunction(ast.GE, &v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-2]) if er.err == nil { - r, er.err = expression.NewFunction(er.ctx, ast.LE, &v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-1]) + r, er.err = er.newFunction(ast.LE, &v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-1]) } op = ast.LogicAnd if er.err != nil { er.err = errors.Trace(er.err) return } - function, err := expression.NewFunction(er.ctx, op, &v.Type, l, r) + function, err := er.newFunction(op, &v.Type, l, r) if err != nil { er.err = errors.Trace(err) return } if v.Not { - function, err = expression.NewFunction(er.ctx, ast.UnaryNot, &v.Type, function) + function, err = er.newFunction(ast.UnaryNot, &v.Type, function) if err != nil { er.err = errors.Trace(err) return @@ -1284,7 +1304,7 @@ func (er *expressionRewriter) rewriteFuncCall(v *ast.FuncCallExpr) bool { RetType: nullTp, } // if(param1 = param2, NULL, param1) - funcIf, err := expression.NewFunction(er.ctx, ast.If, &v.Type, funcCompare, paramNull, param1) + funcIf, err := er.newFunction(ast.If, &v.Type, funcCompare, paramNull, param1) if err != nil { er.err = err return true @@ -1316,7 +1336,7 @@ func (er *expressionRewriter) funcCallToExpression(v *ast.FuncCallExpr) { c := &expression.Constant{Value: types.NewDatum(nil), RetType: function.GetType().Clone(), DeferredExpr: function} er.ctxStack = append(er.ctxStack, c) } else { - function, er.err = expression.NewFunction(er.ctx, v.FnName.L, &v.Type, args...) + function, er.err = er.newFunction(v.FnName.L, &v.Type, args...) er.ctxStack = append(er.ctxStack, function) } } diff --git a/planner/core/planbuilder_test.go b/planner/core/planbuilder_test.go index 49a029904e21a..8eb2e9623be53 100644 --- a/planner/core/planbuilder_test.go +++ b/planner/core/planbuilder_test.go @@ -15,8 +15,10 @@ package core import ( . "github.com/pingcap/check" + "github.com/pingcap/parser" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" + "github.com/pingcap/tidb/expression" ) var _ = Suite(&testPlanBuilderSuite{}) @@ -89,3 +91,77 @@ func (s *testPlanBuilderSuite) TestGetPathByIndexName(c *C) { path = getPathByIndexName(accessPath, model.NewCIStr("primary"), tblInfo) c.Assert(path, IsNil) } + +func (s *testPlanBuilderSuite) TestRewriterPool(c *C) { + builder := &PlanBuilder{ + ctx: MockContext(), + } + + // Make sure PlanBuilder.getExpressionRewriter() provides clean rewriter from pool. + // First, pick one rewriter from the pool and make it dirty. + builder.rewriterCounter++ + dirtyRewriter := builder.getExpressionRewriter(nil) + dirtyRewriter.asScalar = true + dirtyRewriter.aggrMap = make(map[*ast.AggregateFuncExpr]int) + dirtyRewriter.preprocess = func(ast.Node) ast.Node { return nil } + dirtyRewriter.insertPlan = &Insert{} + dirtyRewriter.disableFoldCounter = 1 + dirtyRewriter.ctxStack = make([]expression.Expression, 2) + builder.rewriterCounter-- + // Then, pick again and check if it's cleaned up. + builder.rewriterCounter++ + cleanRewriter := builder.getExpressionRewriter(nil) + c.Assert(cleanRewriter, Equals, dirtyRewriter) // Rewriter should be reused. + c.Assert(cleanRewriter.asScalar, Equals, false) + c.Assert(cleanRewriter.aggrMap, IsNil) + c.Assert(cleanRewriter.preprocess, IsNil) + c.Assert(cleanRewriter.insertPlan, IsNil) + c.Assert(cleanRewriter.disableFoldCounter, Equals, 0) + c.Assert(len(cleanRewriter.ctxStack), Equals, 0) + builder.rewriterCounter-- +} + +func (s *testPlanBuilderSuite) TestDisableFold(c *C) { + // Functions like BENCHMARK() shall not be folded into result 0, + // but normal outer function with constant args should be folded. + // Types of expression and first layer of args will be validated. + cases := []struct { + SQL string + Expected expression.Expression + Args []expression.Expression + }{ + {`select sin(length("abc"))`, &expression.Constant{}, nil}, + {`select benchmark(3, sin(123))`, &expression.ScalarFunction{}, []expression.Expression{ + &expression.Constant{}, + &expression.ScalarFunction{}, + }}, + {`select pow(length("abc"), benchmark(3, sin(123)))`, &expression.ScalarFunction{}, []expression.Expression{ + &expression.Constant{}, + &expression.ScalarFunction{}, + }}, + } + + ctx := MockContext() + for _, t := range cases { + st, err := parser.New().ParseOneStmt(t.SQL, "", "") + c.Assert(err, IsNil) + stmt := st.(*ast.SelectStmt) + expr := stmt.Fields.Fields[0].Expr + + builder := &PlanBuilder{ctx: ctx} + builder.rewriterCounter++ + rewriter := builder.getExpressionRewriter(nil) + c.Assert(rewriter, NotNil) + c.Assert(rewriter.disableFoldCounter, Equals, 0) + rewritenExpression, _, err := builder.rewriteExprNode(rewriter, expr, true) + c.Assert(err, IsNil) + c.Assert(rewriter.disableFoldCounter, Equals, 0) // Make sure the counter is reduced to 0 in the end. + builder.rewriterCounter-- + + c.Assert(rewritenExpression, FitsTypeOf, t.Expected) + for i, expectedArg := range t.Args { + rewritenArg := expression.GetFuncArg(rewritenExpression, i) + c.Assert(rewritenArg, FitsTypeOf, expectedArg) + } + } +}