Skip to content

Commit

Permalink
[SPARK-33847][SQL][FOLLOWUP] Remove the CaseWhen should consider dete…
Browse files Browse the repository at this point in the history
…rministic

### What changes were proposed in this pull request?

This pr fix remove the `CaseWhen` if elseValue is empty and other outputs are null because of we should consider deterministic.

### Why are the changes needed?

Fix bug.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit test.

Closes #30960 from wangyum/SPARK-33847-2.

Authored-by: Yuming Wang <yumwang@ebay.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
wangyum authored and cloud-fan committed Dec 29, 2020
1 parent 16c594d commit c425024
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
val newBranches = cw.branches.map { case (cond, value) =>
replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
}
if (newBranches.forall(_._2 == FalseLiteral) && cw.elseValue.isEmpty) {
FalseLiteral
} else {
val newElseValue = cw.elseValue.map(replaceNullWithFalse)
CaseWhen(newBranches, newElseValue)
}
val newElseValue = cw.elseValue.map(replaceNullWithFalse).getOrElse(FalseLiteral)
CaseWhen(newBranches, newElseValue)
case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
case e if e.dataType == BooleanType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
val (h, t) = branches.span(_._1 != TrueLiteral)
CaseWhen( h :+ t.head, None)

case e @ CaseWhen(branches, Some(elseValue))
if branches.forall(_._2.semanticEquals(elseValue)) =>
case e @ CaseWhen(branches, elseOpt)
if branches.forall(_._2.semanticEquals(elseOpt.getOrElse(Literal(null, e.dataType)))) =>
val elseValue = elseOpt.getOrElse(Literal(null, e.dataType))
// For non-deterministic conditions with side effect, we can not remove it, or change
// the ordering. As a result, we try to remove the deterministic conditions from the tail.
var hitNonDeterministicCond = false
Expand All @@ -532,10 +533,6 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
} else {
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
}

case e @ CaseWhen(branches, None)
if branches.forall(_._2.semanticEquals(Literal(null, e.dataType))) =>
Literal(null, e.dataType)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,13 @@ class PushFoldableIntoBranchesSuite
}

test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") {
Seq(a, LessThan(Rand(1), Literal(0.5))).foreach { condition =>
assertEquivalent(
EqualTo(CaseWhen(Seq((condition, Literal.create(null, IntegerType)))), Literal(2)),
Literal.create(null, BooleanType))
assertEquivalent(
EqualTo(CaseWhen(Seq((condition, Literal("str")))).cast(IntegerType), Literal(2)),
Literal.create(null, BooleanType))
}
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal.create(null, IntegerType)))), Literal(2)),
Literal.create(null, BooleanType))
assertEquivalent(
EqualTo(CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal("str")))).cast(IntegerType),
Literal(2)),
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal.create(null, BooleanType)))))
}

test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
val expectedBranches = Seq(
(UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
(UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
val expectedCond = CaseWhen(expectedBranches)
val expectedCond = CaseWhen(expectedBranches, FalseLiteral)

testFilter(originalCond, expectedCond)
testJoin(originalCond, expectedCond)
Expand All @@ -135,7 +135,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
(UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
(UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral,
TrueLiteral -> TrueLiteral)
val expectedCond = CaseWhen(expectedBranches)
val expectedCond = CaseWhen(expectedBranches, FalseLiteral)

testFilter(originalCond, expectedCond)
testJoin(originalCond, expectedCond)
Expand Down Expand Up @@ -238,7 +238,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
FalseLiteral)
val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue))
val expectedCond = CaseWhen(Seq(
(UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral)))
(UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral)),
FalseLiteral)
testFilter(originalCond = condition, expectedCond = expectedCond)
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,13 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
}

test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") {
Seq(GreaterThan('a, 1), GreaterThan(Rand(0), 1)).foreach { condition =>
assertEquivalent(
CaseWhen((condition, Literal.create(null, IntegerType)) :: Nil, None),
Literal.create(null, IntegerType))
}
assertEquivalent(
CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil, None),
Literal.create(null, IntegerType))

assertEquivalent(
CaseWhen((GreaterThan(Rand(0), 0.5), Literal.create(null, IntegerType)) :: Nil, None),
CaseWhen((GreaterThan(Rand(0), 0.5), Literal.create(null, IntegerType)) :: Nil, None))
}

test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {
Expand Down

0 comments on commit c425024

Please sign in to comment.