diff --git a/planner/cascades/testdata/integration_suite_in.json b/planner/cascades/testdata/integration_suite_in.json index bd3e77a2f55b6..aec9d912f423c 100644 --- a/planner/cascades/testdata/integration_suite_in.json +++ b/planner/cascades/testdata/integration_suite_in.json @@ -28,7 +28,8 @@ "select b, sum(a) from t group by b order by b", "select b, avg(a) from t group by b having sum(a) > 1 order by b", "select max(a+b) from t", - "select sum(a) from t group by a, a+b order by a" + "select sum(a) from t group by a, a+b order by a", + "select max(a) from t" ] }, { diff --git a/planner/cascades/testdata/integration_suite_out.json b/planner/cascades/testdata/integration_suite_out.json index b7cd2764969d0..de02b843ae95b 100644 --- a/planner/cascades/testdata/integration_suite_out.json +++ b/planner/cascades/testdata/integration_suite_out.json @@ -212,10 +212,10 @@ { "SQL": "select max(a+b) from t", "Plan": [ - "HashAgg_12 1.00 root funcs:max(Column#4)->Column#3", - "└─TableReader_13 1.00 root data:HashAgg_14", - " └─HashAgg_14 1.00 cop[tikv] funcs:max(plus(test.t.a, test.t.b))->Column#4", - " └─TableScan_10 10000.00 cop[tikv] table:t, range:[-inf,+inf], keep order:false, stats:pseudo" + "HashAgg_16 1.00 root funcs:max(Column#4)->Column#3", + "└─TableReader_17 1.00 root data:HashAgg_18", + " └─HashAgg_18 1.00 cop[tikv] funcs:max(plus(test.t.a, test.t.b))->Column#4", + " └─TableScan_14 10000.00 cop[tikv] table:t, range:[-inf,+inf], keep order:false, stats:pseudo" ], "Result": [ "48" @@ -237,6 +237,19 @@ "3", "4" ] + }, + { + "SQL": "select max(a) from t", + "Plan": [ + "HashAgg_16 1.00 root funcs:max(test.t.a)->Column#3", + "└─Limit_18 1.00 root offset:0, count:1", + " └─TableReader_23 1.00 root data:Limit_24", + " └─Limit_24 1.00 cop[tikv] offset:0, count:1", + " └─TableScan_22 1.00 cop[tikv] table:t, range:[-inf,+inf], keep order:true, desc, stats:pseudo" + ], + "Result": [ + "4" + ] } ] }, diff --git a/planner/cascades/testdata/transformation_rules_suite_in.json b/planner/cascades/testdata/transformation_rules_suite_in.json index 06bfe0314745b..a593df146a9f4 100644 --- a/planner/cascades/testdata/transformation_rules_suite_in.json +++ b/planner/cascades/testdata/transformation_rules_suite_in.json @@ -68,6 +68,13 @@ "select a+c from (select floor(a) as a, b, c from t) as t2" ] }, + { + "name": "TestEliminateMaxMin", + "cases": [ + "select max(a) from t;", + "select min(a) from t;" + ] + }, { "name": "TestMergeAggregationProjection", "cases": [ diff --git a/planner/cascades/testdata/transformation_rules_suite_out.json b/planner/cascades/testdata/transformation_rules_suite_out.json index 704929e5c3723..6280fc2742e14 100644 --- a/planner/cascades/testdata/transformation_rules_suite_out.json +++ b/planner/cascades/testdata/transformation_rules_suite_out.json @@ -1172,6 +1172,39 @@ } ] }, + { + "Name": "TestEliminateMaxMin", + "Cases": [ + { + "SQL": "select max(a) from t;", + "Result": [ + "Group#0 Schema:[Column#13]", + " Projection_3 input:[Group#1], Column#13", + "Group#1 Schema:[Column#13]", + " Aggregation_2 input:[Group#2], funcs:max(test.t.a)", + " Aggregation_2 input:[Group#3], funcs:max(test.t.a)", + "Group#2 Schema:[test.t.a]", + " TableScan_1 table:t", + "Group#3 Schema:[test.t.a]", + " TopN_4 input:[Group#2], test.t.a:desc, offset:0, count:1" + ] + }, + { + "SQL": "select min(a) from t;", + "Result": [ + "Group#0 Schema:[Column#13]", + " Projection_3 input:[Group#1], Column#13", + "Group#1 Schema:[Column#13]", + " Aggregation_2 input:[Group#2], funcs:min(test.t.a)", + " Aggregation_2 input:[Group#3], funcs:min(test.t.a)", + "Group#2 Schema:[test.t.a]", + " TableScan_1 table:t", + "Group#3 Schema:[test.t.a]", + " TopN_4 input:[Group#2], test.t.a:asc, offset:0, count:1" + ] + } + ] + }, { "Name": "TestMergeAggregationProjection", "Cases": [ diff --git a/planner/cascades/transformation_rules.go b/planner/cascades/transformation_rules.go index f8c6fb86e05ee..5b044b48cb6dd 100644 --- a/planner/cascades/transformation_rules.go +++ b/planner/cascades/transformation_rules.go @@ -16,10 +16,13 @@ package cascades import ( "math" + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/planner/memo" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/ranger" ) @@ -63,6 +66,7 @@ var defaultTransformationMap = map[memo.Operand][]Transformation{ memo.OperandAggregation: { NewRulePushAggDownGather(), NewRuleMergeAggregationProjection(), + NewRuleEliminateSingleMaxMin(), }, memo.OperandLimit: { NewRuleTransformLimitToTopN(), @@ -1339,6 +1343,107 @@ func (r *MergeAggregationProjection) OnTransform(old *memo.ExprIter) (newExprs [ return []*memo.GroupExpr{newAggExpr}, true, false, nil } +// EliminateSingleMaxMin tries to convert a single max/min to Limit+Sort operators. +type EliminateSingleMaxMin struct { + baseRule +} + +// NewRuleEliminateSingleMaxMin creates a new Transformation EliminateSingleMaxMin. +// The pattern of this rule is `max/min->X`. +func NewRuleEliminateSingleMaxMin() Transformation { + rule := &EliminateSingleMaxMin{} + rule.pattern = memo.BuildPattern( + memo.OperandAggregation, + memo.EngineTiDBOnly, + memo.NewPattern(memo.OperandAny, memo.EngineTiDBOnly), + ) + return rule +} + +// Match implements Transformation interface. +func (r *EliminateSingleMaxMin) Match(expr *memo.ExprIter) bool { + // Use appliedRuleSet in GroupExpr to avoid re-apply rules. + if expr.GetExpr().HasAppliedRule(r) { + return false + } + + agg := expr.GetExpr().ExprNode.(*plannercore.LogicalAggregation) + // EliminateSingleMaxMin only works on the complete mode. + if !agg.IsCompleteModeAgg() { + return false + } + if len(agg.GroupByItems) != 0 { + return false + } + + // If there is only one aggFunc, we don't need to guarantee that the child of it is a data + // source, or whether the sort can be eliminated. This transformation won't be worse than previous. + // Make sure that the aggFunc are Max or Min. + // TODO: If there have only one Max or Min aggFunc and the other aggFuncs are FirstRow() can also use this rule. Waiting for the not null prop is maintained. + if len(agg.AggFuncs) != 1 { + return false + } + if agg.AggFuncs[0].Name != ast.AggFuncMax && agg.AggFuncs[0].Name != ast.AggFuncMin { + return false + } + return true +} + +// OnTransform implements Transformation interface. +// It will transform `max/min->X` to `max/min->top1->sel->X`. +func (r *EliminateSingleMaxMin) OnTransform(old *memo.ExprIter) (newExprs []*memo.GroupExpr, eraseOld bool, eraseAll bool, err error) { + agg := old.GetExpr().ExprNode.(*plannercore.LogicalAggregation) + childGroup := old.GetExpr().Children[0] + ctx := agg.SCtx() + f := agg.AggFuncs[0] + + // If there's no column in f.GetArgs()[0], we still need limit and read data from real table because the result should be NULL if the input is empty. + if len(expression.ExtractColumns(f.Args[0])) > 0 { + // If it can be NULL, we need to filter NULL out first. + if !mysql.HasNotNullFlag(f.Args[0].GetType().Flag) { + sel := plannercore.LogicalSelection{}.Init(ctx, agg.SelectBlockOffset()) + isNullFunc := expression.NewFunctionInternal(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), f.Args[0]) + notNullFunc := expression.NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), isNullFunc) + sel.Conditions = []expression.Expression{notNullFunc} + selExpr := memo.NewGroupExpr(sel) + selExpr.SetChildren(childGroup) + selGroup := memo.NewGroupWithSchema(selExpr, childGroup.Prop.Schema) + childGroup = selGroup + } + + // Add top(1) operators. + // For max function, the sort order should be desc. + desc := f.Name == ast.AggFuncMax + var byItems []*plannercore.ByItems + byItems = append(byItems, &plannercore.ByItems{ + Expr: f.Args[0], + Desc: desc, + }) + top1 := plannercore.LogicalTopN{ + ByItems: byItems, + Count: 1, + }.Init(ctx, agg.SelectBlockOffset()) + top1Expr := memo.NewGroupExpr(top1) + top1Expr.SetChildren(childGroup) + top1Group := memo.NewGroupWithSchema(top1Expr, childGroup.Prop.Schema) + childGroup = top1Group + } else { + li := plannercore.LogicalLimit{Count: 1}.Init(ctx, agg.SelectBlockOffset()) + liExpr := memo.NewGroupExpr(li) + liExpr.SetChildren(childGroup) + liGroup := memo.NewGroupWithSchema(liExpr, childGroup.Prop.Schema) + childGroup = liGroup + } + + newAgg := agg + newAggExpr := memo.NewGroupExpr(newAgg) + // If no data in the child, we need to return NULL instead of empty. This cannot be done by sort and limit themselves. + // Since now there would be at most one row returned, the remained agg operator is not expensive anymore. + newAggExpr.SetChildren(childGroup) + newAggExpr.AddAppliedRule(r) + return []*memo.GroupExpr{newAggExpr}, false, false, nil +} + // MergeAdjacentSelection merge adjacent selection. type MergeAdjacentSelection struct { baseRule diff --git a/planner/cascades/transformation_rules_test.go b/planner/cascades/transformation_rules_test.go index d3d7afbc235bc..3bff81b6400dd 100644 --- a/planner/cascades/transformation_rules_test.go +++ b/planner/cascades/transformation_rules_test.go @@ -192,6 +192,24 @@ func (s *testTransformationRuleSuite) TestProjectionElimination(c *C) { testGroupToString(input, output, s, c) } +func (s *testTransformationRuleSuite) TestEliminateMaxMin(c *C) { + s.optimizer.ResetTransformationRules(map[memo.Operand][]Transformation{ + memo.OperandAggregation: { + NewRuleEliminateSingleMaxMin(), + }, + }) + defer func() { + s.optimizer.ResetTransformationRules(defaultTransformationMap) + }() + var input []string + var output []struct { + SQL string + Result []string + } + s.testData.GetTestCases(c, &input, &output) + testGroupToString(input, output, s, c) +} + func (s *testTransformationRuleSuite) TestMergeAggregationProjection(c *C) { s.optimizer.ResetTransformationRules(map[memo.Operand][]Transformation{ memo.OperandAggregation: { diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index b13f69c09bee0..fad6948caf4d3 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -298,6 +298,12 @@ func (la *LogicalAggregation) IsPartialModeAgg() bool { return la.AggFuncs[0].Mode == aggregation.Partial1Mode } +// IsCompleteModeAgg returns if all of the AggFuncs are CompleteMode. +func (la *LogicalAggregation) IsCompleteModeAgg() bool { + // Since all of the AggFunc share the same AggMode, we only need to check the first one. + return la.AggFuncs[0].Mode == aggregation.CompleteMode +} + // GetGroupByCols returns the groupByCols. If the groupByCols haven't be collected, // this method would collect them at first. If the GroupByItems have been changed, // we should explicitly collect GroupByColumns before this method.