Skip to content

Commit

Permalink
[SPARK-26078][SQL][BACKPORT-2.4] Dedup self-join attributes on IN sub…
Browse files Browse the repository at this point in the history
…queries

## What changes were proposed in this pull request?

When there is a self-join as result of a IN subquery, the join condition may be invalid, resulting in trivially true predicates and return wrong results.

The PR deduplicates the subquery output in order to avoid the issue.

## How was this patch tested?

added UT

Closes apache#23449 from mgaido91/SPARK-26078_2.4.

Authored-by: Marco Gaido <marcogaido91@gmail.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
mgaido91 authored and kai-chi committed Aug 1, 2019
1 parent 5c06e00 commit 1cc9e2a
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand All @@ -43,31 +43,53 @@ import org.apache.spark.sql.types._
* condition.
*/
object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match {

private def buildJoin(
outerPlan: LogicalPlan,
subplan: LogicalPlan,
joinType: JoinType,
condition: Option[Expression]): Join = {
// Deduplicate conflicting attributes if any.
val dedupSubplan = dedupSubqueryOnSelfJoin(outerPlan, subplan, None, condition)
Join(outerPlan, dedupSubplan, joinType, condition)
}

private def dedupSubqueryOnSelfJoin(
outerPlan: LogicalPlan,
subplan: LogicalPlan,
valuesOpt: Option[Seq[Expression]],
condition: Option[Expression] = None): LogicalPlan = {
// SPARK-21835: It is possibly that the two sides of the join have conflicting attributes,
// the produced join then becomes unresolved and break structural integrity. We should
// de-duplicate conflicting attributes. We don't use transformation here because we only
// care about the most top join converted from correlated predicate subquery.
case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti | ExistenceJoin(_)), joinCond) =>
val duplicates = right.outputSet.intersect(left.outputSet)
if (duplicates.nonEmpty) {
val aliasMap = AttributeMap(duplicates.map { dup =>
dup -> Alias(dup, dup.toString)()
}.toSeq)
val aliasedExpressions = right.output.map { ref =>
aliasMap.getOrElse(ref, ref)
}
val newRight = Project(aliasedExpressions, right)
val newJoinCond = joinCond.map { condExpr =>
condExpr transform {
case a: Attribute => aliasMap.getOrElse(a, a).toAttribute
// de-duplicate conflicting attributes.
// SPARK-26078: it may also happen that the subquery has conflicting attributes with the outer
// values. In this case, the resulting join would contain trivially true conditions (eg.
// id#3 = id#3) which cannot be de-duplicated after. In this method, if there are conflicting
// attributes in the join condition, the subquery's conflicting attributes are changed using
// a projection which aliases them and resolves the problem.
val outerReferences = valuesOpt.map(values =>
AttributeSet(values.flatMap(_.references))).getOrElse(AttributeSet.empty)
val outerRefs = outerPlan.outputSet ++ outerReferences
val duplicates = outerRefs.intersect(subplan.outputSet)
if (duplicates.nonEmpty) {
condition.foreach { e =>
val conflictingAttrs = e.references.intersect(duplicates)
if (conflictingAttrs.nonEmpty) {
throw new AnalysisException("Found conflicting attributes " +
s"${conflictingAttrs.mkString(",")} in the condition joining outer plan:\n " +
s"$outerPlan\nand subplan:\n $subplan")
}
}
Join(left, newRight, joinType, newJoinCond)
} else {
j
}
case _ => joinPlan
val rewrites = AttributeMap(duplicates.map { dup =>
dup -> Alias(dup, dup.toString)()
}.toSeq)
val aliasedExpressions = subplan.output.map { ref =>
rewrites.getOrElse(ref, ref)
}
Project(aliasedExpressions, subplan)
} else {
subplan
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
Expand All @@ -85,25 +107,27 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
withSubquery.foldLeft(newFilter) {
case (p, Exists(sub, conditions, _)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
buildJoin(outerPlan, sub, LeftSemi, joinCond)
case (p, Not(Exists(sub, conditions, _))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
buildJoin(outerPlan, sub, LeftAnti, joinCond)
case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
Join(outerPlan, newSub, LeftSemi, joinCond)
case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.

// Note that will almost certainly be planned as a Broadcast Nested Loop join.
// Use EXISTS if performance matters to you.
val inConditions = values.zip(sub.output).map(EqualTo.tupled)

// Deduplicate conflicting attributes if any.
val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p)
// Expand the NOT IN expression with the NULL-aware semantic
// to its full form. That is from:
Expand All @@ -118,8 +142,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// will have the final conditions in the LEFT ANTI as
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1
val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond)))
Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond))
case (p, predicate) =>
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
Expand All @@ -140,16 +163,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
e transformUp {
case Exists(sub, conditions, _) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
// Deduplicate conflicting attributes if any.
newPlan = dedupJoin(
Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)))
newPlan =
buildJoin(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
exists
case InSubquery(values, ListQuery(sub, conditions, _, _)) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
// Deduplicate conflicting attributes if any.
newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions))
val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions)
exists
}
}
Expand Down
36 changes: 36 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1268,4 +1268,40 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
assert(getNumSortsInQuery(query5) == 1)
}
}

test("SPARK-26078: deduplicate fake self joins for IN subqueries") {
withTempView("a", "b") {
Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a")
Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b")

val df1 = spark.sql(
"""
|SELECT id,num,source FROM (
| SELECT id, num, 'a' as source FROM a
| UNION ALL
| SELECT id, num, 'b' as source FROM b
|) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2)
""".stripMargin)
checkAnswer(df1, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
val df2 = spark.sql(
"""
|SELECT id,num,source FROM (
| SELECT id, num, 'a' as source FROM a
| UNION ALL
| SELECT id, num, 'b' as source FROM b
|) AS c WHERE c.id NOT IN (SELECT id FROM b WHERE num = 2)
""".stripMargin)
checkAnswer(df2, Seq(Row("b", 1, "a"), Row("b", 1, "b")))
val df3 = spark.sql(
"""
|SELECT id,num,source FROM (
| SELECT id, num, 'a' as source FROM a
| UNION ALL
| SELECT id, num, 'b' as source FROM b
|) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) OR
|c.id IN (SELECT id FROM b WHERE num = 3)
""".stripMargin)
checkAnswer(df3, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
}
}
}

0 comments on commit 1cc9e2a

Please sign in to comment.