Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Spark-compatible CAST float/double to string #346

Merged
merged 12 commits into from
May 3, 2024
105 changes: 103 additions & 2 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use std::{
any::Any,
fmt::{Display, Formatter},
fmt::{Debug, Display, Formatter},
hash::{Hash, Hasher},
sync::Arc,
};
Expand All @@ -31,7 +31,8 @@ use arrow::{
};
use arrow_array::{
types::{Int16Type, Int32Type, Int64Type, Int8Type},
Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
Array, ArrayRef, BooleanArray, Float32Array, Float64Array, GenericStringArray, OffsetSizeTrait,
PrimitiveArray,
};
use arrow_schema::{DataType, Schema};
use chrono::{TimeZone, Timelike};
Expand Down Expand Up @@ -107,6 +108,74 @@ macro_rules! cast_utf8_to_timestamp {
}};
}

macro_rules! cast_float_to_string {
($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{

fn cast<OffsetSize>(
from: &dyn Array,
_eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait, {
let array = from.as_any().downcast_ref::<$output_type>().unwrap();

// If the absolute number is less than 10,000,000 and greater or equal than 0.001, the
// result is expressed without scientific notation with at least one digit on either side of
// the decimal point. Otherwise, Spark uses a mantissa followed by E and an
// exponent. The mantissa has an optional leading minus sign followed by one digit to the
// left of the decimal point, and the minimal number of digits greater than zero to the
// right. The exponent has and optional leading minus sign.
// source: https://docs.databricks.com/en/sql/language-manual/functions/cast.html

const LOWER_SCIENTIFIC_BOUND: $type = 0.001;
const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0;

let output_array = array
.iter()
.map(|value| match value {
Some(value) if value == <$type>::INFINITY => Ok(Some("Infinity".to_string())),
Some(value) if value == <$type>::NEG_INFINITY => Ok(Some("-Infinity".to_string())),
Some(value)
if (value.abs() < UPPER_SCIENTIFIC_BOUND
&& value.abs() >= LOWER_SCIENTIFIC_BOUND)
|| value.abs() == 0.0 =>
{
let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" };

Ok(Some(format!("{value}{trailing_zero}")))
}
Some(value)
if value.abs() >= UPPER_SCIENTIFIC_BOUND
|| value.abs() < LOWER_SCIENTIFIC_BOUND =>
{
let formatted = format!("{value:E}");

if formatted.contains(".") {
Ok(Some(formatted))
} else {
// `formatted` is already in scientific notation and can be split up by E
// in order to add the missing trailing 0 which gets removed for numbers with a fraction of 0.0
let prepare_number: Vec<&str> = formatted.split("E").collect();
mattharder91 marked this conversation as resolved.
Show resolved Hide resolved

let coefficient = prepare_number[0];

let exponent = prepare_number[1];

Ok(Some(format!("{coefficient}.0E{exponent}")))
}
}
Some(value) => Ok(Some(value.to_string())),
_ => Ok(None),
})
.collect::<Result<GenericStringArray<OffsetSize>, CometError>>()?;

Ok(Arc::new(output_array))
}

cast::<$offset_type>($from, $eval_mode)
}};
}

impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -185,6 +254,18 @@ impl Cast {
),
}
}
(DataType::Float64, DataType::Utf8) => {
Self::spark_cast_float64_to_utf8::<i32>(&array, self.eval_mode)?
}
(DataType::Float64, DataType::LargeUtf8) => {
Self::spark_cast_float64_to_utf8::<i64>(&array, self.eval_mode)?
}
(DataType::Float32, DataType::Utf8) => {
Self::spark_cast_float32_to_utf8::<i32>(&array, self.eval_mode)?
}
(DataType::Float32, DataType::LargeUtf8) => {
Self::spark_cast_float32_to_utf8::<i64>(&array, self.eval_mode)?
}
_ => {
// when we have no Spark-specific casting we delegate to DataFusion
cast_with_options(&array, to_type, &CAST_OPTIONS)?
Expand Down Expand Up @@ -248,6 +329,26 @@ impl Cast {
Ok(cast_array)
}

fn spark_cast_float64_to_utf8<OffsetSize>(
from: &dyn Array,
_eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
{
cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize)
}

fn spark_cast_float32_to_utf8<OffsetSize>(
from: &dyn Array,
_eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
{
cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
}

fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
Expand Down
30 changes: 26 additions & 4 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,22 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateFloats(), DataTypes.createDecimalType(10, 2))
}

ignore("cast FloatType to StringType") {
test("cast FloatType to StringType") {
// https://github.com/apache/datafusion-comet/issues/312
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove the references to the issue now that we are resolving the issue

Suggested change
// https://github.com/apache/datafusion-comet/issues/312

castTest(generateFloats(), DataTypes.StringType)
val r = new Random(0)
val values = Seq(
Float.MaxValue,
Float.MinValue,
Float.NaN,
Float.PositiveInfinity,
Float.NegativeInfinity,
1.0f,
-1.0f,
Short.MinValue.toFloat,
Short.MaxValue.toFloat,
0.0f) ++
Range(0, dataSize).map(_ => r.nextFloat())
withNulls(values).toDF("a")
}

ignore("cast FloatType to TimestampType") {
Expand Down Expand Up @@ -374,9 +387,18 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateDoubles(), DataTypes.createDecimalType(10, 2))
}

ignore("cast DoubleType to StringType") {
test("cast DoubleType to StringType") {
// https://github.com/apache/datafusion-comet/issues/312
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// https://github.com/apache/datafusion-comet/issues/312

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove these comments in a future PR

castTest(generateDoubles(), DataTypes.StringType)
val r = new Random(0)
val values = Seq(
Double.MaxValue,
Double.MinValue,
Double.NaN,
Double.PositiveInfinity,
Double.NegativeInfinity,
0.0d) ++
Range(0, dataSize).map(_ => r.nextDouble())
withNulls(values).toDF("a")
}

ignore("cast DoubleType to TimestampType") {
Expand Down
Loading