Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner, expression: support builtin function NAME_CONST #9261

Merged
merged 13 commits into from
Feb 19, 2019
Merged
138 changes: 136 additions & 2 deletions expression/builtin_miscellaneous.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ var (
_ builtinFunc = &builtinIsIPv4MappedSig{}
_ builtinFunc = &builtinIsIPv6Sig{}
_ builtinFunc = &builtinUUIDSig{}

_ builtinFunc = &builtinNameConstIntSig{}
_ builtinFunc = &builtinNameConstRealSig{}
_ builtinFunc = &builtinNameConstDecimalSig{}
_ builtinFunc = &builtinNameConstTimeSig{}
_ builtinFunc = &builtinNameConstDurationSig{}
_ builtinFunc = &builtinNameConstStringSig{}
_ builtinFunc = &builtinNameConstJSONSig{}
)

type sleepFunctionClass struct {
Expand Down Expand Up @@ -228,7 +236,7 @@ func (c *anyValueFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
bf.tp.Charset, bf.tp.Collate, bf.tp.Flag = mysql.DefaultCharset, mysql.DefaultCollationName, 0
sig = &builtinTimeAnyValueSig{bf}
default:
panic("unexpected types.EvalType of builtin function ANY_VALUE")
return nil, errIncorrectArgs.GenWithStackByArgs("ANY_VALUE")
}
return sig, nil
}
Expand Down Expand Up @@ -808,7 +816,133 @@ type nameConstFunctionClass struct {
}

func (c *nameConstFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "NAME_CONST")
if err := c.verifyArgs(args); err != nil {
return nil, err
}
argTp := args[1].GetType().EvalType()
bf := newBaseBuiltinFuncWithTp(ctx, args, argTp, types.ETString, argTp)
*bf.tp = *args[1].GetType()
var sig builtinFunc
switch argTp {
case types.ETDecimal:
sig = &builtinNameConstDecimalSig{bf}
case types.ETDuration:
sig = &builtinNameConstDurationSig{bf}
case types.ETInt:
bf.tp.Decimal = 0
sig = &builtinNameConstIntSig{bf}
case types.ETJson:
sig = &builtinNameConstJSONSig{bf}
case types.ETReal:
sig = &builtinNameConstRealSig{bf}
case types.ETString:
bf.tp.Decimal = types.UnspecifiedLength
sig = &builtinNameConstStringSig{bf}
case types.ETDatetime, types.ETTimestamp:
bf.tp.Charset, bf.tp.Collate, bf.tp.Flag = mysql.DefaultCharset, mysql.DefaultCollationName, 0
sig = &builtinNameConstTimeSig{bf}
default:
return nil, errIncorrectArgs.GenWithStackByArgs("NAME_CONST")
}
return sig, nil
}

type builtinNameConstDecimalSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstDecimalSig) Clone() builtinFunc {
newSig := &builtinNameConstDecimalSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstDecimalSig) evalDecimal(row chunk.Row) (*types.MyDecimal, bool, error) {
return b.args[1].EvalDecimal(b.ctx, row)
}

type builtinNameConstIntSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstIntSig) Clone() builtinFunc {
newSig := &builtinNameConstIntSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstIntSig) evalInt(row chunk.Row) (int64, bool, error) {
return b.args[1].EvalInt(b.ctx, row)
}

type builtinNameConstRealSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstRealSig) Clone() builtinFunc {
newSig := &builtinNameConstRealSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstRealSig) evalReal(row chunk.Row) (float64, bool, error) {
return b.args[1].EvalReal(b.ctx, row)
}

type builtinNameConstStringSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstStringSig) Clone() builtinFunc {
newSig := &builtinNameConstStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstStringSig) evalString(row chunk.Row) (string, bool, error) {
return b.args[1].EvalString(b.ctx, row)
}

type builtinNameConstJSONSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstJSONSig) Clone() builtinFunc {
newSig := &builtinNameConstJSONSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstJSONSig) evalJSON(row chunk.Row) (json.BinaryJSON, bool, error) {
return b.args[1].EvalJSON(b.ctx, row)
}

type builtinNameConstDurationSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstDurationSig) Clone() builtinFunc {
newSig := &builtinNameConstDurationSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstDurationSig) evalDuration(row chunk.Row) (types.Duration, bool, error) {
return b.args[1].EvalDuration(b.ctx, row)
}

type builtinNameConstTimeSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstTimeSig) Clone() builtinFunc {
newSig := &builtinNameConstTimeSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstTimeSig) evalTime(row chunk.Row) (types.Time, bool, error) {
return b.args[1].EvalTime(b.ctx, row)
}

