Skip to content

Commit

Permalink
[SPARK-26065][FOLLOW-UP][SQL] Revert hint behavior in join reordering
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This is to fix a bug in #23036 that would cause a join hint to be applied on node it is not supposed to after join reordering. For example,
```
  val join = df.join(df, "id")
  val broadcasted = join.hint("broadcast")
  val join2 = join.join(broadcasted, "id").join(broadcasted, "id")
```
There should only be 2 broadcast hints on `join2`, but after join reordering there would be 4. It is because the hint application in join reordering compares the attribute set for testing relation equivalency.
Moreover, it could still be problematic even if the child relations were used in testing relation equivalency, due to the potential exprId conflict in nested self-join.

As a result, this PR simply reverts the join reorder hint behavior change introduced in #23036, which means if a join hint is present, the join node itself will not participate in the join reordering, while the sub-joins within its children still can.

## How was this patch tested?

Added new tests

Closes #23524 from maryannxue/query-hint-followup-2.

Authored-by: maryannxue <maryannxue@apache.org>
Signed-off-by: gatorsmile <gatorsmile@gmail.com>
  • Loading branch information
maryannxue authored and gatorsmile committed Jan 13, 2019
1 parent 09b0548 commit 985f966
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,40 +31,6 @@ import org.apache.spark.sql.internal.SQLConf
* Cost-based join reorder.
* We may have several join reorder algorithms in the future. This class is the entry of these
* algorithms, and chooses which one to use.
*
* Note that join strategy hints, e.g. the broadcast hint, do not interfere with the reordering.
* Such hints will be applied on the equivalent counterparts (i.e., join between the same relations
* regardless of the join order) of the original nodes after reordering.
* For example, the plan before reordering is like:
*
* Join
* / \
* Hint1 t4
* /
* Join
* / \
* Join t3
* / \
* Hint2 t2
* /
* t1
*
* The original join order as illustrated above is "((t1 JOIN t2) JOIN t3) JOIN t4", and after
* reordering, the new join order is "((t1 JOIN t3) JOIN t2) JOIN t4", so the new plan will be like:
*
* Join
* / \
* Hint1 t4
* /
* Join
* / \
* Join t2
* / \
* t1 t3
*
* "Hint1" is applied on "(t1 JOIN t3) JOIN t2" as it is equivalent to the original hinted node,
* "(t1 JOIN t2) JOIN t3"; while "Hint2" has disappeared from the new plan since there is no
* equivalent node to "t1 JOIN t2".
*/
object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {

Expand All @@ -74,30 +40,24 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
if (!conf.cboEnabled || !conf.joinReorderEnabled) {
plan
} else {
// Use a map to track the hints on the join items.
val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
val result = plan transformDown {
// Start reordering with a joinable item, which is an InnerLike join with conditions.
case j @ Join(_, _, _: InnerLike, Some(cond), _) =>
reorder(j, j.output, hintMap)
case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), _))
if projectList.forall(_.isInstanceOf[Attribute]) =>
reorder(p, p.output, hintMap)
// Avoid reordering if a join hint is present.
case j @ Join(_, _, _: InnerLike, Some(cond), hint) if hint == JoinHint.NONE =>
reorder(j, j.output)
case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), hint))
if projectList.forall(_.isInstanceOf[Attribute]) && hint == JoinHint.NONE =>
reorder(p, p.output)
}
// After reordering is finished, convert OrderedJoin back to Join.
result transform {
case OrderedJoin(left, right, jt, cond) =>
val joinHint = JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet))
Join(left, right, jt, cond, joinHint)
case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond, JoinHint.NONE)
}
}
}

