diff --git a/plan/logical_plan_builder.go b/plan/logical_plan_builder.go index 3e0b0eba9c731..dfcee23e1c067 100644 --- a/plan/logical_plan_builder.go +++ b/plan/logical_plan_builder.go @@ -190,9 +190,39 @@ func extractCorColumns(expr expression.Expression) (cols []*expression.Correlate return } -func extractOnCondition(conditions []expression.Expression, left LogicalPlan, right LogicalPlan) ( +// pushDownConstExpr checks if the condition is from filter condition, if true, push it down to both +// children of join, whatever the join type is; if false, push it down to inner child of outer join, +// and both children of non-outer-join. +func (p *LogicalJoin) pushDownConstExpr(expr expression.Expression, leftCond []expression.Expression, + rightCond []expression.Expression, filterCond bool) ([]expression.Expression, []expression.Expression) { + switch p.JoinType { + case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: + if filterCond { + leftCond = append(leftCond, expr) + // Append the expr to right join condition instead of `rightCond`, to make it able to be + // pushed down to children of join. + p.RightConditions = append(p.RightConditions, expr) + } else { + rightCond = append(rightCond, expr) + } + case RightOuterJoin: + if filterCond { + rightCond = append(rightCond, expr) + p.LeftConditions = append(p.LeftConditions, expr) + } else { + leftCond = append(leftCond, expr) + } + case SemiJoin, AntiSemiJoin, InnerJoin: + leftCond = append(leftCond, expr) + rightCond = append(rightCond, expr) + } + return leftCond, rightCond +} + +func (p *LogicalJoin) extractOnCondition(conditions []expression.Expression, filterCond bool) ( eqCond []*expression.ScalarFunction, leftCond []expression.Expression, rightCond []expression.Expression, otherCond []expression.Expression) { + left, right := p.children[0], p.children[1] for _, expr := range conditions { binop, ok := expr.(*expression.ScalarFunction) if ok && binop.FuncName.L == ast.EQ { @@ -211,6 +241,12 @@ func extractOnCondition(conditions []expression.Expression, left LogicalPlan, ri } } columns := expression.ExtractColumns(expr) + // `columns` may be empty, if the condition is like `correlated_column op constant`, or `constant`, + // push this kind of constant condition down according to join type. + if len(columns) == 0 { + leftCond, rightCond = p.pushDownConstExpr(expr, leftCond, rightCond, filterCond) + continue + } allFromLeft, allFromRight := true, true for _, col := range columns { if !left.Schema().Contains(col) { diff --git a/plan/logical_plan_test.go b/plan/logical_plan_test.go index 0cba20ace5a4c..5196d1b7fb527 100644 --- a/plan/logical_plan_test.go +++ b/plan/logical_plan_test.go @@ -654,7 +654,7 @@ func (s *testPlanSuite) TestJoinReOrder(c *C) { }, { sql: "select * from t o where o.b in (select t3.c from t t1, t t2, t t3 where t1.a = t3.a and t2.a = t3.a and t2.a = o.a and t1.a = 1)", - best: "Apply{DataScan(o)->Join{Join{DataScan(t1)->DataScan(t3)}->DataScan(t2)->Sel([eq(1, o.a)])}->Projection}->Projection", + best: "Apply{DataScan(o)->Join{Join{DataScan(t1)->Sel([eq(1, o.a)])->DataScan(t3)->Sel([eq(1, o.a)])}->DataScan(t2)->Sel([eq(1, o.a)])}->Projection}->Projection", }, } for _, tt := range tests { diff --git a/plan/logical_plans.go b/plan/logical_plans.go index 2e8d99b3a7c79..2a59c63d9b87e 100644 --- a/plan/logical_plans.go +++ b/plan/logical_plans.go @@ -170,7 +170,7 @@ func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expres } func (p *LogicalJoin) attachOnConds(onConds []expression.Expression) { - eq, left, right, other := extractOnCondition(onConds, p.children[0].(LogicalPlan), p.children[1].(LogicalPlan)) + eq, left, right, other := p.extractOnCondition(onConds, false) p.EqualConditions = append(eq, p.EqualConditions...) p.LeftConditions = append(left, p.LeftConditions...) p.RightConditions = append(right, p.RightConditions...) diff --git a/plan/predicate_push_down.go b/plan/predicate_push_down.go index 5d52e236c0a55..a3fb686407441 100644 --- a/plan/predicate_push_down.go +++ b/plan/predicate_push_down.go @@ -108,15 +108,13 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret return newJoin.PredicatePushDown(predicates) } var leftCond, rightCond []expression.Expression - leftPlan := p.children[0] - rightPlan := p.children[1] var ( equalCond []*expression.ScalarFunction leftPushCond, rightPushCond, otherCond []expression.Expression ) if p.JoinType != InnerJoin { predicates = expression.ExtractFiltersFromDNFs(p.ctx, predicates) - equalCond, leftPushCond, rightPushCond, otherCond = extractOnCondition(predicates, leftPlan, rightPlan) + equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, true) } else { tempCond := make([]expression.Expression, 0, len(p.LeftConditions)+len(p.RightConditions)+len(p.EqualConditions)+len(p.OtherConditions)+len(predicates)) tempCond = append(tempCond, p.LeftConditions...) @@ -125,7 +123,7 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret tempCond = append(tempCond, p.OtherConditions...) tempCond = append(tempCond, predicates...) tempCond = expression.ExtractFiltersFromDNFs(p.ctx, tempCond) - equalCond, leftPushCond, rightPushCond, otherCond = extractOnCondition(expression.PropagateConstant(p.ctx, tempCond), leftPlan, rightPlan) + equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(expression.PropagateConstant(p.ctx, tempCond), true) } switch p.JoinType { case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: @@ -141,7 +139,7 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) ret = append(ret, leftPushCond...) case SemiJoin, AntiSemiJoin: - _, leftPushCond, rightPushCond, _ = extractOnCondition(predicates, leftPlan, rightPlan) + _, leftPushCond, rightPushCond, _ = p.extractOnCondition(predicates, true) leftCond = append(p.LeftConditions, leftPushCond...) rightCond = append(p.RightConditions, rightPushCond...) p.LeftConditions = nil @@ -166,8 +164,8 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret for i := range rightCond { rightCond[i] = rightCond[i].Clone() } - leftRet, lCh := leftPlan.PredicatePushDown(leftCond) - rightRet, rCh := rightPlan.PredicatePushDown(rightCond) + leftRet, lCh := p.children[0].PredicatePushDown(leftCond) + rightRet, rCh := p.children[1].PredicatePushDown(rightCond) addSelection(p, lCh, leftRet, 0) addSelection(p, rCh, rightRet, 1) p.updateEQCond()