Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-26078][SQL] Dedup self-join attributes on IN subqueries #23057

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
val aliasMap = AttributeMap(duplicates.map { dup =>
dup -> Alias(dup, dup.toString)()
}.toSeq)
val aliasedExpressions = right.output.map { ref =>
aliasMap.getOrElse(ref, ref)
}
val newRight = Project(aliasedExpressions, right)
val newRight = rewriteDedupPlan(right, aliasMap)
val newJoinCond = joinCond.map { condExpr =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to your PR, but is this corrected? The duplicated attributes in join condition may refer to left or right child, how can we blindly replace them with new attributes from right side?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I actually think this is useless. Let me try and remove it.

condExpr transform {
case a: Attribute => aliasMap.getOrElse(a, a).toAttribute
Expand All @@ -70,6 +67,27 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
case _ => joinPlan
}

private def rewriteDedupPlan(plan: LogicalPlan, rewrites: AttributeMap[Alias]): LogicalPlan = {
val aliasedExpressions = plan.output.map { ref =>
rewrites.getOrElse(ref, ref)
}
Project(aliasedExpressions, plan)
}

private def dedupSubqueryOnSelfJoin(values: Seq[Expression], sub: LogicalPlan): LogicalPlan = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a simple code comment for this method?

val leftRefs = AttributeSet.fromAttributeSets(values.map(_.references))
val rightRefs = AttributeSet(sub.output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just outputSet?

val duplicates = leftRefs.intersect(rightRefs)
if (duplicates.isEmpty) {
sub
} else {
val aliasMap = AttributeMap(duplicates.map { dup =>
dup -> Alias(dup, dup.toString)()
}.toSeq)
rewriteDedupPlan(sub, aliasMap)
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Filter(condition, child) =>
val (withSubquery, withoutSubquery) =
Expand All @@ -92,18 +110,20 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we don't need dedupJoin, but always dedup the subquery before putting it in a join.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to dedup the subquery only when the join condition has not been created yet (so in the case of InSubquery). In this case, the condition is already there, so I think we still have to use dedupJoin (for Exists)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dedupJoin will evetually dedup the subquery IIUC.

What I'd like to do is to unify dedupJoin and dedupSubqueryOnSelfJoin, so that the code will be consistent for all cases:

val newSub = dedup(sub, values)
// create join condition if any
Join(outerPlan, newSub, ...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the main problem is that in the other cases, so when exists is there, the condition is already created. So we would need to complicate quite a lot the method in order to handle the 2 cases and I am not sure wether it is worth. For instance, the values, in the Exists case, should be taken from the conditions as those expressions referencing attributes from one side and the join condition needs to be rewritten. So I don't think that it is a good idea to have a common rewrite for both them: it would be overcomplicated IMHO.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah thanks for the explanation! Makes sense to me.

case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
val newSub = dedupSubqueryOnSelfJoin(values, sub)
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
dedupJoin(Join(outerPlan, newSub, LeftSemi, joinCond))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need to dedup here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't, let me remove it.

case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.

// Note that will almost certainly be planned as a Broadcast Nested Loop join.
// Use EXISTS if performance matters to you.
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
val newSub = dedupSubqueryOnSelfJoin(values, sub)
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p)
// Expand the NOT IN expression with the NULL-aware semantic
// to its full form. That is from:
Expand All @@ -119,7 +139,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1
val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond)))
dedupJoin(Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond)))
case (p, predicate) =>
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In rewriteExistentialExpr, there is a similar logic for InSubquery. Should we also do dedupSubqueryOnSelfJoin for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmmh...rewriteExistentialExpr operates on the result of the foldLeft,so every InSubquery there was already transformed using dedupSubqueryOnSelfJoin, right? So I don't think it is needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try this test case?

val df1 = spark.sql(
        """
          |SELECT id,num,source FROM (
          |  SELECT id, num, 'a' as source FROM a
          |  UNION ALL
          |  SELECT id, num, 'b' as source FROM b
          |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) OR
          |c.id IN (SELECT id FROM b WHERE num = 3)
        """.stripMargin)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fails indeed. I'll investigate it, thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for your help here @viirya. I added the check also to rewriteExistentialExpr. I was missing the case when it is invoked not only on the result of foldLeft. Thanks.

Expand Down
31 changes: 31 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._

class SubquerySuite extends QueryTest with SharedSQLContext {
import testImplicits._
Expand Down Expand Up @@ -1280,4 +1281,34 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
assert(subqueries.length == 1)
}
}

test("SPARK-26078: deduplicate fake self joins for IN subqueries") {
withTempView("a", "b") {
val a = spark.createDataFrame(spark.sparkContext.parallelize(Seq(Row("a", 2), Row("b", 1))),
StructType(Seq(StructField("id", StringType), StructField("num", IntegerType))))
val b = spark.createDataFrame(spark.sparkContext.parallelize(Seq(Row("a", 2), Row("b", 1))),
StructType(Seq(StructField("id", StringType), StructField("num", IntegerType))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two schema is the same. We can define it just once?

a.createOrReplaceTempView("a")
b.createOrReplaceTempView("b")

val df1 = spark.sql(
"""
|SELECT id,num,source FROM (
| SELECT id, num, 'a' as source FROM a
| UNION ALL
| SELECT id, num, 'b' as source FROM b
|) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2)
""".stripMargin)
checkAnswer(df1, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
val df2 = spark.sql(
"""
|SELECT id,num,source FROM (
| SELECT id, num, 'a' as source FROM a
| UNION ALL
| SELECT id, num, 'b' as source FROM b
|) AS c WHERE c.id NOT IN (SELECT id FROM b WHERE num = 2)
""".stripMargin)
checkAnswer(df2, Seq(Row("b", 1, "a"), Row("b", 1, "b")))
}
}
}