private def reorder(
plan: LogicalPlan,
output: Seq[Attribute],
hintMap: mutable.HashMap[AttributeSet, HintInfo]): LogicalPlan = {
val (items, conditions) = extractInnerJoins(plan, hintMap)
private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan = {
val (items, conditions) = extractInnerJoins(plan)
val result =
// Do reordering if the number of items is appropriate and join conditions exist.
// We also need to check if costs of all items can be evaluated.
Expand All @@ -115,20 +75,16 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
* Extracts items of consecutive inner joins and join conditions.
* This method works for bushy trees and left/right deep trees.
*/
private def extractInnerJoins(
plan: LogicalPlan,
hintMap: mutable.HashMap[AttributeSet, HintInfo]): (Seq[LogicalPlan], Set[Expression]) = {
private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = {
plan match {
case Join(left, right, _: InnerLike, Some(cond), hint) =>
hint.leftHint.foreach(hintMap.put(left.outputSet, _))
hint.rightHint.foreach(hintMap.put(right.outputSet, _))
val (leftPlans, leftConditions) = extractInnerJoins(left, hintMap)
val (rightPlans, rightConditions) = extractInnerJoins(right, hintMap)
case Join(left, right, _: InnerLike, Some(cond), _) =>
val (leftPlans, leftConditions) = extractInnerJoins(left)
val (rightPlans, rightConditions) = extractInnerJoins(right)
(leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++
leftConditions ++ rightConditions)
case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _))
if projectList.forall(_.isInstanceOf[Attribute]) =>
extractInnerJoins(j, hintMap)
extractInnerJoins(j)
case _ =>
(Seq(plan), Set())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,11 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
*
* @param input a list of LogicalPlans to inner join and the type of inner join.
* @param conditions a list of condition for join.
* @param hintMap a map of relation output attribute sets to their corresponding hints.
*/
@tailrec
final def createOrderedJoin(
input: Seq[(LogicalPlan, InnerLike)],
conditions: Seq[Expression],
hintMap: Map[AttributeSet, HintInfo]): LogicalPlan = {
conditions: Seq[Expression]): LogicalPlan = {
assert(input.size >= 2)
if (input.size == 2) {
val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin)
Expand All @@ -58,8 +56,8 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
case (Inner, Inner) => Inner
case (_, _) => Cross
}
val join = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And),
JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
val join = Join(left, right, innerJoinType,
joinConditions.reduceLeftOption(And), JoinHint.NONE)
if (others.nonEmpty) {
Filter(others.reduceLeft(And), join)
} else {
Expand All @@ -82,27 +80,27 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
val joinedRefs = left.outputSet ++ right.outputSet
val (joinConditions, others) = conditions.partition(
e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e))
val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And),
JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
val joined = Join(left, right, innerJoinType,
joinConditions.reduceLeftOption(And), JoinHint.NONE)

// should not have reference to same logical plan
createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others, hintMap)
createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others)
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p @ ExtractFiltersAndInnerJoins(input, conditions, hintMap)
case p @ ExtractFiltersAndInnerJoins(input, conditions)
if input.size > 2 && conditions.nonEmpty =>
val reordered = if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) {
val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions)
if (starJoinPlan.nonEmpty) {
val rest = input.filterNot(starJoinPlan.contains(_))
createOrderedJoin(starJoinPlan ++ rest, conditions, hintMap)
createOrderedJoin(starJoinPlan ++ rest, conditions)
} else {
createOrderedJoin(input, conditions, hintMap)
createOrderedJoin(input, conditions)
}
} else {
createOrderedJoin(input, conditions, hintMap)
createOrderedJoin(input, conditions)
}

if (p.sameOutput(reordered)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,35 +166,27 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
* was involved in an explicit cross join. Also returns the entire list of join conditions for
* the left-deep tree.
*/
def flattenJoin(
plan: LogicalPlan,
hintMap: mutable.HashMap[AttributeSet, HintInfo],
parentJoinType: InnerLike = Inner)
def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner)
: (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match {
case Join(left, right, joinType: InnerLike, cond, hint) =>
val (plans, conditions) = flattenJoin(left, hintMap, joinType)
hint.leftHint.map(hintMap.put(left.outputSet, _))
hint.rightHint.map(hintMap.put(right.outputSet, _))
case Join(left, right, joinType: InnerLike, cond, hint) if hint == JoinHint.NONE =>
val (plans, conditions) = flattenJoin(left, joinType)
(plans ++ Seq((right, joinType)), conditions ++
cond.toSeq.flatMap(splitConjunctivePredicates))
case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, _)) =>
val (plans, conditions) = flattenJoin(j, hintMap)
case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, hint)) if hint == JoinHint.NONE =>
val (plans, conditions) = flattenJoin(j)
(plans, conditions ++ splitConjunctivePredicates(filterCondition))

case _ => (Seq((plan, parentJoinType)), Seq.empty)
}

def unapply(plan: LogicalPlan)
: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression], Map[AttributeSet, HintInfo])]
: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]
= plan match {
case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, _)) =>
val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
val flattened = flattenJoin(f, hintMap)
Some((flattened._1, flattened._2, hintMap.toMap))
case j @ Join(_, _, joinType, _, _) =>
val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
val flattened = flattenJoin(j, hintMap)
Some((flattened._1, flattened._2, hintMap.toMap))
case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, hint))
if hint == JoinHint.NONE =>
Some(flattenJoin(f))
case j @ Join(_, _, joinType, _, hint) if hint == JoinHint.NONE =>
Some(flattenJoin(j))
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class JoinOptimizationSuite extends PlanTest {
def testExtractCheckCross
(plan: LogicalPlan, expected: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]) {
assert(
ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, e._2, Map.empty)))
ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, e._2)))
}

