Skip to content

Commit

Permalink
Add get_many_mut methods to slice
Browse files Browse the repository at this point in the history
  • Loading branch information
Kimundi authored and Mark-Simulacrum committed Nov 20, 2022
1 parent 9cdfe03 commit 3fe37b8
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 0 deletions.
3 changes: 3 additions & 0 deletions library/core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,3 +506,6 @@ impl Error for crate::ffi::FromBytesWithNulError {

#[unstable(feature = "cstr_from_bytes_until_nul", issue = "95027")]
impl Error for crate::ffi::FromBytesUntilNulError {}

#[unstable(feature = "get_many_mut", issue = "104642")]
impl<const N: usize> Error for crate::slice::GetManyMutError<N> {}
136 changes: 136 additions & 0 deletions library/core/src/slice/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#![stable(feature = "rust1", since = "1.0.0")]

use crate::cmp::Ordering::{self, Greater, Less};
use crate::fmt;
use crate::intrinsics::{assert_unsafe_precondition, exact_div};
use crate::marker::Copy;
use crate::mem::{self, SizedTypeProperties};
Expand Down Expand Up @@ -4082,6 +4083,88 @@ impl<T> [T] {
*self = rem;
Some(last)
}

/// Returns mutable references to many indices at once, without doing any checks.
///
/// For a safe alternative see [`get_many_mut`].
///
/// # Safety
///
/// Calling this method with overlapping or out-of-bounds indices is *[undefined behavior]*
/// even if the resulting references are not used.
///
/// # Examples
///
/// ```
/// #![feature(get_many_mut)]
///
/// let x = &mut [1, 2, 4];
///
/// unsafe {
/// let [a, b] = x.get_many_unchecked_mut([0, 2]);
/// *a *= 10;
/// *b *= 100;
/// }
/// assert_eq!(x, &[10, 2, 400]);
/// ```
///
/// [`get_many_mut`]: slice::get_many_mut
/// [undefined behavior]: https://doc.rust-lang.org/reference/behavior-considered-undefined.html
#[unstable(feature = "get_many_mut", issue = "104642")]
#[inline]
pub unsafe fn get_many_unchecked_mut<const N: usize>(
&mut self,
indices: [usize; N],
) -> [&mut T; N] {
// NB: This implementation is written as it is because any variation of
// `indices.map(|i| self.get_unchecked_mut(i))` would make miri unhappy,
// or generate worse code otherwise. This is also why we need to go
// through a raw pointer here.
let slice: *mut [T] = self;
let mut arr: mem::MaybeUninit<[&mut T; N]> = mem::MaybeUninit::uninit();
let arr_ptr = arr.as_mut_ptr();

// SAFETY: We expect `indices` to contain disjunct values that are
// in bounds of `self`.
unsafe {
for i in 0..N {
let idx = *indices.get_unchecked(i);
*(*arr_ptr).get_unchecked_mut(i) = &mut *slice.get_unchecked_mut(idx);
}
arr.assume_init()
}
}

/// Returns mutable references to many indices at once.
///
/// Returns an error if any index is out-of-bounds, or if the same index was
/// passed more than once.
///
/// # Examples
///
/// ```
/// #![feature(get_many_mut)]
///
/// let v = &mut [1, 2, 3];
/// if let Ok([a, b]) = v.get_many_mut([0, 2]) {
/// *a = 413;
/// *b = 612;
/// }
/// assert_eq!(v, &[413, 2, 612]);
/// ```
#[unstable(feature = "get_many_mut", issue = "104642")]
#[inline]
pub fn get_many_mut<const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[&mut T; N], GetManyMutError<N>> {
if !get_many_check_valid(&indices, self.len()) {
return Err(GetManyMutError { _private: () });
}
// SAFETY: The `get_many_check_valid()` call checked that all indices
// are disjunct and in bounds.
unsafe { Ok(self.get_many_unchecked_mut(indices)) }
}
}

impl<T, const N: usize> [[T; N]] {
Expand Down Expand Up @@ -4304,3 +4387,56 @@ impl<T, const N: usize> SlicePattern for [T; N] {
self
}
}

