Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33798][SQL] Add new rule to push down the foldable expressions through CaseWhen/If #30790

Closed
wants to merge 12 commits into from

Conversation

wangyum
Copy link
Member

@wangyum wangyum commented Dec 16, 2020

What changes were proposed in this pull request?

This pr add a new rule(PushFoldableIntoBranches) to push down the foldable expressions through CaseWhen/If. This is a real case from production:

create table t1 using parquet as select * from range(100);
create table t2 using parquet as select * from range(200);

create temp view v1 as                                                             
select 'a' as event_type, * from t1                                                
union all                                                                          
select CASE WHEN id = 1 THEN 'b' WHEN id = 3 THEN 'c' end as event_type, * from t2 

explain select * from v1 where event_type = 'a';

Before this PR:

== Physical Plan ==
Union
:- *(1) Project [a AS event_type#30533, id#30535L]
:  +- *(1) ColumnarToRow
:     +- FileScan parquet default.t1[id#30535L] Batched: true, DataFilters: [], Format: Parquet
+- *(2) Project [CASE WHEN (id#30536L = 1) THEN b WHEN (id#30536L = 3) THEN c END AS event_type#30534, id#30536L]
   +- *(2) Filter (CASE WHEN (id#30536L = 1) THEN b WHEN (id#30536L = 3) THEN c END = a)
      +- *(2) ColumnarToRow
         +- FileScan parquet default.t2[id#30536L] Batched: true, DataFilters: [(CASE WHEN (id#30536L = 1) THEN b WHEN (id#30536L = 3) THEN c END = a)], Format: Parquet

After this PR:

== Physical Plan ==
*(1) Project [a AS event_type#8, id#4L]
+- *(1) ColumnarToRow
   +- FileScan parquet default.t1[id#4L] Batched: true, DataFilters: [], Format: Parquet

Why are the changes needed?

Improve query performance.

Does this PR introduce any user-facing change?

No.

How was this patch tested?

Unit test.

Comment on lines 276 to 281
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)),
FalseLiteral)
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)),
FalseLiteral)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SPARK-33798 will change these behaviors as expected.

Comment on lines 283 to 293
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)),
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)))
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
Literal(1)),
Literal(null, BooleanType))
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
Literal(null, IntegerType)),
Literal(null, BooleanType))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SPARK-33798 will not change these behaviors.

@github-actions github-actions bot added the SQL label Dec 16, 2020
@SparkQA
Copy link

SparkQA commented Dec 16, 2020

Test build #132846 has finished for PR 30790 at commit 19b0a83.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

FalseLiteral

case EqualTo(c @ CaseWhen(branches, elseValue), right: Literal) if c.deterministic &&
(branches.map(_._2) ++ elseValue).forall(_.isInstanceOf[Literal]) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do the "is literal check" inside isAlwaysFalse, to shorten the code.

assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1)))

// Do not simplify if it contains non-deterministic expressions.
val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a case that we should optimize. But we need to distinguish between non-deterministic and has-side-effect. While we can skip running non-deterministic expressions, we can't skip running has-side-effect expressions.

This is not related to this PR, but is something worth considering. cc @viirya @dbtsai @maropu @rednaxelafx