type releaseAllLocksFunctionClass struct {
Expand Down
47 changes: 47 additions & 0 deletions expression/builtin_miscellaneous_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ package expression
import (
"math"
"strings"
"time"

. "github.com/pingcap/check"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/testleak"
Expand Down Expand Up @@ -320,3 +322,48 @@ func (s *testEvaluatorSuite) TestIsIPv4Compat(c *C) {
c.Assert(err, IsNil)
c.Assert(r, testutil.DatumEquals, types.NewDatum(0))
}

func (s *testEvaluatorSuite) TestNameConst(c *C) {
defer testleak.AfterTest(c)()
dec := types.NewDecFromFloatForTest(123.123)
tm := types.Time{Time: types.FromGoTime(time.Now()), Fsp: 6, Type: mysql.TypeDatetime}
du := types.Duration{Duration: time.Duration(12*time.Hour + 1*time.Minute + 1*time.Second), Fsp: types.DefaultFsp}
cases := []struct {
colName string
arg interface{}
isNil bool
asserts func(d types.Datum)
}{
{"test_int", 3, false, func(d types.Datum) {
c.Assert(d.GetInt64(), Equals, int64(3))
}},
{"test_float", 3.14159, false, func(d types.Datum) {
c.Assert(d.GetFloat64(), Equals, 3.14159)
}},
{"test_string", "TiDB", false, func(d types.Datum) {
c.Assert(d.GetString(), Equals, "TiDB")
}},

{"test_null", nil, true, func(d types.Datum) {
c.Assert(d.Kind(), Equals, types.KindNull)
}},
{"test_decimal", dec, false, func(d types.Datum) {

spongedu marked this conversation as resolved.
Show resolved Hide resolved
c.Assert(d.GetMysqlDecimal().String(), Equals, dec.String())
}},
{"test_time", tm, false, func(d types.Datum) {
c.Assert(d.GetMysqlTime().String(), Equals, tm.String())
}},
{"test_duration", du, false, func(d types.Datum) {
c.Assert(d.GetMysqlDuration().String(), Equals, du.String())
}},
}

for _, t := range cases {
f, err := newFunctionForTest(s.ctx, ast.NameConst, s.primitiveValsToConstants([]interface{}{t.colName, t.arg})...)
c.Assert(err, IsNil)
d, err := f.Eval(chunk.Row{})
c.Assert(err, IsNil)
t.asserts(d)
}
}
9 changes: 7 additions & 2 deletions expression/constant_fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ var specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){}

func init() {
specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){
ast.If: ifFoldHandler,
ast.Ifnull: ifNullFoldHandler,
ast.If: ifFoldHandler,
ast.Ifnull: ifNullFoldHandler,
ast.NameConst: nameConstFoldHandler,
}
}

Expand All @@ -35,6 +36,10 @@ func FoldConstant(expr Expression) Expression {
return e
}

func nameConstFoldHandler(expr *ScalarFunction) (Expression, bool) {
return foldConstant(expr.GetArgs()[1])
eurekaka marked this conversation as resolved.
Show resolved Hide resolved
}

