Skip to content

Commit

Permalink
More refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Feb 26, 2024
1 parent 3fde519 commit b46916a
Showing 1 changed file with 80 additions and 108 deletions.
188 changes: 80 additions & 108 deletions crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,19 +492,44 @@ impl<'a> AnyValue<'a> {
///
pub fn strict_cast(&self, dtype: &'a DataType) -> PolarsResult<AnyValue<'a>> {
let new_av = match (self, dtype) {
// Casts are organized as follows:
// 1. Casts are of the form (from dtype, to dtype) that require special logic.
// 2. Casts of the form (from dtype, boolean).
// 3. Casts to numeric. These are independent of the current dtype, and will fail if
// the dtype is not inherently castable.
// 4. Cast to logical. These will fail if the underlying dtype is not compatible with
// the physical type.
// 5. Cast to string.

// 1. (from dtype, to dtype)
// Date -> Datetime
// TODO: identity map date -> date
#[cfg(all(feature = "dtype-date", feature = "dtype-datetime"))]
// to numeric
(av, DataType::UInt8) => AnyValue::UInt8(av.try_extract::<u8>()?),
(av, DataType::UInt16) => AnyValue::UInt16(av.try_extract::<u16>()?),
(av, DataType::UInt32) => AnyValue::UInt32(av.try_extract::<u32>()?),
(av, DataType::UInt64) => AnyValue::UInt64(av.try_extract::<u64>()?),
(av, DataType::Int8) => AnyValue::Int8(av.try_extract::<i8>()?),
(av, DataType::Int16) => AnyValue::Int16(av.try_extract::<i16>()?),
(av, DataType::Int32) => AnyValue::Int32(av.try_extract::<i32>()?),
(av, DataType::Int64) => AnyValue::Int64(av.try_extract::<i64>()?),
(av, DataType::Float32) => AnyValue::Float32(av.try_extract::<f32>()?),
(av, DataType::Float64) => AnyValue::Float64(av.try_extract::<f64>()?),

// to boolean
(AnyValue::UInt8(v), DataType::Boolean) => AnyValue::Boolean(*v != u8::default()),
(AnyValue::UInt16(v), DataType::Boolean) => AnyValue::Boolean(*v != u16::default()),
(AnyValue::UInt32(v), DataType::Boolean) => AnyValue::Boolean(*v != u32::default()),
(AnyValue::UInt64(v), DataType::Boolean) => AnyValue::Boolean(*v != u64::default()),
(AnyValue::Int8(v), DataType::Boolean) => AnyValue::Boolean(*v != i8::default()),
(AnyValue::Int16(v), DataType::Boolean) => AnyValue::Boolean(*v != i16::default()),
(AnyValue::Int32(v), DataType::Boolean) => AnyValue::Boolean(*v != i32::default()),
(AnyValue::Int64(v), DataType::Boolean) => AnyValue::Boolean(*v != i64::default()),
(AnyValue::Float32(v), DataType::Boolean) => AnyValue::Boolean(*v != f32::default()),
(AnyValue::Float64(v), DataType::Boolean) => AnyValue::Boolean(*v != f64::default()),

// to string
(av, DataType::String) => {
AnyValue::StringOwned(format_smartstring!("{}", av.try_extract::<i64>()?))
},

// to binary
(AnyValue::String(v), DataType::Binary) => AnyValue::Binary(v.as_bytes()),

// to datetime
#[cfg(feature = "dtype-datetime")]
(av, DataType::Datetime(tu, tz)) if av.is_numeric() => {
AnyValue::Datetime(av.try_extract::<i64>()?, *tu, tz)
},
#[cfg(all(feature = "dtype-datetime", feature = "dtype-date"))]
(AnyValue::Date(v), DataType::Datetime(tu, _)) => AnyValue::Datetime(
match tu {
TimeUnit::Nanoseconds => (*v as i64) * NS_IN_DAY,
Expand All @@ -514,78 +539,47 @@ impl<'a> AnyValue<'a> {
*tu,
&None,
),
#[cfg(feature = "dtype-datetime")]
(AnyValue::Datetime(v, tu, _), DataType::Datetime(tu_r, tz_r)) => AnyValue::Datetime(
match (tu, tu_r) {
(TimeUnit::Nanoseconds, TimeUnit::Microseconds) => *v / 1_000i64,
(TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => *v / 1_000_000i64,
(TimeUnit::Microseconds, TimeUnit::Nanoseconds) => *v * 1_000i64,
(TimeUnit::Microseconds, TimeUnit::Milliseconds) => *v / 1_000i64,
(TimeUnit::Milliseconds, TimeUnit::Microseconds) => *v * 1_000i64,
(TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => *v * 1_000_000i64,
_ => *v,
},
*tu_r,
tz_r,
),

// TODO: enable the following map (datetime -> datetime)
// Datetime -> Datetime (time unit/time zone change)
// #[cfg(feature = "dtype-datetime")]
// (AnyValue::Datetime(v, tu, _), DataType::Datetime(tu_r, tz_r)) => AnyValue::Datetime(
// match (tu, tu_r) {
// (TimeUnit::Nanoseconds, TimeUnit::Microseconds) => *v / 1_000i64,
// (TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => *v / 1_000_000i64,
// (TimeUnit::Microseconds, TimeUnit::Nanoseconds) => *v * 1_000i64,
// (TimeUnit::Microseconds, TimeUnit::Milliseconds) => *v / 1_000i64,
// (TimeUnit::Milliseconds, TimeUnit::Microseconds) => *v * 1_000i64,
// (TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => *v * 1_000_000i64,
// _ => *v,
// },
// *tu_r,
// tz_r,
// ),

// Datetime -> Date
// to date
#[cfg(feature = "dtype-date")]
(av, DataType::Date) if av.is_numeric() => AnyValue::Date(av.try_extract::<i32>()?),
#[cfg(all(feature = "dtype-date", feature = "dtype-datetime"))]
(AnyValue::Datetime(v, tu, _), DataType::Date) => AnyValue::Date(match tu {
TimeUnit::Nanoseconds => *v / NS_IN_DAY,
TimeUnit::Microseconds => *v / US_IN_DAY,
TimeUnit::Milliseconds => *v / MS_IN_DAY,
} as i32),

// Datetime -> Time
#[cfg(all(feature = "dtype-datetime", feature = "dtype-time"))]
// to time
#[cfg(feature = "dtype-time")]
(av, DataType::Time) if av.is_numeric() => AnyValue::Time(av.try_extract::<i64>()?),
#[cfg(all(feature = "dtype-time", feature = "dtype-datetime"))]
(AnyValue::Datetime(v, tu, _), DataType::Time) => AnyValue::Time(match tu {
TimeUnit::Nanoseconds => *v % NS_IN_DAY,
TimeUnit::Microseconds => (*v % US_IN_DAY) * 1_000i64,
TimeUnit::Milliseconds => (*v % MS_IN_DAY) * 1_000_000i64,
}),

// TODO: enable the following map (duration -> duration)
// Duration -> Duration (time unit change)
// #[cfg(feature = "dtype-duration")]
// (AnyValue::Duration(v, tu), DataType::Duration(tu_r)) => AnyValue::Duration(
// match (tu, tu_r) {
// (_, _) if tu == tu_r => *v,
// (TimeUnit::Nanoseconds, TimeUnit::Microseconds) => *v / 1_000i64,
// (TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => *v / 1_000_000i64,
// (TimeUnit::Microseconds, TimeUnit::Nanoseconds) => *v * 1_000i64,
// (TimeUnit::Microseconds, TimeUnit::Milliseconds) => *v / 1_000i64,
// (TimeUnit::Milliseconds, TimeUnit::Microseconds) => *v * 1_000i64,
// (TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => *v * 1_000_000i64,
// _ => *v,
// },
// *tu_r,
// ),

// Duration -> Time | Date | Datetime
#[cfg(all(feature = "dtype-duration", feature = "dtype-time"))]
(AnyValue::Duration(v, _), DataType::Time) => {
polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype 'Time'", v)
},
#[cfg(all(feature = "dtype-duration", feature = "dtype-date"))]
(AnyValue::Duration(v, _), DataType::Date) => {
polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype 'Date'", v)
},
#[cfg(all(feature = "dtype-duration", feature = "dtype-datetime"))]
(AnyValue::Duration(v, _), DataType::Datetime(_, _)) => {
polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype 'Datetime'", v)
// to duration
#[cfg(feature = "dtype-duration")]
(av, DataType::Duration(tu)) if av.is_numeric() => {
AnyValue::Duration(av.try_extract::<i64>()?, *tu)
},

// TODO: enable the following map (Time -> Time)
// Time -> Time
// #[cfg(feature = "dtype-time")]
// (AnyValue::Time(_), DataType::Time) => self.clone(),

// Time -> Duration
#[cfg(all(feature = "dtype-time", feature = "dtype-duration"))]
#[cfg(all(feature = "dtype-duration", feature = "dtype-time"))]
(AnyValue::Time(v), DataType::Duration(tu)) => AnyValue::Duration(
match *tu {
TimeUnit::Nanoseconds => *v,
Expand All @@ -594,46 +588,24 @@ impl<'a> AnyValue<'a> {
},
*tu,
),
#[cfg(feature = "dtype-duration")]
(AnyValue::Duration(v, tu), DataType::Duration(tu_r)) => AnyValue::Duration(
match (tu, tu_r) {
(_, _) if tu == tu_r => *v,
(TimeUnit::Nanoseconds, TimeUnit::Microseconds) => *v / 1_000i64,
(TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => *v / 1_000_000i64,
(TimeUnit::Microseconds, TimeUnit::Nanoseconds) => *v * 1_000i64,
(TimeUnit::Microseconds, TimeUnit::Milliseconds) => *v / 1_000i64,
(TimeUnit::Milliseconds, TimeUnit::Microseconds) => *v * 1_000i64,
(TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => *v * 1_000_000i64,
_ => *v,
},
*tu_r,
),

// String -> Binary
(AnyValue::String(v), DataType::Binary) => AnyValue::Binary(v.as_bytes()),

// _ -> Boolean
(AnyValue::UInt8(v), DataType::Boolean) => AnyValue::Boolean(*v != u8::default()),
(AnyValue::UInt16(v), DataType::Boolean) => AnyValue::Boolean(*v != u16::default()),
(AnyValue::UInt32(v), DataType::Boolean) => AnyValue::Boolean(*v != u32::default()),
(AnyValue::UInt64(v), DataType::Boolean) => AnyValue::Boolean(*v != u64::default()),
(AnyValue::Int8(v), DataType::Boolean) => AnyValue::Boolean(*v != i8::default()),
(AnyValue::Int16(v), DataType::Boolean) => AnyValue::Boolean(*v != i16::default()),
(AnyValue::Int32(v), DataType::Boolean) => AnyValue::Boolean(*v != i32::default()),
(AnyValue::Int64(v), DataType::Boolean) => AnyValue::Boolean(*v != i64::default()),
(AnyValue::Float32(v), DataType::Boolean) => AnyValue::Boolean(*v != f32::default()),
(AnyValue::Float64(v), DataType::Boolean) => AnyValue::Boolean(*v != f64::default()),
// to self
(av, dtype) if av.dtype() == *dtype => self.clone(),

// standard casts
(av, DataType::UInt8) => AnyValue::UInt8(av.try_extract::<u8>()?),
(av, DataType::UInt16) => AnyValue::UInt16(av.try_extract::<u16>()?),
(av, DataType::UInt32) => AnyValue::UInt32(av.try_extract::<u32>()?),
(av, DataType::UInt64) => AnyValue::UInt64(av.try_extract::<u64>()?),
(av, DataType::Int8) => AnyValue::Int8(av.try_extract::<i8>()?),
(av, DataType::Int16) => AnyValue::Int16(av.try_extract::<i16>()?),
(av, DataType::Int32) => AnyValue::Int32(av.try_extract::<i32>()?),
(av, DataType::Int64) => AnyValue::Int64(av.try_extract::<i64>()?),
(av, DataType::Float32) => AnyValue::Float32(av.try_extract::<f32>()?),
(av, DataType::Float64) => AnyValue::Float64(av.try_extract::<f64>()?),
#[cfg(feature = "dtype-date")]
(av, DataType::Date) => AnyValue::Date(av.try_extract::<i32>()?),
#[cfg(feature = "dtype-datetime")]
(av, DataType::Datetime(tu, tz)) => {
AnyValue::Datetime(av.try_extract::<i64>()?, *tu, tz)
},
#[cfg(feature = "dtype-duration")]
(av, DataType::Duration(tu)) => AnyValue::Duration(av.try_extract::<i64>()?, *tu),
#[cfg(feature = "dtype-time")]
(av, DataType::Time) => AnyValue::Time(av.try_extract::<i64>()?),
(av, DataType::String) => {
AnyValue::StringOwned(format_smartstring!("{}", av.try_extract::<i64>()?))
},
av => polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", av, dtype),
};
Ok(new_av)
Expand Down

0 comments on commit b46916a

Please sign in to comment.