Skip to content

Commit

Permalink
enh: remove simd_nightly feature flag for aarch64 (#63)
Browse files Browse the repository at this point in the history
* remove simd_nightly feature flag for aarch64 for i16

* add aarch64 target_arch to imports

* add aarch64 target_arch to imports in generic.rs

* remove nightly_simd flag for impl_SIMDArgMinMax macro

* remove nightly_simd feature flag for aarch64 for i8 and i32

* wrap up ints and uints

* add f32 support

* fix arm

* 🙈

* update other float types

* restrict cfg to arm and nightly_simd for floats

* restrict imports to arm & nightly_simd

* ♻️

* build aarch64 as stable

* use scalar implementation for 64 bit integers on aarch64

* 🧹 review code

* 🧹 remove fixed TODO

* 🧘 review code
  • Loading branch information
jvdd committed Mar 3, 2024
1 parent 3a9e74d commit 64f4b0b
Show file tree
Hide file tree
Showing 17 changed files with 562 additions and 288 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ exclude = [".git*", "dev_utils/**/*", "tests/**/*"]


[dependencies]
num-traits = { version = "0.2.17", default-features = false }
num-traits = { version = "0.2", default-features = false }
half = { version = "2.3.1", default-features = false, features=["num-traits"], optional = true }
ndarray = { version = "0.15.6", default-features = false, optional = true}
arrow = { version = ">0", default-features = false, optional = true}
arrow2 = { version = ">0.0", default-features = false, optional = true}
ndarray = { version = "0.15.6", default-features = false, optional = true }
arrow = { version = ">0", default-features = false, optional = true }
arrow2 = { version = ">0.0", default-features = false, optional = true }
# once_cell = "1.16.0"

[features]
Expand Down
100 changes: 18 additions & 82 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
),
cfg_attr(version("1.78"), feature(stdarch_x86_avx512))
)]
// TODO: Aarch64 is stable now - check if this is under nightly_simd https://github.com/rust-lang/rust/issues/111800
#![cfg_attr(
all(feature = "nightly_simd", target_arch = "arm"),
cfg_attr(
Expand Down Expand Up @@ -118,8 +117,8 @@ pub(crate) use simd::AVX512;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) use simd::{SIMDArgMinMax, AVX2, SSE};
#[cfg(any(
all(target_arch = "aarch64", feature = "float"), // is stable for f64
all(any(target_arch = "arm", target_arch = "aarch64"), feature = "nightly_simd"),
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64"
))]
pub(crate) use simd::{SIMDArgMinMax, NEON};

