Skip to content

Commit

Permalink
refactor(rust): Refactor AnyValue casting logic
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Feb 26, 2024
1 parent 6b7eb52 commit 3fde519
Showing 1 changed file with 137 additions and 117 deletions.
254 changes: 137 additions & 117 deletions crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
#[cfg(feature = "dtype-struct")]
use arrow::legacy::trusted_len::TrustedLenPush;
#[cfg(feature = "dtype-date")]
use arrow::temporal_conversions::{
timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime,
};
use arrow::types::PrimitiveType;
use polars_utils::format_smartstring;
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -491,130 +487,154 @@ impl<'a> AnyValue<'a> {
}
}

/// Cast `AnyValue` to the provided data type and return a new `AnyValue` with type `dtype`,
/// if possible.
///
pub fn strict_cast(&self, dtype: &'a DataType) -> PolarsResult<AnyValue<'a>> {
fn cast_to_numeric<'a>(av: &AnyValue, dtype: &'a DataType) -> PolarsResult<AnyValue<'a>> {
let out = match dtype {
DataType::UInt8 => AnyValue::UInt8(av.try_extract::<u8>()?),
DataType::UInt16 => AnyValue::UInt16(av.try_extract::<u16>()?),
DataType::UInt32 => AnyValue::UInt32(av.try_extract::<u32>()?),
DataType::UInt64 => AnyValue::UInt64(av.try_extract::<u64>()?),
DataType::Int8 => AnyValue::Int8(av.try_extract::<i8>()?),
DataType::Int16 => AnyValue::Int16(av.try_extract::<i16>()?),
DataType::Int32 => AnyValue::Int32(av.try_extract::<i32>()?),
DataType::Int64 => AnyValue::Int64(av.try_extract::<i64>()?),
DataType::Float32 => AnyValue::Float32(av.try_extract::<f32>()?),
DataType::Float64 => AnyValue::Float64(av.try_extract::<f64>()?),
_ => {
polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", av, dtype)
},
};
Ok(out)
}
let new_av = match (self, dtype) {
// Casts are organized as follows:
// 1. Casts are of the form (from dtype, to dtype) that require special logic.
// 2. Casts of the form (from dtype, boolean).
// 3. Casts to numeric. These are independent of the current dtype, and will fail if
// the dtype is not inherently castable.
// 4. Cast to logical. These will fail if the underlying dtype is not compatible with
// the physical type.
// 5. Cast to string.

fn cast_to_boolean<'a>(av: &AnyValue) -> PolarsResult<AnyValue<'a>> {
let out = match av {
AnyValue::UInt8(v) => AnyValue::Boolean(*v != u8::default()),
AnyValue::UInt16(v) => AnyValue::Boolean(*v != u16::default()),
AnyValue::UInt32(v) => AnyValue::Boolean(*v != u32::default()),
AnyValue::UInt64(v) => AnyValue::Boolean(*v != u64::default()),
AnyValue::Int8(v) => AnyValue::Boolean(*v != i8::default()),
AnyValue::Int16(v) => AnyValue::Boolean(*v != i16::default()),
AnyValue::Int32(v) => AnyValue::Boolean(*v != i32::default()),
AnyValue::Int64(v) => AnyValue::Boolean(*v != i64::default()),
AnyValue::Float32(v) => AnyValue::Boolean(*v != f32::default()),
AnyValue::Float64(v) => AnyValue::Boolean(*v != f64::default()),
_ => {
polars_bail!(ComputeError: "cannot cast any-value {:?} to boolean", av)
// 1. (from dtype, to dtype)
// Date -> Datetime
// TODO: identity map date -> date
#[cfg(all(feature = "dtype-date", feature = "dtype-datetime"))]
(AnyValue::Date(v), DataType::Datetime(tu, _)) => AnyValue::Datetime(
match tu {
TimeUnit::Nanoseconds => (*v as i64) * NS_IN_DAY,
TimeUnit::Microseconds => (*v as i64) * US_IN_DAY,
TimeUnit::Milliseconds => (*v as i64) * MS_IN_DAY,
},
};
Ok(out)
}
*tu,
&None,
),

let new_av = match self {
_ if (self.is_boolean() | self.is_numeric()) => match dtype {
#[cfg(feature = "dtype-date")]
DataType::Date => AnyValue::Date(self.try_extract::<i32>()?),
#[cfg(feature = "dtype-datetime")]
DataType::Datetime(tu, tz) => {
AnyValue::Datetime(self.try_extract::<i64>()?, *tu, tz)
},
#[cfg(feature = "dtype-duration")]
DataType::Duration(tu) => AnyValue::Duration(self.try_extract::<i64>()?, *tu),
#[cfg(feature = "dtype-time")]
DataType::Time => AnyValue::Time(self.try_extract::<i64>()?),
DataType::String => {
AnyValue::StringOwned(format_smartstring!("{}", self.try_extract::<i64>()?))
},
DataType::Boolean => return cast_to_boolean(self),
_ => return cast_to_numeric(self, dtype),
// TODO: enable the following map (datetime -> datetime)
// Datetime -> Datetime (time unit/time zone change)
// #[cfg(feature = "dtype-datetime")]
// (AnyValue::Datetime(v, tu, _), DataType::Datetime(tu_r, tz_r)) => AnyValue::Datetime(
// match (tu, tu_r) {
// (TimeUnit::Nanoseconds, TimeUnit::Microseconds) => *v / 1_000i64,
// (TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => *v / 1_000_000i64,
// (TimeUnit::Microseconds, TimeUnit::Nanoseconds) => *v * 1_000i64,
// (TimeUnit::Microseconds, TimeUnit::Milliseconds) => *v / 1_000i64,
// (TimeUnit::Milliseconds, TimeUnit::Microseconds) => *v * 1_000i64,
// (TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => *v * 1_000_000i64,
// _ => *v,
// },
// *tu_r,
// tz_r,
// ),

// Datetime -> Date
#[cfg(all(feature = "dtype-date", feature = "dtype-datetime"))]
(AnyValue::Datetime(v, tu, _), DataType::Date) => AnyValue::Date(match tu {
TimeUnit::Nanoseconds => *v / NS_IN_DAY,
TimeUnit::Microseconds => *v / US_IN_DAY,
TimeUnit::Milliseconds => *v / MS_IN_DAY,
} as i32),

// Datetime -> Time
#[cfg(all(feature = "dtype-datetime", feature = "dtype-time"))]
(AnyValue::Datetime(v, tu, _), DataType::Time) => AnyValue::Time(match tu {
TimeUnit::Nanoseconds => *v % NS_IN_DAY,
TimeUnit::Microseconds => (*v % US_IN_DAY) * 1_000i64,
TimeUnit::Milliseconds => (*v % MS_IN_DAY) * 1_000_000i64,
}),

// TODO: enable the following map (duration -> duration)
// Duration -> Duration (time unit change)
// #[cfg(feature = "dtype-duration")]
// (AnyValue::Duration(v, tu), DataType::Duration(tu_r)) => AnyValue::Duration(
// match (tu, tu_r) {
// (_, _) if tu == tu_r => *v,
// (TimeUnit::Nanoseconds, TimeUnit::Microseconds) => *v / 1_000i64,
// (TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => *v / 1_000_000i64,
// (TimeUnit::Microseconds, TimeUnit::Nanoseconds) => *v * 1_000i64,
// (TimeUnit::Microseconds, TimeUnit::Milliseconds) => *v / 1_000i64,
// (TimeUnit::Milliseconds, TimeUnit::Microseconds) => *v * 1_000i64,
// (TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => *v * 1_000_000i64,
// _ => *v,
// },
// *tu_r,
// ),

// Duration -> Time | Date | Datetime
#[cfg(all(feature = "dtype-duration", feature = "dtype-time"))]
(AnyValue::Duration(v, _), DataType::Time) => {
polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype 'Time'", v)
},
#[cfg(feature = "dtype-datetime")]
AnyValue::Datetime(v, tu, None) => match dtype {
#[cfg(feature = "dtype-date")]
// Datetime to Date
DataType::Date => {
let convert = match tu {
TimeUnit::Nanoseconds => timestamp_ns_to_datetime,
TimeUnit::Microseconds => timestamp_us_to_datetime,
TimeUnit::Milliseconds => timestamp_ms_to_datetime,
};
let ndt = convert(*v);
let date_value = naive_datetime_to_date(ndt);
AnyValue::Date(date_value)
},
#[cfg(feature = "dtype-time")]
// Datetime to Time
DataType::Time => {
let ns_since_midnight = match tu {
TimeUnit::Nanoseconds => *v % NS_IN_DAY,
TimeUnit::Microseconds => (*v % US_IN_DAY) * 1_000i64,
TimeUnit::Milliseconds => (*v % MS_IN_DAY) * 1_000_000i64,
};
AnyValue::Time(ns_since_midnight)
},
_ => return cast_to_numeric(self, dtype),
#[cfg(all(feature = "dtype-duration", feature = "dtype-date"))]
(AnyValue::Duration(v, _), DataType::Date) => {
polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype 'Date'", v)
},
#[cfg(feature = "dtype-duration")]
AnyValue::Duration(v, _) => match dtype {
DataType::Time | DataType::Date | DataType::Datetime(_, _) => {
polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", v, dtype)
},
_ => return cast_to_numeric(self, dtype),
#[cfg(all(feature = "dtype-duration", feature = "dtype-datetime"))]
(AnyValue::Duration(v, _), DataType::Datetime(_, _)) => {
polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype 'Datetime'", v)
},
#[cfg(feature = "dtype-time")]
AnyValue::Time(v) => match dtype {
#[cfg(feature = "dtype-duration")]
// Time to Duration
DataType::Duration(tu) => {
let duration_value = match tu {
TimeUnit::Nanoseconds => *v,
TimeUnit::Microseconds => *v / 1_000i64,
TimeUnit::Milliseconds => *v / 1_000_000i64,
};
AnyValue::Duration(duration_value, *tu)

// TODO: enable the following map (Time -> Time)
// Time -> Time
// #[cfg(feature = "dtype-time")]
// (AnyValue::Time(_), DataType::Time) => self.clone(),

// Time -> Duration
#[cfg(all(feature = "dtype-time", feature = "dtype-duration"))]
(AnyValue::Time(v), DataType::Duration(tu)) => AnyValue::Duration(
match *tu {
TimeUnit::Nanoseconds => *v,
TimeUnit::Microseconds => *v / 1_000i64,
TimeUnit::Milliseconds => *v / 1_000_000i64,
},
_ => return cast_to_numeric(self, dtype),
},
*tu,
),

// String -> Binary
(AnyValue::String(v), DataType::Binary) => AnyValue::Binary(v.as_bytes()),

// _ -> Boolean
(AnyValue::UInt8(v), DataType::Boolean) => AnyValue::Boolean(*v != u8::default()),
(AnyValue::UInt16(v), DataType::Boolean) => AnyValue::Boolean(*v != u16::default()),
(AnyValue::UInt32(v), DataType::Boolean) => AnyValue::Boolean(*v != u32::default()),
(AnyValue::UInt64(v), DataType::Boolean) => AnyValue::Boolean(*v != u64::default()),
(AnyValue::Int8(v), DataType::Boolean) => AnyValue::Boolean(*v != i8::default()),
(AnyValue::Int16(v), DataType::Boolean) => AnyValue::Boolean(*v != i16::default()),
(AnyValue::Int32(v), DataType::Boolean) => AnyValue::Boolean(*v != i32::default()),
(AnyValue::Int64(v), DataType::Boolean) => AnyValue::Boolean(*v != i64::default()),
(AnyValue::Float32(v), DataType::Boolean) => AnyValue::Boolean(*v != f32::default()),
(AnyValue::Float64(v), DataType::Boolean) => AnyValue::Boolean(*v != f64::default()),

// standard casts
(av, DataType::UInt8) => AnyValue::UInt8(av.try_extract::<u8>()?),
(av, DataType::UInt16) => AnyValue::UInt16(av.try_extract::<u16>()?),
(av, DataType::UInt32) => AnyValue::UInt32(av.try_extract::<u32>()?),
(av, DataType::UInt64) => AnyValue::UInt64(av.try_extract::<u64>()?),
(av, DataType::Int8) => AnyValue::Int8(av.try_extract::<i8>()?),
(av, DataType::Int16) => AnyValue::Int16(av.try_extract::<i16>()?),
(av, DataType::Int32) => AnyValue::Int32(av.try_extract::<i32>()?),
(av, DataType::Int64) => AnyValue::Int64(av.try_extract::<i64>()?),
(av, DataType::Float32) => AnyValue::Float32(av.try_extract::<f32>()?),
(av, DataType::Float64) => AnyValue::Float64(av.try_extract::<f64>()?),
#[cfg(feature = "dtype-date")]
AnyValue::Date(v) => match dtype {
#[cfg(feature = "dtype-datetime")]
// Date to Datetime
DataType::Datetime(tu, None) => {
let ndt = arrow::temporal_conversions::date32_to_datetime(*v);
let func = match tu {
TimeUnit::Nanoseconds => datetime_to_timestamp_ns,
TimeUnit::Microseconds => datetime_to_timestamp_us,
TimeUnit::Milliseconds => datetime_to_timestamp_ms,
};
let value = func(ndt);
AnyValue::Datetime(value, *tu, &None)
},
_ => return cast_to_numeric(self, dtype),
(av, DataType::Date) => AnyValue::Date(av.try_extract::<i32>()?),
#[cfg(feature = "dtype-datetime")]
(av, DataType::Datetime(tu, tz)) => {
AnyValue::Datetime(av.try_extract::<i64>()?, *tu, tz)
},
AnyValue::String(s) if dtype == &DataType::Binary => AnyValue::Binary(s.as_bytes()),
_ => {
polars_bail!(ComputeError: "cannot cast any-value '{:?}' to '{:?}'", self.dtype(), dtype)
#[cfg(feature = "dtype-duration")]
(av, DataType::Duration(tu)) => AnyValue::Duration(av.try_extract::<i64>()?, *tu),
#[cfg(feature = "dtype-time")]
(av, DataType::Time) => AnyValue::Time(av.try_extract::<i64>()?),
(av, DataType::String) => {
AnyValue::StringOwned(format_smartstring!("{}", av.try_extract::<i64>()?))
},
av => polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", av, dtype),
};
Ok(new_av)
}
Expand Down

0 comments on commit 3fde519

Please sign in to comment.