Skip to content

Commit

Permalink
[SPARK-26078][SQL] Dedup self-join attributes on subqueries
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Nov 16, 2018
1 parent 4ac8f9b commit 2af656a
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
val aliasMap = AttributeMap(duplicates.map { dup =>
dup -> Alias(dup, dup.toString)()
}.toSeq)
val aliasedExpressions = right.output.map { ref =>
aliasMap.getOrElse(ref, ref)
}
val newRight = Project(aliasedExpressions, right)
val newRight = rewriteDedupPlan(right, aliasMap)
val newJoinCond = joinCond.map { condExpr =>
condExpr transform {
case a: Attribute => aliasMap.getOrElse(a, a).toAttribute
Expand All @@ -70,6 +67,27 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
case _ => joinPlan
}

private def rewriteDedupPlan(plan: LogicalPlan, rewrites: AttributeMap[Alias]): LogicalPlan = {
val aliasedExpressions = plan.output.map { ref =>
rewrites.getOrElse(ref, ref)
}
Project(aliasedExpressions, plan)
}

private def dedupSubqueryOnSelfJoin(values: Seq[Expression], sub: LogicalPlan): LogicalPlan = {
val leftRefs = AttributeSet.fromAttributeSets(values.map(_.references))
val rightRefs = AttributeSet(sub.output)
val duplicates = leftRefs.intersect(rightRefs)
if (duplicates.isEmpty) {
sub
} else {
val aliasMap = AttributeMap(duplicates.map { dup =>
dup -> Alias(dup, dup.toString)()
}.toSeq)
rewriteDedupPlan(sub, aliasMap)
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Filter(condition, child) =>
val (withSubquery, withoutSubquery) =
Expand All @@ -92,18 +110,20 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
val newSub = dedupSubqueryOnSelfJoin(values, sub)
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
dedupJoin(Join(outerPlan, newSub, LeftSemi, joinCond))
case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.

// Note that will almost certainly be planned as a Broadcast Nested Loop join.
// Use EXISTS if performance matters to you.
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
val newSub = dedupSubqueryOnSelfJoin(values, sub)
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p)
// Expand the NOT IN expression with the NULL-aware semantic
// to its full form. That is from:
Expand All @@ -119,7 +139,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1
val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond)))
dedupJoin(Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond)))
case (p, predicate) =>
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
Expand Down
31 changes: 31 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort}
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.SharedSQLContext

class SubquerySuite extends QueryTest with SharedSQLContext {
Expand Down Expand Up @@ -1280,4 +1281,34 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
assert(subqueries.length == 1)
}
}

test("SPARK-26078: deduplicate fake self joins for IN subqueries") {
withTempView("a", "b") {
val a = spark.createDataFrame(spark.sparkContext.parallelize(Seq(Row("a", 2), Row("b", 1))),
StructType(Seq(StructField("id", StringType), StructField("num", IntegerType))))
val b = spark.createDataFrame(spark.sparkContext.parallelize(Seq(Row("a", 2), Row("b", 1))),
StructType(Seq(StructField("id", StringType), StructField("num", IntegerType))))
a.createOrReplaceTempView("a")
b.createOrReplaceTempView("b")

val df1 = spark.sql(
"""
|SELECT id,num,source FROM (
| SELECT id, num, 'a' as source FROM a
| UNION ALL
| SELECT id, num, 'b' as source FROM b
|) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2)
""".stripMargin)
checkAnswer(df1, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
val df2 = spark.sql(
"""
|SELECT id,num,source FROM (
| SELECT id, num, 'a' as source FROM a
| UNION ALL
| SELECT id, num, 'b' as source FROM b
|) AS c WHERE c.id NOT IN (SELECT id FROM b WHERE num = 2)
""".stripMargin)
checkAnswer(df2, Seq(Row("b", 1, "a"), Row("b", 1, "b")))
}
}
}

0 comments on commit 2af656a

Please sign in to comment.