Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(rust): Refactor AnyValue casting logic #13140

Merged
merged 2 commits into from
Feb 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 109 additions & 117 deletions crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
#[cfg(feature = "dtype-struct")]
use arrow::legacy::trusted_len::TrustedLenPush;
#[cfg(feature = "dtype-date")]
use arrow::temporal_conversions::{
timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime,
};
use arrow::types::PrimitiveType;
use polars_utils::format_smartstring;
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -491,130 +487,126 @@ impl<'a> AnyValue<'a> {
}
}

/// Cast `AnyValue` to the provided data type and return a new `AnyValue` with type `dtype`,
/// if possible.
///
pub fn strict_cast(&self, dtype: &'a DataType) -> PolarsResult<AnyValue<'a>> {
fn cast_to_numeric<'a>(av: &AnyValue, dtype: &'a DataType) -> PolarsResult<AnyValue<'a>> {
let out = match dtype {
DataType::UInt8 => AnyValue::UInt8(av.try_extract::<u8>()?),
DataType::UInt16 => AnyValue::UInt16(av.try_extract::<u16>()?),
DataType::UInt32 => AnyValue::UInt32(av.try_extract::<u32>()?),
DataType::UInt64 => AnyValue::UInt64(av.try_extract::<u64>()?),
DataType::Int8 => AnyValue::Int8(av.try_extract::<i8>()?),
DataType::Int16 => AnyValue::Int16(av.try_extract::<i16>()?),
DataType::Int32 => AnyValue::Int32(av.try_extract::<i32>()?),
DataType::Int64 => AnyValue::Int64(av.try_extract::<i64>()?),
DataType::Float32 => AnyValue::Float32(av.try_extract::<f32>()?),
DataType::Float64 => AnyValue::Float64(av.try_extract::<f64>()?),
_ => {
polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", av, dtype)
},
};
Ok(out)
}
let new_av = match (self, dtype) {
// to numeric
(av, DataType::UInt8) => AnyValue::UInt8(av.try_extract::<u8>()?),
(av, DataType::UInt16) => AnyValue::UInt16(av.try_extract::<u16>()?),
(av, DataType::UInt32) => AnyValue::UInt32(av.try_extract::<u32>()?),
(av, DataType::UInt64) => AnyValue::UInt64(av.try_extract::<u64>()?),
(av, DataType::Int8) => AnyValue::Int8(av.try_extract::<i8>()?),
(av, DataType::Int16) => AnyValue::Int16(av.try_extract::<i16>()?),
(av, DataType::Int32) => AnyValue::Int32(av.try_extract::<i32>()?),
(av, DataType::Int64) => AnyValue::Int64(av.try_extract::<i64>()?),
(av, DataType::Float32) => AnyValue::Float32(av.try_extract::<f32>()?),
(av, DataType::Float64) => AnyValue::Float64(av.try_extract::<f64>()?),

fn cast_to_boolean<'a>(av: &AnyValue) -> PolarsResult<AnyValue<'a>> {
let out = match av {
AnyValue::UInt8(v) => AnyValue::Boolean(*v != u8::default()),
AnyValue::UInt16(v) => AnyValue::Boolean(*v != u16::default()),
AnyValue::UInt32(v) => AnyValue::Boolean(*v != u32::default()),
AnyValue::UInt64(v) => AnyValue::Boolean(*v != u64::default()),
AnyValue::Int8(v) => AnyValue::Boolean(*v != i8::default()),
AnyValue::Int16(v) => AnyValue::Boolean(*v != i16::default()),
AnyValue::Int32(v) => AnyValue::Boolean(*v != i32::default()),
AnyValue::Int64(v) => AnyValue::Boolean(*v != i64::default()),
AnyValue::Float32(v) => AnyValue::Boolean(*v != f32::default()),
AnyValue::Float64(v) => AnyValue::Boolean(*v != f64::default()),
_ => {
polars_bail!(ComputeError: "cannot cast any-value {:?} to boolean", av)
},
};
Ok(out)
}
// to boolean
(AnyValue::UInt8(v), DataType::Boolean) => AnyValue::Boolean(*v != u8::default()),
(AnyValue::UInt16(v), DataType::Boolean) => AnyValue::Boolean(*v != u16::default()),
(AnyValue::UInt32(v), DataType::Boolean) => AnyValue::Boolean(*v != u32::default()),
(AnyValue::UInt64(v), DataType::Boolean) => AnyValue::Boolean(*v != u64::default()),
(AnyValue::Int8(v), DataType::Boolean) => AnyValue::Boolean(*v != i8::default()),
(AnyValue::Int16(v), DataType::Boolean) => AnyValue::Boolean(*v != i16::default()),
(AnyValue::Int32(v), DataType::Boolean) => AnyValue::Boolean(*v != i32::default()),
(AnyValue::Int64(v), DataType::Boolean) => AnyValue::Boolean(*v != i64::default()),
(AnyValue::Float32(v), DataType::Boolean) => AnyValue::Boolean(*v != f32::default()),
(AnyValue::Float64(v), DataType::Boolean) => AnyValue::Boolean(*v != f64::default()),

let new_av = match self {
_ if (self.is_boolean() | self.is_numeric()) => match dtype {
#[cfg(feature = "dtype-date")]
DataType::Date => AnyValue::Date(self.try_extract::<i32>()?),
#[cfg(feature = "dtype-datetime")]
DataType::Datetime(tu, tz) => {
AnyValue::Datetime(self.try_extract::<i64>()?, *tu, tz)
},
#[cfg(feature = "dtype-duration")]
DataType::Duration(tu) => AnyValue::Duration(self.try_extract::<i64>()?, *tu),
#[cfg(feature = "dtype-time")]
DataType::Time => AnyValue::Time(self.try_extract::<i64>()?),
DataType::String => {
AnyValue::StringOwned(format_smartstring!("{}", self.try_extract::<i64>()?))
},
DataType::Boolean => return cast_to_boolean(self),
_ => return cast_to_numeric(self, dtype),
// to string
(av, DataType::String) => {
AnyValue::StringOwned(format_smartstring!("{}", av.try_extract::<i64>()?))
},

// to binary
(AnyValue::String(v), DataType::Binary) => AnyValue::Binary(v.as_bytes()),

// to datetime
#[cfg(feature = "dtype-datetime")]
AnyValue::Datetime(v, tu, None) => match dtype {
#[cfg(feature = "dtype-date")]
// Datetime to Date
DataType::Date => {
let convert = match tu {
TimeUnit::Nanoseconds => timestamp_ns_to_datetime,
TimeUnit::Microseconds => timestamp_us_to_datetime,
TimeUnit::Milliseconds => timestamp_ms_to_datetime,
};
let ndt = convert(*v);
let date_value = naive_datetime_to_date(ndt);
AnyValue::Date(date_value)
(av, DataType::Datetime(tu, tz)) if av.is_numeric() => {
AnyValue::Datetime(av.try_extract::<i64>()?, *tu, tz)
},
#[cfg(all(feature = "dtype-datetime", feature = "dtype-date"))]
(AnyValue::Date(v), DataType::Datetime(tu, _)) => AnyValue::Datetime(
match tu {
TimeUnit::Nanoseconds => (*v as i64) * NS_IN_DAY,
TimeUnit::Microseconds => (*v as i64) * US_IN_DAY,
TimeUnit::Milliseconds => (*v as i64) * MS_IN_DAY,
},
#[cfg(feature = "dtype-time")]
// Datetime to Time
DataType::Time => {
let ns_since_midnight = match tu {
TimeUnit::Nanoseconds => *v % NS_IN_DAY,
TimeUnit::Microseconds => (*v % US_IN_DAY) * 1_000i64,
TimeUnit::Milliseconds => (*v % MS_IN_DAY) * 1_000_000i64,
};
AnyValue::Time(ns_since_midnight)
*tu,
&None,
),
#[cfg(feature = "dtype-datetime")]
(AnyValue::Datetime(v, tu, _), DataType::Datetime(tu_r, tz_r)) => AnyValue::Datetime(
match (tu, tu_r) {
(TimeUnit::Nanoseconds, TimeUnit::Microseconds) => *v / 1_000i64,
(TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => *v / 1_000_000i64,
(TimeUnit::Microseconds, TimeUnit::Nanoseconds) => *v * 1_000i64,
(TimeUnit::Microseconds, TimeUnit::Milliseconds) => *v / 1_000i64,
(TimeUnit::Milliseconds, TimeUnit::Microseconds) => *v * 1_000i64,
(TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => *v * 1_000_000i64,
_ => *v,
},
_ => return cast_to_numeric(self, dtype),
},
*tu_r,
tz_r,
),

// to date
#[cfg(feature = "dtype-date")]
(av, DataType::Date) if av.is_numeric() => AnyValue::Date(av.try_extract::<i32>()?),
#[cfg(all(feature = "dtype-date", feature = "dtype-datetime"))]
(AnyValue::Datetime(v, tu, _), DataType::Date) => AnyValue::Date(match tu {
TimeUnit::Nanoseconds => *v / NS_IN_DAY,
TimeUnit::Microseconds => *v / US_IN_DAY,
TimeUnit::Milliseconds => *v / MS_IN_DAY,
} as i32),

// to time
#[cfg(feature = "dtype-time")]
(av, DataType::Time) if av.is_numeric() => AnyValue::Time(av.try_extract::<i64>()?),
#[cfg(all(feature = "dtype-time", feature = "dtype-datetime"))]
(AnyValue::Datetime(v, tu, _), DataType::Time) => AnyValue::Time(match tu {
TimeUnit::Nanoseconds => *v % NS_IN_DAY,
TimeUnit::Microseconds => (*v % US_IN_DAY) * 1_000i64,
TimeUnit::Milliseconds => (*v % MS_IN_DAY) * 1_000_000i64,
}),

// to duration
#[cfg(feature = "dtype-duration")]
AnyValue::Duration(v, _) => match dtype {
DataType::Time | DataType::Date | DataType::Datetime(_, _) => {
polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", v, dtype)
},
_ => return cast_to_numeric(self, dtype),
(av, DataType::Duration(tu)) if av.is_numeric() => {
AnyValue::Duration(av.try_extract::<i64>()?, *tu)
},
#[cfg(feature = "dtype-time")]
AnyValue::Time(v) => match dtype {
#[cfg(feature = "dtype-duration")]
// Time to Duration
DataType::Duration(tu) => {
let duration_value = match tu {
TimeUnit::Nanoseconds => *v,
TimeUnit::Microseconds => *v / 1_000i64,
TimeUnit::Milliseconds => *v / 1_000_000i64,
};
AnyValue::Duration(duration_value, *tu)
#[cfg(all(feature = "dtype-duration", feature = "dtype-time"))]
(AnyValue::Time(v), DataType::Duration(tu)) => AnyValue::Duration(
match *tu {
TimeUnit::Nanoseconds => *v,
TimeUnit::Microseconds => *v / 1_000i64,
TimeUnit::Milliseconds => *v / 1_000_000i64,
},
_ => return cast_to_numeric(self, dtype),
},
#[cfg(feature = "dtype-date")]
AnyValue::Date(v) => match dtype {
#[cfg(feature = "dtype-datetime")]
// Date to Datetime
DataType::Datetime(tu, None) => {
let ndt = arrow::temporal_conversions::date32_to_datetime(*v);
let func = match tu {
TimeUnit::Nanoseconds => datetime_to_timestamp_ns,
TimeUnit::Microseconds => datetime_to_timestamp_us,
TimeUnit::Milliseconds => datetime_to_timestamp_ms,
};
let value = func(ndt);
AnyValue::Datetime(value, *tu, &None)
*tu,
),
#[cfg(feature = "dtype-duration")]
(AnyValue::Duration(v, tu), DataType::Duration(tu_r)) => AnyValue::Duration(
match (tu, tu_r) {
(_, _) if tu == tu_r => *v,
(TimeUnit::Nanoseconds, TimeUnit::Microseconds) => *v / 1_000i64,
(TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => *v / 1_000_000i64,
(TimeUnit::Microseconds, TimeUnit::Nanoseconds) => *v * 1_000i64,
(TimeUnit::Microseconds, TimeUnit::Milliseconds) => *v / 1_000i64,
(TimeUnit::Milliseconds, TimeUnit::Microseconds) => *v * 1_000i64,
(TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => *v * 1_000_000i64,
_ => *v,
},
_ => return cast_to_numeric(self, dtype),
},
AnyValue::String(s) if dtype == &DataType::Binary => AnyValue::Binary(s.as_bytes()),
_ => {
polars_bail!(ComputeError: "cannot cast any-value '{:?}' to '{:?}'", self.dtype(), dtype)
},
*tu_r,
),

// to self
(av, dtype) if av.dtype() == *dtype => self.clone(),

av => polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", av, dtype),
};
Ok(new_av)
}
Expand Down
Loading