diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index e118738a9de0c..9eeb6feb2d086 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -89,6 +89,7 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) used := expression.GetUsedList(parentUsedCols, la.Schema()) allFirstRow := true + allRemainFirstRow := true for i := len(used) - 1; i >= 0; i-- { if la.AggFuncs[i].Name != ast.AggFuncFirstRow { allFirstRow = false @@ -96,6 +97,8 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) if !used[i] { la.schema.Columns = append(la.schema.Columns[:i], la.schema.Columns[i+1:]...) la.AggFuncs = append(la.AggFuncs[:i], la.AggFuncs[i+1:]...) + } else if la.AggFuncs[i].Name != ast.AggFuncFirstRow { + allRemainFirstRow = false } } var selfUsedCols []*expression.Column @@ -106,7 +109,7 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) aggrFunc.OrderByItems, cols = pruneByItems(aggrFunc.OrderByItems) selfUsedCols = append(selfUsedCols, cols...) } - if len(la.AggFuncs) == 0 { + if len(la.AggFuncs) == 0 || (!allFirstRow && allRemainFirstRow) { // If all the aggregate functions are pruned, we should add an aggregate function to maintain the info of row numbers. // For all the aggregate functions except `first_row`, if we have an empty table defined as t(a,b), // `select agg(a) from t` would always return one row, while `select agg(a) from t group by b` would return empty. @@ -121,12 +124,12 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) if err != nil { return err } - la.AggFuncs = []*aggregation.AggFuncDesc{newAgg} + la.AggFuncs = append(la.AggFuncs, newAgg) col := &expression.Column{ UniqueID: la.ctx.GetSessionVars().AllocPlanColumnID(), RetType: newAgg.RetTp, } - la.schema.Columns = []*expression.Column{col} + la.schema.Columns = append(la.schema.Columns, col) } if len(la.GroupByItems) > 0 { diff --git a/planner/core/testdata/integration_serial_suite_out.json b/planner/core/testdata/integration_serial_suite_out.json index 3356a535d59cd..a00ee8a2dd8cb 100644 --- a/planner/core/testdata/integration_serial_suite_out.json +++ b/planner/core/testdata/integration_serial_suite_out.json @@ -2046,14 +2046,14 @@ " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", " └─Projection(Probe) 7992.00 batchCop[tiflash] Column#11, test.t.id", - " └─HashAgg 7992.00 batchCop[tiflash] group by:Column#32, funcs:sum(Column#30)->Column#11, funcs:firstrow(Column#31)->test.t.id", - " └─Projection 9990.00 batchCop[tiflash] cast(test.t.id, decimal(32,0) BINARY)->Column#30, test.t.id, test.t.id", + " └─HashAgg 7992.00 batchCop[tiflash] group by:Column#39, funcs:sum(Column#37)->Column#11, funcs:firstrow(Column#38)->test.t.id", + " └─Projection 9990.00 batchCop[tiflash] cast(test.t.id, decimal(32,0) BINARY)->Column#37, test.t.id, test.t.id", " └─HashJoin 9990.00 batchCop[tiflash] inner join, equal:[eq(test.t.id, test.t.id)]", - " ├─Projection(Build) 7992.00 batchCop[tiflash] test.t.id", - " │ └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:firstrow(test.t.id)->test.t.id", + " ├─Projection(Build) 7992.00 batchCop[tiflash] test.t.id, Column#13", + " │ └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:firstrow(test.t.id)->test.t.id, funcs:sum(Column#17)->Column#13", " │ └─ExchangeReceiver 7992.00 batchCop[tiflash] ", " │ └─ExchangeSender 7992.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", - " │ └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, ", + " │ └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:count(1)->Column#17", " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", " └─ExchangeReceiver(Probe) 9990.00 batchCop[tiflash] ", @@ -2301,19 +2301,19 @@ " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", " └─Projection(Probe) 7992.00 batchCop[tiflash] Column#11, test.t.id", - " └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:sum(Column#14)->Column#11, funcs:firstrow(test.t.id)->test.t.id", + " └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:sum(Column#17)->Column#11, funcs:firstrow(test.t.id)->test.t.id", " └─ExchangeReceiver 7992.00 batchCop[tiflash] ", " └─ExchangeSender 7992.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", - " └─HashAgg 7992.00 batchCop[tiflash] group by:Column#27, funcs:sum(Column#26)->Column#14", - " └─Projection 9990.00 batchCop[tiflash] cast(test.t.id, decimal(32,0) BINARY)->Column#26, test.t.id", + " └─HashAgg 7992.00 batchCop[tiflash] group by:Column#33, funcs:sum(Column#32)->Column#17", + " └─Projection 9990.00 batchCop[tiflash] cast(test.t.id, decimal(32,0) BINARY)->Column#32, test.t.id", " └─HashJoin 9990.00 batchCop[tiflash] inner join, equal:[eq(test.t.id, test.t.id)]", " ├─ExchangeReceiver(Build) 7992.00 batchCop[tiflash] ", " │ └─ExchangeSender 7992.00 batchCop[tiflash] ExchangeType: Broadcast", - " │ └─Projection 7992.00 batchCop[tiflash] test.t.id", - " │ └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:firstrow(test.t.id)->test.t.id", + " │ └─Projection 7992.00 batchCop[tiflash] test.t.id, Column#13", + " │ └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:firstrow(test.t.id)->test.t.id, funcs:sum(Column#16)->Column#13", " │ └─ExchangeReceiver 7992.00 batchCop[tiflash] ", " │ └─ExchangeSender 7992.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id", - " │ └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, ", + " │ └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:count(1)->Column#16", " │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))", " │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo", " └─Selection(Probe) 9990.00 batchCop[tiflash] not(isnull(test.t.id))", diff --git a/planner/core/testdata/integration_suite_in.json b/planner/core/testdata/integration_suite_in.json index 3ba7037a3bc1d..f89bfa5d80059 100644 --- a/planner/core/testdata/integration_suite_in.json +++ b/planner/core/testdata/integration_suite_in.json @@ -31,7 +31,8 @@ "select count(1) from t join (select max(a) from t where false group by a) as tmp", "select count(1) from t join (select min(a) from t where false group by a) as tmp", "select count(1) from t join (select sum(a) from t where false group by a) as tmp", - "select count(1) from t join (select avg(a) from t where false group by a) as tmp" + "select count(1) from t join (select avg(a) from t where false group by a) as tmp", + "SELECT avg(2) FROM(SELECT min(c) FROM t JOIN(SELECT 1 c) d ORDER BY a) e" ] }, { diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index 6cbd743fb6e73..b271e0b07da71 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -126,6 +126,12 @@ "Res": [ "0" ] + }, + { + "SQL": "SELECT avg(2) FROM(SELECT min(c) FROM t JOIN(SELECT 1 c) d ORDER BY a) e", + "Res": [ + "2.0000" + ] } ] },