Skip to content

Commit

Permalink
[SPARK-48016][SQL][3.4] Fix a bug in try_divide function when with de…
Browse files Browse the repository at this point in the history
…cimals

### What changes were proposed in this pull request?

 Currently, the following query will throw DIVIDE_BY_ZERO error instead of returning null
 ```
SELECT try_divide(1, decimal(0));
```

This is caused by the rule `DecimalPrecision`:
```
case b  BinaryOperator(left, right) if left.dataType != right.dataType =>
  (left, right) match {
 ...
    case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] &&
        l.dataType.isInstanceOf[IntegralType] &&
        literalPickMinimumPrecision =>
      b.makeCopy(Array(Cast(l, DataTypeUtils.fromLiteral(l)), r))
```
The result of the above makeCopy will contain `ANSI` as the `evalMode`, instead of `TRY`.
This PR is to fix this bug by replacing the makeCopy method calls with withNewChildren

### Why are the changes needed?

Bug fix in try_* functions.

### Does this PR introduce _any_ user-facing change?

Yes, it fixes a long-standing bug in the try_divide function.

### How was this patch tested?

New UT

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#46289 from gengliangwang/PICK_PR_46286_BRANCH-3.4.

Authored-by: Gengliang Wang <gengliang@apache.org>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
  • Loading branch information
gengliangwang committed Apr 30, 2024
1 parent e2f34c7 commit 2870c76
Show file tree
Hide file tree
Showing 7 changed files with 1,130 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ object DecimalPrecision extends TypeCoercionRule {
val resultType = widerDecimalType(p1, s1, p2, s2)
val newE1 = if (e1.dataType == resultType) e1 else Cast(e1, resultType)
val newE2 = if (e2.dataType == resultType) e2 else Cast(e2, resultType)
b.makeCopy(Array(newE1, newE2))
b.withNewChildren(Seq(newE1, newE2))
}

/**
Expand Down Expand Up @@ -201,21 +201,21 @@ object DecimalPrecision extends TypeCoercionRule {
case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] &&
l.dataType.isInstanceOf[IntegralType] &&
literalPickMinimumPrecision =>
b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r))
b.withNewChildren(Seq(Cast(l, DecimalType.fromLiteral(l)), r))
case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] &&
r.dataType.isInstanceOf[IntegralType] &&
literalPickMinimumPrecision =>
b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r))))
b.withNewChildren(Seq(l, Cast(r, DecimalType.fromLiteral(r))))
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) =>
b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r))
b.withNewChildren(Seq(Cast(l, DecimalType.forType(l.dataType)), r))
case (l @ DecimalType.Expression(_, _), r @ IntegralType()) =>
b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType))))
b.withNewChildren(Seq(l, Cast(r, DecimalType.forType(r.dataType))))
case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) =>
b.makeCopy(Array(l, Cast(r, DoubleType)))
b.withNewChildren(Seq(l, Cast(r, DoubleType)))
case (l @ DecimalType.Expression(_, _), r) if isFloat(r.dataType) =>
b.makeCopy(Array(Cast(l, DoubleType), r))
b.withNewChildren(Seq(Cast(l, DoubleType), r))
case _ => b
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1106,22 +1106,22 @@ object TypeCoercion extends TypeCoercionBase {

case a @ BinaryArithmetic(left @ StringType(), right)
if right.dataType != CalendarIntervalType =>
a.makeCopy(Array(Cast(left, DoubleType), right))
a.withNewChildren(Seq(Cast(left, DoubleType), right))
case a @ BinaryArithmetic(left, right @ StringType())
if left.dataType != CalendarIntervalType =>
a.makeCopy(Array(left, Cast(right, DoubleType)))
a.withNewChildren(Seq(left, Cast(right, DoubleType)))

// For equality between string and timestamp we cast the string to a timestamp
// so that things like rounding of subsecond precision does not affect the comparison.
case p @ Equality(left @ StringType(), right @ TimestampType()) =>
p.makeCopy(Array(Cast(left, TimestampType), right))
p.withNewChildren(Seq(Cast(left, TimestampType), right))
case p @ Equality(left @ TimestampType(), right @ StringType()) =>
p.makeCopy(Array(left, Cast(right, TimestampType)))
p.withNewChildren(Seq(left, Cast(right, TimestampType)))

case p @ BinaryComparison(left, right)
if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined =>
val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))
p.withNewChildren(Seq(castExpr(left, commonType), castExpr(right, commonType)))
}
}

Expand Down
Loading

0 comments on commit 2870c76

Please sign in to comment.