diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 743d3ce944fe2..6540e95b01e3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -31,40 +31,6 @@ import org.apache.spark.sql.internal.SQLConf * Cost-based join reorder. * We may have several join reorder algorithms in the future. This class is the entry of these * algorithms, and chooses which one to use. - * - * Note that join strategy hints, e.g. the broadcast hint, do not interfere with the reordering. - * Such hints will be applied on the equivalent counterparts (i.e., join between the same relations - * regardless of the join order) of the original nodes after reordering. - * For example, the plan before reordering is like: - * - * Join - * / \ - * Hint1 t4 - * / - * Join - * / \ - * Join t3 - * / \ - * Hint2 t2 - * / - * t1 - * - * The original join order as illustrated above is "((t1 JOIN t2) JOIN t3) JOIN t4", and after - * reordering, the new join order is "((t1 JOIN t3) JOIN t2) JOIN t4", so the new plan will be like: - * - * Join - * / \ - * Hint1 t4 - * / - * Join - * / \ - * Join t2 - * / \ - * t1 t3 - * - * "Hint1" is applied on "(t1 JOIN t3) JOIN t2" as it is equivalent to the original hinted node, - * "(t1 JOIN t2) JOIN t3"; while "Hint2" has disappeared from the new plan since there is no - * equivalent node to "t1 JOIN t2". */ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { @@ -74,30 +40,24 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { if (!conf.cboEnabled || !conf.joinReorderEnabled) { plan } else { - // Use a map to track the hints on the join items. - val hintMap = new mutable.HashMap[AttributeSet, HintInfo] val result = plan transformDown { // Start reordering with a joinable item, which is an InnerLike join with conditions. - case j @ Join(_, _, _: InnerLike, Some(cond), _) => - reorder(j, j.output, hintMap) - case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), _)) - if projectList.forall(_.isInstanceOf[Attribute]) => - reorder(p, p.output, hintMap) + // Avoid reordering if a join hint is present. + case j @ Join(_, _, _: InnerLike, Some(cond), hint) if hint == JoinHint.NONE => + reorder(j, j.output) + case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), hint)) + if projectList.forall(_.isInstanceOf[Attribute]) && hint == JoinHint.NONE => + reorder(p, p.output) } // After reordering is finished, convert OrderedJoin back to Join. result transform { - case OrderedJoin(left, right, jt, cond) => - val joinHint = JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)) - Join(left, right, jt, cond, joinHint) + case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond, JoinHint.NONE) } } } - private def reorder( - plan: LogicalPlan, - output: Seq[Attribute], - hintMap: mutable.HashMap[AttributeSet, HintInfo]): LogicalPlan = { - val (items, conditions) = extractInnerJoins(plan, hintMap) + private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan = { + val (items, conditions) = extractInnerJoins(plan) val result = // Do reordering if the number of items is appropriate and join conditions exist. // We also need to check if costs of all items can be evaluated. @@ -115,20 +75,16 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { * Extracts items of consecutive inner joins and join conditions. * This method works for bushy trees and left/right deep trees. */ - private def extractInnerJoins( - plan: LogicalPlan, - hintMap: mutable.HashMap[AttributeSet, HintInfo]): (Seq[LogicalPlan], Set[Expression]) = { + private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { plan match { - case Join(left, right, _: InnerLike, Some(cond), hint) => - hint.leftHint.foreach(hintMap.put(left.outputSet, _)) - hint.rightHint.foreach(hintMap.put(right.outputSet, _)) - val (leftPlans, leftConditions) = extractInnerJoins(left, hintMap) - val (rightPlans, rightConditions) = extractInnerJoins(right, hintMap) + case Join(left, right, _: InnerLike, Some(cond), _) => + val (leftPlans, leftConditions) = extractInnerJoins(left) + val (rightPlans, rightConditions) = extractInnerJoins(right) (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++ leftConditions ++ rightConditions) case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) if projectList.forall(_.isInstanceOf[Attribute]) => - extractInnerJoins(j, hintMap) + extractInnerJoins(j) case _ => (Seq(plan), Set()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 82aefca8a1af6..251ece315f6a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -43,13 +43,11 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { * * @param input a list of LogicalPlans to inner join and the type of inner join. * @param conditions a list of condition for join. - * @param hintMap a map of relation output attribute sets to their corresponding hints. */ @tailrec final def createOrderedJoin( input: Seq[(LogicalPlan, InnerLike)], - conditions: Seq[Expression], - hintMap: Map[AttributeSet, HintInfo]): LogicalPlan = { + conditions: Seq[Expression]): LogicalPlan = { assert(input.size >= 2) if (input.size == 2) { val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin) @@ -58,8 +56,8 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { case (Inner, Inner) => Inner case (_, _) => Cross } - val join = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And), - JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet))) + val join = Join(left, right, innerJoinType, + joinConditions.reduceLeftOption(And), JoinHint.NONE) if (others.nonEmpty) { Filter(others.reduceLeft(And), join) } else { @@ -82,27 +80,27 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { val joinedRefs = left.outputSet ++ right.outputSet val (joinConditions, others) = conditions.partition( e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e)) - val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And), - JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet))) + val joined = Join(left, right, innerJoinType, + joinConditions.reduceLeftOption(And), JoinHint.NONE) // should not have reference to same logical plan - createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others, hintMap) + createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others) } } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p @ ExtractFiltersAndInnerJoins(input, conditions, hintMap) + case p @ ExtractFiltersAndInnerJoins(input, conditions) if input.size > 2 && conditions.nonEmpty => val reordered = if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) { val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions) if (starJoinPlan.nonEmpty) { val rest = input.filterNot(starJoinPlan.contains(_)) - createOrderedJoin(starJoinPlan ++ rest, conditions, hintMap) + createOrderedJoin(starJoinPlan ++ rest, conditions) } else { - createOrderedJoin(input, conditions, hintMap) + createOrderedJoin(input, conditions) } } else { - createOrderedJoin(input, conditions, hintMap) + createOrderedJoin(input, conditions) } if (p.sameOutput(reordered)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 95be0a52cb2ed..a816922f49aee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -166,35 +166,27 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { * was involved in an explicit cross join. Also returns the entire list of join conditions for * the left-deep tree. */ - def flattenJoin( - plan: LogicalPlan, - hintMap: mutable.HashMap[AttributeSet, HintInfo], - parentJoinType: InnerLike = Inner) + def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner) : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { - case Join(left, right, joinType: InnerLike, cond, hint) => - val (plans, conditions) = flattenJoin(left, hintMap, joinType) - hint.leftHint.map(hintMap.put(left.outputSet, _)) - hint.rightHint.map(hintMap.put(right.outputSet, _)) + case Join(left, right, joinType: InnerLike, cond, hint) if hint == JoinHint.NONE => + val (plans, conditions) = flattenJoin(left, joinType) (plans ++ Seq((right, joinType)), conditions ++ cond.toSeq.flatMap(splitConjunctivePredicates)) - case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, _)) => - val (plans, conditions) = flattenJoin(j, hintMap) + case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, hint)) if hint == JoinHint.NONE => + val (plans, conditions) = flattenJoin(j) (plans, conditions ++ splitConjunctivePredicates(filterCondition)) case _ => (Seq((plan, parentJoinType)), Seq.empty) } def unapply(plan: LogicalPlan) - : Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression], Map[AttributeSet, HintInfo])] + : Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])] = plan match { - case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, _)) => - val hintMap = new mutable.HashMap[AttributeSet, HintInfo] - val flattened = flattenJoin(f, hintMap) - Some((flattened._1, flattened._2, hintMap.toMap)) - case j @ Join(_, _, joinType, _, _) => - val hintMap = new mutable.HashMap[AttributeSet, HintInfo] - val flattened = flattenJoin(j, hintMap) - Some((flattened._1, flattened._2, hintMap.toMap)) + case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, hint)) + if hint == JoinHint.NONE => + Some(flattenJoin(f)) + case j @ Join(_, _, joinType, _, hint) if hint == JoinHint.NONE => + Some(flattenJoin(j)) case _ => None } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 9093d7fecb0f7..c570643c74106 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -66,7 +66,7 @@ class JoinOptimizationSuite extends PlanTest { def testExtractCheckCross (plan: LogicalPlan, expected: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]) { assert( - ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, e._2, Map.empty))) + ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, e._2))) } testExtract(x, None) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 0dee846205868..f1da0a8e865b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -292,77 +292,56 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { assertEqualPlans(originalPlan, bestPlan) } - test("hints preservation") { - // Apply hints if we find an equivalent node in the new plan, otherwise discard them. + test("don't reorder if hints present") { val originalPlan = - t1.join(t2.hint("broadcast")).hint("broadcast").join(t4.join(t3).hint("broadcast")) - .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && - (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && - (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) - - val bestPlan = - t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) - .hint("broadcast") - .join( - t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) - .hint("broadcast"), - Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .hint("broadcast") + .join( + t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + .hint("broadcast"), + Inner, + Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) - assertEqualPlans(originalPlan, bestPlan) + assertEqualPlans(originalPlan, originalPlan) val originalPlan2 = - t1.join(t2).hint("broadcast").join(t3).hint("broadcast").join(t4.hint("broadcast")) - .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && - (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && - (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) - - val bestPlan2 = t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) .hint("broadcast") - .join( - t4.hint("broadcast") - .join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), - Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) - .select(outputsOf(t1, t2, t3, t4): _*) + .join(t4, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + .hint("broadcast") + .join(t3, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) - assertEqualPlans(originalPlan2, bestPlan2) + assertEqualPlans(originalPlan2, originalPlan2) + } - val originalPlan3 = - t1.join(t4).hint("broadcast") - .join(t2.hint("broadcast")).hint("broadcast") - .join(t3.hint("broadcast")) - .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && - (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && - (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + test("reorder below and above the hint node") { + val originalPlan = + t1.join(t2).join(t3) + .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .hint("broadcast").join(t4) - val bestPlan3 = - t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) - .join( - t4.join(t3.hint("broadcast"), - Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), - Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) - .select(outputsOf(t1, t4, t2, t3): _*) + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(outputsOf(t1, t2, t3): _*) + .hint("broadcast").join(t4) - assertEqualPlans(originalPlan3, bestPlan3) + assertEqualPlans(originalPlan, bestPlan) - val originalPlan4 = - t2.hint("broadcast") - .join(t4).hint("broadcast") - .join(t3.hint("broadcast")).hint("broadcast") - .join(t1) - .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && - (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && - (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + val originalPlan2 = + t1.join(t2).join(t3) + .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t4.hint("broadcast")) - val bestPlan4 = - t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) - .join( - t4.join(t3.hint("broadcast"), - Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), - Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) - .select(outputsOf(t2, t4, t3, t1): _*) + val bestPlan2 = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(outputsOf(t1, t2, t3): _*) + .join(t4.hint("broadcast")) - assertEqualPlans(originalPlan4, bestPlan4) + assertEqualPlans(originalPlan2, bestPlan2) } private def assertEqualPlans( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala index 55f210cb04dbf..30a3d54fd833f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class JoinHintSuite extends PlanTest with SharedSQLContext { @@ -100,7 +101,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { } } - test("hint preserved after join reorder") { + test("hints prevent join reorder") { withTempView("a", "b", "c") { df1.createOrReplaceTempView("a") df2.createOrReplaceTempView("b") @@ -118,12 +119,10 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { verifyJoinHint( sql("select /*+ broadcast(a, c)*/ * from a, c, b " + "where a.a1 = b.b1 and b.b1 = c.c1"), - JoinHint( - None, - Some(HintInfo(broadcast = true))) :: + JoinHint.NONE :: JoinHint( Some(HintInfo(broadcast = true)), - None):: Nil + Some(HintInfo(broadcast = true))):: Nil ) verifyJoinHint( sql("select /*+ broadcast(b, c)*/ * from a, c, b " + @@ -199,4 +198,21 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { None) :: Nil ) } + + test("hints prevent cost-based join reorder") { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") { + val join = df.join(df, "id") + val broadcasted = join.hint("broadcast") + verifyJoinHint( + join.join(broadcasted, "id").join(broadcasted, "id"), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: + JoinHint.NONE :: JoinHint.NONE :: JoinHint.NONE :: Nil + ) + } + } }