Expand Down Expand Up @@ -304,7 +303,7 @@ macro_rules! impl_argminmax_int {
return unsafe { SSE::<Int>::argminmax(self) }
}
}
#[cfg(all(target_arch = "aarch64", feature = "nightly_simd"))]
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") & (<$int_type>::NB_BITS < 64) {
// Scalar is faster for 64-bit numbers
Expand Down Expand Up @@ -349,7 +348,7 @@ macro_rules! impl_argminmax_int {
return unsafe { SSE::<Int>::argmin(self) }
}
}
#[cfg(all(target_arch = "aarch64", feature = "nightly_simd"))]
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<Int>::argmin(self) }
Expand Down Expand Up @@ -392,7 +391,7 @@ macro_rules! impl_argminmax_int {
return unsafe { SSE::<Int>::argmax(self) }
}
}
#[cfg(all(target_arch = "aarch64", feature = "nightly_simd"))]
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<Int>::argmax(self) }
Expand Down Expand Up @@ -443,10 +442,9 @@ macro_rules! impl_argminmax_float {
return unsafe { SSE::<FloatIgnoreNaN>::argminmax(self) }
}
}
#[cfg(all(target_arch = "aarch64", feature = "nightly_simd"))]
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
// NEON f64 is part of stable Rust (see code below this macro)
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatIgnoreNaN>::argminmax(self) }
}
}
Expand Down Expand Up @@ -483,10 +481,9 @@ macro_rules! impl_argminmax_float {
return unsafe { SSE::<FloatIgnoreNaN>::argmin(self) }
}
}
#[cfg(all(target_arch = "aarch64", feature = "nightly_simd"))]
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
// NEON f64 is part of stable Rust (see code below this macro)
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatIgnoreNaN>::argmin(self) }
}
}
Expand Down Expand Up @@ -523,10 +520,9 @@ macro_rules! impl_argminmax_float {
return unsafe { SSE::<FloatIgnoreNaN>::argmax(self) }
}
}
#[cfg(all(target_arch = "aarch64", feature = "nightly_simd"))]
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
// NEON f64 is part of stable Rust (see code below this macro)
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatIgnoreNaN>::argmax(self) }
}
}
Expand Down Expand Up @@ -563,10 +559,9 @@ macro_rules! impl_argminmax_float {
return unsafe { SSE::<FloatReturnNaN>::argminmax(self) }
}
}
#[cfg(all(target_arch = "aarch64", feature = "nightly_simd"))]
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
// We miss some NEON instructions for 64-bit numbers
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatReturnNaN>::argminmax(self) }
}
}
Expand Down Expand Up @@ -601,10 +596,9 @@ macro_rules! impl_argminmax_float {
return unsafe { SSE::<FloatReturnNaN>::argmin(self) }
}
}
#[cfg(all(target_arch = "aarch64", feature = "nightly_simd"))]
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
// We miss some NEON instructions for 64-bit numbers
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatReturnNaN>::argmin(self) }
}
}
Expand Down Expand Up @@ -639,10 +633,9 @@ macro_rules! impl_argminmax_float {
return unsafe { SSE::<FloatReturnNaN>::argmax(self) }
}
}
#[cfg(all(target_arch = "aarch64", feature = "nightly_simd"))]
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
// We miss some NEON instructions for 64-bit numbers
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatReturnNaN>::argmax(self) }
}
}
Expand All @@ -660,68 +653,11 @@ macro_rules! impl_argminmax_float {
};
}

/// Implement ArgMinMax for &[f64] on aarch64 as NEON intrinsics for f64
/// are part of stable Rust on aarch64.
// Note: implementing this in a distinct impl block seemed more clean than
// hacking with unimpl_ macros in the simd_.rs files to avoid complaints
// from the compiler..
#[cfg(all(feature = "float", target_arch = "aarch64"))]
impl ArgMinMax for &[f64] {
fn argminmax(&self) -> (usize, usize) {
unsafe { NEON::<FloatIgnoreNaN>::argminmax(self) }
}
fn argmin(&self) -> usize {
unsafe { NEON::<FloatIgnoreNaN>::argmin(self) }
}
fn argmax(&self) -> usize {
unsafe { NEON::<FloatIgnoreNaN>::argmax(self) }
}
}

/// Implement NaNArgMinMax for &[f64] on aarch64 - the required intrinsics
/// for return nan are not part of stable Rust.
// Note: implementing this in a distinct impl block seemed more clean than
// hacking with unimpl_ macros in the simd_.rs files to avoid complaints
// from the compiler..
#[cfg(all(feature = "float", target_arch = "aarch64"))]
impl NaNArgMinMax for &[f64] {
fn nanargminmax(&self) -> (usize, usize) {
#[cfg(feature = "nightly_simd")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatReturnNaN>::argminmax(self) };
}
}
SCALAR::<FloatReturnNaN>::argminmax(self)
}
fn nanargmin(&self) -> usize {
#[cfg(feature = "nightly_simd")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatReturnNaN>::argmin(self) };
}
}
SCALAR::<FloatReturnNaN>::argmin(self)
}
fn nanargmax(&self) -> usize {
#[cfg(feature = "nightly_simd")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { NEON::<FloatReturnNaN>::argmax(self) };
}
}
SCALAR::<FloatReturnNaN>::argmax(self)
}
}

// Implement ArgMinMax for (non-optional) integer rust primitive types
impl_argminmax_int!(i8, i16, i32, i64, u8, u16, u32, u64);
// Implement for (optional) float rust primitive types
#[cfg(all(feature = "float", not(target_arch = "aarch64")))]
#[cfg(feature = "float")]
impl_argminmax_float!(f32, f64);
// For aarch64 f64 is implemented in the two impl blocks above
#[cfg(all(feature = "float", target_arch = "aarch64"))]
impl_argminmax_float!(f32);

