Skip to content

Commit

Permalink
Optimizer: derive TopN from filter on row number (#41209)
Browse files Browse the repository at this point in the history
ref #39792
  • Loading branch information
ghazalfamilyusa authored Feb 13, 2023
1 parent b5ff518 commit 6f99eba
Show file tree
Hide file tree
Showing 10 changed files with 488 additions and 0 deletions.
25 changes: 25 additions & 0 deletions expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,31 @@ func ExtractEquivalenceColumns(result [][]Expression, exprs []Expression) [][]Ex
return result
}

// FindUpperBound looks for column < constant or column <= constant and returns both the column
// and constant. It return nil, 0 if the expression is not of this form.
// It is used by derived Top N pattern and it is put here since it looks like
// a general purpose routine. Similar routines can be added to find lower bound as well.
func FindUpperBound(expr Expression) (*Column, int64) {
scalarFunction, scalarFunctionOk := expr.(*ScalarFunction)
if scalarFunctionOk {
args := scalarFunction.GetArgs()
if len(args) == 2 {
col, colOk := args[0].(*Column)
constant, constantOk := args[1].(*Constant)
if colOk && constantOk && (scalarFunction.FuncName.L == ast.LT || scalarFunction.FuncName.L == ast.LE) {
value, valueOk := constant.Value.GetValue().(int64)
if valueOk {
if scalarFunction.FuncName.L == ast.LT {
return col, value - 1
}
return col, value
}
}
}
}
return nil, 0
}

