Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: add SIMD support for if-then-else kernels #15131

Merged
merged 8 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions crates/polars-compute/src/comparisons/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,6 @@ pub trait TotalOrdKernel: Sized + Array {
}
}

// Trait to enable the scalar blanket implementation.
trait NotSimdPrimitive {}

#[cfg(not(feature = "simd"))]
impl<T> NotSimdPrimitive for T {}

#[cfg(feature = "simd")]
impl NotSimdPrimitive for u128 {}
#[cfg(feature = "simd")]
impl NotSimdPrimitive for i128 {}

mod scalar;
mod view;

Expand Down
6 changes: 3 additions & 3 deletions crates/polars-compute/src/comparisons/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use arrow::array::{BinaryArray, BooleanArray, PrimitiveArray, Utf8Array};
use arrow::bitmap::{self, Bitmap};
use arrow::types::NativeType;
use polars_utils::total_ord::{TotalEq, TotalOrd};

use super::{NotSimdPrimitive, TotalOrdKernel};
use super::TotalOrdKernel;
use crate::NotSimdPrimitive;

impl<T: NativeType + NotSimdPrimitive + TotalOrd> TotalOrdKernel for PrimitiveArray<T> {
impl<T: NotSimdPrimitive + TotalOrd> TotalOrdKernel for PrimitiveArray<T> {
type Scalar = T;

fn tot_lt_kernel(&self, other: &Self) -> Bitmap {
Expand Down
10 changes: 6 additions & 4 deletions crates/polars-compute/src/filter/primitive.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use arrow::bitmap::Bitmap;
use bytemuck::{cast_slice, cast_vec, Pod};
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
use polars_utils::cpuid::is_avx512_enabled;

#[cfg(all(target_arch = "x86_64", feature = "simd"))]
use super::avx512;
Expand Down Expand Up @@ -28,7 +30,7 @@ pub fn filter_values<T: Pod>(values: &[T], mask: &Bitmap) -> Vec<T> {

fn filter_values_u8(values: &[u8], mask: &Bitmap) -> Vec<u8> {
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
if std::arch::is_x86_feature_detected!("avx512vbmi2") {
if is_avx512_enabled() && std::arch::is_x86_feature_detected!("avx512vbmi2") {
return filter_values_generic(values, mask, 64, avx512::filter_u8_avx512vbmi2);
}

Expand All @@ -37,7 +39,7 @@ fn filter_values_u8(values: &[u8], mask: &Bitmap) -> Vec<u8> {

fn filter_values_u16(values: &[u16], mask: &Bitmap) -> Vec<u16> {
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
if std::arch::is_x86_feature_detected!("avx512vbmi2") {
if is_avx512_enabled() && std::arch::is_x86_feature_detected!("avx512vbmi2") {
return filter_values_generic(values, mask, 32, avx512::filter_u16_avx512vbmi2);
}

Expand All @@ -46,7 +48,7 @@ fn filter_values_u16(values: &[u16], mask: &Bitmap) -> Vec<u16> {

fn filter_values_u32(values: &[u32], mask: &Bitmap) -> Vec<u32> {
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
if std::arch::is_x86_feature_detected!("avx512f") {
if is_avx512_enabled() {
return filter_values_generic(values, mask, 16, avx512::filter_u32_avx512f);
}

Expand All @@ -55,7 +57,7 @@ fn filter_values_u32(values: &[u32], mask: &Bitmap) -> Vec<u32> {

fn filter_values_u64(values: &[u64], mask: &Bitmap) -> Vec<u64> {
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
if std::arch::is_x86_feature_detected!("avx512f") {
if is_avx512_enabled() {
return filter_values_generic(values, mask, 8, avx512::filter_u64_avx512f);
}

Expand Down
7 changes: 5 additions & 2 deletions crates/polars-compute/src/if_then_else/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@ use arrow::array::{Array, PrimitiveArray};
use arrow::bitmap::utils::{align_bitslice_start_u8, SlicesIterator};
use arrow::bitmap::{self, Bitmap};
use arrow::datatypes::ArrowDataType;
use arrow::types::NativeType;
use polars_utils::slice::load_padded_le_u64;

use crate::NotSimdPrimitive;

mod array;
mod boolean;
mod list;
mod scalar;
#[cfg(feature = "simd")]
mod simd;
mod view;

pub trait IfThenElseKernel: Sized + Array {
Expand All @@ -35,7 +38,7 @@ pub trait IfThenElseKernel: Sized + Array {
) -> Self;
}

impl<T: NativeType> IfThenElseKernel for PrimitiveArray<T> {
impl<T: NotSimdPrimitive> IfThenElseKernel for PrimitiveArray<T> {
type Scalar<'a> = T;

fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-compute/src/if_then_else/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub fn if_then_else_scalar_rest<T: Copy>(
if_false: &[T],
out: &mut [MaybeUninit<T>],
) {
assert!(if_true.len() <= out.len()); // Removes bounds checks in inner loop.
assert!(if_true.len() == out.len()); // Removes bounds checks in inner loop.
let true_it = if_true.iter().copied();
let false_it = if_false.iter().copied();
for (i, (t, f)) in true_it.zip(false_it).enumerate() {
Expand All @@ -21,7 +21,7 @@ pub fn if_then_else_broadcast_false_scalar_rest<T: Copy>(
if_false: T,
out: &mut [MaybeUninit<T>],
) {
assert!(if_true.len() <= out.len()); // Removes bounds checks in inner loop.
assert!(if_true.len() == out.len()); // Removes bounds checks in inner loop.
let true_it = if_true.iter().copied();
for (i, t) in true_it.enumerate() {
let src = if (mask >> i) & 1 != 0 { t } else { if_false };
Expand Down
155 changes: 155 additions & 0 deletions crates/polars-compute/src/if_then_else/simd.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
use std::mem::MaybeUninit;
use std::simd::{Mask, Simd, SimdElement};

use arrow::array::PrimitiveArray;
use arrow::bitmap::Bitmap;
use arrow::datatypes::ArrowDataType;

use super::{
if_then_else_loop, if_then_else_loop_broadcast_both, if_then_else_loop_broadcast_false,
if_then_else_validity, scalar, IfThenElseKernel,
};

#[cfg(target_arch = "x86_64")]
fn select_simd_64<T: Copy + SimdElement>(
mask: u64,
if_true: Simd<T, 64>,
if_false: Simd<T, 64>,
out: &mut [MaybeUninit<T>; 64],
) {
let mv = Mask::<<T as SimdElement>::Mask, 64>::from_bitmask(mask);
let ret = mv.select(if_true, if_false);
unsafe {
let src = ret.as_array().as_ptr() as *const MaybeUninit<T>;
core::ptr::copy_nonoverlapping(src, out.as_mut_ptr(), 64);
}
}

#[cfg(target_arch = "x86_64")]
fn if_then_else_simd_64<T: Copy + SimdElement>(
mask: u64,
if_true: &[T; 64],
if_false: &[T; 64],
out: &mut [MaybeUninit<T>; 64],
) {
select_simd_64(
mask,
Simd::from_slice(if_true),
Simd::from_slice(if_false),
out,
)
}

#[cfg(target_arch = "x86_64")]
fn if_then_else_broadcast_false_simd_64<T: Copy + SimdElement>(
mask: u64,
if_true: &[T; 64],
if_false: T,
out: &mut [MaybeUninit<T>; 64],
) {
select_simd_64(mask, Simd::from_slice(if_true), Simd::splat(if_false), out)
}

#[cfg(target_arch = "x86_64")]
fn if_then_else_broadcast_both_simd_64<T: Copy + SimdElement>(
mask: u64,
if_true: T,
if_false: T,
out: &mut [MaybeUninit<T>; 64],
) {
select_simd_64(mask, Simd::splat(if_true), Simd::splat(if_false), out)
}

macro_rules! impl_if_then_else {
($T: ty) => {
impl IfThenElseKernel for PrimitiveArray<$T> {
type Scalar<'a> = $T;

fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {
let values = if_then_else_loop(
mask,
if_true.values(),
if_false.values(),
scalar::if_then_else_scalar_rest,
// Auto-generated SIMD was slower on ARM.
#[cfg(target_arch = "x86_64")]
if_then_else_simd_64,
#[cfg(not(target_arch = "x86_64"))]
scalar::if_then_else_scalar_64,
);
let validity = if_then_else_validity(mask, if_true.validity(), if_false.validity());
PrimitiveArray::from_vec(values).with_validity(validity)
}

fn if_then_else_broadcast_true(
mask: &Bitmap,
if_true: Self::Scalar<'_>,
if_false: &Self,
) -> Self {
let values = if_then_else_loop_broadcast_false(
true,
mask,
if_false.values(),
if_true,
// Auto-generated SIMD was slower on ARM.
#[cfg(target_arch = "x86_64")]
if_then_else_broadcast_false_simd_64,
#[cfg(not(target_arch = "x86_64"))]
scalar::if_then_else_broadcast_false_scalar_64,
);
let validity = if_then_else_validity(mask, None, if_false.validity());
PrimitiveArray::from_vec(values).with_validity(validity)
}

fn if_then_else_broadcast_false(
mask: &Bitmap,
if_true: &Self,
if_false: Self::Scalar<'_>,
) -> Self {
let values = if_then_else_loop_broadcast_false(
false,
mask,
if_true.values(),
if_false,
// Auto-generated SIMD was slower on ARM.
#[cfg(target_arch = "x86_64")]
if_then_else_broadcast_false_simd_64,
#[cfg(not(target_arch = "x86_64"))]
scalar::if_then_else_broadcast_false_scalar_64,
);
let validity = if_then_else_validity(mask, if_true.validity(), None);
PrimitiveArray::from_vec(values).with_validity(validity)
}

fn if_then_else_broadcast_both(
_dtype: ArrowDataType,
mask: &Bitmap,
if_true: Self::Scalar<'_>,
if_false: Self::Scalar<'_>,
) -> Self {
let values = if_then_else_loop_broadcast_both(
mask,
if_true,
if_false,
// Auto-generated SIMD was slower on ARM.
#[cfg(target_arch = "x86_64")]
if_then_else_broadcast_both_simd_64,
#[cfg(not(target_arch = "x86_64"))]
scalar::if_then_else_broadcast_both_scalar_64,
);
PrimitiveArray::from_vec(values)
}
}
};
}

impl_if_then_else!(i8);
impl_if_then_else!(i16);
impl_if_then_else!(i32);
impl_if_then_else!(i64);
impl_if_then_else!(u8);
impl_if_then_else!(u16);
impl_if_then_else!(u32);
impl_if_then_else!(u64);
impl_if_then_else!(f32);
impl_if_then_else!(f64);
30 changes: 22 additions & 8 deletions crates/polars-compute/src/if_then_else/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ use arrow::buffer::Buffer;
use arrow::datatypes::ArrowDataType;

use super::IfThenElseKernel;
use crate::if_then_else::scalar::{
if_then_else_broadcast_both_scalar_64, if_then_else_broadcast_false_scalar_64,
};
use crate::if_then_else::scalar::if_then_else_broadcast_both_scalar_64;

// Makes a buffer and a set of views into that buffer from a set of strings.
// Does not allocate a buffer if not necessary.
Expand Down Expand Up @@ -87,7 +85,7 @@ impl IfThenElseKernel for BinaryViewArray {
mask,
if_false.views(),
true_view,
if_then_else_broadcast_false_scalar_64,
if_then_else_broadcast_false_view_64,
);

let validity = super::if_then_else_validity(mask, None, if_false.validity());
Expand Down Expand Up @@ -120,7 +118,7 @@ impl IfThenElseKernel for BinaryViewArray {
mask,
if_true.views(),
false_view,
if_then_else_broadcast_false_scalar_64,
if_then_else_broadcast_false_view_64,
);

let validity = super::if_then_else_validity(mask, if_true.validity(), None);
Expand Down Expand Up @@ -215,13 +213,14 @@ pub fn if_then_else_view_rest(
false_buffer_idx_offset: u32,
) {
assert!(if_true.len() <= out.len()); // Removes bounds checks in inner loop.
let true_it = if_true.iter().copied();
let false_it = if_false.iter().copied();
let true_it = if_true.iter();
let false_it = if_false.iter();
for (i, (t, f)) in true_it.zip(false_it).enumerate() {
// Written like this, this loop *should* be branchless.
// Unfortunately we're still dependent on the compiler.
let m = (mask >> i) & 1 != 0;
let mut v = if m { t } else { f };
let src = if m { t } else { f };
let mut v = *src;
let offset = if m | (v.length <= 12) {
// Yes, | instead of || is intentional.
0
Expand All @@ -242,3 +241,18 @@ pub fn if_then_else_view_64(
) {
if_then_else_view_rest(mask, if_true, if_false, out, false_buffer_idx_offset)
}

// Using the scalar variant of this works, but was slower, we want to select a source pointer and
// then copy it. Using this version for the integers results in branches.
pub fn if_then_else_broadcast_false_view_64(
mask: u64,
if_true: &[View; 64],
if_false: View,
out: &mut [MaybeUninit<View>; 64],
) {
assert!(if_true.len() == out.len()); // Removes bounds checks in inner loop.
for (i, t) in if_true.iter().enumerate() {
let src = if (mask >> i) & 1 != 0 { t } else { &if_false };
out[i] = MaybeUninit::new(*src);
}
}
13 changes: 13 additions & 0 deletions crates/polars-compute/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,23 @@
feature(stdarch_x86_avx512)
)]

use arrow::types::NativeType;

pub mod arithmetic;
pub mod comparisons;
pub mod filter;
pub mod if_then_else;
pub mod min_max;

pub mod arity;

// Trait to enable the scalar blanket implementation.
pub trait NotSimdPrimitive: NativeType {}

#[cfg(not(feature = "simd"))]
impl<T: NativeType> NotSimdPrimitive for T {}

#[cfg(feature = "simd")]
impl NotSimdPrimitive for u128 {}
#[cfg(feature = "simd")]
impl NotSimdPrimitive for i128 {}
24 changes: 24 additions & 0 deletions crates/polars-utils/src/cpuid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,27 @@ pub fn has_fast_bmi2() -> bool {

false
}

#[inline]
pub fn is_avx512_enabled() -> bool {
#[cfg(target_arch = "x86_64")]
{
static CACHE: OnceLock<bool> = OnceLock::new();
return *CACHE.get_or_init(|| {
if !std::arch::is_x86_feature_detected!("avx512f") {
return false;
}

if std::env::var("POLARS_DISABLE_AVX512")
.map(|var| var == "1")
.unwrap_or(false)
{
return false;
}

true
});
}

false
}
Loading