diff --git a/library/alloc/src/slice.rs b/library/alloc/src/slice.rs index a92d22b1c309e..e3c7835f1d10b 100644 --- a/library/alloc/src/slice.rs +++ b/library/alloc/src/slice.rs @@ -19,20 +19,6 @@ use core::cmp::Ordering::{self, Less}; use core::mem::{self, MaybeUninit}; #[cfg(not(no_global_oom_handling))] use core::ptr; -#[cfg(not(no_global_oom_handling))] -use core::slice::sort; - -use crate::alloc::Allocator; -#[cfg(not(no_global_oom_handling))] -use crate::alloc::Global; -#[cfg(not(no_global_oom_handling))] -use crate::borrow::ToOwned; -use crate::boxed::Box; -use crate::vec::Vec; - -#[cfg(test)] -mod tests; - #[unstable(feature = "array_chunks", issue = "74985")] pub use core::slice::ArrayChunks; #[unstable(feature = "array_chunks", issue = "74985")] @@ -43,6 +29,8 @@ pub use core::slice::ArrayWindows; pub use core::slice::EscapeAscii; #[stable(feature = "slice_get_slice", since = "1.28.0")] pub use core::slice::SliceIndex; +#[cfg(not(no_global_oom_handling))] +use core::slice::sort; #[stable(feature = "slice_group_by", since = "1.77.0")] pub use core::slice::{ChunkBy, ChunkByMut}; #[stable(feature = "rust1", since = "1.0.0")] @@ -83,6 +71,14 @@ pub use hack::into_vec; #[cfg(test)] pub use hack::to_vec; +use crate::alloc::Allocator; +#[cfg(not(no_global_oom_handling))] +use crate::alloc::Global; +#[cfg(not(no_global_oom_handling))] +use crate::borrow::ToOwned; +use crate::boxed::Box; +use crate::vec::Vec; + // HACK(japaric): With cfg(test) `impl [T]` is not available, these three // functions are actually methods that are in `impl [T]` but not in // `core::slice::SliceExt` - we need to supply these functions for the diff --git a/library/alloc/src/slice/tests.rs b/library/alloc/src/slice/tests.rs deleted file mode 100644 index 786704caeb0ad..0000000000000 --- a/library/alloc/src/slice/tests.rs +++ /dev/null @@ -1,369 +0,0 @@ -use core::cell::Cell; -use core::cmp::Ordering::{self, Equal, Greater, Less}; -use core::convert::identity; -use core::sync::atomic::AtomicUsize; -use core::sync::atomic::Ordering::Relaxed; -use core::{fmt, mem}; -use std::panic; - -use rand::distributions::Standard; -use rand::prelude::*; -use rand::{Rng, RngCore}; - -use crate::borrow::ToOwned; -use crate::rc::Rc; -use crate::string::ToString; -use crate::test_helpers::test_rng; -use crate::vec::Vec; - -macro_rules! do_test { - ($input:ident, $func:ident) => { - let len = $input.len(); - - // Work out the total number of comparisons required to sort - // this array... - let mut count = 0usize; - $input.to_owned().$func(|a, b| { - count += 1; - a.cmp(b) - }); - - // ... and then panic on each and every single one. - for panic_countdown in 0..count { - // Refresh the counters. - VERSIONS.store(0, Relaxed); - for i in 0..len { - DROP_COUNTS[i].store(0, Relaxed); - } - - let v = $input.to_owned(); - let _ = panic::catch_unwind(move || { - let mut v = v; - let mut panic_countdown = panic_countdown; - v.$func(|a, b| { - if panic_countdown == 0 { - SILENCE_PANIC.with(|s| s.set(true)); - panic!(); - } - panic_countdown -= 1; - a.cmp(b) - }) - }); - - // Check that the number of things dropped is exactly - // what we expect (i.e., the contents of `v`). - for (i, c) in DROP_COUNTS.iter().enumerate().take(len) { - let count = c.load(Relaxed); - assert!(count == 1, "found drop count == {} for i == {}, len == {}", count, i, len); - } - - // Check that the most recent versions of values were dropped. - assert_eq!(VERSIONS.load(Relaxed), 0); - } - }; -} - -const MAX_LEN: usize = 80; - -static DROP_COUNTS: [AtomicUsize; MAX_LEN] = [ - // FIXME(RFC 1109): AtomicUsize is not Copy. - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), - AtomicUsize::new(0), -]; - -static VERSIONS: AtomicUsize = AtomicUsize::new(0); - -#[derive(Clone, Eq)] -struct DropCounter { - x: u32, - id: usize, - version: Cell, -} - -impl PartialEq for DropCounter { - fn eq(&self, other: &Self) -> bool { - self.partial_cmp(other) == Some(Ordering::Equal) - } -} - -impl PartialOrd for DropCounter { - fn partial_cmp(&self, other: &Self) -> Option { - self.version.set(self.version.get() + 1); - other.version.set(other.version.get() + 1); - VERSIONS.fetch_add(2, Relaxed); - self.x.partial_cmp(&other.x) - } -} - -impl Ord for DropCounter { - fn cmp(&self, other: &Self) -> Ordering { - self.partial_cmp(other).unwrap() - } -} - -impl Drop for DropCounter { - fn drop(&mut self) { - DROP_COUNTS[self.id].fetch_add(1, Relaxed); - VERSIONS.fetch_sub(self.version.get(), Relaxed); - } -} - -std::thread_local!(static SILENCE_PANIC: Cell = Cell::new(false)); - -#[test] -#[cfg_attr(target_os = "emscripten", ignore)] // no threads -#[cfg_attr(not(panic = "unwind"), ignore = "test requires unwinding support")] -fn panic_safe() { - panic::update_hook(move |prev, info| { - if !SILENCE_PANIC.with(|s| s.get()) { - prev(info); - } - }); - - let mut rng = test_rng(); - - // Miri is too slow (but still need to `chain` to make the types match) - let lens = if cfg!(miri) { (1..10).chain(0..0) } else { (1..20).chain(70..MAX_LEN) }; - let moduli: &[u32] = if cfg!(miri) { &[5] } else { &[5, 20, 50] }; - - for len in lens { - for &modulus in moduli { - for &has_runs in &[false, true] { - let mut input = (0..len) - .map(|id| DropCounter { - x: rng.next_u32() % modulus, - id: id, - version: Cell::new(0), - }) - .collect::>(); - - if has_runs { - for c in &mut input { - c.x = c.id as u32; - } - - for _ in 0..5 { - let a = rng.gen::() % len; - let b = rng.gen::() % len; - if a < b { - input[a..b].reverse(); - } else { - input.swap(a, b); - } - } - } - - do_test!(input, sort_by); - do_test!(input, sort_unstable_by); - } - } - } - - // Set default panic hook again. - drop(panic::take_hook()); -} - -#[test] -#[cfg_attr(miri, ignore)] // Miri is too slow -#[cfg_attr(not(panic = "unwind"), ignore = "test requires unwinding support")] -fn test_sort() { - let mut rng = test_rng(); - - for len in (2..25).chain(500..510) { - for &modulus in &[5, 10, 100, 1000] { - for _ in 0..10 { - let orig: Vec<_> = (&mut rng) - .sample_iter::(&Standard) - .map(|x| x % modulus) - .take(len) - .collect(); - - // Sort in default order. - let mut v = orig.clone(); - v.sort(); - assert!(v.windows(2).all(|w| w[0] <= w[1])); - - // Sort in ascending order. - let mut v = orig.clone(); - v.sort_by(|a, b| a.cmp(b)); - assert!(v.windows(2).all(|w| w[0] <= w[1])); - - // Sort in descending order. - let mut v = orig.clone(); - v.sort_by(|a, b| b.cmp(a)); - assert!(v.windows(2).all(|w| w[0] >= w[1])); - - // Sort in lexicographic order. - let mut v1 = orig.clone(); - let mut v2 = orig.clone(); - v1.sort_by_key(|x| x.to_string()); - v2.sort_by_cached_key(|x| x.to_string()); - assert!(v1.windows(2).all(|w| w[0].to_string() <= w[1].to_string())); - assert!(v1 == v2); - - // Sort with many pre-sorted runs. - let mut v = orig.clone(); - v.sort(); - v.reverse(); - for _ in 0..5 { - let a = rng.gen::() % len; - let b = rng.gen::() % len; - if a < b { - v[a..b].reverse(); - } else { - v.swap(a, b); - } - } - v.sort(); - assert!(v.windows(2).all(|w| w[0] <= w[1])); - } - } - } - - const ORD_VIOLATION_MAX_LEN: usize = 500; - let mut v = [0; ORD_VIOLATION_MAX_LEN]; - for i in 0..ORD_VIOLATION_MAX_LEN { - v[i] = i as i32; - } - - // Sort using a completely random comparison function. This will reorder the elements *somehow*, - // it may panic but the original elements must still be present. - let _ = panic::catch_unwind(move || { - v.sort_by(|_, _| *[Less, Equal, Greater].choose(&mut rng).unwrap()); - }); - - v.sort(); - for i in 0..ORD_VIOLATION_MAX_LEN { - assert_eq!(v[i], i as i32); - } - - // Should not panic. - [0i32; 0].sort(); - [(); 10].sort(); - [(); 100].sort(); - - let mut v = [0xDEADBEEFu64]; - v.sort(); - assert!(v == [0xDEADBEEF]); -} - -#[test] -fn test_sort_stability() { - // Miri is too slow - let large_range = if cfg!(miri) { 0..0 } else { 500..510 }; - let rounds = if cfg!(miri) { 1 } else { 10 }; - - let mut rng = test_rng(); - for len in (2..25).chain(large_range) { - for _ in 0..rounds { - let mut counts = [0; 10]; - - // create a vector like [(6, 1), (5, 1), (6, 2), ...], - // where the first item of each tuple is random, but - // the second item represents which occurrence of that - // number this element is, i.e., the second elements - // will occur in sorted order. - let orig: Vec<_> = (0..len) - .map(|_| { - let n = rng.gen::() % 10; - counts[n] += 1; - (n, counts[n]) - }) - .collect(); - - let mut v = orig.clone(); - // Only sort on the first element, so an unstable sort - // may mix up the counts. - v.sort_by(|&(a, _), &(b, _)| a.cmp(&b)); - - // This comparison includes the count (the second item - // of the tuple), so elements with equal first items - // will need to be ordered with increasing - // counts... i.e., exactly asserting that this sort is - // stable. - assert!(v.windows(2).all(|w| w[0] <= w[1])); - - let mut v = orig.clone(); - v.sort_by_cached_key(|&(x, _)| x); - assert!(v.windows(2).all(|w| w[0] <= w[1])); - } - } -} diff --git a/library/alloc/tests/lib.rs b/library/alloc/tests/lib.rs index 1d07a7690da43..23efd605ff9ee 100644 --- a/library/alloc/tests/lib.rs +++ b/library/alloc/tests/lib.rs @@ -41,6 +41,7 @@ #![feature(local_waker)] #![feature(vec_pop_if)] #![feature(unique_rc_arc)] +#![feature(macro_metavar_expr_concat)] #![allow(internal_features)] #![deny(fuzzy_provenance_casts)] #![deny(unsafe_op_in_unsafe_fn)] @@ -60,6 +61,7 @@ mod heap; mod linked_list; mod rc; mod slice; +mod sort; mod str; mod string; mod task; diff --git a/library/alloc/tests/sort/ffi_types.rs b/library/alloc/tests/sort/ffi_types.rs new file mode 100644 index 0000000000000..11515ea476971 --- /dev/null +++ b/library/alloc/tests/sort/ffi_types.rs @@ -0,0 +1,82 @@ +use std::cmp::Ordering; + +// Very large stack value. +#[repr(C)] +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct FFIOneKibiByte { + values: [i64; 128], +} + +impl FFIOneKibiByte { + pub fn new(val: i32) -> Self { + let mut values = [0i64; 128]; + let mut val_i64 = val as i64; + + for elem in &mut values { + *elem = val_i64; + val_i64 = std::hint::black_box(val_i64 + 1); + } + Self { values } + } + + fn as_i64(&self) -> i64 { + self.values[11] + self.values[55] + self.values[77] + } +} + +impl PartialOrd for FFIOneKibiByte { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for FFIOneKibiByte { + fn cmp(&self, other: &Self) -> Ordering { + self.as_i64().cmp(&other.as_i64()) + } +} + +// 16 byte stack value, with more expensive comparison. +#[repr(C)] +#[derive(PartialEq, Debug, Clone, Copy)] +pub struct F128 { + x: f64, + y: f64, +} + +impl F128 { + pub fn new(val: i32) -> Self { + let val_f = (val as f64) + (i32::MAX as f64) + 10.0; + + let x = val_f + 0.1; + let y = val_f.log(4.1); + + assert!(y < x); + assert!(x.is_normal() && y.is_normal()); + + Self { x, y } + } +} + +// This is kind of hacky, but we know we only have normal comparable floats in there. +impl Eq for F128 {} + +impl PartialOrd for F128 { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +// Goal is similar code-gen between Rust and C++ +// - Rust https://godbolt.org/z/3YM3xenPP +// - C++ https://godbolt.org/z/178M6j1zz +impl Ord for F128 { + fn cmp(&self, other: &Self) -> Ordering { + // Simulate expensive comparison function. + let this_div = self.x / self.y; + let other_div = other.x / other.y; + + // SAFETY: We checked in the ctor that both are normal. + unsafe { this_div.partial_cmp(&other_div).unwrap_unchecked() } + } +} diff --git a/library/alloc/tests/sort/known_good_stable_sort.rs b/library/alloc/tests/sort/known_good_stable_sort.rs new file mode 100644 index 0000000000000..f8615435fc2a7 --- /dev/null +++ b/library/alloc/tests/sort/known_good_stable_sort.rs @@ -0,0 +1,192 @@ +// This module implements a known good stable sort implementation that helps provide better error +// messages when the correctness tests fail, we can't use the stdlib sort functions because we are +// testing them for correctness. +// +// Based on https://github.com/voultapher/tiny-sort-rs. + +use alloc::alloc::{Layout, alloc, dealloc}; +use std::{mem, ptr}; + +/// Sort `v` preserving initial order of equal elements. +/// +/// - Guaranteed O(N * log(N)) worst case perf +/// - No adaptiveness +/// - Branch miss-prediction not affected by outcome of comparison function +/// - Uses `v.len()` auxiliary memory. +/// +/// If `T: Ord` does not implement a total order the resulting order is +/// unspecified. All original elements will remain in `v` and any possible modifications via +/// interior mutability will be observable. Same is true if `T: Ord` panics. +/// +/// Panics if allocating the auxiliary memory fails. +#[inline(always)] +pub fn sort(v: &mut [T]) { + stable_sort(v, |a, b| a.lt(b)) +} + +#[inline(always)] +fn stable_sort bool>(v: &mut [T], mut is_less: F) { + if mem::size_of::() == 0 { + return; + } + + let len = v.len(); + + // Inline the check for len < 2. This happens a lot, instrumenting the Rust compiler suggests + // len < 2 accounts for 94% of its calls to `slice::sort`. + if len < 2 { + return; + } + + // SAFETY: We checked that len is > 0 and that T is not a ZST. + unsafe { + mergesort_main(v, &mut is_less); + } +} + +/// The core logic should not be inlined. +/// +/// SAFETY: The caller has to ensure that len is > 0 and that T is not a ZST. +#[inline(never)] +unsafe fn mergesort_main bool>(v: &mut [T], is_less: &mut F) { + // While it would be nice to have a merge implementation that only requires N / 2 auxiliary + // memory. Doing so would make the merge implementation significantly more complex and + + // SAFETY: See function safety description. + let buf = unsafe { BufGuard::new(v.len()) }; + + // SAFETY: `scratch` has space for `v.len()` writes. And does not alias `v`. + unsafe { + mergesort_core(v, buf.buf_ptr.as_ptr(), is_less); + } +} + +/// Tiny recursive top-down merge sort optimized for binary size. It has no adaptiveness whatsoever, +/// no run detection, etc. +/// +/// Buffer as pointed to by `scratch` must have space for `v.len()` writes. And must not alias `v`. +#[inline(always)] +unsafe fn mergesort_core bool>( + v: &mut [T], + scratch_ptr: *mut T, + is_less: &mut F, +) { + let len = v.len(); + + if len > 2 { + // SAFETY: `mid` is guaranteed in-bounds. And caller has to ensure that `scratch_ptr` can + // hold `v.len()` values. + unsafe { + let mid = len / 2; + // Sort the left half recursively. + mergesort_core(v.get_unchecked_mut(..mid), scratch_ptr, is_less); + // Sort the right half recursively. + mergesort_core(v.get_unchecked_mut(mid..), scratch_ptr, is_less); + // Combine the two halves. + merge(v, scratch_ptr, is_less, mid); + } + } else if len == 2 { + if is_less(&v[1], &v[0]) { + v.swap(0, 1); + } + } +} + +/// Branchless merge function. +/// +/// SAFETY: The caller must ensure that `scratch_ptr` is valid for `v.len()` writes. And that mid is +/// in-bounds. +#[inline(always)] +unsafe fn merge(v: &mut [T], scratch_ptr: *mut T, is_less: &mut F, mid: usize) +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + debug_assert!(mid > 0 && mid < len); + + let len = v.len(); + + // Indexes to track the positions while merging. + let mut l = 0; + let mut r = mid; + + // SAFETY: No matter what the result of is_less is we check that l and r remain in-bounds and if + // is_less panics the original elements remain in `v`. + unsafe { + let arr_ptr = v.as_ptr(); + + for i in 0..len { + let left_ptr = arr_ptr.add(l); + let right_ptr = arr_ptr.add(r); + + let is_lt = !is_less(&*right_ptr, &*left_ptr); + let copy_ptr = if is_lt { left_ptr } else { right_ptr }; + ptr::copy_nonoverlapping(copy_ptr, scratch_ptr.add(i), 1); + + l += is_lt as usize; + r += !is_lt as usize; + + // As long as neither side is exhausted merge left and right elements. + if ((l == mid) as u8 + (r == len) as u8) != 0 { + break; + } + } + + // The left or right side is exhausted, drain the right side in one go. + let copy_ptr = if l == mid { arr_ptr.add(r) } else { arr_ptr.add(l) }; + let i = l + (r - mid); + ptr::copy_nonoverlapping(copy_ptr, scratch_ptr.add(i), len - i); + + // Now that scratch_ptr holds the full merged content, write it back on-top of v. + ptr::copy_nonoverlapping(scratch_ptr, v.as_mut_ptr(), len); + } +} + +// SAFETY: The caller has to ensure that Option is Some, UB otherwise. +unsafe fn unwrap_unchecked(opt_val: Option) -> T { + match opt_val { + Some(val) => val, + None => { + // SAFETY: See function safety description. + unsafe { + core::hint::unreachable_unchecked(); + } + } + } +} + +// Extremely basic versions of Vec. +// Their use is super limited and by having the code here, it allows reuse between the sort +// implementations. +struct BufGuard { + buf_ptr: ptr::NonNull, + capacity: usize, +} + +impl BufGuard { + // SAFETY: The caller has to ensure that len is not 0 and that T is not a ZST. + unsafe fn new(len: usize) -> Self { + debug_assert!(len > 0 && mem::size_of::() > 0); + + // SAFETY: See function safety description. + let layout = unsafe { unwrap_unchecked(Layout::array::(len).ok()) }; + + // SAFETY: We checked that T is not a ZST. + let buf_ptr = unsafe { alloc(layout) as *mut T }; + + if buf_ptr.is_null() { + panic!("allocation failure"); + } + + Self { buf_ptr: ptr::NonNull::new(buf_ptr).unwrap(), capacity: len } + } +} + +impl Drop for BufGuard { + fn drop(&mut self) { + // SAFETY: We checked that T is not a ZST. + unsafe { + dealloc(self.buf_ptr.as_ptr() as *mut u8, Layout::array::(self.capacity).unwrap()); + } + } +} diff --git a/library/alloc/tests/sort/mod.rs b/library/alloc/tests/sort/mod.rs new file mode 100644 index 0000000000000..0e2494ca9d34e --- /dev/null +++ b/library/alloc/tests/sort/mod.rs @@ -0,0 +1,17 @@ +pub trait Sort { + fn name() -> String; + + fn sort(v: &mut [T]) + where + T: Ord; + + fn sort_by(v: &mut [T], compare: F) + where + F: FnMut(&T, &T) -> std::cmp::Ordering; +} + +mod ffi_types; +mod known_good_stable_sort; +mod patterns; +mod tests; +mod zipf; diff --git a/library/alloc/tests/sort/patterns.rs b/library/alloc/tests/sort/patterns.rs new file mode 100644 index 0000000000000..e5d31d868b251 --- /dev/null +++ b/library/alloc/tests/sort/patterns.rs @@ -0,0 +1,211 @@ +use std::env; +use std::hash::Hash; +use std::str::FromStr; +use std::sync::OnceLock; + +use rand::prelude::*; +use rand_xorshift::XorShiftRng; + +use crate::sort::zipf::ZipfDistribution; + +/// Provides a set of patterns useful for testing and benchmarking sorting algorithms. +/// Currently limited to i32 values. + +// --- Public --- + +pub fn random(len: usize) -> Vec { + // . + // : . : : + // :.:::.:: + + random_vec(len) +} + +pub fn random_uniform(len: usize, range: R) -> Vec +where + R: Into> + Hash, +{ + // :.:.:.:: + + let mut rng: XorShiftRng = rand::SeedableRng::seed_from_u64(get_or_init_rand_seed()); + + // Abstracting over ranges in Rust :( + let dist: rand::distributions::Uniform = range.into(); + (0..len).map(|_| dist.sample(&mut rng)).collect() +} + +pub fn random_zipf(len: usize, exponent: f64) -> Vec { + // https://en.wikipedia.org/wiki/Zipf's_law + + let mut rng: XorShiftRng = rand::SeedableRng::seed_from_u64(get_or_init_rand_seed()); + + // Abstracting over ranges in Rust :( + let dist = ZipfDistribution::new(len, exponent).unwrap(); + (0..len).map(|_| dist.sample(&mut rng) as i32).collect() +} + +pub fn random_sorted(len: usize, sorted_percent: f64) -> Vec { + // .: + // .:::. : + // .::::::.:: + // [----][--] + // ^ ^ + // | | + // sorted | + // unsorted + + // Simulate pre-existing sorted slice, where len - sorted_percent are the new unsorted values + // and part of the overall distribution. + let mut v = random_vec(len); + let sorted_len = ((len as f64) * (sorted_percent / 100.0)).round() as usize; + + v[0..sorted_len].sort_unstable(); + + v +} + +pub fn all_equal(len: usize) -> Vec { + // ...... + // :::::: + + (0..len).map(|_| 66).collect::>() +} + +pub fn ascending(len: usize) -> Vec { + // .: + // .::: + // .::::: + + (0..len as i32).collect::>() +} + +pub fn descending(len: usize) -> Vec { + // :. + // :::. + // :::::. + + (0..len as i32).rev().collect::>() +} + +pub fn saw_mixed(len: usize, saw_count: usize) -> Vec { + // :. :. .::. .: + // :::.:::..::::::..::: + + if len == 0 { + return Vec::new(); + } + + let mut vals = random_vec(len); + let chunks_size = len / saw_count.max(1); + let saw_directions = random_uniform((len / chunks_size) + 1, 0..=1); + + for (i, chunk) in vals.chunks_mut(chunks_size).enumerate() { + if saw_directions[i] == 0 { + chunk.sort_unstable(); + } else if saw_directions[i] == 1 { + chunk.sort_unstable_by_key(|&e| std::cmp::Reverse(e)); + } else { + unreachable!(); + } + } + + vals +} + +pub fn saw_mixed_range(len: usize, range: std::ops::Range) -> Vec { + // :. + // :. :::. .::. .: + // :::.:::::..::::::..:.::: + + // ascending and descending randomly picked, with length in `range`. + + if len == 0 { + return Vec::new(); + } + + let mut vals = random_vec(len); + + let max_chunks = len / range.start; + let saw_directions = random_uniform(max_chunks + 1, 0..=1); + let chunk_sizes = random_uniform(max_chunks + 1, (range.start as i32)..(range.end as i32)); + + let mut i = 0; + let mut l = 0; + while l < len { + let chunk_size = chunk_sizes[i] as usize; + let chunk_end = std::cmp::min(l + chunk_size, len); + let chunk = &mut vals[l..chunk_end]; + + if saw_directions[i] == 0 { + chunk.sort_unstable(); + } else if saw_directions[i] == 1 { + chunk.sort_unstable_by_key(|&e| std::cmp::Reverse(e)); + } else { + unreachable!(); + } + + i += 1; + l += chunk_size; + } + + vals +} + +pub fn pipe_organ(len: usize) -> Vec { + // .:. + // .:::::. + + let mut vals = random_vec(len); + + let first_half = &mut vals[0..(len / 2)]; + first_half.sort_unstable(); + + let second_half = &mut vals[(len / 2)..len]; + second_half.sort_unstable_by_key(|&e| std::cmp::Reverse(e)); + + vals +} + +pub fn get_or_init_rand_seed() -> u64 { + *SEED_VALUE.get_or_init(|| { + env::var("OVERRIDE_SEED") + .ok() + .map(|seed| u64::from_str(&seed).unwrap()) + .unwrap_or_else(rand_root_seed) + }) +} + +// --- Private --- + +static SEED_VALUE: OnceLock = OnceLock::new(); + +#[cfg(not(miri))] +fn rand_root_seed() -> u64 { + // Other test code hashes `panic::Location::caller()` and constructs a seed from that, in these + // tests we want to have a fuzzer like exploration of the test space, if we used the same caller + // based construction we would always test the same. + // + // Instead we use the seconds since UNIX epoch / 10, given CI log output this value should be + // reasonably easy to re-construct. + + use std::time::{SystemTime, UNIX_EPOCH}; + + let epoch_seconds = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + + epoch_seconds / 10 +} + +#[cfg(miri)] +fn rand_root_seed() -> u64 { + // Miri is usually run with isolation with gives us repeatability but also permutations based on + // other code that runs before. + use core::hash::{BuildHasher, Hash, Hasher}; + let mut hasher = std::hash::RandomState::new().build_hasher(); + core::panic::Location::caller().hash(&mut hasher); + hasher.finish() +} + +fn random_vec(len: usize) -> Vec { + let mut rng: XorShiftRng = rand::SeedableRng::seed_from_u64(get_or_init_rand_seed()); + (0..len).map(|_| rng.gen::()).collect() +} diff --git a/library/alloc/tests/sort/tests.rs b/library/alloc/tests/sort/tests.rs new file mode 100644 index 0000000000000..14e6013f965d8 --- /dev/null +++ b/library/alloc/tests/sort/tests.rs @@ -0,0 +1,1233 @@ +use std::cell::Cell; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::panic::{self, AssertUnwindSafe}; +use std::rc::Rc; +use std::{env, fs}; + +use crate::sort::ffi_types::{F128, FFIOneKibiByte}; +use crate::sort::{Sort, known_good_stable_sort, patterns}; + +#[cfg(miri)] +const TEST_LENGTHS: &[usize] = &[2, 3, 4, 7, 10, 15, 20, 24, 33, 50, 100, 171, 300]; + +#[cfg(not(miri))] +const TEST_LENGTHS: &[usize] = &[ + 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 16, 17, 20, 24, 30, 32, 33, 35, 50, 100, 200, 500, 1_000, + 2_048, 5_000, 10_000, 100_000, 1_100_000, +]; + +fn check_is_sorted(v: &mut [T]) { + let seed = patterns::get_or_init_rand_seed(); + + let is_small_test = v.len() <= 100; + let v_orig = v.to_vec(); + + ::sort(v); + + assert_eq!(v.len(), v_orig.len()); + + for window in v.windows(2) { + if window[0] > window[1] { + let mut known_good_sorted_vec = v_orig.clone(); + known_good_stable_sort::sort(known_good_sorted_vec.as_mut_slice()); + + if is_small_test { + eprintln!("Orginal: {:?}", v_orig); + eprintln!("Expected: {:?}", known_good_sorted_vec); + eprintln!("Got: {:?}", v); + } else { + if env::var("WRITE_LARGE_FAILURE").is_ok() { + // Large arrays output them as files. + let original_name = format!("original_{}.txt", seed); + let std_name = format!("known_good_sorted_{}.txt", seed); + let testsort_name = format!("{}_sorted_{}.txt", S::name(), seed); + + fs::write(&original_name, format!("{:?}", v_orig)).unwrap(); + fs::write(&std_name, format!("{:?}", known_good_sorted_vec)).unwrap(); + fs::write(&testsort_name, format!("{:?}", v)).unwrap(); + + eprintln!( + "Failed comparison, see files {original_name}, {std_name}, and {testsort_name}" + ); + } else { + eprintln!( + "Failed comparison, re-run with WRITE_LARGE_FAILURE env var set, to get output." + ); + } + } + + panic!("Test assertion failed!") + } + } +} + +fn test_is_sorted( + test_len: usize, + map_fn: impl Fn(i32) -> T, + pattern_fn: impl Fn(usize) -> Vec, +) { + let mut test_data: Vec = pattern_fn(test_len).into_iter().map(map_fn).collect(); + check_is_sorted::(test_data.as_mut_slice()); +} + +trait DynTrait: Debug { + fn get_val(&self) -> i32; +} + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +struct DynValA { + value: i32, +} + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +struct DynValB { + value: u64, +} + +impl DynTrait for DynValA { + fn get_val(&self) -> i32 { + self.value + } +} +impl DynTrait for DynValB { + fn get_val(&self) -> i32 { + let bytes = self.value.to_ne_bytes(); + i32::from_ne_bytes([bytes[0], bytes[1], bytes[6], bytes[7]]) + } +} + +impl PartialOrd for dyn DynTrait { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for dyn DynTrait { + fn cmp(&self, other: &Self) -> Ordering { + self.get_val().cmp(&other.get_val()) + } +} + +impl PartialEq for dyn DynTrait { + fn eq(&self, other: &Self) -> bool { + self.get_val() == other.get_val() + } +} + +impl Eq for dyn DynTrait {} + +fn shift_i32_to_u32(val: i32) -> u32 { + (val as i64 + (i32::MAX as i64 + 1)) as u32 +} + +fn reverse_shift_i32_to_u32(val: u32) -> i32 { + (val as i64 - (i32::MAX as i64 + 1)) as i32 +} + +fn extend_i32_to_u64(val: i32) -> u64 { + // Extends the value into the 64 bit range, + // while preserving input order. + (shift_i32_to_u32(val) as u64) * i32::MAX as u64 +} + +fn extend_i32_to_u128(val: i32) -> u128 { + // Extends the value into the 64 bit range, + // while preserving input order. + (shift_i32_to_u32(val) as u128) * i64::MAX as u128 +} + +fn dyn_trait_from_i32(val: i32) -> Rc { + if val % 2 == 0 { + Rc::new(DynValA { value: val }) + } else { + Rc::new(DynValB { value: extend_i32_to_u64(val) }) + } +} + +fn i32_from_i32(val: i32) -> i32 { + val +} + +fn i32_from_i32_ref(val: &i32) -> i32 { + *val +} + +fn string_from_i32(val: i32) -> String { + format!("{:010}", shift_i32_to_u32(val)) +} + +fn i32_from_string(val: &String) -> i32 { + reverse_shift_i32_to_u32(val.parse::().unwrap()) +} + +fn cell_i32_from_i32(val: i32) -> Cell { + Cell::new(val) +} + +fn i32_from_cell_i32(val: &Cell) -> i32 { + val.get() +} + +fn calc_comps_required(v: &mut [T], mut cmp_fn: impl FnMut(&T, &T) -> Ordering) -> u32 { + let mut comp_counter = 0u32; + + ::sort_by(v, |a, b| { + comp_counter += 1; + + cmp_fn(a, b) + }); + + comp_counter +} + +#[derive(PartialEq, Eq, Debug, Clone)] +#[repr(C)] +struct CompCount { + val: i32, + comp_count: Cell, +} + +impl CompCount { + fn new(val: i32) -> Self { + Self { val, comp_count: Cell::new(0) } + } +} + +/// Generates $base_name_pattern_name_impl functions calling the test_fns for all test_len. +macro_rules! gen_sort_test_fns { + ( + $base_name:ident, + $test_fn:expr, + $test_lengths:expr, + [$(($pattern_name:ident, $pattern_fn:expr)),* $(,)?] $(,)? + ) => { + $(fn ${concat($base_name, _, $pattern_name, _impl)}() { + for test_len in $test_lengths { + $test_fn(*test_len, $pattern_fn); + } + })* + }; +} + +/// Generates $base_name_pattern_name_impl functions calling the test_fns for all test_len, +/// with a default set of patterns that can be extended by the caller. +macro_rules! gen_sort_test_fns_with_default_patterns { + ( + $base_name:ident, + $test_fn:expr, + $test_lengths:expr, + [$(($pattern_name:ident, $pattern_fn:expr)),* $(,)?] $(,)? + ) => { + gen_sort_test_fns!( + $base_name, + $test_fn, + $test_lengths, + [ + (random, patterns::random), + (random_z1, |len| patterns::random_zipf(len, 1.0)), + (random_d2, |len| patterns::random_uniform(len, 0..2)), + (random_d20, |len| patterns::random_uniform(len, 0..16)), + (random_s95, |len| patterns::random_sorted(len, 95.0)), + (ascending, patterns::ascending), + (descending, patterns::descending), + (saw_mixed, |len| patterns::saw_mixed( + len, + ((len as f64).log2().round()) as usize + )), + $(($pattern_name, $pattern_fn),)* + ] + ); + }; +} + +/// Generates $base_name_type_pattern_name_impl functions calling the test_fns for all test_len for +/// three types that cover the core specialization differences in the sort implementations, with a +/// default set of patterns that can be extended by the caller. +macro_rules! gen_sort_test_fns_with_default_patterns_3_ty { + ( + $base_name:ident, + $test_fn:ident, + [$(($pattern_name:ident, $pattern_fn:expr)),* $(,)?] $(,)? + ) => { + gen_sort_test_fns_with_default_patterns!( + ${concat($base_name, _i32)}, + |len, pattern_fn| $test_fn::(len, i32_from_i32, i32_from_i32_ref, pattern_fn), + &TEST_LENGTHS[..TEST_LENGTHS.len() - 2], + [$(($pattern_name, $pattern_fn),)*], + ); + + gen_sort_test_fns_with_default_patterns!( + ${concat($base_name, _cell_i32)}, + |len, pattern_fn| $test_fn::, S>(len, cell_i32_from_i32, i32_from_cell_i32, pattern_fn), + &TEST_LENGTHS[..TEST_LENGTHS.len() - 3], + [$(($pattern_name, $pattern_fn),)*], + ); + + gen_sort_test_fns_with_default_patterns!( + ${concat($base_name, _string)}, + |len, pattern_fn| $test_fn::(len, string_from_i32, i32_from_string, pattern_fn), + &TEST_LENGTHS[..TEST_LENGTHS.len() - 3], + [$(($pattern_name, $pattern_fn),)*], + ); + }; +} + +// --- TESTS --- + +pub fn basic_impl() { + check_is_sorted::(&mut []); + check_is_sorted::<(), S>(&mut []); + check_is_sorted::<(), S>(&mut [()]); + check_is_sorted::<(), S>(&mut [(), ()]); + check_is_sorted::<(), S>(&mut [(), (), ()]); + check_is_sorted::(&mut []); + check_is_sorted::(&mut [77]); + check_is_sorted::(&mut [2, 3]); + check_is_sorted::(&mut [2, 3, 6]); + check_is_sorted::(&mut [2, 3, 99, 6]); + check_is_sorted::(&mut [2, 7709, 400, 90932]); + check_is_sorted::(&mut [15, -1, 3, -1, -3, -1, 7]); +} + +fn fixed_seed_impl() { + let fixed_seed_a = patterns::get_or_init_rand_seed(); + let fixed_seed_b = patterns::get_or_init_rand_seed(); + + assert_eq!(fixed_seed_a, fixed_seed_b); +} + +fn fixed_seed_rand_vec_prefix_impl() { + let vec_rand_len_5 = patterns::random(5); + let vec_rand_len_7 = patterns::random(7); + + assert_eq!(vec_rand_len_5, vec_rand_len_7[..5]); +} + +fn int_edge_impl() { + // Ensure that the sort can handle integer edge cases. + check_is_sorted::(&mut [i32::MIN, i32::MAX]); + check_is_sorted::(&mut [i32::MAX, i32::MIN]); + check_is_sorted::(&mut [i32::MIN, 3]); + check_is_sorted::(&mut [i32::MIN, -3]); + check_is_sorted::(&mut [i32::MIN, -3, i32::MAX]); + check_is_sorted::(&mut [i32::MIN, -3, i32::MAX, i32::MIN, 5]); + check_is_sorted::(&mut [i32::MAX, 3, i32::MIN, 5, i32::MIN, -3, 60, 200, 50, 7, 10]); + + check_is_sorted::(&mut [u64::MIN, u64::MAX]); + check_is_sorted::(&mut [u64::MAX, u64::MIN]); + check_is_sorted::(&mut [u64::MIN, 3]); + check_is_sorted::(&mut [u64::MIN, u64::MAX - 3]); + check_is_sorted::(&mut [u64::MIN, u64::MAX - 3, u64::MAX]); + check_is_sorted::(&mut [u64::MIN, u64::MAX - 3, u64::MAX, u64::MIN, 5]); + check_is_sorted::(&mut [ + u64::MAX, + 3, + u64::MIN, + 5, + u64::MIN, + u64::MAX - 3, + 60, + 200, + 50, + 7, + 10, + ]); + + let mut large = patterns::random(TEST_LENGTHS[TEST_LENGTHS.len() - 2]); + large.push(i32::MAX); + large.push(i32::MIN); + large.push(i32::MAX); + check_is_sorted::(&mut large); +} + +fn sort_vs_sort_by_impl() { + // Ensure that sort and sort_by produce the same result. + let mut input_normal = [800, 3, -801, 5, -801, -3, 60, 200, 50, 7, 10]; + let expected = [-801, -801, -3, 3, 5, 7, 10, 50, 60, 200, 800]; + + let mut input_sort_by = input_normal.to_vec(); + + ::sort(&mut input_normal); + ::sort_by(&mut input_sort_by, |a, b| a.cmp(b)); + + assert_eq!(input_normal, expected); + assert_eq!(input_sort_by, expected); +} + +gen_sort_test_fns_with_default_patterns!( + correct_i32, + |len, pattern_fn| test_is_sorted::(len, |val| val, pattern_fn), + TEST_LENGTHS, + [ + (random_d4, |len| patterns::random_uniform(len, 0..4)), + (random_d8, |len| patterns::random_uniform(len, 0..8)), + (random_d311, |len| patterns::random_uniform(len, 0..311)), + (random_d1024, |len| patterns::random_uniform(len, 0..1024)), + (random_z1_03, |len| patterns::random_zipf(len, 1.03)), + (random_z2, |len| patterns::random_zipf(len, 2.0)), + (random_s50, |len| patterns::random_sorted(len, 50.0)), + (narrow, |len| patterns::random_uniform( + len, + 0..=(((len as f64).log2().round()) as i32) * 100 + )), + (all_equal, patterns::all_equal), + (saw_mixed_range, |len| patterns::saw_mixed_range(len, 20..50)), + (pipe_organ, patterns::pipe_organ), + ] +); + +gen_sort_test_fns_with_default_patterns!( + correct_u64, + |len, pattern_fn| test_is_sorted::(len, extend_i32_to_u64, pattern_fn), + TEST_LENGTHS, + [] +); + +gen_sort_test_fns_with_default_patterns!( + correct_u128, + |len, pattern_fn| test_is_sorted::(len, extend_i32_to_u128, pattern_fn), + &TEST_LENGTHS[..TEST_LENGTHS.len() - 2], + [] +); + +gen_sort_test_fns_with_default_patterns!( + correct_cell_i32, + |len, pattern_fn| test_is_sorted::, S>(len, Cell::new, pattern_fn), + &TEST_LENGTHS[..TEST_LENGTHS.len() - 2], + [] +); + +gen_sort_test_fns_with_default_patterns!( + correct_string, + |len, pattern_fn| test_is_sorted::( + len, + |val| format!("{:010}", shift_i32_to_u32(val)), + pattern_fn + ), + &TEST_LENGTHS[..TEST_LENGTHS.len() - 2], + [] +); + +gen_sort_test_fns_with_default_patterns!( + correct_f128, + |len, pattern_fn| test_is_sorted::(len, F128::new, pattern_fn), + &TEST_LENGTHS[..TEST_LENGTHS.len() - 2], + [] +); + +gen_sort_test_fns_with_default_patterns!( + correct_1k, + |len, pattern_fn| test_is_sorted::(len, FFIOneKibiByte::new, pattern_fn), + &TEST_LENGTHS[..TEST_LENGTHS.len() - 2], + [] +); + +// Dyn values are fat pointers, something the implementation might have overlooked. +gen_sort_test_fns_with_default_patterns!( + correct_dyn_val, + |len, pattern_fn| test_is_sorted::, S>(len, dyn_trait_from_i32, pattern_fn), + &TEST_LENGTHS[..TEST_LENGTHS.len() - 2], + [] +); + +fn stability_legacy_impl() { + // This non pattern variant has proven to catch some bugs the pattern version of this function + // doesn't catch, so it remains in conjunction with the other one. + + if ::name().contains("unstable") { + // It would be great to mark the test as skipped, but that isn't possible as of now. + return; + } + + let large_range = if cfg!(miri) { 100..110 } else { 3000..3010 }; + let rounds = if cfg!(miri) { 1 } else { 10 }; + + let rand_vals = patterns::random_uniform(5_000, 0..=9); + let mut rand_idx = 0; + + for len in (2..55).chain(large_range) { + for _ in 0..rounds { + let mut counts = [0; 10]; + + // create a vector like [(6, 1), (5, 1), (6, 2), ...], + // where the first item of each tuple is random, but + // the second item represents which occurrence of that + // number this element is, i.e., the second elements + // will occur in sorted order. + let orig: Vec<_> = (0..len) + .map(|_| { + let n = rand_vals[rand_idx]; + rand_idx += 1; + if rand_idx >= rand_vals.len() { + rand_idx = 0; + } + + counts[n as usize] += 1; + i32_tup_as_u64((n, counts[n as usize])) + }) + .collect(); + + let mut v = orig.clone(); + // Only sort on the first element, so an unstable sort + // may mix up the counts. + ::sort_by(&mut v, |a_packed, b_packed| { + let a = i32_tup_from_u64(*a_packed).0; + let b = i32_tup_from_u64(*b_packed).0; + + a.cmp(&b) + }); + + // This comparison includes the count (the second item + // of the tuple), so elements with equal first items + // will need to be ordered with increasing + // counts... i.e., exactly asserting that this sort is + // stable. + assert!(v.windows(2).all(|w| i32_tup_from_u64(w[0]) <= i32_tup_from_u64(w[1]))); + } + } + + // For cpp_sorts that only support u64 we can pack the two i32 inside a u64. + fn i32_tup_as_u64(val: (i32, i32)) -> u64 { + let a_bytes = val.0.to_le_bytes(); + let b_bytes = val.1.to_le_bytes(); + + u64::from_le_bytes([a_bytes, b_bytes].concat().try_into().unwrap()) + } + + fn i32_tup_from_u64(val: u64) -> (i32, i32) { + let bytes = val.to_le_bytes(); + + let a = i32::from_le_bytes(bytes[0..4].try_into().unwrap()); + let b = i32::from_le_bytes(bytes[4..8].try_into().unwrap()); + + (a, b) + } +} + +fn stability_with_patterns( + len: usize, + type_into_fn: impl Fn(i32) -> T, + _type_from_fn: impl Fn(&T) -> i32, + pattern_fn: fn(usize) -> Vec, +) { + if ::name().contains("unstable") { + // It would be great to mark the test as skipped, but that isn't possible as of now. + return; + } + + let pattern = pattern_fn(len); + + let mut counts = [0i32; 128]; + + // create a vector like [(6, 1), (5, 1), (6, 2), ...], + // where the first item of each tuple is random, but + // the second item represents which occurrence of that + // number this element is, i.e., the second elements + // will occur in sorted order. + let orig: Vec<_> = pattern + .iter() + .map(|val| { + let n = val.saturating_abs() % counts.len() as i32; + counts[n as usize] += 1; + (type_into_fn(n), counts[n as usize]) + }) + .collect(); + + let mut v = orig.clone(); + // Only sort on the first element, so an unstable sort + // may mix up the counts. + ::sort(&mut v); + + // This comparison includes the count (the second item + // of the tuple), so elements with equal first items + // will need to be ordered with increasing + // counts... i.e., exactly asserting that this sort is + // stable. + assert!(v.windows(2).all(|w| w[0] <= w[1])); +} + +gen_sort_test_fns_with_default_patterns_3_ty!(stability, stability_with_patterns, []); + +fn observable_is_less(len: usize, pattern_fn: fn(usize) -> Vec) { + // This test, tests that every is_less is actually observable. Ie. this can go wrong if a hole + // is created using temporary memory and, the whole is used as comparison but not copied back. + // + // If this is not upheld a custom type + comparison function could yield UB in otherwise safe + // code. Eg T == Mutex>> which replaces the pointer with none in the comparison + // function, which would not be observed in the original slice and would lead to a double free. + + let pattern = pattern_fn(len); + let mut test_input = pattern.into_iter().map(|val| CompCount::new(val)).collect::>(); + + let mut comp_count_global = 0; + + ::sort_by(&mut test_input, |a, b| { + a.comp_count.replace(a.comp_count.get() + 1); + b.comp_count.replace(b.comp_count.get() + 1); + comp_count_global += 1; + + a.val.cmp(&b.val) + }); + + let total_inner: u64 = test_input.iter().map(|c| c.comp_count.get() as u64).sum(); + + assert_eq!(total_inner, comp_count_global * 2); +} + +gen_sort_test_fns_with_default_patterns!( + observable_is_less, + observable_is_less::, + &TEST_LENGTHS[..TEST_LENGTHS.len() - 2], + [] +); + +fn panic_retain_orig_set( + len: usize, + type_into_fn: impl Fn(i32) -> T + Copy, + type_from_fn: impl Fn(&T) -> i32, + pattern_fn: fn(usize) -> Vec, +) { + let mut test_data: Vec = pattern_fn(len).into_iter().map(type_into_fn).collect(); + + let sum_before: i64 = test_data.iter().map(|x| type_from_fn(x) as i64).sum(); + + // Calculate a specific comparison that should panic. + // Ensure that it can be any of the possible comparisons and that it always panics. + let required_comps = calc_comps_required::(&mut test_data.clone(), |a, b| a.cmp(b)); + let panic_threshold = patterns::random_uniform(1, 1..=required_comps as i32)[0] as usize - 1; + + let mut comp_counter = 0; + + let res = panic::catch_unwind(AssertUnwindSafe(|| { + ::sort_by(&mut test_data, |a, b| { + if comp_counter == panic_threshold { + // Make the panic dependent on the test len and some random factor. We want to + // make sure that panicking may also happen when comparing elements a second + // time. + panic!(); + } + comp_counter += 1; + + a.cmp(b) + }); + })); + + assert!(res.is_err()); + + // If the sum before and after don't match, it means the set of elements hasn't remained the + // same. + let sum_after: i64 = test_data.iter().map(|x| type_from_fn(x) as i64).sum(); + assert_eq!(sum_before, sum_after); +} + +gen_sort_test_fns_with_default_patterns_3_ty!(panic_retain_orig_set, panic_retain_orig_set, []); + +fn panic_observable_is_less(len: usize, pattern_fn: fn(usize) -> Vec) { + // This test, tests that every is_less is actually observable. Ie. this can go wrong if a hole + // is created using temporary memory and, the whole is used as comparison but not copied back. + // This property must also hold if the user provided comparison panics. + // + // If this is not upheld a custom type + comparison function could yield UB in otherwise safe + // code. Eg T == Mutex>> which replaces the pointer with none in the comparison + // function, which would not be observed in the original slice and would lead to a double free. + + let mut test_input = + pattern_fn(len).into_iter().map(|val| CompCount::new(val)).collect::>(); + + let sum_before: i64 = test_input.iter().map(|x| x.val as i64).sum(); + + // Calculate a specific comparison that should panic. + // Ensure that it can be any of the possible comparisons and that it always panics. + let required_comps = + calc_comps_required::(&mut test_input.clone(), |a, b| a.val.cmp(&b.val)); + + let panic_threshold = patterns::random_uniform(1, 1..=required_comps as i32)[0] as u64 - 1; + + let mut comp_count_global = 0; + + let res = panic::catch_unwind(AssertUnwindSafe(|| { + ::sort_by(&mut test_input, |a, b| { + if comp_count_global == panic_threshold { + // Make the panic dependent on the test len and some random factor. We want to + // make sure that panicking may also happen when comparing elements a second + // time. + panic!(); + } + + a.comp_count.replace(a.comp_count.get() + 1); + b.comp_count.replace(b.comp_count.get() + 1); + comp_count_global += 1; + + a.val.cmp(&b.val) + }); + })); + + assert!(res.is_err()); + + let total_inner: u64 = test_input.iter().map(|c| c.comp_count.get() as u64).sum(); + + assert_eq!(total_inner, comp_count_global * 2); + + // If the sum before and after don't match, it means the set of elements hasn't remained the + // same. + let sum_after: i64 = test_input.iter().map(|x| x.val as i64).sum(); + assert_eq!(sum_before, sum_after); +} + +gen_sort_test_fns_with_default_patterns!( + panic_observable_is_less, + panic_observable_is_less::, + &TEST_LENGTHS[..TEST_LENGTHS.len() - 2], + [] +); + +fn deterministic( + len: usize, + type_into_fn: impl Fn(i32) -> T + Copy, + type_from_fn: impl Fn(&T) -> i32, + pattern_fn: fn(usize) -> Vec, +) { + // A property similar to stability is deterministic output order. If the entire value is used as + // the comparison key a lack of determinism has no effect. But if only a part of the value is + // used as comparison key, a lack of determinism can manifest itself in the order of values + // considered equal by the comparison predicate. + // + // This test only tests that results are deterministic across runs, it does not test determinism + // on different platforms and with different toolchains. + + let mut test_input = + pattern_fn(len).into_iter().map(|val| type_into_fn(val)).collect::>(); + + let mut test_input_clone = test_input.clone(); + + let comparison_fn = |a: &T, b: &T| { + let a_i32 = type_from_fn(a); + let b_i32 = type_from_fn(b); + + let a_i32_key_space_reduced = a_i32 % 10_000; + let b_i32_key_space_reduced = b_i32 % 10_000; + + a_i32_key_space_reduced.cmp(&b_i32_key_space_reduced) + }; + + ::sort_by(&mut test_input, comparison_fn); + ::sort_by(&mut test_input_clone, comparison_fn); + + assert_eq!(test_input, test_input_clone); +} + +gen_sort_test_fns_with_default_patterns_3_ty!(deterministic, deterministic, []); + +fn self_cmp( + len: usize, + type_into_fn: impl Fn(i32) -> T + Copy, + _type_from_fn: impl Fn(&T) -> i32, + pattern_fn: fn(usize) -> Vec, +) { + // It's possible for comparisons to run into problems if the values of `a` and `b` passed into + // the comparison function are the same reference. So this tests that they never are. + + let mut test_input = + pattern_fn(len).into_iter().map(|val| type_into_fn(val)).collect::>(); + + let comparison_fn = |a: &T, b: &T| { + assert_ne!(a as *const T as usize, b as *const T as usize); + a.cmp(b) + }; + + ::sort_by(&mut test_input, comparison_fn); + + // Check that the output is actually sorted and wasn't stopped by the assert. + for window in test_input.windows(2) { + assert!(window[0] <= window[1]); + } +} + +gen_sort_test_fns_with_default_patterns_3_ty!(self_cmp, self_cmp, []); + +fn violate_ord_retain_orig_set( + len: usize, + type_into_fn: impl Fn(i32) -> T + Copy, + type_from_fn: impl Fn(&T) -> i32, + pattern_fn: fn(usize) -> Vec, +) { + // A user may implement Ord incorrectly for a type or violate it by calling sort_by with a + // comparison function that violates Ord with the orderings it returns. Even under such + // circumstances the input must retain its original set of elements. + + // Ord implies a strict total order see https://en.wikipedia.org/wiki/Total_order. + + // Generating random numbers with miri is quite expensive. + let random_orderings_len = if cfg!(miri) { 200 } else { 10_000 }; + + // Make sure we get a good distribution of random orderings, that are repeatable with the seed. + // Just using random_uniform with the same len and range will always yield the same value. + let random_orderings = patterns::random_uniform(random_orderings_len, 0..2); + + let get_random_0_1_or_2 = |random_idx: &mut usize| { + let ridx = *random_idx; + *random_idx += 1; + if ridx + 1 == random_orderings.len() { + *random_idx = 0; + } + + random_orderings[ridx] as usize + }; + + let mut random_idx_a = 0; + let mut random_idx_b = 0; + let mut random_idx_c = 0; + + let mut last_element_a = -1; + let mut last_element_b = -1; + + let mut rand_counter_b = 0; + let mut rand_counter_c = 0; + + let mut streak_counter_a = 0; + let mut streak_counter_b = 0; + + // Examples, a = 3, b = 5, c = 9. + // Correct Ord -> 10010 | is_less(a, b) is_less(a, a) is_less(b, a) is_less(a, c) is_less(c, a) + let mut invalid_ord_comp_functions: Vec Ordering>> = vec![ + Box::new(|_a, _b| -> Ordering { + // random + // Eg. is_less(3, 5) == true, is_less(3, 5) == false + + let idx = get_random_0_1_or_2(&mut random_idx_a); + [Ordering::Less, Ordering::Equal, Ordering::Greater][idx] + }), + Box::new(|_a, _b| -> Ordering { + // everything is less -> 11111 + Ordering::Less + }), + Box::new(|_a, _b| -> Ordering { + // everything is equal -> 00000 + Ordering::Equal + }), + Box::new(|_a, _b| -> Ordering { + // everything is greater -> 00000 + // Eg. is_less(3, 5) == false, is_less(5, 3) == false, is_less(3, 3) == false + Ordering::Greater + }), + Box::new(|a, b| -> Ordering { + // equal means less else greater -> 01000 + if a == b { Ordering::Less } else { Ordering::Greater } + }), + Box::new(|a, b| -> Ordering { + // Transitive breaker. remember last element -> 10001 + let lea = last_element_a; + let leb = last_element_b; + + let a_as_i32 = type_from_fn(a); + let b_as_i32 = type_from_fn(b); + + last_element_a = a_as_i32; + last_element_b = b_as_i32; + + if a_as_i32 == lea && b_as_i32 != leb { b.cmp(a) } else { a.cmp(b) } + }), + Box::new(|a, b| -> Ordering { + // Sampled random 1% of comparisons are reversed. + rand_counter_b += get_random_0_1_or_2(&mut random_idx_b); + if rand_counter_b >= 100 { + rand_counter_b = 0; + b.cmp(a) + } else { + a.cmp(b) + } + }), + Box::new(|a, b| -> Ordering { + // Sampled random 33% of comparisons are reversed. + rand_counter_c += get_random_0_1_or_2(&mut random_idx_c); + if rand_counter_c >= 3 { + rand_counter_c = 0; + b.cmp(a) + } else { + a.cmp(b) + } + }), + Box::new(|a, b| -> Ordering { + // STREAK_LEN comparisons yield a.cmp(b) then STREAK_LEN comparisons less. This can + // discover bugs that neither, random Ord, or just Less or Greater can find. Because it + // can push a pointer further than expected. Random Ord will average out how far a + // comparison based pointer travels. Just Less or Greater will be caught by pattern + // analysis and never enter interesting code. + const STREAK_LEN: usize = 50; + + streak_counter_a += 1; + if streak_counter_a <= STREAK_LEN { + a.cmp(b) + } else { + if streak_counter_a == STREAK_LEN * 2 { + streak_counter_a = 0; + } + Ordering::Less + } + }), + Box::new(|a, b| -> Ordering { + // See above. + const STREAK_LEN: usize = 50; + + streak_counter_b += 1; + if streak_counter_b <= STREAK_LEN { + a.cmp(b) + } else { + if streak_counter_b == STREAK_LEN * 2 { + streak_counter_b = 0; + } + Ordering::Greater + } + }), + ]; + + for comp_func in &mut invalid_ord_comp_functions { + let mut test_data: Vec = pattern_fn(len).into_iter().map(type_into_fn).collect(); + let sum_before: i64 = test_data.iter().map(|x| type_from_fn(x) as i64).sum(); + + // It's ok to panic on Ord violation or to complete. + // In both cases the original elements must still be present. + let _ = panic::catch_unwind(AssertUnwindSafe(|| { + ::sort_by(&mut test_data, &mut *comp_func); + })); + + // If the sum before and after don't match, it means the set of elements hasn't remained the + // same. + let sum_after: i64 = test_data.iter().map(|x| type_from_fn(x) as i64).sum(); + assert_eq!(sum_before, sum_after); + + if cfg!(miri) { + // This test is prohibitively expensive in miri, so only run one of the comparison + // functions. This test is not expected to yield direct UB, but rather surface potential + // UB by showing that the sum is different now. + break; + } + } +} + +gen_sort_test_fns_with_default_patterns_3_ty!( + violate_ord_retain_orig_set, + violate_ord_retain_orig_set, + [] +); + +macro_rules! instantiate_sort_test_inner { + ($sort_impl:ty, miri_yes, $test_fn_name:ident) => { + #[test] + fn $test_fn_name() { + $crate::sort::tests::$test_fn_name::<$sort_impl>(); + } + }; + ($sort_impl:ty, miri_no, $test_fn_name:ident) => { + #[test] + #[cfg_attr(miri, ignore)] + fn $test_fn_name() { + $crate::sort::tests::$test_fn_name::<$sort_impl>(); + } + }; +} + +// Using this construct allows us to get warnings for unused test functions. +macro_rules! define_instantiate_sort_tests { + ($([$miri_use:ident, $test_fn_name:ident]),*,) => { + $(pub fn $test_fn_name() { + ${concat($test_fn_name, _impl)}::(); + })* + + + macro_rules! instantiate_sort_tests_gen { + ($sort_impl:ty) => { + $( + instantiate_sort_test_inner!( + $sort_impl, + $miri_use, + $test_fn_name + ); + )* + } + } + }; +} + +// Some tests are not tested with miri to avoid prohibitively long test times. This leaves coverage +// holes, but the way they are selected should make for relatively small holes. Many properties that +// can lead to UB are tested directly, for example that the original set of elements is retained +// even when a panic occurs or Ord is implemented incorrectly. +define_instantiate_sort_tests!( + [miri_yes, basic], + [miri_yes, fixed_seed], + [miri_yes, fixed_seed_rand_vec_prefix], + [miri_yes, int_edge], + [miri_yes, sort_vs_sort_by], + [miri_yes, correct_i32_random], + [miri_yes, correct_i32_random_z1], + [miri_yes, correct_i32_random_d2], + [miri_yes, correct_i32_random_d20], + [miri_yes, correct_i32_random_s95], + [miri_yes, correct_i32_ascending], + [miri_yes, correct_i32_descending], + [miri_yes, correct_i32_saw_mixed], + [miri_no, correct_i32_random_d4], + [miri_no, correct_i32_random_d8], + [miri_no, correct_i32_random_d311], + [miri_no, correct_i32_random_d1024], + [miri_no, correct_i32_random_z1_03], + [miri_no, correct_i32_random_z2], + [miri_no, correct_i32_random_s50], + [miri_no, correct_i32_narrow], + [miri_no, correct_i32_all_equal], + [miri_no, correct_i32_saw_mixed_range], + [miri_yes, correct_i32_pipe_organ], + [miri_no, correct_u64_random], + [miri_yes, correct_u64_random_z1], + [miri_no, correct_u64_random_d2], + [miri_no, correct_u64_random_d20], + [miri_no, correct_u64_random_s95], + [miri_no, correct_u64_ascending], + [miri_no, correct_u64_descending], + [miri_no, correct_u64_saw_mixed], + [miri_no, correct_u128_random], + [miri_yes, correct_u128_random_z1], + [miri_no, correct_u128_random_d2], + [miri_no, correct_u128_random_d20], + [miri_no, correct_u128_random_s95], + [miri_no, correct_u128_ascending], + [miri_no, correct_u128_descending], + [miri_no, correct_u128_saw_mixed], + [miri_no, correct_cell_i32_random], + [miri_yes, correct_cell_i32_random_z1], + [miri_no, correct_cell_i32_random_d2], + [miri_no, correct_cell_i32_random_d20], + [miri_no, correct_cell_i32_random_s95], + [miri_no, correct_cell_i32_ascending], + [miri_no, correct_cell_i32_descending], + [miri_no, correct_cell_i32_saw_mixed], + [miri_no, correct_string_random], + [miri_yes, correct_string_random_z1], + [miri_no, correct_string_random_d2], + [miri_no, correct_string_random_d20], + [miri_no, correct_string_random_s95], + [miri_no, correct_string_ascending], + [miri_no, correct_string_descending], + [miri_no, correct_string_saw_mixed], + [miri_no, correct_f128_random], + [miri_yes, correct_f128_random_z1], + [miri_no, correct_f128_random_d2], + [miri_no, correct_f128_random_d20], + [miri_no, correct_f128_random_s95], + [miri_no, correct_f128_ascending], + [miri_no, correct_f128_descending], + [miri_no, correct_f128_saw_mixed], + [miri_no, correct_1k_random], + [miri_yes, correct_1k_random_z1], + [miri_no, correct_1k_random_d2], + [miri_no, correct_1k_random_d20], + [miri_no, correct_1k_random_s95], + [miri_no, correct_1k_ascending], + [miri_no, correct_1k_descending], + [miri_no, correct_1k_saw_mixed], + [miri_no, correct_dyn_val_random], + [miri_yes, correct_dyn_val_random_z1], + [miri_no, correct_dyn_val_random_d2], + [miri_no, correct_dyn_val_random_d20], + [miri_no, correct_dyn_val_random_s95], + [miri_no, correct_dyn_val_ascending], + [miri_no, correct_dyn_val_descending], + [miri_no, correct_dyn_val_saw_mixed], + [miri_no, stability_legacy], + [miri_no, stability_i32_random], + [miri_yes, stability_i32_random_z1], + [miri_no, stability_i32_random_d2], + [miri_no, stability_i32_random_d20], + [miri_no, stability_i32_random_s95], + [miri_no, stability_i32_ascending], + [miri_no, stability_i32_descending], + [miri_no, stability_i32_saw_mixed], + [miri_no, stability_cell_i32_random], + [miri_yes, stability_cell_i32_random_z1], + [miri_no, stability_cell_i32_random_d2], + [miri_no, stability_cell_i32_random_d20], + [miri_no, stability_cell_i32_random_s95], + [miri_no, stability_cell_i32_ascending], + [miri_no, stability_cell_i32_descending], + [miri_no, stability_cell_i32_saw_mixed], + [miri_no, stability_string_random], + [miri_yes, stability_string_random_z1], + [miri_no, stability_string_random_d2], + [miri_no, stability_string_random_d20], + [miri_no, stability_string_random_s95], + [miri_no, stability_string_ascending], + [miri_no, stability_string_descending], + [miri_no, stability_string_saw_mixed], + [miri_no, observable_is_less_random], + [miri_yes, observable_is_less_random_z1], + [miri_no, observable_is_less_random_d2], + [miri_no, observable_is_less_random_d20], + [miri_no, observable_is_less_random_s95], + [miri_no, observable_is_less_ascending], + [miri_no, observable_is_less_descending], + [miri_no, observable_is_less_saw_mixed], + [miri_no, panic_retain_orig_set_i32_random], + [miri_yes, panic_retain_orig_set_i32_random_z1], + [miri_no, panic_retain_orig_set_i32_random_d2], + [miri_no, panic_retain_orig_set_i32_random_d20], + [miri_no, panic_retain_orig_set_i32_random_s95], + [miri_no, panic_retain_orig_set_i32_ascending], + [miri_no, panic_retain_orig_set_i32_descending], + [miri_no, panic_retain_orig_set_i32_saw_mixed], + [miri_no, panic_retain_orig_set_cell_i32_random], + [miri_yes, panic_retain_orig_set_cell_i32_random_z1], + [miri_no, panic_retain_orig_set_cell_i32_random_d2], + [miri_no, panic_retain_orig_set_cell_i32_random_d20], + [miri_no, panic_retain_orig_set_cell_i32_random_s95], + [miri_no, panic_retain_orig_set_cell_i32_ascending], + [miri_no, panic_retain_orig_set_cell_i32_descending], + [miri_no, panic_retain_orig_set_cell_i32_saw_mixed], + [miri_no, panic_retain_orig_set_string_random], + [miri_yes, panic_retain_orig_set_string_random_z1], + [miri_no, panic_retain_orig_set_string_random_d2], + [miri_no, panic_retain_orig_set_string_random_d20], + [miri_no, panic_retain_orig_set_string_random_s95], + [miri_no, panic_retain_orig_set_string_ascending], + [miri_no, panic_retain_orig_set_string_descending], + [miri_no, panic_retain_orig_set_string_saw_mixed], + [miri_no, panic_observable_is_less_random], + [miri_yes, panic_observable_is_less_random_z1], + [miri_no, panic_observable_is_less_random_d2], + [miri_no, panic_observable_is_less_random_d20], + [miri_no, panic_observable_is_less_random_s95], + [miri_no, panic_observable_is_less_ascending], + [miri_no, panic_observable_is_less_descending], + [miri_no, panic_observable_is_less_saw_mixed], + [miri_no, deterministic_i32_random], + [miri_yes, deterministic_i32_random_z1], + [miri_no, deterministic_i32_random_d2], + [miri_no, deterministic_i32_random_d20], + [miri_no, deterministic_i32_random_s95], + [miri_no, deterministic_i32_ascending], + [miri_no, deterministic_i32_descending], + [miri_no, deterministic_i32_saw_mixed], + [miri_no, deterministic_cell_i32_random], + [miri_yes, deterministic_cell_i32_random_z1], + [miri_no, deterministic_cell_i32_random_d2], + [miri_no, deterministic_cell_i32_random_d20], + [miri_no, deterministic_cell_i32_random_s95], + [miri_no, deterministic_cell_i32_ascending], + [miri_no, deterministic_cell_i32_descending], + [miri_no, deterministic_cell_i32_saw_mixed], + [miri_no, deterministic_string_random], + [miri_yes, deterministic_string_random_z1], + [miri_no, deterministic_string_random_d2], + [miri_no, deterministic_string_random_d20], + [miri_no, deterministic_string_random_s95], + [miri_no, deterministic_string_ascending], + [miri_no, deterministic_string_descending], + [miri_no, deterministic_string_saw_mixed], + [miri_no, self_cmp_i32_random], + [miri_yes, self_cmp_i32_random_z1], + [miri_no, self_cmp_i32_random_d2], + [miri_no, self_cmp_i32_random_d20], + [miri_no, self_cmp_i32_random_s95], + [miri_no, self_cmp_i32_ascending], + [miri_no, self_cmp_i32_descending], + [miri_no, self_cmp_i32_saw_mixed], + [miri_no, self_cmp_cell_i32_random], + [miri_yes, self_cmp_cell_i32_random_z1], + [miri_no, self_cmp_cell_i32_random_d2], + [miri_no, self_cmp_cell_i32_random_d20], + [miri_no, self_cmp_cell_i32_random_s95], + [miri_no, self_cmp_cell_i32_ascending], + [miri_no, self_cmp_cell_i32_descending], + [miri_no, self_cmp_cell_i32_saw_mixed], + [miri_no, self_cmp_string_random], + [miri_yes, self_cmp_string_random_z1], + [miri_no, self_cmp_string_random_d2], + [miri_no, self_cmp_string_random_d20], + [miri_no, self_cmp_string_random_s95], + [miri_no, self_cmp_string_ascending], + [miri_no, self_cmp_string_descending], + [miri_no, self_cmp_string_saw_mixed], + [miri_no, violate_ord_retain_orig_set_i32_random], + [miri_yes, violate_ord_retain_orig_set_i32_random_z1], + [miri_no, violate_ord_retain_orig_set_i32_random_d2], + [miri_no, violate_ord_retain_orig_set_i32_random_d20], + [miri_no, violate_ord_retain_orig_set_i32_random_s95], + [miri_no, violate_ord_retain_orig_set_i32_ascending], + [miri_no, violate_ord_retain_orig_set_i32_descending], + [miri_no, violate_ord_retain_orig_set_i32_saw_mixed], + [miri_no, violate_ord_retain_orig_set_cell_i32_random], + [miri_yes, violate_ord_retain_orig_set_cell_i32_random_z1], + [miri_no, violate_ord_retain_orig_set_cell_i32_random_d2], + [miri_no, violate_ord_retain_orig_set_cell_i32_random_d20], + [miri_no, violate_ord_retain_orig_set_cell_i32_random_s95], + [miri_no, violate_ord_retain_orig_set_cell_i32_ascending], + [miri_no, violate_ord_retain_orig_set_cell_i32_descending], + [miri_no, violate_ord_retain_orig_set_cell_i32_saw_mixed], + [miri_no, violate_ord_retain_orig_set_string_random], + [miri_yes, violate_ord_retain_orig_set_string_random_z1], + [miri_no, violate_ord_retain_orig_set_string_random_d2], + [miri_no, violate_ord_retain_orig_set_string_random_d20], + [miri_no, violate_ord_retain_orig_set_string_random_s95], + [miri_no, violate_ord_retain_orig_set_string_ascending], + [miri_no, violate_ord_retain_orig_set_string_descending], + [miri_no, violate_ord_retain_orig_set_string_saw_mixed], +); + +macro_rules! instantiate_sort_tests { + ($sort_impl:ty) => { + instantiate_sort_tests_gen!($sort_impl); + }; +} + +mod unstable { + struct SortImpl {} + + impl crate::sort::Sort for SortImpl { + fn name() -> String { + "rust_std_unstable".into() + } + + fn sort(v: &mut [T]) + where + T: Ord, + { + v.sort_unstable(); + } + + fn sort_by(v: &mut [T], mut compare: F) + where + F: FnMut(&T, &T) -> std::cmp::Ordering, + { + v.sort_unstable_by(|a, b| compare(a, b)); + } + } + + instantiate_sort_tests!(SortImpl); +} + +mod stable { + struct SortImpl {} + + impl crate::sort::Sort for SortImpl { + fn name() -> String { + "rust_std_stable".into() + } + + fn sort(v: &mut [T]) + where + T: Ord, + { + v.sort(); + } + + fn sort_by(v: &mut [T], mut compare: F) + where + F: FnMut(&T, &T) -> std::cmp::Ordering, + { + v.sort_by(|a, b| compare(a, b)); + } + } + + instantiate_sort_tests!(SortImpl); +} diff --git a/library/alloc/tests/sort/zipf.rs b/library/alloc/tests/sort/zipf.rs new file mode 100644 index 0000000000000..cc774ee5c43bf --- /dev/null +++ b/library/alloc/tests/sort/zipf.rs @@ -0,0 +1,208 @@ +// This module implements a Zipfian distribution generator. +// +// Based on https://github.com/jonhoo/rust-zipf. + +use rand::Rng; + +/// Random number generator that generates Zipf-distributed random numbers using rejection +/// inversion. +#[derive(Clone, Copy)] +pub struct ZipfDistribution { + /// Number of elements + num_elements: f64, + /// Exponent parameter of the distribution + exponent: f64, + /// `hIntegral(1.5) - 1}` + h_integral_x1: f64, + /// `hIntegral(num_elements + 0.5)}` + h_integral_num_elements: f64, + /// `2 - hIntegralInverse(hIntegral(2.5) - h(2)}` + s: f64, +} + +impl ZipfDistribution { + /// Creates a new [Zipf-distributed](https://en.wikipedia.org/wiki/Zipf's_law) + /// random number generator. + /// + /// Note that both the number of elements and the exponent must be greater than 0. + pub fn new(num_elements: usize, exponent: f64) -> Result { + if num_elements == 0 { + return Err(()); + } + if exponent <= 0f64 { + return Err(()); + } + + let z = ZipfDistribution { + num_elements: num_elements as f64, + exponent, + h_integral_x1: ZipfDistribution::h_integral(1.5, exponent) - 1f64, + h_integral_num_elements: ZipfDistribution::h_integral( + num_elements as f64 + 0.5, + exponent, + ), + s: 2f64 + - ZipfDistribution::h_integral_inv( + ZipfDistribution::h_integral(2.5, exponent) + - ZipfDistribution::h(2f64, exponent), + exponent, + ), + }; + + // populate cache + + Ok(z) + } +} + +impl ZipfDistribution { + fn next(&self, rng: &mut R) -> usize { + // The paper describes an algorithm for exponents larger than 1 (Algorithm ZRI). + // + // The original method uses + // H(x) = (v + x)^(1 - q) / (1 - q) + // as the integral of the hat function. + // + // This function is undefined for q = 1, which is the reason for the limitation of the + // exponent. + // + // If instead the integral function + // H(x) = ((v + x)^(1 - q) - 1) / (1 - q) + // is used, for which a meaningful limit exists for q = 1, the method works for all + // positive exponents. + // + // The following implementation uses v = 0 and generates integral number in the range [1, + // num_elements]. This is different to the original method where v is defined to + // be positive and numbers are taken from [0, i_max]. This explains why the implementation + // looks slightly different. + + let hnum = self.h_integral_num_elements; + + loop { + use std::cmp; + let u: f64 = hnum + rng.gen::() * (self.h_integral_x1 - hnum); + // u is uniformly distributed in (h_integral_x1, h_integral_num_elements] + + let x: f64 = ZipfDistribution::h_integral_inv(u, self.exponent); + + // Limit k to the range [1, num_elements] if it would be outside + // due to numerical inaccuracies. + let k64 = x.max(1.0).min(self.num_elements); + // float -> integer rounds towards zero, so we add 0.5 + // to prevent bias towards k == 1 + let k = cmp::max(1, (k64 + 0.5) as usize); + + // Here, the distribution of k is given by: + // + // P(k = 1) = C * (hIntegral(1.5) - h_integral_x1) = C + // P(k = m) = C * (hIntegral(m + 1/2) - hIntegral(m - 1/2)) for m >= 2 + // + // where C = 1 / (h_integral_num_elements - h_integral_x1) + if k64 - x <= self.s + || u >= ZipfDistribution::h_integral(k64 + 0.5, self.exponent) + - ZipfDistribution::h(k64, self.exponent) + { + // Case k = 1: + // + // The right inequality is always true, because replacing k by 1 gives + // u >= hIntegral(1.5) - h(1) = h_integral_x1 and u is taken from + // (h_integral_x1, h_integral_num_elements]. + // + // Therefore, the acceptance rate for k = 1 is P(accepted | k = 1) = 1 + // and the probability that 1 is returned as random value is + // P(k = 1 and accepted) = P(accepted | k = 1) * P(k = 1) = C = C / 1^exponent + // + // Case k >= 2: + // + // The left inequality (k - x <= s) is just a short cut + // to avoid the more expensive evaluation of the right inequality + // (u >= hIntegral(k + 0.5) - h(k)) in many cases. + // + // If the left inequality is true, the right inequality is also true: + // Theorem 2 in the paper is valid for all positive exponents, because + // the requirements h'(x) = -exponent/x^(exponent + 1) < 0 and + // (-1/hInverse'(x))'' = (1+1/exponent) * x^(1/exponent-1) >= 0 + // are both fulfilled. + // Therefore, f(x) = x - hIntegralInverse(hIntegral(x + 0.5) - h(x)) + // is a non-decreasing function. If k - x <= s holds, + // k - x <= s + f(k) - f(2) is obviously also true which is equivalent to + // -x <= -hIntegralInverse(hIntegral(k + 0.5) - h(k)), + // -hIntegralInverse(u) <= -hIntegralInverse(hIntegral(k + 0.5) - h(k)), + // and finally u >= hIntegral(k + 0.5) - h(k). + // + // Hence, the right inequality determines the acceptance rate: + // P(accepted | k = m) = h(m) / (hIntegrated(m+1/2) - hIntegrated(m-1/2)) + // The probability that m is returned is given by + // P(k = m and accepted) = P(accepted | k = m) * P(k = m) + // = C * h(m) = C / m^exponent. + // + // In both cases the probabilities are proportional to the probability mass + // function of the Zipf distribution. + + return k; + } + } + } +} + +impl rand::distributions::Distribution for ZipfDistribution { + fn sample(&self, rng: &mut R) -> usize { + self.next(rng) + } +} + +use std::fmt; +impl fmt::Debug for ZipfDistribution { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.debug_struct("ZipfDistribution") + .field("e", &self.exponent) + .field("n", &self.num_elements) + .finish() + } +} + +impl ZipfDistribution { + /// Computes `H(x)`, defined as + /// + /// - `(x^(1 - exponent) - 1) / (1 - exponent)`, if `exponent != 1` + /// - `log(x)`, if `exponent == 1` + /// + /// `H(x)` is an integral function of `h(x)`, the derivative of `H(x)` is `h(x)`. + fn h_integral(x: f64, exponent: f64) -> f64 { + let log_x = x.ln(); + helper2((1f64 - exponent) * log_x) * log_x + } + + /// Computes `h(x) = 1 / x^exponent` + fn h(x: f64, exponent: f64) -> f64 { + (-exponent * x.ln()).exp() + } + + /// The inverse function of `H(x)`. + /// Returns the `y` for which `H(y) = x`. + fn h_integral_inv(x: f64, exponent: f64) -> f64 { + let mut t: f64 = x * (1f64 - exponent); + if t < -1f64 { + // Limit value to the range [-1, +inf). + // t could be smaller than -1 in some rare cases due to numerical errors. + t = -1f64; + } + (helper1(t) * x).exp() + } +} + +/// Helper function that calculates `log(1 + x) / x`. +/// A Taylor series expansion is used, if x is close to 0. +fn helper1(x: f64) -> f64 { + if x.abs() > 1e-8 { x.ln_1p() / x } else { 1f64 - x * (0.5 - x * (1.0 / 3.0 - 0.25 * x)) } +} + +/// Helper function to calculate `(exp(x) - 1) / x`. +/// A Taylor series expansion is used, if x is close to 0. +fn helper2(x: f64) -> f64 { + if x.abs() > 1e-8 { + x.exp_m1() / x + } else { + 1f64 + x * 0.5 * (1f64 + x * 1.0 / 3.0 * (1f64 + 0.25 * x)) + } +} diff --git a/library/core/tests/slice.rs b/library/core/tests/slice.rs index 7197f3812e542..9ae2bcc852649 100644 --- a/library/core/tests/slice.rs +++ b/library/core/tests/slice.rs @@ -1800,57 +1800,6 @@ fn brute_force_rotate_test_1() { } } -#[test] -#[cfg(not(target_arch = "wasm32"))] -fn sort_unstable() { - use rand::Rng; - - // Miri is too slow (but still need to `chain` to make the types match) - let lens = if cfg!(miri) { (2..20).chain(0..0) } else { (2..25).chain(500..510) }; - let rounds = if cfg!(miri) { 1 } else { 100 }; - - let mut v = [0; 600]; - let mut tmp = [0; 600]; - let mut rng = crate::test_rng(); - - for len in lens { - let v = &mut v[0..len]; - let tmp = &mut tmp[0..len]; - - for &modulus in &[5, 10, 100, 1000] { - for _ in 0..rounds { - for i in 0..len { - v[i] = rng.gen::() % modulus; - } - - // Sort in default order. - tmp.copy_from_slice(v); - tmp.sort_unstable(); - assert!(tmp.windows(2).all(|w| w[0] <= w[1])); - - // Sort in ascending order. - tmp.copy_from_slice(v); - tmp.sort_unstable_by(|a, b| a.cmp(b)); - assert!(tmp.windows(2).all(|w| w[0] <= w[1])); - - // Sort in descending order. - tmp.copy_from_slice(v); - tmp.sort_unstable_by(|a, b| b.cmp(a)); - assert!(tmp.windows(2).all(|w| w[0] >= w[1])); - } - } - } - - // Should not panic. - [0i32; 0].sort_unstable(); - [(); 10].sort_unstable(); - [(); 100].sort_unstable(); - - let mut v = [0xDEADBEEFu64]; - v.sort_unstable(); - assert!(v == [0xDEADBEEF]); -} - #[test] #[cfg(not(target_arch = "wasm32"))] #[cfg_attr(miri, ignore)] // Miri is too slow