Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed Dec 29, 2020
1 parent 7fa0df6 commit 41a8fed
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
val newBranches = cw.branches.map { case (cond, value) =>
replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
}
if (newBranches.forall(_._2 == FalseLiteral) && cw.elseValue.isEmpty) {
FalseLiteral
} else {
val newElseValue = cw.elseValue.map(replaceNullWithFalse)
CaseWhen(newBranches, newElseValue)
}
val newElseValue = cw.elseValue.map(replaceNullWithFalse).getOrElse(FalseLiteral)
CaseWhen(newBranches, newElseValue)
case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
case e if e.dataType == BooleanType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
val expectedBranches = Seq(
(UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
(UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
val expectedCond = CaseWhen(expectedBranches)
val expectedCond = CaseWhen(expectedBranches, FalseLiteral)

testFilter(originalCond, expectedCond)
testJoin(originalCond, expectedCond)
Expand All @@ -135,7 +135,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
(UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
(UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral,
TrueLiteral -> TrueLiteral)
val expectedCond = CaseWhen(expectedBranches)
val expectedCond = CaseWhen(expectedBranches, FalseLiteral)

testFilter(originalCond, expectedCond)
testJoin(originalCond, expectedCond)
Expand Down Expand Up @@ -238,7 +238,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
FalseLiteral)
val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue))
val expectedCond = CaseWhen(Seq(
(UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral)))
(UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral)),
FalseLiteral)
testFilter(originalCond = condition, expectedCond = expectedCond)
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
Expand Down

0 comments on commit 41a8fed

Please sign in to comment.