Skip to content

Commit

Permalink
refactor: no != for nan checks anymore
Browse files Browse the repository at this point in the history
  • Loading branch information
jvdd committed Mar 4, 2023
1 parent 95cdc9c commit db437f5
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/scalar/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mod generic;
pub use generic::{ScalarArgMinMax, SCALAR};
// Data type specific modules}
// Data type specific modules
mod scalar_f16;
28 changes: 25 additions & 3 deletions src/simd/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,19 @@ where
unsafe { *arr.get_unchecked(0) }
}

/// Checks

/// Return case for the algorithm
#[inline(always)]
fn _return_check(_v: ScalarDType) -> bool {
false
}

/// Check if the value is NaN
#[inline(always)]
fn _nan_check(_v: ScalarDType) -> bool {
false
}
}

// --------------- Int (signed and unsigned)
Expand All @@ -190,10 +200,16 @@ macro_rules! impl_SIMDInit_FloatReturnNaN {
{
// Use all initialization methods from the default implementation

/// Return when a NaN is found
#[inline(always)]
fn _return_check(v: $scalar_dtype) -> bool {
v.is_nan()
}

#[inline(always)]
fn _nan_check(v: $scalar_dtype) -> bool {
v.is_nan()
}
}
};
}
Expand Down Expand Up @@ -244,6 +260,11 @@ macro_rules! impl_SIMDInit_FloatIgnoreNaN {
fn _initialize_max_value(_: &[$scalar_dtype]) -> $scalar_dtype {
<$scalar_dtype>::NEG_INFINITY
}

#[inline(always)]
fn _nan_check(v: $scalar_dtype) -> bool {
v.is_nan()
}
}
)*
};
Expand Down Expand Up @@ -442,9 +463,10 @@ where
argminmax_generic(
data,
LANE_SIZE,
Self::_overflow_safe_core_argminmax,
Self::IGNORE_NAN, // TODO: can perhaps cleaner - idk
SCALAR::argminmax,
Self::_overflow_safe_core_argminmax, // SIMD operation
SCALAR::argminmax, // Scalar operation
Self::_nan_check, // NaN check - true if value is NaN
Self::IGNORE_NAN, // Ignore NaNs - if false -> return NaN
)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/simd/simd_f64_ignore_nan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ mod sse_ignore_nan {
impl SIMDArgMinMax<f64, __m128d, __m128d, LANE_SIZE, SCALAR<FloatIgnoreNaN>>
for SSE<FloatIgnoreNaN>
{
#[target_feature(enable = "sse4.1")] // TODO: check if this is correct
#[target_feature(enable = "sse4.1")]
unsafe fn argminmax(data: &[f64]) -> (usize, usize) {
Self::_argminmax(data)
}
Expand Down
48 changes: 26 additions & 22 deletions src/simd/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ pub(crate) fn argminmax_generic<T: Copy + PartialOrd>(
arr: &[T],
lane_size: usize,
core_argminmax: unsafe fn(&[T]) -> (usize, T, usize, T),
ignore_nan: bool, // if false, NaNs will be returned
scalar_argminmax: fn(&[T]) -> (usize, usize),
nan_check: fn(T) -> bool, // returns true if value is NaN
ignore_nan: bool, // if false, NaNs will be returned
) -> (usize, usize) {
assert!(!arr.is_empty()); // split_array should never return (None, None)
match split_array(arr, lane_size) {
Expand All @@ -26,18 +27,24 @@ pub(crate) fn argminmax_generic<T: Copy + PartialOrd>(
let (min_index, min_value) = find_final_index_min(
(simd_result.0, simd_result.1),
(rem_result.0, rem_result.1),
nan_check,
ignore_nan,
);
let (max_index, max_value) = find_final_index_max(
(simd_result.2, simd_result.3),
(rem_result.2, rem_result.3),
nan_check,
ignore_nan,
);
get_correct_argminmax_result(min_index, min_value, max_index, max_value, ignore_nan)
get_correct_argminmax_result(
min_index, min_value, max_index, max_value, nan_check, ignore_nan,
)
}
(Some(simd_arr), None) => {
let (min_index, min_value, max_index, max_value) = unsafe { core_argminmax(simd_arr) };
get_correct_argminmax_result(min_index, min_value, max_index, max_value, ignore_nan)
get_correct_argminmax_result(
min_index, min_value, max_index, max_value, nan_check, ignore_nan,
)
}
(None, Some(rem)) => {
let (rem_min_index, rem_max_index) = scalar_argminmax(rem);
Expand Down Expand Up @@ -80,6 +87,7 @@ fn split_array<T: Copy>(arr: &[T], lane_size: usize) -> (Option<&[T]>, Option<&[
fn find_final_index_min<T: Copy + PartialOrd>(
simd_result: (usize, T),
remainder_result: (usize, T),
nan_check: fn(T) -> bool,
ignore_nan: bool,
) -> (usize, T) {
let (min_index, min_value) = match simd_result.1.partial_cmp(&remainder_result.1) {
Expand All @@ -89,21 +97,18 @@ fn find_final_index_min<T: Copy + PartialOrd>(
None => {
if !ignore_nan {
// --- Return NaNs
// Should prefer the simd result over the remainder result if both are
// NaN
if simd_result.1 != simd_result.1 {
// because NaN != NaN
// Should prefer simd result over remainder result if both are NaN
if nan_check(simd_result.1) {
simd_result
} else {
remainder_result
}
} else {
// --- Ignore NaNs
// If both are NaN raise panic, otherwise return the index of the
// non-NaN value
if simd_result.1 != simd_result.1 && remainder_result.1 != remainder_result.1 {
// If both are NaN raise panic, else return index of the non-NaN value
if nan_check(simd_result.1) && nan_check(remainder_result.1) {
panic!("Data contains only NaNs (or +/- inf)")
} else if remainder_result.1 != remainder_result.1 {
} else if nan_check(remainder_result.1) {
simd_result
} else {
remainder_result
Expand All @@ -128,6 +133,7 @@ fn find_final_index_min<T: Copy + PartialOrd>(
fn find_final_index_max<T: Copy + PartialOrd>(
simd_result: (usize, T),
remainder_result: (usize, T),
nan_check: fn(T) -> bool,
ignore_nan: bool,
) -> (usize, T) {
let (max_index, max_value) = match simd_result.1.partial_cmp(&remainder_result.1) {
Expand All @@ -137,21 +143,18 @@ fn find_final_index_max<T: Copy + PartialOrd>(
None => {
if !ignore_nan {
// --- Return NaNs
// Should prefer the simd result over the remainder result if both are
// NaN
if simd_result.1 != simd_result.1 {
// because NaN != NaN
// Should prefer simd result over remainder result if both are NaN
if nan_check(simd_result.1) {
simd_result
} else {
remainder_result
}
} else {
// --- Ignore NaNs
// If both are NaN raise panic, otherwise return the index of the
// non-NaN value
if simd_result.1 != simd_result.1 && remainder_result.1 != remainder_result.1 {
// If both are NaN raise panic, else return index of the non-NaN value
if nan_check(simd_result.1) && nan_check(remainder_result.1) {
panic!("Data contains only NaNs (or +/- inf)")
} else if remainder_result.1 != remainder_result.1 {
} else if nan_check(remainder_result.1) {
simd_result
} else {
remainder_result
Expand All @@ -173,16 +176,17 @@ fn get_correct_argminmax_result<T: Copy + PartialOrd>(
min_value: T,
max_index: usize,
max_value: T,
nan_check: fn(T) -> bool,
ignore_nan: bool,
) -> (usize, usize) {
if !ignore_nan && (min_value != min_value || max_value != max_value) {
if !ignore_nan && (nan_check(min_value) || nan_check(max_value)) {
// --- Return NaNs
// -> at least one of the values is NaN
if min_value != min_value && max_value != max_value {
if nan_check(min_value) && nan_check(max_value) {
// If both are NaN, return lowest index
let lowest_index = std::cmp::min(min_index, max_index);
return (lowest_index, lowest_index);
} else if min_value != min_value {
} else if nan_check(min_value) {
// If min is the only NaN, return min index
return (min_index, min_index);
} else {
Expand Down

0 comments on commit db437f5

Please sign in to comment.