From 19b0a839a6e30e6beb6690fd7b7b5abef0e8b147 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 16 Dec 2020 09:31:23 +0800 Subject: [PATCH 01/12] Simplify EqualTo(CaseWhen/If, Literal) always false --- .../sql/catalyst/optimizer/expressions.scala | 13 +++ .../optimizer/SimplifyConditionalSuite.scala | 93 +++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 7666c4a53e5dd..17a54dc13f354 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -470,6 +470,10 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case _ => false } + private def isAlwaysFalse(exps: Seq[Expression], equalTo: Literal): Boolean = { + exps.forall(!EqualTo(_, equalTo).eval(EmptyRow).asInstanceOf[Boolean]) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case If(TrueLiteral, trueValue, _) => trueValue @@ -523,6 +527,15 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } else { e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) } + + case EqualTo(i @ If(_, trueValue: Literal, falseValue: Literal), right: Literal) + if i.deterministic && isAlwaysFalse(trueValue :: falseValue :: Nil, right) => + FalseLiteral + + case EqualTo(c @ CaseWhen(branches, elseValue), right: Literal) if c.deterministic && + (branches.map(_._2) ++ elseValue).forall(_.isInstanceOf[Literal]) && + isAlwaysFalse(branches.map(_._2) ++ elseValue, right) => + FalseLiteral } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index bac962ced4618..d4eee541ac053 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -199,4 +199,97 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P If(Factorial(5) > 100L, b, nullLiteral).eval(EmptyRow)) } } + + test("SPARK-33798: simplify EqualTo(If, Literal) always false") { + val a = EqualTo(UnresolvedAttribute("a"), Literal(100)) + val ifExp = If(a === Literal(1), Literal(2), Literal(3)) + + assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) + assertEquivalent(EqualTo(ifExp, Literal(3)), EqualTo(ifExp, Literal(3))) + assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral) + assertEquivalent(EqualTo(ifExp, Literal("3")), EqualTo(ifExp, Literal(3))) + + // Do not simplify if it contains non foldable expressions. + assertEquivalent(EqualTo(ifExp, NonFoldableLiteral(true)), + EqualTo(ifExp, NonFoldableLiteral(true))) + val nonFoldable = If(NonFoldableLiteral(true), Literal(1), Literal(2)) + assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1))) + + // Do not simplify if it contains non-deterministic expressions. + val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(1)) + assert(!nonDeterministic.deterministic) + assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) + + // null check, SPARK-33798 will not change these behaviors. + assertEquivalent( + EqualTo(If(FalseLiteral, Literal(null, IntegerType), Literal(1)), Literal(1)), + TrueLiteral) + assertEquivalent( + EqualTo(If(TrueLiteral, Literal(null, IntegerType), Literal(1)), Literal(1)), + Literal(null, BooleanType)) + assertEquivalent( + EqualTo(If(FalseLiteral, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)), + Literal(null, BooleanType)) + + assertEquivalent( + EqualTo(If(FalseLiteral, Literal(1), Literal(2)), Literal(null, IntegerType)), + Literal(null, BooleanType)) + assertEquivalent( + EqualTo(If(TrueLiteral, Literal(1), Literal(2)), Literal(null, IntegerType)), + Literal(null, BooleanType)) + } + + test("SPARK-33798: simplify EqualTo(CaseWhen, Literal) always false") { + val a = EqualTo(UnresolvedAttribute("a"), Literal(100)) + val b = UnresolvedAttribute("b") + val c = EqualTo(UnresolvedAttribute("c"), Literal(true)) + val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3))) + + assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral) + assertEquivalent(EqualTo(caseWhen, Literal(3)), EqualTo(caseWhen, Literal(3))) + assertEquivalent(EqualTo(caseWhen, Literal("4")), FalseLiteral) + assertEquivalent(EqualTo(caseWhen, Literal("3")), EqualTo(caseWhen, Literal(3))) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal("1")), (c, Literal("2"))), None), Literal("4")), + FalseLiteral) + + assertEquivalent( + And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))), + FalseLiteral) + + assertEquivalent( + EqualTo(CaseWhen(Seq(normalBranch, (a, Literal(1)), (c, Literal(1))), None), Literal(-1)), + FalseLiteral) + + // Do not simplify if it contains non foldable expressions. + assertEquivalent(EqualTo(caseWhen, NonFoldableLiteral(true)), + EqualTo(caseWhen, NonFoldableLiteral(true))) + val nonFoldable = CaseWhen(Seq(normalBranch, (a, b)), None) + assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1))) + + // Do not simplify if it contains non-deterministic expressions. + val nonDeterministic = CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(b)) + assert(!nonDeterministic.deterministic) + assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) + + // null check, SPARK-33798 will change the following two behaviors. + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)), + FalseLiteral) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)), + FalseLiteral) + + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)), + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1))) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), + Literal(1)), + Literal(null, BooleanType)) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), + Literal(null, IntegerType)), + Literal(null, BooleanType)) + } } From 831bf66b889d72724c072f9b9a20d8a436d4dbb8 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 16 Dec 2020 14:11:30 +0800 Subject: [PATCH 02/12] Address comments --- .../sql/catalyst/optimizer/expressions.scala | 10 ++++---- .../optimizer/SimplifyConditionalSuite.scala | 23 ++++++++++--------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 17a54dc13f354..e679033530b46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -471,7 +471,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } private def isAlwaysFalse(exps: Seq[Expression], equalTo: Literal): Boolean = { - exps.forall(!EqualTo(_, equalTo).eval(EmptyRow).asInstanceOf[Boolean]) + exps.forall(_.isInstanceOf[Literal]) && + exps.forall(!EqualTo(_, equalTo).eval(EmptyRow).asInstanceOf[Boolean]) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -528,13 +529,12 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) } - case EqualTo(i @ If(_, trueValue: Literal, falseValue: Literal), right: Literal) + case EqualTo(i @ If(_, trueValue, falseValue), right: Literal) if i.deterministic && isAlwaysFalse(trueValue :: falseValue :: Nil, right) => FalseLiteral - case EqualTo(c @ CaseWhen(branches, elseValue), right: Literal) if c.deterministic && - (branches.map(_._2) ++ elseValue).forall(_.isInstanceOf[Literal]) && - isAlwaysFalse(branches.map(_._2) ++ elseValue, right) => + case EqualTo(c @ CaseWhen(branches, elseValue), right: Literal) + if c.deterministic && isAlwaysFalse(branches.map(_._2) ++ elseValue, right) => FalseLiteral } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index d4eee541ac053..61d31b9c1c9af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -220,22 +220,23 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) - // null check, SPARK-33798 will not change these behaviors. - assertEquivalent( - EqualTo(If(FalseLiteral, Literal(null, IntegerType), Literal(1)), Literal(1)), - TrueLiteral) + // null check, SPARK-33798 will change the following two behaviors. assertEquivalent( - EqualTo(If(TrueLiteral, Literal(null, IntegerType), Literal(1)), Literal(1)), - Literal(null, BooleanType)) + EqualTo(If(a === Literal(1), Literal(null, IntegerType), Literal(1)), Literal(2)), + FalseLiteral) assertEquivalent( - EqualTo(If(FalseLiteral, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)), - Literal(null, BooleanType)) + EqualTo(If(a =!= Literal(1), Literal(1), Literal(2)), Literal(null, IntegerType)), + FalseLiteral) assertEquivalent( - EqualTo(If(FalseLiteral, Literal(1), Literal(2)), Literal(null, IntegerType)), - Literal(null, BooleanType)) + EqualTo(If(a === Literal(1), Literal(null, IntegerType), Literal(1)), Literal(1)), + EqualTo(If(a === Literal(1), Literal(null, IntegerType), Literal(1)), Literal(1))) assertEquivalent( - EqualTo(If(TrueLiteral, Literal(1), Literal(2)), Literal(null, IntegerType)), + EqualTo(If(a =!= Literal(1), Literal(null, IntegerType), Literal(1)), Literal(1)), + EqualTo(If(a =!= Literal(1), Literal(null, IntegerType), Literal(1)), Literal(1))) + assertEquivalent( + EqualTo(If(a =!= Literal(1), Literal(null, IntegerType), Literal(null, IntegerType)), + Literal(1)), Literal(null, BooleanType)) } From 859893d0a8629d689ee72db9e9897f772e5ef4a5 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 16 Dec 2020 16:48:47 +0800 Subject: [PATCH 03/12] Null value should be handled by NullPropagation. --- .../spark/sql/catalyst/optimizer/expressions.scala | 8 ++++++-- .../optimizer/SimplifyConditionalSuite.scala | 14 ++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index e679033530b46..4d572e01f7526 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -472,6 +472,7 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { private def isAlwaysFalse(exps: Seq[Expression], equalTo: Literal): Boolean = { exps.forall(_.isInstanceOf[Literal]) && + exps.forall(_.asInstanceOf[Literal].value != null) && exps.forall(!EqualTo(_, equalTo).eval(EmptyRow).asInstanceOf[Boolean]) } @@ -529,12 +530,15 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) } + // Null value should be handled by NullPropagation. case EqualTo(i @ If(_, trueValue, falseValue), right: Literal) - if i.deterministic && isAlwaysFalse(trueValue :: falseValue :: Nil, right) => + if i.deterministic && right.value != null && + isAlwaysFalse(trueValue :: falseValue :: Nil, right) => FalseLiteral case EqualTo(c @ CaseWhen(branches, elseValue), right: Literal) - if c.deterministic && isAlwaysFalse(branches.map(_._2) ++ elseValue, right) => + if c.deterministic && right.value != null && + isAlwaysFalse(branches.map(_._2) ++ elseValue, right) => FalseLiteral } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 61d31b9c1c9af..b1ddb9d2331f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -220,14 +220,13 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) - // null check, SPARK-33798 will change the following two behaviors. + // Null value should be handled by NullPropagation. assertEquivalent( EqualTo(If(a === Literal(1), Literal(null, IntegerType), Literal(1)), Literal(2)), - FalseLiteral) + EqualTo(If(a === Literal(1), Literal(null, IntegerType), Literal(1)), Literal(2))) assertEquivalent( EqualTo(If(a =!= Literal(1), Literal(1), Literal(2)), Literal(null, IntegerType)), - FalseLiteral) - + EqualTo(If(a =!= Literal(1), Literal(1), Literal(2)), Literal(null, IntegerType))) assertEquivalent( EqualTo(If(a === Literal(1), Literal(null, IntegerType), Literal(1)), Literal(1)), EqualTo(If(a === Literal(1), Literal(null, IntegerType), Literal(1)), Literal(1))) @@ -273,14 +272,13 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) - // null check, SPARK-33798 will change the following two behaviors. + // Null value should be handled by NullPropagation. assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)), - FalseLiteral) + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2))) assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)), - FalseLiteral) - + EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType))) assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)), EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1))) From db584efb5a0e1f9ba0a8c21dfd4dcdc5737a62ac Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 16 Dec 2020 18:31:53 +0800 Subject: [PATCH 04/12] Handle Null values --- .../sql/catalyst/optimizer/expressions.scala | 16 ++++++++-------- .../optimizer/SimplifyConditionalSuite.scala | 9 +++++---- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 4d572e01f7526..3484ad30b5aaa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -471,9 +471,12 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } private def isAlwaysFalse(exps: Seq[Expression], equalTo: Literal): Boolean = { - exps.forall(_.isInstanceOf[Literal]) && - exps.forall(_.asInstanceOf[Literal].value != null) && - exps.forall(!EqualTo(_, equalTo).eval(EmptyRow).asInstanceOf[Boolean]) + exps.forall { + case l: Literal => + val res = EqualTo(l, equalTo).eval(EmptyRow) + res != null && !res.asInstanceOf[Boolean] + case _ => false + } } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -530,15 +533,12 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) } - // Null value should be handled by NullPropagation. case EqualTo(i @ If(_, trueValue, falseValue), right: Literal) - if i.deterministic && right.value != null && - isAlwaysFalse(trueValue :: falseValue :: Nil, right) => + if i.deterministic && isAlwaysFalse(trueValue :: falseValue :: Nil, right) => FalseLiteral case EqualTo(c @ CaseWhen(branches, elseValue), right: Literal) - if c.deterministic && right.value != null && - isAlwaysFalse(branches.map(_._2) ++ elseValue, right) => + if c.deterministic && isAlwaysFalse(branches.map(_._2) ++ elseValue, right) => FalseLiteral } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index b1ddb9d2331f7..cc18492ce44ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -202,6 +202,7 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P test("SPARK-33798: simplify EqualTo(If, Literal) always false") { val a = EqualTo(UnresolvedAttribute("a"), Literal(100)) + val b = UnresolvedAttribute("b") val ifExp = If(a === Literal(1), Literal(2), Literal(3)) assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) @@ -210,8 +211,8 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P assertEquivalent(EqualTo(ifExp, Literal("3")), EqualTo(ifExp, Literal(3))) // Do not simplify if it contains non foldable expressions. - assertEquivalent(EqualTo(ifExp, NonFoldableLiteral(true)), - EqualTo(ifExp, NonFoldableLiteral(true))) + assertEquivalent(EqualTo(If(a === Literal(1), b, Literal(2)), Literal(3)), + EqualTo(If(a === Literal(1), b, Literal(2)), Literal(3))) val nonFoldable = If(NonFoldableLiteral(true), Literal(1), Literal(2)) assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1))) @@ -220,7 +221,7 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) - // Null value should be handled by NullPropagation. + // Should not handle Null values. assertEquivalent( EqualTo(If(a === Literal(1), Literal(null, IntegerType), Literal(1)), Literal(2)), EqualTo(If(a === Literal(1), Literal(null, IntegerType), Literal(1)), Literal(2))) @@ -272,7 +273,7 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) - // Null value should be handled by NullPropagation. + // Should not handle Null values. assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)), EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2))) From 55f4528e7e61629e2194adb055f3ba52ae4a875b Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 16 Dec 2020 22:08:53 +0800 Subject: [PATCH 05/12] address comments --- .../optimizer/SimplifyConditionalSuite.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index cc18492ce44ce..5826b38ac3b78 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -203,7 +203,7 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P test("SPARK-33798: simplify EqualTo(If, Literal) always false") { val a = EqualTo(UnresolvedAttribute("a"), Literal(100)) val b = UnresolvedAttribute("b") - val ifExp = If(a === Literal(1), Literal(2), Literal(3)) + val ifExp = If(a, Literal(2), Literal(3)) assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) assertEquivalent(EqualTo(ifExp, Literal(3)), EqualTo(ifExp, Literal(3))) @@ -211,8 +211,9 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P assertEquivalent(EqualTo(ifExp, Literal("3")), EqualTo(ifExp, Literal(3))) // Do not simplify if it contains non foldable expressions. - assertEquivalent(EqualTo(If(a === Literal(1), b, Literal(2)), Literal(3)), - EqualTo(If(a === Literal(1), b, Literal(2)), Literal(3))) + assertEquivalent( + EqualTo(If(a, b, Literal(2)), Literal(3)), + EqualTo(If(a, b, Literal(2)), Literal(3))) val nonFoldable = If(NonFoldableLiteral(true), Literal(1), Literal(2)) assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1))) @@ -223,20 +224,19 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P // Should not handle Null values. assertEquivalent( - EqualTo(If(a === Literal(1), Literal(null, IntegerType), Literal(1)), Literal(2)), - EqualTo(If(a === Literal(1), Literal(null, IntegerType), Literal(1)), Literal(2))) + EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)), + EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2))) assertEquivalent( - EqualTo(If(a =!= Literal(1), Literal(1), Literal(2)), Literal(null, IntegerType)), - EqualTo(If(a =!= Literal(1), Literal(1), Literal(2)), Literal(null, IntegerType))) + EqualTo(If(!a, Literal(1), Literal(2)), Literal(null, IntegerType)), + EqualTo(If(!a, Literal(1), Literal(2)), Literal(null, IntegerType))) assertEquivalent( - EqualTo(If(a === Literal(1), Literal(null, IntegerType), Literal(1)), Literal(1)), - EqualTo(If(a === Literal(1), Literal(null, IntegerType), Literal(1)), Literal(1))) + EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)), + EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1))) assertEquivalent( - EqualTo(If(a =!= Literal(1), Literal(null, IntegerType), Literal(1)), Literal(1)), - EqualTo(If(a =!= Literal(1), Literal(null, IntegerType), Literal(1)), Literal(1))) + EqualTo(If(!a, Literal(null, IntegerType), Literal(1)), Literal(1)), + EqualTo(If(!a, Literal(null, IntegerType), Literal(1)), Literal(1))) assertEquivalent( - EqualTo(If(a =!= Literal(1), Literal(null, IntegerType), Literal(null, IntegerType)), - Literal(1)), + EqualTo(If(!a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)), Literal(null, BooleanType)) } From 0a48048d90077b6f4c9f0a1125e35c3b0fec861a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 17 Dec 2020 08:38:26 +0800 Subject: [PATCH 06/12] address comments --- .../optimizer/SimplifyConditionalSuite.scala | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 5826b38ac3b78..278abc44e0a9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -212,10 +212,8 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P // Do not simplify if it contains non foldable expressions. assertEquivalent( - EqualTo(If(a, b, Literal(2)), Literal(3)), - EqualTo(If(a, b, Literal(2)), Literal(3))) - val nonFoldable = If(NonFoldableLiteral(true), Literal(1), Literal(2)) - assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1))) + EqualTo(If(a, b, Literal(2)), Literal(2)), + EqualTo(If(a, b, Literal(2)), Literal(2))) // Do not simplify if it contains non-deterministic expressions. val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(1)) @@ -223,20 +221,17 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) // Should not handle Null values. - assertEquivalent( - EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)), - EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2))) - assertEquivalent( - EqualTo(If(!a, Literal(1), Literal(2)), Literal(null, IntegerType)), - EqualTo(If(!a, Literal(1), Literal(2)), Literal(null, IntegerType))) assertEquivalent( EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)), EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1))) assertEquivalent( - EqualTo(If(!a, Literal(null, IntegerType), Literal(1)), Literal(1)), - EqualTo(If(!a, Literal(null, IntegerType), Literal(1)), Literal(1))) + EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)), + EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2))) + assertEquivalent( + EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType)), + EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType))) assertEquivalent( - EqualTo(If(!a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)), + EqualTo(If(a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)), Literal(null, BooleanType)) } From f9f622f96c20d1787488c9ea392f0b857be532b5 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 17 Dec 2020 13:43:15 +0800 Subject: [PATCH 07/12] Push down EqualTo through CaseWhen/If --- .../sql/catalyst/optimizer/expressions.scala | 20 ++++------- .../optimizer/SimplifyConditionalSuite.scala | 36 +++++++++---------- 2 files changed, 23 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 3484ad30b5aaa..f7c1d291f78b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -470,15 +470,6 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case _ => false } - private def isAlwaysFalse(exps: Seq[Expression], equalTo: Literal): Boolean = { - exps.forall { - case l: Literal => - val res = EqualTo(l, equalTo).eval(EmptyRow) - res != null && !res.asInstanceOf[Boolean] - case _ => false - } - } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case If(TrueLiteral, trueValue, _) => trueValue @@ -533,13 +524,14 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) } - case EqualTo(i @ If(_, trueValue, falseValue), right: Literal) - if i.deterministic && isAlwaysFalse(trueValue :: falseValue :: Nil, right) => - FalseLiteral + case EqualTo(i @ If(_, trueValue: Literal, falseValue: Literal), right: Literal) + if i.deterministic => + i.copy(trueValue = EqualTo(trueValue, right), falseValue = EqualTo(falseValue, right)) case EqualTo(c @ CaseWhen(branches, elseValue), right: Literal) - if c.deterministic && isAlwaysFalse(branches.map(_._2) ++ elseValue, right) => - FalseLiteral + if c.deterministic && (branches.map(_._2) ++ elseValue).forall(_.isInstanceOf[Literal]) => + c.copy(branches.map(b => b.copy(_2 = EqualTo(b._2, right))), + elseValue.map(EqualTo(_, right))) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 278abc44e0a9e..735ce5da98fe3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -200,15 +200,15 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P } } - test("SPARK-33798: simplify EqualTo(If, Literal) always false") { + test("SPARK-33798: Push down EqualTo through If") { val a = EqualTo(UnresolvedAttribute("a"), Literal(100)) val b = UnresolvedAttribute("b") val ifExp = If(a, Literal(2), Literal(3)) assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) - assertEquivalent(EqualTo(ifExp, Literal(3)), EqualTo(ifExp, Literal(3))) + assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral)) assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral) - assertEquivalent(EqualTo(ifExp, Literal("3")), EqualTo(ifExp, Literal(3))) + assertEquivalent(EqualTo(ifExp, Literal("3")), If(a, FalseLiteral, TrueLiteral)) // Do not simplify if it contains non foldable expressions. assertEquivalent( @@ -220,43 +220,41 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) - // Should not handle Null values. + // Handle Null values. assertEquivalent( EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)), - EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1))) + If(a, Literal(null, BooleanType), TrueLiteral)) assertEquivalent( EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)), - EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2))) + If(a, Literal(null, BooleanType), FalseLiteral)) assertEquivalent( EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType)), - EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType))) + Literal(null, BooleanType)) assertEquivalent( EqualTo(If(a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)), Literal(null, BooleanType)) } - test("SPARK-33798: simplify EqualTo(CaseWhen, Literal) always false") { + test("SPARK-33798: Push down EqualTo through CaseWhen") { val a = EqualTo(UnresolvedAttribute("a"), Literal(100)) val b = UnresolvedAttribute("b") val c = EqualTo(UnresolvedAttribute("c"), Literal(true)) val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3))) assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral) - assertEquivalent(EqualTo(caseWhen, Literal(3)), EqualTo(caseWhen, Literal(3))) + assertEquivalent(EqualTo(caseWhen, Literal(3)), + CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) assertEquivalent(EqualTo(caseWhen, Literal("4")), FalseLiteral) - assertEquivalent(EqualTo(caseWhen, Literal("3")), EqualTo(caseWhen, Literal(3))) + assertEquivalent(EqualTo(caseWhen, Literal("3")), + CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal("1")), (c, Literal("2"))), None), Literal("4")), - FalseLiteral) + CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None)) assertEquivalent( And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))), FalseLiteral) - assertEquivalent( - EqualTo(CaseWhen(Seq(normalBranch, (a, Literal(1)), (c, Literal(1))), None), Literal(-1)), - FalseLiteral) - // Do not simplify if it contains non foldable expressions. assertEquivalent(EqualTo(caseWhen, NonFoldableLiteral(true)), EqualTo(caseWhen, NonFoldableLiteral(true))) @@ -268,16 +266,16 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) - // Should not handle Null values. + // Handle Null values. assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)), - EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2))) + CaseWhen(Seq((a, Literal(null, BooleanType))), Some(FalseLiteral))) assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)), - EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType))) + Literal(null, BooleanType)) assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)), - EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1))) + CaseWhen(Seq((a, Literal(null, BooleanType))), Some(TrueLiteral))) assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), Literal(1)), From 21ef0c3b9dc8e984c0d65d92fdbde6e128864bc0 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 17 Dec 2020 14:33:41 +0800 Subject: [PATCH 08/12] Push down EqualTo through CaseWhen/If --- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/catalyst/optimizer/expressions.scala | 29 +++- .../PushFoldableIntoBranchesSuite.scala | 146 ++++++++++++++++++ .../optimizer/SimplifyConditionalSuite.scala | 86 ----------- 4 files changed, 168 insertions(+), 94 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index aa8540fb44556..fdb9c5b4821dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -99,6 +99,7 @@ abstract class Optimizer(catalogManager: CatalogManager) LikeSimplification, BooleanSimplification, SimplifyConditionals, + PushFoldableIntoBranches, RemoveDispensableExpressions, SimplifyBinaryComparison, ReplaceNullWithFalseInPredicate, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index f7c1d291f78b8..55003872af8e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -523,20 +523,33 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } else { e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) } + } + } +} - case EqualTo(i @ If(_, trueValue: Literal, falseValue: Literal), right: Literal) - if i.deterministic => - i.copy(trueValue = EqualTo(trueValue, right), falseValue = EqualTo(falseValue, right)) - - case EqualTo(c @ CaseWhen(branches, elseValue), right: Literal) - if c.deterministic && (branches.map(_._2) ++ elseValue).forall(_.isInstanceOf[Literal]) => - c.copy(branches.map(b => b.copy(_2 = EqualTo(b._2, right))), - elseValue.map(EqualTo(_, right))) +/** + * Push the foldable expression into (if / case) branches. + */ +object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + case b @ BinaryComparison(i @ If(_, trueValue, falseValue), right) + if i.deterministic && trueValue.foldable && falseValue.foldable && right.foldable => + i.copy( + trueValue = b.makeCopy(Array(trueValue, right)), + falseValue = b.makeCopy(Array(falseValue, right))) + + case b @ BinaryComparison(c @ CaseWhen(branches, elseValue), right) if c.deterministic && + right.foldable && (branches.map(_._2) ++ elseValue).forall(_.foldable) => + c.copy( + branches.map(e => e.copy(_2 = b.makeCopy(Array(e._2, right)))), + elseValue.map(e => b.makeCopy(Array(e, right)))) } } } + /** * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition. * For example, when the expression is just checking to see if a string starts with a given diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala new file mode 100644 index 0000000000000..8ebab8412f716 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + + +class PushFoldableIntoBranchesSuite + extends PlanTest with ExpressionEvalHelper with PredicateHelper { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("PushFoldableIntoBranches", FixedPoint(50), + BooleanSimplification, ConstantFolding, SimplifyConditionals, PushFoldableIntoBranches) :: Nil + } + + private val relation = LocalRelation('a.int, 'b.int, 'c.boolean) + private val a = EqualTo(UnresolvedAttribute("a"), Literal(100)) + private val b = UnresolvedAttribute("b") + private val c = EqualTo(UnresolvedAttribute("c"), Literal(true)) + private val ifExp = If(a, Literal(2), Literal(3)) + private val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3))) + + protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { + val correctAnswer = Project(Alias(e2, "out")() :: Nil, relation).analyze + val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, relation).analyze) + comparePlans(actual, correctAnswer) + } + + private val normalBranch = (NonFoldableLiteral(true), Literal(10)) + + test("SPARK-33798: Push down EqualTo through If") { + assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) + assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral)) + assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral) + assertEquivalent(EqualTo(ifExp, Literal("3")), If(a, FalseLiteral, TrueLiteral)) + + // Do not simplify if it contains non foldable expressions. + assertEquivalent( + EqualTo(If(a, b, Literal(2)), Literal(2)), + EqualTo(If(a, b, Literal(2)), Literal(2))) + + // Do not simplify if it contains non-deterministic expressions. + val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(1)) + assert(!nonDeterministic.deterministic) + assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) + + // Handle Null values. + assertEquivalent( + EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)), + If(a, Literal(null, BooleanType), TrueLiteral)) + assertEquivalent( + EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)), + If(a, Literal(null, BooleanType), FalseLiteral)) + assertEquivalent( + EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType)), + Literal(null, BooleanType)) + assertEquivalent( + EqualTo(If(a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)), + Literal(null, BooleanType)) + } + + test("SPARK-33798: Push down other BinaryComparison through If") { + assertEquivalent(EqualNullSafe(ifExp, Literal(4)), FalseLiteral) + assertEquivalent(GreaterThan(ifExp, Literal(4)), FalseLiteral) + assertEquivalent(GreaterThanOrEqual(ifExp, Literal(4)), FalseLiteral) + assertEquivalent(LessThan(ifExp, Literal(4)), TrueLiteral) + assertEquivalent(LessThanOrEqual(ifExp, Literal(4)), TrueLiteral) + } + + test("SPARK-33798: Push down EqualTo through CaseWhen") { + assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral) + assertEquivalent(EqualTo(caseWhen, Literal(3)), + CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) + assertEquivalent(EqualTo(caseWhen, Literal("4")), FalseLiteral) + assertEquivalent(EqualTo(caseWhen, Literal("3")), + CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal("1")), (c, Literal("2"))), None), Literal("4")), + CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None)) + + assertEquivalent( + And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))), + FalseLiteral) + + // Do not simplify if it contains non foldable expressions. + assertEquivalent(EqualTo(caseWhen, NonFoldableLiteral(true)), + EqualTo(caseWhen, NonFoldableLiteral(true))) + val nonFoldable = CaseWhen(Seq(normalBranch, (a, b)), None) + assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1))) + + // Do not simplify if it contains non-deterministic expressions. + val nonDeterministic = CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(b)) + assert(!nonDeterministic.deterministic) + assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) + + // Handle Null values. + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)), + CaseWhen(Seq((a, Literal(null, BooleanType))), Some(FalseLiteral))) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)), + Literal(null, BooleanType)) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)), + CaseWhen(Seq((a, Literal(null, BooleanType))), Some(TrueLiteral))) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), + Literal(1)), + Literal(null, BooleanType)) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), + Literal(null, IntegerType)), + Literal(null, BooleanType)) + } + + test("SPARK-33798: Push down other BinaryComparison through CaseWhen") { + assertEquivalent(EqualNullSafe(caseWhen, Literal(4)), FalseLiteral) + assertEquivalent(GreaterThan(caseWhen, Literal(4)), FalseLiteral) + assertEquivalent(GreaterThanOrEqual(caseWhen, Literal(4)), FalseLiteral) + assertEquivalent(LessThan(caseWhen, Literal(4)), TrueLiteral) + assertEquivalent(LessThanOrEqual(caseWhen, Literal(4)), TrueLiteral) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 735ce5da98fe3..bac962ced4618 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -199,90 +199,4 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P If(Factorial(5) > 100L, b, nullLiteral).eval(EmptyRow)) } } - - test("SPARK-33798: Push down EqualTo through If") { - val a = EqualTo(UnresolvedAttribute("a"), Literal(100)) - val b = UnresolvedAttribute("b") - val ifExp = If(a, Literal(2), Literal(3)) - - assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) - assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral)) - assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral) - assertEquivalent(EqualTo(ifExp, Literal("3")), If(a, FalseLiteral, TrueLiteral)) - - // Do not simplify if it contains non foldable expressions. - assertEquivalent( - EqualTo(If(a, b, Literal(2)), Literal(2)), - EqualTo(If(a, b, Literal(2)), Literal(2))) - - // Do not simplify if it contains non-deterministic expressions. - val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(1)) - assert(!nonDeterministic.deterministic) - assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) - - // Handle Null values. - assertEquivalent( - EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)), - If(a, Literal(null, BooleanType), TrueLiteral)) - assertEquivalent( - EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)), - If(a, Literal(null, BooleanType), FalseLiteral)) - assertEquivalent( - EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType)), - Literal(null, BooleanType)) - assertEquivalent( - EqualTo(If(a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)), - Literal(null, BooleanType)) - } - - test("SPARK-33798: Push down EqualTo through CaseWhen") { - val a = EqualTo(UnresolvedAttribute("a"), Literal(100)) - val b = UnresolvedAttribute("b") - val c = EqualTo(UnresolvedAttribute("c"), Literal(true)) - val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3))) - - assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral) - assertEquivalent(EqualTo(caseWhen, Literal(3)), - CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) - assertEquivalent(EqualTo(caseWhen, Literal("4")), FalseLiteral) - assertEquivalent(EqualTo(caseWhen, Literal("3")), - CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) - assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal("1")), (c, Literal("2"))), None), Literal("4")), - CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None)) - - assertEquivalent( - And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))), - FalseLiteral) - - // Do not simplify if it contains non foldable expressions. - assertEquivalent(EqualTo(caseWhen, NonFoldableLiteral(true)), - EqualTo(caseWhen, NonFoldableLiteral(true))) - val nonFoldable = CaseWhen(Seq(normalBranch, (a, b)), None) - assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1))) - - // Do not simplify if it contains non-deterministic expressions. - val nonDeterministic = CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(b)) - assert(!nonDeterministic.deterministic) - assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) - - // Handle Null values. - assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)), - CaseWhen(Seq((a, Literal(null, BooleanType))), Some(FalseLiteral))) - assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)), - Literal(null, BooleanType)) - assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)), - CaseWhen(Seq((a, Literal(null, BooleanType))), Some(TrueLiteral))) - assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), - Literal(1)), - Literal(null, BooleanType)) - assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), - Literal(null, IntegerType)), - Literal(null, BooleanType)) - } } From d5135ec1f7b435d40c42089fa409abb7b73f9d68 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 17 Dec 2020 14:36:25 +0800 Subject: [PATCH 09/12] Fix --- .../org/apache/spark/sql/catalyst/optimizer/expressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 55003872af8e9..37146b52fbbc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -527,6 +527,7 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } } + /** * Push the foldable expression into (if / case) branches. */ @@ -549,7 +550,6 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { } - /** * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition. * For example, when the expression is just checking to see if a string starts with a given From 8ccc3c1f8804e34b46fd9bec19a017db61b37f54 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 17 Dec 2020 14:42:41 +0800 Subject: [PATCH 10/12] fix --- .../org/apache/spark/sql/catalyst/optimizer/expressions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 37146b52fbbc8..c46aa9b7f13d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -534,8 +534,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case b @ BinaryComparison(i @ If(_, trueValue, falseValue), right) - if i.deterministic && trueValue.foldable && falseValue.foldable && right.foldable => + case b @ BinaryComparison(i @ If(_, trueValue, falseValue), right) if i.deterministic && + right.foldable && trueValue.foldable && falseValue.foldable => i.copy( trueValue = b.makeCopy(Array(trueValue, right)), falseValue = b.makeCopy(Array(falseValue, right))) From 45b56fcc7b08a89893082b7b9b6c6e2f8e4747f7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 18 Dec 2020 10:55:12 +0800 Subject: [PATCH 11/12] fix --- .../sql/catalyst/expressions/Expression.scala | 5 + .../sql/catalyst/optimizer/expressions.scala | 22 ++++- .../PushFoldableIntoBranchesSuite.scala | 93 ++++++++++++++++--- 3 files changed, 104 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1d23953484046..65f89bbdd0599 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -636,6 +636,11 @@ abstract class BinaryExpression extends Expression { } +object BinaryExpression { + def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right)) +} + + /** * A [[BinaryExpression]] that is an operator, with two properties: * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index c46aa9b7f13d5..c33e39b8a8b7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, _} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull @@ -534,17 +534,29 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case b @ BinaryComparison(i @ If(_, trueValue, falseValue), right) if i.deterministic && - right.foldable && trueValue.foldable && falseValue.foldable => + case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right) if i.deterministic && + right.foldable && (trueValue.foldable || falseValue.foldable) => i.copy( trueValue = b.makeCopy(Array(trueValue, right)), falseValue = b.makeCopy(Array(falseValue, right))) - case b @ BinaryComparison(c @ CaseWhen(branches, elseValue), right) if c.deterministic && - right.foldable && (branches.map(_._2) ++ elseValue).forall(_.foldable) => + case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue)) if i.deterministic && + left.foldable && (trueValue.foldable || falseValue.foldable) => + i.copy( + trueValue = b.makeCopy(Array(left, trueValue)), + falseValue = b.makeCopy(Array(left, falseValue))) + + case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right) if c.deterministic && + right.foldable && (branches.map(_._2) ++ elseValue).exists(_.foldable) => c.copy( branches.map(e => e.copy(_2 = b.makeCopy(Array(e._2, right)))), elseValue.map(e => b.makeCopy(Array(e, right)))) + + case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue)) if c.deterministic && + left.foldable && (branches.map(_._2) ++ elseValue).exists(_.foldable) => + c.copy( + branches.map(e => e.copy(_2 = b.makeCopy(Array(left, e._2)))), + elseValue.map(e => b.makeCopy(Array(left, e)))) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 8ebab8412f716..de2f84e0659a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer +import java.sql.Date + import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -49,20 +51,18 @@ class PushFoldableIntoBranchesSuite comparePlans(actual, correctAnswer) } - private val normalBranch = (NonFoldableLiteral(true), Literal(10)) - test("SPARK-33798: Push down EqualTo through If") { assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral)) assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral) assertEquivalent(EqualTo(ifExp, Literal("3")), If(a, FalseLiteral, TrueLiteral)) - // Do not simplify if it contains non foldable expressions. + // Partially push down if it contains non foldable expressions. assertEquivalent( EqualTo(If(a, b, Literal(2)), Literal(2)), - EqualTo(If(a, b, Literal(2)), Literal(2))) + If(a, EqualTo(b, Literal(2)), TrueLiteral)) - // Do not simplify if it contains non-deterministic expressions. + // Do not push down if it contains non-deterministic expressions. val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(1)) assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) @@ -90,6 +90,29 @@ class PushFoldableIntoBranchesSuite assertEquivalent(LessThanOrEqual(ifExp, Literal(4)), TrueLiteral) } + test("SPARK-33798: Push down other BinaryOperator through If") { + assertEquivalent(Add(ifExp, Literal(4)), If(a, Literal(6), Literal(7))) + assertEquivalent(Subtract(ifExp, Literal(4)), If(a, Literal(-2), Literal(-1))) + assertEquivalent(Multiply(ifExp, Literal(4)), If(a, Literal(8), Literal(12))) + assertEquivalent(Pmod(ifExp, Literal(4)), If(a, Literal(2), Literal(3))) + assertEquivalent(Remainder(ifExp, Literal(4)), If(a, Literal(2), Literal(3))) + assertEquivalent(Divide(If(a, Literal(2.0), Literal(3.0)), Literal(1.0)), + If(a, Literal(2.0), Literal(3.0))) + assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), + If(a, FalseLiteral, TrueLiteral)) + assertEquivalent(Or(If(a, FalseLiteral, TrueLiteral), TrueLiteral), TrueLiteral) + } + + test("SPARK-33798: Push down other BinaryExpression through If") { + assertEquivalent(BRound(If(a, Literal(1.23), Literal(1.24)), Literal(1)), Literal(1.2)) + assertEquivalent(StartsWith(If(a, Literal("ab"), Literal("ac")), Literal("a")), TrueLiteral) + assertEquivalent(FindInSet(If(a, Literal("ab"), Literal("ac")), Literal("a")), Literal(0)) + assertEquivalent( + AddMonths(If(a, Literal(Date.valueOf("2020-01-01")), Literal(Date.valueOf("2021-01-01"))), + Literal(1)), + If(a, Literal(Date.valueOf("2020-02-01")), Literal(Date.valueOf("2021-02-01")))) + } + test("SPARK-33798: Push down EqualTo through CaseWhen") { assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral) assertEquivalent(EqualTo(caseWhen, Literal(3)), @@ -105,13 +128,12 @@ class PushFoldableIntoBranchesSuite And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))), FalseLiteral) - // Do not simplify if it contains non foldable expressions. - assertEquivalent(EqualTo(caseWhen, NonFoldableLiteral(true)), - EqualTo(caseWhen, NonFoldableLiteral(true))) - val nonFoldable = CaseWhen(Seq(normalBranch, (a, b)), None) - assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1))) + // Partially push down if it contains non foldable expressions. + val nonFoldable = CaseWhen(Seq((NonFoldableLiteral(true), Literal(10)), (a, b)), None) + assertEquivalent(EqualTo(nonFoldable, Literal(1)), + CaseWhen(Seq((NonFoldableLiteral(true), FalseLiteral), (a, EqualTo(b, Literal(1)))), None)) - // Do not simplify if it contains non-deterministic expressions. + // Do not push down if it contains non-deterministic expressions. val nonDeterministic = CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(b)) assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) @@ -143,4 +165,53 @@ class PushFoldableIntoBranchesSuite assertEquivalent(LessThan(caseWhen, Literal(4)), TrueLiteral) assertEquivalent(LessThanOrEqual(caseWhen, Literal(4)), TrueLiteral) } + + test("SPARK-33798: Push down other BinaryOperator through CaseWhen") { + assertEquivalent(Add(caseWhen, Literal(4)), + CaseWhen(Seq((a, Literal(5)), (c, Literal(6))), Some(Literal(7)))) + assertEquivalent(Subtract(caseWhen, Literal(4)), + CaseWhen(Seq((a, Literal(-3)), (c, Literal(-2))), Some(Literal(-1)))) + assertEquivalent(Multiply(caseWhen, Literal(4)), + CaseWhen(Seq((a, Literal(4)), (c, Literal(8))), Some(Literal(12)))) + assertEquivalent(Pmod(caseWhen, Literal(4)), + CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3)))) + assertEquivalent(Remainder(caseWhen, Literal(4)), + CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3)))) + assertEquivalent(Divide(CaseWhen(Seq((a, Literal(1.0)), (c, Literal(2.0))), Some(Literal(3.0))), + Literal(1.0)), + CaseWhen(Seq((a, Literal(1.0)), (c, Literal(2.0))), Some(Literal(3.0)))) + assertEquivalent(And(CaseWhen(Seq((a, FalseLiteral), (c, TrueLiteral)), Some(TrueLiteral)), + TrueLiteral), + CaseWhen(Seq((a, FalseLiteral), (c, TrueLiteral)), Some(TrueLiteral))) + assertEquivalent(Or(CaseWhen(Seq((a, FalseLiteral), (c, TrueLiteral)), Some(TrueLiteral)), + TrueLiteral), TrueLiteral) + } + + test("SPARK-33798: Push down other BinaryExpression through CaseWhen") { + assertEquivalent( + BRound(CaseWhen(Seq((a, Literal(1.23)), (c, Literal(1.24))), Some(Literal(1.25))), + Literal(1)), + Literal(1.2)) + assertEquivalent( + StartsWith(CaseWhen(Seq((a, Literal("ab")), (c, Literal("ac"))), Some(Literal("ad"))), + Literal("a")), + TrueLiteral) + assertEquivalent( + FindInSet(CaseWhen(Seq((a, Literal("ab")), (c, Literal("ac"))), Some(Literal("ad"))), + Literal("a")), + Literal(0)) + assertEquivalent( + AddMonths(CaseWhen(Seq((a, Literal(Date.valueOf("2020-01-01"))), + (c, Literal(Date.valueOf("2021-01-01")))), + Some(Literal(Date.valueOf("2022-01-01")))), + Literal(1)), + CaseWhen(Seq((a, Literal(Date.valueOf("2020-02-01"))), + (c, Literal(Date.valueOf("2021-02-01")))), + Some(Literal(Date.valueOf("2022-02-01"))))) + } + + test("SPARK-33798: Push down BinaryExpression through If/CaseWhen backwards") { + assertEquivalent(EqualTo(Literal(4), ifExp), FalseLiteral) + assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral) + } } From cfac0e8655755f770f725f7ec5cd2dc2db950ff6 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 18 Dec 2020 16:13:58 +0800 Subject: [PATCH 12/12] fix --- .../sql/catalyst/optimizer/expressions.scala | 24 ++++--- .../PushFoldableIntoBranchesSuite.scala | 62 +++++++++++-------- 2 files changed, 51 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index c33e39b8a8b7b..e6730c9275a1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -532,28 +532,36 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { * Push the foldable expression into (if / case) branches. */ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { + + // To be conservative here: it's only a guaranteed win if all but at most only one branch + // end up being not foldable. + private def atMostOneUnfoldable(exprs: Seq[Expression]): Boolean = { + val (foldables, others) = exprs.partition(_.foldable) + foldables.nonEmpty && others.length < 2 + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right) if i.deterministic && - right.foldable && (trueValue.foldable || falseValue.foldable) => + case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right) + if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) => i.copy( trueValue = b.makeCopy(Array(trueValue, right)), falseValue = b.makeCopy(Array(falseValue, right))) - case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue)) if i.deterministic && - left.foldable && (trueValue.foldable || falseValue.foldable) => + case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue)) + if left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) => i.copy( trueValue = b.makeCopy(Array(left, trueValue)), falseValue = b.makeCopy(Array(left, falseValue))) - case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right) if c.deterministic && - right.foldable && (branches.map(_._2) ++ elseValue).exists(_.foldable) => + case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right) + if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = b.makeCopy(Array(e._2, right)))), elseValue.map(e => b.makeCopy(Array(e, right)))) - case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue)) if c.deterministic && - left.foldable && (branches.map(_._2) ++ elseValue).exists(_.foldable) => + case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue)) + if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = b.makeCopy(Array(left, e._2)))), elseValue.map(e => b.makeCopy(Array(left, e)))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index de2f84e0659a8..43360af46ffb3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -51,21 +51,25 @@ class PushFoldableIntoBranchesSuite comparePlans(actual, correctAnswer) } - test("SPARK-33798: Push down EqualTo through If") { + test("Push down EqualTo through If") { assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral)) - assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral) - assertEquivalent(EqualTo(ifExp, Literal("3")), If(a, FalseLiteral, TrueLiteral)) - // Partially push down if it contains non foldable expressions. + // Push down at most one not foldable expressions. assertEquivalent( EqualTo(If(a, b, Literal(2)), Literal(2)), If(a, EqualTo(b, Literal(2)), TrueLiteral)) + assertEquivalent( + EqualTo(If(a, b, b + 1), Literal(2)), + EqualTo(If(a, b, b + 1), Literal(2))) - // Do not push down if it contains non-deterministic expressions. - val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(1)) + // Push down non-deterministic expressions. + val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(2)) assert(!nonDeterministic.deterministic) - assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) + assertEquivalent(EqualTo(nonDeterministic, Literal(2)), + If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, TrueLiteral)) + assertEquivalent(EqualTo(nonDeterministic, Literal(3)), + If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, FalseLiteral)) // Handle Null values. assertEquivalent( @@ -82,7 +86,7 @@ class PushFoldableIntoBranchesSuite Literal(null, BooleanType)) } - test("SPARK-33798: Push down other BinaryComparison through If") { + test("Push down other BinaryComparison through If") { assertEquivalent(EqualNullSafe(ifExp, Literal(4)), FalseLiteral) assertEquivalent(GreaterThan(ifExp, Literal(4)), FalseLiteral) assertEquivalent(GreaterThanOrEqual(ifExp, Literal(4)), FalseLiteral) @@ -90,7 +94,7 @@ class PushFoldableIntoBranchesSuite assertEquivalent(LessThanOrEqual(ifExp, Literal(4)), TrueLiteral) } - test("SPARK-33798: Push down other BinaryOperator through If") { + test("Push down other BinaryOperator through If") { assertEquivalent(Add(ifExp, Literal(4)), If(a, Literal(6), Literal(7))) assertEquivalent(Subtract(ifExp, Literal(4)), If(a, Literal(-2), Literal(-1))) assertEquivalent(Multiply(ifExp, Literal(4)), If(a, Literal(8), Literal(12))) @@ -103,7 +107,7 @@ class PushFoldableIntoBranchesSuite assertEquivalent(Or(If(a, FalseLiteral, TrueLiteral), TrueLiteral), TrueLiteral) } - test("SPARK-33798: Push down other BinaryExpression through If") { + test("Push down other BinaryExpression through If") { assertEquivalent(BRound(If(a, Literal(1.23), Literal(1.24)), Literal(1)), Literal(1.2)) assertEquivalent(StartsWith(If(a, Literal("ab"), Literal("ac")), Literal("a")), TrueLiteral) assertEquivalent(FindInSet(If(a, Literal("ab"), Literal("ac")), Literal("a")), Literal(0)) @@ -113,30 +117,34 @@ class PushFoldableIntoBranchesSuite If(a, Literal(Date.valueOf("2020-02-01")), Literal(Date.valueOf("2021-02-01")))) } - test("SPARK-33798: Push down EqualTo through CaseWhen") { + test("Push down EqualTo through CaseWhen") { assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral) assertEquivalent(EqualTo(caseWhen, Literal(3)), CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) - assertEquivalent(EqualTo(caseWhen, Literal("4")), FalseLiteral) - assertEquivalent(EqualTo(caseWhen, Literal("3")), - CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal("1")), (c, Literal("2"))), None), Literal("4")), + EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None), Literal(4)), CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None)) assertEquivalent( And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))), FalseLiteral) - // Partially push down if it contains non foldable expressions. - val nonFoldable = CaseWhen(Seq((NonFoldableLiteral(true), Literal(10)), (a, b)), None) - assertEquivalent(EqualTo(nonFoldable, Literal(1)), - CaseWhen(Seq((NonFoldableLiteral(true), FalseLiteral), (a, EqualTo(b, Literal(1)))), None)) - - // Do not push down if it contains non-deterministic expressions. - val nonDeterministic = CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(b)) + // Push down at most one branch is not foldable expressions. + assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, Literal(1))), None), Literal(1)), + CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), None)) + assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)), + EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1))) + assertEquivalent(EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)), + EqualTo(CaseWhen(Seq((a, b)), None), Literal(1))) + + // Push down non-deterministic expressions. + val nonDeterministic = + CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(Literal(2))) assert(!nonDeterministic.deterministic) - assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1))) + assertEquivalent(EqualTo(nonDeterministic, Literal(2)), + CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(TrueLiteral))) + assertEquivalent(EqualTo(nonDeterministic, Literal(3)), + CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(FalseLiteral))) // Handle Null values. assertEquivalent( @@ -158,7 +166,7 @@ class PushFoldableIntoBranchesSuite Literal(null, BooleanType)) } - test("SPARK-33798: Push down other BinaryComparison through CaseWhen") { + test("Push down other BinaryComparison through CaseWhen") { assertEquivalent(EqualNullSafe(caseWhen, Literal(4)), FalseLiteral) assertEquivalent(GreaterThan(caseWhen, Literal(4)), FalseLiteral) assertEquivalent(GreaterThanOrEqual(caseWhen, Literal(4)), FalseLiteral) @@ -166,7 +174,7 @@ class PushFoldableIntoBranchesSuite assertEquivalent(LessThanOrEqual(caseWhen, Literal(4)), TrueLiteral) } - test("SPARK-33798: Push down other BinaryOperator through CaseWhen") { + test("Push down other BinaryOperator through CaseWhen") { assertEquivalent(Add(caseWhen, Literal(4)), CaseWhen(Seq((a, Literal(5)), (c, Literal(6))), Some(Literal(7)))) assertEquivalent(Subtract(caseWhen, Literal(4)), @@ -187,7 +195,7 @@ class PushFoldableIntoBranchesSuite TrueLiteral), TrueLiteral) } - test("SPARK-33798: Push down other BinaryExpression through CaseWhen") { + test("Push down other BinaryExpression through CaseWhen") { assertEquivalent( BRound(CaseWhen(Seq((a, Literal(1.23)), (c, Literal(1.24))), Some(Literal(1.25))), Literal(1)), @@ -210,7 +218,7 @@ class PushFoldableIntoBranchesSuite Some(Literal(Date.valueOf("2022-02-01"))))) } - test("SPARK-33798: Push down BinaryExpression through If/CaseWhen backwards") { + test("Push down BinaryExpression through If/CaseWhen backwards") { assertEquivalent(EqualTo(Literal(4), ifExp), FalseLiteral) assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral) }