diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index 4dc7f55cf63e8..e1d54ba920633 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -2144,6 +2144,10 @@ func (p *baseLogicalPlan) exhaustPhysicalPlans(_ *property.PhysicalProperty) ([] // canPushToCop checks if it can be pushed to some stores. For TiKV, it only checks datasource. // For TiFlash, it will check whether the operator is supported, but note that the check might be inaccrute. func (p *baseLogicalPlan) canPushToCop(storeTp kv.StoreType) bool { + return p.canPushToCopImpl(storeTp, false) +} + +func (p *baseLogicalPlan) canPushToCopImpl(storeTp kv.StoreType, considerDual bool) bool { ret := true for _, ch := range p.children { switch c := ch.(type) { @@ -2155,7 +2159,21 @@ func (p *baseLogicalPlan) canPushToCop(storeTp kv.StoreType) bool { } } ret = ret && validDs - case *LogicalAggregation, *LogicalProjection, *LogicalSelection, *LogicalJoin, *LogicalUnionAll: + case *LogicalUnionAll: + if storeTp == kv.TiFlash { + ret = ret && c.canPushToCopImpl(storeTp, true) + } else { + return false + } + case *LogicalProjection: + if storeTp == kv.TiFlash { + ret = ret && c.canPushToCopImpl(storeTp, considerDual) + } else { + return false + } + case *LogicalTableDual: + return storeTp == kv.TiFlash && considerDual + case *LogicalAggregation, *LogicalSelection, *LogicalJoin: if storeTp == kv.TiFlash { ret = ret && c.canPushToCop(storeTp) } else { @@ -2533,7 +2551,7 @@ func (p *LogicalUnionAll) exhaustPhysicalPlans(prop *property.PhysicalProperty) if prop.TaskTp == property.MppTaskType && prop.PartitionTp != property.AnyType { return nil, true } - canUseMpp := p.ctx.GetSessionVars().AllowMPPExecution && p.canPushToCop(kv.TiFlash) + canUseMpp := p.ctx.GetSessionVars().AllowMPPExecution && p.canPushToCopImpl(kv.TiFlash, true) chReqProps := make([]*property.PhysicalProperty, 0, len(p.children)) for range p.children { if canUseMpp && prop.TaskTp == property.MppTaskType { diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index 8f860b4fb90e1..19dbfbcbf748b 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -141,7 +141,7 @@ func (p *LogicalTableDual) findBestTask(prop *property.PhysicalProperty, planCou }.Init(p.ctx, p.stats, p.blockOffset) dual.SetSchema(p.schema) planCounter.Dec(1) - return &rootTask{p: dual}, 1, nil + return &rootTask{p: dual, isEmpty: p.RowCount == 0}, 1, nil } func (p *LogicalShow) findBestTask(prop *property.PhysicalProperty, planCounter *PlanCounterTp) (task, int64, error) { diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index fc44a6f5f2a4d..c89934a83dfb7 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -3052,6 +3052,45 @@ func (s *testIntegrationSerialSuite) TestPushDownAggForMPP(c *C) { } } +func (s *testIntegrationSerialSuite) TestMppUnionAll(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t (a int not null, b int, c varchar(20))") + tk.MustExec("create table t1 (a int, b int not null, c double)") + + // Create virtual tiflash replica info. + dom := domain.GetDomain(tk.Se) + is := dom.InfoSchema() + db, exists := is.SchemaByName(model.NewCIStr("test")) + c.Assert(exists, IsTrue) + for _, tblInfo := range db.Tables { + if tblInfo.Name.L == "t" || tblInfo.Name.L == "t1" { + tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: 1, + Available: true, + } + } + } + + var input []string + var output []struct { + SQL string + Plan []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + }) + res := tk.MustQuery(tt) + res.Check(testkit.Rows(output[i].Plan...)) + } + +} + func (s *testIntegrationSerialSuite) TestMppJoinDecimal(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/planner/core/task.go b/planner/core/task.go index 835c4c21c70b8..78c0df1cd402e 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1027,8 +1027,10 @@ func setTableScanToTableRowIDScan(p PhysicalPlan) { // rootTask is the final sink node of a plan graph. It should be a single goroutine on tidb. type rootTask struct { - p PhysicalPlan - cst float64 + p PhysicalPlan + cst float64 + isEmpty bool // isEmpty indicates if this task contains a dual table and returns empty data. + // TODO: The flag 'isEmpty' is only checked by Projection and UnionAll. We should support more cases in the future. } func (t *rootTask) copy() task { @@ -1273,6 +1275,9 @@ func (p *PhysicalProjection) attach2Task(tasks ...task) task { t = t.convertToRootTask(p.ctx) t = attachPlan2Task(p, t) t.addCost(p.GetCost(t.count())) + if root, ok := tasks[0].(*rootTask); ok && root.isEmpty { + t.(*rootTask).isEmpty = true + } return t } @@ -1287,18 +1292,25 @@ func (p *PhysicalUnionAll) attach2MppTasks(tasks ...task) task { childMaxCost = childCost } childPlans = append(childPlans, mpp.plan()) + } else if root, ok := tk.(*rootTask); ok && root.isEmpty { + continue } else { return invalidTask } } + if len(childPlans) == 0 { + return invalidTask + } p.SetChildren(childPlans...) t.cst = childMaxCost return t } func (p *PhysicalUnionAll) attach2Task(tasks ...task) task { - if _, ok := tasks[0].(*mppTask); ok { - return p.attach2MppTasks(tasks...) + for _, t := range tasks { + if _, ok := t.(*mppTask); ok { + return p.attach2MppTasks(tasks...) + } } t := &rootTask{p: p} childPlans := make([]PhysicalPlan, 0, len(tasks)) diff --git a/planner/core/testdata/integration_serial_suite_in.json b/planner/core/testdata/integration_serial_suite_in.json index c064aa3b58edc..887a4e9afd390 100644 --- a/planner/core/testdata/integration_serial_suite_in.json +++ b/planner/core/testdata/integration_serial_suite_in.json @@ -209,6 +209,16 @@ "desc format = 'brief' select id from t as A where not exists (select 1 from t where t.id=A.id)" ] }, + { + "name": "TestMppUnionAll", + "cases": [ + "explain format = 'brief' select count(*) from (select a , b from t union all select a , b from t1) tt", + "explain format = 'brief' select count(*) from (select a , b from t union all select a , b from t1 union all select a, b from t where false) tt", + "explain format = 'brief' select count(*) from (select a , b from t union all select a , b from t1) tt", + "explain format = 'brief' select count(*) from (select a , b from t union all select a , b from t1 where false) tt", + "explain format = 'brief' select count(*) from (select a , b from t where false union all select a , b from t1 where false) tt" + ] + }, { "name": "TestMppJoinDecimal", "cases": [ diff --git a/planner/core/testdata/integration_serial_suite_out.json b/planner/core/testdata/integration_serial_suite_out.json index df1a3c0a729b1..3356a535d59cd 100644 --- a/planner/core/testdata/integration_serial_suite_out.json +++ b/planner/core/testdata/integration_serial_suite_out.json @@ -1585,6 +1585,74 @@ } ] }, + { + "Name": "TestMppUnionAll", + "Cases": [ + { + "SQL": "explain format = 'brief' select count(*) from (select a , b from t union all select a , b from t1) tt", + "Plan": [ + "HashAgg 1.00 root funcs:count(Column#12)->Column#11", + "└─TableReader 1.00 root data:ExchangeSender", + " └─ExchangeSender 1.00 cop[tiflash] ExchangeType: PassThrough", + " └─HashAgg 1.00 cop[tiflash] funcs:count(1)->Column#12", + " └─Union 20000.00 cop[tiflash] ", + " ├─Projection 10000.00 cop[tiflash] cast(test.t.a, int(11) BINARY)->Column#9, test.t.b", + " │ └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo", + " └─Projection 10000.00 cop[tiflash] test.t1.a, cast(test.t1.b, int(11) BINARY)->Column#10", + " └─TableFullScan 10000.00 cop[tiflash] table:t1 keep order:false, stats:pseudo" + ] + }, + { + "SQL": "explain format = 'brief' select count(*) from (select a , b from t union all select a , b from t1 union all select a, b from t where false) tt", + "Plan": [ + "HashAgg 1.00 root funcs:count(Column#16)->Column#15", + "└─TableReader 1.00 root data:ExchangeSender", + " └─ExchangeSender 1.00 cop[tiflash] ExchangeType: PassThrough", + " └─HashAgg 1.00 cop[tiflash] funcs:count(1)->Column#16", + " └─Union 20000.00 cop[tiflash] ", + " ├─Projection 10000.00 cop[tiflash] cast(test.t.a, int(11) BINARY)->Column#13, test.t.b", + " │ └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo", + " └─Projection 10000.00 cop[tiflash] test.t1.a, cast(test.t1.b, int(11) BINARY)->Column#14", + " └─TableFullScan 10000.00 cop[tiflash] table:t1 keep order:false, stats:pseudo" + ] + }, + { + "SQL": "explain format = 'brief' select count(*) from (select a , b from t union all select a , b from t1) tt", + "Plan": [ + "HashAgg 1.00 root funcs:count(Column#12)->Column#11", + "└─TableReader 1.00 root data:ExchangeSender", + " └─ExchangeSender 1.00 cop[tiflash] ExchangeType: PassThrough", + " └─HashAgg 1.00 cop[tiflash] funcs:count(1)->Column#12", + " └─Union 20000.00 cop[tiflash] ", + " ├─Projection 10000.00 cop[tiflash] cast(test.t.a, int(11) BINARY)->Column#9, test.t.b", + " │ └─TableFullScan 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo", + " └─Projection 10000.00 cop[tiflash] test.t1.a, cast(test.t1.b, int(11) BINARY)->Column#10", + " └─TableFullScan 10000.00 cop[tiflash] table:t1 keep order:false, stats:pseudo" + ] + }, + { + "SQL": "explain format = 'brief' select count(*) from (select a , b from t union all select a , b from t1 where false) tt", + "Plan": [ + "HashAgg 1.00 root funcs:count(Column#12)->Column#11", + "└─TableReader 1.00 root data:ExchangeSender", + " └─ExchangeSender 1.00 batchCop[tiflash] ExchangeType: PassThrough", + " └─HashAgg 1.00 batchCop[tiflash] funcs:count(1)->Column#12", + " └─Union 10000.00 batchCop[tiflash] ", + " └─Projection 10000.00 batchCop[tiflash] cast(test.t.a, int(11) BINARY)->Column#9, test.t.b", + " └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "explain format = 'brief' select count(*) from (select a , b from t where false union all select a , b from t1 where false) tt", + "Plan": [ + "StreamAgg 1.00 root funcs:count(1)->Column#11", + "└─Union 0.00 root ", + " ├─TableDual 0.00 root rows:0", + " └─TableDual 0.00 root rows:0" + ] + } + ] + }, { "Name": "TestMppJoinDecimal", "Cases": [