Skip to content

Commit

Permalink
[SPARK-26706][SQL] Fix illegalNumericPrecedence for ByteType
Browse files Browse the repository at this point in the history
This PR contains a minor change in `Cast$mayTruncate` that fixes its logic for bytes.

Right now, `mayTruncate(ByteType, LongType)` returns `false` while `mayTruncate(ShortType, LongType)` returns `true`. Consequently, `spark.range(1, 3).as[Byte]` and `spark.range(1, 3).as[Short]` behave differently.

Potentially, this bug can silently corrupt someone's data.
```scala
// executes silently even though Long is converted into Byte
spark.range(Long.MaxValue - 10, Long.MaxValue).as[Byte]
  .map(b => b - 1)
  .show()
+-----+
|value|
+-----+
|  -12|
|  -11|
|  -10|
|   -9|
|   -8|
|   -7|
|   -6|
|   -5|
|   -4|
|   -3|
+-----+
// throws an AnalysisException: Cannot up cast `id` from bigint to smallint as it may truncate
spark.range(Long.MaxValue - 10, Long.MaxValue).as[Short]
  .map(s => s - 1)
  .show()
```

This PR comes with a set of unit tests.

Closes apache#23632 from aokolnychyi/cast-fix.

Authored-by: Anton Okolnychyi <aokolnychyi@apple.com>
Signed-off-by: DB Tsai <d_tsai@apple.com>
  • Loading branch information
aokolnychyi authored and kai-chi committed Aug 1, 2019
1 parent 41d0645 commit 55f83da
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ object Cast {
private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from)
val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to)
toPrecedence > 0 && fromPrecedence > toPrecedence
toPrecedence >= 0 && fromPrecedence > toPrecedence
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.{Calendar, Locale, TimeZone}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
Expand Down Expand Up @@ -953,4 +954,39 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType)
checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]")
}

test("SPARK-26706: Fix Cast.mayTruncate for bytes") {
assert(!Cast.mayTruncate(ByteType, ByteType))
assert(!Cast.mayTruncate(DecimalType.ByteDecimal, ByteType))
assert(Cast.mayTruncate(ShortType, ByteType))
assert(Cast.mayTruncate(IntegerType, ByteType))
assert(Cast.mayTruncate(LongType, ByteType))
assert(Cast.mayTruncate(FloatType, ByteType))
assert(Cast.mayTruncate(DoubleType, ByteType))
assert(Cast.mayTruncate(DecimalType.IntDecimal, ByteType))
}

test("canSafeCast and mayTruncate must be consistent for numeric types") {
import DataTypeTestUtils._

def isCastSafe(from: NumericType, to: NumericType): Boolean = (from, to) match {
case (_, dt: DecimalType) => dt.isWiderThan(from)
case (dt: DecimalType, _) => dt.isTighterThan(to)
case _ => numericPrecedence.indexOf(from) <= numericPrecedence.indexOf(to)
}

numericTypes.foreach { from =>
val (safeTargetTypes, unsafeTargetTypes) = numericTypes.partition(to => isCastSafe(from, to))

safeTargetTypes.foreach { to =>
assert(Cast.canSafeCast(from, to), s"It should be possible to safely cast $from to $to")
assert(!Cast.mayTruncate(from, to), s"No truncation is expected when casting $from to $to")
}

unsafeTargetTypes.foreach { to =>
assert(!Cast.canSafeCast(from, to), s"It shouldn't be possible to safely cast $from to $to")
assert(Cast.mayTruncate(from, to), s"Truncation is expected when casting $from to $to")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1567,6 +1567,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c")
checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null)))
}

test("SPARK-26706: Fix Cast.mayTruncate for bytes") {
val thrownException = intercept[AnalysisException] {
spark.range(Long.MaxValue - 10, Long.MaxValue).as[Byte]
.map(b => b - 1)
.collect()
}
assert(thrownException.message.contains("Cannot up cast `id` from bigint to tinyint"))
}
}

case class TestDataUnion(x: Int, y: Int, z: Int)
Expand Down

0 comments on commit 55f83da

Please sign in to comment.