Skip to content

Commit

Permalink
feat: Implement Spark-compatible CAST from string to timestamp types (a…
Browse files Browse the repository at this point in the history
…pache#335)

* casting str to timestamp

* fix format

* fixing failed tests, using char as pattern

* bug fixes

* hangling microsecond

* make format

* bug fixes and core refactor

* format code

* removing print statements

* clippy error

* enabling cast timestamp test case

* code refactor

* comet spark test case

* adding all the supported format in test

* fallback spark when timestamp not utc

* bug fix

* bug fix

* adding an explainer commit

* fix test case

* bug fix

* bug fix

* better error handling for unwrap in fn parse_str_to_time_only_timestamp

* remove unwrap from macro

* improving error handling

* adding tests for invalid inputs

* removed all unwraps from timestamp cast functions

* code format
  • Loading branch information
vaibhawvipul authored and Steve Vaughan Jr committed May 2, 2024
1 parent b48f8f3 commit 8298695
Show file tree
Hide file tree
Showing 3 changed files with 391 additions and 8 deletions.
316 changes: 314 additions & 2 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use std::{
use crate::errors::{CometError, CometResult};
use arrow::{
compute::{cast_with_options, CastOptions},
datatypes::TimestampMicrosecondType,
record_batch::RecordBatch,
util::display::FormatOptions,
};
Expand All @@ -33,10 +34,12 @@ use arrow_array::{
Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
};
use arrow_schema::{DataType, Schema};
use chrono::{TimeZone, Timelike};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;
use num::{traits::CheckedNeg, CheckedSub, Integer, Num};
use regex::Regex;

use crate::execution::datafusion::expressions::utils::{
array_with_timezone, down_cast_any_ref, spark_cast,
Expand Down Expand Up @@ -86,6 +89,24 @@ macro_rules! cast_utf8_to_int {
}};
}

macro_rules! cast_utf8_to_timestamp {
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
let len = $array.len();
let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
for i in 0..len {
if $array.is_null(i) {
cast_array.append_null()
} else if let Ok(Some(cast_value)) = $cast_method($array.value(i).trim(), $eval_mode) {
cast_array.append_value(cast_value);
} else {
cast_array.append_null()
}
}
let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
result
}};
}

impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -125,6 +146,9 @@ impl Cast {
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)?
}
(DataType::Utf8, DataType::Timestamp(_, _)) => {
Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)?
}
(
DataType::Utf8,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
Expand Down Expand Up @@ -200,6 +224,30 @@ impl Cast {
Ok(cast_array)
}

fn cast_string_to_timestamp(
array: &ArrayRef,
to_type: &DataType,
eval_mode: EvalMode,
) -> CometResult<ArrayRef> {
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.expect("Expected a string array");

let cast_array: ArrayRef = match to_type {
DataType::Timestamp(_, _) => {
cast_utf8_to_timestamp!(
string_array,
eval_mode,
TimestampMicrosecondType,
timestamp_parser
)
}
_ => unreachable!("Invalid data type {:?} in cast from string", to_type),
};
Ok(cast_array)
}

fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
Expand Down Expand Up @@ -510,9 +558,273 @@ impl PhysicalExpr for Cast {
}
}

fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult<Option<i64>> {
let value = value.trim();
if value.is_empty() {
return Ok(None);
}
// Define regex patterns and corresponding parsing functions
let patterns = &[
(
Regex::new(r"^\d{4}$").unwrap(),
parse_str_to_year_timestamp as fn(&str) -> CometResult<Option<i64>>,
),
(
Regex::new(r"^\d{4}-\d{2}$").unwrap(),
parse_str_to_month_timestamp,
),
(
Regex::new(r"^\d{4}-\d{2}-\d{2}$").unwrap(),
parse_str_to_day_timestamp,
),
(
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
parse_str_to_hour_timestamp,
),
(
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
parse_str_to_minute_timestamp,
),
(
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
parse_str_to_second_timestamp,
),
(
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
parse_str_to_microsecond_timestamp,
),
(
Regex::new(r"^T\d{1,2}$").unwrap(),
parse_str_to_time_only_timestamp,
),
];

let mut timestamp = None;

// Iterate through patterns and try matching
for (pattern, parse_func) in patterns {
if pattern.is_match(value) {
timestamp = parse_func(value)?;
break;
}
}

if timestamp.is_none() {
if eval_mode == EvalMode::Ansi {
return Err(CometError::CastInvalidValue {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "TIMESTAMP".to_string(),
});
} else {
return Ok(None);
}
}

match timestamp {
Some(ts) => Ok(Some(ts)),
None => Err(CometError::Internal(
"Failed to parse timestamp".to_string(),
)),
}
}

fn parse_ymd_timestamp(year: i32, month: u32, day: u32) -> CometResult<Option<i64>> {
let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, 0, 0, 0);

// Check if datetime is not None
let utc_datetime = match datetime.single() {
Some(dt) => dt.with_timezone(&chrono::Utc),
None => {
return Err(CometError::Internal(
"Failed to parse timestamp".to_string(),
));
}
};

Ok(Some(utc_datetime.timestamp_micros()))
}

fn parse_hms_timestamp(
year: i32,
month: u32,
day: u32,
hour: u32,
minute: u32,
second: u32,
microsecond: u32,
) -> CometResult<Option<i64>> {
let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, hour, minute, second);

// Check if datetime is not None
let utc_datetime = match datetime.single() {
Some(dt) => dt
.with_timezone(&chrono::Utc)
.with_nanosecond(microsecond * 1000),
None => {
return Err(CometError::Internal(
"Failed to parse timestamp".to_string(),
));
}
};

let result = match utc_datetime {
Some(dt) => dt.timestamp_micros(),
None => {
return Err(CometError::Internal(
"Failed to parse timestamp".to_string(),
));
}
};

Ok(Some(result))
}

fn get_timestamp_values(value: &str, timestamp_type: &str) -> CometResult<Option<i64>> {
let values: Vec<_> = value
.split(|c| c == 'T' || c == '-' || c == ':' || c == '.')
.collect();
let year = values[0].parse::<i32>().unwrap_or_default();
let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
let minute = values.get(4).map_or(0, |m| m.parse::<u32>().unwrap_or(0));
let second = values.get(5).map_or(0, |s| s.parse::<u32>().unwrap_or(0));
let microsecond = values.get(6).map_or(0, |ms| ms.parse::<u32>().unwrap_or(0));

match timestamp_type {
"year" => parse_ymd_timestamp(year, 1, 1),
"month" => parse_ymd_timestamp(year, month, 1),
"day" => parse_ymd_timestamp(year, month, day),
"hour" => parse_hms_timestamp(year, month, day, hour, 0, 0, 0),
"minute" => parse_hms_timestamp(year, month, day, hour, minute, 0, 0),
"second" => parse_hms_timestamp(year, month, day, hour, minute, second, 0),
"microsecond" => parse_hms_timestamp(year, month, day, hour, minute, second, microsecond),
_ => Err(CometError::CastInvalidValue {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "TIMESTAMP".to_string(),
}),
}
}

fn parse_str_to_year_timestamp(value: &str) -> CometResult<Option<i64>> {
get_timestamp_values(value, "year")
}

fn parse_str_to_month_timestamp(value: &str) -> CometResult<Option<i64>> {
get_timestamp_values(value, "month")
}

fn parse_str_to_day_timestamp(value: &str) -> CometResult<Option<i64>> {
get_timestamp_values(value, "day")
}

fn parse_str_to_hour_timestamp(value: &str) -> CometResult<Option<i64>> {
get_timestamp_values(value, "hour")
}

