From fea585dfea22d032823aed3e03d69b0712e35fcc Mon Sep 17 00:00:00 2001 From: "Chai T. Rex" Date: Wed, 24 Jul 2024 20:00:45 -0400 Subject: [PATCH] Greatly sped up `checked_isqrt` and `isqrt` methods * Uses a lookup table for 8-bit integers and then the Karatsuba square root algorithm for larger integers. * Includes optimization hints that give the compiler the exact numeric range of results. --- library/core/src/num/int_macros.rs | 36 ++++-- library/core/src/num/int_sqrt.rs | 190 ++++++++++++++++++++++++++++ library/core/src/num/mod.rs | 1 + library/core/src/num/nonzero.rs | 34 ++--- library/core/src/num/uint_macros.rs | 11 +- 5 files changed, 237 insertions(+), 35 deletions(-) create mode 100644 library/core/src/num/int_sqrt.rs diff --git a/library/core/src/num/int_macros.rs b/library/core/src/num/int_macros.rs index d40e02352a1d0..1b3dcf2287fc8 100644 --- a/library/core/src/num/int_macros.rs +++ b/library/core/src/num/int_macros.rs @@ -1580,7 +1580,18 @@ macro_rules! int_impl { if self < 0 { None } else { - Some((self as $UnsignedT).isqrt() as Self) + let result = crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT; + + // SAFETY: Inform the optimizer that square roots of + // nonnegative integers are nonnegative and what the maximum + // result is. + unsafe { + crate::hint::assert_unchecked(result >= 0); + const MAX_RESULT: $SelfT = crate::num::int_sqrt::$ActualT($ActualT::MAX) as $SelfT; + crate::hint::assert_unchecked(result <= MAX_RESULT); + } + + Some(result) } } @@ -2766,14 +2777,21 @@ macro_rules! int_impl { without modifying the original"] #[inline] pub const fn isqrt(self) -> Self { - // I would like to implement it as - // ``` - // self.checked_isqrt().expect("argument of integer square root must be non-negative") - // ``` - // but `expect` is not yet stable as a `const fn`. - match self.checked_isqrt() { - Some(sqrt) => sqrt, - None => panic!("argument of integer square root must be non-negative"), + if self < 0 { + crate::num::int_sqrt::panic_for_negative_argument(); + } else { + let result = crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT; + + // SAFETY: Inform the optimizer that square roots of + // nonnegative integers are nonnegative and what the maximum + // result is. + unsafe { + crate::hint::assert_unchecked(result >= 0); + const MAX_RESULT: $SelfT = crate::num::int_sqrt::$ActualT($ActualT::MAX) as $SelfT; + crate::hint::assert_unchecked(result <= MAX_RESULT); + } + + result } } diff --git a/library/core/src/num/int_sqrt.rs b/library/core/src/num/int_sqrt.rs new file mode 100644 index 0000000000000..e7320d7fbb35d --- /dev/null +++ b/library/core/src/num/int_sqrt.rs @@ -0,0 +1,190 @@ +/// These functions compute the integer square root of their type, assuming +/// that someone has already checked that the value is nonnegative. + +const ISQRT_AND_REMAINDER_8_BIT: [(u8, u8); 256] = { + let mut result = [(0, 0); 256]; + + let mut sqrt = 0; + let mut i = 0; + 'outer: loop { + let mut remaining = 2 * sqrt + 1; + while remaining > 0 { + result[i as usize] = (sqrt, 2 * sqrt + 1 - remaining); + i += 1; + if i >= result.len() { + break 'outer; + } + remaining -= 1; + } + sqrt += 1; + } + + result +}; + +// `#[inline(always)]` because the programmer-accessible functions will use +// this internally and the contents of this should be inlined there. +#[inline(always)] +pub const fn u8(n: u8) -> u8 { + ISQRT_AND_REMAINDER_8_BIT[n as usize].0 +} + +#[inline(always)] +const fn intermediate_u8(n: u8) -> (u8, u8) { + ISQRT_AND_REMAINDER_8_BIT[n as usize] +} + +macro_rules! karatsuba_isqrt { + ($FullBitsT:ty, $fn:ident, $intermediate_fn:ident, $HalfBitsT:ty, $half_fn:ident, $intermediate_half_fn:ident) => { + // `#[inline(always)]` because the programmer-accessible functions will + // use this internally and the contents of this should be inlined + // there. + #[inline(always)] + pub const fn $fn(mut n: $FullBitsT) -> $FullBitsT { + // Performs a Karatsuba square root. + // https://web.archive.org/web/20230511212802/https://inria.hal.science/inria-00072854v1/file/RR-3805.pdf + + const HALF_BITS: u32 = <$FullBitsT>::BITS >> 1; + const QUARTER_BITS: u32 = <$FullBitsT>::BITS >> 2; + + let leading_zeros = n.leading_zeros(); + let result = if leading_zeros >= HALF_BITS { + $half_fn(n as $HalfBitsT) as $FullBitsT + } else { + // Either the most-significant bit or its neighbor must be a one, so we shift left to make that happen. + let precondition_shift = leading_zeros & (HALF_BITS - 2); + n <<= precondition_shift; + + let hi = (n >> HALF_BITS) as $HalfBitsT; + let lo = n & (<$HalfBitsT>::MAX as $FullBitsT); + + let (s_prime, r_prime) = $intermediate_half_fn(hi); + + let numerator = ((r_prime as $FullBitsT) << QUARTER_BITS) | (lo >> QUARTER_BITS); + let denominator = (s_prime as $FullBitsT) << 1; + + let q = numerator / denominator; + let u = numerator % denominator; + + let mut s = (s_prime << QUARTER_BITS) as $FullBitsT + q; + if ((u << QUARTER_BITS) | (lo & ((1 << QUARTER_BITS) - 1))) < q * q { + s -= 1; + } + s >> (precondition_shift >> 1) + }; + + result + } + + const fn $intermediate_fn(mut n: $FullBitsT) -> ($FullBitsT, $FullBitsT) { + // Performs a Karatsuba square root. + // https://web.archive.org/web/20230511212802/https://inria.hal.science/inria-00072854v1/file/RR-3805.pdf + + const HALF_BITS: u32 = <$FullBitsT>::BITS >> 1; + const QUARTER_BITS: u32 = <$FullBitsT>::BITS >> 2; + + let leading_zeros = n.leading_zeros(); + let result = if leading_zeros >= HALF_BITS { + let (s, r) = $intermediate_half_fn(n as $HalfBitsT); + (s as $FullBitsT, r as $FullBitsT) + } else { + // Either the most-significant bit or its neighbor must be a one, so we shift left to make that happen. + let precondition_shift = leading_zeros & (HALF_BITS - 2); + n <<= precondition_shift; + + let hi = (n >> HALF_BITS) as $HalfBitsT; + let lo = n & (<$HalfBitsT>::MAX as $FullBitsT); + + let (s_prime, r_prime) = $intermediate_half_fn(hi); + + let numerator = ((r_prime as $FullBitsT) << QUARTER_BITS) | (lo >> QUARTER_BITS); + let denominator = (s_prime as $FullBitsT) << 1; + + let q = numerator / denominator; + let u = numerator % denominator; + + let mut s = (s_prime << QUARTER_BITS) as $FullBitsT + q; + let (mut r, overflow) = + ((u << QUARTER_BITS) | (lo & ((1 << QUARTER_BITS) - 1))).overflowing_sub(q * q); + if overflow { + r = r.wrapping_add((s << 1) - 1); + s -= 1; + } + (s >> (precondition_shift >> 1), r >> (precondition_shift >> 1)) + }; + + result + } + }; +} + +karatsuba_isqrt!(u16, u16, intermediate_u16, u8, u8, intermediate_u8); +karatsuba_isqrt!(u32, u32, intermediate_u32, u16, u16, intermediate_u16); +karatsuba_isqrt!(u64, u64, intermediate_u64, u32, u32, intermediate_u32); +karatsuba_isqrt!(u128, u128, _intermediate_u128, u64, u64, intermediate_u64); + +#[cfg(target_pointer_width = "16")] +#[inline(always)] +pub const fn usize(n: usize) -> usize { + u16(n as u16) as usize +} + +#[cfg(target_pointer_width = "32")] +#[inline(always)] +pub const fn usize(n: usize) -> usize { + u32(n as u32) as usize +} + +#[cfg(target_pointer_width = "64")] +#[inline(always)] +pub const fn usize(n: usize) -> usize { + u64(n as u64) as usize +} + +// 0 <= val <= i8::MAX +#[inline(always)] +pub const fn i8(n: i8) -> i8 { + u8(n as u8) as i8 +} + +// 0 <= val <= i16::MAX +#[inline(always)] +pub const fn i16(n: i16) -> i16 { + u16(n as u16) as i16 +} + +// 0 <= val <= i32::MAX +#[inline(always)] +pub const fn i32(n: i32) -> i32 { + u32(n as u32) as i32 +} + +// 0 <= val <= i64::MAX +#[inline(always)] +pub const fn i64(n: i64) -> i64 { + u64(n as u64) as i64 +} + +// 0 <= val <= i128::MAX +#[inline(always)] +pub const fn i128(n: i128) -> i128 { + u128(n as u128) as i128 +} + +/* +This function is not used. + +// 0 <= val <= isize::MAX +#[inline(always)] +pub const fn isize(n: isize) -> isize { + usize(n as usize) as isize +} +*/ + +/// Instantiate this panic logic once, rather than for all the ilog methods +/// on every single primitive type. +#[cold] +#[track_caller] +pub const fn panic_for_negative_argument() -> ! { + panic!("argument of integer square root cannot be negative") +} diff --git a/library/core/src/num/mod.rs b/library/core/src/num/mod.rs index 4e8e0ecdde998..b059417357b80 100644 --- a/library/core/src/num/mod.rs +++ b/library/core/src/num/mod.rs @@ -43,6 +43,7 @@ mod uint_macros; // import uint_impl! mod error; mod int_log10; +mod int_sqrt; mod nonzero; mod overflow_panic; mod saturating; diff --git a/library/core/src/num/nonzero.rs b/library/core/src/num/nonzero.rs index d80d3241b1eee..6a0323136dc0e 100644 --- a/library/core/src/num/nonzero.rs +++ b/library/core/src/num/nonzero.rs @@ -1247,31 +1247,19 @@ macro_rules! nonzero_integer_signedness_dependent_methods { without modifying the original"] #[inline] pub const fn isqrt(self) -> Self { - // The algorithm is based on the one presented in - // - // which cites as source the following C code: - // . - - let mut op = self.get(); - let mut res = 0; - let mut one = 1 << (self.ilog2() & !1); - - while one != 0 { - if op >= res + one { - op -= res + one; - res = (res >> 1) + one; - } else { - res >>= 1; - } - one >>= 2; + let result = super::int_sqrt::$Int(self.get()); + + // SAFETY: Inform the optimizer that square roots of positive + // integers are positive and what the maximum result is. + unsafe { + hint::assert_unchecked(result > 0); + const MAX_RESULT: $Int = super::int_sqrt::$Int($Int::MAX); + hint::assert_unchecked(result <= MAX_RESULT); } - // SAFETY: The result fits in an integer with half as many bits. - // Inform the optimizer about it. - unsafe { hint::assert_unchecked(res < 1 << (Self::BITS / 2)) }; - - // SAFETY: The square root of an integer >= 1 is always >= 1. - unsafe { Self::new_unchecked(res) } + // SAFETY: The square root of a positive integer is always + // positive. + unsafe { Self::new_unchecked(result) } } }; diff --git a/library/core/src/num/uint_macros.rs b/library/core/src/num/uint_macros.rs index d50bcde01571c..756bac4813ab0 100644 --- a/library/core/src/num/uint_macros.rs +++ b/library/core/src/num/uint_macros.rs @@ -2588,10 +2588,15 @@ macro_rules! uint_impl { without modifying the original"] #[inline] pub const fn isqrt(self) -> Self { - match NonZero::new(self) { - Some(x) => x.isqrt().get(), - None => 0, + let result = crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT; + + // SAFETY: Inform the optimizer of what the maximum result is. + unsafe { + const MAX_RESULT: $SelfT = crate::num::int_sqrt::$ActualT($ActualT::MAX) as $SelfT; + crate::hint::assert_unchecked(result <= MAX_RESULT); } + + result } /// Performs Euclidean division.