Skip to content

Commit

Permalink
perf: add new when-then-otherwise kernels (#15089)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp committed Mar 18, 2024
1 parent 8cf46e6 commit 9f7ec49
Show file tree
Hide file tree
Showing 42 changed files with 1,445 additions and 479 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ features = [
"compute_boolean_kleene",
"compute_cast",
"compute_comparison",
"compute_if_then_else",
]

[patch.crates-io]
Expand Down
2 changes: 0 additions & 2 deletions crates/polars-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ compute_boolean_kleene = []
compute_cast = ["compute_take", "ryu", "atoi_simd", "itoa", "fast-float"]
compute_comparison = ["compute_take", "compute_boolean"]
compute_hash = ["multiversion"]
compute_if_then_else = []
compute_take = []
compute_temporal = []
compute = [
Expand All @@ -152,7 +151,6 @@ compute = [
"compute_cast",
"compute_comparison",
"compute_hash",
"compute_if_then_else",
"compute_take",
"compute_temporal",
]
Expand Down
23 changes: 23 additions & 0 deletions crates/polars-arrow/src/array/binview/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,29 @@ impl View {
pub fn as_u128(self) -> u128 {
unsafe { std::mem::transmute(self) }
}

#[inline]
pub fn new_from_bytes(bytes: &[u8], buffer_idx: u32, offset: u32) -> Self {
if bytes.len() <= 12 {
let mut ret = Self {
length: bytes.len() as u32,
..Default::default()
};
let ret_ptr = &mut ret as *mut _ as *mut u8;
unsafe {
core::ptr::copy_nonoverlapping(bytes.as_ptr(), ret_ptr.add(4), bytes.len());
}
ret
} else {
let prefix_buf: [u8; 4] = std::array::from_fn(|i| *bytes.get(i).unwrap_or(&0));
Self {
length: bytes.len() as u32,
prefix: u32::from_le_bytes(prefix_buf),
buffer_idx,
offset,
}
}
}
}

impl IsNull for View {
Expand Down
18 changes: 17 additions & 1 deletion crates/polars-arrow/src/array/growable/binview.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use polars_utils::unwrap::UnwrapUncheckedRelease;

use super::Growable;
use crate::array::binview::{BinaryViewArrayGeneric, View, ViewType};
use crate::array::growable::utils::{extend_validity, prepare_validity};
use crate::array::growable::utils::{extend_validity, extend_validity_copies, prepare_validity};
use crate::array::Array;
use crate::bitmap::MutableBitmap;
use crate::buffer::Buffer;
Expand Down Expand Up @@ -166,6 +166,22 @@ impl<'a, T: ViewType + ?Sized> Growable<'a> for GrowableBinaryViewArray<'a, T> {
unsafe { self.extend_unchecked(index, start, len) }
}

unsafe fn extend_copies(&mut self, index: usize, start: usize, len: usize, copies: usize) {
let orig_view_start = self.views.len();
if copies > 0 {
unsafe { self.extend_unchecked(index, start, len) }
}
if copies > 1 {
let array = *self.arrays.get_unchecked(index);
extend_validity_copies(&mut self.validity, array, start, len, copies - 1);
let extended_view_end = self.views.len();
for _ in 0..copies - 1 {
self.views
.extend_from_within(orig_view_start..extended_view_end)
}
}
}

fn extend_validity(&mut self, additional: usize) {
self.views
.extend(std::iter::repeat(View::default()).take(additional));
Expand Down
12 changes: 10 additions & 2 deletions crates/polars-arrow/src/array/growable/fixed_size_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use polars_utils::slice::GetSaferUnchecked;

use super::{make_growable, Growable};
use crate::array::growable::utils::{extend_validity, prepare_validity};
use crate::array::growable::utils::{extend_validity, extend_validity_copies, prepare_validity};
use crate::array::{Array, FixedSizeListArray};
use crate::bitmap::MutableBitmap;
use crate::datatypes::ArrowDataType;
Expand Down Expand Up @@ -55,7 +55,7 @@ impl<'a> GrowableFixedSizeList<'a> {
}
}

fn to(&mut self) -> FixedSizeListArray {
pub fn to(&mut self) -> FixedSizeListArray {
let validity = std::mem::take(&mut self.validity);
let values = self.values.as_box();

Expand All @@ -76,6 +76,14 @@ impl<'a> Growable<'a> for GrowableFixedSizeList<'a> {
.extend(index, start * self.size, len * self.size);
}

unsafe fn extend_copies(&mut self, index: usize, start: usize, len: usize, copies: usize) {
let array = *self.arrays.get_unchecked_release(index);
extend_validity_copies(&mut self.validity, array, start, len, copies);

self.values
.extend_copies(index, start * self.size, len * self.size, copies);
}

fn extend_validity(&mut self, additional: usize) {
self.values.extend_validity(additional * self.size);
if let Some(validity) = &mut self.validity {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/array/growable/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl<'a, O: Offset> GrowableList<'a, O> {
}
}

fn to(&mut self) -> ListArray<O> {
pub fn to(&mut self) -> ListArray<O> {
let validity = std::mem::take(&mut self.validity);
let offsets = std::mem::take(&mut self.offsets);
let values = self.values.as_box();
Expand Down
12 changes: 11 additions & 1 deletion crates/polars-arrow/src/array/growable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,19 @@ pub trait Growable<'a> {
/// a slice starting at `start` and length `len`.
///
/// # Safety
/// Doesn't do any bound checks
/// Doesn't do any bound checks.
unsafe fn extend(&mut self, index: usize, start: usize, len: usize);

/// Same as extend, except it repeats the extension `copies` times.
///
/// # Safety
/// Doesn't do any bound checks.
unsafe fn extend_copies(&mut self, index: usize, start: usize, len: usize, copies: usize) {
for _ in 0..copies {
self.extend(index, start, len)
}
}

/// Extends this [`Growable`] with null elements, disregarding the bound arrays
///
/// # Safety
Expand Down
15 changes: 14 additions & 1 deletion crates/polars-arrow/src/array/growable/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use polars_utils::slice::GetSaferUnchecked;

use super::Growable;
use crate::array::growable::utils::{extend_validity, prepare_validity};
use crate::array::growable::utils::{extend_validity, extend_validity_copies, prepare_validity};
use crate::array::{Array, PrimitiveArray};
use crate::bitmap::MutableBitmap;
use crate::datatypes::ArrowDataType;
Expand Down Expand Up @@ -66,6 +66,19 @@ impl<'a, T: NativeType> Growable<'a> for GrowablePrimitive<'a, T> {
.extend_from_slice(values.get_unchecked_release(start..start + len));
}

#[inline]
unsafe fn extend_copies(&mut self, index: usize, start: usize, len: usize, copies: usize) {
let array = *self.arrays.get_unchecked_release(index);
extend_validity_copies(&mut self.validity, array, start, len, copies);

let values = array.values().as_slice();
self.values.reserve(len * copies);
for _ in 0..copies {
self.values
.extend_from_slice(values.get_unchecked_release(start..start + len));
}
}

#[inline]
fn extend_validity(&mut self, additional: usize) {
self.values
Expand Down
24 changes: 24 additions & 0 deletions crates/polars-arrow/src/array/growable/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,27 @@ pub(super) fn extend_validity(
}
}
}

pub(super) fn extend_validity_copies(
mutable_validity: &mut Option<MutableBitmap>,
array: &dyn Array,
start: usize,
len: usize,
copies: usize,
) {
if let Some(mutable_validity) = mutable_validity {
match array.validity() {
None => mutable_validity.extend_constant(len * copies, true),
Some(validity) => {
debug_assert!(start + len <= validity.len());
let (slice, offset, _) = validity.as_slice();
// SAFETY: invariant offset + length <= slice.len()
for _ in 0..copies {
unsafe {
mutable_validity.extend_from_slice_unchecked(slice, start + offset, len);
}
}
},
}
}
}
30 changes: 28 additions & 2 deletions crates/polars-arrow/src/array/static_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use crate::array::binview::BinaryViewValueIter;
use crate::array::static_array_collect::ArrayFromIterDtype;
use crate::array::{
Array, ArrayValuesIter, BinaryArray, BinaryValueIter, BinaryViewArray, BooleanArray,
FixedSizeListArray, ListArray, ListValuesIter, PrimitiveArray, Utf8Array, Utf8ValuesIter,
Utf8ViewArray,
FixedSizeListArray, ListArray, ListValuesIter, MutableBinaryViewArray, PrimitiveArray,
Utf8Array, Utf8ValuesIter, Utf8ViewArray,
};
use crate::bitmap::utils::{BitmapIter, ZipValidity};
use crate::bitmap::Bitmap;
Expand All @@ -18,6 +18,7 @@ pub trait StaticArray:
+ for<'a> ArrayFromIterDtype<Self::ValueT<'a>>
+ for<'a> ArrayFromIterDtype<Self::ZeroableValueT<'a>>
+ for<'a> ArrayFromIterDtype<Option<Self::ValueT<'a>>>
+ Clone
{
type ValueT<'a>: Clone
where
Expand Down Expand Up @@ -82,6 +83,10 @@ pub trait StaticArray:
}

fn full_null(length: usize, dtype: ArrowDataType) -> Self;

fn full(length: usize, value: Self::ValueT<'_>, dtype: ArrowDataType) -> Self {
Self::arr_from_iter_with_dtype(dtype, std::iter::repeat(value).take(length))
}
}

pub trait ParameterFreeDtypeStaticArray: StaticArray {
Expand Down Expand Up @@ -126,6 +131,10 @@ impl<T: NativeType> StaticArray for PrimitiveArray<T> {
fn full_null(length: usize, dtype: ArrowDataType) -> Self {
Self::new_null(dtype, length)
}

fn full(length: usize, value: Self::ValueT<'_>, _dtype: ArrowDataType) -> Self {
PrimitiveArray::from_vec(vec![value; length])
}
}

impl<T: NativeType> ParameterFreeDtypeStaticArray for PrimitiveArray<T> {
Expand Down Expand Up @@ -167,6 +176,10 @@ impl StaticArray for BooleanArray {
fn full_null(length: usize, dtype: ArrowDataType) -> Self {
Self::new_null(dtype, length)
}

fn full(length: usize, value: Self::ValueT<'_>, _dtype: ArrowDataType) -> Self {
Bitmap::new_with_value(value, length).into()
}
}

impl ParameterFreeDtypeStaticArray for BooleanArray {
Expand Down Expand Up @@ -265,6 +278,12 @@ impl StaticArray for BinaryViewArray {
fn full_null(length: usize, dtype: ArrowDataType) -> Self {
Self::new_null(dtype, length)
}

fn full(length: usize, value: Self::ValueT<'_>, _dtype: ArrowDataType) -> Self {
let mut builder = MutableBinaryViewArray::with_capacity(length);
builder.extend_constant(length, Some(value));
builder.into()
}
}

impl ParameterFreeDtypeStaticArray for BinaryViewArray {
Expand Down Expand Up @@ -297,6 +316,13 @@ impl StaticArray for Utf8ViewArray {
fn full_null(length: usize, dtype: ArrowDataType) -> Self {
Self::new_null(dtype, length)
}

fn full(length: usize, value: Self::ValueT<'_>, _dtype: ArrowDataType) -> Self {
unsafe {
BinaryViewArray::full(length, value.as_bytes(), ArrowDataType::BinaryView)
.to_utf8view_unchecked()
}
}
}

impl ParameterFreeDtypeStaticArray for Utf8ViewArray {
Expand Down
32 changes: 28 additions & 4 deletions crates/polars-arrow/src/array/static_array_collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@ pub trait ArrayFromIter<T>: Sized {
impl<T, A: ParameterFreeDtypeStaticArray + ArrayFromIter<T>> ArrayFromIterDtype<T> for A {
#[inline(always)]
fn arr_from_iter_with_dtype<I: IntoIterator<Item = T>>(dtype: ArrowDataType, iter: I) -> Self {
debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype()));
// FIXME: currently some Object arrays have Unknown dtype, when this is fixed remove this bypass.
if dtype != ArrowDataType::Unknown {
debug_assert_eq!(
std::mem::discriminant(&dtype),
std::mem::discriminant(&A::get_dtype())
);
}
Self::arr_from_iter(iter)
}

Expand All @@ -80,7 +86,13 @@ impl<T, A: ParameterFreeDtypeStaticArray + ArrayFromIter<T>> ArrayFromIterDtype<
I: IntoIterator<Item = T>,
I::IntoIter: TrustedLen,
{
debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype()));
// FIXME: currently some Object arrays have Unknown dtype, when this is fixed remove this bypass.
if dtype != ArrowDataType::Unknown {
debug_assert_eq!(
std::mem::discriminant(&dtype),
std::mem::discriminant(&A::get_dtype())
);
}
Self::arr_from_iter_trusted(iter)
}

Expand All @@ -89,7 +101,13 @@ impl<T, A: ParameterFreeDtypeStaticArray + ArrayFromIter<T>> ArrayFromIterDtype<
dtype: ArrowDataType,
iter: I,
) -> Result<Self, E> {
debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype()));
// FIXME: currently some Object arrays have Unknown dtype, when this is fixed remove this bypass.
if dtype != ArrowDataType::Unknown {
debug_assert_eq!(
std::mem::discriminant(&dtype),
std::mem::discriminant(&A::get_dtype())
);
}
Self::try_arr_from_iter(iter)
}

Expand All @@ -99,7 +117,13 @@ impl<T, A: ParameterFreeDtypeStaticArray + ArrayFromIter<T>> ArrayFromIterDtype<
I: IntoIterator<Item = Result<T, E>>,
I::IntoIter: TrustedLen,
{
debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype()));
// FIXME: currently some Object arrays have Unknown dtype, when this is fixed remove this bypass.
if dtype != ArrowDataType::Unknown {
debug_assert_eq!(
std::mem::discriminant(&dtype),
std::mem::discriminant(&A::get_dtype())
);
}
Self::try_arr_from_iter_trusted(iter)
}
}
Expand Down
20 changes: 14 additions & 6 deletions crates/polars-arrow/src/bitmap/bitmap_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ pub(crate) fn align(bitmap: &Bitmap, new_offset: usize) -> Bitmap {
bitmap.sliced(new_offset, length)
}

