Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Spark-compatible CAST float/double to string #346

Merged
merged 12 commits into from
May 3, 2024
103 changes: 101 additions & 2 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use std::{
any::Any,
fmt::{Display, Formatter},
fmt::{Debug, Display, Formatter},
hash::{Hash, Hasher},
sync::Arc,
};
Expand All @@ -31,7 +31,8 @@ use arrow::{
};
use arrow_array::{
types::{Int16Type, Int32Type, Int64Type, Int8Type},
Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
Array, ArrayRef, BooleanArray, Float32Array, Float64Array, GenericStringArray, OffsetSizeTrait,
PrimitiveArray,
};
use arrow_schema::{DataType, Schema};
use chrono::{TimeZone, Timelike};
Expand Down Expand Up @@ -107,6 +108,72 @@ macro_rules! cast_utf8_to_timestamp {
}};
}

macro_rules! cast_float_to_string {
($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{

fn cast<OffsetSize>(
from: &dyn Array,
_eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait, {
let array = from.as_any().downcast_ref::<$output_type>().unwrap();

// If the absolute number is less than 10,000,000 and greater or equal than 0.001, the
// result is expressed without scientific notation with at least one digit on either side of
// the decimal point. Otherwise, Spark uses a mantissa followed by E and an
// exponent. The mantissa has an optional leading minus sign followed by one digit to the
// left of the decimal point, and the minimal number of digits greater than zero to the
// right. The exponent has and optional leading minus sign.
// source: https://docs.databricks.com/en/sql/language-manual/functions/cast.html

const LOWER_SCIENTIFIC_BOUND: $type = 0.001;
const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0;

let output_array = array
.iter()
.map(|value| match value {
Some(value) if value == <$type>::INFINITY => Ok(Some("Infinity".to_string())),
Some(value) if value == <$type>::NEG_INFINITY => Ok(Some("-Infinity".to_string())),
Some(value)
if (value.abs() < UPPER_SCIENTIFIC_BOUND
&& value.abs() >= LOWER_SCIENTIFIC_BOUND)
|| value.abs() == 0.0 =>
{
let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" };

Ok(Some(format!("{value}{trailing_zero}")))
}
Some(value)
if value.abs() >= UPPER_SCIENTIFIC_BOUND
|| value.abs() < LOWER_SCIENTIFIC_BOUND =>
{
let formatted = format!("{value:E}");

if formatted.contains(".") {
Ok(Some(formatted))
} else {
let prepare_number: Vec<&str> = formatted.split("E").collect();
mattharder91 marked this conversation as resolved.
Show resolved Hide resolved

let coefficient = prepare_number[0];

let exponent = prepare_number[1];

Ok(Some(format!("{coefficient}.0E{exponent}")))
}
}
Some(value) => Ok(Some(value.to_string())),
_ => Ok(None),
})
.collect::<Result<GenericStringArray<OffsetSize>, CometError>>()?;

Ok(Arc::new(output_array))
}

cast::<$offset_type>($from, $eval_mode)
}};
}

impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -185,6 +252,18 @@ impl Cast {
),
}
}
(DataType::Float64, DataType::Utf8) => {
Self::spark_cast_float64_to_utf8::<i32>(&array, self.eval_mode)?
}
(DataType::Float64, DataType::LargeUtf8) => {
Self::spark_cast_float64_to_utf8::<i64>(&array, self.eval_mode)?
}
(DataType::Float32, DataType::Utf8) => {
Self::spark_cast_float32_to_utf8::<i32>(&array, self.eval_mode)?
}
(DataType::Float32, DataType::LargeUtf8) => {
Self::spark_cast_float32_to_utf8::<i64>(&array, self.eval_mode)?
}
_ => {
// when we have no Spark-specific casting we delegate to DataFusion
cast_with_options(&array, to_type, &CAST_OPTIONS)?
Expand Down Expand Up @@ -248,6 +327,26 @@ impl Cast {
Ok(cast_array)
}

fn spark_cast_float64_to_utf8<OffsetSize>(
from: &dyn Array,
_eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
{
cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize)
}

fn spark_cast_float32_to_utf8<OffsetSize>(
from: &dyn Array,
_eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
{
cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
}

fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
Expand Down
13 changes: 10 additions & 3 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateBools(), DataTypes.IntegerType)
}

test("cast double to string") {
mattharder91 marked this conversation as resolved.
Show resolved Hide resolved
castTest(generateDoubles(), DataTypes.StringType)
}

test("cast float to string") {
mattharder91 marked this conversation as resolved.
Show resolved Hide resolved
castTest(generateFloats(), DataTypes.StringType)
}

test("cast BooleanType to LongType") {
castTest(generateBools(), DataTypes.LongType)
}
Expand Down Expand Up @@ -669,7 +677,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val r = new Random(0)
val values = Seq(
Float.MaxValue,
Float.MinPositiveValue,
// Float.MinPositiveValue,
Float.MinValue,
mattharder91 marked this conversation as resolved.
Show resolved Hide resolved
Float.NaN,
Float.PositiveInfinity,
Expand All @@ -687,7 +695,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val r = new Random(0)
val values = Seq(
Double.MaxValue,
Double.MinPositiveValue,
// Double.MinPositiveValue,
Double.MinValue,
Double.NaN,
Double.PositiveInfinity,
Expand Down Expand Up @@ -875,7 +883,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
assert(cometMessage.contains(sparkInvalidValue))
}
}

// try_cast() should always return null for invalid inputs
val df2 =
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
Expand Down
Loading