func ifFoldHandler(expr *ScalarFunction) (Expression, bool) {
args := expr.GetArgs()
foldedArg0, _ := foldConstant(args[0])
Expand Down
38 changes: 38 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/mock"
"github.com/pingcap/tidb/util/sqlexec"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/testutil"
Expand Down Expand Up @@ -3882,3 +3883,40 @@ func (s *testIntegrationSuite) TestValuesFloat32(c *C) {
tk.MustExec(`insert into t values (1, 0.02) on duplicate key update j = values (j);`)
tk.MustQuery(`select * from t;`).Check(testkit.Rows(`1 0.02`))
}

func (s *testIntegrationSuite) TestFuncNameConst(c *C) {
tk := testkit.NewTestKit(c, s.store)
defer s.cleanEnv(c)
tk.MustExec("USE test;")
tk.MustExec("DROP TABLE IF EXISTS t;")
tk.MustExec("CREATE TABLE t(a CHAR(20), b VARCHAR(20), c BIGINT);")
tk.MustExec("INSERT INTO t (b, c) values('hello', 1);")

r := tk.MustQuery("SELECT name_const('test_int', 1), name_const('test_float', 3.1415);")
r.Check(testkit.Rows("1 3.1415"))
r = tk.MustQuery("SELECT name_const('test_string', 'hello'), name_const('test_nil', null);")
r.Check(testkit.Rows("hello <nil>"))
r = tk.MustQuery("SELECT name_const('test_string', 1) + c FROM t;")
r.Check(testkit.Rows("2"))
r = tk.MustQuery("SELECT concat('hello', name_const('test_string', 'world')) FROM t;")
r.Check(testkit.Rows("helloworld"))
err := tk.ExecToErr(`select name_const(a,b) from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const(a,"hello") from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const("hello", b) from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const("hello", 1+1) from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const(concat('a', 'b'), 555) from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const(555) from t;`)
c.Assert(err.Error(), Equals, "[expression:1582]Incorrect parameter count in the call to native function 'name_const'")

var rs sqlexec.RecordSet
rs, err = tk.Exec(`select name_const("hello", 1);`)
c.Assert(err, IsNil)
c.Assert(len(rs.Fields()), Equals, 1)
c.Assert(rs.Fields()[0].Column.Name.L, Equals, "hello")

}
44 changes: 32 additions & 12 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,18 +557,29 @@ func (b *PlanBuilder) buildProjectionFieldNameFromColumns(field *ast.SelectField
}

// buildProjectionFieldNameFromExpressions builds the field name when field expression is a normal expression.
func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(field *ast.SelectField) model.CIStr {
func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(field *ast.SelectField) (model.CIStr, error) {
if agg, ok := field.Expr.(*ast.AggregateFuncExpr); ok && agg.F == ast.AggFuncFirstRow {
// When the query is select t.a from t group by a; The Column Name should be a but not t.a;
return agg.Args[0].(*ast.ColumnNameExpr).Name.Name
return agg.Args[0].(*ast.ColumnNameExpr).Name.Name, nil
}

innerExpr := getInnerFromParenthesesAndUnaryPlus(field.Expr)
funcCall, isFuncCall := innerExpr.(*ast.FuncCallExpr)
// When used to produce a result set column, NAME_CONST() causes the column to have the given name.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_name-const for details
if isFuncCall && funcCall.FnName.L == ast.NameConst {
if v, err := evalAstExpr(b.ctx, funcCall.Args[0]); err == nil {
zz-jason marked this conversation as resolved.
Show resolved Hide resolved
if s, err := v.ToString(); err == nil {
return model.NewCIStr(s), nil
}
}
return model.NewCIStr(""), ErrWrongArguments.GenWithStackByArgs("NAME_CONST")
}
valueExpr, isValueExpr := innerExpr.(*driver.ValueExpr)

// Non-literal: Output as inputed, except that comments need to be removed.
if !isValueExpr {
return model.NewCIStr(parser.SpecFieldPattern.ReplaceAllStringFunc(field.Text(), parser.TrimComment))
return model.NewCIStr(parser.SpecFieldPattern.ReplaceAllStringFunc(field.Text(), parser.TrimComment)), nil
}

// Literal: Need special processing
Expand All @@ -584,21 +595,21 @@ func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(field *ast.SelectF
fieldName := strings.TrimLeftFunc(projName, func(r rune) bool {
return !unicode.IsOneOf(mysql.RangeGraph, r)
})
return model.NewCIStr(fieldName)
return model.NewCIStr(fieldName), nil
case types.KindNull:
// See #4053, #3685
return model.NewCIStr("NULL")
return model.NewCIStr("NULL"), nil
default:
// Keep as it is.
if innerExpr.Text() != "" {
return model.NewCIStr(innerExpr.Text())
return model.NewCIStr(innerExpr.Text()), nil
}
return model.NewCIStr(field.Text())
return model.NewCIStr(field.Text()), nil
}
}

// buildProjectionField builds the field object according to SelectField in projection.
func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectField, expr expression.Expression) *expression.Column {
func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectField, expr expression.Expression) (*expression.Column, error) {
var origTblName, tblName, origColName, colName, dbName model.CIStr
if c, ok := expr.(*expression.Column); ok && !c.IsReferenced {
// Field is a column reference.
Expand All @@ -608,7 +619,10 @@ func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectFi
colName = field.AsName
} else {
// Other: field is an expression.
colName = b.buildProjectionFieldNameFromExpressions(field)
var err error
if colName, err = b.buildProjectionFieldNameFromExpressions(field); err != nil {
return nil, errors.Trace(err)
eurekaka marked this conversation as resolved.
Show resolved Hide resolved
}
}
return &expression.Column{
UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(),
Expand All @@ -618,7 +632,7 @@ func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectFi
OrigColName: origColName,
DBName: dbName,
RetType: expr.GetType(),
}
}, nil
}

// buildProjection returns a Projection plan and non-aux columns length.
Expand Down Expand Up @@ -647,7 +661,10 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField,
expr = p.Schema().Columns[i]
}
proj.Exprs = append(proj.Exprs, expr)
col := b.buildProjectionField(proj.id, schema.Len()+1, field, expr)
col, err := b.buildProjectionField(proj.id, schema.Len()+1, field, expr)
if err != nil {
return nil, 0, errors.Trace(err)
spongedu marked this conversation as resolved.
Show resolved Hide resolved
}
schema.Append(col)
continue
}
Expand All @@ -659,7 +676,10 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField,
p = np
proj.Exprs = append(proj.Exprs, newExpr)

col := b.buildProjectionField(proj.id, schema.Len()+1, field, newExpr)
col, err := b.buildProjectionField(proj.id, schema.Len()+1, field, newExpr)
if err != nil {
return nil, 0, errors.Trace(err)
eurekaka marked this conversation as resolved.
Show resolved Hide resolved
}
schema.Append(col)
}
proj.SetSchema(schema)
Expand Down
Loading