Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Spark-compatible CAST between integer types #340

Merged
merged 16 commits into from
May 3, 2024
Merged
9 changes: 9 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ pub enum CometError {
to_type: String,
},

#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
CastOverFlow {
value: String,
from_type: String,
to_type: String,
},

#[error(transparent)]
Arrow {
#[from]
Expand Down
98 changes: 98 additions & 0 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,62 @@ macro_rules! cast_utf8_to_timestamp {
}};
}

macro_rules! cast_int_to_int_macro {
(
$array: expr,
$eval_mode:expr,
$from_arrow_primitive_type: ty,
$to_arrow_primitive_type: ty,
$from_data_type: expr,
$to_native_type: ty,
$spark_from_data_type_name: expr,
$spark_to_data_type_name: expr
) => {{
let cast_array = $array
.as_any()
.downcast_ref::<PrimitiveArray<$from_arrow_primitive_type>>()
.unwrap();
let spark_int_literal_suffix = match $from_data_type {
&DataType::Int64 => "L",
&DataType::Int16 => "S",
&DataType::Int8 => "T",
_ => "",
};

let output_array = match $eval_mode {
EvalMode::Legacy => cast_array
.iter()
.map(|value| match value {
Some(value) => {
Ok::<Option<$to_native_type>, CometError>(Some(value as $to_native_type))
}
_ => Ok(None),
})
.collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
_ => cast_array
.iter()
.map(|value| match value {
Some(value) => {
let res = <$to_native_type>::try_from(value);
if res.is_err() {
Err(CometError::CastOverFlow {
value: value.to_string() + spark_int_literal_suffix,
from_type: $spark_from_data_type_name.to_string(),
to_type: $spark_to_data_type_name.to_string(),
})
} else {
Ok::<Option<$to_native_type>, CometError>(Some(res.unwrap()))
}
}
_ => Ok(None),
})
.collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
}?;
let result: CometResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef);
result
}};
}

impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -149,6 +205,16 @@ impl Cast {
(DataType::Utf8, DataType::Timestamp(_, _)) => {
Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)?
}
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
| (DataType::Int64, DataType::Int8)
| (DataType::Int32, DataType::Int16)
| (DataType::Int32, DataType::Int8)
| (DataType::Int16, DataType::Int8)
if self.eval_mode != EvalMode::Try =>
{
Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)?
}
(
DataType::Utf8,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
Expand Down Expand Up @@ -248,6 +314,38 @@ impl Cast {
Ok(cast_array)
}

fn spark_cast_int_to_int(
array: &dyn Array,
eval_mode: EvalMode,
from_type: &DataType,
to_type: &DataType,
) -> CometResult<ArrayRef> {
match (from_type, to_type) {
(DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT"
),
(DataType::Int64, DataType::Int16) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT"
),
(DataType::Int64, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT"
),
(DataType::Int32, DataType::Int16) => cast_int_to_int_macro!(
array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT"
),
(DataType::Int32, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT"
),
(DataType::Int16, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT"
),
_ => unreachable!(
"{}",
format!("invalid integer type {to_type} in cast from {from_type}")
),
}
}

fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
Expand Down
26 changes: 18 additions & 8 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateShorts(), DataTypes.BooleanType)
}

ignore("cast ShortType to ByteType") {
test("cast ShortType to ByteType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateShorts(), DataTypes.ByteType)
}
Expand Down Expand Up @@ -215,12 +215,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateInts(), DataTypes.BooleanType)
}

ignore("cast IntegerType to ByteType") {
test("cast IntegerType to ByteType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateInts(), DataTypes.ByteType)
}

ignore("cast IntegerType to ShortType") {
test("cast IntegerType to ShortType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateInts(), DataTypes.ShortType)
}
Expand Down Expand Up @@ -257,17 +257,17 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateLongs(), DataTypes.BooleanType)
}

ignore("cast LongType to ByteType") {
test("cast LongType to ByteType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateLongs(), DataTypes.ByteType)
}

ignore("cast LongType to ShortType") {
test("cast LongType to ShortType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateLongs(), DataTypes.ShortType)
}

ignore("cast LongType to IntegerType") {
test("cast LongType to IntegerType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateLongs(), DataTypes.IntegerType)
}
Expand Down Expand Up @@ -868,11 +868,21 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
} else {
// Spark 3.2 and 3.3 have a different error message format so we can't do a direct
// comparison between Spark and Comet.
// In the case of CAST_INVALID_INPUT
// Spark message is in format `invalid input syntax for type TYPE: VALUE`
// Comet message is in format `The value 'VALUE' of the type FROM_TYPE cannot be cast to TO_TYPE`
// We just check that the comet message contains the same invalid value as the Spark message
val sparkInvalidValue = sparkMessage.substring(sparkMessage.indexOf(':') + 2)
assert(cometMessage.contains(sparkInvalidValue))
// In the case of CAST_OVERFLOW
// Spark message is in format `Casting VALUE to TO_TYPE causes overflow`
// Comet message is in format `The value 'VALUE' of the type FROM_TYPE cannot be cast to TO_TYPE
// due to an overflow`
// We check if the comet message contains 'overflow'.
if (sparkMessage.indexOf(':') == -1) {
assert(cometMessage.contains("overflow"))
} else {
assert(
cometMessage.contains(sparkMessage.substring(sparkMessage.indexOf(':') + 2)))
}
}
}

Expand Down
Loading