From a9f088ce415d865b3b528ba865ec769d6922d4d6 Mon Sep 17 00:00:00 2001 From: Niklas Jonsson Date: Thu, 14 Jul 2022 19:52:18 +0200 Subject: [PATCH] Implement get_many_mut --- src/map.rs | 44 ++++++++++++++++++++++++++++++++++ src/map/tests.rs | 61 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) diff --git a/src/map.rs b/src/map.rs index e71c799c..2fc2794d 100644 --- a/src/map.rs +++ b/src/map.rs @@ -477,6 +477,50 @@ where } } + #[allow(unsafe_code)] + pub fn get_many_mut<'a, 'b, Q: ?Sized, const N: usize>( + &'a mut self, + keys: [&'b Q; N], + ) -> Option<[&'a mut V; N]> + where + Q: Hash + Equivalent, + { + let indices = keys.map(|key| self.get_index_of(key)); + if indices.iter().any(Option::is_none) { + return None; + } + let indices = indices.map(Option::unwrap); + + // SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data + for i in 0..N { + let idx = indices[i]; + if indices[i + 1..N].contains(&idx) { + return None; + } + } + + // Replace with MaybeUninit::uninit_array when that is stable + // SAFETY: Creating MaybeUninit from uninit is always safe + let mut out: [std::mem::MaybeUninit<&'a mut V>; N] = + unsafe { std::mem::MaybeUninit::uninit().assume_init() }; + + let entries = self.as_entries_mut(); + for (elem, idx) in out.iter_mut().zip(indices) { + let v: &mut V = &mut entries[idx].value; + // SAFETY: We already checked that each index is unique and we own each instance of V. + // As we know that each index is unique, it is OK to discard the mutable borrow lifetime of v, + // we will never mutably borrow an element twice. + unsafe { std::ptr::write(elem.as_mut_ptr(), &mut *(v as *mut V)) }; + } + + // Can't transmute a const-generic sized array: + // https://github.com/rust-lang/rust/issues/61956 + // This is the workaround. + // SAFETY: This is fine as the references all are from unique entries that we own and all of + // them have been initialized by the above loop. + Some(unsafe { std::mem::transmute_copy::<_, [&'a mut V; N]>(&out) }) + } + /// Remove the key-value pair equivalent to `key` and return /// its value. /// diff --git a/src/map/tests.rs b/src/map/tests.rs index b6c6a42d..cbf856a0 100644 --- a/src/map/tests.rs +++ b/src/map/tests.rs @@ -418,3 +418,64 @@ fn from_array() { assert_eq!(map, expected) } + +#[test] +fn many_mut_empty() { + let mut map: IndexMap = IndexMap::default(); + assert!(map.get_many_mut([&0, &1, &2, &3]).is_none()); +} + +#[test] +fn many_mut_single_fail() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + assert!(map.get_many_mut([&0]).is_none()); +} + +#[test] +fn many_mut_single_success() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + assert_eq!(map.get_many_mut([&1]), Some([&mut 10])); +} + +#[test] +fn many_mut_multi_success() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + map.insert(1123, 100); + map.insert(321, 20); + map.insert(1337, 30); + assert_eq!(map.get_many_mut([&1, &1123]), Some([&mut 10, &mut 100])); + assert_eq!(map.get_many_mut([&1, &1337]), Some([&mut 10, &mut 30])); + assert_eq!( + map.get_many_mut([&1337, &321, &1, &1123]), + Some([&mut 30, &mut 20, &mut 10, &mut 100]) + ); +} + +#[test] +fn many_mut_multi_fail_missing() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + map.insert(1123, 100); + map.insert(321, 20); + map.insert(1337, 30); + assert_eq!(map.get_many_mut([&121, &1123]), None); + assert_eq!(map.get_many_mut([&1, &1337, &56]), None); + assert_eq!(map.get_many_mut([&1337, &123, &321, &1, &1123]), None); +} + +#[test] +fn many_mut_multi_fail_duplicate() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + map.insert(1123, 100); + map.insert(321, 20); + map.insert(1337, 30); + assert_eq!(map.get_many_mut([&1, &1]), None); + assert_eq!( + map.get_many_mut([&1337, &123, &321, &1337, &1, &1123]), + None + ); +}