diff --git a/library/core/src/slice/mod.rs b/library/core/src/slice/mod.rs index 521c324820446..a46613e7a6b38 100644 --- a/library/core/src/slice/mod.rs +++ b/library/core/src/slice/mod.rs @@ -2787,26 +2787,48 @@ impl [T] { where F: FnMut(&'a T) -> Ordering, { + // If T is a ZST, we assume that f is a constant function. + // We do so because: + // 1. ZSTs can only have one inhabitant + // 2. We assume that f doesn't compare the address of the reference + // passed, but only the value pointed to. + // 3. We assume f's output to be entirely determined by its input + if T::IS_ZST { + let res = if self.len() == 0 { + Err(0) + } else { + match f(&self[0]) { + Less => Err(self.len()), + Equal => Ok(0), + Greater => Err(0), + } + }; + return res; + } + // Now we can assume that T is not a ZST, so self.len() <= isize::MAX // INVARIANTS: - // - 0 <= left <= left + size = right <= self.len() + // - 0 <= left <= right <= self.len() <= isize::MAX // - f returns Less for everything in self[..left] // - f returns Greater for everything in self[right..] - let mut size = self.len(); + let mut right = self.len(); let mut left = 0; - let mut right = size; while left < right { - let mid = left + size / 2; + // left + right <= 2*isize::MAX < usize::MAX + // so the addition won't overflow + let mid = (left + right) / 2; - // SAFETY: the while condition means `size` is strictly positive, so - // `size/2 < size`. Thus `left + size/2 < left + size`, which - // coupled with the `left + size <= self.len()` invariant means - // we have `left + size/2 < self.len()`, and this is in-bounds. + // SAFETY: We have that left < right, so + // 0 <= left <= mid < right <= self.len() + // and the indexing is in-bounds. let cmp = f(unsafe { self.get_unchecked(mid) }); // This control flow produces conditional moves, which results in // fewer branches and instructions than if/else or matching on // cmp::Ordering. // This is x86 asm for u8: https://rust.godbolt.org/z/698eYffTx. + // (Note: the code has slightly changed since this comment but the + // reasoning remains the same.) + left = if cmp == Less { mid + 1 } else { left }; right = if cmp == Greater { mid } else { right }; if cmp == Equal { @@ -2814,10 +2836,7 @@ impl [T] { unsafe { hint::assert_unchecked(mid < self.len()) }; return Ok(mid); } - - size = right - left; } - // SAFETY: directly true from the overall invariant. // Note that this is `<=`, unlike the assume in the `Ok` path. unsafe { hint::assert_unchecked(left <= self.len()) }; diff --git a/library/core/tests/slice.rs b/library/core/tests/slice.rs index 4cbbabb672ba0..c535fd7e8763d 100644 --- a/library/core/tests/slice.rs +++ b/library/core/tests/slice.rs @@ -69,13 +69,13 @@ fn test_binary_search() { assert_eq!(b.binary_search(&8), Err(5)); let b = [(); usize::MAX]; - assert_eq!(b.binary_search(&()), Ok(usize::MAX / 2)); + assert_eq!(b.binary_search(&()), Ok(0)); } #[test] fn test_binary_search_by_overflow() { let b = [(); usize::MAX]; - assert_eq!(b.binary_search_by(|_| Ordering::Equal), Ok(usize::MAX / 2)); + assert_eq!(b.binary_search_by(|_| Ordering::Equal), Ok(0)); assert_eq!(b.binary_search_by(|_| Ordering::Greater), Err(0)); assert_eq!(b.binary_search_by(|_| Ordering::Less), Err(usize::MAX)); }