From 467c762a5a8e23ce5bec8bc147e46770e69fb1b0 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Wed, 15 Jan 2025 03:43:42 -0500 Subject: [PATCH] add `HashMap::try_insert_with` --- src/map.rs | 50 ++++++++++++++++++++++++++++++++++++ src/raw/mod.rs | 58 +++++++++++++++++++++++++++++++++++++++-- tests/basic.rs | 70 +++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 175 insertions(+), 3 deletions(-) diff --git a/src/map.rs b/src/map.rs index 5299b54..8d7ab7f 100644 --- a/src/map.rs +++ b/src/map.rs @@ -565,6 +565,42 @@ where } } + /// Tries to insert a key and value computed from a closure into the map, + /// and returns a reference to the value that was inserted. + /// + /// If the map already had this key present, nothing is updated, and + /// the existing value is returned. + /// + /// # Examples + /// + /// ``` + /// use papaya::HashMap; + /// + /// let map = HashMap::new(); + /// let map = map.pin(); + /// + /// assert_eq!(map.try_insert_with(37, || "a").unwrap(), &"a"); + /// + /// let current = map.try_insert_with(37, || "b").unwrap_err(); + /// assert_eq!(current, &"a"); + /// ``` + #[inline] + pub fn try_insert_with<'g, F>( + &self, + key: K, + f: F, + guard: &'g impl Guard, + ) -> Result<&'g V, &'g V> + where + F: FnOnce() -> V, + K: 'g, + { + self.raw.check_guard(guard); + + // Safety: Checked the guard above. + unsafe { self.raw.try_insert_with(key, f, guard) } + } + /// Returns a reference to the value corresponding to the key, or inserts a default value. /// /// If the given key is present, the corresponding value is returned. If it is not present, @@ -1341,6 +1377,20 @@ where } } + /// Tries to insert a key and value computed from a closure into the map, + /// and returns a reference to the value that was inserted. + /// + /// See [`HashMap::try_insert_with`] for details. + /// ``` + #[inline] + pub fn try_insert_with(&self, key: K, f: F) -> Result<&V, &V> + where + F: FnOnce() -> V, + { + // Safety: `self.guard` was created from our map. + unsafe { self.map.raw.try_insert_with(key, f, &self.guard) } + } + /// Returns a reference to the value corresponding to the key, or inserts a default value. /// /// See [`HashMap::get_or_insert`] for details. diff --git a/src/raw/mod.rs b/src/raw/mod.rs index 536681a..0e76fd3 100644 --- a/src/raw/mod.rs +++ b/src/raw/mod.rs @@ -1175,6 +1175,14 @@ where } // Restores the state if an operation fails. + // + // This allows the result of the compute closure with a given input to be memoized. + // This is useful at it avoids calling the closure multiple times if an update needs + // to be retried in a new table. + // + // Additionally, update and insert operations are memoized separately, although this + // is not guaranteed in the public API. This means that internal methods can rely on + // `compute(None)` being called at most once. #[inline] fn restore(&mut self, input: Option<*mut Entry>, output: Operation) { match input { @@ -1238,6 +1246,42 @@ where K: Hash + Eq, S: BuildHasher, { + /// Tries to insert a key and value computed from a closure into the map, + /// and returns a reference to the value that was inserted. + // + // # Safety + // + // The guard must be valid to use with this map. + #[inline] + pub unsafe fn try_insert_with<'g, F>( + &self, + key: K, + f: F, + guard: &'g impl Guard, + ) -> Result<&'g V, &'g V> + where + F: FnOnce() -> V, + K: 'g, + { + let mut f = Some(f); + let compute = |entry| match entry { + // There is already an existing value. + Some((_, current)) => Operation::Abort(current), + + // Insert the initial value. + // + // Note that this case is guaranteed to be executed at most + // once as insert values are memoized, so this can never panic. + None => Operation::Insert((f.take().unwrap())()), + }; + + match self.compute(key, compute, guard) { + Compute::Aborted(current) => Err(current), + Compute::Inserted(_, value) => Ok(value), + _ => unreachable!(), + } + } + /// Returns a reference to the value corresponding to the key, or inserts a default value /// computed from a closure. // @@ -1254,7 +1298,11 @@ where let compute = |entry| match entry { // Return the existing value. Some((_, current)) => Operation::Abort(current), + // Insert the initial value. + // + // Note that this case is guaranteed to be executed at most + // once as insert values are memoized, so this can never panic. None => Operation::Insert((f.take().unwrap())()), }; @@ -1282,8 +1330,10 @@ where K: 'g, { let compute = |entry| match entry { - Some((_, value)) => Operation::Insert(update(value)), + // There is nothing to update. None => Operation::Abort(()), + // Perform the update. + Some((_, value)) => Operation::Insert(update(value)), }; match self.compute(key, compute, guard) { @@ -1317,7 +1367,11 @@ where let compute = |entry| match entry { // Perform the update. Some((_, value)) => Operation::Insert::<_, ()>(update(value)), + // Insert the initial value. + // + // Note that this case is guaranteed to be executed at most + // once as insert values are memoized, so this can never panic. None => Operation::Insert((f.take().unwrap())()), }; @@ -1359,7 +1413,7 @@ where // Deallocate the entry if it was not inserted. if matches!(result, Compute::Removed(..) | Compute::Aborted(_)) { if let LazyEntry::Init(entry) = entry { - // Safety: We allocated this box above and it was not inserted into the map. + // Safety: The entry was allocated but not inserted into the map. let _ = unsafe { Box::from_raw(entry) }; } } diff --git a/tests/basic.rs b/tests/basic.rs index 1efc4a1..59eff2a 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -1,6 +1,6 @@ // Adapted from: https://github.com/jonhoo/flurry/blob/main/tests/basic.rs -use papaya::{Compute, HashMap, Operation}; +use papaya::{Compute, HashMap, OccupiedError, Operation}; use std::hash::{BuildHasher, BuildHasherDefault, Hasher}; use std::sync::Arc; @@ -247,6 +247,74 @@ fn get_or_insert() { }); } +#[test] +fn try_insert() { + with_map::(|map| { + let map = map(); + let guard = map.guard(); + + assert_eq!(map.try_insert(42, 1, &guard), Ok(&1)); + assert_eq!(map.len(), 1); + + { + let guard = map.guard(); + let e = map.get(&42, &guard).unwrap(); + assert_eq!(e, &1); + } + + assert_eq!( + map.try_insert(42, 2, &guard), + Err(OccupiedError { + current: &1, + not_inserted: 2 + }) + ); + assert_eq!(map.len(), 1); + + { + let guard = map.guard(); + let e = map.get(&42, &guard).unwrap(); + assert_eq!(e, &1); + } + + assert_eq!(map.try_insert(43, 2, &guard), Ok(&2)); + }); +} + +#[test] +fn try_insert_with() { + with_map::(|map| { + let map = map(); + let guard = map.guard(); + + map.try_insert_with(42, || 1, &guard).unwrap(); + assert_eq!(map.len(), 1); + + { + let guard = map.guard(); + let e = map.get(&42, &guard).unwrap(); + assert_eq!(e, &1); + } + + let mut called = false; + let insert = || { + called = true; + 2 + }; + assert_eq!(map.try_insert_with(42, insert, &guard), Err(&1)); + assert_eq!(map.len(), 1); + assert!(!called); + + { + let guard = map.guard(); + let e = map.get(&42, &guard).unwrap(); + assert_eq!(e, &1); + } + + assert_eq!(map.try_insert_with(43, || 2, &guard), Ok(&2)); + }); +} + #[test] fn get_or_insert_with() { with_map::(|map| {