Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement basic version of string to float/double/decimal #870

Merged
merged 5 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/user-guide/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ The following cast operations are not compatible with Spark for all inputs and a
|-|-|-|
| integer | decimal | No overflow check |
| long | decimal | No overflow check |
| string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
| string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
| string | decimal | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits |
| string | timestamp | Not all valid formats are supported |
| binary | string | Only works for binary data representing valid UTF-8 strings |

Expand Down
12 changes: 10 additions & 2 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,13 @@ impl PhysicalPlanner {
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let timezone = expr.timezone.clone();
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;

Ok(Arc::new(Cast::new(child, datatype, eval_mode, timezone)))
Ok(Arc::new(Cast::new(
child,
datatype,
eval_mode,
timezone,
expr.allow_incompat,
)))
}
ExprStruct::Hour(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
Expand Down Expand Up @@ -723,17 +728,20 @@ impl PhysicalPlanner {
left,
DataType::Decimal256(p1, s1),
EvalMode::Legacy,
false,
));
let right = Arc::new(Cast::new_without_timezone(
right,
DataType::Decimal256(p2, s2),
EvalMode::Legacy,
false,
));
let child = Arc::new(BinaryExpr::new(left, op, right));
Ok(Arc::new(Cast::new_without_timezone(
child,
data_type,
EvalMode::Legacy,
false,
)))
}
(
Expand Down
4 changes: 2 additions & 2 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ message Cast {
Expr child = 1;
DataType datatype = 2;
string timezone = 3;
EvalMode eval_mode = 4;

EvalMode eval_mode = 4;
bool allow_incompat = 5;
}

message Equal {
Expand Down
5 changes: 4 additions & 1 deletion native/spark-expr/benches/cast_from_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,23 @@ fn criterion_benchmark(c: &mut Criterion) {
DataType::Int8,
EvalMode::Legacy,
timezone.clone(),
false,
);
let cast_string_to_i16 = Cast::new(
expr.clone(),
DataType::Int16,
EvalMode::Legacy,
timezone.clone(),
false,
);
let cast_string_to_i32 = Cast::new(
expr.clone(),
DataType::Int32,
EvalMode::Legacy,
timezone.clone(),
false,
);
let cast_string_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone);
let cast_string_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone, false);

let mut group = c.benchmark_group("cast_string_to_int");
group.bench_function("cast_string_to_i8", |b| {
Expand Down
4 changes: 3 additions & 1 deletion native/spark-expr/benches/cast_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@ fn criterion_benchmark(c: &mut Criterion) {
DataType::Int8,
EvalMode::Legacy,
timezone.clone(),
false,
);
let cast_i32_to_i16 = Cast::new(
expr.clone(),
DataType::Int16,
EvalMode::Legacy,
timezone.clone(),
false,
);
let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone);
let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone, false);

