Skip to content

Commit

Permalink
[SPARK-33845][SQL][FOLLOWUP] fix SimplifyConditionals
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This is a followup of #30849, to fix a correctness issue caused by null value handling.

### Why are the changes needed?

Fix a correctness issue. `If(null, true, false)` should return false, not true.

### Does this PR introduce _any_ user-facing change?

Yes, but the bug only exist in the master branch.

### How was this patch tested?

updated tests.

Closes #30953 from cloud-fan/bug.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
cloud-fan authored and dongjoon-hyun committed Dec 29, 2020
1 parent 6497ccb commit c2eac1d
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If}
import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or}
import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, EqualNullSafe, Expression, If, LambdaFunction, Literal, MapFilter, Or}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter, Join, LogicalPlan, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.BooleanType
Expand Down Expand Up @@ -56,6 +55,12 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond)))
case p: LogicalPlan => p transformExpressions {
// For `EqualNullSafe` with a `TrueLiteral`, whether the other side is null or false has no
// difference, as `null <=> true` and `false <=> true` both return false.
case EqualNullSafe(left, TrueLiteral) =>
EqualNullSafe(replaceNullWithFalse(left), TrueLiteral)
case EqualNullSafe(TrueLiteral, right) =>
EqualNullSafe(TrueLiteral, replaceNullWithFalse(right))
case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
case cw @ CaseWhen(branches, _) =>
val newBranches = branches.map { case (cond, value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,10 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
case If(Literal(null, _), _, falseValue) => falseValue
case If(cond, TrueLiteral, FalseLiteral) => cond
case If(cond, FalseLiteral, TrueLiteral) => Not(cond)
case If(cond, TrueLiteral, FalseLiteral) =>
if (cond.nullable) EqualNullSafe(cond, TrueLiteral) else cond
case If(cond, FalseLiteral, TrueLiteral) =>
if (cond.nullable) Not(EqualNullSafe(cond, TrueLiteral)) else Not(cond)
case If(cond, trueValue, falseValue)
if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue
case If(cond, l @ Literal(null, _), FalseLiteral) if !cond.nullable => And(cond, l)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PushFoldableIntoBranchesSuite

test("Push down EqualTo through If") {
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
assertEquivalent(EqualTo(ifExp, Literal(3)), Not(a))
assertEquivalent(EqualTo(ifExp, Literal(3)), Not(a <=> TrueLiteral))

// Push down at most one not foldable expressions.
assertEquivalent(
Expand Down Expand Up @@ -102,7 +102,7 @@ class PushFoldableIntoBranchesSuite
assertEquivalent(Remainder(ifExp, Literal(4)), If(a, Literal(2), Literal(3)))
assertEquivalent(Divide(If(a, Literal(2.0), Literal(3.0)), Literal(1.0)),
If(a, Literal(2.0), Literal(3.0)))
assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), Not(a))
assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), Not(a <=> TrueLiteral))
assertEquivalent(Or(If(a, FalseLiteral, TrueLiteral), TrueLiteral), TrueLiteral)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable}
Expand Down Expand Up @@ -237,8 +237,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
TrueLiteral,
FalseLiteral)
val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue))
val expectedCond =
CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> (Literal(2) === nestedCaseWhen)))
val expectedCond = CaseWhen(Seq(
(UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral)))
testFilter(originalCond = condition, expectedCond = expectedCond)
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
Expand All @@ -253,10 +253,10 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
Literal(3)),
TrueLiteral,
FalseLiteral)
val expectedCond = Literal(5) > If(
val expectedCond = (Literal(5) > If(
UnresolvedAttribute("i") === Literal(15),
Literal(null, IntegerType),
Literal(3))
Literal(3))) <=> TrueLiteral
testFilter(originalCond = condition, expectedCond = expectedCond)
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
Expand Down Expand Up @@ -443,9 +443,9 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
val lambda1 = LambdaFunction(
function = If(cond, Literal(null, BooleanType), TrueLiteral),
arguments = lambdaArgs)
// the optimized lambda body is: if(arg > 0, false, true) => arg <= 0
// the optimized lambda body is: if(arg > 0, false, true) => !((arg > 0) <=> true)
val lambda2 = LambdaFunction(
function = LessThanOrEqual(condArg, Literal(0)),
function = !(cond <=> TrueLiteral),
arguments = lambdaArgs)
testProjection(
originalExpr = createExpr(argument, lambda1) as 'x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,39 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
}

test("SPARK-33845: remove unnecessary if when the outputs are boolean type") {
assertEquivalent(
If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral),
IsNotNull(UnresolvedAttribute("a")))
assertEquivalent(
If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
IsNull(UnresolvedAttribute("a")))
// verify the boolean equivalence of all transformations involved
val fields = Seq(
'cond.boolean.notNull,
'cond_nullable.boolean,
'a.boolean,
'b.boolean
)
val Seq(cond, cond_nullable, a, b) = fields.zipWithIndex.map { case (f, i) => f.at(i) }

val exprs = Seq(
// actual expressions of the transformations: original -> transformed
If(cond, true, false) -> cond,
If(cond, false, true) -> !cond,
If(cond_nullable, true, false) -> (cond_nullable <=> true),
If(cond_nullable, false, true) -> (!(cond_nullable <=> true)))

// check plans
for ((originalExpr, expectedExpr) <- exprs) {
assertEquivalent(originalExpr, expectedExpr)
}

assertEquivalent(
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral),
GreaterThan(Rand(0), UnresolvedAttribute("a")))
assertEquivalent(
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
LessThanOrEqual(Rand(0), UnresolvedAttribute("a")))
// check evaluation
val binaryBooleanValues = Seq(true, false)
val ternaryBooleanValues = Seq(true, false, null)
for (condVal <- binaryBooleanValues;
condNullableVal <- ternaryBooleanValues;
aVal <- ternaryBooleanValues;
bVal <- ternaryBooleanValues;
(originalExpr, expectedExpr) <- exprs) {
val inputRow = create_row(condVal, condNullableVal, aVal, bVal)
val optimizedVal = evaluateWithoutCodegen(expectedExpr, inputRow)
checkEvaluation(originalExpr, optimizedVal, inputRow)
}
}

test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") {
Expand Down

0 comments on commit c2eac1d

Please sign in to comment.