-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-33798][SQL] Add new rule to push down the foldable expressions through CaseWhen/If #30790
Conversation
...talyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
Outdated
Show resolved
Hide resolved
assertEquivalent( | ||
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)), | ||
FalseLiteral) | ||
assertEquivalent( | ||
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)), | ||
FalseLiteral) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SPARK-33798 will change these behaviors as expected.
assertEquivalent( | ||
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)), | ||
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1))) | ||
assertEquivalent( | ||
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), | ||
Literal(1)), | ||
Literal(null, BooleanType)) | ||
assertEquivalent( | ||
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), | ||
Literal(null, IntegerType)), | ||
Literal(null, BooleanType)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SPARK-33798 will not change these behaviors.
Test build #132846 has finished for PR 30790 at commit
|
FalseLiteral | ||
|
||
case EqualTo(c @ CaseWhen(branches, elseValue), right: Literal) if c.deterministic && | ||
(branches.map(_._2) ++ elseValue).forall(_.isInstanceOf[Literal]) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can do the "is literal check" inside isAlwaysFalse
, to shorten the code.
assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1))) | ||
|
||
// Do not simplify if it contains non-deterministic expressions. | ||
val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a case that we should optimize. But we need to distinguish between non-deterministic and has-side-effect. While we can skip running non-deterministic expressions, we can't skip running has-side-effect expressions.
This is not related to this PR, but is something worth considering. cc @viirya @dbtsai @maropu @rednaxelafx
|
||
// null check, SPARK-33798 will not change these behaviors. | ||
assertEquivalent( | ||
EqualTo(If(FalseLiteral, Literal(null, IntegerType), Literal(1)), Literal(1)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if condition is a literal, does it really go through the new branch? will it be optimized earlier by some other logics?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer to still use a === Literal(1)
as the if conditionn, and make sure the expression is not optimized if there are null values.
Kubernetes integration test starting |
Kubernetes integration test status failure |
Kubernetes integration test starting |
...talyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
Outdated
Show resolved
Hide resolved
Kubernetes integration test status success |
@@ -470,6 +470,11 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { | |||
case _ => false | |||
} | |||
|
|||
private def isAlwaysFalse(exps: Seq[Expression], equalTo: Literal): Boolean = { | |||
exps.forall(_.isInstanceOf[Literal]) && | |||
exps.forall(!EqualTo(_, equalTo).eval(EmptyRow).asInstanceOf[Boolean]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EqualTo
may return null, we need to take care of it, because null is not false in SQL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Null value should be handled by NullPropagation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for making a new one, @wangyum .
@@ -523,6 +529,17 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { | |||
} else { | |||
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) | |||
} | |||
|
|||
// Null value should be handled by NullPropagation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to make optimizer rules orthogonal and don't rely on each other. Let's make isAlwaysFalse
stricter:
exps.forall { e =>
val res = EqualTo(_, equalTo).eval(EmptyRow)
res != null && !res.asInstanceOf[Boolean]
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about:
exps.forall {
case l: Literal =>
val res = EqualTo(l, equalTo).eval(EmptyRow)
res != null && !res.asInstanceOf[Boolean]
case _ => false
}
otherwise this test will fail:
val nonFoldable = CaseWhen(Seq(normalBranch, (a, b)), None)
assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1)))
Kubernetes integration test starting |
Test build #132861 has finished for PR 30790 at commit
|
Kubernetes integration test status success |
Kubernetes integration test starting |
Kubernetes integration test status success |
Test build #132878 has finished for PR 30790 at commit
|
test("SPARK-33798: simplify EqualTo(If, Literal) always false") { | ||
val a = EqualTo(UnresolvedAttribute("a"), Literal(100)) | ||
val b = UnresolvedAttribute("b") | ||
val ifExp = If(a === Literal(1), Literal(2), Literal(3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure the test expression is valid. a
is an EqualTo
, comparing boolean and Literal(1)
is invalid.
|
||
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) | ||
assertEquivalent(EqualTo(ifExp, Literal(3)), EqualTo(ifExp, Literal(3))) | ||
assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, let's don't test invalid expressions, ifExp.dataType
is int, not string.
assertEquivalent( | ||
EqualTo(If(a, b, Literal(2)), Literal(3)), | ||
EqualTo(If(a, b, Literal(2)), Literal(3))) | ||
val nonFoldable = If(NonFoldableLiteral(true), Literal(1), Literal(2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NonFoldableLiteral
is nothing special comparing to 'a == 100
, as they are both non-foldable. I think we don't need to test with NonFoldableLiteral
* Push the foldable expression into (if / case) branches. | ||
*/ | ||
object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { | ||
def apply(plan: LogicalPlan): LogicalPlan = plan transform { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we cannot use transformAllExpressions
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This copied from
spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
Line 474 in 8ccc3c1
case q: LogicalPlan => q transformExpressionsUp { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. I think its okay to simply use transformAllExpressions
here.
Kubernetes integration test starting |
Kubernetes integration test status success |
Test build #132935 has finished for PR 30790 at commit
|
Test build #132939 has finished for PR 30790 at commit
|
falseValue = b.makeCopy(Array(falseValue, right))) | ||
|
||
case b @ BinaryComparison(c @ CaseWhen(branches, elseValue), right) if c.deterministic && | ||
right.foldable && (branches.map(_._2) ++ elseValue).forall(_.foldable) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some more ideas:
- I think it's helpful to pushdown foldables into branches as long as there is at lease one foldable branche. To be conservative, we can do the pushdown if the IF/CASE WHEN has at most one non-foldable branch, so that we never increase the expression tree size, so is the generated code size.
- I think it's also useful to pushdown expressions like Add, e.g.
IF(cond, 1, 2) +1
->IF(cond, 2, 3)
. We can useBinnaryExpression
instead ofBinaryComparison
.
/** | ||
* Push the foldable expression into (if / case) branches. | ||
*/ | ||
object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice idea.
def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
case q: LogicalPlan => q transformExpressionsUp { | ||
case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right) if i.deterministic && | ||
right.foldable && (trueValue.foldable || falseValue.foldable) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry if I missed something but can you explain why we need this foldable condition?
Shouldn't bin_op((if cond: a else: b), c)
be same as if cond: bin_op(a, c) else: bin_op(b, c)
as long as the results are deterministic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh, okay but we just want to push down foldables only that can have benefits. gotya.
Kubernetes integration test starting |
Kubernetes integration test status success |
trueValue = b.makeCopy(Array(trueValue, right)), | ||
falseValue = b.makeCopy(Array(falseValue, right))) | ||
|
||
case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue)) if i.deterministic && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we care about deterministic
here? This optimization doesn't change the execution order of the branch conditions and/or branch values.
falseValue = b.makeCopy(Array(falseValue, right))) | ||
|
||
case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue)) if i.deterministic && | ||
left.foldable && (trueValue.foldable || falseValue.foldable) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's create a method to count foldables, with comments to explain why
// be conservative here: it's only a guaranteed win if all but at most only one branch
// end up being not foldable.
private def atMostOneUnfoldable(exprs: Seq[Expression]): Boolean = ...
then here left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to push down this case, only one branch and non foldable:
assertEquivalent(EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)), CaseWhen(Seq((a, EqualTo(b, Literal(1)))), None))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about?
// To be conservative here: it's only a guaranteed win if all but at most only one branch
// end up being not foldable.
private def atMostOneUnfoldable(exprs: Seq[Expression]): Boolean = {
val (foldables, others) = exprs.partition(_.foldable)
foldables.nonEmpty && others.length < 2
}
comparePlans(actual, correctAnswer) | ||
} | ||
|
||
test("SPARK-33798: Push down EqualTo through If") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: when creating a new test suite, we don't need to add the JIRA ID as prefix.
|
||
test("SPARK-33798: Push down EqualTo through If") { | ||
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) | ||
assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If(a, FalseLiteral, TrueLiteral)
can be turned into !a
. We can optimize it in followups.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test("SPARK-33798: Push down EqualTo through If") { | ||
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) | ||
assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral)) | ||
assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not test invalid expressions, EqualTo can't compare int and string.
assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral) | ||
assertEquivalent(EqualTo(caseWhen, Literal(3)), | ||
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) | ||
assertEquivalent(EqualTo(caseWhen, Literal("4")), FalseLiteral) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, don't test invalid expressions.
private val b = UnresolvedAttribute("b") | ||
private val c = EqualTo(UnresolvedAttribute("c"), Literal(true)) | ||
private val ifExp = If(a, Literal(2), Literal(3)) | ||
private val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
without the else value, seems like we can't optimize CASE WHEN cond1 THEN false WHEN cond2 THEN false
into false
. We should fix it in followups.
This comment was marked as spam.
This comment was marked as spam.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
case q: LogicalPlan => q transformExpressionsUp { | ||
case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right) if i.deterministic && | ||
right.foldable && (trueValue.foldable || falseValue.foldable) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is good idea when both trueValue
and falseValue
are foldable like the example in the description.
But if only trueValue
is foldable? For example, from If(_, 'b', col1) = 'a'
, we get If(_, 'b' = 'a', col1 = 'a')
, are some benefits here too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm not sure if the difference means obvious benefits. Looks like it should be very close but I think I'm okay with it.
Test build #132982 has finished for PR 30790 at commit
|
Kubernetes integration test starting |
Kubernetes integration test status success |
Test build #133005 has finished for PR 30790 at commit
|
thanks, merging to master! |
What changes were proposed in this pull request?
This pr add a new rule(
PushFoldableIntoBranches
) to push down the foldable expressions throughCaseWhen/If
. This is a real case from production:Before this PR:
After this PR:
Why are the changes needed?
Improve query performance.
Does this PR introduce any user-facing change?
No.
How was this patch tested?
Unit test.