diff --git a/arrow-buffer/Cargo.toml b/arrow-buffer/Cargo.toml index 68bfe8ddf732..5bd81d1aac33 100644 --- a/arrow-buffer/Cargo.toml +++ b/arrow-buffer/Cargo.toml @@ -28,11 +28,17 @@ include = { workspace = true } edition = { workspace = true } rust-version = { workspace = true } +[package.metadata.docs.rs] +features = ["pool"] + [lib] name = "arrow_buffer" path = "src/lib.rs" bench = false +[features] +pool = [] + [dependencies] bytes = { version = "1.4" } num = { version = "0.4", default-features = false, features = ["std"] } diff --git a/arrow-buffer/src/buffer/immutable.rs b/arrow-buffer/src/buffer/immutable.rs index 8d1a46583fca..48c98f0b7b24 100644 --- a/arrow-buffer/src/buffer/immutable.rs +++ b/arrow-buffer/src/buffer/immutable.rs @@ -357,6 +357,12 @@ impl Buffer { pub fn ptr_eq(&self, other: &Self) -> bool { self.ptr == other.ptr && self.length == other.length } + + /// Register this [`Buffer`] with the provided [`MemoryPool`], replacing any prior assignment + #[cfg(feature = "pool")] + pub fn claim(&self, pool: &dyn crate::MemoryPool) { + self.data.claim(pool) + } } /// Note that here we deliberately do not implement diff --git a/arrow-buffer/src/bytes.rs b/arrow-buffer/src/bytes.rs index ba61342d8e39..acc955851114 100644 --- a/arrow-buffer/src/bytes.rs +++ b/arrow-buffer/src/bytes.rs @@ -44,6 +44,9 @@ pub struct Bytes { /// how to deallocate this region deallocation: Deallocation, + + #[cfg(feature = "pool")] + reservation: std::sync::Mutex>>, } impl Bytes { @@ -65,6 +68,8 @@ impl Bytes { ptr, len, deallocation, + #[cfg(feature = "pool")] + reservation: std::sync::Mutex::new(None), } } @@ -96,6 +101,12 @@ impl Bytes { } } + /// Register this [`Bytes`] with the provided [`MemoryPool`], replacing any prior assignment + #[cfg(feature = "pool")] + pub fn claim(&self, pool: &dyn crate::MemoryPool) { + *self.reservation.lock().unwrap() = Some(pool.register(self.capacity())); + } + #[inline] pub(crate) fn deallocation(&self) -> &Deallocation { &self.deallocation @@ -152,6 +163,8 @@ impl From for Bytes { len, ptr: NonNull::new(value.as_ptr() as _).unwrap(), deallocation: Deallocation::Custom(std::sync::Arc::new(value), len), + #[cfg(feature = "pool")] + reservation: std::sync::Mutex::new(None), } } } diff --git a/arrow-buffer/src/lib.rs b/arrow-buffer/src/lib.rs index 34e432208ada..40013de25ce7 100644 --- a/arrow-buffer/src/lib.rs +++ b/arrow-buffer/src/lib.rs @@ -43,3 +43,8 @@ mod interval; pub use interval::*; mod arith; + +#[cfg(feature = "pool")] +mod pool; +#[cfg(feature = "pool")] +pub use pool::*; diff --git a/arrow-buffer/src/pool.rs b/arrow-buffer/src/pool.rs new file mode 100644 index 000000000000..ea73b4faa431 --- /dev/null +++ b/arrow-buffer/src/pool.rs @@ -0,0 +1,81 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +/// A [`MemoryPool`] can be used to track memory usage by [`Buffer`](crate::Buffer) +pub trait MemoryPool { + /// Return a memory reservation of `size` bytes + fn register(&self, size: usize) -> Box; +} + +/// A memory reservation within a [`MemoryPool`] that is freed on drop +pub trait MemoryReservation { + /// Resize this reservation to `new` bytes + fn resize(&mut self, new: usize); +} + +/// A simple [`MemoryPool`] that reports the total memory usage +#[derive(Debug, Default)] +pub struct TrackingMemoryPool(Arc); + +impl TrackingMemoryPool { + /// Returns the total allocated size + pub fn allocated(&self) -> usize { + self.0.load(Ordering::Relaxed) + } +} + +impl MemoryPool for TrackingMemoryPool { + fn register(&self, size: usize) -> Box { + self.0.fetch_add(size, Ordering::Relaxed); + Box::new(Tracker { + size, + shared: Arc::clone(&self.0), + }) + } +} + +#[derive(Debug)] +struct Tracker { + size: usize, + shared: Arc, +} + +impl Drop for Tracker { + fn drop(&mut self) { + self.shared.fetch_sub(self.size, Ordering::Relaxed); + } +} + +impl MemoryReservation for Tracker { + fn resize(&mut self, new: usize) { + match self.size < new { + true => self.shared.fetch_add(new - self.size, Ordering::Relaxed), + false => self.shared.fetch_sub(self.size - new, Ordering::Relaxed), + }; + self.size = new; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Buffer; + + #[test] + fn test_memory_pool() { + let pool = TrackingMemoryPool::default(); + let b1 = Buffer::from(vec![0_i64, 1, 2]); + let b2 = Buffer::from(vec![3_u16, 4, 5]); + + let buffers = [b1.clone(), b1.slice(12), b1.clone(), b2.clone()]; + buffers.iter().for_each(|x| x.claim(&pool)); + + assert_eq!(pool.allocated(), b1.capacity() + b2.capacity()); + drop(buffers); + assert_eq!(pool.allocated(), b1.capacity() + b2.capacity()); + drop(b2); + assert_eq!(pool.allocated(), b1.capacity()); + drop(b1); + assert_eq!(pool.allocated(), 0); + } +}