Skip to content

Commit

Permalink
Simplify EqualTo(CaseWhen/If, Literal) always false
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed Dec 16, 2020
1 parent 40c37d6 commit 19b0a83
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case _ => false
}

private def isAlwaysFalse(exps: Seq[Expression], equalTo: Literal): Boolean = {
exps.forall(!EqualTo(_, equalTo).eval(EmptyRow).asInstanceOf[Boolean])
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue
Expand Down Expand Up @@ -523,6 +527,15 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
} else {
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
}

case EqualTo(i @ If(_, trueValue: Literal, falseValue: Literal), right: Literal)
if i.deterministic && isAlwaysFalse(trueValue :: falseValue :: Nil, right) =>
FalseLiteral

case EqualTo(c @ CaseWhen(branches, elseValue), right: Literal) if c.deterministic &&
(branches.map(_._2) ++ elseValue).forall(_.isInstanceOf[Literal]) &&
isAlwaysFalse(branches.map(_._2) ++ elseValue, right) =>
FalseLiteral
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,97 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
If(Factorial(5) > 100L, b, nullLiteral).eval(EmptyRow))
}
}

test("SPARK-33798: simplify EqualTo(If, Literal) always false") {
val a = EqualTo(UnresolvedAttribute("a"), Literal(100))
val ifExp = If(a === Literal(1), Literal(2), Literal(3))

assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
assertEquivalent(EqualTo(ifExp, Literal(3)), EqualTo(ifExp, Literal(3)))
assertEquivalent(EqualTo(ifExp, Literal("4")), FalseLiteral)
assertEquivalent(EqualTo(ifExp, Literal("3")), EqualTo(ifExp, Literal(3)))

// Do not simplify if it contains non foldable expressions.
assertEquivalent(EqualTo(ifExp, NonFoldableLiteral(true)),
EqualTo(ifExp, NonFoldableLiteral(true)))
val nonFoldable = If(NonFoldableLiteral(true), Literal(1), Literal(2))
assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1)))

// Do not simplify if it contains non-deterministic expressions.
val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(1))
assert(!nonDeterministic.deterministic)
assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1)))

// null check, SPARK-33798 will not change these behaviors.
assertEquivalent(
EqualTo(If(FalseLiteral, Literal(null, IntegerType), Literal(1)), Literal(1)),
TrueLiteral)
assertEquivalent(
EqualTo(If(TrueLiteral, Literal(null, IntegerType), Literal(1)), Literal(1)),
Literal(null, BooleanType))
assertEquivalent(
EqualTo(If(FalseLiteral, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)),
Literal(null, BooleanType))

assertEquivalent(
EqualTo(If(FalseLiteral, Literal(1), Literal(2)), Literal(null, IntegerType)),
Literal(null, BooleanType))
assertEquivalent(
EqualTo(If(TrueLiteral, Literal(1), Literal(2)), Literal(null, IntegerType)),
Literal(null, BooleanType))
}

test("SPARK-33798: simplify EqualTo(CaseWhen, Literal) always false") {
val a = EqualTo(UnresolvedAttribute("a"), Literal(100))
val b = UnresolvedAttribute("b")
val c = EqualTo(UnresolvedAttribute("c"), Literal(true))
val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3)))

assertEquivalent(EqualTo(caseWhen, Literal(4)), FalseLiteral)
assertEquivalent(EqualTo(caseWhen, Literal(3)), EqualTo(caseWhen, Literal(3)))
assertEquivalent(EqualTo(caseWhen, Literal("4")), FalseLiteral)
assertEquivalent(EqualTo(caseWhen, Literal("3")), EqualTo(caseWhen, Literal(3)))
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal("1")), (c, Literal("2"))), None), Literal("4")),
FalseLiteral)

assertEquivalent(
And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))),
FalseLiteral)

assertEquivalent(
EqualTo(CaseWhen(Seq(normalBranch, (a, Literal(1)), (c, Literal(1))), None), Literal(-1)),
FalseLiteral)

// Do not simplify if it contains non foldable expressions.
assertEquivalent(EqualTo(caseWhen, NonFoldableLiteral(true)),
EqualTo(caseWhen, NonFoldableLiteral(true)))
val nonFoldable = CaseWhen(Seq(normalBranch, (a, b)), None)
assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1)))

// Do not simplify if it contains non-deterministic expressions.
val nonDeterministic = CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(b))
assert(!nonDeterministic.deterministic)
assertEquivalent(EqualTo(nonDeterministic, Literal(-1)), EqualTo(nonDeterministic, Literal(-1)))

// null check, SPARK-33798 will change the following two behaviors.
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)),
FalseLiteral)
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)),
FalseLiteral)

assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)),
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)))
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
Literal(1)),
Literal(null, BooleanType))
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
Literal(null, IntegerType)),
Literal(null, BooleanType))
}
}

0 comments on commit 19b0a83

Please sign in to comment.