func extractEquivalenceColumns(result [][]Expression, expr Expression) [][]Expression {
switch v := expr.(type) {
case *ScalarFunction:
Expand Down
2 changes: 2 additions & 0 deletions planner/core/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ go_library(
"rule_build_key_info.go",
"rule_column_pruning.go",
"rule_decorrelate.go",
"rule_derive_topn_from_window.go",
"rule_eliminate_projection.go",
"rule_generate_column_substitute.go",
"rule_inject_extra_projection.go",
Expand Down Expand Up @@ -202,6 +203,7 @@ go_test(
"predicate_simplification_test.go",
"prepare_test.go",
"preprocess_test.go",
"rule_derive_topn_from_window_test.go",
"rule_inject_extra_projection_test.go",
"rule_join_reorder_dp_test.go",
"rule_join_reorder_test.go",
Expand Down
1 change: 1 addition & 0 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,7 @@ func (b *PlanBuilder) coalesceCommonColumns(p *LogicalJoin, leftPlan, rightPlan

func (b *PlanBuilder) buildSelection(ctx context.Context, p LogicalPlan, where ast.ExprNode, aggMapper map[*ast.AggregateFuncExpr]int) (LogicalPlan, error) {
b.optFlag |= flagPredicatePushDown
b.optFlag |= flagDeriveTopNFromWindow
if b.curClause != havingClause {
b.curClause = whereClause
}
Expand Down
5 changes: 5 additions & 0 deletions planner/core/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func TestMain(m *testing.M) {
testDataMap.LoadTestSuiteData("testdata", "flat_plan_suite")
testDataMap.LoadTestSuiteData("testdata", "binary_plan_suite")
testDataMap.LoadTestSuiteData("testdata", "json_plan_suite")
testDataMap.LoadTestSuiteData("testdata", "derive_topn_from_window")

indexMergeSuiteData = testDataMap["index_merge_suite"]
planSuiteUnexportedData = testDataMap["plan_suite_unexported"]
Expand Down Expand Up @@ -139,3 +140,7 @@ func GetIndexMergeSuiteData() testdata.TestData {
func GetJSONPlanSuiteData() testdata.TestData {
return testDataMap["json_plan_suite"]
}

func GetDerivedTopNSuiteData() testdata.TestData {
return testDataMap["derive_topn_from_window"]
}
2 changes: 2 additions & 0 deletions planner/core/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ const (
flagPartitionProcessor
flagCollectPredicateColumnsPoint
flagPushDownAgg
flagDeriveTopNFromWindow
flagPushDownTopN
flagSyncWaitStatsLoadPoint
flagJoinReOrder
Expand All @@ -97,6 +98,7 @@ var optRuleList = []logicalOptRule{
&partitionProcessor{},
&collectPredicateColumnsPoint{},
&aggregationPushDownSolver{},
&deriveTopNFromWindow{},
&pushDownTopNOptimizer{},
&syncWaitStatsLoadPoint{},
&joinReOrderSolver{},
Expand Down
3 changes: 3 additions & 0 deletions planner/core/plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ type LogicalPlan interface {
// pushDownTopN will push down the topN or limit operator during logical optimization.
pushDownTopN(topN *LogicalTopN, opt *logicalOptimizeOp) LogicalPlan

// deriveTopN derives an implicit TopN from a filter on row_number window function..
deriveTopN(opt *logicalOptimizeOp) LogicalPlan

// recursiveDeriveStats derives statistic info between plans.
recursiveDeriveStats(colGroups [][]*expression.Column) (*property.StatsInfo, error)

Expand Down
122 changes: 122 additions & 0 deletions planner/core/rule_derive_topn_from_window.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package core

import (
"context"
"fmt"

"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/planner/util"
)

// deriveTopNFromWindow pushes down the topN or limit. In the future we will remove the limit from `requiredProperty` in CBO phase.
type deriveTopNFromWindow struct {
}

func appendDerivedTopNTrace(topN LogicalPlan, opt *logicalOptimizeOp) {
child := topN.Children()[0]
action := func() string {
return fmt.Sprintf("%v_%v top N added below %v_%v ", topN.TP(), topN.ID(), child.TP(), child.ID())
}
reason := func() string {
return fmt.Sprintf("%v filter on row number", topN.TP())
}
opt.appendStepToCurrent(topN.ID(), topN.TP(), reason, action)
}

/*
Check the following pattern of filter over row number window function:
- Filter is simple condition of row_number < value or row_number <= value
- The window function is a simple row number
- With default frame: rows between current row and current row. Check is not necessary since
current row is only frame applicable to row number
- No partition
- Child is a data source.
*/
func windowIsTopN(p *LogicalSelection) (bool, uint64) {
// Check if child is window function.
child, isLogicalWindow := p.Children()[0].(*LogicalWindow)
if !isLogicalWindow {
return false, 0
}

if len(p.Conditions) != 1 {
return false, 0
}

// Check if filter is column < constant or column <= constant. If it is in this form find column and constant.
column, limitValue := expression.FindUpperBound(p.Conditions[0])
if column == nil || limitValue <= 0 {
return false, 0
}

// Check if filter on window function
windowColumns := child.GetWindowResultColumns()
if len(windowColumns) != 1 || !(column.Equal(p.ctx, windowColumns[0])) {
return false, 0
}

grandChild := child.Children()[0]
_, isDataSource := grandChild.(*DataSource)
if !isDataSource {
return false, 0
}
if len(child.WindowFuncDescs) == 1 && child.WindowFuncDescs[0].Name == "row_number" && len(child.PartitionBy) == 0 &&
child.Frame.Type == ast.Rows && child.Frame.Start.Type == ast.CurrentRow && child.Frame.End.Type == ast.CurrentRow {
return true, uint64(limitValue)
}
return false, 0
}

func (s *deriveTopNFromWindow) optimize(_ context.Context, p LogicalPlan, opt *logicalOptimizeOp) (LogicalPlan, error) {
return p.deriveTopN(opt), nil
}

func (s *baseLogicalPlan) deriveTopN(opt *logicalOptimizeOp) LogicalPlan {
p := s.self
for i, child := range p.Children() {
newChild := child.deriveTopN(opt)
p.SetChild(i, newChild)
}
return p
}

func (s *LogicalSelection) deriveTopN(opt *logicalOptimizeOp) LogicalPlan {
p := s.self.(*LogicalSelection)
windowIsTopN, limitValue := windowIsTopN(p)
if windowIsTopN {
child := p.Children()[0].(*LogicalWindow)
grandChild := child.Children()[0].(*DataSource)
// Build order by for derived Limit
byItems := make([]*util.ByItems, 0, len(child.OrderBy))
for _, col := range child.OrderBy {
byItems = append(byItems, &util.ByItems{Expr: col.Col, Desc: col.Desc})
}
// Build derived Limit
derivedTopN := LogicalTopN{Count: limitValue, ByItems: byItems}.Init(grandChild.ctx, grandChild.blockOffset)
derivedTopN.SetChildren(grandChild)
/* return datasource->topN->window */
child.SetChildren(derivedTopN)
appendDerivedTopNTrace(child, opt)
return child
}
return p
}

func (*deriveTopNFromWindow) name() string {
return "derive_topn_from_window"
}
85 changes: 85 additions & 0 deletions planner/core/rule_derive_topn_from_window_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package core_test

import (
"testing"

plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/testkit/testdata"
)

// Rule should bot be applied
func TestPushDerivedTopnNegative(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists employee")
tk.MustExec("create table t(a int, b int)")
tk.MustExec("insert into t values(1,1)")
tk.MustExec("insert into t values(2,1)")
tk.MustExec("insert into t values(3,2)")
tk.MustExec("insert into t values(4,2)")
tk.MustExec("insert into t values(5,2)")
var input Input
var output []struct {
SQL string
Plan []string
}
suiteData := plannercore.GetDerivedTopNSuiteData()
suiteData.LoadTestCases(t, &input, &output)
for i, sql := range input {
plan := tk.MustQuery("explain format = 'brief' " + sql)
testdata.OnRecord(func() {
output[i].SQL = sql
output[i].Plan = testdata.ConvertRowsToStrings(plan.Rows())
})
plan.Check(testkit.Rows(output[i].Plan...))
}
}

// Rule should be applied
func TestPushDerivedTopnPositive(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists employee")
tk.MustExec("create table t(a int, b int)")
tk.MustExec("insert into t values(1,1)")
tk.MustExec("insert into t values(2,1)")
tk.MustExec("insert into t values(3,2)")
tk.MustExec("insert into t values(4,2)")
tk.MustExec("insert into t values(5,2)")
var input Input
var output []struct {
SQL string
Plan []string
Res []string
}
suiteData := plannercore.GetDerivedTopNSuiteData()
suiteData.LoadTestCases(t, &input, &output)
for i, sql := range input {
plan := tk.MustQuery("explain format = 'brief' " + sql)
res := tk.MustQuery(sql)
testdata.OnRecord(func() {
output[i].SQL = sql
output[i].Plan = testdata.ConvertRowsToStrings(plan.Rows())
output[i].Res = testdata.ConvertRowsToStrings(res.Rows())
})
plan.Check(testkit.Rows(output[i].Plan...))
res.Check(testkit.Rows(output[i].Res...))
}
}
29 changes: 29 additions & 0 deletions planner/core/testdata/derive_topn_from_window_in.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[
{
"name": "TestPushDerivedTopnNegative",
"cases":[
"select row_number() over (partition by a) from t -- pattern missing filter on row number",
"select * from (select rank() over () as rank_order from t) DT where rank_order <= 3 -- rank not supported in pattern yet",
"select * from (select row_number() over (partition by a) as rownumber from t) DT where rownumber <= 3 -- pattern is applicable but partition by is not supported yet",
"select * from (select row_number() over () as rownumber1, row_number() over (partition by a) as rownumber2 from t) DT where rownumber1 <= 3 -- pattern not applicable with multiple window functions",
"select * from (select b, row_number() over () as rownumber from t) DT where rownumber <= 3 and b > 5 -- pattern is not applicable with complex filter on top of window",
"select * from (select b, row_number() over () as rownumber from t) DT where rownumber > 3 -- pattern is not applicable with filter is not < or <=",
"select * from (select a,b, row_number() over () as rownumber from t) DT where a > b -- pattern is not applicable with filter is not < or <=",
"select * from (select a,b, row_number() over () as rownumber from t) DT where a <= 3 -- pattern is not applicable with filter is not on row number",
"select * from (select a,b, row_number() over () as rownumber from t) DT where 3 >= rownumber -- pattern is not applicable with filter is not < or <=",
"select * from (select a,b, row_number() over () as rownumber from t) DT where rownumber <= -4 -- pattern is not applicable with filter constant negative",
"select * from (select row_number() over () as rownumber from t) DT where rownumber <= 3 and rownumber >= 2 -- pattern is not applicable with complex filter"
]
},
{
"name": "TestPushDerivedTopnPositive",
"cases":[
"select * from (select a,b, row_number() over (order by a) as rownumber from t) DT where rownumber <= 3.5 -- pattern is applicable with N rounded down to an integer",
"select * from (select row_number() over (order by a) as rownumber from t) DT where rownumber <= 3 -- pattern is applicable",
"select * from (select row_number() over (order by a) as rownumber from t) DT where rownumber < 3 -- pattern is applicable",
"select * from (select row_number() over(rows between 1 preceding and 1 following) as rownumber from t) DT where rownumber <= 3 -- pattern is applicable",
"select * from (select a,row_number() over (order by a desc) as rownumber,b from t) DT where rownumber <= 3 -- pattern is applicable",
"select count(*) from (select * from (select a,row_number() over (order by b) as rownumber,b from t) DT1 where rownumber <= 1) DT2 -- pattern is applicable"
]
}
]
Loading

0 comments on commit 6f99eba

Please sign in to comment.