From 4e268306f6ab2c6caf6c7a2e9c69716ede666f27 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Fri, 21 Jun 2024 16:09:09 +0200 Subject: [PATCH] fix(rust): bubble error when no available bitrepr This PR refactors the code surrounding bit representations and properly bubbles errors up if no bit representation is defined. This resolves the panic in #14826, but does not implement the wanted behavior. I am not sure a test case is useful here as the behavior should be added in the end. --- .../src/chunked_array/ops/bit_repr.rs | 77 +++--- .../polars-core/src/chunked_array/ops/mod.rs | 7 +- .../src/frame/group_by/into_groups.rs | 9 +- .../src/series/implementations/categorical.rs | 7 +- .../src/series/implementations/date.rs | 12 +- .../src/series/implementations/datetime.rs | 7 +- .../src/series/implementations/decimal.rs | 6 +- .../src/series/implementations/duration.rs | 7 +- .../src/series/implementations/mod.rs | 50 ++-- .../src/series/implementations/null.rs | 7 +- .../src/series/implementations/object.rs | 8 +- .../src/series/implementations/struct_.rs | 6 +- .../src/series/implementations/time.rs | 12 +- crates/polars-core/src/series/series_trait.rs | 18 +- .../polars-ops/src/chunked_array/list/hash.rs | 11 +- .../polars-ops/src/frame/join/asof/groups.rs | 51 ++-- .../src/frame/join/hash_join/mod.rs | 2 +- .../join/hash_join/single_keys_dispatch.rs | 227 +++++++++++------- .../polars-ops/src/frame/pivot/positioning.rs | 51 ++-- 19 files changed, 339 insertions(+), 236 deletions(-) diff --git a/crates/polars-core/src/chunked_array/ops/bit_repr.rs b/crates/polars-core/src/chunked_array/ops/bit_repr.rs index 3ccc0624d1c3..37617a1d43ea 100644 --- a/crates/polars-core/src/chunked_array/ops/bit_repr.rs +++ b/crates/polars-core/src/chunked_array/ops/bit_repr.rs @@ -1,6 +1,7 @@ use arrow::buffer::Buffer; use crate::prelude::*; +use crate::series::BitRepr; /// Reinterprets the type of a [`ChunkedArray`]. T and U must have the same size /// and alignment. @@ -103,41 +104,41 @@ impl ToBitRepr for ChunkedArray where T: PolarsNumericType, { - fn bit_repr_is_large() -> bool { - std::mem::size_of::() == 8 - } + fn to_bit_repr(&self) -> BitRepr { + let is_large = std::mem::size_of::() == 8; - fn bit_repr_large(&self) -> UInt64Chunked { - if std::mem::size_of::() == 8 { + if is_large { if matches!(self.dtype(), DataType::UInt64) { let ca = self.clone(); // Convince the compiler we are this type. This keeps flags. - return unsafe { std::mem::transmute::, UInt64Chunked>(ca) }; + return BitRepr::Large(unsafe { + std::mem::transmute::, UInt64Chunked>(ca) + }); } - reinterpret_chunked_array(self) - } else { - unreachable!() - } - } - fn bit_repr_small(&self) -> UInt32Chunked { - if std::mem::size_of::() == 4 { - if matches!(self.dtype(), DataType::UInt32) { - let ca = self.clone(); - // Convince the compiler we are this type. This preserves flags. - return unsafe { std::mem::transmute::, UInt32Chunked>(ca) }; - } - reinterpret_chunked_array(self) + BitRepr::Large(reinterpret_chunked_array(self)) } else { - // SAFETY: an unchecked cast to uint32 (which has no invariants) is - // always sound. - unsafe { - self.cast_unchecked(&DataType::UInt32) - .unwrap() - .u32() - .unwrap() - .clone() - } + BitRepr::Small(if std::mem::size_of::() == 4 { + if matches!(self.dtype(), DataType::UInt32) { + let ca = self.clone(); + // Convince the compiler we are this type. This preserves flags. + return BitRepr::Small(unsafe { + std::mem::transmute::, UInt32Chunked>(ca) + }); + } + + reinterpret_chunked_array(self) + } else { + // SAFETY: an unchecked cast to uint32 (which has no invariants) is + // always sound. + unsafe { + self.cast_unchecked(&DataType::UInt32) + .unwrap() + .u32() + .unwrap() + .clone() + } + }) } } } @@ -160,7 +161,10 @@ impl Reinterpret for Int64Chunked { } fn reinterpret_unsigned(&self) -> Series { - self.bit_repr_large().into_series() + let BitRepr::Large(b) = self.to_bit_repr() else { + unreachable!() + }; + b.into_series() } } @@ -183,7 +187,10 @@ impl Reinterpret for Int32Chunked { } fn reinterpret_unsigned(&self) -> Series { - self.bit_repr_small().into_series() + let BitRepr::Small(b) = self.to_bit_repr() else { + unreachable!() + }; + b.into_series() } } @@ -250,7 +257,10 @@ impl Float32Chunked { where F: Fn(&Series) -> Series, { - let s = self.bit_repr_small().into_series(); + let BitRepr::Small(s) = self.to_bit_repr() else { + unreachable!() + }; + let s = s.into_series(); let out = f(&s); let out = out.u32().unwrap(); out._reinterpret_float().into() @@ -261,7 +271,10 @@ impl Float64Chunked { where F: Fn(&Series) -> Series, { - let s = self.bit_repr_large().into_series(); + let BitRepr::Large(s) = self.to_bit_repr() else { + unreachable!() + }; + let s = s.into_series(); let out = f(&s); let out = out.u64().unwrap(); out._reinterpret_float().into() diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index fb4dafaf3037..42f621404605 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -43,7 +43,7 @@ use serde::{Deserialize, Serialize}; pub use sort::options::*; use crate::chunked_array::cast::CastOptions; -use crate::series::IsSorted; +use crate::series::{BitRepr, IsSorted}; #[cfg(feature = "reinterpret")] pub trait Reinterpret { fn reinterpret_signed(&self) -> Series { @@ -59,10 +59,7 @@ pub trait Reinterpret { /// This is useful in hashing context and reduces no. /// of compiled code paths. pub(crate) trait ToBitRepr { - fn bit_repr_is_large() -> bool; - - fn bit_repr_large(&self) -> UInt64Chunked; - fn bit_repr_small(&self) -> UInt32Chunked; + fn to_bit_repr(&self) -> BitRepr; } pub trait ChunkAnyValue { diff --git a/crates/polars-core/src/frame/group_by/into_groups.rs b/crates/polars-core/src/frame/group_by/into_groups.rs index 365545b0bbb5..d507459e7c49 100644 --- a/crates/polars-core/src/frame/group_by/into_groups.rs +++ b/crates/polars-core/src/frame/group_by/into_groups.rs @@ -5,6 +5,7 @@ use super::*; use crate::chunked_array::cast::CastOptions; use crate::config::verbose; use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca_unordered; +use crate::series::BitRepr; use crate::utils::flatten::flatten_par; /// Used to create the tuples for a group_by operation. @@ -163,11 +164,15 @@ where num_groups_proxy(ca, multithreaded, sorted) }, DataType::Int64 => { - let ca = self.bit_repr_large(); + let BitRepr::Large(ca) = self.to_bit_repr() else { + unreachable!() + }; num_groups_proxy(&ca, multithreaded, sorted) }, DataType::Int32 => { - let ca = self.bit_repr_small(); + let BitRepr::Small(ca) = self.to_bit_repr() else { + unreachable!() + }; num_groups_proxy(&ca, multithreaded, sorted) }, DataType::Float64 => { diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index 97ac4be0031a..5204048bda81 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -304,10 +304,7 @@ impl SeriesTrait for SeriesWrap { } impl private::PrivateSeriesNumeric for SeriesWrap { - fn bit_repr_is_large(&self) -> bool { - false - } - fn bit_repr_small(&self) -> UInt32Chunked { - self.0.physical().clone() + fn bit_repr(&self) -> Option { + Some(BitRepr::Small(self.0.physical().clone())) } } diff --git a/crates/polars-core/src/series/implementations/date.rs b/crates/polars-core/src/series/implementations/date.rs index 296122f0ed57..e0658ceb32ce 100644 --- a/crates/polars-core/src/series/implementations/date.rs +++ b/crates/polars-core/src/series/implementations/date.rs @@ -349,15 +349,7 @@ impl SeriesTrait for SeriesWrap { } impl private::PrivateSeriesNumeric for SeriesWrap { - fn bit_repr_is_large(&self) -> bool { - false - } - - fn bit_repr_large(&self) -> UInt64Chunked { - self.0.bit_repr_large() - } - - fn bit_repr_small(&self) -> UInt32Chunked { - self.0.bit_repr_small() + fn bit_repr(&self) -> Option { + Some(self.0.to_bit_repr()) } } diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index 42f60bd06c4e..de22804a5ab0 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -10,11 +10,8 @@ unsafe impl IntoSeries for DatetimeChunked { } impl private::PrivateSeriesNumeric for SeriesWrap { - fn bit_repr_is_large(&self) -> bool { - true - } - fn bit_repr_large(&self) -> UInt64Chunked { - self.0.bit_repr_large() + fn bit_repr(&self) -> Option { + Some(self.0.to_bit_repr()) } } diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index bfc79836b618..9f5c382d94c7 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -7,7 +7,11 @@ unsafe impl IntoSeries for DecimalChunked { } } -impl private::PrivateSeriesNumeric for SeriesWrap {} +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + None + } +} impl SeriesWrap { fn apply_physical_to_s Int128Chunked>(&self, f: F) -> Series { diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index c6a4760c8004..6667b69fea41 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -13,11 +13,8 @@ unsafe impl IntoSeries for DurationChunked { } impl private::PrivateSeriesNumeric for SeriesWrap { - fn bit_repr_is_large(&self) -> bool { - true - } - fn bit_repr_large(&self) -> UInt64Chunked { - self.0.bit_repr_large() + fn bit_repr(&self) -> Option { + Some(self.0.to_bit_repr()) } } diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 07e55b7eae3c..3712bdcc3393 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -484,33 +484,47 @@ impl_dyn_series!(Int32Chunked); impl_dyn_series!(Int64Chunked); impl private::PrivateSeriesNumeric for SeriesWrap> { - fn bit_repr_is_large(&self) -> bool { - ChunkedArray::::bit_repr_is_large() + fn bit_repr(&self) -> Option { + Some(self.0.to_bit_repr()) } - fn bit_repr_large(&self) -> UInt64Chunked { - self.0.bit_repr_large() +} + +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + None + } +} +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + None } - fn bit_repr_small(&self) -> UInt32Chunked { - self.0.bit_repr_small() +} +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + None + } +} +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + None } } - -impl private::PrivateSeriesNumeric for SeriesWrap {} -impl private::PrivateSeriesNumeric for SeriesWrap {} -impl private::PrivateSeriesNumeric for SeriesWrap {} -impl private::PrivateSeriesNumeric for SeriesWrap {} #[cfg(feature = "dtype-array")] -impl private::PrivateSeriesNumeric for SeriesWrap {} -impl private::PrivateSeriesNumeric for SeriesWrap { - fn bit_repr_is_large(&self) -> bool { - false +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + None } - fn bit_repr_small(&self) -> UInt32Chunked { - self.0 +} +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + let repr = self + .0 .cast_with_options(&DataType::UInt32, CastOptions::NonStrict) .unwrap() .u32() .unwrap() - .clone() + .clone(); + + Some(BitRepr::Small(repr)) } } diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs index 564a8f93669d..ddf5f746cb18 100644 --- a/crates/polars-core/src/series/implementations/null.rs +++ b/crates/polars-core/src/series/implementations/null.rs @@ -36,8 +36,11 @@ impl NullChunked { } } impl PrivateSeriesNumeric for NullChunked { - fn bit_repr_small(&self) -> UInt32Chunked { - UInt32Chunked::full_null(self.name.as_ref(), self.len()) + fn bit_repr(&self) -> Option { + Some(BitRepr::Small(UInt32Chunked::full_null( + self.name.as_ref(), + self.len(), + ))) } } diff --git a/crates/polars-core/src/series/implementations/object.rs b/crates/polars-core/src/series/implementations/object.rs index e9f80a82f4e7..0784098e9460 100644 --- a/crates/polars-core/src/series/implementations/object.rs +++ b/crates/polars-core/src/series/implementations/object.rs @@ -3,7 +3,7 @@ use std::borrow::Cow; use ahash::RandomState; -use super::MetadataFlags; +use super::{BitRepr, MetadataFlags}; use crate::chunked_array::cast::CastOptions; use crate::chunked_array::object::PolarsObjectSafe; use crate::chunked_array::ops::compare_inner::{IntoTotalEqInner, TotalEqInner}; @@ -11,7 +11,11 @@ use crate::prelude::*; use crate::series::implementations::SeriesWrap; use crate::series::private::{PrivateSeries, PrivateSeriesNumeric}; -impl PrivateSeriesNumeric for SeriesWrap> {} +impl PrivateSeriesNumeric for SeriesWrap> { + fn bit_repr(&self) -> Option { + None + } +} impl PrivateSeries for SeriesWrap> where diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index d9ed03948fce..9a50bc9a7364 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -9,7 +9,11 @@ unsafe impl IntoSeries for StructChunked { } } -impl PrivateSeriesNumeric for SeriesWrap {} +impl PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + None + } +} impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { diff --git a/crates/polars-core/src/series/implementations/time.rs b/crates/polars-core/src/series/implementations/time.rs index 7e39af3f271a..159f07a48426 100644 --- a/crates/polars-core/src/series/implementations/time.rs +++ b/crates/polars-core/src/series/implementations/time.rs @@ -314,15 +314,7 @@ impl SeriesTrait for SeriesWrap { } impl private::PrivateSeriesNumeric for SeriesWrap { - fn bit_repr_is_large(&self) -> bool { - true - } - - fn bit_repr_large(&self) -> UInt64Chunked { - self.0.bit_repr_large() - } - - fn bit_repr_small(&self) -> UInt32Chunked { - self.0.bit_repr_small() + fn bit_repr(&self) -> Option { + Some(self.0.to_bit_repr()) } } diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 409ed388b589..16faebe0012b 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -39,6 +39,11 @@ macro_rules! invalid_operation_panic { }; } +pub enum BitRepr { + Small(UInt32Chunked), + Large(UInt64Chunked), +} + pub(crate) mod private { use ahash::RandomState; @@ -47,15 +52,10 @@ pub(crate) mod private { use crate::chunked_array::ops::compare_inner::{TotalEqInner, TotalOrdInner}; pub trait PrivateSeriesNumeric { - fn bit_repr_is_large(&self) -> bool { - false - } - fn bit_repr_large(&self) -> UInt64Chunked { - unimplemented!() - } - fn bit_repr_small(&self) -> UInt32Chunked { - unimplemented!() - } + /// Return a bit representation + /// + /// If there is no available bit representation this returns `None`. + fn bit_repr(&self) -> Option; } pub trait PrivateSeries { diff --git a/crates/polars-ops/src/chunked_array/list/hash.rs b/crates/polars-ops/src/chunked_array/list/hash.rs index fe00dcdceeb6..67cb61a51273 100644 --- a/crates/polars-ops/src/chunked_array/list/hash.rs +++ b/crates/polars-ops/src/chunked_array/list/hash.rs @@ -2,6 +2,7 @@ use std::hash::Hash; use polars_core::export::_boost_hash_combine; use polars_core::export::rayon::prelude::*; +use polars_core::series::BitRepr; use polars_core::utils::NoNull; use polars_core::{with_match_physical_float_polars_type, POOL}; use polars_utils::total_ord::{ToTotalOrd, TotalHash}; @@ -66,12 +67,12 @@ pub(crate) fn hash(ca: &mut ListChunked, build_hasher: ahash::RandomState) -> UI let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); hash_agg(ca, &build_hasher) }) - } else if s.bit_repr_is_large() { - let ca = s.bit_repr_large(); - hash_agg(&ca, &build_hasher) } else { - let ca = s.bit_repr_small(); - hash_agg(&ca, &build_hasher) + match s.bit_repr() { + None => unimplemented!("Hash for lists without bit representation"), + Some(BitRepr::Small(ca)) => hash_agg(&ca, &build_hasher), + Some(BitRepr::Large(ca)) => hash_agg(&ca, &build_hasher), + } } }, }) diff --git a/crates/polars-ops/src/frame/join/asof/groups.rs b/crates/polars-ops/src/frame/join/asof/groups.rs index 918d95596dee..4547fbe4f141 100644 --- a/crates/polars-ops/src/frame/join/asof/groups.rs +++ b/crates/polars-ops/src/frame/join/asof/groups.rs @@ -8,6 +8,7 @@ use polars_core::hashing::{ _HASHMAP_INIT_SIZE, }; use polars_core::prelude::*; +use polars_core::series::BitRepr; use polars_core::utils::flatten::flatten_nullable; use polars_core::utils::{_set_partition_size, split_and_flatten}; use polars_core::{with_match_physical_float_polars_type, IdBuildHasher, POOL}; @@ -402,7 +403,7 @@ where let left_dtype = left_by_s.dtype(); let right_dtype = right_by_s.dtype(); polars_ensure!(left_dtype == right_dtype, - ComputeError: "mismatching dtypes in 'by' parameter of asof-join: `{}` and `{}`", left_dtype, right_dtype + ComputeError: "mismatching dtypes in 'by' parameter of asof-join: `{left_dtype}` and `{right_dtype}`", ); match left_dtype { DataType::String => { @@ -415,27 +416,37 @@ where let right_by = right_by_s.binary().unwrap(); asof_join_by_binary::(left_by, right_by, left_asof, right_asof, filter) }, + x if x.is_float() => { + with_match_physical_float_polars_type!(left_by_s.dtype(), |$T| { + let left_by: &ChunkedArray<$T> = left_by_s.as_ref().as_ref().as_ref(); + let right_by: &ChunkedArray<$T> = right_by_s.as_ref().as_ref().as_ref(); + asof_join_by_numeric::( + left_by, right_by, left_asof, right_asof, filter, + )? + }) + }, _ => { - if left_by_s.dtype().is_float() { - with_match_physical_float_polars_type!(left_by_s.dtype(), |$T| { - let left_by: &ChunkedArray<$T> = left_by_s.as_ref().as_ref().as_ref(); - let right_by: &ChunkedArray<$T> = right_by_s.as_ref().as_ref().as_ref(); - asof_join_by_numeric::( - left_by, right_by, left_asof, right_asof, filter, + let left_by = left_by_s.bit_repr(); + let right_by = right_by_s.bit_repr(); + + let (Some(left_by), Some(right_by)) = (left_by, right_by) else { + polars_bail!(nyi = "Dispatch join for {left_dtype} and {right_dtype}"); + }; + + use BitRepr as B; + match (left_by, right_by) { + (B::Small(left_by), B::Small(right_by)) => { + asof_join_by_numeric::( + &left_by, &right_by, left_asof, right_asof, filter, )? - }) - } else if left_by_s.bit_repr_is_large() { - let left_by = left_by_s.bit_repr_large(); - let right_by = right_by_s.bit_repr_large(); - asof_join_by_numeric::( - &left_by, &right_by, left_asof, right_asof, filter, - )? - } else { - let left_by = left_by_s.bit_repr_small(); - let right_by = right_by_s.bit_repr_small(); - asof_join_by_numeric::( - &left_by, &right_by, left_asof, right_asof, filter, - )? + }, + (B::Large(left_by), B::Large(right_by)) => { + asof_join_by_numeric::( + &left_by, &right_by, left_asof, right_asof, filter, + )? + }, + // We have already asserted that the datatypes are the same. + _ => unreachable!(), } }, } diff --git a/crates/polars-ops/src/frame/join/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs index 4e939e56c0c4..dd970d523757 100644 --- a/crates/polars-ops/src/frame/join/hash_join/mod.rs +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -237,7 +237,7 @@ pub trait JoinDispatch: IntoDf { #[cfg(feature = "dtype-categorical")] _check_categorical_src(s_left.dtype(), s_right.dtype())?; - let idx = s_left.hash_join_semi_anti(s_right, anti, join_nulls); + let idx = s_left.hash_join_semi_anti(s_right, anti, join_nulls)?; // SAFETY: // indices are in bounds Ok(unsafe { ca_self._finish_anti_semi_join(&idx, slice) }) diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs index bf483c194d72..e399880fb47c 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs @@ -1,4 +1,5 @@ use arrow::array::PrimitiveArray; +use polars_core::series::BitRepr; use polars_core::utils::split; use polars_core::with_match_physical_float_polars_type; use polars_utils::hashing::DirtyHash; @@ -20,12 +21,14 @@ pub trait SeriesJoin: SeriesSealed + Sized { let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr()); validate.validate_probe(&lhs, &rhs, false)?; - use DataType::*; + let lhs_dtype = lhs.dtype(); + let rhs_dtype = rhs.dtype(); - match lhs.dtype() { - String | Binary => { - let lhs = lhs.cast(&Binary).unwrap(); - let rhs = rhs.cast(&Binary).unwrap(); + use DataType as T; + match lhs_dtype { + T::String | T::Binary => { + let lhs = lhs.cast(&T::Binary).unwrap(); + let rhs = rhs.cast(&T::Binary).unwrap(); let lhs = lhs.binary().unwrap(); let rhs = rhs.binary().unwrap(); let (lhs, rhs, _, _) = prepare_binary::(lhs, rhs, false); @@ -33,7 +36,7 @@ pub trait SeriesJoin: SeriesSealed + Sized { let rhs = rhs.iter().map(|v| v.as_slice()).collect::>(); hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls) }, - BinaryOffset => { + T::BinaryOffset => { let lhs = lhs.binary_offset().unwrap(); let rhs = rhs.binary_offset().unwrap(); let (lhs, rhs, _, _) = prepare_binary::(lhs, rhs, false); @@ -42,37 +45,57 @@ pub trait SeriesJoin: SeriesSealed + Sized { let rhs = rhs.iter().map(|k| k.as_slice()).collect::>(); hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls) }, + x if x.is_float() => { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + num_group_join_left(lhs, rhs, validate, join_nulls) + }) + }, _ => { - if lhs.dtype().is_float() { - with_match_physical_float_polars_type!(lhs.dtype(), |$T| { - let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); - let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); - num_group_join_left(lhs, rhs, validate, join_nulls) - }) - } else if s_self.bit_repr_is_large() { - let lhs = lhs.bit_repr_large(); - let rhs = rhs.bit_repr_large(); - num_group_join_left(&lhs, &rhs, validate, join_nulls) - } else { - let lhs = lhs.bit_repr_small(); - let rhs = rhs.bit_repr_small(); - num_group_join_left(&lhs, &rhs, validate, join_nulls) + let lhs = s_self.bit_repr(); + let rhs = other.bit_repr(); + + let (Some(lhs), Some(rhs)) = (lhs, rhs) else { + polars_bail!(nyi = "Hash Left Join between {lhs_dtype} and {rhs_dtype}"); + }; + + use BitRepr as B; + match (lhs, rhs) { + (B::Small(lhs), B::Small(rhs)) => { + num_group_join_left(&lhs, &rhs, validate, join_nulls) + }, + (B::Large(lhs), B::Large(rhs)) => { + num_group_join_left(&lhs, &rhs, validate, join_nulls) + }, + _ => { + polars_bail!( + nyi = "Mismatch bit repr Hash Left Join between {lhs_dtype} and {rhs_dtype}", + ); + }, } }, } } #[cfg(feature = "semi_anti_join")] - fn hash_join_semi_anti(&self, other: &Series, anti: bool, join_nulls: bool) -> Vec { + fn hash_join_semi_anti( + &self, + other: &Series, + anti: bool, + join_nulls: bool, + ) -> PolarsResult> { let s_self = self.as_series(); let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr()); - use DataType::*; + let lhs_dtype = lhs.dtype(); + let rhs_dtype = rhs.dtype(); - match lhs.dtype() { - String | Binary => { - let lhs = lhs.cast(&Binary).unwrap(); - let rhs = rhs.cast(&Binary).unwrap(); + use DataType as T; + Ok(match lhs_dtype { + T::String | T::Binary => { + let lhs = lhs.cast(&T::Binary).unwrap(); + let rhs = rhs.cast(&T::Binary).unwrap(); let lhs = lhs.binary().unwrap(); let rhs = rhs.binary().unwrap(); let (lhs, rhs, _, _) = prepare_binary::(lhs, rhs, false); @@ -85,7 +108,7 @@ pub trait SeriesJoin: SeriesSealed + Sized { hash_join_tuples_left_semi(lhs, rhs, join_nulls) } }, - BinaryOffset => { + T::BinaryOffset => { let lhs = lhs.binary_offset().unwrap(); let rhs = rhs.binary_offset().unwrap(); let (lhs, rhs, _, _) = prepare_binary::(lhs, rhs, false); @@ -98,24 +121,37 @@ pub trait SeriesJoin: SeriesSealed + Sized { hash_join_tuples_left_semi(lhs, rhs, join_nulls) } }, + x if x.is_float() => { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + num_group_join_anti_semi(lhs, rhs, anti, join_nulls) + }) + }, _ => { - if lhs.dtype().is_float() { - with_match_physical_float_polars_type!(lhs.dtype(), |$T| { - let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); - let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); - num_group_join_anti_semi(lhs, rhs, anti, join_nulls) - }) - } else if s_self.bit_repr_is_large() { - let lhs = lhs.bit_repr_large(); - let rhs = rhs.bit_repr_large(); - num_group_join_anti_semi(&lhs, &rhs, anti, join_nulls) - } else { - let lhs = lhs.bit_repr_small(); - let rhs = rhs.bit_repr_small(); - num_group_join_anti_semi(&lhs, &rhs, anti, join_nulls) + let lhs = s_self.bit_repr(); + let rhs = other.bit_repr(); + + let (Some(lhs), Some(rhs)) = (lhs, rhs) else { + polars_bail!(nyi = "Hash Semi-Anti Join between {lhs_dtype} and {rhs_dtype}"); + }; + + use BitRepr as B; + match (lhs, rhs) { + (B::Small(lhs), B::Small(rhs)) => { + num_group_join_anti_semi(&lhs, &rhs, anti, join_nulls) + }, + (B::Large(lhs), B::Large(rhs)) => { + num_group_join_anti_semi(&lhs, &rhs, anti, join_nulls) + }, + _ => { + polars_bail!( + nyi = "Mismatch bit repr Hash Semi-Anti Join between {lhs_dtype} and {rhs_dtype}", + ); + }, } }, - } + }) } // returns the join tuples and whether or not the lhs tuples are sorted @@ -129,11 +165,14 @@ pub trait SeriesJoin: SeriesSealed + Sized { let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr()); validate.validate_probe(&lhs, &rhs, true)?; - use DataType::*; - match lhs.dtype() { - String | Binary => { - let lhs = lhs.cast(&Binary).unwrap(); - let rhs = rhs.cast(&Binary).unwrap(); + let lhs_dtype = lhs.dtype(); + let rhs_dtype = rhs.dtype(); + + use DataType as T; + match lhs_dtype { + T::String | T::Binary => { + let lhs = lhs.cast(&T::Binary).unwrap(); + let rhs = rhs.cast(&T::Binary).unwrap(); let lhs = lhs.binary().unwrap(); let rhs = rhs.binary().unwrap(); let (lhs, rhs, swapped, _) = prepare_binary::(lhs, rhs, true); @@ -145,7 +184,7 @@ pub trait SeriesJoin: SeriesSealed + Sized { !swapped, )) }, - BinaryOffset => { + T::BinaryOffset => { let lhs = lhs.binary_offset().unwrap(); let rhs = rhs.binary_offset()?; let (lhs, rhs, swapped, _) = prepare_binary::(lhs, rhs, true); @@ -157,21 +196,34 @@ pub trait SeriesJoin: SeriesSealed + Sized { !swapped, )) }, + x if x.is_float() => { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + group_join_inner::<$T>(lhs, rhs, validate, join_nulls) + }) + }, _ => { - if lhs.dtype().is_float() { - with_match_physical_float_polars_type!(lhs.dtype(), |$T| { - let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); - let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); - group_join_inner::<$T>(lhs, rhs, validate, join_nulls) - }) - } else if s_self.bit_repr_is_large() { - let lhs = s_self.bit_repr_large(); - let rhs = other.bit_repr_large(); - group_join_inner::(&lhs, &rhs, validate, join_nulls) - } else { - let lhs = s_self.bit_repr_small(); - let rhs = other.bit_repr_small(); - group_join_inner::(&lhs, &rhs, validate, join_nulls) + let lhs = s_self.bit_repr(); + let rhs = other.bit_repr(); + + let (Some(lhs), Some(rhs)) = (lhs, rhs) else { + polars_bail!(nyi = "Hash Inner Join between {lhs_dtype} and {rhs_dtype}"); + }; + + use BitRepr as B; + match (lhs, rhs) { + (B::Small(lhs), B::Small(rhs)) => { + group_join_inner(&lhs, &rhs, validate, join_nulls) + }, + (B::Large(lhs), BitRepr::Large(rhs)) => { + group_join_inner(&lhs, &rhs, validate, join_nulls) + }, + _ => { + polars_bail!( + nyi = "Mismatch bit repr Hash Inner Join between {lhs_dtype} and {rhs_dtype}" + ); + }, } }, } @@ -187,11 +239,14 @@ pub trait SeriesJoin: SeriesSealed + Sized { let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr()); validate.validate_probe(&lhs, &rhs, true)?; - use DataType::*; - match lhs.dtype() { - String | Binary => { - let lhs = lhs.cast(&Binary).unwrap(); - let rhs = rhs.cast(&Binary).unwrap(); + let lhs_dtype = lhs.dtype(); + let rhs_dtype = rhs.dtype(); + + use DataType as T; + match lhs_dtype { + T::String | T::Binary => { + let lhs = lhs.cast(&T::Binary).unwrap(); + let rhs = rhs.cast(&T::Binary).unwrap(); let lhs = lhs.binary().unwrap(); let rhs = rhs.binary().unwrap(); let (lhs, rhs, swapped, _) = prepare_binary::(lhs, rhs, true); @@ -200,7 +255,7 @@ pub trait SeriesJoin: SeriesSealed + Sized { let rhs = rhs.iter().map(|k| k.as_slice()).collect::>(); hash_join_tuples_outer(lhs, rhs, swapped, validate, join_nulls) }, - BinaryOffset => { + T::BinaryOffset => { let lhs = lhs.binary_offset().unwrap(); let rhs = rhs.binary_offset()?; let (lhs, rhs, swapped, _) = prepare_binary::(lhs, rhs, true); @@ -209,21 +264,29 @@ pub trait SeriesJoin: SeriesSealed + Sized { let rhs = rhs.iter().map(|k| k.as_slice()).collect::>(); hash_join_tuples_outer(lhs, rhs, swapped, validate, join_nulls) }, + x if x.is_float() => { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + hash_join_outer(lhs, rhs, validate, join_nulls) + }) + }, _ => { - if lhs.dtype().is_float() { - with_match_physical_float_polars_type!(lhs.dtype(), |$T| { - let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); - let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); - hash_join_outer(lhs, rhs, validate, join_nulls) - }) - } else if s_self.bit_repr_is_large() { - let lhs = s_self.bit_repr_large(); - let rhs = other.bit_repr_large(); - hash_join_outer(&lhs, &rhs, validate, join_nulls) - } else { - let lhs = s_self.bit_repr_small(); - let rhs = other.bit_repr_small(); - hash_join_outer(&lhs, &rhs, validate, join_nulls) + let (Some(lhs), Some(rhs)) = (s_self.bit_repr(), other.bit_repr()) else { + polars_bail!(nyi = "Hash Join Outer between {lhs_dtype} and {rhs_dtype}"); + }; + + use BitRepr as B; + match (lhs, rhs) { + (B::Small(lhs), B::Small(rhs)) => { + hash_join_outer(&lhs, &rhs, validate, join_nulls) + }, + (B::Large(lhs), B::Large(rhs)) => { + hash_join_outer(&lhs, &rhs, validate, join_nulls) + }, + _ => { + polars_bail!(nyi = "Mismatch bit repr Hash Join Outer between {lhs_dtype} and {rhs_dtype}"); + }, } }, } diff --git a/crates/polars-ops/src/frame/pivot/positioning.rs b/crates/polars-ops/src/frame/pivot/positioning.rs index 5ad0b32f101d..affdf559e02a 100644 --- a/crates/polars-ops/src/frame/pivot/positioning.rs +++ b/crates/polars-ops/src/frame/pivot/positioning.rs @@ -2,6 +2,7 @@ use std::hash::Hash; use arrow::legacy::trusted_len::TrustedLenPush; use polars_core::prelude::*; +use polars_core::series::BitRepr; use polars_utils::sync::SyncPtr; use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; @@ -231,39 +232,43 @@ pub(super) fn compute_col_idx( let column_agg = unsafe { column_s.agg_first(groups) }; let column_agg_physical = column_agg.to_physical_repr(); - use DataType::*; + use DataType as T; let col_locations = match column_agg_physical.dtype() { - Int32 | UInt32 => { - let ca = column_agg_physical.bit_repr_small(); + T::Int32 | T::UInt32 => { + let Some(BitRepr::Small(ca)) = column_agg_physical.bit_repr() else { + polars_bail!(ComputeError: "Expected 32-bit bit representation to be available. This should never happen"); + }; compute_col_idx_numeric(&ca) }, - Int64 | UInt64 => { - let ca = column_agg_physical.bit_repr_large(); + T::Int64 | T::UInt64 => { + let Some(BitRepr::Large(ca)) = column_agg_physical.bit_repr() else { + polars_bail!(ComputeError: "Expected 64-bit bit representation to be available. This should never happen"); + }; compute_col_idx_numeric(&ca) }, - Float64 => { + T::Float64 => { let ca: &ChunkedArray = column_agg_physical.as_ref().as_ref().as_ref(); compute_col_idx_numeric(ca) }, - Float32 => { + T::Float32 => { let ca: &ChunkedArray = column_agg_physical.as_ref().as_ref().as_ref(); compute_col_idx_numeric(ca) }, - Struct(_) => { + T::Struct(_) => { let ca = column_agg_physical.struct_().unwrap(); let ca = ca.rows_encode()?; compute_col_idx_gen(&ca) }, - String => { + T::String => { let ca = column_agg_physical.str().unwrap(); let ca = ca.as_binary(); compute_col_idx_gen(&ca) }, - Binary => { + T::Binary => { let ca = column_agg_physical.binary().unwrap(); compute_col_idx_gen(ca) }, - Boolean => { + T::Boolean => { let ca = column_agg_physical.bool().unwrap(); compute_col_idx_gen(ca) }, @@ -393,34 +398,38 @@ pub(super) fn compute_row_idx( let index_agg = unsafe { index_s.agg_first(groups) }; let index_agg_physical = index_agg.to_physical_repr(); - use DataType::*; + use DataType as T; match index_agg_physical.dtype() { - Int32 | UInt32 => { - let ca = index_agg_physical.bit_repr_small(); + T::Int32 | T::UInt32 => { + let Some(BitRepr::Small(ca)) = index_agg_physical.bit_repr() else { + polars_bail!(ComputeError: "Expected 32-bit bit representation to be available. This should never happen"); + }; compute_row_index(index, &ca, count, index_s.dtype()) }, - Int64 | UInt64 => { - let ca = index_agg_physical.bit_repr_large(); + T::Int64 | T::UInt64 => { + let Some(BitRepr::Large(ca)) = index_agg_physical.bit_repr() else { + polars_bail!(ComputeError: "Expected 64-bit bit representation to be available. This should never happen"); + }; compute_row_index(index, &ca, count, index_s.dtype()) }, - Float64 => { + T::Float64 => { let ca: &ChunkedArray = index_agg_physical.as_ref().as_ref().as_ref(); compute_row_index(index, ca, count, index_s.dtype()) }, - Float32 => { + T::Float32 => { let ca: &ChunkedArray = index_agg_physical.as_ref().as_ref().as_ref(); compute_row_index(index, ca, count, index_s.dtype()) }, - Boolean => { + T::Boolean => { let ca = index_agg_physical.bool().unwrap(); compute_row_index(index, ca, count, index_s.dtype()) }, - Struct(_) => { + T::Struct(_) => { let ca = index_agg_physical.struct_().unwrap(); let ca = ca.rows_encode()?; compute_row_index_struct(index, &index_agg, &ca, count) }, - String => { + T::String => { let ca = index_agg_physical.str().unwrap(); compute_row_index(index, ca, count, index_s.dtype()) },