fn parse_str_to_minute_timestamp(value: &str) -> CometResult<Option<i64>> {
get_timestamp_values(value, "minute")
}

fn parse_str_to_second_timestamp(value: &str) -> CometResult<Option<i64>> {
get_timestamp_values(value, "second")
}

fn parse_str_to_microsecond_timestamp(value: &str) -> CometResult<Option<i64>> {
get_timestamp_values(value, "microsecond")
}

fn parse_str_to_time_only_timestamp(value: &str) -> CometResult<Option<i64>> {
let values: Vec<&str> = value.split('T').collect();
let time_values: Vec<u32> = values[1]
.split(':')
.map(|v| v.parse::<u32>().unwrap_or(0))
.collect();

let datetime = chrono::Utc::now();
let timestamp = datetime
.with_hour(time_values.first().copied().unwrap_or_default())
.and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0)))
.and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0)))
.and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 1_000))
.map(|dt| dt.to_utc().timestamp_micros())
.unwrap_or_default();

Ok(Some(timestamp))
}

#[cfg(test)]
mod test {
use super::{cast_string_to_i8, EvalMode};
mod tests {
use super::*;
use arrow::datatypes::TimestampMicrosecondType;
use arrow_array::StringArray;
use arrow_schema::TimeUnit;

#[test]
fn timestamp_parser_test() {
// write for all formats
assert_eq!(
timestamp_parser("2020", EvalMode::Legacy).unwrap(),
Some(1577836800000000) // this is in milliseconds
);
assert_eq!(
timestamp_parser("2020-01", EvalMode::Legacy).unwrap(),
Some(1577836800000000)
);
assert_eq!(
timestamp_parser("2020-01-01", EvalMode::Legacy).unwrap(),
Some(1577836800000000)
);
assert_eq!(
timestamp_parser("2020-01-01T12", EvalMode::Legacy).unwrap(),
Some(1577880000000000)
);
assert_eq!(
timestamp_parser("2020-01-01T12:34", EvalMode::Legacy).unwrap(),
Some(1577882040000000)
);
assert_eq!(
timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy).unwrap(),
Some(1577882096000000)
);
assert_eq!(
timestamp_parser("2020-01-01T12:34:56.123456", EvalMode::Legacy).unwrap(),
Some(1577882096123456)
);
// assert_eq!(
// timestamp_parser("T2", EvalMode::Legacy).unwrap(),
// Some(1714356000000000) // this value needs to change everyday.
// );
}

#[test]
fn test_cast_string_to_timestamp() {
let array: ArrayRef = Arc::new(StringArray::from(vec![
Some("2020-01-01T12:34:56.123456"),
Some("T2"),
]));

let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.expect("Expected a string array");

let eval_mode = EvalMode::Legacy;
let result = cast_utf8_to_timestamp!(
&string_array,
eval_mode,
TimestampMicrosecondType,
timestamp_parser
);

assert_eq!(
result.data_type(),
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()))
);
assert_eq!(result.len(), 2);
}

#[test]
fn test_cast_string_as_i8() {
Expand Down
12 changes: 11 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
// Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
evalMode.toString
}

val supportedTimezone = (child.dataType, dt) match {
case (DataTypes.StringType, DataTypes.TimestampType)
if !timeZoneId.contains("UTC") =>
withInfo(expr, s"Unsupported timezone ${timeZoneId} for timestamp cast")
false
case _ => true
}

val supportedCast = (child.dataType, dt) match {
case (DataTypes.StringType, DataTypes.TimestampType)
if !CometConf.COMET_CAST_STRING_TO_TIMESTAMP.get() =>
Expand All @@ -593,7 +602,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
false
case _ => true
}
if (supportedCast) {

if (supportedCast && supportedTimezone) {
castToProto(timeZoneId, dt, childExpr, evalModeStr)
} else {
// no need to call withInfo here since it was called when determining
Expand Down
Loading

0 comments on commit 8298695

Please sign in to comment.