// null check, SPARK-33798 will not change these behaviors.
assertEquivalent(
EqualTo(If(FalseLiteral, Literal(null, IntegerType), Literal(1)), Literal(1)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if condition is a literal, does it really go through the new branch? will it be optimized earlier by some other logics?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to still use a === Literal(1) as the if conditionn, and make sure the expression is not optimized if there are null values.

@SparkQA
Copy link

SparkQA commented Dec 16, 2020

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37448/

@SparkQA
Copy link

SparkQA commented Dec 16, 2020

Kubernetes integration test status failure
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37448/

@SparkQA
Copy link

SparkQA commented Dec 16, 2020

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37463/

@SparkQA
Copy link

SparkQA commented Dec 16, 2020

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37463/

@@ -470,6 +470,11 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case _ => false
}

private def isAlwaysFalse(exps: Seq[Expression], equalTo: Literal): Boolean = {
exps.forall(_.isInstanceOf[Literal]) &&
exps.forall(!EqualTo(_, equalTo).eval(EmptyRow).asInstanceOf[Boolean])
Copy link
Contributor

@cloud-fan cloud-fan Dec 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EqualTo may return null, we need to take care of it, because null is not false in SQL.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Null value should be handled by NullPropagation.

Copy link
Member

@dongjoon-hyun dongjoon-hyun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making a new one, @wangyum .

@@ -523,6 +529,17 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
} else {
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
}

// Null value should be handled by NullPropagation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to make optimizer rules orthogonal and don't rely on each other. Let's make isAlwaysFalse stricter:

exps.forall  { e =>
  val res = EqualTo(_, equalTo).eval(EmptyRow)
  res != null && !res.asInstanceOf[Boolean]
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:

exps.forall {
  case l: Literal =>
    val res = EqualTo(l, equalTo).eval(EmptyRow)
    res != null && !res.asInstanceOf[Boolean]
  case _ => false
}

otherwise this test will fail:

val nonFoldable = CaseWhen(Seq(normalBranch, (a, b)), None)
assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1)))

@SparkQA
Copy link

SparkQA commented Dec 16, 2020

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37479/

@SparkQA
Copy link

SparkQA commented Dec 16, 2020

Test build #132861 has finished for PR 30790 at commit 831bf66.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 16, 2020

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37479/

@SparkQA
Copy link

SparkQA commented Dec 16, 2020

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37485/

@SparkQA
Copy link

SparkQA commented Dec 16, 2020

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37485/

@SparkQA
Copy link

SparkQA commented Dec 16, 2020

Test build #132878 has finished for PR 30790 at commit 859893d.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

test("SPARK-33798: simplify EqualTo(If, Literal) always false") {
val a = EqualTo(UnresolvedAttribute("a"), Literal(100))
val b = UnresolvedAttribute("b")
val ifExp = If(a === Literal(1), Literal(2), Literal(3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure the test expression is valid. a is an EqualTo, comparing boolean and Literal(1) is invalid.


assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
assertEquivalent(EqualTo(ifExp, Literal(3)), EqualTo(ifExp, Literal(3)))
assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, let's don't test invalid expressions, ifExp.dataType is int, not string.

assertEquivalent(
EqualTo(If(a, b, Literal(2)), Literal(3)),
EqualTo(If(a, b, Literal(2)), Literal(3)))
val nonFoldable = If(NonFoldableLiteral(true), Literal(1), Literal(2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NonFoldableLiteral is nothing special comparing to 'a == 100, as they are both non-foldable. I think we don't need to test with NonFoldableLiteral

* Push the foldable expression into (if / case) branches.
*/
object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we cannot use transformAllExpressions here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This copied from

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I think its okay to simply use transformAllExpressions here.

@SparkQA
Copy link

SparkQA commented Dec 17, 2020

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37542/

@SparkQA
Copy link

SparkQA commented Dec 17, 2020

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37542/

@SparkQA
Copy link

SparkQA commented Dec 17, 2020

Test build #132935 has finished for PR 30790 at commit f9f622f.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 17, 2020

Test build #132939 has finished for PR 30790 at commit 8ccc3c1.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

falseValue = b.makeCopy(Array(falseValue, right)))

case b @ BinaryComparison(c @ CaseWhen(branches, elseValue), right) if c.deterministic &&
right.foldable && (branches.map(_._2) ++ elseValue).forall(_.foldable) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some more ideas:

  1. I think it's helpful to pushdown foldables into branches as long as there is at lease one foldable branche. To be conservative, we can do the pushdown if the IF/CASE WHEN has at most one non-foldable branch, so that we never increase the expression tree size, so is the generated code size.
  2. I think it's also useful to pushdown expressions like Add, e.g. IF(cond, 1, 2) +1 -> IF(cond, 2, 3). We can use BinnaryExpression instead of BinaryComparison.

/**
* Push the foldable expression into (if / case) branches.
*/
object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice idea.

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right) if i.deterministic &&
right.foldable && (trueValue.foldable || falseValue.foldable) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I missed something but can you explain why we need this foldable condition?

Shouldn't bin_op((if cond: a else: b), c) be same as if cond: bin_op(a, c) else: bin_op(b, c) as long as the results are deterministic?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, okay but we just want to push down foldables only that can have benefits. gotya.

@SparkQA
Copy link

SparkQA commented Dec 18, 2020

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37583/

@SparkQA
Copy link

SparkQA commented Dec 18, 2020

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37583/

trueValue = b.makeCopy(Array(trueValue, right)),
falseValue = b.makeCopy(Array(falseValue, right)))

case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue)) if i.deterministic &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we care about deterministic here? This optimization doesn't change the execution order of the branch conditions and/or branch values.

