diff --git a/common_util/src/partitioned_lock.rs b/common_util/src/partitioned_lock.rs index ae1f15c4a6..36cc3c03b5 100644 --- a/common_util/src/partitioned_lock.rs +++ b/common_util/src/partitioned_lock.rs @@ -5,24 +5,27 @@ use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, - num::NonZeroUsize, - sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}, + sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}, }; /// Simple partitioned `RwLock` pub struct PartitionedRwLock { - partitions: Vec>>, + partitions: Vec>, + partition_mask: usize, } -impl PartitionedRwLock { - // TODO: we should get the nearest 2^n of `partition_num` as real - // `partition_num`. By doing so, we can use "&" to get partition rather than - // "%". - pub fn new(t: T, partition_num: NonZeroUsize) -> Self { - let partition_num = partition_num.get(); - let locked_content = Arc::new(RwLock::new(t)); +impl PartitionedRwLock +where + T: Clone, +{ + pub fn new(t: T, partition_bit: usize) -> Self { + let partition_num = 1 << partition_bit; + let partitions = (0..partition_num) + .map(|_| RwLock::new(t.clone())) + .collect::>(); Self { - partitions: vec![locked_content; partition_num], + partitions, + partition_mask: partition_num - 1, } } @@ -41,26 +44,29 @@ impl PartitionedRwLock { fn get_partition(&self, key: &K) -> &RwLock { let mut hasher = DefaultHasher::new(); key.hash(&mut hasher); - let partition_num = self.partitions.len(); - &self.partitions[(hasher.finish() as usize) % partition_num] + &self.partitions[(hasher.finish() as usize) & self.partition_mask] } } /// Simple partitioned `Mutex` pub struct PartitionedMutex { - partitions: Vec>>, + partitions: Vec>, + partition_mask: usize, } -impl PartitionedMutex { - // TODO: we should get the nearest 2^n of `partition_num` as real - // `partition_num`. By doing so, we can use "&" to get partition rather than - // "%". - pub fn new(t: T, partition_num: NonZeroUsize) -> Self { - let partition_num = partition_num.get(); - let locked_content = Arc::new(Mutex::new(t)); +impl PartitionedMutex +where + T: Clone, +{ + pub fn new(t: T, partition_bit: usize) -> Self { + let partition_num = 1 << partition_bit; + let partitions = (0..partition_num) + .map(|_| Mutex::new(t.clone())) + .collect::>(); Self { - partitions: vec![locked_content; partition_num], + partitions, + partition_mask: partition_num - 1, } } @@ -73,9 +79,8 @@ impl PartitionedMutex { fn get_partition(&self, key: &K) -> &Mutex { let mut hasher = DefaultHasher::new(); key.hash(&mut hasher); - let partition_num = self.partitions.len(); - &self.partitions[(hasher.finish() as usize) % partition_num] + &self.partitions[(hasher.finish() as usize) & self.partition_mask] } } @@ -87,8 +92,7 @@ mod tests { #[test] fn test_partitioned_rwlock() { - let test_locked_map = - PartitionedRwLock::new(HashMap::new(), NonZeroUsize::new(10).unwrap()); + let test_locked_map = PartitionedRwLock::new(HashMap::new(), 4); let test_key = "test_key".to_string(); let test_value = "test_value".to_string(); @@ -105,7 +109,7 @@ mod tests { #[test] fn test_partitioned_mutex() { - let test_locked_map = PartitionedMutex::new(HashMap::new(), NonZeroUsize::new(10).unwrap()); + let test_locked_map = PartitionedMutex::new(HashMap::new(), 4); let test_key = "test_key".to_string(); let test_value = "test_value".to_string(); @@ -119,4 +123,34 @@ mod tests { assert_eq!(map.get(&test_key).unwrap(), &test_value); } } + + #[test] + fn test_partitioned_mutex_vis_different_partition() { + let tmp_vec: Vec = Vec::new(); + let test_locked_map = PartitionedMutex::new(tmp_vec, 4); + let test_key_first = "test_key_first".to_string(); + let mutex_first = test_locked_map.get_partition(&test_key_first); + let mut _tmp_data = mutex_first.lock().unwrap(); + assert!(mutex_first.try_lock().is_err()); + + let test_key_second = "test_key_second".to_string(); + let mutex_second = test_locked_map.get_partition(&test_key_second); + assert!(mutex_second.try_lock().is_ok()); + assert!(mutex_first.try_lock().is_err()); + } + + #[test] + fn test_partitioned_rwmutex_vis_different_partition() { + let tmp_vec: Vec = Vec::new(); + let test_locked_map = PartitionedRwLock::new(tmp_vec, 4); + let test_key_first = "test_key_first".to_string(); + let mutex_first = test_locked_map.get_partition(&test_key_first); + let mut _tmp = mutex_first.write().unwrap(); + assert!(mutex_first.try_write().is_err()); + + let test_key_second = "test_key_second".to_string(); + let mutex_second_try_lock = test_locked_map.get_partition(&test_key_second); + assert!(mutex_second_try_lock.try_write().is_ok()); + assert!(mutex_first.try_write().is_err()); + } }