diff --git a/library/alloc/src/collections/binary_heap/mod.rs b/library/alloc/src/collections/binary_heap/mod.rs index 4583bc9a158ef..0b73b1af4eb35 100644 --- a/library/alloc/src/collections/binary_heap/mod.rs +++ b/library/alloc/src/collections/binary_heap/mod.rs @@ -146,6 +146,7 @@ use core::fmt; use core::iter::{FromIterator, FusedIterator, InPlaceIterable, SourceIter, TrustedLen}; use core::mem::{self, swap, ManuallyDrop}; +use core::num::NonZeroUsize; use core::ops::{Deref, DerefMut}; use core::ptr; @@ -165,12 +166,20 @@ mod tests; /// It is a logic error for an item to be modified in such a way that the /// item's ordering relative to any other item, as determined by the [`Ord`] /// trait, changes while it is in the heap. This is normally only possible -/// through [`Cell`], [`RefCell`], global state, I/O, or unsafe code. The +/// through interior mutability, global state, I/O, or unsafe code. The /// behavior resulting from such a logic error is not specified, but will /// be encapsulated to the `BinaryHeap` that observed the logic error and not /// result in undefined behavior. This could include panics, incorrect results, /// aborts, memory leaks, and non-termination. /// +/// As long as no elements change their relative order while being in the heap +/// as described above, the API of `BinaryHeap` guarantees that the heap +/// invariant remains intact i.e. its methods all behave as documented. For +/// example if a method is documented as iterating in sorted order, that's +/// guaranteed to work as long as elements in the heap have not changed order, +/// even in the presence of closures getting unwinded out of, iterators getting +/// leaked, and similar foolishness. +/// /// # Examples /// /// ``` @@ -279,7 +288,9 @@ pub struct BinaryHeap { #[stable(feature = "binary_heap_peek_mut", since = "1.12.0")] pub struct PeekMut<'a, T: 'a + Ord> { heap: &'a mut BinaryHeap, - sift: bool, + // If a set_len + sift_down are required, this is Some. If a &mut T has not + // yet been exposed to peek_mut()'s caller, it's None. + original_len: Option, } #[stable(feature = "collection_debug", since = "1.17.0")] @@ -292,7 +303,14 @@ impl fmt::Debug for PeekMut<'_, T> { #[stable(feature = "binary_heap_peek_mut", since = "1.12.0")] impl Drop for PeekMut<'_, T> { fn drop(&mut self) { - if self.sift { + if let Some(original_len) = self.original_len { + // SAFETY: That's how many elements were in the Vec at the time of + // the PeekMut::deref_mut call, and therefore also at the time of + // the BinaryHeap::peek_mut call. Since the PeekMut did not end up + // getting leaked, we are now undoing the leak amplification that + // the DerefMut prepared for. + unsafe { self.heap.data.set_len(original_len.get()) }; + // SAFETY: PeekMut is only instantiated for non-empty heaps. unsafe { self.heap.sift_down(0) }; } @@ -313,7 +331,26 @@ impl Deref for PeekMut<'_, T> { impl DerefMut for PeekMut<'_, T> { fn deref_mut(&mut self) -> &mut T { debug_assert!(!self.heap.is_empty()); - self.sift = true; + + let len = self.heap.len(); + if len > 1 { + // Here we preemptively leak all the rest of the underlying vector + // after the currently max element. If the caller mutates the &mut T + // we're about to give them, and then leaks the PeekMut, all these + // elements will remain leaked. If they don't leak the PeekMut, then + // either Drop or PeekMut::pop will un-leak the vector elements. + // + // This is technique is described throughout several other places in + // the standard library as "leak amplification". + unsafe { + // SAFETY: len > 1 so len != 0. + self.original_len = Some(NonZeroUsize::new_unchecked(len)); + // SAFETY: len > 1 so all this does for now is leak elements, + // which is safe. + self.heap.data.set_len(1); + } + } + // SAFE: PeekMut is only instantiated for non-empty heaps unsafe { self.heap.data.get_unchecked_mut(0) } } @@ -323,9 +360,16 @@ impl<'a, T: Ord> PeekMut<'a, T> { /// Removes the peeked value from the heap and returns it. #[stable(feature = "binary_heap_peek_mut_pop", since = "1.18.0")] pub fn pop(mut this: PeekMut<'a, T>) -> T { - let value = this.heap.pop().unwrap(); - this.sift = false; - value + if let Some(original_len) = this.original_len.take() { + // SAFETY: This is how many elements were in the Vec at the time of + // the BinaryHeap::peek_mut call. + unsafe { this.heap.data.set_len(original_len.get()) }; + + // Unlike in Drop, here we don't also need to do a sift_down even if + // the caller could've mutated the element. It is removed from the + // heap on the next line and pop() is not sensitive to its value. + } + this.heap.pop().unwrap() } } @@ -398,8 +442,9 @@ impl BinaryHeap { /// Returns a mutable reference to the greatest item in the binary heap, or /// `None` if it is empty. /// - /// Note: If the `PeekMut` value is leaked, the heap may be in an - /// inconsistent state. + /// Note: If the `PeekMut` value is leaked, some heap elements might get + /// leaked along with it, but the remaining elements will remain a valid + /// heap. /// /// # Examples /// @@ -426,7 +471,7 @@ impl BinaryHeap { /// otherwise it's *O*(1). #[stable(feature = "binary_heap_peek_mut", since = "1.12.0")] pub fn peek_mut(&mut self) -> Option> { - if self.is_empty() { None } else { Some(PeekMut { heap: self, sift: false }) } + if self.is_empty() { None } else { Some(PeekMut { heap: self, original_len: None }) } } /// Removes the greatest item from the binary heap and returns it, or `None` if it diff --git a/library/alloc/src/collections/binary_heap/tests.rs b/library/alloc/src/collections/binary_heap/tests.rs index 59c516374c0e2..ffbb6c80ac018 100644 --- a/library/alloc/src/collections/binary_heap/tests.rs +++ b/library/alloc/src/collections/binary_heap/tests.rs @@ -1,6 +1,7 @@ use super::*; use crate::boxed::Box; use crate::testing::crash_test::{CrashTestDummy, Panic}; +use core::mem; use std::iter::TrustedLen; use std::panic::{catch_unwind, AssertUnwindSafe}; @@ -146,6 +147,24 @@ fn test_peek_mut() { assert_eq!(heap.peek(), Some(&9)); } +#[test] +fn test_peek_mut_leek() { + let data = vec![4, 2, 7]; + let mut heap = BinaryHeap::from(data); + let mut max = heap.peek_mut().unwrap(); + *max = -1; + + // The PeekMut object's Drop impl would have been responsible for moving the + // -1 out of the max position of the BinaryHeap, but we don't run it. + mem::forget(max); + + // Absent some mitigation like leak amplification, the -1 would incorrectly + // end up in the last position of the returned Vec, with the rest of the + // heap's original contents in front of it in sorted order. + let sorted_vec = heap.into_sorted_vec(); + assert!(sorted_vec.is_sorted(), "{:?}", sorted_vec); +} + #[test] fn test_peek_mut_pop() { let data = vec![2, 4, 6, 2, 1, 8, 10, 3, 5, 7, 0, 9, 1]; diff --git a/library/alloc/src/lib.rs b/library/alloc/src/lib.rs index 4e812529c2cc8..afc3a3dc6a8c6 100644 --- a/library/alloc/src/lib.rs +++ b/library/alloc/src/lib.rs @@ -125,6 +125,7 @@ #![feature(hasher_prefixfree_extras)] #![feature(inline_const)] #![feature(inplace_iteration)] +#![cfg_attr(test, feature(is_sorted))] #![feature(iter_advance_by)] #![feature(iter_next_chunk)] #![feature(iter_repeat_n)]