// Implement ArgMinMax for other data types
#[cfg(feature = "half")]
Expand Down
48 changes: 34 additions & 14 deletions src/simd/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ where
/// - `impl_SIMDInit_Int!`
/// - called in the `simd_i*.rs` files
/// - called in the `simd_u*.rs` files
/// - `impl_SIMDInit_FloatIgnoreNaN!`
/// - see the `simd_f*_return_nan.rs` files
/// - `impl_SIMDInit_FloatReturnNaN!`
/// - see the `simd_f*_return_nan.rs` files
/// - `impl_SIMDInit_FloatIgnoreNaN!`
/// - see the `simd_f*_ignore_nan.rs` files
///
/// The current (default) implementation is for the Int case - see `impl_SIMDInit_Int!`
Expand Down Expand Up @@ -181,7 +181,12 @@ where

// --------------- Int (signed and unsigned)

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
macro_rules! impl_SIMDInit_Int {
($scalar_dtype:ty, $simd_vec_dtype:ty, $simd_mask_dtype:ty, $lane_size:expr, $simd_struct:ty) => {
impl SIMDInit<$scalar_dtype, $simd_vec_dtype, $simd_mask_dtype, $lane_size>
Expand All @@ -192,13 +197,23 @@ macro_rules! impl_SIMDInit_Int {
};
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
pub(crate) use impl_SIMDInit_Int; // Now classic paths Just Work™

// --------------- Float Return NaNs

#[cfg(any(feature = "float", feature = "half"))]
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
macro_rules! impl_SIMDInit_FloatReturnNaN {
($scalar_dtype:ty, $simd_vec_dtype:ty, $simd_mask_dtype:ty, $lane_size:expr, $simd_struct:ty) => {
impl SIMDInit<$scalar_dtype, $simd_vec_dtype, $simd_mask_dtype, $lane_size>
Expand All @@ -221,7 +236,12 @@ macro_rules! impl_SIMDInit_FloatReturnNaN {
}

#[cfg(any(feature = "float", feature = "half"))]
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
pub(crate) use impl_SIMDInit_FloatReturnNaN; // Now classic paths Just Work™

// --------------- Float Ignore NaNs
Expand All @@ -230,8 +250,8 @@ pub(crate) use impl_SIMDInit_FloatReturnNaN; // Now classic paths Just Work™
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "aarch64", feature = "float"), // is stable for f64
feature = "nightly_simd"
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
macro_rules! impl_SIMDInit_FloatIgnoreNaN {
($scalar_dtype:ty, $simd_vec_dtype:ty, $simd_mask_dtype:ty, $lane_size:expr, $simd_struct:ty) => {
Expand Down Expand Up @@ -306,8 +326,8 @@ macro_rules! impl_SIMDInit_FloatIgnoreNaN {
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "aarch64", feature = "float"), // is stable for f64
feature = "nightly_simd"
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
pub(crate) use impl_SIMDInit_FloatIgnoreNaN; // Now classic paths Just Work™

Expand Down Expand Up @@ -768,8 +788,8 @@ where
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "aarch64", feature = "float"), // is stable for f64
feature = "nightly_simd"
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
macro_rules! impl_SIMDArgMinMax {
($scalar_dtype:ty, $simd_vec_dtype:ty, $simd_mask_dtype:ty, $lane_size:expr, $scalar_struct:ty, $simd_struct:ty, $target:expr) => {
Expand Down Expand Up @@ -806,8 +826,8 @@ macro_rules! impl_SIMDArgMinMax {
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "aarch64", feature = "float"), // is stable for f64
feature = "nightly_simd"
all(target_arch = "arm", feature = "nightly_simd"),
target_arch = "aarch64",
))]
pub(crate) use impl_SIMDArgMinMax; // Now classic paths Just Work™

Expand Down
Loading

0 comments on commit 64f4b0b

Please sign in to comment.