Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-26078][SQL] Dedup self-join attributes on IN subqueries #23057

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand All @@ -43,31 +43,53 @@ import org.apache.spark.sql.types._
* condition.
*/
object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match {

private def buildJoin(
outerPlan: LogicalPlan,
subplan: LogicalPlan,
joinType: JoinType,
condition: Option[Expression]): Join = {
// Deduplicate conflicting attributes if any.
val dedupSubplan = dedupSubqueryOnSelfJoin(outerPlan, subplan, None, condition)
Join(outerPlan, dedupSubplan, joinType, condition)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add an assert to make sure the condition doesn't contain conflicting attributes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this: how do we check it? If the same attribute is present on both sides of a BinaryOperator? Is this always forbidden?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to refactor the code a little bit

...
val duplicates = outerRefs.intersect(subplan.outputSet)
condition.foreach {
  case a: Attribute if duplicates.contains(a) => fail
  case _ =>
}
...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean now. I'll do that, thanks.

}

private def dedupSubqueryOnSelfJoin(
outerPlan: LogicalPlan,
subplan: LogicalPlan,
valuesOpt: Option[Seq[Expression]],
condition: Option[Expression] = None): LogicalPlan = {
// SPARK-21835: It is possibly that the two sides of the join have conflicting attributes,
// the produced join then becomes unresolved and break structural integrity. We should
// de-duplicate conflicting attributes. We don't use transformation here because we only
// care about the most top join converted from correlated predicate subquery.
case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti | ExistenceJoin(_)), joinCond) =>
val duplicates = right.outputSet.intersect(left.outputSet)
if (duplicates.nonEmpty) {
val aliasMap = AttributeMap(duplicates.map { dup =>
dup -> Alias(dup, dup.toString)()
}.toSeq)
val aliasedExpressions = right.output.map { ref =>
aliasMap.getOrElse(ref, ref)
}
val newRight = Project(aliasedExpressions, right)
val newJoinCond = joinCond.map { condExpr =>
condExpr transform {
case a: Attribute => aliasMap.getOrElse(a, a).toAttribute
// de-duplicate conflicting attributes.
// SPARK-26078: it may also happen that the subquery has conflicting attributes with the outer
// values. In this case, the resulting join would contain trivially true conditions (eg.
// id#3 = id#3) which cannot be de-duplicated after. In this method, if there are conflicting
// attributes in the join condition, the subquery's conflicting attributes are changed using
// a projection which aliases them and resolves the problem.
val outerReferences = valuesOpt.map(values =>
AttributeSet.fromAttributeSets(values.map(_.references))).getOrElse(AttributeSet.empty)
val outerRefs = outerPlan.outputSet ++ outerReferences
val duplicates = outerRefs.intersect(subplan.outputSet)
if (duplicates.nonEmpty) {
condition.foreach { e =>
val conflictingAttrs = e.references.intersect(duplicates)
if (conflictingAttrs.nonEmpty) {
throw new AnalysisException("Found conflicting attributes " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for curiosity, when can this happen? or how we guarantee this will never happen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can happen in case the condition is built in advance (eg. the correlated condition of exists) and it contains some attribute which is not dedup. I am not sure if this scenario can actually happen or our dedup logic in the previous rules guarantees this will never happen, though.

s"${conflictingAttrs.mkString(",")} in the condition joining outer plan:\n " +
s"$outerPlan\nand subplan:\n $subplan")
}
}
Join(left, newRight, joinType, newJoinCond)
} else {
j
}
case _ => joinPlan
val rewrites = AttributeMap(duplicates.map { dup =>
dup -> Alias(dup, dup.toString)()
}.toSeq)
val aliasedExpressions = subplan.output.map { ref =>
rewrites.getOrElse(ref, ref)
}
Project(aliasedExpressions, subplan)
} else {
subplan
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
Expand All @@ -85,25 +107,27 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
withSubquery.foldLeft(newFilter) {
case (p, Exists(sub, conditions, _)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
buildJoin(outerPlan, sub, LeftSemi, joinCond)
case (p, Not(Exists(sub, conditions, _))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
buildJoin(outerPlan, sub, LeftAnti, joinCond)
case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
Join(outerPlan, newSub, LeftSemi, joinCond)
case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.

// Note that will almost certainly be planned as a Broadcast Nested Loop join.
// Use EXISTS if performance matters to you.
val inConditions = values.zip(sub.output).map(EqualTo.tupled)

// Deduplicate conflicting attributes if any.
val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p)
// Expand the NOT IN expression with the NULL-aware semantic
// to its full form. That is from:
Expand All @@ -118,8 +142,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// will have the final conditions in the LEFT ANTI as
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1
val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond)))
Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond))
case (p, predicate) =>
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In rewriteExistentialExpr, there is a similar logic for InSubquery. Should we also do dedupSubqueryOnSelfJoin for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmmh...rewriteExistentialExpr operates on the result of the foldLeft,so every InSubquery there was already transformed using dedupSubqueryOnSelfJoin, right? So I don't think it is needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try this test case?

val df1 = spark.sql(
        """
          |SELECT id,num,source FROM (
          |  SELECT id, num, 'a' as source FROM a
          |  UNION ALL
          |  SELECT id, num, 'b' as source FROM b
          |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) OR
          |c.id IN (SELECT id FROM b WHERE num = 3)
        """.stripMargin)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fails indeed. I'll investigate it, thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for your help here @viirya. I added the check also to rewriteExistentialExpr. I was missing the case when it is invoked not only on the result of foldLeft. Thanks.

Expand All @@ -140,16 +163,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
e transformUp {
case Exists(sub, conditions, _) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
// Deduplicate conflicting attributes if any.
newPlan = dedupJoin(
Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)))
newPlan =
buildJoin(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
exists
case InSubquery(values, ListQuery(sub, conditions, _, _)) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
// Deduplicate conflicting attributes if any.
newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions))
val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions)
exists
}
}
Expand Down
37 changes: 37 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._

class SubquerySuite extends QueryTest with SharedSQLContext {
import testImplicits._
Expand Down Expand Up @@ -1280,4 +1281,40 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
assert(subqueries.length == 1)
}
}

test("SPARK-26078: deduplicate fake self joins for IN subqueries") {
withTempView("a", "b") {
Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a")
Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b")

val df1 = spark.sql(
"""
|SELECT id,num,source FROM (
| SELECT id, num, 'a' as source FROM a
| UNION ALL
| SELECT id, num, 'b' as source FROM b
|) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2)
""".stripMargin)
checkAnswer(df1, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
val df2 = spark.sql(
"""
|SELECT id,num,source FROM (
| SELECT id, num, 'a' as source FROM a
| UNION ALL
| SELECT id, num, 'b' as source FROM b
|) AS c WHERE c.id NOT IN (SELECT id FROM b WHERE num = 2)
""".stripMargin)
checkAnswer(df2, Seq(Row("b", 1, "a"), Row("b", 1, "b")))
val df3 = spark.sql(
"""
|SELECT id,num,source FROM (
| SELECT id, num, 'a' as source FROM a
| UNION ALL
| SELECT id, num, 'b' as source FROM b
|) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) OR
|c.id IN (SELECT id FROM b WHERE num = 3)
""".stripMargin)
checkAnswer(df3, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
}
}
}