Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: refactor functions to build expressions #50997

Merged
merged 3 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion br/pkg/lightning/backend/kv/sql2kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,12 @@ func CollectGeneratedColumns(se *Session, meta *model.TableInfo, cols []*table.C
var genCols []GeneratedCol
for i, col := range cols {
if col.GeneratedExpr != nil {
expr, err := expression.RewriteAstExpr(se, col.GeneratedExpr.Internal(), schema, names, true)
expr, err := expression.BuildSimpleExpr(
se,
col.GeneratedExpr.Internal(),
expression.WithInputSchemaAndNames(schema, names, meta),
expression.WithAllowCastArray(true),
)
if err != nil {
return nil, err
}
Expand Down
17 changes: 10 additions & 7 deletions pkg/ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
}

// evaluate the non-function-call expr to a certain value.
v, err := expression.EvalAstExpr(ctx, option.Expr)
v, err := expression.EvalSimpleAst(ctx, option.Expr)
if err != nil {
return nil, false, errors.Trace(err)
}
Expand Down Expand Up @@ -2304,7 +2304,7 @@ func checkTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo
return errors.Trace(err)
}
if s.Partition != nil {
if err := checkPartitionFuncType(ctx, s.Partition.Expr, s.Table.Schema, tbInfo); err != nil {
if err := checkPartitionFuncType(ctx, s.Partition.Expr, s.Table.Schema.O, tbInfo); err != nil {
return errors.Trace(err)
}
if err := checkPartitioningKeysConstraints(ctx, s, tbInfo); err != nil {
Expand Down Expand Up @@ -3254,11 +3254,11 @@ func checkTwoRangeColumns(ctx sessionctx.Context, curr, prev *model.PartitionDef
}

func parseAndEvalBoolExpr(ctx sessionctx.Context, l, r string, colInfo *model.ColumnInfo, tbInfo *model.TableInfo) (bool, error) {
lexpr, err := expression.ParseSimpleExprCastWithTableInfo(ctx, l, tbInfo, &colInfo.FieldType)
lexpr, err := expression.ParseSimpleExpr(ctx, l, expression.WithTableInfo("", tbInfo), expression.WithCastExprTo(&colInfo.FieldType))
if err != nil {
return false, err
}
rexpr, err := expression.ParseSimpleExprCastWithTableInfo(ctx, r, tbInfo, &colInfo.FieldType)
rexpr, err := expression.ParseSimpleExpr(ctx, r, expression.WithTableInfo("", tbInfo), expression.WithCastExprTo(&colInfo.FieldType))
if err != nil {
return false, err
}
Expand Down Expand Up @@ -5383,7 +5383,7 @@ func setDefaultValueWithBinaryPadding(col *table.Column, value any) error {
}

func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) error {
value, err := expression.EvalAstExpr(ctx, option.Expr)
value, err := expression.EvalSimpleAst(ctx, option.Expr)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -7309,7 +7309,10 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as
if err != nil {
return nil, errors.Trace(err)
}
expr, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, idxPart.Expr, true)
expr, err := expression.BuildSimpleExpr(ctx, idxPart.Expr,
expression.WithTableInfo(ctx.GetSessionVars().CurrentDB, tblInfo),
expression.WithAllowCastArray(true),
)
if err != nil {
// TODO: refine the error message.
return nil, err
Expand Down Expand Up @@ -8003,7 +8006,7 @@ func checkAndGetColumnsTypeAndValuesMatch(ctx sessionctx.Context, colTypes []typ
continue
}
colType := colTypes[i]
val, err := expression.EvalAstExpr(ctx, colExpr)
val, err := expression.EvalSimpleAst(ctx, colExpr)
if err != nil {
return nil, err
}
Expand Down
39 changes: 17 additions & 22 deletions pkg/ddl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,11 +698,11 @@ func getPartitionIntervalFromTable(ctx sessionctx.Context, tbInfo *model.TableIn
var firstExpr, lastExpr ast.ExprNode
if isIntType {
exprStr := fmt.Sprintf("((%s) - (%s)) DIV %d", lastPartLessThan, firstPartLessThan, endIdx-startIdx)
exprs, err := expression.ParseSimpleExprsWithNames(ctx, exprStr, nil, nil)
expr, err := expression.ParseSimpleExpr(ctx, exprStr)
if err != nil {
return nil
}
val, isNull, err := exprs[0].EvalInt(ctx, chunk.Row{})
val, isNull, err := expr.EvalInt(ctx, chunk.Row{})
if isNull || err != nil || val < 1 {
// If NULL, error or interval < 1 then cannot be an INTERVAL partitioned table
return nil
Expand All @@ -721,11 +721,11 @@ func getPartitionIntervalFromTable(ctx sessionctx.Context, tbInfo *model.TableIn
interval.LastRangeEnd = &lastExpr
} else { // types.ETDatetime
exprStr := fmt.Sprintf("TIMESTAMPDIFF(SECOND, '%s', '%s')", firstPartLessThan, lastPartLessThan)
exprs, err := expression.ParseSimpleExprsWithNames(ctx, exprStr, nil, nil)
expr, err := expression.ParseSimpleExpr(ctx, exprStr)
if err != nil {
return nil
}
val, isNull, err := exprs[0].EvalInt(ctx, chunk.Row{})
val, isNull, err := expr.EvalInt(ctx, chunk.Row{})
if isNull || err != nil || val < 1 {
// If NULL, error or interval < 1 then cannot be an INTERVAL partitioned table
return nil
Expand Down Expand Up @@ -781,7 +781,7 @@ func comparePartitionAstAndModel(ctx sessionctx.Context, pAst *ast.PartitionOpti
}

evalFn := func(expr ast.ExprNode) (types.Datum, error) {
val, err := expression.EvalAstExpr(ctx, ast.NewValueExpr(expr, "", ""))
val, err := expression.EvalSimpleAst(ctx, ast.NewValueExpr(expr, "", ""))
if err != nil || partCol == nil {
return val, err
}
Expand Down Expand Up @@ -865,7 +865,7 @@ func comparePartitionDefinitions(ctx sessionctx.Context, a, b []*ast.PartitionDe
L: definedExpr,
R: generatedExpr,
}
cmp, err := expression.EvalAstExpr(ctx, cmpExpr)
cmp, err := expression.EvalSimpleAst(ctx, cmpExpr)
if err != nil {
return err
}
Expand Down Expand Up @@ -1066,7 +1066,7 @@ func GeneratePartDefsFromInterval(ctx sessionctx.Context, tp ast.AlterTableType,
default:
return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning: Internal error during generating altered INTERVAL partitions, no known alter type")
}
lastVal, err := expression.EvalAstExpr(ctx, lastExpr)
lastVal, err := expression.EvalSimpleAst(ctx, lastExpr)
if err != nil {
return err
}
Expand Down Expand Up @@ -1115,7 +1115,7 @@ func GeneratePartDefsFromInterval(ctx sessionctx.Context, tp ast.AlterTableType,
}
}
}
currVal, err = expression.EvalAstExpr(ctx, currExpr)
currVal, err = expression.EvalSimpleAst(ctx, currExpr)
if err != nil {
return err
}
Expand Down Expand Up @@ -1441,7 +1441,7 @@ func checkPartitionValuesIsInt(ctx sessionctx.Context, defName any, exprs []ast.
}
continue
}
val, err := expression.EvalAstExpr(ctx, exp)
val, err := expression.EvalSimpleAst(ctx, exp)
if err != nil {
return err
}
Expand Down Expand Up @@ -1580,21 +1580,16 @@ func checkResultOK(ok bool) error {
}

// checkPartitionFuncType checks partition function return type.
func checkPartitionFuncType(ctx sessionctx.Context, expr ast.ExprNode, dbName model.CIStr, tblInfo *model.TableInfo) error {
func checkPartitionFuncType(ctx sessionctx.Context, expr ast.ExprNode, schema string, tblInfo *model.TableInfo) error {
if expr == nil {
return nil
}

if dbName.L == "" {
dbName = model.NewCIStr(ctx.GetSessionVars().CurrentDB)
if schema == "" {
schema = ctx.GetSessionVars().CurrentDB
}

columns, names, err := expression.ColumnInfos2ColumnsAndNames(ctx, dbName, tblInfo.Name, tblInfo.Cols(), tblInfo)
if err != nil {
return err
}

e, err := expression.RewriteAstExpr(ctx, expr, expression.NewSchema(columns...), names, false)
e, err := expression.BuildSimpleExpr(ctx, expr, expression.WithTableInfo(schema, tblInfo))
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -1720,7 +1715,7 @@ func formatListPartitionValue(ctx sessionctx.Context, tblInfo *model.TableInfo)
if strings.EqualFold(v, "MAXVALUE") {
return nil, errors.Trace(dbterror.ErrMaxvalueInValuesIn)
}
expr, err := expression.ParseSimpleExprCastWithTableInfo(ctx, v, &model.TableInfo{}, colTps[k])
expr, err := expression.ParseSimpleExpr(ctx, v, expression.WithCastExprTo(colTps[k]))
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -1760,7 +1755,7 @@ func getRangeValue(ctx sessionctx.Context, str string, unsigned bool) (any, bool
return value, false, nil
}

e, err1 := expression.ParseSimpleExprWithTableInfo(ctx, str, &model.TableInfo{})
e, err1 := expression.ParseSimpleExpr(ctx, str)
if err1 != nil {
return 0, false, err1
}
Expand All @@ -1776,7 +1771,7 @@ func getRangeValue(ctx sessionctx.Context, str string, unsigned bool) (any, bool
// For example, the following two cases are the same:
// PARTITION p0 VALUES LESS THAN (TO_SECONDS('2004-01-01'))
// PARTITION p0 VALUES LESS THAN (63340531200)
e, err1 := expression.ParseSimpleExprWithTableInfo(ctx, str, &model.TableInfo{})
e, err1 := expression.ParseSimpleExpr(ctx, str)
if err1 != nil {
return 0, false, err1
}
Expand Down Expand Up @@ -3984,7 +3979,7 @@ func isPartExprUnsigned(tbInfo *model.TableInfo) bool {
// We should not rely on any configuration, system or session variables, so use a mock ctx!
// Same as in tables.newPartitionExpr
ctx := mock.NewContext()
expr, err := expression.ParseSimpleExprWithTableInfo(ctx, tbInfo.Partition.Expr, tbInfo)
expr, err := expression.ParseSimpleExpr(ctx, tbInfo.Partition.Expr, expression.WithTableInfo("", tbInfo))
if err != nil {
logutil.BgLogger().Error("isPartExpr failed parsing expression!", zap.Error(err))
return false
Expand Down
2 changes: 1 addition & 1 deletion pkg/ddl/ttl.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func checkTTLIntervalExpr(ctx sessionctx.Context, ttlInfo *model.TTLInfo) error
return errors.Trace(err)
}
nowAddIntervalExpr = stmts[0].(*ast.SelectStmt).Fields.Fields[0].Expr
_, err = expression.EvalAstExpr(ctx, nowAddIntervalExpr)
_, err = expression.EvalSimpleAst(ctx, nowAddIntervalExpr)
return err
}

Expand Down
2 changes: 2 additions & 0 deletions pkg/executor/importer/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ go_library(
"//pkg/parser/mysql",
"//pkg/parser/terror",
"//pkg/planner/core",
"//pkg/planner/util",
"//pkg/sessionctx",
"//pkg/sessionctx/stmtctx",
"//pkg/sessionctx/variable",
Expand Down Expand Up @@ -131,6 +132,7 @@ go_test(
"//pkg/parser/model",
"//pkg/parser/mysql",
"//pkg/planner/core",
"//pkg/planner/util",
"//pkg/session",
"//pkg/sessionctx/variable",
"//pkg/testkit",
Expand Down
3 changes: 2 additions & 1 deletion pkg/executor/importer/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
plannercore "github.com/pingcap/tidb/pkg/planner/core"
plannerutil "github.com/pingcap/tidb/pkg/planner/util"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
Expand Down Expand Up @@ -1283,7 +1284,7 @@ func (e *LoadDataController) CreateColAssignExprs(sctx sessionctx.Context) ([]ex
res := make([]expression.Expression, 0, len(e.ColumnAssignments))
allWarnings := []stmtctx.SQLWarn{}
for _, assign := range e.ColumnAssignments {
newExpr, err := expression.RewriteAstExpr(sctx, assign.Expr, nil, nil, false)
newExpr, err := plannerutil.RewriteAstExprWithPlanCtx(sctx, assign.Expr, nil, nil, false)
// col assign expr warnings is static, we should generate it for each row processed.
// so we save it and clear it here.
allWarnings = append(allWarnings, sctx.GetSessionVars().StmtCtx.GetWarnings()...)
Expand Down
3 changes: 2 additions & 1 deletion pkg/executor/importer/import_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
plannercore "github.com/pingcap/tidb/pkg/planner/core"
plannerutil "github.com/pingcap/tidb/pkg/planner/util"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/dbterror/exeerrors"
Expand Down Expand Up @@ -90,7 +91,7 @@ func TestInitOptionsPositiveCase(t *testing.T) {
for _, opt := range inOptions {
loadDataOpt := plannercore.LoadDataOpt{Name: opt.Name}
if opt.Value != nil {
loadDataOpt.Value, err = expression.RewriteSimpleExprWithNames(sctx, opt.Value, nil, nil)
loadDataOpt.Value, err = plannerutil.RewriteAstExprWithPlanCtx(sctx, opt.Value, nil, nil, false)
require.NoError(t, err)
}
options = append(options, &loadDataOpt)
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/internal/querywatch/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ go_library(
"//pkg/domain",
"//pkg/domain/resourcegroup",
"//pkg/executor/internal/exec",
"//pkg/expression",
"//pkg/infoschema",
"//pkg/parser",
"//pkg/parser/ast",
"//pkg/parser/model",
"//pkg/planner/util",
"//pkg/sessionctx",
"//pkg/util/chunk",
"//pkg/util/sqlexec",
Expand Down
6 changes: 3 additions & 3 deletions pkg/executor/internal/querywatch/query_watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ import (
"github.com/pingcap/tidb/pkg/domain"
"github.com/pingcap/tidb/pkg/domain/resourcegroup"
"github.com/pingcap/tidb/pkg/executor/internal/exec"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/infoschema"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/model"
plannerutil "github.com/pingcap/tidb/pkg/planner/util"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/sqlexec"
Expand All @@ -44,7 +44,7 @@ func setWatchOption(ctx context.Context,
switch op.Tp {
case ast.QueryWatchResourceGroup:
if op.ExprValue != nil {
expr, err := expression.RewriteAstExpr(sctx, op.ExprValue, nil, nil, false)
expr, err := plannerutil.RewriteAstExprWithPlanCtx(sctx, op.ExprValue, nil, nil, false)
if err != nil {
return err
}
Expand All @@ -62,7 +62,7 @@ func setWatchOption(ctx context.Context,
case ast.QueryWatchAction:
record.Action = rmpb.RunawayAction(op.IntValue)
case ast.QueryWatchType:
expr, err := expression.RewriteAstExpr(sctx, op.ExprValue, nil, nil, false)
expr, err := plannerutil.RewriteAstExprWithPlanCtx(sctx, op.ExprValue, nil, nil, false)
if err != nil {
return err
}
Expand Down
8 changes: 2 additions & 6 deletions pkg/expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand Down Expand Up @@ -71,13 +70,10 @@ func TestCompareFunctionWithRefine(t *testing.T) {
{"-123456789123456789123456789.12345 < a", "1"},
{"'aaaa'=a", "eq(0, a)"},
}
cols, names, err := ColumnInfos2ColumnsAndNames(ctx, model.NewCIStr(""), tblInfo.Name, tblInfo.Cols(), tblInfo)
require.NoError(t, err)
schema := NewSchema(cols...)
for _, test := range tests {
f, err := ParseSimpleExprsWithNames(ctx, test.exprStr, schema, names)
f, err := ParseSimpleExpr(ctx, test.exprStr, WithTableInfo("", tblInfo))
require.NoError(t, err)
require.Equal(t, test.result, f[0].String())
require.Equal(t, test.result, f.String())
}
}

Expand Down
Loading