let mut group = c.benchmark_group("cast_int_to_int");
group.bench_function("cast_i32_to_i8", |b| {
Expand Down
44 changes: 40 additions & 4 deletions native/spark-expr/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ pub struct Cast {
/// When cast from/to timezone related types, we need timezone, which will be resolved with
/// session local timezone by an analyzer in Spark.
pub timezone: String,

/// Whether to allow casts that are known to be incompatible with Spark
pub allow_incompat: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets have a comment on this field

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

}

macro_rules! cast_utf8_to_int {
Expand Down Expand Up @@ -545,25 +548,29 @@ impl Cast {
data_type: DataType,
eval_mode: EvalMode,
timezone: String,
allow_incompat: bool,
) -> Self {
Self {
child,
data_type,
timezone,
eval_mode,
allow_incompat,
}
}

pub fn new_without_timezone(
child: Arc<dyn PhysicalExpr>,
data_type: DataType,
eval_mode: EvalMode,
allow_incompat: bool,
) -> Self {
Self {
child,
data_type,
timezone: "".to_string(),
eval_mode,
allow_incompat,
}
}
}
Expand All @@ -576,21 +583,29 @@ pub fn spark_cast(
data_type: &DataType,
eval_mode: EvalMode,
timezone: String,
allow_incompat: bool,
) -> DataFusionResult<ColumnarValue> {
match arg {
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array(
array,
data_type,
eval_mode,
timezone.to_owned(),
allow_incompat,
)?)),
ColumnarValue::Scalar(scalar) => {
// Note that normally CAST(scalar) should be fold in Spark JVM side. However, for
// some cases e.g., scalar subquery, Spark will not fold it, so we need to handle it
// here.
let array = scalar.to_array()?;
let scalar = ScalarValue::try_from_array(
&cast_array(array, data_type, eval_mode, timezone.to_owned())?,
&cast_array(
array,
data_type,
eval_mode,
timezone.to_owned(),
allow_incompat,
)?,
0,
)?;
Ok(ColumnarValue::Scalar(scalar))
Expand All @@ -603,6 +618,7 @@ fn cast_array(
to_type: &DataType,
eval_mode: EvalMode,
timezone: String,
allow_incompat: bool,
) -> DataFusionResult<ArrayRef> {
let array = array_with_timezone(array, timezone.clone(), Some(to_type))?;
let from_type = array.data_type().clone();
Expand All @@ -624,6 +640,7 @@ fn cast_array(
to_type,
eval_mode,
timezone,
allow_incompat,
)?,
);

Expand Down Expand Up @@ -693,7 +710,7 @@ fn cast_array(
{
spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type)
}
_ if is_datafusion_spark_compatible(from_type, to_type) => {
_ if is_datafusion_spark_compatible(from_type, to_type, allow_incompat) => {
// use DataFusion cast only when we know that it is compatible with Spark
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
}
Expand All @@ -711,7 +728,11 @@ fn cast_array(

/// Determines if DataFusion supports the given cast in a way that is
/// compatible with Spark
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Shall we update the documentation comment to reflect the addition of the allow_incompatparameter?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allow_incompat is an internal API so I don't think we need to add antyhing to the user guide. We do already have documentation for the spark.comet.cast.allowIncompatible config, which is used to populate allow_incompat.

fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool {
fn is_datafusion_spark_compatible(
from_type: &DataType,
to_type: &DataType,
allow_incompat: bool,
) -> bool {
if from_type == to_type {
return true;
}
Expand Down Expand Up @@ -764,6 +785,10 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _)
),
DataType::Utf8 if allow_incompat => matches!(
to_type,
DataType::Binary | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _)
),
DataType::Utf8 => matches!(to_type, DataType::Binary),
DataType::Date32 => matches!(to_type, DataType::Utf8),
DataType::Timestamp(_, _) => {
Expand Down Expand Up @@ -1385,7 +1410,13 @@ impl PhysicalExpr for Cast {

fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
let arg = self.child.evaluate(batch)?;
spark_cast(arg, &self.data_type, self.eval_mode, self.timezone.clone())
spark_cast(
arg,
&self.data_type,
self.eval_mode,
self.timezone.clone(),
self.allow_incompat,
)
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
Expand All @@ -1402,6 +1433,7 @@ impl PhysicalExpr for Cast {
self.data_type.clone(),
self.eval_mode,
self.timezone.clone(),
self.allow_incompat,
))),
_ => internal_err!("Cast should have exactly one child"),
}
Expand All @@ -1413,6 +1445,7 @@ impl PhysicalExpr for Cast {
self.data_type.hash(&mut s);
self.timezone.hash(&mut s);
self.eval_mode.hash(&mut s);
self.allow_incompat.hash(&mut s);
self.hash(&mut s);
}
}
Expand Down Expand Up @@ -1996,6 +2029,7 @@ mod tests {
&DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())),
EvalMode::Legacy,
timezone.clone(),
false,
)?;
assert_eq!(
*result.data_type(),
Expand Down Expand Up @@ -2205,6 +2239,7 @@ mod tests {
&DataType::Date32,
EvalMode::Legacy,
"UTC".to_owned(),
false,
);
assert!(result.is_err())
}
Expand All @@ -2217,6 +2252,7 @@ mod tests {
&DataType::Date32,
EvalMode::Legacy,
"Not a valid timezone".to_owned(),
false,
);
assert!(result.is_err())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,15 @@ object CometCast {
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
// https://github.com/apache/datafusion-comet/issues/326
Unsupported
Incompatible(
Some(
"Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " +
"Does not support ANSI mode."))
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/325
Unsupported
Incompatible(
Some("Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " +
"Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits"))
case DataTypes.DateType =>
// https://github.com/apache/datafusion-comet/issues/327
Compatible(Some("Only supports years between 262143 BC and 262142 AD"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
castBuilder.setChild(childExpr.get)
castBuilder.setDatatype(dataType.get)
castBuilder.setEvalMode(evalModeToProto(evalMode))

castBuilder.setAllowIncompat(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get())
val timeZone = timeZoneId.getOrElse("UTC")
castBuilder.setTimezone(timeZone)

Expand Down Expand Up @@ -1506,6 +1506,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
.setChild(e)
.setDatatype(serializeDataType(IntegerType).get)
.setEvalMode(ExprOuterClass.EvalMode.LEGACY)
.setAllowIncompat(false)
.build())
.build()
})
Expand Down
Loading
Loading