From e4a10ae406d351a99cf939c997abef09527f567e Mon Sep 17 00:00:00 2001 From: Arthur Carcano Date: Wed, 26 Jun 2024 23:25:54 +0200 Subject: [PATCH] Improve slice::binary_search_by --- library/core/src/slice/mod.rs | 26 ++++++++++++++------------ library/core/tests/slice.rs | 2 +- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/library/core/src/slice/mod.rs b/library/core/src/slice/mod.rs index 521c324820446..1921b08fce849 100644 --- a/library/core/src/slice/mod.rs +++ b/library/core/src/slice/mod.rs @@ -2788,13 +2788,12 @@ impl [T] { F: FnMut(&'a T) -> Ordering, { // INVARIANTS: - // - 0 <= left <= left + size = right <= self.len() + // - 0 <= left <= left + size <= self.len() // - f returns Less for everything in self[..left] - // - f returns Greater for everything in self[right..] + // - f returns Greater for everything in self[left + size..] let mut size = self.len(); let mut left = 0; - let mut right = size; - while left < right { + while size > 1 { let mid = left + size / 2; // SAFETY: the while condition means `size` is strictly positive, so @@ -2807,21 +2806,24 @@ impl [T] { // fewer branches and instructions than if/else or matching on // cmp::Ordering. // This is x86 asm for u8: https://rust.godbolt.org/z/698eYffTx. - left = if cmp == Less { mid + 1 } else { left }; - right = if cmp == Greater { mid } else { right }; + + left = if cmp == Less { mid } else { left }; + size = if cmp == Greater { size / 2 } else { size - size / 2 }; if cmp == Equal { // SAFETY: same as the `get_unchecked` above unsafe { hint::assert_unchecked(mid < self.len()) }; return Ok(mid); } - - size = right - left; } - // SAFETY: directly true from the overall invariant. - // Note that this is `<=`, unlike the assume in the `Ok` path. - unsafe { hint::assert_unchecked(left <= self.len()) }; - Err(left) + if size == 0 { + Err(left) + } else { + // SAFETY: allowed per the invariants + let cmp = f(unsafe { self.get_unchecked(left) }); + let res_idx = if cmp == Less { left + 1 } else { left }; + if cmp == Equal { Ok(res_idx) } else { Err(res_idx) } + } } /// Binary searches this slice with a key extraction function. diff --git a/library/core/tests/slice.rs b/library/core/tests/slice.rs index 4cbbabb672ba0..7e4ad6dd46258 100644 --- a/library/core/tests/slice.rs +++ b/library/core/tests/slice.rs @@ -90,7 +90,7 @@ fn test_binary_search_implementation_details() { assert_eq!(b.binary_search(&3), Ok(5)); let b = [1, 1, 1, 1, 1, 3, 3, 3, 3]; assert_eq!(b.binary_search(&1), Ok(4)); - assert_eq!(b.binary_search(&3), Ok(7)); + assert_eq!(b.binary_search(&3), Ok(6)); let b = [1, 1, 1, 1, 3, 3, 3, 3, 3]; assert_eq!(b.binary_search(&1), Ok(2)); assert_eq!(b.binary_search(&3), Ok(4));