Skip to content

Commit

Permalink
Greatly sped up checked_isqrt and isqrt methods
Browse files Browse the repository at this point in the history
* Uses a lookup table for 8-bit integers and then the Karatsuba square root
  algorithm for larger integers.
* Includes optimization hints that give the compiler the exact numeric range
  of results.
  • Loading branch information
ChaiTRex committed Jul 25, 2024
1 parent d180572 commit fea585d
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 35 deletions.
36 changes: 27 additions & 9 deletions library/core/src/num/int_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,18 @@ macro_rules! int_impl {
if self < 0 {
None
} else {
Some((self as $UnsignedT).isqrt() as Self)
let result = crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT;

// SAFETY: Inform the optimizer that square roots of
// nonnegative integers are nonnegative and what the maximum
// result is.
unsafe {
crate::hint::assert_unchecked(result >= 0);
const MAX_RESULT: $SelfT = crate::num::int_sqrt::$ActualT($ActualT::MAX) as $SelfT;
crate::hint::assert_unchecked(result <= MAX_RESULT);
}

Some(result)
}
}

Expand Down Expand Up @@ -2766,14 +2777,21 @@ macro_rules! int_impl {
without modifying the original"]
#[inline]
pub const fn isqrt(self) -> Self {
// I would like to implement it as
// ```
// self.checked_isqrt().expect("argument of integer square root must be non-negative")
// ```
// but `expect` is not yet stable as a `const fn`.
match self.checked_isqrt() {
Some(sqrt) => sqrt,
None => panic!("argument of integer square root must be non-negative"),
if self < 0 {
crate::num::int_sqrt::panic_for_negative_argument();
} else {
let result = crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT;

// SAFETY: Inform the optimizer that square roots of
// nonnegative integers are nonnegative and what the maximum
// result is.
unsafe {
crate::hint::assert_unchecked(result >= 0);
const MAX_RESULT: $SelfT = crate::num::int_sqrt::$ActualT($ActualT::MAX) as $SelfT;
crate::hint::assert_unchecked(result <= MAX_RESULT);
}

result
}
}

Expand Down
190 changes: 190 additions & 0 deletions library/core/src/num/int_sqrt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/// These functions compute the integer square root of their type, assuming
/// that someone has already checked that the value is nonnegative.
const ISQRT_AND_REMAINDER_8_BIT: [(u8, u8); 256] = {
let mut result = [(0, 0); 256];

let mut sqrt = 0;
let mut i = 0;
'outer: loop {
let mut remaining = 2 * sqrt + 1;
while remaining > 0 {
result[i as usize] = (sqrt, 2 * sqrt + 1 - remaining);
i += 1;
if i >= result.len() {
break 'outer;
}
remaining -= 1;
}
sqrt += 1;
}

result
};

// `#[inline(always)]` because the programmer-accessible functions will use
// this internally and the contents of this should be inlined there.
#[inline(always)]
pub const fn u8(n: u8) -> u8 {
ISQRT_AND_REMAINDER_8_BIT[n as usize].0
}

#[inline(always)]
const fn intermediate_u8(n: u8) -> (u8, u8) {
ISQRT_AND_REMAINDER_8_BIT[n as usize]
}

macro_rules! karatsuba_isqrt {
($FullBitsT:ty, $fn:ident, $intermediate_fn:ident, $HalfBitsT:ty, $half_fn:ident, $intermediate_half_fn:ident) => {
// `#[inline(always)]` because the programmer-accessible functions will
// use this internally and the contents of this should be inlined
// there.
#[inline(always)]
pub const fn $fn(mut n: $FullBitsT) -> $FullBitsT {
// Performs a Karatsuba square root.
// https://web.archive.org/web/20230511212802/https://inria.hal.science/inria-00072854v1/file/RR-3805.pdf

const HALF_BITS: u32 = <$FullBitsT>::BITS >> 1;
const QUARTER_BITS: u32 = <$FullBitsT>::BITS >> 2;

let leading_zeros = n.leading_zeros();
let result = if leading_zeros >= HALF_BITS {
$half_fn(n as $HalfBitsT) as $FullBitsT
} else {
// Either the most-significant bit or its neighbor must be a one, so we shift left to make that happen.
let precondition_shift = leading_zeros & (HALF_BITS - 2);
n <<= precondition_shift;

let hi = (n >> HALF_BITS) as $HalfBitsT;
let lo = n & (<$HalfBitsT>::MAX as $FullBitsT);

let (s_prime, r_prime) = $intermediate_half_fn(hi);

let numerator = ((r_prime as $FullBitsT) << QUARTER_BITS) | (lo >> QUARTER_BITS);
let denominator = (s_prime as $FullBitsT) << 1;

let q = numerator / denominator;
let u = numerator % denominator;

let mut s = (s_prime << QUARTER_BITS) as $FullBitsT + q;
if ((u << QUARTER_BITS) | (lo & ((1 << QUARTER_BITS) - 1))) < q * q {
s -= 1;
}
s >> (precondition_shift >> 1)
};

result
}

const fn $intermediate_fn(mut n: $FullBitsT) -> ($FullBitsT, $FullBitsT) {
// Performs a Karatsuba square root.
// https://web.archive.org/web/20230511212802/https://inria.hal.science/inria-00072854v1/file/RR-3805.pdf

const HALF_BITS: u32 = <$FullBitsT>::BITS >> 1;
const QUARTER_BITS: u32 = <$FullBitsT>::BITS >> 2;

let leading_zeros = n.leading_zeros();
let result = if leading_zeros >= HALF_BITS {
let (s, r) = $intermediate_half_fn(n as $HalfBitsT);
(s as $FullBitsT, r as $FullBitsT)
} else {
// Either the most-significant bit or its neighbor must be a one, so we shift left to make that happen.
let precondition_shift = leading_zeros & (HALF_BITS - 2);
n <<= precondition_shift;

let hi = (n >> HALF_BITS) as $HalfBitsT;
let lo = n & (<$HalfBitsT>::MAX as $FullBitsT);

let (s_prime, r_prime) = $intermediate_half_fn(hi);

let numerator = ((r_prime as $FullBitsT) << QUARTER_BITS) | (lo >> QUARTER_BITS);
let denominator = (s_prime as $FullBitsT) << 1;

let q = numerator / denominator;
let u = numerator % denominator;

let mut s = (s_prime << QUARTER_BITS) as $FullBitsT + q;
let (mut r, overflow) =
((u << QUARTER_BITS) | (lo & ((1 << QUARTER_BITS) - 1))).overflowing_sub(q * q);
if overflow {
r = r.wrapping_add((s << 1) - 1);
s -= 1;
}
(s >> (precondition_shift >> 1), r >> (precondition_shift >> 1))
};

result
}
};
}

karatsuba_isqrt!(u16, u16, intermediate_u16, u8, u8, intermediate_u8);
karatsuba_isqrt!(u32, u32, intermediate_u32, u16, u16, intermediate_u16);
karatsuba_isqrt!(u64, u64, intermediate_u64, u32, u32, intermediate_u32);
karatsuba_isqrt!(u128, u128, _intermediate_u128, u64, u64, intermediate_u64);

#[cfg(target_pointer_width = "16")]
#[inline(always)]
pub const fn usize(n: usize) -> usize {
u16(n as u16) as usize
}

#[cfg(target_pointer_width = "32")]
#[inline(always)]
pub const fn usize(n: usize) -> usize {
u32(n as u32) as usize
}

#[cfg(target_pointer_width = "64")]
#[inline(always)]
pub const fn usize(n: usize) -> usize {
u64(n as u64) as usize
}

// 0 <= val <= i8::MAX
#[inline(always)]
pub const fn i8(n: i8) -> i8 {
u8(n as u8) as i8
}

// 0 <= val <= i16::MAX
#[inline(always)]
pub const fn i16(n: i16) -> i16 {
u16(n as u16) as i16
}

// 0 <= val <= i32::MAX
#[inline(always)]
pub const fn i32(n: i32) -> i32 {
u32(n as u32) as i32
}

// 0 <= val <= i64::MAX
#[inline(always)]
pub const fn i64(n: i64) -> i64 {
u64(n as u64) as i64
}

// 0 <= val <= i128::MAX
#[inline(always)]
pub const fn i128(n: i128) -> i128 {
u128(n as u128) as i128
}

/*
This function is not used.
// 0 <= val <= isize::MAX
#[inline(always)]
pub const fn isize(n: isize) -> isize {
usize(n as usize) as isize
}
*/

/// Instantiate this panic logic once, rather than for all the ilog methods
/// on every single primitive type.
#[cold]
#[track_caller]
pub const fn panic_for_negative_argument() -> ! {
panic!("argument of integer square root cannot be negative")
}
1 change: 1 addition & 0 deletions library/core/src/num/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ mod uint_macros; // import uint_impl!

mod error;
mod int_log10;
mod int_sqrt;
mod nonzero;
mod overflow_panic;
mod saturating;
Expand Down
34 changes: 11 additions & 23 deletions library/core/src/num/nonzero.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1247,31 +1247,19 @@ macro_rules! nonzero_integer_signedness_dependent_methods {
without modifying the original"]
#[inline]
pub const fn isqrt(self) -> Self {
// The algorithm is based on the one presented in
// <https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)>
// which cites as source the following C code:
// <https://web.archive.org/web/20120306040058/http://medialab.freaknet.org/martin/src/sqrt/sqrt.c>.

let mut op = self.get();
let mut res = 0;
let mut one = 1 << (self.ilog2() & !1);

while one != 0 {
if op >= res + one {
op -= res + one;
res = (res >> 1) + one;
} else {
res >>= 1;
}
one >>= 2;
let result = super::int_sqrt::$Int(self.get());

// SAFETY: Inform the optimizer that square roots of positive
// integers are positive and what the maximum result is.
unsafe {
hint::assert_unchecked(result > 0);
const MAX_RESULT: $Int = super::int_sqrt::$Int($Int::MAX);
hint::assert_unchecked(result <= MAX_RESULT);
}

// SAFETY: The result fits in an integer with half as many bits.
// Inform the optimizer about it.
unsafe { hint::assert_unchecked(res < 1 << (Self::BITS / 2)) };

// SAFETY: The square root of an integer >= 1 is always >= 1.
unsafe { Self::new_unchecked(res) }
// SAFETY: The square root of a positive integer is always
// positive.
unsafe { Self::new_unchecked(result) }
}
};

Expand Down
11 changes: 8 additions & 3 deletions library/core/src/num/uint_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2588,10 +2588,15 @@ macro_rules! uint_impl {
without modifying the original"]
#[inline]
pub const fn isqrt(self) -> Self {
match NonZero::new(self) {
Some(x) => x.isqrt().get(),
None => 0,
let result = crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT;

// SAFETY: Inform the optimizer of what the maximum result is.
unsafe {
const MAX_RESULT: $SelfT = crate::num::int_sqrt::$ActualT($ActualT::MAX) as $SelfT;
crate::hint::assert_unchecked(result <= MAX_RESULT);
}

result
}

/// Performs Euclidean division.
Expand Down

0 comments on commit fea585d

Please sign in to comment.