/// This checks every index against each other, and against `len`.
///
/// This will do `binomial(N + 1, 2) = N * (N + 1) / 2 = 0, 1, 3, 6, 10, ..`
/// comparison operations.
fn get_many_check_valid<const N: usize>(indices: &[usize; N], len: usize) -> bool {
// NB: The optimzer should inline the loops into a sequence
// of instructions without additional branching.
let mut valid = true;
for (i, &idx) in indices.iter().enumerate() {
valid &= idx < len;
for &idx2 in &indices[..i] {
valid &= idx != idx2;
}
}
valid
}

/// The error type returned by [`get_many_mut<N>`][`slice::get_many_mut`].
///
/// It indicates one of two possible errors:
/// - An index is out-of-bounds.
/// - The same index appeared multiple times in the array.
///
/// # Examples
///
/// ```
/// #![feature(get_many_mut)]
///
/// let v = &mut [1, 2, 3];
/// assert!(v.get_many_mut([0, 999]).is_err());
/// assert!(v.get_many_mut([1, 1]).is_err());
/// ```
#[unstable(feature = "get_many_mut", issue = "104642")]
// NB: The N here is there to be forward-compatible with adding more details
// to the error type at a later point
pub struct GetManyMutError<const N: usize> {
_private: (),
}

#[unstable(feature = "get_many_mut", issue = "104642")]
impl<const N: usize> fmt::Debug for GetManyMutError<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GetManyMutError").finish_non_exhaustive()
}
}

#[unstable(feature = "get_many_mut", issue = "104642")]
impl<const N: usize> fmt::Display for GetManyMutError<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt("an index is out of bounds or appeared multiple times in the array", f)
}
}
1 change: 1 addition & 0 deletions library/core/tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
#![feature(provide_any)]
#![feature(utf8_chunks)]
#![feature(is_ascii_octdigit)]
#![feature(get_many_mut)]
#![deny(unsafe_op_in_unsafe_fn)]

extern crate test;
Expand Down
60 changes: 60 additions & 0 deletions library/core/tests/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2595,3 +2595,63 @@ fn test_flatten_mut_size_overflow() {
let x = &mut [[(); usize::MAX]; 2][..];
let _ = x.flatten_mut();
}

#[test]
fn test_get_many_mut_normal_2() {
let mut v = vec![1, 2, 3, 4, 5];
let [a, b] = v.get_many_mut([3, 0]).unwrap();
*a += 10;
*b += 100;
assert_eq!(v, vec![101, 2, 3, 14, 5]);
}

#[test]
fn test_get_many_mut_normal_3() {
let mut v = vec![1, 2, 3, 4, 5];
let [a, b, c] = v.get_many_mut([0, 4, 2]).unwrap();
*a += 10;
*b += 100;
*c += 1000;
assert_eq!(v, vec![11, 2, 1003, 4, 105]);
}

#[test]
fn test_get_many_mut_empty() {
let mut v = vec![1, 2, 3, 4, 5];
let [] = v.get_many_mut([]).unwrap();
assert_eq!(v, vec![1, 2, 3, 4, 5]);
}

#[test]
fn test_get_many_mut_single_first() {
let mut v = vec![1, 2, 3, 4, 5];
let [a] = v.get_many_mut([0]).unwrap();
*a += 10;
assert_eq!(v, vec![11, 2, 3, 4, 5]);
}

#[test]
fn test_get_many_mut_single_last() {
let mut v = vec![1, 2, 3, 4, 5];
let [a] = v.get_many_mut([4]).unwrap();
*a += 10;
assert_eq!(v, vec![1, 2, 3, 4, 15]);
}

#[test]
fn test_get_many_mut_oob_nonempty() {
let mut v = vec![1, 2, 3, 4, 5];
assert!(v.get_many_mut([5]).is_err());
}

#[test]
fn test_get_many_mut_oob_empty() {
let mut v: Vec<i32> = vec![];
assert!(v.get_many_mut([0]).is_err());
}

#[test]
fn test_get_many_mut_duplicate() {
let mut v = vec![1, 2, 3, 4, 5];
assert!(v.get_many_mut([1, 3, 3, 4]).is_err());
}
1 change: 1 addition & 0 deletions library/std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@
#![feature(stdsimd)]
#![feature(test)]
#![feature(trace_macros)]
#![feature(get_many_mut)]
//
// Only used in tests/benchmarks:
//
Expand Down

0 comments on commit 3fe37b8

Please sign in to comment.