testExtract(x, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,77 +292,56 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
assertEqualPlans(originalPlan, bestPlan)
}

test("hints preservation") {
// Apply hints if we find an equivalent node in the new plan, otherwise discard them.
test("don't reorder if hints present") {
val originalPlan =
t1.join(t2.hint("broadcast")).hint("broadcast").join(t4.join(t3).hint("broadcast"))
.where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))

val bestPlan =
t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.hint("broadcast")
.join(
t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
.hint("broadcast"),
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.hint("broadcast")
.join(
t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
.hint("broadcast"),
Inner,
Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))

assertEqualPlans(originalPlan, bestPlan)
assertEqualPlans(originalPlan, originalPlan)

val originalPlan2 =
t1.join(t2).hint("broadcast").join(t3).hint("broadcast").join(t4.hint("broadcast"))
.where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))

val bestPlan2 =
t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.hint("broadcast")
.join(
t4.hint("broadcast")
.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
.select(outputsOf(t1, t2, t3, t4): _*)
.join(t4, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
.hint("broadcast")
.join(t3, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))

assertEqualPlans(originalPlan2, bestPlan2)
assertEqualPlans(originalPlan2, originalPlan2)
}

val originalPlan3 =
t1.join(t4).hint("broadcast")
.join(t2.hint("broadcast")).hint("broadcast")
.join(t3.hint("broadcast"))
.where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
test("reorder below and above the hint node") {
val originalPlan =
t1.join(t2).join(t3)
.where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.hint("broadcast").join(t4)

val bestPlan3 =
t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.join(
t4.join(t3.hint("broadcast"),
Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
.select(outputsOf(t1, t4, t2, t3): _*)
val bestPlan =
t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.select(outputsOf(t1, t2, t3): _*)
.hint("broadcast").join(t4)

assertEqualPlans(originalPlan3, bestPlan3)
assertEqualPlans(originalPlan, bestPlan)

val originalPlan4 =
t2.hint("broadcast")
.join(t4).hint("broadcast")
.join(t3.hint("broadcast")).hint("broadcast")
.join(t1)
.where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
val originalPlan2 =
t1.join(t2).join(t3)
.where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t4.hint("broadcast"))

val bestPlan4 =
t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.join(
t4.join(t3.hint("broadcast"),
Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
.select(outputsOf(t2, t4, t3, t1): _*)
val bestPlan2 =
t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.select(outputsOf(t1, t2, t3): _*)
.join(t4.hint("broadcast"))

assertEqualPlans(originalPlan4, bestPlan4)
assertEqualPlans(originalPlan2, bestPlan2)
}

private def assertEqualPlans(
Expand Down
26 changes: 21 additions & 5 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext

class JoinHintSuite extends PlanTest with SharedSQLContext {
Expand Down Expand Up @@ -100,7 +101,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
}
}

test("hint preserved after join reorder") {
test("hints prevent join reorder") {
withTempView("a", "b", "c") {
df1.createOrReplaceTempView("a")
df2.createOrReplaceTempView("b")
Expand All @@ -118,12 +119,10 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
verifyJoinHint(
sql("select /*+ broadcast(a, c)*/ * from a, c, b " +
"where a.a1 = b.b1 and b.b1 = c.c1"),
JoinHint(
None,
Some(HintInfo(broadcast = true))) ::
JoinHint.NONE ::
JoinHint(
Some(HintInfo(broadcast = true)),
None):: Nil
Some(HintInfo(broadcast = true))):: Nil
)
verifyJoinHint(
sql("select /*+ broadcast(b, c)*/ * from a, c, b " +
Expand Down Expand Up @@ -199,4 +198,21 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
None) :: Nil
)
}

test("hints prevent cost-based join reorder") {
withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") {
val join = df.join(df, "id")
val broadcasted = join.hint("broadcast")
verifyJoinHint(
join.join(broadcasted, "id").join(broadcasted, "id"),
JoinHint(
None,
Some(HintInfo(broadcast = true))) ::
JoinHint(
None,
Some(HintInfo(broadcast = true))) ::
JoinHint.NONE :: JoinHint.NONE :: JoinHint.NONE :: Nil
)
}
}
}

0 comments on commit 985f966

Please sign in to comment.