falseValue = b.makeCopy(Array(falseValue, right)))

case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue)) if i.deterministic &&
left.foldable && (trueValue.foldable || falseValue.foldable) =>
Copy link
Contributor

@cloud-fan cloud-fan Dec 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's create a method to count foldables, with comments to explain why

// be conservative here: it's only a guaranteed win if all but at most only one branch
// end up being not foldable.
private def atMostOneUnfoldable(exprs: Seq[Expression]): Boolean = ...

then here left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue))

Copy link
Member Author

@wangyum wangyum Dec 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to push down this case, only one branch and non foldable:

assertEquivalent(EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)), CaseWhen(Seq((a, EqualTo(b, Literal(1)))), None))

Copy link
Member Author

@wangyum wangyum Dec 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about?

  // To be conservative here: it's only a guaranteed win if all but at most only one branch
  // end up being not foldable.
  private def atMostOneUnfoldable(exprs: Seq[Expression]): Boolean = {
    val (foldables, others) = exprs.partition(_.foldable)
    foldables.nonEmpty && others.length < 2
  }

comparePlans(actual, correctAnswer)
}

test("SPARK-33798: Push down EqualTo through If") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: when creating a new test suite, we don't need to add the JIRA ID as prefix.


test("SPARK-33798: Push down EqualTo through If") {
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If(a, FalseLiteral, TrueLiteral) can be turned into !a. We can optimize it in followups.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test("SPARK-33798: Push down EqualTo through If") {
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral))
assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not test invalid expressions, EqualTo can't compare int and string.

assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral)
assertEquivalent(EqualTo(caseWhen, Literal(3)),
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral)))
assertEquivalent(EqualTo(caseWhen, Literal("4")), FalseLiteral)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, don't test invalid expressions.

private val b = UnresolvedAttribute("b")
private val c = EqualTo(UnresolvedAttribute("c"), Literal(true))
private val ifExp = If(a, Literal(2), Literal(3))
private val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without the else value, seems like we can't optimize CASE WHEN cond1 THEN false WHEN cond2 THEN false into false. We should fix it in followups.

This comment was marked as spam.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right) if i.deterministic &&
right.foldable && (trueValue.foldable || falseValue.foldable) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good idea when both trueValue and falseValue are foldable like the example in the description.

But if only trueValue is foldable? For example, from If(_, 'b', col1) = 'a', we get If(_, 'b' = 'a', col1 = 'a'), are some benefits here too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The differences in generated code:
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not sure if the difference means obvious benefits. Looks like it should be very close but I think I'm okay with it.

@SparkQA
Copy link

SparkQA commented Dec 18, 2020

Test build #132982 has finished for PR 30790 at commit 45b56fc.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 18, 2020

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37604/

@SparkQA
Copy link

SparkQA commented Dec 18, 2020

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37604/

@SparkQA
Copy link

SparkQA commented Dec 18, 2020

Test build #133005 has finished for PR 30790 at commit cfac0e8.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@cloud-fan
Copy link
Contributor

thanks, merging to master!

@cloud-fan cloud-fan closed this in 06b1bbb Dec 18, 2020
@wangyum wangyum deleted the SPARK-33798 branch December 18, 2020 16:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants