diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1d23953484046..65f89bbdd0599 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -636,6 +636,11 @@ abstract class BinaryExpression extends Expression { } +object BinaryExpression { + def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right)) +} + + /** * A [[BinaryExpression]] that is an operator, with two properties: * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index aa8540fb44556..fdb9c5b4821dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -99,6 +99,7 @@ abstract class Optimizer(catalogManager: CatalogManager) LikeSimplification, BooleanSimplification, SimplifyConditionals, + PushFoldableIntoBranches, RemoveDispensableExpressions, SimplifyBinaryComparison, ReplaceNullWithFalseInPredicate, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 7666c4a53e5dd..e6730c9275a1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, _} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull @@ -528,6 +528,48 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } +/** + * Push the foldable expression into (if / case) branches. + */ +object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { + + // To be conservative here: it's only a guaranteed win if all but at most only one branch + // end up being not foldable. + private def atMostOneUnfoldable(exprs: Seq[Expression]): Boolean = { + val (foldables, others) = exprs.partition(_.foldable) + foldables.nonEmpty && others.length < 2 + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right) + if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) => + i.copy( + trueValue = b.makeCopy(Array(trueValue, right)), + falseValue = b.makeCopy(Array(falseValue, right))) + + case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue)) + if left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) => + i.copy( + trueValue = b.makeCopy(Array(left, trueValue)), + falseValue = b.makeCopy(Array(left, falseValue))) + + case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right) + if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => + c.copy( + branches.map(e => e.copy(_2 = b.makeCopy(Array(e._2, right)))), + elseValue.map(e => b.makeCopy(Array(e, right)))) + + case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue)) + if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => + c.copy( + branches.map(e => e.copy(_2 = b.makeCopy(Array(left, e._2)))), + elseValue.map(e => b.makeCopy(Array(left, e)))) + } + } +} + + /** * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition. * For example, when the expression is just checking to see if a string starts with a given diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala new file mode 100644 index 0000000000000..43360af46ffb3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import java.sql.Date + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + + +class PushFoldableIntoBranchesSuite + extends PlanTest with ExpressionEvalHelper with PredicateHelper { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("PushFoldableIntoBranches", FixedPoint(50), + BooleanSimplification, ConstantFolding, SimplifyConditionals, PushFoldableIntoBranches) :: Nil + } + + private val relation = LocalRelation('a.int, 'b.int, 'c.boolean) + private val a = EqualTo(UnresolvedAttribute("a"), Literal(100)) + private val b = UnresolvedAttribute("b") + private val c = EqualTo(UnresolvedAttribute("c"), Literal(true)) + private val ifExp = If(a, Literal(2), Literal(3)) + private val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3))) + + protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { + val correctAnswer = Project(Alias(e2, "out")() :: Nil, relation).analyze + val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, relation).analyze) + comparePlans(actual, correctAnswer) + } + + test("Push down EqualTo through If") { + assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) + assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral)) + + // Push down at most one not foldable expressions. + assertEquivalent( + EqualTo(If(a, b, Literal(2)), Literal(2)), + If(a, EqualTo(b, Literal(2)), TrueLiteral)) + assertEquivalent( + EqualTo(If(a, b, b + 1), Literal(2)), + EqualTo(If(a, b, b + 1), Literal(2))) + + // Push down non-deterministic expressions. + val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(2)) + assert(!nonDeterministic.deterministic) + assertEquivalent(EqualTo(nonDeterministic, Literal(2)), + If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, TrueLiteral)) + assertEquivalent(EqualTo(nonDeterministic, Literal(3)), + If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, FalseLiteral)) + + // Handle Null values. + assertEquivalent( + EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)), + If(a, Literal(null, BooleanType), TrueLiteral)) + assertEquivalent( + EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)), + If(a, Literal(null, BooleanType), FalseLiteral)) + assertEquivalent( + EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType)), + Literal(null, BooleanType)) + assertEquivalent( + EqualTo(If(a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)), + Literal(null, BooleanType)) + } + + test("Push down other BinaryComparison through If") { + assertEquivalent(EqualNullSafe(ifExp, Literal(4)), FalseLiteral) + assertEquivalent(GreaterThan(ifExp, Literal(4)), FalseLiteral) + assertEquivalent(GreaterThanOrEqual(ifExp, Literal(4)), FalseLiteral) + assertEquivalent(LessThan(ifExp, Literal(4)), TrueLiteral) + assertEquivalent(LessThanOrEqual(ifExp, Literal(4)), TrueLiteral) + } + + test("Push down other BinaryOperator through If") { + assertEquivalent(Add(ifExp, Literal(4)), If(a, Literal(6), Literal(7))) + assertEquivalent(Subtract(ifExp, Literal(4)), If(a, Literal(-2), Literal(-1))) + assertEquivalent(Multiply(ifExp, Literal(4)), If(a, Literal(8), Literal(12))) + assertEquivalent(Pmod(ifExp, Literal(4)), If(a, Literal(2), Literal(3))) + assertEquivalent(Remainder(ifExp, Literal(4)), If(a, Literal(2), Literal(3))) + assertEquivalent(Divide(If(a, Literal(2.0), Literal(3.0)), Literal(1.0)), + If(a, Literal(2.0), Literal(3.0))) + assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), + If(a, FalseLiteral, TrueLiteral)) + assertEquivalent(Or(If(a, FalseLiteral, TrueLiteral), TrueLiteral), TrueLiteral) + } + + test("Push down other BinaryExpression through If") { + assertEquivalent(BRound(If(a, Literal(1.23), Literal(1.24)), Literal(1)), Literal(1.2)) + assertEquivalent(StartsWith(If(a, Literal("ab"), Literal("ac")), Literal("a")), TrueLiteral) + assertEquivalent(FindInSet(If(a, Literal("ab"), Literal("ac")), Literal("a")), Literal(0)) + assertEquivalent( + AddMonths(If(a, Literal(Date.valueOf("2020-01-01")), Literal(Date.valueOf("2021-01-01"))), + Literal(1)), + If(a, Literal(Date.valueOf("2020-02-01")), Literal(Date.valueOf("2021-02-01")))) + } + + test("Push down EqualTo through CaseWhen") { + assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral) + assertEquivalent(EqualTo(caseWhen, Literal(3)), + CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None), Literal(4)), + CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None)) + + assertEquivalent( + And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))), + FalseLiteral) + + // Push down at most one branch is not foldable expressions. + assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, Literal(1))), None), Literal(1)), + CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), None)) + assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)), + EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1))) + assertEquivalent(EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)), + EqualTo(CaseWhen(Seq((a, b)), None), Literal(1))) + + // Push down non-deterministic expressions. + val nonDeterministic = + CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(Literal(2))) + assert(!nonDeterministic.deterministic) + assertEquivalent(EqualTo(nonDeterministic, Literal(2)), + CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(TrueLiteral))) + assertEquivalent(EqualTo(nonDeterministic, Literal(3)), + CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(FalseLiteral))) + + // Handle Null values. + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)), + CaseWhen(Seq((a, Literal(null, BooleanType))), Some(FalseLiteral))) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)), + Literal(null, BooleanType)) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)), + CaseWhen(Seq((a, Literal(null, BooleanType))), Some(TrueLiteral))) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), + Literal(1)), + Literal(null, BooleanType)) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), + Literal(null, IntegerType)), + Literal(null, BooleanType)) + } + + test("Push down other BinaryComparison through CaseWhen") { + assertEquivalent(EqualNullSafe(caseWhen, Literal(4)), FalseLiteral) + assertEquivalent(GreaterThan(caseWhen, Literal(4)), FalseLiteral) + assertEquivalent(GreaterThanOrEqual(caseWhen, Literal(4)), FalseLiteral) + assertEquivalent(LessThan(caseWhen, Literal(4)), TrueLiteral) + assertEquivalent(LessThanOrEqual(caseWhen, Literal(4)), TrueLiteral) + } + + test("Push down other BinaryOperator through CaseWhen") { + assertEquivalent(Add(caseWhen, Literal(4)), + CaseWhen(Seq((a, Literal(5)), (c, Literal(6))), Some(Literal(7)))) + assertEquivalent(Subtract(caseWhen, Literal(4)), + CaseWhen(Seq((a, Literal(-3)), (c, Literal(-2))), Some(Literal(-1)))) + assertEquivalent(Multiply(caseWhen, Literal(4)), + CaseWhen(Seq((a, Literal(4)), (c, Literal(8))), Some(Literal(12)))) + assertEquivalent(Pmod(caseWhen, Literal(4)), + CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3)))) + assertEquivalent(Remainder(caseWhen, Literal(4)), + CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3)))) + assertEquivalent(Divide(CaseWhen(Seq((a, Literal(1.0)), (c, Literal(2.0))), Some(Literal(3.0))), + Literal(1.0)), + CaseWhen(Seq((a, Literal(1.0)), (c, Literal(2.0))), Some(Literal(3.0)))) + assertEquivalent(And(CaseWhen(Seq((a, FalseLiteral), (c, TrueLiteral)), Some(TrueLiteral)), + TrueLiteral), + CaseWhen(Seq((a, FalseLiteral), (c, TrueLiteral)), Some(TrueLiteral))) + assertEquivalent(Or(CaseWhen(Seq((a, FalseLiteral), (c, TrueLiteral)), Some(TrueLiteral)), + TrueLiteral), TrueLiteral) + } + + test("Push down other BinaryExpression through CaseWhen") { + assertEquivalent( + BRound(CaseWhen(Seq((a, Literal(1.23)), (c, Literal(1.24))), Some(Literal(1.25))), + Literal(1)), + Literal(1.2)) + assertEquivalent( + StartsWith(CaseWhen(Seq((a, Literal("ab")), (c, Literal("ac"))), Some(Literal("ad"))), + Literal("a")), + TrueLiteral) + assertEquivalent( + FindInSet(CaseWhen(Seq((a, Literal("ab")), (c, Literal("ac"))), Some(Literal("ad"))), + Literal("a")), + Literal(0)) + assertEquivalent( + AddMonths(CaseWhen(Seq((a, Literal(Date.valueOf("2020-01-01"))), + (c, Literal(Date.valueOf("2021-01-01")))), + Some(Literal(Date.valueOf("2022-01-01")))), + Literal(1)), + CaseWhen(Seq((a, Literal(Date.valueOf("2020-02-01"))), + (c, Literal(Date.valueOf("2021-02-01")))), + Some(Literal(Date.valueOf("2022-02-01"))))) + } + + test("Push down BinaryExpression through If/CaseWhen backwards") { + assertEquivalent(EqualTo(Literal(4), ifExp), FalseLiteral) + assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral) + } +}