diff --git a/crates/polars-compute/src/comparisons/mod.rs b/crates/polars-compute/src/comparisons/mod.rs index a0baebad6b7d..c1b7dcb75981 100644 --- a/crates/polars-compute/src/comparisons/mod.rs +++ b/crates/polars-compute/src/comparisons/mod.rs @@ -72,17 +72,6 @@ pub trait TotalOrdKernel: Sized + Array { } } -// Trait to enable the scalar blanket implementation. -trait NotSimdPrimitive {} - -#[cfg(not(feature = "simd"))] -impl NotSimdPrimitive for T {} - -#[cfg(feature = "simd")] -impl NotSimdPrimitive for u128 {} -#[cfg(feature = "simd")] -impl NotSimdPrimitive for i128 {} - mod scalar; mod view; diff --git a/crates/polars-compute/src/comparisons/scalar.rs b/crates/polars-compute/src/comparisons/scalar.rs index bc338f3816a8..f2f245ed34c5 100644 --- a/crates/polars-compute/src/comparisons/scalar.rs +++ b/crates/polars-compute/src/comparisons/scalar.rs @@ -1,11 +1,11 @@ use arrow::array::{BinaryArray, BooleanArray, PrimitiveArray, Utf8Array}; use arrow::bitmap::{self, Bitmap}; -use arrow::types::NativeType; use polars_utils::total_ord::{TotalEq, TotalOrd}; -use super::{NotSimdPrimitive, TotalOrdKernel}; +use super::TotalOrdKernel; +use crate::NotSimdPrimitive; -impl TotalOrdKernel for PrimitiveArray { +impl TotalOrdKernel for PrimitiveArray { type Scalar = T; fn tot_lt_kernel(&self, other: &Self) -> Bitmap { diff --git a/crates/polars-compute/src/filter/primitive.rs b/crates/polars-compute/src/filter/primitive.rs index 10c00afdff1c..9cc542b60978 100644 --- a/crates/polars-compute/src/filter/primitive.rs +++ b/crates/polars-compute/src/filter/primitive.rs @@ -1,5 +1,7 @@ use arrow::bitmap::Bitmap; use bytemuck::{cast_slice, cast_vec, Pod}; +#[cfg(all(target_arch = "x86_64", feature = "simd"))] +use polars_utils::cpuid::is_avx512_enabled; #[cfg(all(target_arch = "x86_64", feature = "simd"))] use super::avx512; @@ -28,7 +30,7 @@ pub fn filter_values(values: &[T], mask: &Bitmap) -> Vec { fn filter_values_u8(values: &[u8], mask: &Bitmap) -> Vec { #[cfg(all(target_arch = "x86_64", feature = "simd"))] - if std::arch::is_x86_feature_detected!("avx512vbmi2") { + if is_avx512_enabled() && std::arch::is_x86_feature_detected!("avx512vbmi2") { return filter_values_generic(values, mask, 64, avx512::filter_u8_avx512vbmi2); } @@ -37,7 +39,7 @@ fn filter_values_u8(values: &[u8], mask: &Bitmap) -> Vec { fn filter_values_u16(values: &[u16], mask: &Bitmap) -> Vec { #[cfg(all(target_arch = "x86_64", feature = "simd"))] - if std::arch::is_x86_feature_detected!("avx512vbmi2") { + if is_avx512_enabled() && std::arch::is_x86_feature_detected!("avx512vbmi2") { return filter_values_generic(values, mask, 32, avx512::filter_u16_avx512vbmi2); } @@ -46,7 +48,7 @@ fn filter_values_u16(values: &[u16], mask: &Bitmap) -> Vec { fn filter_values_u32(values: &[u32], mask: &Bitmap) -> Vec { #[cfg(all(target_arch = "x86_64", feature = "simd"))] - if std::arch::is_x86_feature_detected!("avx512f") { + if is_avx512_enabled() { return filter_values_generic(values, mask, 16, avx512::filter_u32_avx512f); } @@ -55,7 +57,7 @@ fn filter_values_u32(values: &[u32], mask: &Bitmap) -> Vec { fn filter_values_u64(values: &[u64], mask: &Bitmap) -> Vec { #[cfg(all(target_arch = "x86_64", feature = "simd"))] - if std::arch::is_x86_feature_detected!("avx512f") { + if is_avx512_enabled() { return filter_values_generic(values, mask, 8, avx512::filter_u64_avx512f); } diff --git a/crates/polars-compute/src/if_then_else/mod.rs b/crates/polars-compute/src/if_then_else/mod.rs index 3f704ca1d5f4..f912473594ba 100644 --- a/crates/polars-compute/src/if_then_else/mod.rs +++ b/crates/polars-compute/src/if_then_else/mod.rs @@ -4,13 +4,16 @@ use arrow::array::{Array, PrimitiveArray}; use arrow::bitmap::utils::{align_bitslice_start_u8, SlicesIterator}; use arrow::bitmap::{self, Bitmap}; use arrow::datatypes::ArrowDataType; -use arrow::types::NativeType; use polars_utils::slice::load_padded_le_u64; +use crate::NotSimdPrimitive; + mod array; mod boolean; mod list; mod scalar; +#[cfg(feature = "simd")] +mod simd; mod view; pub trait IfThenElseKernel: Sized + Array { @@ -35,7 +38,7 @@ pub trait IfThenElseKernel: Sized + Array { ) -> Self; } -impl IfThenElseKernel for PrimitiveArray { +impl IfThenElseKernel for PrimitiveArray { type Scalar<'a> = T; fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { diff --git a/crates/polars-compute/src/if_then_else/scalar.rs b/crates/polars-compute/src/if_then_else/scalar.rs index 2e1d6b396ddb..310da8cfb121 100644 --- a/crates/polars-compute/src/if_then_else/scalar.rs +++ b/crates/polars-compute/src/if_then_else/scalar.rs @@ -6,7 +6,7 @@ pub fn if_then_else_scalar_rest( if_false: &[T], out: &mut [MaybeUninit], ) { - assert!(if_true.len() <= out.len()); // Removes bounds checks in inner loop. + assert!(if_true.len() == out.len()); // Removes bounds checks in inner loop. let true_it = if_true.iter().copied(); let false_it = if_false.iter().copied(); for (i, (t, f)) in true_it.zip(false_it).enumerate() { @@ -21,7 +21,7 @@ pub fn if_then_else_broadcast_false_scalar_rest( if_false: T, out: &mut [MaybeUninit], ) { - assert!(if_true.len() <= out.len()); // Removes bounds checks in inner loop. + assert!(if_true.len() == out.len()); // Removes bounds checks in inner loop. let true_it = if_true.iter().copied(); for (i, t) in true_it.enumerate() { let src = if (mask >> i) & 1 != 0 { t } else { if_false }; diff --git a/crates/polars-compute/src/if_then_else/simd.rs b/crates/polars-compute/src/if_then_else/simd.rs new file mode 100644 index 000000000000..f17719711eb0 --- /dev/null +++ b/crates/polars-compute/src/if_then_else/simd.rs @@ -0,0 +1,155 @@ +use std::mem::MaybeUninit; +use std::simd::{Mask, Simd, SimdElement}; + +use arrow::array::PrimitiveArray; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; + +use super::{ + if_then_else_loop, if_then_else_loop_broadcast_both, if_then_else_loop_broadcast_false, + if_then_else_validity, scalar, IfThenElseKernel, +}; + +#[cfg(target_arch = "x86_64")] +fn select_simd_64( + mask: u64, + if_true: Simd, + if_false: Simd, + out: &mut [MaybeUninit; 64], +) { + let mv = Mask::<::Mask, 64>::from_bitmask(mask); + let ret = mv.select(if_true, if_false); + unsafe { + let src = ret.as_array().as_ptr() as *const MaybeUninit; + core::ptr::copy_nonoverlapping(src, out.as_mut_ptr(), 64); + } +} + +#[cfg(target_arch = "x86_64")] +fn if_then_else_simd_64( + mask: u64, + if_true: &[T; 64], + if_false: &[T; 64], + out: &mut [MaybeUninit; 64], +) { + select_simd_64( + mask, + Simd::from_slice(if_true), + Simd::from_slice(if_false), + out, + ) +} + +#[cfg(target_arch = "x86_64")] +fn if_then_else_broadcast_false_simd_64( + mask: u64, + if_true: &[T; 64], + if_false: T, + out: &mut [MaybeUninit; 64], +) { + select_simd_64(mask, Simd::from_slice(if_true), Simd::splat(if_false), out) +} + +#[cfg(target_arch = "x86_64")] +fn if_then_else_broadcast_both_simd_64( + mask: u64, + if_true: T, + if_false: T, + out: &mut [MaybeUninit; 64], +) { + select_simd_64(mask, Simd::splat(if_true), Simd::splat(if_false), out) +} + +macro_rules! impl_if_then_else { + ($T: ty) => { + impl IfThenElseKernel for PrimitiveArray<$T> { + type Scalar<'a> = $T; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + let values = if_then_else_loop( + mask, + if_true.values(), + if_false.values(), + scalar::if_then_else_scalar_rest, + // Auto-generated SIMD was slower on ARM. + #[cfg(target_arch = "x86_64")] + if_then_else_simd_64, + #[cfg(not(target_arch = "x86_64"))] + scalar::if_then_else_scalar_64, + ); + let validity = if_then_else_validity(mask, if_true.validity(), if_false.validity()); + PrimitiveArray::from_vec(values).with_validity(validity) + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + let values = if_then_else_loop_broadcast_false( + true, + mask, + if_false.values(), + if_true, + // Auto-generated SIMD was slower on ARM. + #[cfg(target_arch = "x86_64")] + if_then_else_broadcast_false_simd_64, + #[cfg(not(target_arch = "x86_64"))] + scalar::if_then_else_broadcast_false_scalar_64, + ); + let validity = if_then_else_validity(mask, None, if_false.validity()); + PrimitiveArray::from_vec(values).with_validity(validity) + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + let values = if_then_else_loop_broadcast_false( + false, + mask, + if_true.values(), + if_false, + // Auto-generated SIMD was slower on ARM. + #[cfg(target_arch = "x86_64")] + if_then_else_broadcast_false_simd_64, + #[cfg(not(target_arch = "x86_64"))] + scalar::if_then_else_broadcast_false_scalar_64, + ); + let validity = if_then_else_validity(mask, if_true.validity(), None); + PrimitiveArray::from_vec(values).with_validity(validity) + } + + fn if_then_else_broadcast_both( + _dtype: ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + let values = if_then_else_loop_broadcast_both( + mask, + if_true, + if_false, + // Auto-generated SIMD was slower on ARM. + #[cfg(target_arch = "x86_64")] + if_then_else_broadcast_both_simd_64, + #[cfg(not(target_arch = "x86_64"))] + scalar::if_then_else_broadcast_both_scalar_64, + ); + PrimitiveArray::from_vec(values) + } + } + }; +} + +impl_if_then_else!(i8); +impl_if_then_else!(i16); +impl_if_then_else!(i32); +impl_if_then_else!(i64); +impl_if_then_else!(u8); +impl_if_then_else!(u16); +impl_if_then_else!(u32); +impl_if_then_else!(u64); +impl_if_then_else!(f32); +impl_if_then_else!(f64); diff --git a/crates/polars-compute/src/if_then_else/view.rs b/crates/polars-compute/src/if_then_else/view.rs index 5b1100153b03..b304534064ff 100644 --- a/crates/polars-compute/src/if_then_else/view.rs +++ b/crates/polars-compute/src/if_then_else/view.rs @@ -7,9 +7,7 @@ use arrow::buffer::Buffer; use arrow::datatypes::ArrowDataType; use super::IfThenElseKernel; -use crate::if_then_else::scalar::{ - if_then_else_broadcast_both_scalar_64, if_then_else_broadcast_false_scalar_64, -}; +use crate::if_then_else::scalar::if_then_else_broadcast_both_scalar_64; // Makes a buffer and a set of views into that buffer from a set of strings. // Does not allocate a buffer if not necessary. @@ -87,7 +85,7 @@ impl IfThenElseKernel for BinaryViewArray { mask, if_false.views(), true_view, - if_then_else_broadcast_false_scalar_64, + if_then_else_broadcast_false_view_64, ); let validity = super::if_then_else_validity(mask, None, if_false.validity()); @@ -120,7 +118,7 @@ impl IfThenElseKernel for BinaryViewArray { mask, if_true.views(), false_view, - if_then_else_broadcast_false_scalar_64, + if_then_else_broadcast_false_view_64, ); let validity = super::if_then_else_validity(mask, if_true.validity(), None); @@ -215,13 +213,14 @@ pub fn if_then_else_view_rest( false_buffer_idx_offset: u32, ) { assert!(if_true.len() <= out.len()); // Removes bounds checks in inner loop. - let true_it = if_true.iter().copied(); - let false_it = if_false.iter().copied(); + let true_it = if_true.iter(); + let false_it = if_false.iter(); for (i, (t, f)) in true_it.zip(false_it).enumerate() { // Written like this, this loop *should* be branchless. // Unfortunately we're still dependent on the compiler. let m = (mask >> i) & 1 != 0; - let mut v = if m { t } else { f }; + let src = if m { t } else { f }; + let mut v = *src; let offset = if m | (v.length <= 12) { // Yes, | instead of || is intentional. 0 @@ -242,3 +241,18 @@ pub fn if_then_else_view_64( ) { if_then_else_view_rest(mask, if_true, if_false, out, false_buffer_idx_offset) } + +// Using the scalar variant of this works, but was slower, we want to select a source pointer and +// then copy it. Using this version for the integers results in branches. +pub fn if_then_else_broadcast_false_view_64( + mask: u64, + if_true: &[View; 64], + if_false: View, + out: &mut [MaybeUninit; 64], +) { + assert!(if_true.len() == out.len()); // Removes bounds checks in inner loop. + for (i, t) in if_true.iter().enumerate() { + let src = if (mask >> i) & 1 != 0 { t } else { &if_false }; + out[i] = MaybeUninit::new(*src); + } +} diff --git a/crates/polars-compute/src/lib.rs b/crates/polars-compute/src/lib.rs index cc477a817739..0911243e3b4d 100644 --- a/crates/polars-compute/src/lib.rs +++ b/crates/polars-compute/src/lib.rs @@ -5,6 +5,8 @@ feature(stdarch_x86_avx512) )] +use arrow::types::NativeType; + pub mod arithmetic; pub mod comparisons; pub mod filter; @@ -12,3 +14,14 @@ pub mod if_then_else; pub mod min_max; pub mod arity; + +// Trait to enable the scalar blanket implementation. +pub trait NotSimdPrimitive: NativeType {} + +#[cfg(not(feature = "simd"))] +impl NotSimdPrimitive for T {} + +#[cfg(feature = "simd")] +impl NotSimdPrimitive for u128 {} +#[cfg(feature = "simd")] +impl NotSimdPrimitive for i128 {} diff --git a/crates/polars-utils/src/cpuid.rs b/crates/polars-utils/src/cpuid.rs index 71eda848a878..f7642e3e574c 100644 --- a/crates/polars-utils/src/cpuid.rs +++ b/crates/polars-utils/src/cpuid.rs @@ -37,3 +37,27 @@ pub fn has_fast_bmi2() -> bool { false } + +#[inline] +pub fn is_avx512_enabled() -> bool { + #[cfg(target_arch = "x86_64")] + { + static CACHE: OnceLock = OnceLock::new(); + return *CACHE.get_or_init(|| { + if !std::arch::is_x86_feature_detected!("avx512f") { + return false; + } + + if std::env::var("POLARS_DISABLE_AVX512") + .map(|var| var == "1") + .unwrap_or(false) + { + return false; + } + + true + }); + } + + false +}