From 41a8fed86cddb0ea381aa387de7da7a24ff3fe4b Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 29 Dec 2020 16:27:00 +0800 Subject: [PATCH] fix --- .../optimizer/ReplaceNullWithFalseInPredicate.scala | 8 ++------ .../optimizer/ReplaceNullWithFalseInPredicateSuite.scala | 7 ++++--- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index df3da3e8a9982..2f95f242c851c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -98,12 +98,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { val newBranches = cw.branches.map { case (cond, value) => replaceNullWithFalse(cond) -> replaceNullWithFalse(value) } - if (newBranches.forall(_._2 == FalseLiteral) && cw.elseValue.isEmpty) { - FalseLiteral - } else { - val newElseValue = cw.elseValue.map(replaceNullWithFalse) - CaseWhen(newBranches, newElseValue) - } + val newElseValue = cw.elseValue.map(replaceNullWithFalse).getOrElse(FalseLiteral) + CaseWhen(newBranches, newElseValue) case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) case e if e.dataType == BooleanType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index ae97d53256837..ffab358721e1a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -114,7 +114,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val expectedBranches = Seq( (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) - val expectedCond = CaseWhen(expectedBranches) + val expectedCond = CaseWhen(expectedBranches, FalseLiteral) testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) @@ -135,7 +135,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, (UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral, TrueLiteral -> TrueLiteral) - val expectedCond = CaseWhen(expectedBranches) + val expectedCond = CaseWhen(expectedBranches, FalseLiteral) testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) @@ -238,7 +238,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { FalseLiteral) val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)) val expectedCond = CaseWhen(Seq( - (UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral))) + (UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral)), + FalseLiteral) testFilter(originalCond = condition, expectedCond = expectedCond) testJoin(originalCond = condition, expectedCond = expectedCond) testDelete(originalCond = condition, expectedCond = expectedCond)