diff --git a/Cargo.toml b/Cargo.toml index bf71d65c89d9..a93e78953023 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -118,7 +118,6 @@ features = [ "compute_boolean_kleene", "compute_cast", "compute_comparison", - "compute_if_then_else", ] [patch.crates-io] diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index 5f7ce63ad038..95f4cfb60c3a 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -140,7 +140,6 @@ compute_boolean_kleene = [] compute_cast = ["compute_take", "ryu", "atoi_simd", "itoa", "fast-float"] compute_comparison = ["compute_take", "compute_boolean"] compute_hash = ["multiversion"] -compute_if_then_else = [] compute_take = [] compute_temporal = [] compute = [ @@ -152,7 +151,6 @@ compute = [ "compute_cast", "compute_comparison", "compute_hash", - "compute_if_then_else", "compute_take", "compute_temporal", ] diff --git a/crates/polars-arrow/src/array/binview/view.rs b/crates/polars-arrow/src/array/binview/view.rs index 4c480afdc5c5..4975930ee744 100644 --- a/crates/polars-arrow/src/array/binview/view.rs +++ b/crates/polars-arrow/src/array/binview/view.rs @@ -34,6 +34,29 @@ impl View { pub fn as_u128(self) -> u128 { unsafe { std::mem::transmute(self) } } + + #[inline] + pub fn new_from_bytes(bytes: &[u8], buffer_idx: u32, offset: u32) -> Self { + if bytes.len() <= 12 { + let mut ret = Self { + length: bytes.len() as u32, + ..Default::default() + }; + let ret_ptr = &mut ret as *mut _ as *mut u8; + unsafe { + core::ptr::copy_nonoverlapping(bytes.as_ptr(), ret_ptr.add(4), bytes.len()); + } + ret + } else { + let prefix_buf: [u8; 4] = std::array::from_fn(|i| *bytes.get(i).unwrap_or(&0)); + Self { + length: bytes.len() as u32, + prefix: u32::from_le_bytes(prefix_buf), + buffer_idx, + offset, + } + } + } } impl IsNull for View { diff --git a/crates/polars-arrow/src/array/growable/binview.rs b/crates/polars-arrow/src/array/growable/binview.rs index affcb472cbec..803f99cde2b8 100644 --- a/crates/polars-arrow/src/array/growable/binview.rs +++ b/crates/polars-arrow/src/array/growable/binview.rs @@ -7,7 +7,7 @@ use polars_utils::unwrap::UnwrapUncheckedRelease; use super::Growable; use crate::array::binview::{BinaryViewArrayGeneric, View, ViewType}; -use crate::array::growable::utils::{extend_validity, prepare_validity}; +use crate::array::growable::utils::{extend_validity, extend_validity_copies, prepare_validity}; use crate::array::Array; use crate::bitmap::MutableBitmap; use crate::buffer::Buffer; @@ -166,6 +166,22 @@ impl<'a, T: ViewType + ?Sized> Growable<'a> for GrowableBinaryViewArray<'a, T> { unsafe { self.extend_unchecked(index, start, len) } } + unsafe fn extend_copies(&mut self, index: usize, start: usize, len: usize, copies: usize) { + let orig_view_start = self.views.len(); + if copies > 0 { + unsafe { self.extend_unchecked(index, start, len) } + } + if copies > 1 { + let array = *self.arrays.get_unchecked(index); + extend_validity_copies(&mut self.validity, array, start, len, copies - 1); + let extended_view_end = self.views.len(); + for _ in 0..copies - 1 { + self.views + .extend_from_within(orig_view_start..extended_view_end) + } + } + } + fn extend_validity(&mut self, additional: usize) { self.views .extend(std::iter::repeat(View::default()).take(additional)); diff --git a/crates/polars-arrow/src/array/growable/fixed_size_list.rs b/crates/polars-arrow/src/array/growable/fixed_size_list.rs index 8226f1867b68..1841285f377d 100644 --- a/crates/polars-arrow/src/array/growable/fixed_size_list.rs +++ b/crates/polars-arrow/src/array/growable/fixed_size_list.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use polars_utils::slice::GetSaferUnchecked; use super::{make_growable, Growable}; -use crate::array::growable::utils::{extend_validity, prepare_validity}; +use crate::array::growable::utils::{extend_validity, extend_validity_copies, prepare_validity}; use crate::array::{Array, FixedSizeListArray}; use crate::bitmap::MutableBitmap; use crate::datatypes::ArrowDataType; @@ -55,7 +55,7 @@ impl<'a> GrowableFixedSizeList<'a> { } } - fn to(&mut self) -> FixedSizeListArray { + pub fn to(&mut self) -> FixedSizeListArray { let validity = std::mem::take(&mut self.validity); let values = self.values.as_box(); @@ -76,6 +76,14 @@ impl<'a> Growable<'a> for GrowableFixedSizeList<'a> { .extend(index, start * self.size, len * self.size); } + unsafe fn extend_copies(&mut self, index: usize, start: usize, len: usize, copies: usize) { + let array = *self.arrays.get_unchecked_release(index); + extend_validity_copies(&mut self.validity, array, start, len, copies); + + self.values + .extend_copies(index, start * self.size, len * self.size, copies); + } + fn extend_validity(&mut self, additional: usize) { self.values.extend_validity(additional * self.size); if let Some(validity) = &mut self.validity { diff --git a/crates/polars-arrow/src/array/growable/list.rs b/crates/polars-arrow/src/array/growable/list.rs index 30aa1a2d2c7f..a97518a310e3 100644 --- a/crates/polars-arrow/src/array/growable/list.rs +++ b/crates/polars-arrow/src/array/growable/list.rs @@ -64,7 +64,7 @@ impl<'a, O: Offset> GrowableList<'a, O> { } } - fn to(&mut self) -> ListArray { + pub fn to(&mut self) -> ListArray { let validity = std::mem::take(&mut self.validity); let offsets = std::mem::take(&mut self.offsets); let values = self.values.as_box(); diff --git a/crates/polars-arrow/src/array/growable/mod.rs b/crates/polars-arrow/src/array/growable/mod.rs index ca8fc87a5a86..ada21b71c121 100644 --- a/crates/polars-arrow/src/array/growable/mod.rs +++ b/crates/polars-arrow/src/array/growable/mod.rs @@ -37,9 +37,19 @@ pub trait Growable<'a> { /// a slice starting at `start` and length `len`. /// /// # Safety - /// Doesn't do any bound checks + /// Doesn't do any bound checks. unsafe fn extend(&mut self, index: usize, start: usize, len: usize); + /// Same as extend, except it repeats the extension `copies` times. + /// + /// # Safety + /// Doesn't do any bound checks. + unsafe fn extend_copies(&mut self, index: usize, start: usize, len: usize, copies: usize) { + for _ in 0..copies { + self.extend(index, start, len) + } + } + /// Extends this [`Growable`] with null elements, disregarding the bound arrays /// /// # Safety diff --git a/crates/polars-arrow/src/array/growable/primitive.rs b/crates/polars-arrow/src/array/growable/primitive.rs index 16f72cb868ee..936905ab05fa 100644 --- a/crates/polars-arrow/src/array/growable/primitive.rs +++ b/crates/polars-arrow/src/array/growable/primitive.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use polars_utils::slice::GetSaferUnchecked; use super::Growable; -use crate::array::growable::utils::{extend_validity, prepare_validity}; +use crate::array::growable::utils::{extend_validity, extend_validity_copies, prepare_validity}; use crate::array::{Array, PrimitiveArray}; use crate::bitmap::MutableBitmap; use crate::datatypes::ArrowDataType; @@ -66,6 +66,19 @@ impl<'a, T: NativeType> Growable<'a> for GrowablePrimitive<'a, T> { .extend_from_slice(values.get_unchecked_release(start..start + len)); } + #[inline] + unsafe fn extend_copies(&mut self, index: usize, start: usize, len: usize, copies: usize) { + let array = *self.arrays.get_unchecked_release(index); + extend_validity_copies(&mut self.validity, array, start, len, copies); + + let values = array.values().as_slice(); + self.values.reserve(len * copies); + for _ in 0..copies { + self.values + .extend_from_slice(values.get_unchecked_release(start..start + len)); + } + } + #[inline] fn extend_validity(&mut self, additional: usize) { self.values diff --git a/crates/polars-arrow/src/array/growable/utils.rs b/crates/polars-arrow/src/array/growable/utils.rs index 7cb4b667a5c1..1b9f85e8c801 100644 --- a/crates/polars-arrow/src/array/growable/utils.rs +++ b/crates/polars-arrow/src/array/growable/utils.rs @@ -46,3 +46,27 @@ pub(super) fn extend_validity( } } } + +pub(super) fn extend_validity_copies( + mutable_validity: &mut Option, + array: &dyn Array, + start: usize, + len: usize, + copies: usize, +) { + if let Some(mutable_validity) = mutable_validity { + match array.validity() { + None => mutable_validity.extend_constant(len * copies, true), + Some(validity) => { + debug_assert!(start + len <= validity.len()); + let (slice, offset, _) = validity.as_slice(); + // SAFETY: invariant offset + length <= slice.len() + for _ in 0..copies { + unsafe { + mutable_validity.extend_from_slice_unchecked(slice, start + offset, len); + } + } + }, + } + } +} diff --git a/crates/polars-arrow/src/array/static_array.rs b/crates/polars-arrow/src/array/static_array.rs index ac8fbc4cec32..e9547f9201d2 100644 --- a/crates/polars-arrow/src/array/static_array.rs +++ b/crates/polars-arrow/src/array/static_array.rs @@ -4,8 +4,8 @@ use crate::array::binview::BinaryViewValueIter; use crate::array::static_array_collect::ArrayFromIterDtype; use crate::array::{ Array, ArrayValuesIter, BinaryArray, BinaryValueIter, BinaryViewArray, BooleanArray, - FixedSizeListArray, ListArray, ListValuesIter, PrimitiveArray, Utf8Array, Utf8ValuesIter, - Utf8ViewArray, + FixedSizeListArray, ListArray, ListValuesIter, MutableBinaryViewArray, PrimitiveArray, + Utf8Array, Utf8ValuesIter, Utf8ViewArray, }; use crate::bitmap::utils::{BitmapIter, ZipValidity}; use crate::bitmap::Bitmap; @@ -18,6 +18,7 @@ pub trait StaticArray: + for<'a> ArrayFromIterDtype> + for<'a> ArrayFromIterDtype> + for<'a> ArrayFromIterDtype>> + + Clone { type ValueT<'a>: Clone where @@ -82,6 +83,10 @@ pub trait StaticArray: } fn full_null(length: usize, dtype: ArrowDataType) -> Self; + + fn full(length: usize, value: Self::ValueT<'_>, dtype: ArrowDataType) -> Self { + Self::arr_from_iter_with_dtype(dtype, std::iter::repeat(value).take(length)) + } } pub trait ParameterFreeDtypeStaticArray: StaticArray { @@ -126,6 +131,10 @@ impl StaticArray for PrimitiveArray { fn full_null(length: usize, dtype: ArrowDataType) -> Self { Self::new_null(dtype, length) } + + fn full(length: usize, value: Self::ValueT<'_>, _dtype: ArrowDataType) -> Self { + PrimitiveArray::from_vec(vec![value; length]) + } } impl ParameterFreeDtypeStaticArray for PrimitiveArray { @@ -167,6 +176,10 @@ impl StaticArray for BooleanArray { fn full_null(length: usize, dtype: ArrowDataType) -> Self { Self::new_null(dtype, length) } + + fn full(length: usize, value: Self::ValueT<'_>, _dtype: ArrowDataType) -> Self { + Bitmap::new_with_value(value, length).into() + } } impl ParameterFreeDtypeStaticArray for BooleanArray { @@ -265,6 +278,12 @@ impl StaticArray for BinaryViewArray { fn full_null(length: usize, dtype: ArrowDataType) -> Self { Self::new_null(dtype, length) } + + fn full(length: usize, value: Self::ValueT<'_>, _dtype: ArrowDataType) -> Self { + let mut builder = MutableBinaryViewArray::with_capacity(length); + builder.extend_constant(length, Some(value)); + builder.into() + } } impl ParameterFreeDtypeStaticArray for BinaryViewArray { @@ -297,6 +316,13 @@ impl StaticArray for Utf8ViewArray { fn full_null(length: usize, dtype: ArrowDataType) -> Self { Self::new_null(dtype, length) } + + fn full(length: usize, value: Self::ValueT<'_>, _dtype: ArrowDataType) -> Self { + unsafe { + BinaryViewArray::full(length, value.as_bytes(), ArrowDataType::BinaryView) + .to_utf8view_unchecked() + } + } } impl ParameterFreeDtypeStaticArray for Utf8ViewArray { diff --git a/crates/polars-arrow/src/array/static_array_collect.rs b/crates/polars-arrow/src/array/static_array_collect.rs index 9413b0a16778..e2ecca37f7be 100644 --- a/crates/polars-arrow/src/array/static_array_collect.rs +++ b/crates/polars-arrow/src/array/static_array_collect.rs @@ -70,7 +70,13 @@ pub trait ArrayFromIter: Sized { impl> ArrayFromIterDtype for A { #[inline(always)] fn arr_from_iter_with_dtype>(dtype: ArrowDataType, iter: I) -> Self { - debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype())); + // FIXME: currently some Object arrays have Unknown dtype, when this is fixed remove this bypass. + if dtype != ArrowDataType::Unknown { + debug_assert_eq!( + std::mem::discriminant(&dtype), + std::mem::discriminant(&A::get_dtype()) + ); + } Self::arr_from_iter(iter) } @@ -80,7 +86,13 @@ impl> ArrayFromIterDtype< I: IntoIterator, I::IntoIter: TrustedLen, { - debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype())); + // FIXME: currently some Object arrays have Unknown dtype, when this is fixed remove this bypass. + if dtype != ArrowDataType::Unknown { + debug_assert_eq!( + std::mem::discriminant(&dtype), + std::mem::discriminant(&A::get_dtype()) + ); + } Self::arr_from_iter_trusted(iter) } @@ -89,7 +101,13 @@ impl> ArrayFromIterDtype< dtype: ArrowDataType, iter: I, ) -> Result { - debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype())); + // FIXME: currently some Object arrays have Unknown dtype, when this is fixed remove this bypass. + if dtype != ArrowDataType::Unknown { + debug_assert_eq!( + std::mem::discriminant(&dtype), + std::mem::discriminant(&A::get_dtype()) + ); + } Self::try_arr_from_iter(iter) } @@ -99,7 +117,13 @@ impl> ArrayFromIterDtype< I: IntoIterator>, I::IntoIter: TrustedLen, { - debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype())); + // FIXME: currently some Object arrays have Unknown dtype, when this is fixed remove this bypass. + if dtype != ArrowDataType::Unknown { + debug_assert_eq!( + std::mem::discriminant(&dtype), + std::mem::discriminant(&A::get_dtype()) + ); + } Self::try_arr_from_iter_trusted(iter) } } diff --git a/crates/polars-arrow/src/bitmap/bitmap_ops.rs b/crates/polars-arrow/src/bitmap/bitmap_ops.rs index c83e63255093..433fbbc93a71 100644 --- a/crates/polars-arrow/src/bitmap/bitmap_ops.rs +++ b/crates/polars-arrow/src/bitmap/bitmap_ops.rs @@ -160,8 +160,7 @@ pub(crate) fn align(bitmap: &Bitmap, new_offset: usize) -> Bitmap { bitmap.sliced(new_offset, length) } -#[inline] -/// Compute bitwise AND operation +/// Compute bitwise A AND B operation. pub fn and(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { if lhs.unset_bits() == lhs.len() || rhs.unset_bits() == rhs.len() { assert_eq!(lhs.len(), rhs.len()); @@ -171,8 +170,12 @@ pub fn and(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { } } -#[inline] -/// Compute bitwise OR operation +/// Compute bitwise A AND NOT B operation. +pub fn and_not(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { + binary(lhs, rhs, |x, y| x & !y) +} + +/// Compute bitwise A OR B operation. pub fn or(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { if lhs.unset_bits() == 0 || rhs.unset_bits() == 0 { assert_eq!(lhs.len(), rhs.len()); @@ -184,8 +187,12 @@ pub fn or(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { } } -#[inline] -/// Compute bitwise XOR operation +/// Compute bitwise A OR NOT B operation. +pub fn or_not(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { + binary(lhs, rhs, |x, y| x | !y) +} + +/// Compute bitwise XOR operation. pub fn xor(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { let lhs_nulls = lhs.unset_bits(); let rhs_nulls = rhs.unset_bits(); @@ -208,6 +215,7 @@ pub fn xor(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { } } +/// Compute bitwise equality (not XOR) operation. fn eq(lhs: &Bitmap, rhs: &Bitmap) -> bool { if lhs.len() != rhs.len() { return false; diff --git a/crates/polars-arrow/src/bitmap/utils/mod.rs b/crates/polars-arrow/src/bitmap/utils/mod.rs index 4ec5786f1c4f..29429ef20d58 100644 --- a/crates/polars-arrow/src/bitmap/utils/mod.rs +++ b/crates/polars-arrow/src/bitmap/utils/mod.rs @@ -140,3 +140,41 @@ pub fn count_zeros(mut slice: &[u8], mut offset: usize, len: usize) -> usize { len - num_ones } + +/// Takes the given slice of bytes plus a bit offset and bit length and returns +/// the slice so that it starts at a byte-aligned boundary. +/// +/// Returns (in order): +/// - the bits of the first byte if it isn't a full byte +/// - the number of bits in the first partial byte +/// - the rest of the bits as a byteslice +/// - the number of bits in the byteslice +#[inline] +pub fn align_bitslice_start_u8( + slice: &[u8], + offset: usize, + len: usize, +) -> (u8, usize, &[u8], usize) { + if len == 0 { + return (0, 0, &[], 0); + } + + // Protects the below get_uncheckeds. + assert!(slice.len() * 8 >= offset + len); + + let mut first_byte_idx = offset / 8; + let partial_offset = offset % 8; + let bits_in_partial_byte = (8 - partial_offset).min(len) % 8; + let mut partial_byte = unsafe { *slice.get_unchecked(first_byte_idx) }; + partial_byte >>= partial_offset; + partial_byte &= (1 << bits_in_partial_byte) - 1; + first_byte_idx += (partial_offset > 0) as usize; + + let rest_slice = unsafe { slice.get_unchecked(first_byte_idx..) }; + ( + partial_byte, + bits_in_partial_byte, + rest_slice, + len - bits_in_partial_byte, + ) +} diff --git a/crates/polars-arrow/src/compute/if_then_else.rs b/crates/polars-arrow/src/compute/if_then_else.rs deleted file mode 100644 index 834a1fefad3a..000000000000 --- a/crates/polars-arrow/src/compute/if_then_else.rs +++ /dev/null @@ -1,58 +0,0 @@ -//! Contains the operator [`if_then_else`]. -use polars_error::{polars_bail, PolarsResult}; - -use crate::array::{growable, Array, BooleanArray}; -use crate::bitmap::utils::SlicesIterator; - -/// Returns the values from `lhs` if the predicate is `true` or from the `rhs` if the predicate is false -/// Returns `None` if the predicate is `None`. -pub fn if_then_else( - predicate: &BooleanArray, - lhs: &dyn Array, - rhs: &dyn Array, -) -> PolarsResult> { - if lhs.data_type() != rhs.data_type() { - polars_bail!(InvalidOperation: - "If then else requires the arguments to have the same datatypes ({:?} != {:?})", - lhs.data_type(), - rhs.data_type() - ) - } - if (lhs.len() != rhs.len()) | (lhs.len() != predicate.len()) { - polars_bail!(ComputeError: - "If then else requires all arguments to have the same length (predicate = {}, lhs = {}, rhs = {})", - predicate.len(), - lhs.len(), - rhs.len() - ); - } - - let result = if predicate.null_count() > 0 { - let mut growable = growable::make_growable(&[lhs, rhs], true, lhs.len()); - for (i, v) in predicate.iter().enumerate() { - match v { - Some(v) => unsafe { growable.extend(!v as usize, i, 1) }, - None => growable.extend_validity(1), - } - } - growable.as_box() - } else { - let mut growable = growable::make_growable(&[lhs, rhs], false, lhs.len()); - let mut start_falsy = 0; - let mut total_len = 0; - for (start, len) in SlicesIterator::new(predicate.values()) { - if start != start_falsy { - unsafe { growable.extend(1, start_falsy, start - start_falsy) }; - total_len += start - start_falsy; - }; - unsafe { growable.extend(0, start, len) }; - total_len += len; - start_falsy = start + len; - } - if total_len != lhs.len() { - unsafe { growable.extend(1, total_len, lhs.len() - total_len) }; - } - growable.as_box() - }; - Ok(result) -} diff --git a/crates/polars-arrow/src/compute/mod.rs b/crates/polars-arrow/src/compute/mod.rs index 6dba6456d7f6..f08a3acb3215 100644 --- a/crates/polars-arrow/src/compute/mod.rs +++ b/crates/polars-arrow/src/compute/mod.rs @@ -29,9 +29,6 @@ pub mod boolean_kleene; #[cfg_attr(docsrs, doc(cfg(feature = "compute_cast")))] pub mod cast; pub mod concatenate; -#[cfg(feature = "compute_if_then_else")] -#[cfg_attr(docsrs, doc(cfg(feature = "compute_if_then_else")))] -pub mod if_then_else; #[cfg(feature = "compute_take")] #[cfg_attr(docsrs, doc(cfg(feature = "compute_take")))] pub mod take; diff --git a/crates/polars-arrow/src/compute/utils.rs b/crates/polars-arrow/src/compute/utils.rs index 744d12d2fe69..3e72d00f55d9 100644 --- a/crates/polars-arrow/src/compute/utils.rs +++ b/crates/polars-arrow/src/compute/utils.rs @@ -3,7 +3,7 @@ use std::ops::{BitAnd, BitOr}; use polars_error::{polars_ensure, PolarsResult}; use crate::array::Array; -use crate::bitmap::{ternary, Bitmap}; +use crate::bitmap::{and_not, ternary, Bitmap}; pub fn combine_validities_and3( opt1: Option<&Bitmap>, @@ -36,6 +36,17 @@ pub fn combine_validities_or(opt_l: Option<&Bitmap>, opt_r: Option<&Bitmap>) -> _ => None, } } +pub fn combine_validities_and_not( + opt_l: Option<&Bitmap>, + opt_r: Option<&Bitmap>, +) -> Option { + match (opt_l, opt_r) { + (Some(l), Some(r)) => Some(and_not(l, r)), + (None, Some(r)) => Some(!r), + (Some(l), None) => Some(l.clone()), + (None, None) => None, + } +} // Errors iff the two arrays have a different length. #[inline] diff --git a/crates/polars-arrow/src/datatypes/mod.rs b/crates/polars-arrow/src/datatypes/mod.rs index 95e64447293f..6325d7a5be4e 100644 --- a/crates/polars-arrow/src/datatypes/mod.rs +++ b/crates/polars-arrow/src/datatypes/mod.rs @@ -169,6 +169,8 @@ pub enum ArrowDataType { /// A string type that inlines small values /// and can intern strings. Utf8View, + /// A type unknown to Arrow. + Unknown, } #[cfg(feature = "arrow_rs")] @@ -233,6 +235,7 @@ impl From for arrow_schema::DataType { ArrowDataType::BinaryView | ArrowDataType::Utf8View => { panic!("view datatypes not supported by arrow-rs") }, + ArrowDataType::Unknown => unimplemented!(), } } } @@ -470,6 +473,7 @@ impl ArrowDataType { Map(_, _) => PhysicalType::Map, Dictionary(key, _, _) => PhysicalType::Dictionary(*key), Extension(_, key, _) => key.to_physical_type(), + Unknown => unimplemented!(), } } diff --git a/crates/polars-arrow/src/ffi/schema.rs b/crates/polars-arrow/src/ffi/schema.rs index 23cf9c8c4a47..f958311d7988 100644 --- a/crates/polars-arrow/src/ffi/schema.rs +++ b/crates/polars-arrow/src/ffi/schema.rs @@ -479,6 +479,7 @@ fn to_format(data_type: &ArrowDataType) -> String { ArrowDataType::Map(_, _) => "+m".to_string(), ArrowDataType::Dictionary(index, _, _) => to_format(&(*index).into()), ArrowDataType::Extension(_, inner, _) => to_format(inner.as_ref()), + ArrowDataType::Unknown => unimplemented!(), } } diff --git a/crates/polars-arrow/src/io/ipc/write/schema.rs b/crates/polars-arrow/src/io/ipc/write/schema.rs index 5aefef3e6684..8243e07a7d04 100644 --- a/crates/polars-arrow/src/io/ipc/write/schema.rs +++ b/crates/polars-arrow/src/io/ipc/write/schema.rs @@ -247,6 +247,7 @@ fn serialize_type(data_type: &ArrowDataType) -> arrow_format::ipc::Type { Extension(_, v, _) => serialize_type(v), Utf8View => ipc::Type::Utf8View(Box::new(ipc::Utf8View {})), BinaryView => ipc::Type::BinaryView(Box::new(ipc::BinaryView {})), + Unknown => unimplemented!(), } } @@ -295,6 +296,7 @@ fn serialize_children( .collect(), Dictionary(_, inner, _) => serialize_children(inner, ipc_field), Extension(_, inner, _) => serialize_children(inner, ipc_field), + Unknown => unimplemented!(), } } diff --git a/crates/polars-compute/src/if_then_else/array.rs b/crates/polars-compute/src/if_then_else/array.rs new file mode 100644 index 000000000000..a15349bf1e2c --- /dev/null +++ b/crates/polars-compute/src/if_then_else/array.rs @@ -0,0 +1,85 @@ +use arrow::array::growable::{Growable, GrowableFixedSizeList}; +use arrow::array::{Array, ArrayCollectIterExt, FixedSizeListArray}; +use arrow::bitmap::Bitmap; + +use super::{if_then_else_extend, IfThenElseKernel}; + +impl IfThenElseKernel for FixedSizeListArray { + type Scalar<'a> = Box; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + let mut growable = GrowableFixedSizeList::new(vec![if_true, if_false], false, mask.len()); + unsafe { + if_then_else_extend( + &mut growable, + mask, + |g, off, len| g.extend(0, off, len), + |g, off, len| g.extend(1, off, len), + ) + }; + growable.to() + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + let if_true_list: FixedSizeListArray = + std::iter::once(if_true).collect_arr_trusted_with_dtype(if_false.data_type().clone()); + let mut growable = + GrowableFixedSizeList::new(vec![&if_true_list, if_false], false, mask.len()); + unsafe { + if_then_else_extend( + &mut growable, + mask, + |g, _, len| g.extend_copies(0, 0, 1, len), + |g, off, len| g.extend(1, off, len), + ) + }; + growable.to() + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + let if_false_list: FixedSizeListArray = + std::iter::once(if_false).collect_arr_trusted_with_dtype(if_true.data_type().clone()); + let mut growable = + GrowableFixedSizeList::new(vec![if_true, &if_false_list], false, mask.len()); + unsafe { + if_then_else_extend( + &mut growable, + mask, + |g, off, len| g.extend(0, off, len), + |g, _, len| g.extend_copies(1, 0, 1, len), + ) + }; + growable.to() + } + + fn if_then_else_broadcast_both( + dtype: arrow::datatypes::ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + let if_true_list: FixedSizeListArray = + std::iter::once(if_true).collect_arr_trusted_with_dtype(dtype.clone()); + let if_false_list: FixedSizeListArray = + std::iter::once(if_false).collect_arr_trusted_with_dtype(dtype.clone()); + let mut growable = + GrowableFixedSizeList::new(vec![&if_true_list, &if_false_list], false, mask.len()); + unsafe { + if_then_else_extend( + &mut growable, + mask, + |g, _, len| g.extend_copies(0, 0, 1, len), + |g, _, len| g.extend_copies(1, 0, 1, len), + ) + }; + growable.to() + } +} diff --git a/crates/polars-compute/src/if_then_else/boolean.rs b/crates/polars-compute/src/if_then_else/boolean.rs new file mode 100644 index 000000000000..469d6b8d4937 --- /dev/null +++ b/crates/polars-compute/src/if_then_else/boolean.rs @@ -0,0 +1,60 @@ +use arrow::array::BooleanArray; +use arrow::bitmap::{self, Bitmap}; +use arrow::datatypes::ArrowDataType; + +use super::{if_then_else_validity, IfThenElseKernel}; + +impl IfThenElseKernel for BooleanArray { + type Scalar<'a> = bool; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + let values = bitmap::ternary(mask, if_true.values(), if_false.values(), |m, t, f| { + (m & t) | (!m & f) + }); + let validity = if_then_else_validity(mask, if_true.validity(), if_false.validity()); + BooleanArray::from(values).with_validity(validity) + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + let values = if if_true { + bitmap::or(if_false.values(), mask) // (m & true) | (!m & f) -> f | m + } else { + bitmap::and_not(if_false.values(), mask) // (m & false) | (!m & f) -> f & !m + }; + let validity = if_then_else_validity(mask, None, if_false.validity()); + BooleanArray::from(values).with_validity(validity) + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + let values = if if_false { + bitmap::or_not(if_true.values(), mask) // (m & t) | (!m & true) -> t | !m + } else { + bitmap::and(if_true.values(), mask) // (m & t) | (!m & false) -> t & m + }; + let validity = if_then_else_validity(mask, if_true.validity(), None); + BooleanArray::from(values).with_validity(validity) + } + + fn if_then_else_broadcast_both( + _dtype: ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + let values = match (if_true, if_false) { + (false, false) => Bitmap::new_with_value(false, mask.len()), + (false, true) => !mask, + (true, false) => mask.clone(), + (true, true) => Bitmap::new_with_value(true, mask.len()), + }; + BooleanArray::from(values) + } +} diff --git a/crates/polars-compute/src/if_then_else/list.rs b/crates/polars-compute/src/if_then_else/list.rs new file mode 100644 index 000000000000..aa3096c6f07e --- /dev/null +++ b/crates/polars-compute/src/if_then_else/list.rs @@ -0,0 +1,83 @@ +use arrow::array::growable::{Growable, GrowableList}; +use arrow::array::{Array, ArrayCollectIterExt, ListArray}; +use arrow::bitmap::Bitmap; + +use super::{if_then_else_extend, IfThenElseKernel}; + +impl IfThenElseKernel for ListArray { + type Scalar<'a> = Box; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + let mut growable = GrowableList::new(vec![if_true, if_false], false, mask.len()); + unsafe { + if_then_else_extend( + &mut growable, + mask, + |g, off, len| g.extend(0, off, len), + |g, off, len| g.extend(1, off, len), + ) + }; + growable.to() + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + let if_true_list: ListArray = + std::iter::once(if_true).collect_arr_trusted_with_dtype(if_false.data_type().clone()); + let mut growable = GrowableList::new(vec![&if_true_list, if_false], false, mask.len()); + unsafe { + if_then_else_extend( + &mut growable, + mask, + |g, _, len| g.extend_copies(0, 0, 1, len), + |g, off, len| g.extend(1, off, len), + ) + }; + growable.to() + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + let if_false_list: ListArray = + std::iter::once(if_false).collect_arr_trusted_with_dtype(if_true.data_type().clone()); + let mut growable = GrowableList::new(vec![if_true, &if_false_list], false, mask.len()); + unsafe { + if_then_else_extend( + &mut growable, + mask, + |g, off, len| g.extend(0, off, len), + |g, _, len| g.extend_copies(1, 0, 1, len), + ) + }; + growable.to() + } + + fn if_then_else_broadcast_both( + dtype: arrow::datatypes::ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + let if_true_list: ListArray = + std::iter::once(if_true).collect_arr_trusted_with_dtype(dtype.clone()); + let if_false_list: ListArray = + std::iter::once(if_false).collect_arr_trusted_with_dtype(dtype.clone()); + let mut growable = + GrowableList::new(vec![&if_true_list, &if_false_list], false, mask.len()); + unsafe { + if_then_else_extend( + &mut growable, + mask, + |g, _, len| g.extend_copies(0, 0, 1, len), + |g, _, len| g.extend_copies(1, 0, 1, len), + ) + }; + growable.to() + } +} diff --git a/crates/polars-compute/src/if_then_else/mod.rs b/crates/polars-compute/src/if_then_else/mod.rs new file mode 100644 index 000000000000..3f704ca1d5f4 --- /dev/null +++ b/crates/polars-compute/src/if_then_else/mod.rs @@ -0,0 +1,326 @@ +use std::mem::MaybeUninit; + +use arrow::array::{Array, PrimitiveArray}; +use arrow::bitmap::utils::{align_bitslice_start_u8, SlicesIterator}; +use arrow::bitmap::{self, Bitmap}; +use arrow::datatypes::ArrowDataType; +use arrow::types::NativeType; +use polars_utils::slice::load_padded_le_u64; + +mod array; +mod boolean; +mod list; +mod scalar; +mod view; + +pub trait IfThenElseKernel: Sized + Array { + type Scalar<'a>; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self; + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self; + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self; + fn if_then_else_broadcast_both( + dtype: ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self; +} + +impl IfThenElseKernel for PrimitiveArray { + type Scalar<'a> = T; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + let values = if_then_else_loop( + mask, + if_true.values(), + if_false.values(), + scalar::if_then_else_scalar_rest, + scalar::if_then_else_scalar_64, + ); + let validity = if_then_else_validity(mask, if_true.validity(), if_false.validity()); + PrimitiveArray::from_vec(values).with_validity(validity) + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + let values = if_then_else_loop_broadcast_false( + true, + mask, + if_false.values(), + if_true, + scalar::if_then_else_broadcast_false_scalar_64, + ); + let validity = if_then_else_validity(mask, None, if_false.validity()); + PrimitiveArray::from_vec(values).with_validity(validity) + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + let values = if_then_else_loop_broadcast_false( + false, + mask, + if_true.values(), + if_false, + scalar::if_then_else_broadcast_false_scalar_64, + ); + let validity = if_then_else_validity(mask, if_true.validity(), None); + PrimitiveArray::from_vec(values).with_validity(validity) + } + + fn if_then_else_broadcast_both( + _dtype: ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + let values = if_then_else_loop_broadcast_both( + mask, + if_true, + if_false, + scalar::if_then_else_broadcast_both_scalar_64, + ); + PrimitiveArray::from_vec(values) + } +} + +fn if_then_else_validity( + mask: &Bitmap, + if_true: Option<&Bitmap>, + if_false: Option<&Bitmap>, +) -> Option { + match (if_true, if_false) { + (None, None) => None, + (None, Some(f)) => Some(mask | f), + (Some(t), None) => Some(bitmap::binary(mask, t, |m, t| !m | t)), + (Some(t), Some(f)) => Some(bitmap::ternary(mask, t, f, |m, t, f| (m & t) | (!m & f))), + } +} + +fn if_then_else_extend( + growable: &mut G, + mask: &Bitmap, + extend_true: ET, + extend_false: EF, +) { + let mut last_true_end = 0; + for (start, len) in SlicesIterator::new(mask) { + if start != last_true_end { + extend_false(growable, last_true_end, start - last_true_end); + }; + extend_true(growable, start, len); + last_true_end = start + len; + } + if last_true_end != mask.len() { + extend_false(growable, last_true_end, mask.len() - last_true_end) + } +} + +fn if_then_else_loop( + mask: &Bitmap, + if_true: &[T], + if_false: &[T], + process_var: F, + process_chunk: F64, +) -> Vec +where + T: Copy, + F: Fn(u64, &[T], &[T], &mut [MaybeUninit]), + F64: Fn(u64, &[T; 64], &[T; 64], &mut [MaybeUninit; 64]), +{ + assert_eq!(mask.len(), if_true.len()); + assert_eq!(mask.len(), if_false.len()); + let (mask_slice, offset, len) = mask.as_slice(); + + let mut ret = Vec::with_capacity(mask.len()); + let out = &mut ret.spare_capacity_mut()[..mask.len()]; + + // Handle offset. + let (start_byte, num_start_bits, bulk_mask, bulk_len) = + align_bitslice_start_u8(mask_slice, offset, len); + let (start_true, rest_true) = if_true.split_at(num_start_bits); + let (start_false, rest_false) = if_false.split_at(num_start_bits); + let (start_out, rest_out) = out.split_at_mut(num_start_bits); + process_var(start_byte as u64, start_true, start_false, start_out); + + // Handle bulk. + let mut true_chunks = rest_true.chunks_exact(64); + let mut false_chunks = rest_false.chunks_exact(64); + let mut out_chunks = rest_out.chunks_exact_mut(64); + let combined = true_chunks + .by_ref() + .zip(false_chunks.by_ref()) + .zip(out_chunks.by_ref()); + for (i, ((tc, fc), oc)) in combined.enumerate() { + let m = unsafe { + u64::from_le_bytes( + bulk_mask + .get_unchecked(8 * i..8 * i + 8) + .try_into() + .unwrap(), + ) + }; + process_chunk( + m, + tc.try_into().unwrap(), + fc.try_into().unwrap(), + oc.try_into().unwrap(), + ); + } + + // Handle remainder. + if !true_chunks.remainder().is_empty() { + let rest_mask_byte_offset = bulk_len / 64 * 8; + let rest_mask = load_padded_le_u64(&bulk_mask[rest_mask_byte_offset..]); + process_var( + rest_mask, + true_chunks.remainder(), + false_chunks.remainder(), + out_chunks.into_remainder(), + ); + } + + unsafe { + ret.set_len(mask.len()); + } + ret +} + +fn if_then_else_loop_broadcast_false( + invert_mask: bool, // Allows code reuse for both false and true broadcasts. + mask: &Bitmap, + if_true: &[T], + if_false: T, + process_chunk: F64, +) -> Vec +where + T: Copy, + F64: Fn(u64, &[T; 64], T, &mut [MaybeUninit; 64]), +{ + assert_eq!(mask.len(), if_true.len()); + let (mask_slice, offset, len) = mask.as_slice(); + + let mut ret = Vec::with_capacity(mask.len()); + let out = &mut ret.spare_capacity_mut()[..mask.len()]; + + // XOR with all 1's inverts the mask. + let xor_inverter = if invert_mask { u64::MAX } else { 0 }; + + // Handle offset. + let (start_byte, num_start_bits, bulk_mask, bulk_len) = + align_bitslice_start_u8(mask_slice, offset, len); + let (start_true, rest_true) = if_true.split_at(num_start_bits); + let (start_out, rest_out) = out.split_at_mut(num_start_bits); + scalar::if_then_else_broadcast_false_scalar_rest( + start_byte as u64 ^ xor_inverter, + start_true, + if_false, + start_out, + ); + + // Handle bulk. + let mut true_chunks = rest_true.chunks_exact(64); + let mut out_chunks = rest_out.chunks_exact_mut(64); + let combined = true_chunks.by_ref().zip(out_chunks.by_ref()); + for (i, (tc, oc)) in combined.enumerate() { + let m = unsafe { + u64::from_le_bytes( + bulk_mask + .get_unchecked(8 * i..8 * i + 8) + .try_into() + .unwrap(), + ) + }; + process_chunk( + m ^ xor_inverter, + tc.try_into().unwrap(), + if_false, + oc.try_into().unwrap(), + ); + } + + // Handle remainder. + if !true_chunks.remainder().is_empty() { + let rest_mask_byte_offset = bulk_len / 64 * 8; + let rest_mask = load_padded_le_u64(&bulk_mask[rest_mask_byte_offset..]); + scalar::if_then_else_broadcast_false_scalar_rest( + rest_mask ^ xor_inverter, + true_chunks.remainder(), + if_false, + out_chunks.into_remainder(), + ); + } + + unsafe { + ret.set_len(mask.len()); + } + ret +} + +fn if_then_else_loop_broadcast_both( + mask: &Bitmap, + if_true: T, + if_false: T, + generate_chunk: F64, +) -> Vec +where + T: Copy, + F64: Fn(u64, T, T, &mut [MaybeUninit; 64]), +{ + let (mask_slice, offset, len) = mask.as_slice(); + + let mut ret = Vec::with_capacity(mask.len()); + let out = &mut ret.spare_capacity_mut()[..mask.len()]; + + // Handle offset. + let (start_byte, num_start_bits, bulk_mask, bulk_len) = + align_bitslice_start_u8(mask_slice, offset, len); + let (start_out, rest_out) = out.split_at_mut(num_start_bits); + scalar::if_then_else_broadcast_both_scalar_rest( + start_byte as u64, + if_true, + if_false, + start_out, + ); + + // Handle bulk. + let mut out_chunks = rest_out.chunks_exact_mut(64); + for (i, oc) in out_chunks.by_ref().enumerate() { + let m = unsafe { + u64::from_le_bytes( + bulk_mask + .get_unchecked(8 * i..8 * i + 8) + .try_into() + .unwrap(), + ) + }; + generate_chunk(m, if_true, if_false, oc.try_into().unwrap()); + } + + // Handle remainder. + let out_chunk = out_chunks.into_remainder(); + if !out_chunk.is_empty() { + let rest_mask_byte_offset = bulk_len / 64 * 8; + let rest_mask = load_padded_le_u64(&bulk_mask[rest_mask_byte_offset..]); + scalar::if_then_else_broadcast_both_scalar_rest(rest_mask, if_true, if_false, out_chunk); + } + + unsafe { + ret.set_len(mask.len()); + } + ret +} diff --git a/crates/polars-compute/src/if_then_else/scalar.rs b/crates/polars-compute/src/if_then_else/scalar.rs new file mode 100644 index 000000000000..2e1d6b396ddb --- /dev/null +++ b/crates/polars-compute/src/if_then_else/scalar.rs @@ -0,0 +1,76 @@ +use std::mem::MaybeUninit; + +pub fn if_then_else_scalar_rest( + mask: u64, + if_true: &[T], + if_false: &[T], + out: &mut [MaybeUninit], +) { + assert!(if_true.len() <= out.len()); // Removes bounds checks in inner loop. + let true_it = if_true.iter().copied(); + let false_it = if_false.iter().copied(); + for (i, (t, f)) in true_it.zip(false_it).enumerate() { + let src = if (mask >> i) & 1 != 0 { t } else { f }; + out[i] = MaybeUninit::new(src); + } +} + +pub fn if_then_else_broadcast_false_scalar_rest( + mask: u64, + if_true: &[T], + if_false: T, + out: &mut [MaybeUninit], +) { + assert!(if_true.len() <= out.len()); // Removes bounds checks in inner loop. + let true_it = if_true.iter().copied(); + for (i, t) in true_it.enumerate() { + let src = if (mask >> i) & 1 != 0 { t } else { if_false }; + out[i] = MaybeUninit::new(src); + } +} + +pub fn if_then_else_broadcast_both_scalar_rest( + mask: u64, + if_true: T, + if_false: T, + out: &mut [MaybeUninit], +) { + for (i, dst) in out.iter_mut().enumerate() { + let src = if (mask >> i) & 1 != 0 { + if_true + } else { + if_false + }; + *dst = MaybeUninit::new(src); + } +} + +pub fn if_then_else_scalar_64( + mask: u64, + if_true: &[T; 64], + if_false: &[T; 64], + out: &mut [MaybeUninit; 64], +) { + // This generated the best autovectorized code on ARM, and branchless everywhere. + if_then_else_scalar_rest(mask, if_true, if_false, out) +} + +pub fn if_then_else_broadcast_false_scalar_64( + mask: u64, + if_true: &[T; 64], + if_false: T, + out: &mut [MaybeUninit; 64], +) { + // This generated the best autovectorized code on ARM, and branchless everywhere. + if_then_else_broadcast_false_scalar_rest(mask, if_true, if_false, out) +} + +pub fn if_then_else_broadcast_both_scalar_64( + mask: u64, + if_true: T, + if_false: T, + out: &mut [MaybeUninit; 64], +) { + // This generated the best autovectorized code on ARM, and branchless everywhere. + if_then_else_broadcast_both_scalar_rest(mask, if_true, if_false, out) +} diff --git a/crates/polars-compute/src/if_then_else/view.rs b/crates/polars-compute/src/if_then_else/view.rs new file mode 100644 index 000000000000..5b1100153b03 --- /dev/null +++ b/crates/polars-compute/src/if_then_else/view.rs @@ -0,0 +1,244 @@ +use std::mem::MaybeUninit; +use std::sync::Arc; + +use arrow::array::{Array, BinaryViewArray, Utf8ViewArray, View}; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; + +use super::IfThenElseKernel; +use crate::if_then_else::scalar::{ + if_then_else_broadcast_both_scalar_64, if_then_else_broadcast_false_scalar_64, +}; + +// Makes a buffer and a set of views into that buffer from a set of strings. +// Does not allocate a buffer if not necessary. +fn make_buffer_and_views( + strings: [&[u8]; N], + buffer_idx: u32, +) -> ([View; N], Option>) { + let mut buf_data = Vec::new(); + let views = strings.map(|s| { + let offset = buf_data.len().try_into().unwrap(); + if s.len() > 12 { + buf_data.extend(s); + } + View::new_from_bytes(s, buffer_idx, offset) + }); + let buf = (!buf_data.is_empty()).then(|| buf_data.into()); + (views, buf) +} + +impl IfThenElseKernel for BinaryViewArray { + type Scalar<'a> = &'a [u8]; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + let combined_buffers: Arc<_>; + let combined_buffer_len: usize; + let false_buffer_idx_offset: u32; + if Arc::ptr_eq(if_true.data_buffers(), if_false.data_buffers()) { + // Share exact same buffers, no need to combine. + combined_buffers = if_true.data_buffers().clone(); + combined_buffer_len = if_true.total_buffer_len(); + false_buffer_idx_offset = 0; + } else { + // Put false buffers after true buffers. + let true_buffers = if_true.data_buffers().iter().cloned(); + let false_buffers = if_false.data_buffers().iter().cloned(); + combined_buffers = true_buffers.chain(false_buffers).collect(); + combined_buffer_len = if_true.total_buffer_len() + if_false.total_buffer_len(); + false_buffer_idx_offset = if_true.data_buffers().len() as u32; + } + + let views = super::if_then_else_loop( + mask, + if_true.views(), + if_false.views(), + |m, t, f, o| if_then_else_view_rest(m, t, f, o, false_buffer_idx_offset), + |m, t, f, o| if_then_else_view_64(m, t, f, o, false_buffer_idx_offset), + ); + + let validity = super::if_then_else_validity(mask, if_true.validity(), if_false.validity()); + unsafe { + BinaryViewArray::new_unchecked_unknown_md( + if_true.data_type().clone(), + views.into(), + combined_buffers, + validity, + Some(combined_buffer_len), + ) + } + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + // It's cheaper if we put the false buffers first, that way we don't need to modify any views in the loop. + let false_buffers = if_false.data_buffers().iter().cloned(); + let true_buffer_idx_offset: u32 = if_false.data_buffers().len() as u32; + let ([true_view], true_buffer) = make_buffer_and_views([if_true], true_buffer_idx_offset); + let combined_buffers: Arc<_> = false_buffers.chain(true_buffer).collect(); + let combined_buffer_len = if_false.total_buffer_len() + if_true.len(); + + let views = super::if_then_else_loop_broadcast_false( + true, // Invert the mask so we effectively broadcast true. + mask, + if_false.views(), + true_view, + if_then_else_broadcast_false_scalar_64, + ); + + let validity = super::if_then_else_validity(mask, None, if_false.validity()); + unsafe { + BinaryViewArray::new_unchecked_unknown_md( + if_false.data_type().clone(), + views.into(), + combined_buffers, + validity, + Some(combined_buffer_len), + ) + } + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + // It's cheaper if we put the true buffers first, that way we don't need to modify any views in the loop. + let true_buffers = if_true.data_buffers().iter().cloned(); + let false_buffer_idx_offset: u32 = if_true.data_buffers().len() as u32; + let ([false_view], false_buffer) = + make_buffer_and_views([if_false], false_buffer_idx_offset); + let combined_buffers: Arc<_> = true_buffers.chain(false_buffer).collect(); + let combined_buffer_len = if_true.total_buffer_len() + if_false.len(); + + let views = super::if_then_else_loop_broadcast_false( + false, + mask, + if_true.views(), + false_view, + if_then_else_broadcast_false_scalar_64, + ); + + let validity = super::if_then_else_validity(mask, if_true.validity(), None); + unsafe { + BinaryViewArray::new_unchecked_unknown_md( + if_true.data_type().clone(), + views.into(), + combined_buffers, + validity, + Some(combined_buffer_len), + ) + } + } + + fn if_then_else_broadcast_both( + dtype: ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + let total_len = if_true.len() + if_false.len(); + let ([true_view, false_view], buffer) = make_buffer_and_views([if_true, if_false], 0); + let buffers: Arc<_> = buffer.into_iter().collect(); + let views = super::if_then_else_loop_broadcast_both( + mask, + true_view, + false_view, + if_then_else_broadcast_both_scalar_64, + ); + unsafe { + BinaryViewArray::new_unchecked(dtype, views.into(), buffers, None, total_len, total_len) + } + } +} + +impl IfThenElseKernel for Utf8ViewArray { + type Scalar<'a> = &'a str; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + let ret = + IfThenElseKernel::if_then_else(mask, &if_true.to_binview(), &if_false.to_binview()); + unsafe { ret.to_utf8view_unchecked() } + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + let ret = IfThenElseKernel::if_then_else_broadcast_true( + mask, + if_true.as_bytes(), + &if_false.to_binview(), + ); + unsafe { ret.to_utf8view_unchecked() } + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + let ret = IfThenElseKernel::if_then_else_broadcast_false( + mask, + &if_true.to_binview(), + if_false.as_bytes(), + ); + unsafe { ret.to_utf8view_unchecked() } + } + + fn if_then_else_broadcast_both( + dtype: ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + let ret: BinaryViewArray = IfThenElseKernel::if_then_else_broadcast_both( + dtype, + mask, + if_true.as_bytes(), + if_false.as_bytes(), + ); + unsafe { ret.to_utf8view_unchecked() } + } +} + +pub fn if_then_else_view_rest( + mask: u64, + if_true: &[View], + if_false: &[View], + out: &mut [MaybeUninit], + false_buffer_idx_offset: u32, +) { + assert!(if_true.len() <= out.len()); // Removes bounds checks in inner loop. + let true_it = if_true.iter().copied(); + let false_it = if_false.iter().copied(); + for (i, (t, f)) in true_it.zip(false_it).enumerate() { + // Written like this, this loop *should* be branchless. + // Unfortunately we're still dependent on the compiler. + let m = (mask >> i) & 1 != 0; + let mut v = if m { t } else { f }; + let offset = if m | (v.length <= 12) { + // Yes, | instead of || is intentional. + 0 + } else { + false_buffer_idx_offset + }; + v.buffer_idx += offset; + out[i] = MaybeUninit::new(v); + } +} + +pub fn if_then_else_view_64( + mask: u64, + if_true: &[View; 64], + if_false: &[View; 64], + out: &mut [MaybeUninit; 64], + false_buffer_idx_offset: u32, +) { + if_then_else_view_rest(mask, if_true, if_false, out, false_buffer_idx_offset) +} diff --git a/crates/polars-compute/src/lib.rs b/crates/polars-compute/src/lib.rs index 797e8a8af9b0..cc477a817739 100644 --- a/crates/polars-compute/src/lib.rs +++ b/crates/polars-compute/src/lib.rs @@ -8,6 +8,7 @@ pub mod arithmetic; pub mod comparisons; pub mod filter; +pub mod if_then_else; pub mod min_max; pub mod arity; diff --git a/crates/polars-core/src/chunked_array/collect.rs b/crates/polars-core/src/chunked_array/collect.rs index 2d0226029236..6131d741eac6 100644 --- a/crates/polars-core/src/chunked_array/collect.rs +++ b/crates/polars-core/src/chunked_array/collect.rs @@ -12,7 +12,6 @@ use std::sync::Arc; -use arrow::datatypes::ArrowDataType; use arrow::trusted_len::TrustedLen; use crate::chunked_array::ChunkedArray; @@ -20,20 +19,6 @@ use crate::datatypes::{ ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype, DataType, Field, PolarsDataType, }; -pub(crate) fn prepare_collect_dtype(dtype: &DataType) -> ArrowDataType { - match dtype { - #[cfg(feature = "object")] - DataType::Object(_, reg) => match reg { - Some(reg) => reg.physical_dtype.clone(), - None => { - use crate::chunked_array::object::registry; - registry::get_object_physical_type() - }, - }, - dt => dt.to_arrow(true), - } -} - pub trait ChunkedCollectIterExt: Iterator + Sized { #[inline] fn collect_ca_with_dtype(self, name: &str, dtype: DataType) -> ChunkedArray @@ -41,8 +26,7 @@ pub trait ChunkedCollectIterExt: Iterator + Sized { T::Array: ArrayFromIterDtype, { let field = Arc::new(Field::new(name, dtype.clone())); - let arrow_dtype = prepare_collect_dtype(&field.dtype); - let arr = self.collect_arr_with_dtype(arrow_dtype); + let arr = self.collect_arr_with_dtype(field.dtype.to_arrow(true)); ChunkedArray::from_chunk_iter_and_field(field, [arr]) } @@ -52,8 +36,7 @@ pub trait ChunkedCollectIterExt: Iterator + Sized { T::Array: ArrayFromIterDtype, { let field = Arc::clone(&name_dtype_src.field); - let arrow_dtype = prepare_collect_dtype(&field.dtype); - let arr = self.collect_arr_with_dtype(arrow_dtype); + let arr = self.collect_arr_with_dtype(field.dtype.to_arrow(true)); ChunkedArray::from_chunk_iter_and_field(field, [arr]) } @@ -64,8 +47,7 @@ pub trait ChunkedCollectIterExt: Iterator + Sized { Self: TrustedLen, { let field = Arc::new(Field::new(name, dtype.clone())); - let arrow_dtype = prepare_collect_dtype(&field.dtype); - let arr = self.collect_arr_trusted_with_dtype(arrow_dtype); + let arr = self.collect_arr_trusted_with_dtype(field.dtype.to_arrow(true)); ChunkedArray::from_chunk_iter_and_field(field, [arr]) } @@ -76,8 +58,7 @@ pub trait ChunkedCollectIterExt: Iterator + Sized { Self: TrustedLen, { let field = Arc::clone(&name_dtype_src.field); - let arrow_dtype = prepare_collect_dtype(&field.dtype); - let arr = self.collect_arr_trusted_with_dtype(arrow_dtype); + let arr = self.collect_arr_trusted_with_dtype(field.dtype.to_arrow(true)); ChunkedArray::from_chunk_iter_and_field(field, [arr]) } @@ -92,8 +73,7 @@ pub trait ChunkedCollectIterExt: Iterator + Sized { Self: Iterator>, { let field = Arc::new(Field::new(name, dtype.clone())); - let arrow_dtype = prepare_collect_dtype(&field.dtype); - let arr = self.try_collect_arr_with_dtype(arrow_dtype)?; + let arr = self.try_collect_arr_with_dtype(field.dtype.to_arrow(true))?; Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) } @@ -107,8 +87,7 @@ pub trait ChunkedCollectIterExt: Iterator + Sized { Self: Iterator>, { let field = Arc::clone(&name_dtype_src.field); - let arrow_dtype = prepare_collect_dtype(&field.dtype); - let arr = self.try_collect_arr_with_dtype(arrow_dtype)?; + let arr = self.try_collect_arr_with_dtype(field.dtype.to_arrow(true))?; Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) } @@ -123,8 +102,7 @@ pub trait ChunkedCollectIterExt: Iterator + Sized { Self: Iterator> + TrustedLen, { let field = Arc::new(Field::new(name, dtype.clone())); - let arrow_dtype = prepare_collect_dtype(&field.dtype); - let arr = self.try_collect_arr_trusted_with_dtype(arrow_dtype)?; + let arr = self.try_collect_arr_trusted_with_dtype(field.dtype.to_arrow(true))?; Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) } @@ -138,8 +116,7 @@ pub trait ChunkedCollectIterExt: Iterator + Sized { Self: Iterator> + TrustedLen, { let field = Arc::clone(&name_dtype_src.field); - let arrow_dtype = prepare_collect_dtype(&field.dtype); - let arr = self.try_collect_arr_trusted_with_dtype(arrow_dtype)?; + let arr = self.try_collect_arr_trusted_with_dtype(field.dtype.to_arrow(true))?; Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) } } diff --git a/crates/polars-core/src/chunked_array/from.rs b/crates/polars-core/src/chunked_array/from.rs index 9160aa566d10..b65b4b71d9a8 100644 --- a/crates/polars-core/src/chunked_array/from.rs +++ b/crates/polars-core/src/chunked_array/from.rs @@ -281,6 +281,11 @@ where out.compute_len(); out } + + pub fn full_null_like(ca: &Self, length: usize) -> Self { + let chunks = std::iter::once(T::Array::full_null(length, ca.dtype().to_arrow(true))); + Self::from_chunk_iter_like(ca, chunks) + } } impl ChunkedArray diff --git a/crates/polars-core/src/chunked_array/object/builder.rs b/crates/polars-core/src/chunked_array/object/builder.rs index 1bf6727342a9..f16ae12fbfe3 100644 --- a/crates/polars-core/src/chunked_array/object/builder.rs +++ b/crates/polars-core/src/chunked_array/object/builder.rs @@ -1,7 +1,5 @@ use std::marker::PhantomData; -use arrow::bitmap::MutableBitmap; - use super::*; use crate::chunked_array::object::registry::{AnonymousObjectBuilder, ObjectRegistry}; use crate::utils::get_iter_capacity; diff --git a/crates/polars-core/src/chunked_array/object/mod.rs b/crates/polars-core/src/chunked_array/object/mod.rs index 9f17a1d1b434..c834a61fa990 100644 --- a/crates/polars-core/src/chunked_array/object/mod.rs +++ b/crates/polars-core/src/chunked_array/object/mod.rs @@ -3,7 +3,7 @@ use std::fmt::{Debug, Display}; use std::hash::Hash; use arrow::bitmap::utils::{BitmapIter, ZipValidity}; -use arrow::bitmap::Bitmap; +use arrow::bitmap::{Bitmap, MutableBitmap}; use polars_utils::total_ord::TotalHash; use crate::prelude::*; @@ -153,7 +153,7 @@ where } fn data_type(&self) -> &ArrowDataType { - unimplemented!() + &ArrowDataType::FixedSizeBinary(std::mem::size_of::()) } fn slice(&mut self, offset: usize, length: usize) { @@ -199,6 +199,44 @@ where } } +impl StaticArray for ObjectArray { + type ValueT<'a> = &'a T; + type ZeroableValueT<'a> = Option<&'a T>; + type ValueIterT<'a> = ObjectValueIter<'a, T>; + + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + + fn values_iter(&self) -> Self::ValueIterT<'_> { + self.values_iter() + } + + fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { + self.iter() + } + + fn with_validity_typed(self, validity: Option) -> Self { + self.with_validity(validity) + } + + fn full_null(length: usize, _dtype: ArrowDataType) -> Self { + ObjectArray { + values: Arc::new(vec![T::default(); length]), + null_bitmap: Some(Bitmap::new_with_value(false, length)), + offset: 0, + len: length, + } + } +} + +impl ParameterFreeDtypeStaticArray for ObjectArray { + fn get_dtype() -> ArrowDataType { + ArrowDataType::FixedSizeBinary(std::mem::size_of::()) + } +} + impl ObjectChunked where T: PolarsObject, diff --git a/crates/polars-core/src/chunked_array/ops/full.rs b/crates/polars-core/src/chunked_array/ops/full.rs index 16ba9d5f0ba8..818cf19f1720 100644 --- a/crates/polars-core/src/chunked_array/ops/full.rs +++ b/crates/polars-core/src/chunked_array/ops/full.rs @@ -134,17 +134,19 @@ impl ArrayChunked { #[cfg(feature = "dtype-array")] impl ChunkFull<&Series> for ArrayChunked { fn full(name: &str, value: &Series, length: usize) -> ArrayChunked { - if !value.dtype().is_numeric() { - todo!("Array only supports numeric data types"); - }; let width = value.len(); - let values = value.tile(length); - let values = values.chunks()[0].clone(); - let data_type = ArrowDataType::FixedSizeList( - Box::new(ArrowField::new("item", values.data_type().clone(), true)), + let dtype = value.dtype(); + let arrow_dtype = ArrowDataType::FixedSizeList( + Box::new(ArrowField::new("item", dtype.to_arrow(true), true)), width, ); - let arr = FixedSizeListArray::new(data_type, values, None); + let arr = if value.dtype().is_numeric() { + let values = value.tile(length); + FixedSizeListArray::new(arrow_dtype, values.chunks()[0].clone(), None) + } else { + let value = value.rechunk().chunks()[0].clone(); + FixedSizeListArray::full(length, value, arrow_dtype) + }; ChunkedArray::with_chunk(name, arr) } } diff --git a/crates/polars-core/src/chunked_array/ops/gather.rs b/crates/polars-core/src/chunked_array/ops/gather.rs index fbc2d827db7f..6bbbc86c11eb 100644 --- a/crates/polars-core/src/chunked_array/ops/gather.rs +++ b/crates/polars-core/src/chunked_array/ops/gather.rs @@ -4,7 +4,6 @@ use arrow::compute::take::take_unchecked; use polars_error::polars_ensure; use polars_utils::index::check_bounds; -use crate::chunked_array::collect::prepare_collect_dtype; use crate::prelude::*; use crate::series::IsSorted; @@ -153,7 +152,7 @@ impl + ?Sized> ChunkTakeUnchecked for } let targets: Vec<_> = ca.downcast_iter().collect(); let arr = gather_idx_array_unchecked( - prepare_collect_dtype(ca.dtype()), + ca.dtype().to_arrow(true), &targets, ca.null_count() > 0, indices.as_ref(), @@ -210,7 +209,7 @@ impl ChunkTakeUnchecked for ChunkedAr let targets: Vec<_> = ca.downcast_iter().collect(); let chunks = indices.downcast_iter().map(|idx_arr| { - let dtype = prepare_collect_dtype(ca.dtype()); + let dtype = ca.dtype().to_arrow(true); if idx_arr.null_count() == 0 { gather_idx_array_unchecked(dtype, &targets, targets_have_nulls, idx_arr.values()) } else if targets.len() == 1 { diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index acd5df24bfb8..de8f78f00e7d 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -562,7 +562,12 @@ impl ChunkExpandAtIndex for ArrayChunked { unsafe { ca.to_logical(self.inner_dtype()) }; ca }, - None => ArrayChunked::full_null_with_dtype(self.name(), length, &self.inner_dtype(), 0), + None => ArrayChunked::full_null_with_dtype( + self.name(), + length, + &self.inner_dtype(), + self.width(), + ), } } } diff --git a/crates/polars-core/src/chunked_array/ops/zip.rs b/crates/polars-core/src/chunked_array/ops/zip.rs index 80b3bcdfd815..8319c81d9c3c 100644 --- a/crates/polars-core/src/chunked_array/ops/zip.rs +++ b/crates/polars-core/src/chunked_array/ops/zip.rs @@ -1,272 +1,208 @@ -use arrow::compute::if_then_else::if_then_else; +use arrow::bitmap::Bitmap; +use arrow::compute::utils::{combine_validities_and, combine_validities_and_not}; +use polars_compute::if_then_else::IfThenElseKernel; +#[cfg(feature = "object")] +use crate::chunked_array::object::ObjectArray; use crate::prelude::*; -use crate::utils::align_chunks_ternary; - -fn ternary_apply(predicate: bool, truthy: T, falsy: T) -> T { - if predicate { - truthy - } else { - falsy - } -} +use crate::utils::{align_chunks_binary, align_chunks_ternary}; -fn prepare_mask(mask: &BooleanArray) -> BooleanArray { - // make sure that zip works same as main branch - // that is that null are ignored from mask and that we take from the right array +const SHAPE_MISMATCH_STR: &str = + "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation"; - match mask.validity() { - // nulls are set to true meaning we take from the right in the zip/ if_then_else kernel - Some(validity) if validity.unset_bits() != 0 => { - let mask = mask.values() & validity; - BooleanArray::from_data_default(mask, None) - }, - _ => mask.clone(), - } +fn if_then_else_broadcast_mask( + mask: bool, + if_true: &ChunkedArray, + if_false: &ChunkedArray, +) -> PolarsResult> +where + ChunkedArray: ChunkExpandAtIndex, +{ + let src = if mask { if_true } else { if_false }; + let other = if mask { if_false } else { if_true }; + let ret = match (src.len(), other.len()) { + (a, b) if a == b => src.clone(), + (_, 1) => src.clone(), + (1, other_len) => src.new_from_index(0, other_len), + _ => polars_bail!(ShapeMismatch: SHAPE_MISMATCH_STR), + }; + Ok(ret.with_name(if_true.name())) } -macro_rules! impl_ternary_broadcast { - ($self:ident, $self_len:expr, $other_len:expr, $mask_len: expr, $other:expr, $mask:expr, $ty:ty) => {{ - match ($self_len, $other_len, $mask_len) { - (1, 1, _) => { - let left = $self.get(0); - let right = $other.get(0); - let mut val: ChunkedArray<$ty> = $mask.apply_generic(|mask| ternary_apply(mask.unwrap_or(false), left, right)); - val.rename($self.name()); - Ok(val) - } - (_, 1, 1) => { - let mut val = if let Some(true) = $mask.get(0) { - $self.clone() - } else { - $other.new_from_index(0, $self_len) - }; - - val.rename($self.name()); - Ok(val) - } - (1, _, 1) => { - let mut val = if let Some(true) = $mask.get(0) { - $self.new_from_index(0, $other_len) - } else { - $other.clone() - }; - - val.rename($self.name()); - Ok(val) - }, - (1, r_len, mask_len) if r_len == mask_len =>{ - let left = $self.get(0); - - let mut val: ChunkedArray<$ty> = $mask - .into_iter() - .zip($other) - .map(|(mask, right)| ternary_apply(mask.unwrap_or(false), left, right)) - .collect_trusted(); - val.rename($self.name()); - Ok(val) - }, - (l_len, 1, mask_len) if l_len == mask_len => { - let mask = $mask.apply_kernel(&|arr| prepare_mask(arr).to_boxed()); - let right = $other.get(0); - - let mut val: ChunkedArray<$ty> = mask - .into_iter() - .zip($self) - .map(|(mask, left)| ternary_apply(mask.unwrap_or(false), left, right)) - .collect_trusted(); - val.rename($self.name()); - Ok(val) - }, - (l_len, r_len, 1) if l_len == r_len => { - let mut val = if let Some(true) = $mask.get(0) { - $self.clone() - } else { - $other.clone() - }; - - val.rename($self.name()); - Ok(val) - }, - (_, _, 0) => { - Ok($self.clear()) - } - (_, _, _) => Err(polars_err!( - ShapeMismatch: "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation" - )), +fn bool_null_to_false(mask: &BooleanArray) -> Bitmap { + if mask.null_count() == 0 { + mask.values().clone() + } else { + mask.values() & mask.validity().unwrap() } - }}; -} - -macro_rules! expand_unit { - ($ca:ident, $len:ident) => {{ - if $ca.len() == 1 { - $ca.new_from_index(0, $len) - } else { - $ca.clone() - } - }}; -} - -macro_rules! expand_lengths { - ($truthy:ident, $falsy:ident, $mask:ident) => { - if $mask.is_empty() { - ($truthy.clear(), $falsy.clear(), $mask.clone()) - } else { - let len = std::cmp::max(std::cmp::max($truthy.len(), $falsy.len()), $mask.len()); - - let $truthy = expand_unit!($truthy, len); - let $falsy = expand_unit!($falsy, len); - let $mask = expand_unit!($mask, len); - - ($truthy, $falsy, $mask) - } - }; } -fn zip_with( - left: &ChunkedArray, - right: &ChunkedArray, +/// Combines the validities of ca with the bits in mask using the given combiner. +/// +/// If the mask itself has validity, those null bits are converted to false. +fn combine_validities_chunked< + T: PolarsDataType, + F: Fn(Option<&Bitmap>, Option<&Bitmap>) -> Option, +>( + ca: &ChunkedArray, mask: &BooleanChunked, -) -> PolarsResult> { - if left.len() != right.len() || right.len() != mask.len() { - return Err(polars_err!( - ShapeMismatch: "shapes of `left`, `right` and `mask` are not suitable for `zip_with` operation" - )); - }; - - let (left, right, mask) = align_chunks_ternary(left, right, mask); - let chunks = left - .chunks() - .iter() - .zip(right.chunks()) - .zip(mask.downcast_iter()) - .map(|((left_c, right_c), mask_c)| { - let mask_c = prepare_mask(mask_c); - let arr = if_then_else(&mask_c, left_c.as_ref(), right_c.as_ref())?; - Ok(arr) - }) - .collect::>>()?; - unsafe { Ok(left.copy_with_chunks(chunks, false, false)) } + combiner: F, +) -> ChunkedArray { + let (ca_al, mask_al) = align_chunks_binary(ca, mask); + let chunks = ca_al + .downcast_iter() + .zip(mask_al.downcast_iter()) + .map(|(a, m)| { + let bm = bool_null_to_false(m); + let validity = combiner(a.validity(), Some(&bm)); + a.clone().with_validity_typed(validity) + }); + ChunkedArray::from_chunk_iter_like(ca, chunks) } impl ChunkZip for ChunkedArray where - T: PolarsNumericType, + T: PolarsDataType, + T::Array: for<'a> IfThenElseKernel = T::Physical<'a>>, + ChunkedArray: ChunkExpandAtIndex, { fn zip_with( &self, mask: &BooleanChunked, other: &ChunkedArray, ) -> PolarsResult> { - // broadcasting path - if self.len() != mask.len() || other.len() != mask.len() { - impl_ternary_broadcast!(self, self.len(), other.len(), mask.len(), other, mask, T) - } else { - zip_with(self, other, mask) + let if_true = self; + let if_false = other; + + // Broadcast mask. + if mask.len() == 1 { + return if_then_else_broadcast_mask(mask.get(0).unwrap_or(false), if_true, if_false); } - } -} -impl ChunkZip for BooleanChunked { - fn zip_with( - &self, - mask: &BooleanChunked, - other: &BooleanChunked, - ) -> PolarsResult { - // broadcasting path - if self.len() != mask.len() || other.len() != mask.len() { - impl_ternary_broadcast!( - self, - self.len(), - other.len(), - mask.len(), - other, - mask, - BooleanType - ) + // Broadcast both. + let ret = if if_true.len() == 1 && if_false.len() == 1 { + match (if_true.get(0), if_false.get(0)) { + (None, None) => ChunkedArray::full_null_like(if_true, mask.len()), + (None, Some(_)) => combine_validities_chunked( + &if_false.new_from_index(0, mask.len()), + mask, + combine_validities_and_not, + ), + (Some(_), None) => combine_validities_chunked( + &if_true.new_from_index(0, mask.len()), + mask, + combine_validities_and, + ), + (Some(t), Some(f)) => { + let dtype = if_true.downcast_iter().next().unwrap().data_type(); + let chunks = mask.downcast_iter().map(|m| { + let bm = bool_null_to_false(m); + let t = t.clone(); + let f = f.clone(); + IfThenElseKernel::if_then_else_broadcast_both(dtype.clone(), &bm, t, f) + }); + ChunkedArray::from_chunk_iter_like(if_true, chunks) + }, + } + + // Broadcast neither. + } else if if_true.len() == if_false.len() { + polars_ensure!(mask.len() == if_true.len(), ShapeMismatch: SHAPE_MISMATCH_STR); + let (mask_al, if_true_al, if_false_al) = align_chunks_ternary(mask, if_true, if_false); + let chunks = mask_al + .downcast_iter() + .zip(if_true_al.downcast_iter()) + .zip(if_false_al.downcast_iter()) + .map(|((m, t), f)| IfThenElseKernel::if_then_else(&bool_null_to_false(m), t, f)); + ChunkedArray::from_chunk_iter_like(if_true, chunks) + + // Broadcast true value. + } else if if_true.len() == 1 { + polars_ensure!(mask.len() == if_false.len(), ShapeMismatch: SHAPE_MISMATCH_STR); + if let Some(true_scalar) = if_true.get(0) { + let (mask_al, if_false_al) = align_chunks_binary(mask, if_false); + let chunks = mask_al + .downcast_iter() + .zip(if_false_al.downcast_iter()) + .map(|(m, f)| { + let bm = bool_null_to_false(m); + let t = true_scalar.clone(); + IfThenElseKernel::if_then_else_broadcast_true(&bm, t, f) + }); + ChunkedArray::from_chunk_iter_like(if_true, chunks) + } else { + combine_validities_chunked(if_false, mask, combine_validities_and_not) + } + + // Broadcast false value. + } else if if_false.len() == 1 { + polars_ensure!(mask.len() == if_true.len(), ShapeMismatch: SHAPE_MISMATCH_STR); + if let Some(false_scalar) = if_false.get(0) { + let (mask_al, if_true_al) = align_chunks_binary(mask, if_true); + let chunks = + mask_al + .downcast_iter() + .zip(if_true_al.downcast_iter()) + .map(|(m, t)| { + let bm = bool_null_to_false(m); + let f = false_scalar.clone(); + IfThenElseKernel::if_then_else_broadcast_false(&bm, t, f) + }); + ChunkedArray::from_chunk_iter_like(if_false, chunks) + } else { + combine_validities_chunked(if_true, mask, combine_validities_and) + } } else { - zip_with(self, other, mask) - } - } -} + polars_bail!(ShapeMismatch: SHAPE_MISMATCH_STR) + }; -impl ChunkZip for StringChunked { - fn zip_with( - &self, - mask: &BooleanChunked, - other: &StringChunked, - ) -> PolarsResult { - unsafe { - self.as_binary() - .zip_with(mask, &other.as_binary()) - .map(|ca| ca.to_string()) - } + Ok(ret.with_name(if_true.name())) } } -impl ChunkZip for BinaryChunked { - fn zip_with( - &self, - mask: &BooleanChunked, - other: &BinaryChunked, - ) -> PolarsResult { - if self.len() != mask.len() || other.len() != mask.len() { - impl_ternary_broadcast!( - self, - self.len(), - other.len(), - mask.len(), - other, - mask, - BinaryType - ) - } else { - zip_with(self, other, mask) - } +// Basic implementation for ObjectArray. +#[cfg(feature = "object")] +impl IfThenElseKernel for ObjectArray { + type Scalar<'a> = &'a T; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + mask.iter() + .zip(if_true.iter()) + .zip(if_false.iter()) + .map(|((m, t), f)| if m { t } else { f }) + .collect_arr() } -} -impl ChunkZip for ListChunked { - fn zip_with(&self, mask: &BooleanChunked, other: &ListChunked) -> PolarsResult { - let (truthy, falsy, mask) = (self, other, mask); - let (truthy, falsy, mask) = expand_lengths!(truthy, falsy, mask); - zip_with(&truthy, &falsy, &mask) + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + mask.iter() + .zip(if_false.iter()) + .map(|(m, f)| if m { Some(if_true) } else { f }) + .collect_arr() } -} -#[cfg(feature = "dtype-array")] -impl ChunkZip for ArrayChunked { - fn zip_with(&self, mask: &BooleanChunked, other: &ArrayChunked) -> PolarsResult { - let (truthy, falsy, mask) = (self, other, mask); - let (truthy, falsy, mask) = expand_lengths!(truthy, falsy, mask); - zip_with(&truthy, &falsy, &mask) + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + mask.iter() + .zip(if_true.iter()) + .map(|(m, t)| if m { t } else { Some(if_false) }) + .collect_arr() } -} - -#[cfg(feature = "object")] -impl ChunkZip> for ObjectChunked { - fn zip_with( - &self, - mask: &BooleanChunked, - other: &ChunkedArray>, - ) -> PolarsResult>> { - let (truthy, falsy, mask) = (self, other, mask); - let (truthy, falsy, mask) = expand_lengths!(truthy, falsy, mask); - let (left, right, mask) = align_chunks_ternary(&truthy, &falsy, &mask); - let mut ca: Self = left - .as_ref() - .into_iter() - .zip(right.as_ref()) - .zip(mask.as_ref()) - .map(|((left_c, right_c), mask_c)| match mask_c { - Some(true) => left_c.cloned(), - Some(false) => right_c.cloned(), - None => None, - }) - .collect(); - ca.rename(self.name()); - Ok(ca) + fn if_then_else_broadcast_both( + _dtype: ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + mask.iter() + .map(|m| if m { if_true } else { if_false }) + .collect_arr() } } diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index 7e5947301db4..db1f52497ff6 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -406,8 +406,13 @@ impl DataType { ))), Null => Ok(ArrowDataType::Null), #[cfg(feature = "object")] - Object(_, _) => { - polars_bail!(InvalidOperation: "cannot convert Object dtype data to Arrow") + Object(_, Some(reg)) => Ok(reg.physical_dtype.clone()), + #[cfg(feature = "object")] + Object(_, None) => { + // FIXME: find out why we have Objects floating around without a + // known dtype. + // polars_bail!(InvalidOperation: "cannot convert Object dtype without registry to Arrow") + Ok(ArrowDataType::Unknown) }, #[cfg(feature = "dtype-categorical")] Categorical(_, _) | Enum(_, _) => { @@ -428,9 +433,7 @@ impl DataType { Ok(ArrowDataType::Struct(fields)) }, BinaryOffset => Ok(ArrowDataType::LargeBinary), - Unknown => { - polars_bail!(InvalidOperation: "cannot convert Unknown dtype data to Arrow") - }, + Unknown => Ok(ArrowDataType::Unknown), } } diff --git a/crates/polars-core/src/datatypes/mod.rs b/crates/polars-core/src/datatypes/mod.rs index 8c14f78379f8..f2a58b5a66eb 100644 --- a/crates/polars-core/src/datatypes/mod.rs +++ b/crates/polars-core/src/datatypes/mod.rs @@ -13,8 +13,6 @@ mod any_value; mod dtype; mod field; #[cfg(feature = "object")] -mod static_array; -#[cfg(feature = "object")] mod static_array_collect; mod time_unit; diff --git a/crates/polars-core/src/datatypes/static_array.rs b/crates/polars-core/src/datatypes/static_array.rs deleted file mode 100644 index fdd8c3c80097..000000000000 --- a/crates/polars-core/src/datatypes/static_array.rs +++ /dev/null @@ -1,32 +0,0 @@ -use arrow::bitmap::utils::{BitmapIter, ZipValidity}; -use arrow::bitmap::Bitmap; - -use crate::chunked_array::object::{ObjectArray, ObjectValueIter}; -use crate::prelude::*; - -impl StaticArray for ObjectArray { - type ValueT<'a> = &'a T; - type ZeroableValueT<'a> = Option<&'a T>; - type ValueIterT<'a> = ObjectValueIter<'a, T>; - - #[inline] - unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { - self.value_unchecked(idx) - } - - fn values_iter(&self) -> Self::ValueIterT<'_> { - self.values_iter() - } - - fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { - self.iter() - } - - fn with_validity_typed(self, validity: Option) -> Self { - self.with_validity(validity) - } - - fn full_null(_length: usize, _dtype: ArrowDataType) -> Self { - panic!("ObjectArray does not support full_null"); - } -} diff --git a/crates/polars-core/src/datatypes/static_array_collect.rs b/crates/polars-core/src/datatypes/static_array_collect.rs index 758bd6c5b2dd..02974d7b33a8 100644 --- a/crates/polars-core/src/datatypes/static_array_collect.rs +++ b/crates/polars-core/src/datatypes/static_array_collect.rs @@ -1,46 +1,27 @@ use std::sync::Arc; -use arrow::array::ArrayFromIterDtype; +use arrow::array::ArrayFromIter; use arrow::bitmap::Bitmap; -use arrow::datatypes::ArrowDataType; use crate::chunked_array::object::{ObjectArray, PolarsObject}; // TODO: more efficient implementations, I really took the short path here. -impl<'a, T: PolarsObject> ArrayFromIterDtype<&'a T> for ObjectArray { - fn arr_from_iter_with_dtype>( - dtype: ArrowDataType, - iter: I, - ) -> Self { - Self::try_arr_from_iter_with_dtype( - dtype, - iter.into_iter().map(|o| -> Result<_, ()> { Ok(Some(o)) }), - ) - .unwrap() +impl<'a, T: PolarsObject> ArrayFromIter<&'a T> for ObjectArray { + fn arr_from_iter>(iter: I) -> Self { + Self::try_arr_from_iter(iter.into_iter().map(|o| -> Result<_, ()> { Ok(Some(o)) })).unwrap() } - fn try_arr_from_iter_with_dtype>>( - dtype: ArrowDataType, - iter: I, - ) -> Result { - Self::try_arr_from_iter_with_dtype(dtype, iter.into_iter().map(|o| Ok(Some(o?)))) + fn try_arr_from_iter>>(iter: I) -> Result { + Self::try_arr_from_iter(iter.into_iter().map(|o| Ok(Some(o?)))) } } -impl<'a, T: PolarsObject> ArrayFromIterDtype> for ObjectArray { - fn arr_from_iter_with_dtype>>( - dtype: ArrowDataType, - iter: I, - ) -> Self { - Self::try_arr_from_iter_with_dtype( - dtype, - iter.into_iter().map(|o| -> Result<_, ()> { Ok(o) }), - ) - .unwrap() +impl<'a, T: PolarsObject> ArrayFromIter> for ObjectArray { + fn arr_from_iter>>(iter: I) -> Self { + Self::try_arr_from_iter(iter.into_iter().map(|o| -> Result<_, ()> { Ok(o) })).unwrap() } - fn try_arr_from_iter_with_dtype, E>>>( - _dtype: ArrowDataType, + fn try_arr_from_iter, E>>>( iter: I, ) -> Result { let iter = iter.into_iter(); diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 32a697f55090..b9fbcb179177 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -876,9 +876,17 @@ where T: 'static + PolarsDataType, { fn as_ref(&self) -> &ChunkedArray { + #[cfg(feature = "dtype-array")] + let is_array = matches!(T::get_dtype(), DataType::Array(_, _)) + && matches!(self.dtype(), DataType::Array(_, _)); + #[cfg(not(feature = "dtype-array"))] + let is_array = false; + if &T::get_dtype() == self.dtype() || // Needed because we want to get ref of List no matter what the inner type is. (matches!(T::get_dtype(), DataType::List(_)) && matches!(self.dtype(), DataType::List(_))) + // Similarly for arrays. + || is_array { unsafe { &*(self as *const dyn SeriesTrait as *const ChunkedArray) } } else { diff --git a/crates/polars/tests/it/arrow/compute/if_then_else.rs b/crates/polars/tests/it/arrow/compute/if_then_else.rs deleted file mode 100644 index e203d831c39f..000000000000 --- a/crates/polars/tests/it/arrow/compute/if_then_else.rs +++ /dev/null @@ -1,42 +0,0 @@ -use arrow::array::*; -use arrow::compute::if_then_else::if_then_else; -use polars_error::PolarsResult; - -#[test] -fn basics() -> PolarsResult<()> { - let lhs = Int32Array::from_slice([1, 2, 3]); - let rhs = Int32Array::from_slice([4, 5, 6]); - let predicate = BooleanArray::from_slice(vec![true, false, true]); - let c = if_then_else(&predicate, &lhs, &rhs)?; - - let expected = Int32Array::from_slice([1, 5, 3]); - - assert_eq!(expected, c.as_ref()); - Ok(()) -} - -#[test] -fn basics_nulls() -> PolarsResult<()> { - let lhs = Int32Array::from(&[Some(1), None, None]); - let rhs = Int32Array::from(&[None, Some(5), Some(6)]); - let predicate = BooleanArray::from_slice(vec![true, false, true]); - let c = if_then_else(&predicate, &lhs, &rhs)?; - - let expected = Int32Array::from(&[Some(1), Some(5), None]); - - assert_eq!(expected, c.as_ref()); - Ok(()) -} - -#[test] -fn basics_nulls_pred() -> PolarsResult<()> { - let lhs = Int32Array::from_slice([1, 2, 3]); - let rhs = Int32Array::from_slice([4, 5, 6]); - let predicate = BooleanArray::from(&[Some(true), None, Some(false)]); - let result = if_then_else(&predicate, &lhs, &rhs)?; - - let expected = Int32Array::from(&[Some(1), None, Some(6)]); - - assert_eq!(expected, result.as_ref()); - Ok(()) -} diff --git a/crates/polars/tests/it/arrow/compute/mod.rs b/crates/polars/tests/it/arrow/compute/mod.rs index 95126a4a3a54..0f1fe99969e4 100644 --- a/crates/polars/tests/it/arrow/compute/mod.rs +++ b/crates/polars/tests/it/arrow/compute/mod.rs @@ -6,7 +6,5 @@ mod bitwise; mod boolean; #[cfg(feature = "compute_boolean_kleene")] mod boolean_kleene; -#[cfg(feature = "compute_if_then_else")] -mod if_then_else; mod arity_assign; diff --git a/py-polars/tests/unit/functions/test_when_then.py b/py-polars/tests/unit/functions/test_when_then.py index b5aec7177072..0d25b286bdfb 100644 --- a/py-polars/tests/unit/functions/test_when_then.py +++ b/py-polars/tests/unit/functions/test_when_then.py @@ -1,4 +1,9 @@ +from __future__ import annotations + +import itertools +import random from datetime import datetime +from typing import Any import pytest @@ -523,3 +528,71 @@ def test_when_then_null_broadcast() -> None: ).height == 2 ) + + +@pytest.mark.slow() +@pytest.mark.parametrize("len", [1, 10, 100, 500]) +@pytest.mark.parametrize( + ("dtype", "vals"), + [ + pytest.param(pl.Boolean, [False, True], id="Boolean"), + pytest.param(pl.UInt8, [0, 1], id="UInt8"), + pytest.param(pl.UInt16, [0, 1], id="UInt16"), + pytest.param(pl.UInt32, [0, 1], id="UInt32"), + pytest.param(pl.UInt64, [0, 1], id="UInt64"), + pytest.param(pl.Float32, [0.0, 1.0], id="Float32"), + pytest.param(pl.Float64, [0.0, 1.0], id="Float64"), + pytest.param(pl.String, ["0", "12"], id="String"), + pytest.param(pl.Array(pl.String, 2), [["0", "1"], ["3", "4"]], id="StrArray"), + pytest.param(pl.Array(pl.Int64, 2), [[0, 1], [3, 4]], id="IntArray"), + pytest.param(pl.List(pl.String), [["0"], ["1", "2"]], id="List"), + pytest.param( + pl.Struct({"foo": pl.Int32, "bar": pl.String}), + [{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}], + id="Struct", + ), + pytest.param(pl.Object, ["x", "y"], id="Object"), + ], +) +@pytest.mark.parametrize("broadcast", list(itertools.product([False, True], repeat=3))) +def test_when_then_parametric( + len: int, dtype: pl.DataType, vals: list[Any], broadcast: list[bool] +) -> None: + # Makes no sense to broadcast all columns. + if all(broadcast): + return + + rng = random.Random(42) + + for _ in range(10): + mask = rng.choices([False, True, None], k=len) + if_true = rng.choices(vals + [None], k=len) + if_false = rng.choices(vals + [None], k=len) + + py_mask, py_true, py_false = ( + [c[0]] * len if b else c + for b, c in zip(broadcast, [mask, if_true, if_false]) + ) + pl_mask, pl_true, pl_false = ( + c.first() if b else c + for b, c in zip(broadcast, [pl.col.mask, pl.col.if_true, pl.col.if_false]) + ) + + ref = pl.DataFrame( + {"if_true": [t if m else f for m, t, f in zip(py_mask, py_true, py_false)]}, + schema={"if_true": dtype}, + ) + df = pl.DataFrame( + { + "mask": mask, + "if_true": if_true, + "if_false": if_false, + }, + schema={"mask": pl.Boolean, "if_true": dtype, "if_false": dtype}, + ) + + ans = df.select(pl.when(pl_mask).then(pl_true).otherwise(pl_false)) + if dtype != pl.Object: + assert_frame_equal(ref, ans) + else: + assert ref["if_true"].to_list() == ans["if_true"].to_list()