Skip to content

Commit

Permalink
Add StringView support for date_part and make_date funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
a10y committed Jul 20, 2024
1 parent efcf5c6 commit 580e64b
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 41 deletions.
24 changes: 12 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,15 @@ unused_imports = "deny"
## Temporary arrow-rs patch until 52.2.0 is released

[patch.crates-io]
arrow = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-array = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-buffer = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-cast = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-data = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-ipc = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-schema = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-select = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-string = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-ord = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-flight = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
parquet = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-array = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-buffer = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-cast = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-data = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-ipc = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-schema = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-select = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-string = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-ord = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-flight = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
parquet = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
24 changes: 12 additions & 12 deletions datafusion-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ predicates = "3.0"
rstest = "0.17"

[patch.crates-io]
arrow = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-array = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-buffer = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-cast = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-data = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-ipc = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-schema = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-select = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-string = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-ord = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow-flight = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
parquet = { git = "https://github.com/apache/arrow-rs.git", rev = "66390ff8ec15bb6ed585f353f67a19574da4375a" }
arrow = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-array = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-buffer = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-cast = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-data = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-ipc = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-schema = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-select = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-string = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-ord = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
arrow-flight = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
parquet = { git = "https://github.com/apache/arrow-rs.git", rev = "8a5be1330e30e6dd7760dba910737550d760e612" }
26 changes: 14 additions & 12 deletions datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
/// where one is temporal and one is `Utf8`/`LargeUtf8`.
/// where one is temporal and one is `Utf8View`/`Utf8`/`LargeUtf8`.
///
/// Note this cannot be performed in case of arithmetic as there is insufficient information
/// to correctly determine the type of argument. Consider
Expand All @@ -547,19 +547,21 @@ fn string_temporal_coercion(

fn match_rule(l: &DataType, r: &DataType) -> Option<DataType> {
match (l, r) {
// Coerce Utf8/LargeUtf8 to Date32/Date64/Time32/Time64/Timestamp
(Utf8, temporal) | (LargeUtf8, temporal) => match temporal {
Date32 | Date64 => Some(temporal.clone()),
Time32(_) | Time64(_) => {
if is_time_with_valid_unit(temporal.to_owned()) {
Some(temporal.to_owned())
} else {
None
// Coerce Utf8View/Utf8/LargeUtf8 to Date32/Date64/Time32/Time64/Timestamp
(Utf8, temporal) | (LargeUtf8, temporal) | (Utf8View, temporal) => {
match temporal {
Date32 | Date64 => Some(temporal.clone()),
Time32(_) | Time64(_) => {
if is_time_with_valid_unit(temporal.to_owned()) {
Some(temporal.to_owned())
} else {
None
}
}
Timestamp(_, tz) => Some(Timestamp(TimeUnit::Nanosecond, tz.clone())),
_ => None,
}
Timestamp(_, tz) => Some(Timestamp(TimeUnit::Nanosecond, tz.clone())),
_ => None,
},
}
_ => None,
}
}
Expand Down
30 changes: 29 additions & 1 deletion datafusion/functions/src/datetime/date_part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::sync::Arc;
use arrow::array::{Array, ArrayRef, Float64Array};
use arrow::compute::{binary, cast, date_part, DatePart};
use arrow::datatypes::DataType::{
Date32, Date64, Float64, Time32, Time64, Timestamp, Utf8,
Date32, Date64, Float64, Time32, Time64, Timestamp, Utf8, Utf8View,
};
use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
use arrow::datatypes::{DataType, TimeUnit};
Expand Down Expand Up @@ -56,31 +56,57 @@ impl DatePartFunc {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Timestamp(Nanosecond, None)]),
Exact(vec![Utf8View, Timestamp(Nanosecond, None)]),
Exact(vec![
Utf8,
Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![
Utf8View,
Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![Utf8, Timestamp(Millisecond, None)]),
Exact(vec![Utf8View, Timestamp(Millisecond, None)]),
Exact(vec![
Utf8,
Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![
Utf8View,
Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![Utf8, Timestamp(Microsecond, None)]),
Exact(vec![Utf8View, Timestamp(Microsecond, None)]),
Exact(vec![
Utf8,
Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![
Utf8View,
Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![Utf8, Timestamp(Second, None)]),
Exact(vec![Utf8View, Timestamp(Second, None)]),
Exact(vec![
Utf8,
Timestamp(Second, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![
Utf8View,
Timestamp(Second, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![Utf8, Date64]),
Exact(vec![Utf8View, Date64]),
Exact(vec![Utf8, Date32]),
Exact(vec![Utf8View, Date32]),
Exact(vec![Utf8, Time32(Second)]),
Exact(vec![Utf8View, Time32(Second)]),
Exact(vec![Utf8, Time32(Millisecond)]),
Exact(vec![Utf8View, Time32(Millisecond)]),
Exact(vec![Utf8, Time64(Microsecond)]),
Exact(vec![Utf8View, Time64(Microsecond)]),
Exact(vec![Utf8, Time64(Nanosecond)]),
Exact(vec![Utf8View, Time64(Nanosecond)]),
],
Volatility::Immutable,
),
Expand Down Expand Up @@ -114,6 +140,8 @@ impl ScalarUDFImpl for DatePartFunc {

let part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = part {
v
} else if let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) = part {
v
} else {
return exec_err!(
"First argument of `DATE_PART` must be non-null scalar Utf8"
Expand Down
25 changes: 24 additions & 1 deletion datafusion/functions/src/datetime/date_trunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use arrow::array::types::{
TimestampNanosecondType, TimestampSecondType,
};
use arrow::array::{Array, PrimitiveArray};
use arrow::datatypes::DataType::{self, Null, Timestamp, Utf8};
use arrow::datatypes::DataType::{self, Null, Timestamp, Utf8, Utf8View};
use arrow::datatypes::TimeUnit::{self, Microsecond, Millisecond, Nanosecond, Second};
use datafusion_common::cast::as_primitive_array;
use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue};
Expand Down Expand Up @@ -61,25 +61,45 @@ impl DateTruncFunc {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Timestamp(Nanosecond, None)]),
Exact(vec![Utf8View, Timestamp(Nanosecond, None)]),
Exact(vec![
Utf8,
Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![
Utf8View,
Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![Utf8, Timestamp(Microsecond, None)]),
Exact(vec![Utf8View, Timestamp(Microsecond, None)]),
Exact(vec![
Utf8,
Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![
Utf8View,
Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![Utf8, Timestamp(Millisecond, None)]),
Exact(vec![Utf8View, Timestamp(Millisecond, None)]),
Exact(vec![
Utf8,
Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![
Utf8View,
Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![Utf8, Timestamp(Second, None)]),
Exact(vec![Utf8View, Timestamp(Second, None)]),
Exact(vec![
Utf8,
Timestamp(Second, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![
Utf8View,
Timestamp(Second, Some(TIMEZONE_WILDCARD.into())),
]),
],
Volatility::Immutable,
),
Expand Down Expand Up @@ -121,6 +141,9 @@ impl ScalarUDFImpl for DateTruncFunc {
granularity
{
v.to_lowercase()
} else if let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) =
granularity {
v.to_lowercase()
} else {
return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8");
};
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/datetime/make_date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use arrow::array::cast::AsArray;
use arrow::array::types::{Date32Type, Int32Type};
use arrow::array::PrimitiveArray;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{Date32, Int32, Int64, UInt32, UInt64, Utf8};
use arrow::datatypes::DataType::{Date32, Int32, Int64, UInt32, UInt64, Utf8, Utf8View};
use chrono::prelude::*;

use datafusion_common::{exec_err, Result, ScalarValue};
Expand All @@ -45,7 +45,7 @@ impl MakeDateFunc {
Self {
signature: Signature::uniform(
3,
vec![Int32, Int64, UInt32, UInt64, Utf8],
vec![Int32, Int64, UInt32, UInt64, Utf8, Utf8View],
Volatility::Immutable,
),
}
Expand Down
8 changes: 7 additions & 1 deletion datafusion/functions/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::array::ArrayRef;
use arrow::datatypes::DataType;

use datafusion_common::{Result, ScalarValue};
use datafusion_expr::function::Hint;
use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation};
use std::sync::Arc;

/// Creates a function to identify the optimal return type of a string function given
/// the type of its first argument.
Expand All @@ -29,6 +31,8 @@ use std::sync::Arc;
/// `$largeUtf8Type`,
///
/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`,
///
/// If the input type is `Utf8View` the return type is `Utf8View`,
macro_rules! get_optimal_return_type {
($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
Expand All @@ -37,6 +41,8 @@ macro_rules! get_optimal_return_type {
DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
// Binary inputs are automatically coerced to Utf8
DataType::Utf8 | DataType::Binary => $utf8Type,
// Utf8View inputs will yield Utf8View outputs
DataType::Utf8View => DataType::Utf8View,
DataType::Null => DataType::Null,
DataType::Dictionary(_, value_type) => match **value_type {
DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result<String> {
DataType::Utf8 => {
Ok(varchar_to_str(get_row_value!(array::StringArray, col, row)))
}
DataType::Utf8View => {
Ok(varchar_to_str(get_row_value!(array::StringViewArray, col, row)))
}
_ => {
let f = ArrayFormatter::try_new(col.as_ref(), &DEFAULT_FORMAT_OPTIONS);
Ok(f.unwrap().value(row).to_string())
Expand Down
21 changes: 21 additions & 0 deletions datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,24 @@ logical_plan

statement ok
drop table test;

# coercion from stringview to integer, as input to make_date
query D
select make_date(arrow_cast('2024', 'Utf8View'), arrow_cast('01', 'Utf8View'), arrow_cast('23', 'Utf8View'))
----
2024-01-23

# coercions between stringview and date types
statement ok
create table dates (dt date) as values
(date '2024-01-23'),
(date '2023-11-30');

query D
select t.dt from dates t where arrow_cast('2024-01-01', 'Utf8View') < t.dt;
----
2024-01-23


statement ok
drop table dates;

0 comments on commit 580e64b

Please sign in to comment.