From 4519916d1299af08823759f906d7f6405e9edba3 Mon Sep 17 00:00:00 2001 From: Arthur Silva Date: Sun, 18 Aug 2024 19:57:38 +0200 Subject: [PATCH] Add Miri tests and address found issues --- .github/workflows/ci.yml | 14 ++++++++++ README.md | 1 - src/lib.rs | 42 +++++++++++++++++++++++------- src/linked_slab.rs | 20 +++++++++++++-- src/shard.rs | 55 +++++++++++++++++++++++++++++----------- src/sync.rs | 1 + src/unsync.rs | 11 +++++--- 7 files changed, 114 insertions(+), 30 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1563639..437af40 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -98,3 +98,17 @@ jobs: with: command: clippy args: -- -D warnings + + miri: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Install Rust toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: nightly + override: true + components: miri + - name: Run Miri + run: cargo miri test diff --git a/README.md b/README.md index 5ad47e8..aedfda6 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ # Quick Cache - [![Crates.io](https://img.shields.io/crates/v/quick_cache.svg)](https://crates.io/crates/quick_cache) [![Docs](https://docs.rs/quick_cache/badge.svg)](https://docs.rs/quick_cache/latest) [![CI](https://github.com/arthurprs/quick-cache/actions/workflows/ci.yml/badge.svg)](https://github.com/arthurprs/quick-cache/actions/workflows/ci.yml) diff --git a/src/lib.rs b/src/lib.rs index 4ce6d02..99e4802 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -169,6 +169,14 @@ mod tests { }; use super::*; + #[derive(Clone)] + struct StringWeighter; + + impl Weighter for StringWeighter { + fn weight(&self, _key: &u64, val: &String) -> u64 { + val.len() as u64 + } + } #[test] fn test_new() { @@ -186,15 +194,6 @@ mod tests { #[test] fn test_custom_cost() { - #[derive(Clone)] - struct StringWeighter; - - impl Weighter for StringWeighter { - fn weight(&self, _key: &u64, val: &String) -> u64 { - val.len() as u64 - } - } - let cache = sync::Cache::with_weighter(100, 100_000, StringWeighter); cache.insert(1, "1".to_string()); cache.insert(54, "54".to_string()); @@ -202,6 +201,27 @@ mod tests { assert_eq!(cache.get(&1000).unwrap(), "1000"); } + #[test] + fn test_change_get_mut_change_weight() { + let mut cache = unsync::Cache::with_weighter(100, 100_000, StringWeighter); + cache.insert(1, "1".to_string()); + assert_eq!(cache.get(&1).unwrap(), "1"); + assert_eq!(cache.weight(), 1); + let _old = { + cache + .get_mut(&1) + .map(|mut v| std::mem::replace(&mut *v, "11".to_string())) + }; + let _old = { + cache + .get_mut(&1) + .map(|mut v| std::mem::replace(&mut *v, "".to_string())) + }; + assert_eq!(cache.get(&1).unwrap(), ""); + assert_eq!(cache.weight(), 0); + cache.validate(); + } + #[derive(Debug, Hash)] pub struct Pair(pub A, pub B); @@ -243,6 +263,7 @@ mod tests { } #[test] + #[cfg_attr(miri, ignore)] fn test_get_or_insert() { use rand::prelude::*; for _i in 0..2000 { @@ -334,6 +355,7 @@ mod tests { } #[test] + #[cfg_attr(miri, ignore)] fn test_value_or_guard() { use crate::sync::*; use rand::prelude::*; @@ -371,6 +393,7 @@ mod tests { } #[tokio::test(flavor = "multi_thread")] + #[cfg_attr(miri, ignore)] async fn test_get_or_insert_async() { use rand::prelude::*; for _i in 0..5000 { @@ -413,6 +436,7 @@ mod tests { } #[tokio::test(flavor = "multi_thread")] + #[cfg_attr(miri, ignore)] async fn test_value_or_guard_async() { use rand::prelude::*; for _i in 0..5000 { diff --git a/src/linked_slab.rs b/src/linked_slab.rs index 5371245..cb71b9d 100644 --- a/src/linked_slab.rs +++ b/src/linked_slab.rs @@ -31,12 +31,12 @@ impl LinkedSlab { self.entries.iter().filter(|e| e.item.is_some()).count() } - #[cfg(fuzzing)] + #[cfg(any(fuzzing, test))] pub fn iter_entries(&self) -> impl Iterator + '_ { self.entries.iter().filter_map(|e| e.item.as_ref()) } - #[cfg(fuzzing)] + #[cfg(any(fuzzing, test))] pub fn validate(&self) { let mut freelist = std::collections::HashSet::new(); let mut next_free = self.next_free; @@ -104,6 +104,22 @@ impl LinkedSlab { None } + /// Gets an entry and a token to the next entry w/o checking, thus unsafe. + #[inline] + pub unsafe fn get_unchecked(&self, index: Token) -> (&T, Token) { + let entry = self.entries.get_unchecked((index.get() - 1) as usize); + let v = entry.item.as_ref().unwrap_unchecked(); + (v, entry.next) + } + + /// Gets an entry and a token to the next entry w/o checking, thus unsafe. + #[inline] + pub unsafe fn get_mut_unchecked(&mut self, index: Token) -> (&mut T, Token) { + let entry = self.entries.get_unchecked_mut((index.get() - 1) as usize); + let v = entry.item.as_mut().unwrap_unchecked(); + (v, entry.next) + } + /// Links an entry before `target_head`. Returns the item next to the linked item, /// which is either the item itself or `target_head`. /// diff --git a/src/shard.rs b/src/shard.rs index 6fc6628..d41f158 100644 --- a/src/shard.rs +++ b/src/shard.rs @@ -206,7 +206,7 @@ impl< } } - #[cfg(fuzzing)] + #[cfg(any(fuzzing, test))] pub fn validate(&self) { self.entries.validate(); let mut num_hot = 0; @@ -417,7 +417,6 @@ impl< record_miss_mut!(self); return None; }; - let cache = self as *mut _; let Some((Entry::Resident(resident), _)) = self.entries.get_mut(idx) else { unreachable!() }; @@ -428,11 +427,9 @@ impl< let old_weight = self.weighter.weight(&resident.key, &resident.value); Some(RefMut { - key: &resident.key, - value: &mut resident.value, idx, old_weight, - cache, + cache: self, }) } @@ -447,15 +444,12 @@ impl< #[inline] pub fn peek_token_mut(&mut self, token: Token) -> Option> { - let cache = self as *mut _; if let Some((Entry::Resident(resident), _)) = self.entries.get_mut(token) { let old_weight = self.weighter.weight(&resident.key, &resident.value); Some(RefMut { - key: &resident.key, - value: &mut resident.value, old_weight, idx: token, - cache, + cache: self, }) } else { None @@ -992,11 +986,41 @@ impl< /// Structure wrapping a mutable reference to a cached item. pub struct RefMut<'cache, Key, Val, We: Weighter, B, L, Plh: SharedPlaceholder> { - pub key: &'cache Key, - pub value: &'cache mut Val, + cache: &'cache mut CacheShard, idx: Token, old_weight: u64, - cache: *mut CacheShard, +} + +impl<'cache, Key, Val, We: Weighter, B, L, Plh: SharedPlaceholder> + RefMut<'cache, Key, Val, We, B, L, Plh> +{ + pub(crate) fn pair(&self) -> (&Key, &Val) { + // Safety: RefMut was constructed correctly from a Resident entry in get_mut or peek_token_mut + // and it couldn't be modified as we're holding a mutable reference to the cache + unsafe { + if let (Entry::Resident(Resident { key, value, .. }), _) = + self.cache.entries.get_unchecked(self.idx) + { + (key, value) + } else { + core::hint::unreachable_unchecked() + } + } + } + + pub(crate) fn value_mut(&mut self) -> &mut Val { + // Safety: RefMut was constructed correctly from a Resident entry in get_mut or peek_token_mut + // and it couldn't be modified as we're holding a mutable reference to the cache + unsafe { + if let (Entry::Resident(Resident { value, .. }), _) = + self.cache.entries.get_mut_unchecked(self.idx) + { + value + } else { + core::hint::unreachable_unchecked() + } + } + } } impl<'cache, Key, Val, We: Weighter, B, L, Plh: SharedPlaceholder> Drop @@ -1004,10 +1028,11 @@ impl<'cache, Key, Val, We: Weighter, B, L, Plh: SharedPlaceholder> Dro { #[inline] fn drop(&mut self) { - let value = &*self.value; - let new_weight = unsafe { &*self.cache }.weighter.weight(self.key, value); + let (key, value) = self.pair(); + let new_weight = self.cache.weighter.weight(key, value); if self.old_weight != new_weight { - unsafe { &mut *self.cache }.cold_change_weight(self.idx, self.old_weight, new_weight); + self.cache + .cold_change_weight(self.idx, self.old_weight, new_weight); } } } diff --git a/src/sync.rs b/src/sync.rs index dfd0823..6995384 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -443,6 +443,7 @@ mod tests { }; #[test] + #[cfg_attr(miri, ignore)] fn test_multiple_threads() { const N_THREAD_PAIRS: usize = 8; const N_ROUNDS: usize = 1_000; diff --git a/src/unsync.rs b/src/unsync.rs index fc75dca..d837179 100644 --- a/src/unsync.rs +++ b/src/unsync.rs @@ -140,7 +140,7 @@ impl, B: BuildHasher, L: Lifecycle(&self, key: &Q) -> Option<&Val> where Q: Hash + Equivalent + ?Sized, @@ -356,6 +356,11 @@ impl, B: BuildHasher, L: Lifecycle std::fmt::Debug for Cache { @@ -464,7 +469,7 @@ impl<'cache, Key, Val, We: Weighter, B, L> std::ops::Deref #[inline] fn deref(&self) -> &Self::Target { - self.0.value + self.0.pair().1 } } @@ -473,7 +478,7 @@ impl<'cache, Key, Val, We: Weighter, B, L> std::ops::DerefMut { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { - self.0.value + self.0.value_mut() } }