From 0cac9152116b5b4e4c5f84a05e0ceb20dc325a0c Mon Sep 17 00:00:00 2001 From: "Chai T. Rex" Date: Mon, 26 Aug 2024 01:34:25 -0400 Subject: [PATCH] Improve `isqrt` tests and add benchmarks * Choose test inputs more thoroughly and systematically. * Check that `isqrt` and `checked_isqrt` have equivalent results for signed types, either equivalent numerically or equivalent as a panic and a `None`. * Check that `isqrt` has numerically-equivalent results for unsigned types and their `NonZero` counterparts. * Reuse `ilog10` benchmarks, plus benchmarks that use a uniform distribution. --- library/core/benches/lib.rs | 1 + library/core/benches/num/int_sqrt/mod.rs | 62 ++++++ library/core/benches/num/mod.rs | 1 + library/core/tests/num/int_macros.rs | 32 --- library/core/tests/num/int_sqrt.rs | 248 +++++++++++++++++++++++ library/core/tests/num/mod.rs | 1 + 6 files changed, 313 insertions(+), 32 deletions(-) create mode 100644 library/core/benches/num/int_sqrt/mod.rs create mode 100644 library/core/tests/num/int_sqrt.rs diff --git a/library/core/benches/lib.rs b/library/core/benches/lib.rs index 32d15c386cb1b..3f1c58bbd7204 100644 --- a/library/core/benches/lib.rs +++ b/library/core/benches/lib.rs @@ -8,6 +8,7 @@ #![feature(iter_array_chunks)] #![feature(iter_next_chunk)] #![feature(iter_advance_by)] +#![feature(isqrt)] extern crate test; diff --git a/library/core/benches/num/int_sqrt/mod.rs b/library/core/benches/num/int_sqrt/mod.rs new file mode 100644 index 0000000000000..3c9d173e456a1 --- /dev/null +++ b/library/core/benches/num/int_sqrt/mod.rs @@ -0,0 +1,62 @@ +use rand::Rng; +use test::{black_box, Bencher}; + +macro_rules! int_sqrt_bench { + ($t:ty, $predictable:ident, $random:ident, $random_small:ident, $random_uniform:ident) => { + #[bench] + fn $predictable(bench: &mut Bencher) { + bench.iter(|| { + for n in 0..(<$t>::BITS / 8) { + for i in 1..=(100 as $t) { + let x = black_box(i << (n * 8)); + black_box(x.isqrt()); + } + } + }); + } + + #[bench] + fn $random(bench: &mut Bencher) { + let mut rng = crate::bench_rng(); + /* Exponentially distributed random numbers from the whole range of the type. */ + let numbers: Vec<$t> = + (0..256).map(|_| rng.gen::<$t>() >> rng.gen_range(0..<$t>::BITS)).collect(); + bench.iter(|| { + for x in &numbers { + black_box(black_box(x).isqrt()); + } + }); + } + + #[bench] + fn $random_small(bench: &mut Bencher) { + let mut rng = crate::bench_rng(); + /* Exponentially distributed random numbers from the range 0..256. */ + let numbers: Vec<$t> = + (0..256).map(|_| (rng.gen::() >> rng.gen_range(0..u8::BITS)) as $t).collect(); + bench.iter(|| { + for x in &numbers { + black_box(black_box(x).isqrt()); + } + }); + } + + #[bench] + fn $random_uniform(bench: &mut Bencher) { + let mut rng = crate::bench_rng(); + /* Exponentially distributed random numbers from the whole range of the type. */ + let numbers: Vec<$t> = (0..256).map(|_| rng.gen::<$t>()).collect(); + bench.iter(|| { + for x in &numbers { + black_box(black_box(x).isqrt()); + } + }); + } + }; +} + +int_sqrt_bench! {u8, u8_sqrt_predictable, u8_sqrt_random, u8_sqrt_random_small, u8_sqrt_uniform} +int_sqrt_bench! {u16, u16_sqrt_predictable, u16_sqrt_random, u16_sqrt_random_small, u16_sqrt_uniform} +int_sqrt_bench! {u32, u32_sqrt_predictable, u32_sqrt_random, u32_sqrt_random_small, u32_sqrt_uniform} +int_sqrt_bench! {u64, u64_sqrt_predictable, u64_sqrt_random, u64_sqrt_random_small, u64_sqrt_uniform} +int_sqrt_bench! {u128, u128_sqrt_predictable, u128_sqrt_random, u128_sqrt_random_small, u128_sqrt_uniform} diff --git a/library/core/benches/num/mod.rs b/library/core/benches/num/mod.rs index c1dc3a3062256..7ff7443cfa7fe 100644 --- a/library/core/benches/num/mod.rs +++ b/library/core/benches/num/mod.rs @@ -2,6 +2,7 @@ mod dec2flt; mod flt2dec; mod int_log; mod int_pow; +mod int_sqrt; use std::str::FromStr; diff --git a/library/core/tests/num/int_macros.rs b/library/core/tests/num/int_macros.rs index 7cd3b54e3f39a..830a96204ca03 100644 --- a/library/core/tests/num/int_macros.rs +++ b/library/core/tests/num/int_macros.rs @@ -288,38 +288,6 @@ macro_rules! int_module { assert_eq!(r.saturating_pow(0), 1 as $T); } - #[test] - fn test_isqrt() { - assert_eq!($T::MIN.checked_isqrt(), None); - assert_eq!((-1 as $T).checked_isqrt(), None); - assert_eq!((0 as $T).isqrt(), 0 as $T); - assert_eq!((1 as $T).isqrt(), 1 as $T); - assert_eq!((2 as $T).isqrt(), 1 as $T); - assert_eq!((99 as $T).isqrt(), 9 as $T); - assert_eq!((100 as $T).isqrt(), 10 as $T); - } - - #[cfg(not(miri))] // Miri is too slow - #[test] - fn test_lots_of_isqrt() { - let n_max: $T = (1024 * 1024).min($T::MAX as u128) as $T; - for n in 0..=n_max { - let isqrt: $T = n.isqrt(); - - assert!(isqrt.pow(2) <= n); - let (square, overflow) = (isqrt + 1).overflowing_pow(2); - assert!(overflow || square > n); - } - - for n in ($T::MAX - 127)..=$T::MAX { - let isqrt: $T = n.isqrt(); - - assert!(isqrt.pow(2) <= n); - let (square, overflow) = (isqrt + 1).overflowing_pow(2); - assert!(overflow || square > n); - } - } - #[test] fn test_div_floor() { let a: $T = 8; diff --git a/library/core/tests/num/int_sqrt.rs b/library/core/tests/num/int_sqrt.rs new file mode 100644 index 0000000000000..d68db0787d22c --- /dev/null +++ b/library/core/tests/num/int_sqrt.rs @@ -0,0 +1,248 @@ +macro_rules! tests { + ($isqrt_consistency_check_fn_macro:ident : $($T:ident)+) => { + $( + mod $T { + $isqrt_consistency_check_fn_macro!($T); + + // Check that the following produce the correct values from + // `isqrt`: + // + // * the first and last 128 nonnegative values + // * powers of two, minus one + // * powers of two + // + // For signed types, check that `checked_isqrt` and `isqrt` + // either produce the same numeric value or respectively + // produce `None` and a panic. Make sure to do a consistency + // check for `<$T>::MIN` as well, as no nonnegative values + // negate to it. + // + // For unsigned types check that `isqrt` produces the same + // numeric value for `$T` and `NonZero<$T>`. + #[test] + fn isqrt() { + isqrt_consistency_check(<$T>::MIN); + + for n in (0..=127) + .chain(<$T>::MAX - 127..=<$T>::MAX) + .chain((0..<$T>::MAX.count_ones()).map(|exponent| (1 << exponent) - 1)) + .chain((0..<$T>::MAX.count_ones()).map(|exponent| 1 << exponent)) + { + isqrt_consistency_check(n); + + let isqrt_n = n.isqrt(); + assert!( + isqrt_n + .checked_mul(isqrt_n) + .map(|isqrt_n_squared| isqrt_n_squared <= n) + .unwrap_or(false), + "`{n}.isqrt()` should be lower than {isqrt_n}." + ); + assert!( + (isqrt_n + 1) + .checked_mul(isqrt_n + 1) + .map(|isqrt_n_plus_1_squared| n < isqrt_n_plus_1_squared) + .unwrap_or(true), + "`{n}.isqrt()` should be higher than {isqrt_n})." + ); + } + } + + // Check the square roots of: + // + // * the first 1,024 perfect squares + // * halfway between each of the first 1,024 perfect squares + // and the next perfect square + // * the next perfect square after the each of the first 1,024 + // perfect squares, minus one + // * the last 1,024 perfect squares + // * the last 1,024 perfect squares, minus one + // * halfway between each of the last 1,024 perfect squares + // and the previous perfect square + #[test] + // Skip this test on Miri, as it takes too long to run. + #[cfg(not(miri))] + fn isqrt_extended() { + // The correct value is worked out by using the fact that + // the nth nonzero perfect square is the sum of the first n + // odd numbers: + // + // 1 = 1 + // 4 = 1 + 3 + // 9 = 1 + 3 + 5 + // 16 = 1 + 3 + 5 + 7 + // + // Note also that the last odd number added in is two times + // the square root of the previous perfect square, plus + // one: + // + // 1 = 2*0 + 1 + // 3 = 2*1 + 1 + // 5 = 2*2 + 1 + // 7 = 2*3 + 1 + // + // That means we can add the square root of this perfect + // square once to get about halfway to the next perfect + // square, then we can add the square root of this perfect + // square again to get to the next perfect square, minus + // one, then we can add one to get to the next perfect + // square. + // + // This allows us to, for each of the first 1,024 perfect + // squares, test that the square roots of the following are + // all correct and equal to each other: + // + // * the current perfect square + // * about halfway to the next perfect square + // * the next perfect square, minus one + let mut n: $T = 0; + for sqrt_n in 0..1_024.min((1_u128 << (<$T>::MAX.count_ones()/2)) - 1) as $T { + isqrt_consistency_check(n); + assert_eq!( + n.isqrt(), + sqrt_n, + "`{sqrt_n}.pow(2).isqrt()` should be {sqrt_n}." + ); + + n += sqrt_n; + isqrt_consistency_check(n); + assert_eq!( + n.isqrt(), + sqrt_n, + "{n} is about halfway between `{sqrt_n}.pow(2)` and `{}.pow(2)`, so `{n}.isqrt()` should be {sqrt_n}.", + sqrt_n + 1 + ); + + n += sqrt_n; + isqrt_consistency_check(n); + assert_eq!( + n.isqrt(), + sqrt_n, + "`({}.pow(2) - 1).isqrt()` should be {sqrt_n}.", + sqrt_n + 1 + ); + + n += 1; + } + + // Similarly, for each of the last 1,024 perfect squares, + // check: + // + // * the current perfect square + // * the current perfect square, minus one + // * about halfway to the previous perfect square + // + // `MAX`'s `isqrt` return value is verified in the `isqrt` + // test function above. + let maximum_sqrt = <$T>::MAX.isqrt(); + let mut n = maximum_sqrt * maximum_sqrt; + + for sqrt_n in (maximum_sqrt - 1_024.min((1_u128 << (<$T>::MAX.count_ones()/2)) - 1) as $T..maximum_sqrt).rev() { + isqrt_consistency_check(n); + assert_eq!( + n.isqrt(), + sqrt_n + 1, + "`{0}.pow(2).isqrt()` should be {0}.", + sqrt_n + 1 + ); + + n -= 1; + isqrt_consistency_check(n); + assert_eq!( + n.isqrt(), + sqrt_n, + "`({}.pow(2) - 1).isqrt()` should be {sqrt_n}.", + sqrt_n + 1 + ); + + n -= sqrt_n; + isqrt_consistency_check(n); + assert_eq!( + n.isqrt(), + sqrt_n, + "{n} is about halfway between `{sqrt_n}.pow(2)` and `{}.pow(2)`, so `{n}.isqrt()` should be {sqrt_n}.", + sqrt_n + 1 + ); + + n -= sqrt_n; + } + } + } + )* + }; +} + +macro_rules! signed_check { + ($T:ident) => { + /// This takes an input and, if it's nonnegative or + #[doc = concat!("`", stringify!($T), "::MIN`,")] + /// checks that `isqrt` and `checked_isqrt` produce equivalent results + /// for that input and for the negative of that input. + /// + /// # Note + /// + /// This cannot check that negative inputs to `isqrt` cause panics if + /// panics abort instead of unwind. + fn isqrt_consistency_check(n: $T) { + // `<$T>::MIN` will be negative, so ignore it in this nonnegative + // section. + if n >= 0 { + assert_eq!( + Some(n.isqrt()), + n.checked_isqrt(), + "`{n}.checked_isqrt()` should match `Some({n}.isqrt())`.", + ); + } + + // `wrapping_neg` so that `<$T>::MIN` will negate to itself rather + // than panicking. + let negative_n = n.wrapping_neg(); + + // Zero negated will still be nonnegative, so ignore it in this + // negative section. + if negative_n < 0 { + assert_eq!( + negative_n.checked_isqrt(), + None, + "`({negative_n}).checked_isqrt()` should be `None`, as {negative_n} is negative.", + ); + + // `catch_unwind` only works when panics unwind rather than abort. + #[cfg(panic = "unwind")] + { + std::panic::catch_unwind(core::panic::AssertUnwindSafe(|| (-n).isqrt())).expect_err( + &format!("`({negative_n}).isqrt()` should have panicked, as {negative_n} is negative.") + ); + } + } + } + }; +} + +macro_rules! unsigned_check { + ($T:ident) => { + /// This takes an input and, if it's nonzero, checks that `isqrt` + /// produces the same numeric value for both + #[doc = concat!("`", stringify!($T), "` and ")] + #[doc = concat!("`NonZero<", stringify!($T), ">`.")] + fn isqrt_consistency_check(n: $T) { + // Zero cannot be turned into a `NonZero` value, so ignore it in + // this nonzero section. + if n > 0 { + assert_eq!( + n.isqrt(), + core::num::NonZero::<$T>::new(n) + .expect( + "Was not able to create a new `NonZero` value from a nonzero number." + ) + .isqrt() + .get(), + "`{n}.isqrt` should match `NonZero`'s `{n}.isqrt().get()`.", + ); + } + } + }; +} + +tests!(signed_check: i8 i16 i32 i64 i128); +tests!(unsigned_check: u8 u16 u32 u64 u128); diff --git a/library/core/tests/num/mod.rs b/library/core/tests/num/mod.rs index dad46ad88fe19..b14fe0b22c311 100644 --- a/library/core/tests/num/mod.rs +++ b/library/core/tests/num/mod.rs @@ -27,6 +27,7 @@ mod const_from; mod dec2flt; mod flt2dec; mod int_log; +mod int_sqrt; mod ops; mod wrapping;