diff --git a/executor/aggregate.go b/executor/aggregate.go index 61c55b5a3584d..b6595ba8f01ae 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -90,8 +90,9 @@ type HashAggFinalWorker struct { // AfFinalResult indicates aggregation functions final result. type AfFinalResult struct { - chk *chunk.Chunk - err error + chk *chunk.Chunk + err error + giveBackCh chan *chunk.Chunk } // HashAggExec deals with all the aggregate functions. @@ -150,7 +151,6 @@ type HashAggExec struct { finishCh chan struct{} finalOutputCh chan *AfFinalResult - finalInputCh chan *chunk.Chunk partialOutputChs []chan *HashAggIntermData inputCh chan *HashAggInput partialInputChs []chan *chunk.Chunk @@ -271,7 +271,6 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { partialConcurrency := sessionVars.HashAggPartialConcurrency e.isChildReturnEmpty = true e.finalOutputCh = make(chan *AfFinalResult, finalConcurrency) - e.finalInputCh = make(chan *chunk.Chunk, finalConcurrency) e.inputCh = make(chan *HashAggInput, partialConcurrency) e.finishCh = make(chan struct{}, 1) @@ -316,11 +315,12 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { groupSet: set.NewStringSet(), inputCh: e.partialOutputChs[i], outputCh: e.finalOutputCh, - finalResultHolderCh: e.finalInputCh, + finalResultHolderCh: make(chan *chunk.Chunk, 1), rowBuffer: make([]types.Datum, 0, e.Schema().Len()), mutableRow: chunk.MutRowFromTypes(retTypes(e)), groupKeys: make([][]byte, 0, 8), } + e.finalWorkers[i].finalResultHolderCh <- newFirstChunk(e) } } @@ -540,14 +540,14 @@ func (w *HashAggFinalWorker) getFinalResult(sctx sessionctx.Context) { result.SetNumVirtualRows(result.NumRows() + 1) } if result.IsFull() { - w.outputCh <- &AfFinalResult{chk: result} + w.outputCh <- &AfFinalResult{chk: result, giveBackCh: w.finalResultHolderCh} result, finished = w.receiveFinalResultHolder() if finished { return } } } - w.outputCh <- &AfFinalResult{chk: result} + w.outputCh <- &AfFinalResult{chk: result, giveBackCh: w.finalResultHolderCh} } func (w *HashAggFinalWorker) receiveFinalResultHolder() (*chunk.Chunk, bool) { @@ -668,28 +668,26 @@ func (e *HashAggExec) parallelExec(ctx context.Context, chk *chunk.Chunk) error if e.executed { return nil } - for !chk.IsFull() { - e.finalInputCh <- chk + for { result, ok := <-e.finalOutputCh - if !ok { // all finalWorkers exited + if !ok { e.executed = true - if chk.NumRows() > 0 { // but there are some data left - return nil - } if e.isChildReturnEmpty && e.defaultVal != nil { chk.Append(e.defaultVal, 0, 1) } - e.isChildReturnEmpty = false return nil } if result.err != nil { return result.err } + chk.SwapColumns(result.chk) + result.chk.Reset() + result.giveBackCh <- result.chk if chk.NumRows() > 0 { e.isChildReturnEmpty = false + return nil } } - return nil } // unparallelExec executes hash aggregation algorithm in single thread. diff --git a/executor/executor_required_rows_test.go b/executor/executor_required_rows_test.go index 76dcb742c1796..e3896f4d66ddd 100644 --- a/executor/executor_required_rows_test.go +++ b/executor/executor_required_rows_test.go @@ -677,67 +677,6 @@ func (s *testExecSuite) TestStreamAggRequiredRows(c *C) { } } -func (s *testExecSuite) TestHashAggParallelRequiredRows(c *C) { - maxChunkSize := defaultCtx().GetSessionVars().MaxChunkSize - testCases := []struct { - totalRows int - aggFunc string - requiredRows []int - expectedRows []int - expectedRowsDS []int - gen func(valType *types.FieldType) interface{} - }{ - { - totalRows: maxChunkSize, - aggFunc: ast.AggFuncSum, - requiredRows: []int{1, 2, 3, 4, 5, 6, 7}, - expectedRows: []int{1, 2, 3, 4, 5, 6, 7}, - expectedRowsDS: []int{maxChunkSize, 0}, - gen: divGenerator(1), - }, - { - totalRows: maxChunkSize * 3, - aggFunc: ast.AggFuncAvg, - requiredRows: []int{1, 3}, - expectedRows: []int{1, 2}, - expectedRowsDS: []int{maxChunkSize, maxChunkSize, maxChunkSize, 0}, - gen: divGenerator(maxChunkSize), - }, - { - totalRows: maxChunkSize * 3, - aggFunc: ast.AggFuncAvg, - requiredRows: []int{maxChunkSize, maxChunkSize}, - expectedRows: []int{maxChunkSize, maxChunkSize / 2}, - expectedRowsDS: []int{maxChunkSize, maxChunkSize, maxChunkSize, 0}, - gen: divGenerator(2), - }, - } - - for _, hasDistinct := range []bool{false, true} { - for _, testCase := range testCases { - sctx := defaultCtx() - ctx := context.Background() - ds := newRequiredRowsDataSourceWithGenerator(sctx, testCase.totalRows, testCase.expectedRowsDS, testCase.gen) - childCols := ds.Schema().Columns - schema := expression.NewSchema(childCols...) - groupBy := []expression.Expression{childCols[1]} - aggFunc, err := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, hasDistinct) - c.Assert(err, IsNil) - aggFuncs := []*aggregation.AggFuncDesc{aggFunc} - exec := buildHashAggExecutor(sctx, ds, schema, aggFuncs, groupBy) - c.Assert(exec.Open(ctx), IsNil) - chk := newFirstChunk(exec) - for i := range testCase.requiredRows { - chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) - c.Assert(exec.Next(ctx, chk), IsNil) - c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i]) - } - c.Assert(exec.Close(), IsNil) - c.Assert(ds.checkNumNextCalled(), IsNil) - } - } -} - func (s *testExecSuite) TestMergeJoinRequiredRows(c *C) { justReturn1 := func(valType *types.FieldType) interface{} { switch valType.Tp { diff --git a/planner/cascades/implementation_rules.go b/planner/cascades/implementation_rules.go index b54e12f2f60f8..762708ffed7a8 100644 --- a/planner/cascades/implementation_rules.go +++ b/planner/cascades/implementation_rules.go @@ -74,6 +74,12 @@ var defaultImplementationMap = map[memo.Operand][]ImplementationRule{ memo.OperandUnionAll: { &ImplUnionAll{}, }, + memo.OperandApply: { + &ImplApply{}, + }, + memo.OperandMaxOneRow: { + &ImplMaxOneRow{}, + }, } // ImplTableDual implements LogicalTableDual as PhysicalTableDual. @@ -467,3 +473,49 @@ func (r *ImplUnionAll) OnImplement(expr *memo.GroupExpr, reqProp *property.Physi physicalUnion.SetSchema(expr.Group.Prop.Schema) return impl.NewUnionAllImpl(physicalUnion), nil } + +// ImplApply implements LogicalApply to PhysicalApply +type ImplApply struct { +} + +// Match implements ImplementationRule Match interface. +func (r *ImplApply) Match(expr *memo.GroupExpr, prop *property.PhysicalProperty) (matched bool) { + return prop.AllColsFromSchema(expr.Children[0].Prop.Schema) +} + +// OnImplement implements ImplementationRule OnImplement interface +func (r *ImplApply) OnImplement(expr *memo.GroupExpr, reqProp *property.PhysicalProperty) (memo.Implementation, error) { + la := expr.ExprNode.(*plannercore.LogicalApply) + join := la.GetHashJoin(reqProp) + physicalApply := plannercore.PhysicalApply{ + PhysicalHashJoin: *join, + OuterSchema: la.CorCols, + }.Init( + la.SCtx(), + expr.Group.Prop.Stats.ScaleByExpectCnt(reqProp.ExpectedCnt), + la.SelectBlockOffset(), + &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, Items: reqProp.Items}, + &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64}) + physicalApply.SetSchema(expr.Group.Prop.Schema) + return impl.NewApplyImpl(physicalApply), nil +} + +// ImplMaxOneRow implements LogicalMaxOneRow to PhysicalMaxOneRow. +type ImplMaxOneRow struct { +} + +// Match implements ImplementationRule Match interface. +func (r *ImplMaxOneRow) Match(expr *memo.GroupExpr, prop *property.PhysicalProperty) (matched bool) { + return prop.IsEmpty() +} + +// OnImplement implements ImplementationRule OnImplement interface +func (r *ImplMaxOneRow) OnImplement(expr *memo.GroupExpr, reqProp *property.PhysicalProperty) (memo.Implementation, error) { + mor := expr.ExprNode.(*plannercore.LogicalMaxOneRow) + physicalMaxOneRow := plannercore.PhysicalMaxOneRow{}.Init( + mor.SCtx(), + expr.Group.Prop.Stats, + mor.SelectBlockOffset(), + &property.PhysicalProperty{ExpectedCnt: 2}) + return impl.NewMaxOneRowImpl(physicalMaxOneRow), nil +} diff --git a/planner/cascades/integration_test.go b/planner/cascades/integration_test.go index 8e29d97544001..3561179606516 100644 --- a/planner/cascades/integration_test.go +++ b/planner/cascades/integration_test.go @@ -221,3 +221,29 @@ func (s *testIntegrationSuite) TestJoin(c *C) { tk.MustQuery(sql).Check(testkit.Rows(output[i].Result...)) } } + +func (s *testIntegrationSuite) TestApply(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1(a int primary key, b int)") + tk.MustExec("create table t2(a int primary key, b int)") + tk.MustExec("insert into t1 values (1, 11), (4, 44), (2, 22), (3, 33)") + tk.MustExec("insert into t2 values (1, 11), (2, 22), (3, 33)") + tk.MustExec("set session tidb_enable_cascades_planner = 1") + var input []string + var output []struct { + SQL string + Plan []string + Result []string + } + s.testData.GetTestCases(c, &input, &output) + for i, sql := range input { + s.testData.OnRecord(func() { + output[i].SQL = sql + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery("explain " + sql).Rows()) + output[i].Result = s.testData.ConvertRowsToStrings(tk.MustQuery(sql).Rows()) + }) + tk.MustQuery("explain " + sql).Check(testkit.Rows(output[i].Plan...)) + tk.MustQuery(sql).Check(testkit.Rows(output[i].Result...)) + } +} diff --git a/planner/cascades/testdata/integration_suite_in.json b/planner/cascades/testdata/integration_suite_in.json index e4d537155c0c1..0f3d00c3cc9f0 100644 --- a/planner/cascades/testdata/integration_suite_in.json +++ b/planner/cascades/testdata/integration_suite_in.json @@ -56,5 +56,12 @@ "select t1.a, t1.b from t1, t2 where t1.a = t2.a and t1.a > 2", "select t1.a, t1.b from t1, t2 where t1.a > t2.a and t2.b > 200" ] + }, + { + "name": "TestApply", + "cases": [ + "select a = (select a from t2 where t1.b = t2.b order by a limit 1) from t1", + "select sum(a), (select t1.a from t1 where t1.a = t2.a limit 1), (select t1.b from t1 where t1.b = t2.b limit 1) from t2" + ] } ] diff --git a/planner/cascades/testdata/integration_suite_out.json b/planner/cascades/testdata/integration_suite_out.json index e9bcfcc5bdf8f..4af6962bfda8a 100644 --- a/planner/cascades/testdata/integration_suite_out.json +++ b/planner/cascades/testdata/integration_suite_out.json @@ -406,5 +406,56 @@ ] } ] + }, + { + "Name": "TestApply", + "Cases": [ + { + "SQL": "select a = (select a from t2 where t1.b = t2.b order by a limit 1) from t1", + "Plan": [ + "Projection_17 10000.00 root eq(test.t1.a, test.t2.a)->Column#5", + "└─Apply_19 10000.00 root CARTESIAN left outer join, inner:MaxOneRow_22", + " ├─TableReader_20 10000.00 root data:TableScan_21", + " │ └─TableScan_21 10000.00 cop[tikv] table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + " └─MaxOneRow_22 1.00 root ", + " └─Projection_23 1.00 root test.t2.a", + " └─Limit_25 1.00 root offset:0, count:1", + " └─TableReader_29 1.00 root data:Selection_30", + " └─Selection_30 1.00 cop[tikv] eq(test.t1.b, test.t2.b)", + " └─TableScan_31 1.00 cop[tikv] table:t2, range:[-inf,+inf], keep order:true, stats:pseudo" + ], + "Result": [ + "1", + "1", + "1", + "" + ] + }, + { + "SQL": "select sum(a), (select t1.a from t1 where t1.a = t2.a limit 1), (select t1.b from t1 where t1.b = t2.b limit 1) from t2", + "Plan": [ + "Projection_26 1.00 root Column#3, test.t1.a, test.t1.b", + "└─Apply_28 1.00 root CARTESIAN left outer join, inner:MaxOneRow_43", + " ├─Apply_30 1.00 root CARTESIAN left outer join, inner:MaxOneRow_38", + " │ ├─HashAgg_35 1.00 root funcs:sum(Column#8)->Column#3, funcs:firstrow(Column#9)->test.t2.a, funcs:firstrow(Column#10)->test.t2.b", + " │ │ └─TableReader_36 1.00 root data:HashAgg_37", + " │ │ └─HashAgg_37 1.00 cop[tikv] funcs:sum(test.t2.a)->Column#8, funcs:firstrow(test.t2.a)->Column#9, funcs:firstrow(test.t2.b)->Column#10", + " │ │ └─TableScan_33 10000.00 cop[tikv] table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + " │ └─MaxOneRow_38 1.00 root ", + " │ └─Limit_39 1.00 root offset:0, count:1", + " │ └─TableReader_40 1.00 root data:Selection_41", + " │ └─Selection_41 1.00 cop[tikv] eq(test.t1.a, test.t2.a)", + " │ └─TableScan_42 1.00 cop[tikv] table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + " └─MaxOneRow_43 1.00 root ", + " └─Limit_44 1.00 root offset:0, count:1", + " └─TableReader_45 1.00 root data:Selection_46", + " └─Selection_46 1.00 cop[tikv] eq(test.t1.b, test.t2.b)", + " └─TableScan_47 1.00 cop[tikv] table:t1, range:[-inf,+inf], keep order:false, stats:pseudo" + ], + "Result": [ + "6 1 11" + ] + } + ] } ] diff --git a/planner/cascades/testdata/stringer_suite_in.json b/planner/cascades/testdata/stringer_suite_in.json index d30220f52efa9..f947b1f2d6cad 100644 --- a/planner/cascades/testdata/stringer_suite_in.json +++ b/planner/cascades/testdata/stringer_suite_in.json @@ -21,7 +21,9 @@ // Order by. "select a from t where b > 1 order by c", // Union ALL. - "select avg(a) from t union all select avg(b) from t" + "select avg(a) from t union all select avg(b) from t", + // Apply. + "select a = (select a from t t2 where t1.b = t2.b order by a limit 1) from t t1" ] } ] diff --git a/planner/cascades/testdata/stringer_suite_out.json b/planner/cascades/testdata/stringer_suite_out.json index 9daac9bdc98d5..61bd297511785 100644 --- a/planner/cascades/testdata/stringer_suite_out.json +++ b/planner/cascades/testdata/stringer_suite_out.json @@ -285,6 +285,33 @@ "Group#16 Schema:[test.t.b]", " TableScan_24 table:t" ] + }, + { + "SQL": "select a = (select a from t t2 where t1.b = t2.b order by a limit 1) from t t1", + "Result": [ + "Group#0 Schema:[Column#25]", + " Projection_2 input:[Group#1], eq(test.t.a, test.t.a)->Column#25", + "Group#1 Schema:[test.t.a,test.t.b,test.t.a]", + " Apply_9 input:[Group#2,Group#3], left outer join", + "Group#2 Schema:[test.t.a,test.t.b], UniqueKey:[test.t.a]", + " TiKVSingleGather_11 input:[Group#4], table:t1", + "Group#4 Schema:[test.t.a,test.t.b], UniqueKey:[test.t.a]", + " TableScan_10 table:t1, pk col:test.t.a", + "Group#3 Schema:[test.t.a], UniqueKey:[test.t.a]", + " MaxOneRow_8 input:[Group#5]", + "Group#5 Schema:[test.t.a], UniqueKey:[test.t.a]", + " Limit_7 input:[Group#6], offset:0, count:1", + "Group#6 Schema:[test.t.a], UniqueKey:[test.t.a]", + " Sort_6 input:[Group#7], test.t.a:asc", + "Group#7 Schema:[test.t.a], UniqueKey:[test.t.a]", + " Projection_5 input:[Group#8], test.t.a", + "Group#8 Schema:[test.t.a,test.t.b], UniqueKey:[test.t.a]", + " TiKVSingleGather_13 input:[Group#9], table:t2", + "Group#9 Schema:[test.t.a,test.t.b], UniqueKey:[test.t.a]", + " Selection_14 input:[Group#10], eq(test.t.b, test.t.b)", + "Group#10 Schema:[test.t.a,test.t.b], UniqueKey:[test.t.a]", + " TableScan_12 table:t2, pk col:test.t.a" + ] } ] } diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index 9e1cbc3174018..7526eaff48c11 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -1412,15 +1412,19 @@ func (lt *LogicalTopN) exhaustPhysicalPlans(prop *property.PhysicalProperty) []P return nil } +// GetHashJoin is public for cascades planner. +func (la *LogicalApply) GetHashJoin(prop *property.PhysicalProperty) *PhysicalHashJoin { + return la.LogicalJoin.getHashJoin(prop, 1, false) +} + func (la *LogicalApply) exhaustPhysicalPlans(prop *property.PhysicalProperty) []PhysicalPlan { if !prop.AllColsFromSchema(la.children[0].Schema()) { // for convenient, we don't pass through any prop return nil } - join := la.getHashJoin(prop, 1, false) + join := la.GetHashJoin(prop) apply := PhysicalApply{ PhysicalHashJoin: *join, - OuterSchema: la.corCols, - rightChOffset: la.children[0].Schema().Len(), + OuterSchema: la.CorCols, }.Init(la.ctx, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), la.blockOffset, diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index bb199b8893022..deed7bc477683 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -334,7 +334,7 @@ func (p *LogicalSelection) extractCorrelatedCols() []*expression.CorrelatedColum type LogicalApply struct { LogicalJoin - corCols []*expression.CorrelatedColumn + CorCols []*expression.CorrelatedColumn } func (la *LogicalApply) extractCorrelatedCols() []*expression.CorrelatedColumn { diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 23b17a1e26786..17136487828a4 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -286,8 +286,7 @@ type PhysicalTopN struct { type PhysicalApply struct { PhysicalHashJoin - OuterSchema []*expression.CorrelatedColumn - rightChOffset int + OuterSchema []*expression.CorrelatedColumn } type basePhysicalJoin struct { diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index c90ff5cdfe236..870c05f2e0586 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -316,8 +316,8 @@ func (la *LogicalApply) PruneColumns(parentUsedCols []*expression.Column) error return err } - la.corCols = extractCorColumnsBySchema(la.children[1], la.children[0].Schema()) - for _, col := range la.corCols { + la.CorCols = extractCorColumnsBySchema(la.children[1], la.children[0].Schema()) + for _, col := range la.CorCols { leftCols = append(leftCols, &col.Column) } diff --git a/planner/core/rule_decorrelate.go b/planner/core/rule_decorrelate.go index cd7bc9007f1e2..63173c688a172 100644 --- a/planner/core/rule_decorrelate.go +++ b/planner/core/rule_decorrelate.go @@ -102,8 +102,8 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan) (Logica if apply, ok := p.(*LogicalApply); ok { outerPlan := apply.children[0] innerPlan := apply.children[1] - apply.corCols = extractCorColumnsBySchema(apply.children[1], apply.children[0].Schema()) - if len(apply.corCols) == 0 { + apply.CorCols = extractCorColumnsBySchema(apply.children[1], apply.children[0].Schema()) + if len(apply.CorCols) == 0 { // If the inner plan is non-correlated, the apply will be simplified to join. join := &apply.LogicalJoin join.self = join @@ -196,10 +196,10 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan) (Logica if len(eqCondWithCorCol) > 0 { originalExpr := sel.Conditions sel.Conditions = remainedExpr - apply.corCols = extractCorColumnsBySchema(apply.children[1], apply.children[0].Schema()) + apply.CorCols = extractCorColumnsBySchema(apply.children[1], apply.children[0].Schema()) // There's no other correlated column. groupByCols := expression.NewSchema(agg.groupByCols...) - if len(apply.corCols) == 0 { + if len(apply.CorCols) == 0 { join := &apply.LogicalJoin join.EqualConditions = append(join.EqualConditions, eqCondWithCorCol...) for _, eqCond := range eqCondWithCorCol { @@ -242,7 +242,7 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan) (Logica return s.optimize(ctx, p) } sel.Conditions = originalExpr - apply.corCols = extractCorColumnsBySchema(apply.children[1], apply.children[0].Schema()) + apply.CorCols = extractCorColumnsBySchema(apply.children[1], apply.children[0].Schema()) } } } else if sort, ok := innerPlan.(*LogicalSort); ok { diff --git a/planner/core/rule_eliminate_projection.go b/planner/core/rule_eliminate_projection.go index 784e25dcdd62e..46817e560381c 100644 --- a/planner/core/rule_eliminate_projection.go +++ b/planner/core/rule_eliminate_projection.go @@ -215,7 +215,7 @@ func (p *LogicalSelection) replaceExprColumns(replace map[string]*expression.Col func (la *LogicalApply) replaceExprColumns(replace map[string]*expression.Column) { la.LogicalJoin.replaceExprColumns(replace) - for _, coCol := range la.corCols { + for _, coCol := range la.CorCols { dst := replace[string(coCol.Column.HashCode(nil))] if dst != nil { coCol.Column = *dst diff --git a/planner/core/task.go b/planner/core/task.go index b280a09640b13..07ec780fed065 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -166,14 +166,20 @@ func (p *PhysicalApply) attach2Task(tasks ...task) task { rTask := finishCopTask(p.ctx, tasks[1].copy()) p.SetChildren(lTask.plan(), rTask.plan()) p.schema = BuildPhysicalJoinSchema(p.JoinType, p) + return &rootTask{ + p: p, + cst: p.GetCost(lTask.count(), rTask.count()) + lTask.cost(), + } +} + +// GetCost computes the cost of apply operator. +func (p *PhysicalApply) GetCost(lCount float64, rCount float64) float64 { var cpuCost float64 - lCount := lTask.count() sessVars := p.ctx.GetSessionVars() if len(p.LeftConditions) > 0 { cpuCost += lCount * sessVars.CPUFactor lCount *= selectionFactor } - rCount := rTask.count() if len(p.RightConditions) > 0 { cpuCost += lCount * rCount * sessVars.CPUFactor rCount *= selectionFactor @@ -181,10 +187,7 @@ func (p *PhysicalApply) attach2Task(tasks ...task) task { if len(p.EqualConditions)+len(p.OtherConditions) > 0 { cpuCost += lCount * rCount * sessVars.CPUFactor } - return &rootTask{ - p: p, - cst: cpuCost + lTask.cost(), - } + return cpuCost } func (p *PhysicalIndexMergeJoin) attach2Task(tasks ...task) task { diff --git a/planner/implementation/simple_plans.go b/planner/implementation/simple_plans.go index 7d53d2ea5d839..c32f8c4526290 100644 --- a/planner/implementation/simple_plans.go +++ b/planner/implementation/simple_plans.go @@ -162,10 +162,44 @@ func (impl *UnionAllImpl) CalcCost(outCount float64, children ...memo.Implementa selfCost := float64(1+len(children)) * impl.plan.SCtx().GetSessionVars().ConcurrencyFactor // Children of UnionAll are executed in parallel. impl.cost = selfCost + childMaxCost - return selfCost + return impl.cost } // NewUnionAllImpl creates a new UnionAllImpl. func NewUnionAllImpl(union *plannercore.PhysicalUnionAll) *UnionAllImpl { return &UnionAllImpl{baseImpl{plan: union}} } + +// ApplyImpl is the implementation of PhysicalApply. +type ApplyImpl struct { + baseImpl +} + +// CalcCost implements Implementation CalcCost interface. +func (impl *ApplyImpl) CalcCost(outCount float64, children ...memo.Implementation) float64 { + apply := impl.plan.(*plannercore.PhysicalApply) + selfCost := apply.GetCost(children[0].GetPlan().Stats().RowCount, children[1].GetPlan().Stats().RowCount) + impl.cost = selfCost + children[0].GetCost() + return impl.cost +} + +// NewApplyImpl creates a new ApplyImpl. +func NewApplyImpl(apply *plannercore.PhysicalApply) *ApplyImpl { + return &ApplyImpl{baseImpl{plan: apply}} +} + +// MaxOneRowImpl is the implementation of PhysicalApply. +type MaxOneRowImpl struct { + baseImpl +} + +// CalcCost implements Implementation CalcCost interface. +func (impl *MaxOneRowImpl) CalcCost(outCount float64, children ...memo.Implementation) float64 { + impl.cost = children[0].GetCost() + return impl.cost +} + +// NewMaxOneRowImpl creates a new MaxOneRowImpl. +func NewMaxOneRowImpl(maxOneRow *plannercore.PhysicalMaxOneRow) *MaxOneRowImpl { + return &MaxOneRowImpl{baseImpl{plan: maxOneRow}} +}