#[inline]
/// Compute bitwise AND operation
/// Compute bitwise A AND B operation.
pub fn and(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap {
if lhs.unset_bits() == lhs.len() || rhs.unset_bits() == rhs.len() {
assert_eq!(lhs.len(), rhs.len());
Expand All @@ -171,8 +170,12 @@ pub fn and(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap {
}
}

#[inline]
/// Compute bitwise OR operation
/// Compute bitwise A AND NOT B operation.
pub fn and_not(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap {
binary(lhs, rhs, |x, y| x & !y)
}

/// Compute bitwise A OR B operation.
pub fn or(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap {
if lhs.unset_bits() == 0 || rhs.unset_bits() == 0 {
assert_eq!(lhs.len(), rhs.len());
Expand All @@ -184,8 +187,12 @@ pub fn or(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap {
}
}

#[inline]
/// Compute bitwise XOR operation
/// Compute bitwise A OR NOT B operation.
pub fn or_not(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap {
binary(lhs, rhs, |x, y| x | !y)
}

/// Compute bitwise XOR operation.
pub fn xor(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap {
let lhs_nulls = lhs.unset_bits();
let rhs_nulls = rhs.unset_bits();
Expand All @@ -208,6 +215,7 @@ pub fn xor(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap {
}
}

/// Compute bitwise equality (not XOR) operation.
fn eq(lhs: &Bitmap, rhs: &Bitmap) -> bool {
if lhs.len() != rhs.len() {
return false;
Expand Down
Loading

0 comments on commit 9f7ec49

Please sign in to comment.