From f297afa0c91243b17283be17864f2c48f91127d9 Mon Sep 17 00:00:00 2001 From: Lukas Bergdoll Date: Sun, 22 Jan 2023 12:01:06 +0100 Subject: [PATCH] Flip scanning direction of stable sort Memory pre-fetching prefers forward scanning vs backwards scanning, and the code-gen is usually better. For the most sensitive types such as integers, these are planned to be merged bidirectionally at once. So there is no benefit in scanning backwards. The largest perf gains are seen for full ascending and descending inputs, which see 1.5x speedups. Random inputs benefit too, and some patterns can loose out, but these losses are minimal. --- library/core/src/slice/sort.rs | 112 ++++++++++++++++++++------------- 1 file changed, 67 insertions(+), 45 deletions(-) diff --git a/library/core/src/slice/sort.rs b/library/core/src/slice/sort.rs index 6bb53b16e610..227db51a0b40 100644 --- a/library/core/src/slice/sort.rs +++ b/library/core/src/slice/sort.rs @@ -1196,52 +1196,37 @@ pub fn merge_sort( let mut runs = RunVec::new(run_alloc_fn, run_dealloc_fn); - // In order to identify natural runs in `v`, we traverse it backwards. That might seem like a - // strange decision, but consider the fact that merges more often go in the opposite direction - // (forwards). According to benchmarks, merging forwards is slightly faster than merging - // backwards. To conclude, identifying runs by traversing backwards improves performance. - let mut end = len; - while end > 0 { - // Find the next natural run, and reverse it if it's strictly descending. - let mut start = end - 1; - if start > 0 { - start -= 1; - - // SAFETY: The v.get_unchecked must be fed with correct inbound indicies. - unsafe { - if is_less(v.get_unchecked(start + 1), v.get_unchecked(start)) { - while start > 0 && is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) { - start -= 1; - } - v[start..end].reverse(); - } else { - while start > 0 && !is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) - { - start -= 1; - } - } - } + let mut end = 0; + let mut start = 0; + + // Scan forward. Memory pre-fetching prefers forward scanning vs backwards scanning, and the + // code-gen is usually better. For the most sensitive types such as integers, these are merged + // bidirectionally at once. So there is no benefit in scanning backwards. + while end < len { + let (streak_end, was_reversed) = find_streak(&v[start..], is_less); + end += streak_end; + if was_reversed { + v[start..end].reverse(); } // Insert some more elements into the run if it's too short. Insertion sort is faster than // merge sort on short sequences, so this significantly improves performance. - start = provide_sorted_batch(v, start, end, is_less); + end = provide_sorted_batch(v, start, end, is_less); // Push this run onto the stack. runs.push(TimSortRun { start, len: end - start }); - end = start; + start = end; // Merge some pairs of adjacent runs to satisfy the invariants. - while let Some(r) = collapse(runs.as_slice()) { - let left = runs[r + 1]; - let right = runs[r]; - // SAFETY: `buf_ptr` must hold enough capacity for the shorter of the two sides, and - // neither side may be on length 0. + while let Some(r) = collapse(runs.as_slice(), len) { + let left = runs[r]; + let right = runs[r + 1]; + let merge_slice = &mut v[left.start..right.start + right.len]; unsafe { - merge(&mut v[left.start..right.start + right.len], left.len, buf_ptr, is_less); + merge(merge_slice, left.len, buf_ptr, is_less); } - runs[r] = TimSortRun { start: left.start, len: left.len + right.len }; - runs.remove(r + 1); + runs[r + 1] = TimSortRun { start: left.start, len: left.len + right.len }; + runs.remove(r); } } @@ -1263,10 +1248,10 @@ pub fn merge_sort( // run starts at index 0, it will always demand a merge operation until the stack is fully // collapsed, in order to complete the sort. #[inline] - fn collapse(runs: &[TimSortRun]) -> Option { + fn collapse(runs: &[TimSortRun], stop: usize) -> Option { let n = runs.len(); if n >= 2 - && (runs[n - 1].start == 0 + && (runs[n - 1].start + runs[n - 1].len == stop || runs[n - 2].len <= runs[n - 1].len || (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len) || (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len)) @@ -1454,14 +1439,15 @@ pub struct TimSortRun { start: usize, } -/// Takes a range as denoted by start and end, that is already sorted and extends it to the left if +/// Takes a range as denoted by start and end, that is already sorted and extends it to the right if /// necessary with sorts optimized for smaller ranges such as insertion sort. #[cfg(not(no_global_oom_handling))] -fn provide_sorted_batch(v: &mut [T], mut start: usize, end: usize, is_less: &mut F) -> usize +fn provide_sorted_batch(v: &mut [T], start: usize, mut end: usize, is_less: &mut F) -> usize where F: FnMut(&T, &T) -> bool, { - debug_assert!(end > start); + let len = v.len(); + assert!(end >= start && end <= len); // This value is a balance between least comparisons and best performance, as // influenced by for example cache locality. @@ -1469,18 +1455,54 @@ where // Insert some more elements into the run if it's too short. Insertion sort is faster than // merge sort on short sequences, so this significantly improves performance. - let start_found = start; let start_end_diff = end - start; - if start_end_diff < MIN_INSERTION_RUN && start != 0 { + if start_end_diff < MIN_INSERTION_RUN && end < len { // v[start_found..end] are elements that are already sorted in the input. We want to extend // the sorted region to the left, so we push up MIN_INSERTION_RUN - 1 to the right. Which is // more efficient that trying to push those already sorted elements to the left. + end = cmp::min(start + MIN_INSERTION_RUN, len); + let presorted_start = cmp::max(start_end_diff, 1); - start = if end >= MIN_INSERTION_RUN { end - MIN_INSERTION_RUN } else { 0 }; + insertion_sort_shift_left(&mut v[start..end], presorted_start, is_less); + } - insertion_sort_shift_right(&mut v[start..end], start_found - start, is_less); + end +} + +/// Finds a streak of presorted elements starting at the beginning of the slice. Returns the first +/// value that is not part of said streak, and a bool denoting wether the streak was reversed. +/// Streaks can be increasing or decreasing. +fn find_streak(v: &[T], is_less: &mut F) -> (usize, bool) +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + + if len < 2 { + return (len, false); } - start + let mut end = 2; + + // SAFETY: See below specific. + unsafe { + // SAFETY: We checked that len >= 2, so 0 and 1 are valid indices. + let assume_reverse = is_less(v.get_unchecked(1), v.get_unchecked(0)); + + // SAFETY: We know end >= 2 and check end < len. + // From that follows that accessing v at end and end - 1 is safe. + if assume_reverse { + while end < len && is_less(v.get_unchecked(end), v.get_unchecked(end - 1)) { + end += 1; + } + + (end, true) + } else { + while end < len && !is_less(v.get_unchecked(end), v.get_unchecked(end - 1)) { + end += 1; + } + (end, false) + } + } }