Skip to content

Commit

Permalink
use simple expr for partition processor
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao committed Feb 6, 2024
1 parent 23bc742 commit 7f6aafe
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 86 deletions.
27 changes: 15 additions & 12 deletions pkg/planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,22 @@ func buildSimpleExpr(ctx expression.BuildContext, node ast.ExprNode, opts ...exp
return nil, errors.New("InputSchema and InputNames should be the same length")
}

if len(options.InputNames) > 0 {
intest.AssertFunc(func() bool {
dbName := options.InputNames[0].DBName
if options.SourceTableDB.L != "" {
intest.Assert(dbName.L == options.SourceTableDB.L)
}

for _, name := range options.InputNames {
intest.Assert(name.DBName.L == dbName.L)
}
// assert all input db names are the same if specified
intest.AssertFunc(func() bool {
if len(options.InputNames) == 0 {
return true
})
}
}

dbName := options.InputNames[0].DBName
if options.SourceTableDB.L != "" {
intest.Assert(dbName.L == options.SourceTableDB.L)
}

for _, name := range options.InputNames {
intest.Assert(name.DBName.L == dbName.L)
}
return true
})

rewriter := &expressionRewriter{
ctx: context.TODO(),
Expand Down
58 changes: 25 additions & 33 deletions pkg/planner/core/expression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,40 +40,27 @@ func parseExpr(t *testing.T, expr string) ast.ExprNode {
return stmt.Fields.Fields[0].Expr
}

func buildExpr(t *testing.T, ctx sessionctx.Context, exprNode any, opts ...expression.BuildOption) (expression.Expression, expression.Expression, error) {
var node ast.ExprNode
var expr2 expression.Expression
var err2 error
parse := false
func buildExpr(t *testing.T, ctx sessionctx.Context, exprNode any, opts ...expression.BuildOption) (expr expression.Expression, err error) {
switch x := exprNode.(type) {
case string:
parse = true
node = parseExpr(t, x)
expr2, err2 = expression.ParseSimpleExpr(ctx, x, opts...)
node := parseExpr(t, x)
expr, err = expression.BuildSimpleExpr(ctx, node, opts...)
case ast.ExprNode:
node = x
expr, err = expression.BuildSimpleExpr(ctx, x, opts...)
default:
require.FailNow(t, "invalid input type: %T", x)
}

expr, err := expression.BuildSimpleExpr(ctx, node, opts...)
if err != nil {
require.Nil(t, expr)
if parse {
require.Nil(t, expr2)
require.EqualError(t, err2, err.Error())
}
} else {
require.NotNil(t, expr)
if parse {
expr.Equal(ctx, expr2)
}
}
return expr, expr2, err
return
}

func buildExprAndEval(t *testing.T, ctx sessionctx.Context, exprNode any) types.Datum {
expr, _, err := buildExpr(t, ctx, exprNode)
expr, err := buildExpr(t, ctx, exprNode)
require.NoError(t, err)
val, err := expr.Eval(ctx, chunk.Row{})
require.NoError(t, err)
Expand Down Expand Up @@ -428,8 +415,13 @@ func TestBuildExpression(t *testing.T) {
schema := expression.NewSchema(cols...)

// normal build
expr, expr2, err := buildExpr(t, ctx, "(1+a)*(3+b)", expression.WithTableInfo("", tbl))
ctx.GetSessionVars().PlanColumnID.Store(0)
expr, err := buildExpr(t, ctx, "(1+a)*(3+b)", expression.WithTableInfo("", tbl))
require.NoError(t, err)
ctx.GetSessionVars().PlanColumnID.Store(0)
expr2, err := expression.ParseSimpleExpr(ctx, "(1+a)*(3+b)", expression.WithTableInfo("", tbl))
require.NoError(t, err)
require.True(t, expr.Equal(ctx, expr2))
val, _, err := expr.EvalInt(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, int64(10), val)
Expand All @@ -443,54 +435,54 @@ func TestBuildExpression(t *testing.T) {
require.NoError(t, err)
require.Equal(t, int64(28), val)

expr, _, err = buildExpr(t, ctx, "(1+a)*(3+b)", expression.WithInputSchemaAndNames(schema, names, nil))
expr, err = buildExpr(t, ctx, "(1+a)*(3+b)", expression.WithInputSchemaAndNames(schema, names, nil))
require.NoError(t, err)
val, _, err = expr.EvalInt(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, int64(10), val)

// build expression without enough columns
_, _, err = buildExpr(t, ctx, "1+a")
_, err = buildExpr(t, ctx, "1+a")
require.EqualError(t, err, "[planner:1054]Unknown column 'a' in 'expression'")
_, _, err = buildExpr(t, ctx, "(1+a)*(3+b+c)", expression.WithTableInfo("", tbl))
_, err = buildExpr(t, ctx, "(1+a)*(3+b+c)", expression.WithTableInfo("", tbl))
require.EqualError(t, err, "[planner:1054]Unknown column 'c' in 'expression'")

// cast to array not supported by default
_, _, err = buildExpr(t, ctx, "cast(1 as signed array)")
_, err = buildExpr(t, ctx, "cast(1 as signed array)")
require.EqualError(t, err, "[expression:1235]This version of TiDB doesn't yet support 'Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions'")
// use WithAllowCastArray to allow casting to array
expr, _, err = buildExpr(t, ctx, `cast(json_extract('{"a": [1, 2, 3]}', '$.a') as signed array)`, expression.WithAllowCastArray(true))
expr, err = buildExpr(t, ctx, `cast(json_extract('{"a": [1, 2, 3]}', '$.a') as signed array)`, expression.WithAllowCastArray(true))
require.NoError(t, err)
j, _, err := expr.EvalJSON(ctx, chunk.Row{})
require.NoError(t, err)
require.Equal(t, types.JSONTypeCodeArray, j.TypeCode)
require.Equal(t, "[1, 2, 3]", j.String())

// default expr
expr, _, err = buildExpr(t, ctx, "default(id)", expression.WithTableInfo("", tbl))
expr, err = buildExpr(t, ctx, "default(id)", expression.WithTableInfo("", tbl))
require.NoError(t, err)
s, _, err := expr.EvalString(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, 36, len(s), s)

expr, _, err = buildExpr(t, ctx, "default(id)", expression.WithInputSchemaAndNames(schema, names, tbl))
expr, err = buildExpr(t, ctx, "default(id)", expression.WithInputSchemaAndNames(schema, names, tbl))
require.NoError(t, err)
s, _, err = expr.EvalString(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, 36, len(s), s)

expr, _, err = buildExpr(t, ctx, "default(b)", expression.WithTableInfo("", tbl))
expr, err = buildExpr(t, ctx, "default(b)", expression.WithTableInfo("", tbl))
require.NoError(t, err)
d, err := expr.Eval(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, types.NewDatum(int64(123)), d)

// WithCastExprTo
expr, _, err = buildExpr(t, ctx, "1+2+3")
expr, err = buildExpr(t, ctx, "1+2+3")
require.NoError(t, err)
require.Equal(t, mysql.TypeLonglong, expr.GetType().GetType())
castTo := types.NewFieldType(mysql.TypeVarchar)
expr, _, err = buildExpr(t, ctx, "1+2+3", expression.WithCastExprTo(castTo))
expr, err = buildExpr(t, ctx, "1+2+3", expression.WithCastExprTo(castTo))
require.NoError(t, err)
require.Equal(t, mysql.TypeVarchar, expr.GetType().GetType())
v, err := expr.Eval(ctx, chunk.Row{})
Expand All @@ -499,14 +491,14 @@ func TestBuildExpression(t *testing.T) {
require.Equal(t, "6", v.GetString())

// should report error for default expr when source table not provided
_, _, err = buildExpr(t, ctx, "default(b)", expression.WithInputSchemaAndNames(schema, names, nil))
_, err = buildExpr(t, ctx, "default(b)", expression.WithInputSchemaAndNames(schema, names, nil))
require.EqualError(t, err, "Unsupported expr *ast.DefaultExpr when source table not provided")

// subquery not supported
_, _, err = buildExpr(t, ctx, "a + (select b from t)", expression.WithTableInfo("", tbl))
_, err = buildExpr(t, ctx, "a + (select b from t)", expression.WithTableInfo("", tbl))
require.EqualError(t, err, "node '*ast.SubqueryExpr' is not allowed when building an expression without planner")

// param marker not supported
_, _, err = buildExpr(t, ctx, "a + ?", expression.WithTableInfo("", tbl))
_, err = buildExpr(t, ctx, "a + ?", expression.WithTableInfo("", tbl))
require.EqualError(t, err, "node '*driver.ParamMarkerExpr' is not allowed when building an expression without planner")
}
17 changes: 14 additions & 3 deletions pkg/planner/core/rule_partition_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/planner/util"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/table"
"github.com/pingcap/tidb/pkg/table/tables"
Expand Down Expand Up @@ -121,7 +120,13 @@ type partitionTable interface {

func generateHashPartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo, columns []*expression.Column, names types.NameSlice) (expression.Expression, error) {
schema := expression.NewSchema(columns...)
expr, err := util.ParseExprWithPlanCtx(ctx, pi.Expr, schema, names)
// Increase the PlanID to make sure some tests will pass. The old implementation to rewrite AST builds a `TableDual`
// that causes the `PlanID` increases, and many test cases hardcoded the output plan in the expected result.
// Considering the new `ParseSimpleExpr` does not do the same thing and to make the test pass,
// we have to increase the `PlanID` here. But it is safe to remove this line without introducing any bug.
// TODO: remove this line after fixing the test cases.
ctx.GetSessionVars().PlanID.Add(1)
expr, err := expression.ParseSimpleExpr(ctx, pi.Expr, expression.WithInputSchemaAndNames(schema, names, nil))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1049,7 +1054,13 @@ func (s *partitionProcessor) processListPartition(ds *DataSource, pi *model.Part
func makePartitionByFnCol(sctx sessionctx.Context, columns []*expression.Column, names types.NameSlice, partitionExpr string) (*expression.Column, *expression.ScalarFunction, monotoneMode, error) {
monotonous := monotoneModeInvalid
schema := expression.NewSchema(columns...)
partExpr, err := util.ParseExprWithPlanCtx(sctx, partitionExpr, schema, names)
// Increase the PlanID to make sure some tests will pass. The old implementation to rewrite AST builds a `TableDual`
// that causes the `PlanID` increases, and many test cases hardcoded the output plan in the expected result.
// Considering the new `ParseSimpleExpr` does not do the same thing and to make the test pass,
// we have to increase the `PlanID` here. But it is safe to remove this line without introducing any bug.
// TODO: remove this line after fixing the test cases.
sctx.GetSessionVars().PlanID.Add(1)
partExpr, err := expression.ParseSimpleExpr(sctx, partitionExpr, expression.WithInputSchemaAndNames(schema, names, nil))
if err != nil {
return nil, nil, monotonous, err
}
Expand Down
4 changes: 0 additions & 4 deletions pkg/planner/util/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,13 @@ go_library(
deps = [
"//pkg/expression",
"//pkg/kv",
"//pkg/parser",
"//pkg/parser/ast",
"//pkg/parser/model",
"//pkg/sessionctx",
"//pkg/types",
"//pkg/util",
"//pkg/util/collate",
"//pkg/util/ranger",
"//pkg/util/size",
"//pkg/util/sqlexec",
"@com_github_pingcap_errors//:errors",
],
)

Expand Down
34 changes: 0 additions & 34 deletions pkg/planner/util/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,10 @@
package util

import (
"context"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util"
"github.com/pingcap/tidb/pkg/util/sqlexec"
)

// EvalAstExprWithPlanCtx evaluates ast expression with plan context.
Expand All @@ -39,31 +33,3 @@ var EvalAstExprWithPlanCtx func(ctx sessionctx.Context, expr ast.ExprNode) (type
// If you only want to build simple expressions, use `expression.BuildSimpleExpr` instead.
var RewriteAstExprWithPlanCtx func(ctx sessionctx.Context, expr ast.ExprNode,
schema *expression.Schema, names types.NameSlice, allowCastArray bool) (expression.Expression, error)

// ParseExprWithPlanCtx parses expression string to Expression.
// Different with expression.ParseSimpleExpr, it uses planner context and is more powerful to build
// some special expressions like subquery, window function, etc.
// If you only want to build simple expressions, use `expression.ParseSimpleExpr` instead.
func ParseExprWithPlanCtx(ctx sessionctx.Context, exprStr string,
schema *expression.Schema, names types.NameSlice) (expression.Expression, error) {
exprStr = "select " + exprStr
var stmts []ast.StmtNode
var err error
var warns []error
if p, ok := ctx.(sqlexec.SQLParser); ok {
stmts, warns, err = p.ParseSQL(context.Background(), exprStr)
} else {
stmts, warns, err = parser.New().ParseSQL(exprStr)
}

if err != nil {
return nil, errors.Trace(util.SyntaxWarn(err))
}

for _, warn := range warns {
ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn))
}

expr := stmts[0].(*ast.SelectStmt).Fields.Fields[0].Expr
return RewriteAstExprWithPlanCtx(ctx, expr, schema, names, false)
}

0 comments on commit 7f6aafe

Please sign in to comment.