diff --git a/Cargo.toml b/Cargo.toml index 02a44d46..19e794e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,12 +49,14 @@ nix = "0.29.0" once_cell = "1.18.0" os_pipe = "1.1.4" paste = "1.0.14" +rand = "0.8.5" slab = "0.4.9" socket2 = "0.5.6" tempfile = "3.8.1" tokio = "1.33.0" widestring = "1.0.2" windows-sys = "0.52.0" +pin-project-lite = "0.2.14" [profile.bench] debug = true diff --git a/compio-driver/Cargo.toml b/compio-driver/Cargo.toml index b7daf11e..61d01848 100644 --- a/compio-driver/Cargo.toml +++ b/compio-driver/Cargo.toml @@ -42,6 +42,7 @@ socket2 = { workspace = true } [target.'cfg(windows)'.dependencies] aligned-array = "1.0.1" once_cell = { workspace = true } +pin-project-lite = { workspace = true } windows-sys = { workspace = true, features = [ "Win32_Foundation", "Win32_Networking_WinSock", @@ -58,12 +59,16 @@ windows-sys = { workspace = true, features = [ # Linux specific dependencies [target.'cfg(target_os = "linux")'.dependencies] io-uring = { version = "0.6.2", optional = true } +io_uring_buf_ring = { version = "0.1.0", optional = true } polling = { version = "3.3.0", optional = true } os_pipe = { workspace = true, optional = true } paste = { workspace = true } +pin-project-lite = { workspace = true } +slab = { version = "0.4.9", optional = true } # Other platform dependencies [target.'cfg(all(not(target_os = "linux"), unix))'.dependencies] +pin-project-lite = { workspace = true } polling = "3.3.0" os_pipe = { workspace = true } @@ -77,11 +82,13 @@ compio-buf = { workspace = true, features = ["arrayvec"] } [features] default = ["io-uring"] +io-uring = ["dep:io-uring"] polling = ["dep:polling", "dep:os_pipe"] io-uring-sqe128 = [] io-uring-cqe32 = [] io-uring-socket = [] +io-uring-buf-ring = ["dep:io_uring_buf_ring", "dep:slab"] iocp-global = [] diff --git a/compio-driver/src/fallback_buffer_pool.rs b/compio-driver/src/fallback_buffer_pool.rs new file mode 100644 index 00000000..f274bddc --- /dev/null +++ b/compio-driver/src/fallback_buffer_pool.rs @@ -0,0 +1,117 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + cell::RefCell, + collections::VecDeque, + fmt::{Debug, Formatter}, + mem::ManuallyDrop, + ops::{Deref, DerefMut}, +}; + +use compio_buf::{IntoInner, Slice}; + +/// Buffer pool +/// +/// A buffer pool to allow user no need to specify a specific buffer to do the +/// IO operation +pub struct BufferPool { + buffers: RefCell>>, +} + +impl Debug for BufferPool { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufferPool").finish_non_exhaustive() + } +} + +impl BufferPool { + pub(crate) fn new(buffer_len: u16, buffer_size: usize) -> Self { + let buffers = (0..buffer_len.next_power_of_two()) + .map(|_| Vec::with_capacity(buffer_size)) + .collect(); + + Self { + buffers: RefCell::new(buffers), + } + } + + pub(crate) fn get_buffer(&self) -> Option> { + self.buffers.borrow_mut().pop_front() + } + + pub(crate) fn add_buffer(&self, mut buffer: Vec) { + buffer.clear(); + self.buffers.borrow_mut().push_back(buffer) + } +} + +/// Buffer borrowed from buffer pool +/// +/// When IO operation finish, user will obtain a `BorrowedBuffer` to access the +/// filled data +pub struct BorrowedBuffer<'a> { + buffer: ManuallyDrop>>, + pool: &'a BufferPool, +} + +impl<'a> BorrowedBuffer<'a> { + pub(crate) fn new(buffer: Slice>, pool: &'a BufferPool) -> Self { + Self { + buffer: ManuallyDrop::new(buffer), + pool, + } + } +} + +impl Debug for BorrowedBuffer<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BorrowedBuffer").finish_non_exhaustive() + } +} + +impl Drop for BorrowedBuffer<'_> { + fn drop(&mut self) { + let buffer = unsafe { + // Safety: we won't take self.buffer again + ManuallyDrop::take(&mut self.buffer) + }; + self.pool.add_buffer(buffer.into_inner()); + } +} + +impl Deref for BorrowedBuffer<'_> { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.buffer.deref() + } +} + +impl DerefMut for BorrowedBuffer<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.buffer.deref_mut() + } +} + +impl AsRef<[u8]> for BorrowedBuffer<'_> { + fn as_ref(&self) -> &[u8] { + self.deref() + } +} + +impl AsMut<[u8]> for BorrowedBuffer<'_> { + fn as_mut(&mut self) -> &mut [u8] { + self.deref_mut() + } +} + +impl Borrow<[u8]> for BorrowedBuffer<'_> { + fn borrow(&self) -> &[u8] { + self.deref() + } +} + +impl BorrowMut<[u8]> for BorrowedBuffer<'_> { + fn borrow_mut(&mut self) -> &mut [u8] { + self.deref_mut() + } +} diff --git a/compio-driver/src/fusion/buffer_pool.rs b/compio-driver/src/fusion/buffer_pool.rs new file mode 100644 index 00000000..7d47e3b1 --- /dev/null +++ b/compio-driver/src/fusion/buffer_pool.rs @@ -0,0 +1,145 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + fmt::{Debug, Formatter}, + ops::{Deref, DerefMut}, +}; + +/// Buffer pool +/// +/// A buffer pool to allow user no need to specify a specific buffer to do the +/// IO operation +pub struct BufferPool { + inner: BufferPollInner, +} + +impl Debug for BufferPool { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufferPool").finish_non_exhaustive() + } +} + +impl BufferPool { + pub(crate) fn new_io_uring(buffer_pool: super::iour::buffer_pool::BufferPool) -> Self { + Self { + inner: BufferPollInner::IoUring(buffer_pool), + } + } + + pub(crate) fn as_io_uring(&self) -> &super::iour::buffer_pool::BufferPool { + match &self.inner { + BufferPollInner::IoUring(inner) => inner, + BufferPollInner::Poll(_) => panic!("BufferPool type is not poll type"), + } + } + + pub(crate) fn as_poll(&self) -> &crate::fallback_buffer_pool::BufferPool { + match &self.inner { + BufferPollInner::Poll(inner) => inner, + BufferPollInner::IoUring(_) => panic!("BufferPool type is not io-uring type"), + } + } + + pub(crate) fn new_poll(buffer_pool: crate::fallback_buffer_pool::BufferPool) -> Self { + Self { + inner: BufferPollInner::Poll(buffer_pool), + } + } + + pub(crate) fn into_poll(self) -> crate::fallback_buffer_pool::BufferPool { + match self.inner { + BufferPollInner::IoUring(_) => { + panic!("BufferPool type is not io-uring type") + } + BufferPollInner::Poll(inner) => inner, + } + } + + pub(crate) fn into_io_uring(self) -> super::iour::buffer_pool::BufferPool { + match self.inner { + BufferPollInner::IoUring(inner) => inner, + BufferPollInner::Poll(_) => panic!("BufferPool type is not poll type"), + } + } +} + +enum BufferPollInner { + IoUring(super::iour::buffer_pool::BufferPool), + Poll(crate::fallback_buffer_pool::BufferPool), +} + +/// Buffer borrowed from buffer pool +/// +/// When IO operation finish, user will obtain a `BorrowedBuffer` to access the +/// filled data +pub struct BorrowedBuffer<'a> { + inner: BorrowedBufferInner<'a>, +} + +impl<'a> BorrowedBuffer<'a> { + pub(crate) fn new_io_uring(buffer: super::iour::buffer_pool::BorrowedBuffer<'a>) -> Self { + Self { + inner: BorrowedBufferInner::IoUring(buffer), + } + } + + pub(crate) fn new_poll(buffer: crate::fallback_buffer_pool::BorrowedBuffer<'a>) -> Self { + Self { + inner: BorrowedBufferInner::Poll(buffer), + } + } +} + +enum BorrowedBufferInner<'a> { + IoUring(super::iour::buffer_pool::BorrowedBuffer<'a>), + Poll(crate::fallback_buffer_pool::BorrowedBuffer<'a>), +} + +impl Debug for BorrowedBuffer<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BorrowedBuffer").finish_non_exhaustive() + } +} + +impl Deref for BorrowedBuffer<'_> { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + match &self.inner { + BorrowedBufferInner::IoUring(inner) => inner.deref(), + BorrowedBufferInner::Poll(inner) => inner.deref(), + } + } +} + +impl DerefMut for BorrowedBuffer<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + match &mut self.inner { + BorrowedBufferInner::IoUring(inner) => inner.deref_mut(), + BorrowedBufferInner::Poll(inner) => inner.deref_mut(), + } + } +} + +impl AsRef<[u8]> for BorrowedBuffer<'_> { + fn as_ref(&self) -> &[u8] { + self.deref() + } +} + +impl AsMut<[u8]> for BorrowedBuffer<'_> { + fn as_mut(&mut self) -> &mut [u8] { + self.deref_mut() + } +} + +impl Borrow<[u8]> for BorrowedBuffer<'_> { + fn borrow(&self) -> &[u8] { + self.deref() + } +} + +impl BorrowMut<[u8]> for BorrowedBuffer<'_> { + fn borrow_mut(&mut self) -> &mut [u8] { + self.deref_mut() + } +} diff --git a/compio-driver/src/fusion/mod.rs b/compio-driver/src/fusion/mod.rs index b621e909..d633cd87 100644 --- a/compio-driver/src/fusion/mod.rs +++ b/compio-driver/src/fusion/mod.rs @@ -4,6 +4,7 @@ mod poll; #[path = "../iour/mod.rs"] mod iour; +pub(crate) mod buffer_pool; pub(crate) mod op; #[cfg_attr(all(doc, docsrs), doc(cfg(all())))] @@ -15,7 +16,7 @@ pub(crate) use iour::{sockaddr_storage, socklen_t}; pub use iour::{OpCode as IourOpCode, OpEntry}; pub use poll::{Decision, OpCode as PollOpCode}; -use crate::{Key, OutEntries, ProactorBuilder}; +use crate::{BufferPool, Key, OutEntries, ProactorBuilder}; mod driver_type { use std::sync::atomic::{AtomicU8, Ordering}; @@ -180,6 +181,31 @@ impl Driver { }; Ok(NotifyHandle::from_fuse(fuse)) } + + pub fn create_buffer_pool( + &mut self, + buffer_len: u16, + buffer_size: usize, + ) -> io::Result { + match &mut self.fuse { + FuseDriver::IoUring(driver) => Ok(BufferPool::new_io_uring( + driver.create_buffer_pool(buffer_len, buffer_size)?, + )), + FuseDriver::Poll(driver) => Ok(BufferPool::new_poll( + driver.create_buffer_pool(buffer_len, buffer_size)?, + )), + } + } + + /// # Safety + /// + /// caller must make sure release the buffer pool with correct driver + pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> { + match &mut self.fuse { + FuseDriver::Poll(driver) => driver.release_buffer_pool(buffer_pool.into_poll()), + FuseDriver::IoUring(driver) => driver.release_buffer_pool(buffer_pool.into_io_uring()), + } + } } impl AsRawFd for Driver { diff --git a/compio-driver/src/fusion/op.rs b/compio-driver/src/fusion/op.rs index 6ec78845..a6d69602 100644 --- a/compio-driver/src/fusion/op.rs +++ b/compio-driver/src/fusion/op.rs @@ -1,11 +1,14 @@ -use std::ffi::CString; +use std::{ffi::CString, io}; use compio_buf::{IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; use socket2::SockAddr; -use super::*; +use super::{ + buffer_pool::{BorrowedBuffer, BufferPool}, + *, +}; pub use crate::unix::op::*; -use crate::SharedFd; +use crate::{SharedFd, TakeBuffer}; macro_rules! op { (<$($ty:ident: $trait:ident),* $(,)?> $name:ident( $($arg:ident: $arg_t:ty),* $(,)? )) => { @@ -87,6 +90,98 @@ macro_rules! op { }; } +macro_rules! buffer_pool_op { + (<$($ty:ident: $trait:ident),* $(,)?> $name:ident( $($arg:ident: $arg_t:ty),* $(,)? )) => { + ::paste::paste!{ + enum [< $name Inner >] <$($ty: $trait),*> { + Poll(poll::$name<$($ty),*>), + IoUring(iour::$name<$($ty),*>), + } + + impl<$($ty: $trait),*> [< $name Inner >]<$($ty),*> { + fn poll(&mut self) -> &mut poll::$name<$($ty),*> { + debug_assert!(DriverType::current() == DriverType::Poll); + + match self { + Self::Poll(ref mut op) => op, + Self::IoUring(_) => unreachable!("Current driver is not `io-uring`"), + } + } + + fn iour(&mut self) -> &mut iour::$name<$($ty),*> { + debug_assert!(DriverType::current() == DriverType::IoUring); + + match self { + Self::IoUring(ref mut op) => op, + Self::Poll(_) => unreachable!("Current driver is not `polling`"), + } + } + } + + #[doc = concat!("A fused `", stringify!($name), "` operation")] + pub struct $name <$($ty: $trait),*> { + inner: [< $name Inner >] <$($ty),*> + } + + impl<$($ty: $trait),*> TakeBuffer for $name <$($ty),*> { + type BufferPool = BufferPool; + type Buffer<'a> = BorrowedBuffer<'a>; + + fn take_buffer( + self, + buffer_pool: &Self::BufferPool, + result: io::Result, + flags: u32, + ) -> io::Result> { + match self.inner { + [< $name Inner >]::Poll(inner) => { + Ok(BorrowedBuffer::new_poll(inner.take_buffer(buffer_pool.as_poll(), result, flags)?)) + } + [< $name Inner >]::IoUring(inner) => { + Ok(BorrowedBuffer::new_io_uring(inner.take_buffer(buffer_pool.as_io_uring(), result, flags)?)) + } + } + } + } + + impl<$($ty: $trait),*> $name <$($ty),*> { + #[doc = concat!("Create a new `", stringify!($name), "`.")] + pub fn new(buffer_pool: &BufferPool, $($arg: $arg_t),*) -> io::Result { + let this = match DriverType::current() { + DriverType::Poll => Self { + inner: [< $name Inner >]::Poll(poll::$name::new(buffer_pool.as_poll(), $($arg),*)?), + }, + DriverType::IoUring => Self { + inner: [< $name Inner >]::IoUring(iour::$name::new(buffer_pool.as_io_uring(), $($arg),*)?), + }, + }; + + Ok(this) + } + } + } + + impl<$($ty: $trait),*> poll::OpCode for $name<$($ty),*> { + fn pre_submit(self: std::pin::Pin<&mut Self>) -> std::io::Result { + unsafe { self.map_unchecked_mut(|x| x.inner.poll() ) }.pre_submit() + } + + fn on_event( + self: std::pin::Pin<&mut Self>, + event: &polling::Event, + ) -> std::task::Poll> { + unsafe { self.map_unchecked_mut(|x| x.inner.poll() ) }.on_event(event) + } + } + + impl<$($ty: $trait),*> iour::OpCode for $name<$($ty),*> { + fn create_entry(self: std::pin::Pin<&mut Self>) -> OpEntry { + unsafe { self.map_unchecked_mut(|x| x.inner.iour() ) }.create_entry() + } + } + }; +} + #[rustfmt::skip] mod iour { pub use crate::sys::iour::{op::*, OpCode}; } #[rustfmt::skip] @@ -98,3 +193,6 @@ op!( RecvFromVectored(fd: SharedFd, buffer: op!( SendToVectored(fd: SharedFd, buffer: T, addr: SockAddr)); op!( FileStat(fd: SharedFd)); op!(<> PathStat(path: CString, follow_symlink: bool)); + +buffer_pool_op!( RecvBufferPool(fd: SharedFd, len: u32)); +buffer_pool_op!( ReadAtBufferPool(fd: SharedFd, offset: u64, len: u32)); diff --git a/compio-driver/src/iocp/mod.rs b/compio-driver/src/iocp/mod.rs index 3b9f0fa8..37528d59 100644 --- a/compio-driver/src/iocp/mod.rs +++ b/compio-driver/src/iocp/mod.rs @@ -28,7 +28,7 @@ use windows_sys::Win32::{ }, }; -use crate::{syscall, AsyncifyPool, Entry, Key, OutEntries, ProactorBuilder}; +use crate::{syscall, AsyncifyPool, BufferPool, Entry, Key, OutEntries, ProactorBuilder}; pub(crate) mod op; @@ -320,6 +320,21 @@ impl Driver { self.notify_overlapped.clone(), )) } + + pub fn create_buffer_pool( + &mut self, + buffer_len: u16, + buffer_size: usize, + ) -> io::Result { + Ok(BufferPool::new(buffer_len, buffer_size)) + } + + /// # Safety + /// + /// caller must make sure release the buffer pool with correct driver + pub unsafe fn release_buffer_pool(&mut self, _: BufferPool) -> io::Result<()> { + Ok(()) + } } impl AsRawFd for Driver { diff --git a/compio-driver/src/iocp/op.rs b/compio-driver/src/iocp/op.rs index a725709e..b1890ab9 100644 --- a/compio-driver/src/iocp/op.rs +++ b/compio-driver/src/iocp/op.rs @@ -2,6 +2,7 @@ use std::sync::OnceLock; use std::{ io, + io::ErrorKind, marker::PhantomPinned, net::Shutdown, os::windows::io::AsRawSocket, @@ -11,9 +12,12 @@ use std::{ }; use aligned_array::{Aligned, A8}; -use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; +use compio_buf::{ + BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut, SetBufInit, Slice, +}; #[cfg(not(feature = "once_cell_try"))] use once_cell::sync::OnceCell as OnceLock; +use pin_project_lite::pin_project; use socket2::SockAddr; use windows_sys::{ core::GUID, @@ -38,7 +42,10 @@ use windows_sys::{ }, }; -use crate::{op::*, syscall, AsRawFd, OpCode, OpType, RawFd, SharedFd}; +use crate::{ + op::*, syscall, AsRawFd, BorrowedBuffer, BufferPool, OpCode, OpType, RawFd, SharedFd, + TakeBuffer, +}; #[inline] fn winapi_result(transferred: u32) -> Poll> { @@ -164,6 +171,69 @@ impl OpCode for ReadAt { } } +pin_project! { + /// Read a file at specified position into specified buffer. + pub struct ReadAtBufferPool { + #[pin] + read_at: ReadAt>, S>, + } +} + +impl ReadAtBufferPool { + /// Create [`ReadAtBufferPool`]. + pub fn new( + buffer_pool: &BufferPool, + fd: SharedFd, + offset: u64, + len: u32, + ) -> io::Result { + let buffer = buffer_pool.get_buffer().ok_or_else(|| { + io::Error::new(ErrorKind::Other, "buffer ring has no available buffer") + })?; + let len = if len == 0 { + buffer.capacity() + } else { + buffer.capacity().min(len as _) + }; + + Ok(Self { + read_at: ReadAt::new(fd, offset, buffer.slice(..len)), + }) + } +} + +impl OpCode for ReadAtBufferPool { + unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll> { + self.project().read_at.operate(optr) + } + + unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> { + self.project().read_at.cancel(optr) + } +} + +impl TakeBuffer for ReadAtBufferPool { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + fn take_buffer( + self, + buffer_pool: &BufferPool, + result: io::Result, + _: u32, + ) -> io::Result { + let n = result?; + let mut slice = self.read_at.into_inner(); + + // Safety: n is valid + unsafe { + slice.set_buf_init(n); + } + + Ok(BorrowedBuffer::new(slice, buffer_pool)) + } +} + impl OpCode for WriteAt { unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll> { if let Some(overlapped) = optr.as_mut() { @@ -419,6 +489,64 @@ impl OpCode for Recv { } } +pin_project! { + /// Receive data from remote. + pub struct RecvBufferPool { + #[pin] + recv: Recv>, S>, + } +} + +impl RecvBufferPool { + /// Create [`RecvBufferPool`]. + pub fn new(buffer_pool: &BufferPool, fd: SharedFd, len: u32) -> io::Result { + let buffer = buffer_pool.get_buffer().ok_or_else(|| { + io::Error::new(ErrorKind::Other, "buffer ring has no available buffer") + })?; + let len = if len == 0 { + buffer.capacity() + } else { + buffer.capacity().min(len as _) + }; + + Ok(Self { + recv: Recv::new(fd, buffer.slice(..len)), + }) + } +} + +impl OpCode for RecvBufferPool { + unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll> { + self.project().recv.operate(optr) + } + + unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> { + self.project().recv.cancel(optr) + } +} + +impl TakeBuffer for RecvBufferPool { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + fn take_buffer( + self, + buffer_pool: &BufferPool, + result: io::Result, + _: u32, + ) -> io::Result { + let n = result?; + let mut slice = self.recv.into_inner(); + + // Safety: n is valid + unsafe { + slice.set_buf_init(n); + } + + Ok(BorrowedBuffer::new(slice, buffer_pool)) + } +} + /// Receive data from remote into vectored buffer. pub struct RecvVectored { pub(crate) fd: SharedFd, diff --git a/compio-driver/src/iour/buffer_pool.rs b/compio-driver/src/iour/buffer_pool.rs new file mode 100644 index 00000000..6cb61789 --- /dev/null +++ b/compio-driver/src/iour/buffer_pool.rs @@ -0,0 +1,104 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + fmt::{Debug, Formatter}, + ops::{Deref, DerefMut}, +}; + +use io_uring::cqueue::buffer_select; +use io_uring_buf_ring::IoUringBufRing; + +/// Buffer pool +/// +/// A buffer pool to allow user no need to specify a specific buffer to do the +/// IO operation +pub struct BufferPool { + buf_ring: IoUringBufRing>, +} + +impl Debug for BufferPool { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufferPool").finish_non_exhaustive() + } +} + +impl BufferPool { + pub(crate) fn new(buf_ring: IoUringBufRing>) -> Self { + Self { buf_ring } + } + + pub(crate) fn buffer_group(&self) -> u16 { + self.buf_ring.buffer_group() + } + + pub(crate) fn into_inner(self) -> IoUringBufRing> { + self.buf_ring + } + + pub(crate) unsafe fn get_buffer( + &self, + flags: u32, + available_len: usize, + ) -> Option { + let buffer_id = buffer_select(flags)?; + + self.buf_ring + .get_buf(buffer_id, available_len) + .map(BorrowedBuffer) + } + + pub(crate) unsafe fn reuse_buffer(&self, flags: u32) { + if let Some(buffer_id) = buffer_select(flags) { + self.buf_ring.get_buf(buffer_id, 0).map(BorrowedBuffer); + } + } +} + +/// Buffer borrowed from buffer pool +/// +/// When IO operation finish, user will obtain a `BorrowedBuffer` to access the +/// filled data +pub struct BorrowedBuffer<'a>(io_uring_buf_ring::BorrowedBuffer<'a, Vec>); + +impl Debug for BorrowedBuffer<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BorrowedBuffer").finish_non_exhaustive() + } +} + +impl Deref for BorrowedBuffer<'_> { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + +impl DerefMut for BorrowedBuffer<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.deref_mut() + } +} + +impl AsRef<[u8]> for BorrowedBuffer<'_> { + fn as_ref(&self) -> &[u8] { + self.deref() + } +} + +impl AsMut<[u8]> for BorrowedBuffer<'_> { + fn as_mut(&mut self) -> &mut [u8] { + self.deref_mut() + } +} + +impl Borrow<[u8]> for BorrowedBuffer<'_> { + fn borrow(&self) -> &[u8] { + self.deref() + } +} + +impl BorrowMut<[u8]> for BorrowedBuffer<'_> { + fn borrow_mut(&mut self) -> &mut [u8] { + self.deref_mut() + } +} diff --git a/compio-driver/src/iour/mod.rs b/compio-driver/src/iour/mod.rs index 67670dae..6945b920 100644 --- a/compio-driver/src/iour/mod.rs +++ b/compio-driver/src/iour/mod.rs @@ -1,7 +1,7 @@ #[cfg_attr(all(doc, docsrs), doc(cfg(all())))] #[allow(unused_imports)] pub use std::os::fd::{AsRawFd, OwnedFd, RawFd}; -use std::{io, os::fd::FromRawFd, pin::Pin, sync::Arc, task::Poll, time::Duration}; +use std::{io, io::ErrorKind, os::fd::FromRawFd, pin::Pin, sync::Arc, task::Poll, time::Duration}; use compio_log::{instrument, trace, warn}; use crossbeam_queue::SegQueue; @@ -25,12 +25,26 @@ use io_uring::{ types::{Fd, SubmitArgs, Timespec}, IoUring, }; +#[cfg(feature = "io-uring-buf-ring")] +use io_uring_buf_ring::IoUringBufRing; pub(crate) use libc::{sockaddr_storage, socklen_t}; +#[cfg(feature = "io-uring-buf-ring")] +use slab::Slab; use crate::{syscall, AsyncifyPool, Entry, Key, OutEntries, ProactorBuilder}; +#[cfg(feature = "io-uring-buf-ring")] +pub(crate) mod buffer_pool; pub(crate) mod op; +#[cfg(feature = "io-uring-buf-ring")] +use buffer_pool::BufferPool; + +#[cfg(not(feature = "io-uring-buf-ring"))] +pub(crate) use crate::fallback_buffer_pool as buffer_pool; +#[cfg(not(feature = "io-uring-buf-ring"))] +use crate::fallback_buffer_pool::BufferPool; + /// The created entry of [`OpCode`]. pub enum OpEntry { /// This operation creates an io-uring submission entry. @@ -73,6 +87,9 @@ pub(crate) struct Driver { notifier: Notifier, pool: AsyncifyPool, pool_completed: Arc>, + #[cfg(feature = "io-uring-buf-ring")] + // buffer group id type is u16, we should reuse the buffer group id when BufferPool is drop + buffer_group_ids: Slab<()>, } impl Driver { @@ -106,6 +123,8 @@ impl Driver { notifier, pool: builder.create_or_get_thread_pool(), pool_completed: Arc::new(SegQueue::new()), + #[cfg(feature = "io-uring-buf-ring")] + buffer_group_ids: Default::default(), }) } @@ -126,14 +145,14 @@ impl Driver { match res { Ok(_) => { if self.inner.completion().is_empty() { - Err(io::ErrorKind::TimedOut.into()) + Err(ErrorKind::TimedOut.into()) } else { Ok(()) } } Err(e) => match e.raw_os_error() { - Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()), - Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()), + Some(libc::ETIME) => Err(ErrorKind::TimedOut.into()), + Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(ErrorKind::Interrupted.into()), _ => Err(e), }, } @@ -274,6 +293,57 @@ impl Driver { pub fn handle(&self) -> io::Result { self.notifier.handle() } + + #[cfg(feature = "io-uring-buf-ring")] + pub fn create_buffer_pool( + &mut self, + buffer_len: u16, + buffer_size: usize, + ) -> io::Result { + let buffer_group = self.buffer_group_ids.insert(()); + if buffer_group > u16::MAX as usize { + self.buffer_group_ids.remove(buffer_group); + + return Err(io::Error::new( + ErrorKind::OutOfMemory, + "too many buffer pool allocated", + )); + } + + let buf_ring = + IoUringBufRing::new(&self.inner, buffer_len, buffer_group as _, buffer_size)?; + + Ok(BufferPool::new(buf_ring)) + } + + #[cfg(not(feature = "io-uring-buf-ring"))] + pub fn create_buffer_pool( + &mut self, + buffer_len: u16, + buffer_size: usize, + ) -> io::Result { + Ok(BufferPool::new(buffer_len, buffer_size)) + } + + #[cfg(feature = "io-uring-buf-ring")] + /// # Safety + /// + /// caller must make sure release the buffer pool with correct driver + pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> { + let buffer_group = buffer_pool.buffer_group(); + buffer_pool.into_inner().release(&self.inner)?; + self.buffer_group_ids.remove(buffer_group as _); + + Ok(()) + } + + #[cfg(not(feature = "io-uring-buf-ring"))] + /// # Safety + /// + /// caller must make sure release the buffer pool with correct driver + pub unsafe fn release_buffer_pool(&mut self, _: BufferPool) -> io::Result<()> { + Ok(()) + } } impl AsRawFd for Driver { @@ -333,9 +403,9 @@ impl Notifier { break Ok(()); } // Clear the next time:) - Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()), + Err(e) if e.kind() == ErrorKind::WouldBlock => break Ok(()), // Just like read_exact - Err(e) if e.kind() == io::ErrorKind::Interrupted => continue, + Err(e) if e.kind() == ErrorKind::Interrupted => continue, Err(e) => break Err(e), } } diff --git a/compio-driver/src/iour/op.rs b/compio-driver/src/iour/op.rs index 75ba51d4..0261f565 100644 --- a/compio-driver/src/iour/op.rs +++ b/compio-driver/src/iour/op.rs @@ -1,8 +1,12 @@ use std::{ffi::CString, io, marker::PhantomPinned, os::fd::AsRawFd, pin::Pin}; +#[cfg(feature = "io-uring-buf-ring")] +pub use buf_ring_op::{ReadAtBufferPool, RecvBufferPool}; use compio_buf::{ BufResult, IntoInner, IoBuf, IoBufMut, IoSlice, IoSliceMut, IoVectoredBuf, IoVectoredBufMut, }; +#[cfg(not(feature = "io-uring-buf-ring"))] +pub use fallback_op::{ReadAtBufferPool, RecvBufferPool}; use io_uring::{ opcode, types::{Fd, FsyncFlags}, @@ -23,7 +27,7 @@ impl< OpEntry::Blocking } - fn call_blocking(self: Pin<&mut Self>) -> std::io::Result { + fn call_blocking(self: Pin<&mut Self>) -> io::Result { // Safety: self won't be moved let this = unsafe { self.get_unchecked_mut() }; let f = this @@ -567,3 +571,263 @@ impl OpCode for PollOnce { .into() } } + +#[cfg(feature = "io-uring-buf-ring")] +mod buf_ring_op { + use std::{io, io::ErrorKind, marker::PhantomPinned, os::fd::AsRawFd, pin::Pin, ptr}; + + use io_uring::{opcode, squeue::Flags, types::Fd}; + + use super::{ + super::buffer_pool::{BorrowedBuffer, BufferPool}, + OpCode, OpEntry, + }; + use crate::{SharedFd, TakeBuffer}; + + /// Receive data from remote. + pub struct RecvBufferPool { + fd: SharedFd, + buffer_group: u16, + len: u32, + _p: PhantomPinned, + } + + impl RecvBufferPool { + /// Create [`RecvBufferPool`]. + pub fn new(buffer_pool: &BufferPool, fd: SharedFd, len: u32) -> io::Result { + Ok(Self { + fd, + buffer_group: buffer_pool.buffer_group(), + len, + _p: PhantomPinned, + }) + } + } + + impl OpCode for RecvBufferPool { + fn create_entry(self: Pin<&mut Self>) -> OpEntry { + let fd = self.fd.as_raw_fd(); + opcode::Read::new(Fd(fd), ptr::null_mut(), self.len) + .buf_group(self.buffer_group) + .build() + .flags(Flags::BUFFER_SELECT) + .into() + } + } + + impl TakeBuffer for RecvBufferPool { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + fn take_buffer( + self, + buffer_pool: &Self::BufferPool, + result: io::Result, + flags: u32, + ) -> io::Result> { + let n = result.inspect_err(|_| unsafe { + // Safety: flags is valid + buffer_pool.reuse_buffer(flags) + })?; + + unsafe { + // Safety: flags is valid + buffer_pool.get_buffer(flags, n).ok_or_else(|| { + io::Error::new(ErrorKind::InvalidInput, format!("flags {flags} is invalid")) + }) + } + } + } + + /// Read a file at specified position into specified buffer. + #[derive(Debug)] + pub struct ReadAtBufferPool { + pub(crate) fd: SharedFd, + pub(crate) offset: u64, + buffer_group: u16, + len: u32, + _p: PhantomPinned, + } + + impl ReadAtBufferPool { + /// Create [`ReadAtBufferPool`]. + pub fn new( + buffer_pool: &BufferPool, + fd: SharedFd, + offset: u64, + len: u32, + ) -> io::Result { + Ok(Self { + fd, + offset, + buffer_group: buffer_pool.buffer_group(), + len, + _p: PhantomPinned, + }) + } + } + + impl OpCode for ReadAtBufferPool { + fn create_entry(self: Pin<&mut Self>) -> OpEntry { + let fd = Fd(self.fd.as_raw_fd()); + let offset = self.offset; + opcode::Read::new(fd, ptr::null_mut(), self.len) + .offset(offset) + .buf_group(self.buffer_group) + .build() + .flags(Flags::BUFFER_SELECT) + .into() + } + } + + impl TakeBuffer for ReadAtBufferPool { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + fn take_buffer( + self, + buffer_pool: &Self::BufferPool, + result: io::Result, + flags: u32, + ) -> io::Result> { + let n = result.inspect_err(|_| unsafe { + // Safety: flags is valid + buffer_pool.reuse_buffer(flags) + })?; + + unsafe { + // Safety: flags is valid + buffer_pool.get_buffer(flags, n).ok_or_else(|| { + io::Error::new(ErrorKind::InvalidInput, format!("flags {flags} is invalid")) + }) + } + } + } +} + +#[cfg(not(feature = "io-uring-buf-ring"))] +mod fallback_op { + use std::{io, io::ErrorKind, os::fd::AsRawFd, pin::Pin}; + + use compio_buf::{IntoInner, IoBuf, SetBufInit, Slice}; + use pin_project_lite::pin_project; + + use super::{OpCode, ReadAt, Recv}; + use crate::{ + fallback_buffer_pool::{BorrowedBuffer, BufferPool}, + OpEntry, SharedFd, TakeBuffer, + }; + + pin_project! { + /// Receive data from remote. + pub struct RecvBufferPool { + #[pin] + recv: Recv>, S>, + } + } + + impl RecvBufferPool { + /// Create [`RecvBufferPool`]. + pub fn new(buffer_pool: &BufferPool, fd: SharedFd, len: u32) -> io::Result { + let buffer = buffer_pool.get_buffer().ok_or_else(|| { + io::Error::new(ErrorKind::Other, "buffer ring has no available buffer") + })?; + let len = if len == 0 { + buffer.capacity() + } else { + buffer.capacity().min(len as _) + }; + + Ok(Self { + recv: Recv::new(fd, buffer.slice(..len)), + }) + } + } + + impl OpCode for RecvBufferPool { + fn create_entry(self: Pin<&mut Self>) -> OpEntry { + self.project().recv.create_entry() + } + } + + impl TakeBuffer for RecvBufferPool { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + fn take_buffer( + self, + buffer_pool: &Self::BufferPool, + result: io::Result, + _: u32, + ) -> io::Result> { + let n = result?; + let mut slice = self.recv.into_inner(); + + // Safety: n is valid + unsafe { + slice.set_buf_init(n); + } + + Ok(BorrowedBuffer::new(slice, buffer_pool)) + } + } + + pin_project! { + /// Read a file at specified position into specified buffer. + pub struct ReadAtBufferPool { + #[pin] + read_at: ReadAt>, S>, + } + } + + impl ReadAtBufferPool { + /// Create [`ReadAtBufferPool`]. + pub fn new( + buffer_pool: &BufferPool, + fd: SharedFd, + offset: u64, + len: u32, + ) -> io::Result { + let buffer = buffer_pool.get_buffer().ok_or_else(|| { + io::Error::new(ErrorKind::Other, "buffer ring has no available buffer") + })?; + let len = if len == 0 { + buffer.capacity() + } else { + buffer.capacity().min(len as _) + }; + + Ok(Self { + read_at: ReadAt::new(fd, offset, buffer.slice(..len)), + }) + } + } + + impl OpCode for ReadAtBufferPool { + fn create_entry(self: Pin<&mut Self>) -> OpEntry { + self.project().read_at.create_entry() + } + } + + impl TakeBuffer for ReadAtBufferPool { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + fn take_buffer( + self, + buffer_pool: &Self::BufferPool, + result: io::Result, + _: u32, + ) -> io::Result> { + let n = result?; + let mut slice = self.read_at.into_inner(); + + // Safety: n is valid + unsafe { + slice.set_buf_init(n); + } + + Ok(BorrowedBuffer::new(slice, buffer_pool)) + } + } +} diff --git a/compio-driver/src/lib.rs b/compio-driver/src/lib.rs index d39d8c04..ef0074e4 100644 --- a/compio-driver/src/lib.rs +++ b/compio-driver/src/lib.rs @@ -35,21 +35,31 @@ mod asyncify; pub use asyncify::*; mod fd; + pub use fd::*; cfg_if::cfg_if! { if #[cfg(windows)] { #[path = "iocp/mod.rs"] mod sys; + mod fallback_buffer_pool; + pub use fallback_buffer_pool::{BufferPool, BorrowedBuffer}; } else if #[cfg(all(target_os = "linux", feature = "polling", feature = "io-uring"))] { #[path = "fusion/mod.rs"] mod sys; + mod fallback_buffer_pool; + pub use sys::buffer_pool::{BufferPool, BorrowedBuffer}; } else if #[cfg(all(target_os = "linux", feature = "io-uring"))] { #[path = "iour/mod.rs"] mod sys; + #[cfg(not(feature = "io-uring-buf-ring"))] + mod fallback_buffer_pool; + pub use sys::buffer_pool::{BufferPool, BorrowedBuffer}; } else if #[cfg(unix)] { #[path = "poll/mod.rs"] mod sys; + mod fallback_buffer_pool; + pub use fallback_buffer_pool::{BufferPool, BorrowedBuffer}; } } @@ -311,6 +321,29 @@ impl Proactor { pub fn handle(&self) -> io::Result { self.driver.handle() } + + /// Create buffer pool with given `buffer_size` and `buffer_len` + /// + /// # Notes + /// + /// If `buffer_len` is not power of 2, it will be upward with + /// [`u16::next_power_of_two`] + pub fn create_buffer_pool( + &mut self, + buffer_len: u16, + buffer_size: usize, + ) -> io::Result { + self.driver.create_buffer_pool(buffer_len, buffer_size) + } + + /// Release the buffer pool + /// + /// # Safety + /// + /// caller must make sure release the buffer pool with correct driver + pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> { + self.driver.release_buffer_pool(buffer_pool) + } } impl AsRawFd for Proactor { @@ -503,3 +536,21 @@ impl ProactorBuilder { Proactor::with_builder(self) } } + +/// Trait to get the selected buffer of an io operation. +pub trait TakeBuffer { + /// Selected buffer type + type Buffer<'a>; + + /// Buffer pool type + type BufferPool; + + /// Take the selected buffer with `buffer_pool`, io `result` and `flags`, if + /// io operation is success + fn take_buffer( + self, + buffer_pool: &Self::BufferPool, + result: io::Result, + flags: u32, + ) -> io::Result>; +} diff --git a/compio-driver/src/op.rs b/compio-driver/src/op.rs index b453275d..f5e77169 100644 --- a/compio-driver/src/op.rs +++ b/compio-driver/src/op.rs @@ -11,8 +11,8 @@ use socket2::SockAddr; #[cfg(windows)] pub use crate::sys::op::ConnectNamedPipe; pub use crate::sys::op::{ - Accept, Recv, RecvFrom, RecvFromVectored, RecvVectored, Send, SendTo, SendToVectored, - SendVectored, + Accept, ReadAtBufferPool, Recv, RecvBufferPool, RecvFrom, RecvFromVectored, RecvVectored, Send, + SendTo, SendToVectored, SendVectored, }; #[cfg(unix)] pub use crate::sys::op::{ diff --git a/compio-driver/src/poll/mod.rs b/compio-driver/src/poll/mod.rs index 746d764c..d2446903 100644 --- a/compio-driver/src/poll/mod.rs +++ b/compio-driver/src/poll/mod.rs @@ -17,7 +17,10 @@ use crossbeam_queue::SegQueue; pub(crate) use libc::{sockaddr_storage, socklen_t}; use polling::{Event, Events, PollMode, Poller}; -use crate::{op::Interest, syscall, AsyncifyPool, Entry, Key, OutEntries, ProactorBuilder}; +use crate::{ + fallback_buffer_pool::BufferPool, op::Interest, syscall, AsyncifyPool, Entry, Key, OutEntries, + ProactorBuilder, +}; pub(crate) mod op; @@ -298,6 +301,21 @@ impl Driver { pub fn handle(&self) -> io::Result { self.notifier.handle() } + + pub fn create_buffer_pool( + &mut self, + buffer_len: u16, + buffer_size: usize, + ) -> io::Result { + Ok(BufferPool::new(buffer_len, buffer_size)) + } + + /// # Safety + /// + /// caller must make sure release the buffer pool with correct driver + pub unsafe fn release_buffer_pool(&mut self, _: BufferPool) -> io::Result<()> { + Ok(()) + } } impl AsRawFd for Driver { diff --git a/compio-driver/src/poll/op.rs b/compio-driver/src/poll/op.rs index 3107b1d2..74b66d8c 100644 --- a/compio-driver/src/poll/op.rs +++ b/compio-driver/src/poll/op.rs @@ -1,7 +1,8 @@ -use std::{ffi::CString, io, marker::PhantomPinned, pin::Pin, task::Poll}; +use std::{ffi::CString, io, io::ErrorKind, marker::PhantomPinned, pin::Pin, task::Poll}; use compio_buf::{ BufResult, IntoInner, IoBuf, IoBufMut, IoSlice, IoSliceMut, IoVectoredBuf, IoVectoredBufMut, + SetBufInit, Slice, }; #[cfg(not(all(target_os = "linux", target_env = "gnu")))] use libc::open; @@ -11,12 +12,17 @@ use libc::open64 as open; use libc::{pread, preadv, pwrite, pwritev}; #[cfg(any(target_os = "linux", target_os = "android", target_os = "hurd"))] use libc::{pread64 as pread, preadv64 as preadv, pwrite64 as pwrite, pwritev64 as pwritev}; +use pin_project_lite::pin_project; use polling::Event; use socket2::SockAddr; use super::{sockaddr_storage, socklen_t, syscall, AsRawFd, Decision, OpCode}; pub use crate::unix::op::*; -use crate::{op::*, SharedFd}; +use crate::{ + fallback_buffer_pool::{BorrowedBuffer, BufferPool}, + op::*, + SharedFd, TakeBuffer, +}; impl< D: std::marker::Send + 'static, @@ -193,6 +199,69 @@ impl OpCode for ReadAt { } } +pin_project! { + /// Read a file at specified position into specified buffer. + pub struct ReadAtBufferPool { + #[pin] + read_at: ReadAt>, S>, + } +} + +impl ReadAtBufferPool { + /// Create [`ReadAtBufferPool`]. + pub fn new( + buffer_pool: &BufferPool, + fd: SharedFd, + offset: u64, + len: u32, + ) -> io::Result { + let buffer = buffer_pool.get_buffer().ok_or_else(|| { + io::Error::new(ErrorKind::Other, "buffer ring has no available buffer") + })?; + let len = if len == 0 { + buffer.capacity() + } else { + buffer.capacity().min(len as _) + }; + + Ok(Self { + read_at: ReadAt::new(fd, offset, buffer.slice(..len)), + }) + } +} + +impl OpCode for ReadAtBufferPool { + fn pre_submit(self: Pin<&mut Self>) -> io::Result { + self.project().read_at.pre_submit() + } + + fn on_event(self: Pin<&mut Self>, event: &Event) -> Poll> { + self.project().read_at.on_event(event) + } +} + +impl TakeBuffer for ReadAtBufferPool { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + fn take_buffer( + self, + buffer_pool: &Self::BufferPool, + result: io::Result, + _: u32, + ) -> io::Result> { + let n = result?; + let mut slice = self.read_at.into_inner(); + + // Safety: n is valid + unsafe { + slice.set_buf_init(n); + } + + Ok(BorrowedBuffer::new(slice, buffer_pool)) + } +} + impl OpCode for ReadVectoredAt { fn pre_submit(self: Pin<&mut Self>) -> io::Result { Ok(Decision::blocking_readable(self.fd.as_raw_fd())) @@ -452,6 +521,64 @@ impl OpCode for Recv { } } +pin_project! { + /// Receive data from remote. + pub struct RecvBufferPool { + #[pin] + recv: Recv>, S>, + } +} + +impl RecvBufferPool { + /// Create [`RecvBufferPool`]. + pub fn new(buffer_pool: &BufferPool, fd: SharedFd, len: u32) -> io::Result { + let buffer = buffer_pool.get_buffer().ok_or_else(|| { + io::Error::new(ErrorKind::Other, "buffer ring has no available buffer") + })?; + let len = if len == 0 { + buffer.capacity() + } else { + buffer.capacity().min(len as _) + }; + + Ok(Self { + recv: Recv::new(fd, buffer.slice(..len)), + }) + } +} + +impl OpCode for RecvBufferPool { + fn pre_submit(self: Pin<&mut Self>) -> io::Result { + self.project().recv.pre_submit() + } + + fn on_event(self: Pin<&mut Self>, event: &Event) -> Poll> { + self.project().recv.on_event(event) + } +} + +impl TakeBuffer for RecvBufferPool { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + fn take_buffer( + self, + buffer_pool: &Self::BufferPool, + result: io::Result, + _: u32, + ) -> io::Result> { + let n = result?; + let mut slice = self.recv.into_inner(); + + // Safety: n is valid + unsafe { + slice.set_buf_init(n); + } + + Ok(BorrowedBuffer::new(slice, buffer_pool)) + } +} + impl OpCode for RecvVectored { fn pre_submit(self: Pin<&mut Self>) -> io::Result { Ok(Decision::wait_readable(self.fd.as_raw_fd())) diff --git a/compio-fs/src/async_fd.rs b/compio-fs/src/async_fd.rs index 70d150bd..b2928d79 100644 --- a/compio-fs/src/async_fd.rs +++ b/compio-fs/src/async_fd.rs @@ -8,11 +8,14 @@ use std::os::windows::io::{ use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut}; use compio_driver::{ - op::{BufResultExt, Recv, Send}, - AsRawFd, SharedFd, ToSharedFd, + op::{BufResultExt, Recv, RecvBufferPool, Send}, + AsRawFd, SharedFd, TakeBuffer, ToSharedFd, +}; +use compio_io::{AsyncRead, AsyncReadBufferPool, AsyncWrite}; +use compio_runtime::{ + buffer_pool::{BorrowedBuffer, BufferPool}, + Attacher, }; -use compio_io::{AsyncRead, AsyncWrite}; -use compio_runtime::Attacher; #[cfg(unix)] use { compio_buf::{IoVectoredBuf, IoVectoredBufMut}, @@ -74,6 +77,36 @@ impl AsyncRead for &AsyncFd { } } +impl AsyncReadBufferPool for AsyncFd { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + (&*self).read_buffer_pool(buffer_pool, len).await + } +} + +impl AsyncReadBufferPool for &AsyncFd { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + let fd = self.to_shared_fd(); + let op = RecvBufferPool::new(buffer_pool.as_driver_buffer_pool(), fd, len as _)?; + let (BufResult(res, op), flags) = compio_runtime::submit_with_flags(op).await; + + op.take_buffer(buffer_pool.as_driver_buffer_pool(), res, flags) + } +} + impl AsyncWrite for AsyncFd { #[inline] async fn write(&mut self, buf: B) -> BufResult { diff --git a/compio-fs/src/file.rs b/compio-fs/src/file.rs index 01557fb0..7096c602 100644 --- a/compio-fs/src/file.rs +++ b/compio-fs/src/file.rs @@ -3,11 +3,14 @@ use std::{future::Future, io, mem::ManuallyDrop, panic::resume_unwind, path::Pat use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut}; use compio_driver::{ impl_raw_fd, - op::{BufResultExt, CloseFile, ReadAt, Sync, WriteAt}, - ToSharedFd, + op::{BufResultExt, CloseFile, ReadAt, ReadAtBufferPool, Sync, WriteAt}, + TakeBuffer, ToSharedFd, +}; +use compio_io::{AsyncReadAt, AsyncReadAtBufferPool, AsyncWriteAt}; +use compio_runtime::{ + buffer_pool::{BorrowedBuffer, BufferPool}, + Attacher, }; -use compio_io::{AsyncReadAt, AsyncWriteAt}; -use compio_runtime::Attacher; #[cfg(unix)] use { compio_buf::{IoVectoredBuf, IoVectoredBufMut}, @@ -168,6 +171,24 @@ impl AsyncReadAt for File { } } +impl AsyncReadAtBufferPool for File { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_at_buffer_pool<'a>( + &self, + buffer_pool: &'a Self::BufferPool, + pos: u64, + len: usize, + ) -> io::Result> { + let fd = self.to_shared_fd(); + let op = ReadAtBufferPool::new(buffer_pool.as_driver_buffer_pool(), fd, pos, len as _)?; + let (BufResult(res, op), flags) = compio_runtime::submit_with_flags(op).await; + + op.take_buffer(buffer_pool.as_driver_buffer_pool(), res, flags) + } +} + impl AsyncWriteAt for File { #[inline] async fn write_at(&mut self, buf: T, pos: u64) -> BufResult { diff --git a/compio-fs/src/named_pipe.rs b/compio-fs/src/named_pipe.rs index cb4f3347..9a58d257 100644 --- a/compio-fs/src/named_pipe.rs +++ b/compio-fs/src/named_pipe.rs @@ -8,7 +8,10 @@ use std::{ffi::OsStr, io, os::windows::io::FromRawHandle, ptr::null}; use compio_buf::{BufResult, IoBuf, IoBufMut}; use compio_driver::{impl_raw_fd, op::ConnectNamedPipe, syscall, AsRawFd, RawFd, ToSharedFd}; -use compio_io::{AsyncRead, AsyncReadAt, AsyncWrite, AsyncWriteAt}; +use compio_io::{ + AsyncRead, AsyncReadAt, AsyncReadAtBufferPool, AsyncReadBufferPool, AsyncWrite, AsyncWriteAt, +}; +use compio_runtime::buffer_pool::{BorrowedBuffer, BufferPool}; use widestring::U16CString; use windows_sys::Win32::{ Storage::FileSystem::{ @@ -192,6 +195,32 @@ impl AsyncRead for &NamedPipeServer { } } +impl AsyncReadBufferPool for NamedPipeServer { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + (&*self).read_buffer_pool(buffer_pool, len).await + } +} + +impl AsyncReadBufferPool for &NamedPipeServer { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + (&self.handle).read_buffer_pool(buffer_pool, len).await + } +} + impl AsyncWrite for NamedPipeServer { #[inline] async fn write(&mut self, buf: T) -> BufResult { @@ -312,6 +341,32 @@ impl AsyncRead for &NamedPipeClient { } } +impl AsyncReadBufferPool for NamedPipeClient { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + (&*self).read_buffer_pool(buffer_pool, len).await + } +} + +impl AsyncReadBufferPool for &NamedPipeClient { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + self.handle.read_at_buffer_pool(buffer_pool, 0, len).await + } +} + impl AsyncWrite for NamedPipeClient { #[inline] async fn write(&mut self, buf: T) -> BufResult { diff --git a/compio-fs/src/pipe.rs b/compio-fs/src/pipe.rs index 1119d7f7..81ce56f8 100644 --- a/compio-fs/src/pipe.rs +++ b/compio-fs/src/pipe.rs @@ -10,10 +10,11 @@ use std::{ use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; use compio_driver::{ impl_raw_fd, - op::{BufResultExt, Recv, RecvVectored, Send, SendVectored}, - syscall, AsRawFd, ToSharedFd, + op::{BufResultExt, Recv, RecvBufferPool, RecvVectored, Send, SendVectored}, + syscall, AsRawFd, TakeBuffer, ToSharedFd, }; -use compio_io::{AsyncRead, AsyncWrite}; +use compio_io::{AsyncRead, AsyncReadBufferPool, AsyncWrite}; +use compio_runtime::buffer_pool::{BorrowedBuffer, BufferPool}; use crate::File; @@ -502,6 +503,36 @@ impl AsyncRead for &Receiver { } } +impl AsyncReadBufferPool for Receiver { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + (&*self).read_buffer_pool(buffer_pool, len).await + } +} + +impl AsyncReadBufferPool for &Receiver { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + let fd = self.to_shared_fd(); + let op = RecvBufferPool::new(buffer_pool.as_driver_buffer_pool(), fd, len as _)?; + let (BufResult(res, op), flags) = compio_runtime::submit_with_flags(op).await; + + op.take_buffer(buffer_pool.as_driver_buffer_pool(), res, flags) + } +} + impl_raw_fd!(Receiver, std::fs::File, file, file); /// Checks if file is a FIFO diff --git a/compio-fs/src/stdio/unix.rs b/compio-fs/src/stdio/unix.rs index faa24262..444b581f 100644 --- a/compio-fs/src/stdio/unix.rs +++ b/compio-fs/src/stdio/unix.rs @@ -2,7 +2,8 @@ use std::io; use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; use compio_driver::{AsRawFd, RawFd}; -use compio_io::{AsyncRead, AsyncWrite}; +use compio_io::{AsyncRead, AsyncReadBufferPool, AsyncWrite}; +use compio_runtime::buffer_pool::{BorrowedBuffer, BufferPool}; #[cfg(doc)] use super::{stderr, stdin, stdout}; @@ -31,6 +32,32 @@ impl AsyncRead for Stdin { } } +impl AsyncReadBufferPool for Stdin { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + (&*self).read_buffer_pool(buffer_pool, len).await + } +} + +impl AsyncReadBufferPool for &Stdin { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + (&self.0).read_buffer_pool(buffer_pool, len).await + } +} + impl AsyncRead for &Stdin { async fn read(&mut self, buf: B) -> BufResult { (&self.0).read(buf).await diff --git a/compio-fs/tests/buffer_pool.rs b/compio-fs/tests/buffer_pool.rs new file mode 100644 index 00000000..3b0d8029 --- /dev/null +++ b/compio-fs/tests/buffer_pool.rs @@ -0,0 +1,42 @@ +use std::io::Write; + +#[cfg(unix)] +use compio_fs::pipe; +use compio_fs::File; +use compio_io::AsyncReadAtBufferPool; +#[cfg(unix)] +use compio_io::{AsyncReadBufferPool, AsyncWriteExt}; +use compio_runtime::buffer_pool::BufferPool; +use tempfile::NamedTempFile; + +const HELLO: &[u8] = b"hello world..."; + +fn tempfile() -> NamedTempFile { + NamedTempFile::new().unwrap() +} + +#[compio_macros::test] +async fn test_read_file() { + let mut tempfile = tempfile(); + tempfile.write_all(HELLO).unwrap(); + + let file = File::open(tempfile.path()).await.unwrap(); + let buffer_pool = BufferPool::new(1, 15).unwrap(); + let buf = file.read_at_buffer_pool(&buffer_pool, 0, 0).await.unwrap(); + + assert_eq!(buf.len(), HELLO.len()); + assert_eq!(buf.as_ref(), HELLO); +} + +#[cfg(unix)] +#[compio_macros::test] +async fn test_read_pipe() { + let (mut rx, mut tx) = pipe::anonymous().unwrap(); + tx.write_all(HELLO).await.unwrap(); + + let buffer_pool = BufferPool::new(1, 15).unwrap(); + let buf = rx.read_buffer_pool(&buffer_pool, 0).await.unwrap(); + + assert_eq!(buf.len(), HELLO.len()); + assert_eq!(buf.as_ref(), HELLO); +} diff --git a/compio-io/src/read/mod.rs b/compio-io/src/read/mod.rs index 69032b96..7c88566f 100644 --- a/compio-io/src/read/mod.rs +++ b/compio-io/src/read/mod.rs @@ -1,6 +1,6 @@ #[cfg(feature = "allocator_api")] use std::alloc::Allocator; -use std::{io::Cursor, rc::Rc, sync::Arc}; +use std::{io, io::Cursor, ops::DerefMut, rc::Rc, sync::Arc}; use compio_buf::{buf_try, t_alloc, BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBufMut}; @@ -130,6 +130,51 @@ pub trait AsyncReadAt { } } +/// # AsyncReadBufferPool +/// +/// Async read with buffer pool +pub trait AsyncReadBufferPool { + /// Filled buffer type + type Buffer<'a>: DerefMut; + + /// Buffer pool type + type BufferPool; + + /// Read some bytes from this source with [`BufferPool`] and return + /// a [`BorrowedBuffer`]. + /// + /// If `len` == 0, will use [`BufferPool`] inner buffer size as the max len, + /// if `len` > 0, `min(len, inner buffer size)` will be the read max len + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result>; +} + +/// # AsyncReadAtBufferPool +/// +/// Async read with buffer pool and position +pub trait AsyncReadAtBufferPool { + /// Buffer pool type + type BufferPool; + + /// Filled buffer type + type Buffer<'a>: DerefMut; + + /// Read some bytes from this source at position with [`BufferPool`] and + /// return a [`BorrowedBuffer`]. + /// + /// If `len` == 0, will use [`BufferPool`] inner buffer size as the max len, + /// if `len` > 0, `min(len, inner buffer size)` will be the read max len + async fn read_at_buffer_pool<'a>( + &self, + buffer_pool: &'a Self::BufferPool, + pos: u64, + len: usize, + ) -> io::Result>; +} + macro_rules! impl_read_at { (@ptr $($ty:ty),*) => { $( diff --git a/compio-net/src/socket.rs b/compio-net/src/socket.rs index 523ac36a..d35a2949 100644 --- a/compio-net/src/socket.rs +++ b/compio-net/src/socket.rs @@ -6,12 +6,16 @@ use compio_driver::op::CreateSocket; use compio_driver::{ impl_raw_fd, op::{ - Accept, BufResultExt, CloseSocket, Connect, Recv, RecvFrom, RecvFromVectored, - RecvResultExt, RecvVectored, Send, SendTo, SendToVectored, SendVectored, ShutdownSocket, + Accept, BufResultExt, CloseSocket, Connect, Recv, RecvBufferPool, RecvFrom, + RecvFromVectored, RecvResultExt, RecvVectored, Send, SendTo, SendToVectored, SendVectored, + ShutdownSocket, }, - ToSharedFd, + TakeBuffer, ToSharedFd, +}; +use compio_runtime::{ + buffer_pool::{BorrowedBuffer, BufferPool}, + Attacher, }; -use compio_runtime::Attacher; use socket2::{Domain, Protocol, SockAddr, Socket as Socket2, Type}; use crate::PollFd; @@ -215,6 +219,18 @@ impl Socket { compio_runtime::submit(op).await.into_inner().map_advanced() } + pub async fn recv_buffer_pool<'a>( + &self, + buffer_pool: &'a BufferPool, + len: u32, + ) -> io::Result> { + let fd = self.to_shared_fd(); + let op = RecvBufferPool::new(buffer_pool.as_driver_buffer_pool(), fd, len)?; + let (BufResult(res, op), flags) = compio_runtime::submit_with_flags(op).await; + + op.take_buffer(buffer_pool.as_driver_buffer_pool(), res, flags) + } + pub async fn recv_vectored(&self, buffer: V) -> BufResult { let fd = self.to_shared_fd(); let op = RecvVectored::new(fd, buffer); diff --git a/compio-net/src/tcp.rs b/compio-net/src/tcp.rs index 782591df..592cd6cb 100644 --- a/compio-net/src/tcp.rs +++ b/compio-net/src/tcp.rs @@ -2,7 +2,8 @@ use std::{future::Future, io, net::SocketAddr}; use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; use compio_driver::impl_raw_fd; -use compio_io::{AsyncRead, AsyncWrite}; +use compio_io::{AsyncRead, AsyncReadBufferPool, AsyncWrite}; +use compio_runtime::buffer_pool::{BorrowedBuffer, BufferPool}; use socket2::{Protocol, SockAddr, Socket as Socket2, Type}; use crate::{ @@ -229,6 +230,32 @@ impl AsyncRead for TcpStream { } } +impl AsyncReadBufferPool for TcpStream { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + (&*self).read_buffer_pool(buffer_pool, len).await + } +} + +impl AsyncReadBufferPool for &TcpStream { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + self.inner.recv_buffer_pool(buffer_pool, len as _).await + } +} + impl AsyncRead for &TcpStream { #[inline] async fn read(&mut self, buf: B) -> BufResult { diff --git a/compio-net/src/udp.rs b/compio-net/src/udp.rs index f3b86f9a..93e602c6 100644 --- a/compio-net/src/udp.rs +++ b/compio-net/src/udp.rs @@ -2,6 +2,7 @@ use std::{future::Future, io, net::SocketAddr}; use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; use compio_driver::impl_raw_fd; +use compio_runtime::buffer_pool::{BorrowedBuffer, BufferPool}; use socket2::{Protocol, SockAddr, Type}; use crate::{Socket, ToSocketAddrsAsync}; @@ -249,6 +250,19 @@ impl UdpSocket { }) .await } + + /// Read some bytes from this source with [`BufferPool`] and return + /// a [`BorrowedBuffer`]. + /// + /// If `len` == 0, will use [`BufferPool`] inner buffer size as the max len, + /// if `len` > 0, `min(len, inner buffer size)` will be the read max len + pub async fn recv_buffer_pool<'a>( + &self, + buffer_pool: &'a BufferPool, + len: u32, + ) -> io::Result> { + self.inner.recv_buffer_pool(buffer_pool, len).await + } } impl_raw_fd!(UdpSocket, socket2::Socket, inner, socket); diff --git a/compio-net/src/unix.rs b/compio-net/src/unix.rs index e4ca0f61..0b08d434 100644 --- a/compio-net/src/unix.rs +++ b/compio-net/src/unix.rs @@ -2,7 +2,8 @@ use std::{future::Future, io, path::Path}; use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; use compio_driver::impl_raw_fd; -use compio_io::{AsyncRead, AsyncWrite}; +use compio_io::{AsyncRead, AsyncReadBufferPool, AsyncWrite}; +use compio_runtime::buffer_pool::{BorrowedBuffer, BufferPool}; use socket2::{SockAddr, Socket as Socket2, Type}; use crate::{OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, WriteHalf}; @@ -223,6 +224,32 @@ impl AsyncRead for &UnixStream { } } +impl AsyncReadBufferPool for UnixStream { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + (&*self).read_buffer_pool(buffer_pool, len).await + } +} + +impl AsyncReadBufferPool for &UnixStream { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + self.inner.recv_buffer_pool(buffer_pool, len as _).await + } +} + impl AsyncWrite for UnixStream { #[inline] async fn write(&mut self, buf: T) -> BufResult { diff --git a/compio-net/tests/buffer_pool.rs b/compio-net/tests/buffer_pool.rs new file mode 100644 index 00000000..13501a3f --- /dev/null +++ b/compio-net/tests/buffer_pool.rs @@ -0,0 +1,98 @@ +use std::net::Ipv6Addr; + +use compio_io::{AsyncReadBufferPool, AsyncWriteExt}; +use compio_net::{TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream}; +use compio_runtime::buffer_pool::BufferPool; + +#[compio_macros::test] +async fn test_tcp_read_buffer_pool() { + let listener = TcpListener::bind((Ipv6Addr::LOCALHOST, 0)).await.unwrap(); + let addr = listener.local_addr().unwrap(); + + compio_runtime::spawn(async move { + let mut stream = listener.accept().await.unwrap().0; + stream.write_all(b"test").await.unwrap(); + }) + .detach(); + + let buffer_pool = BufferPool::new(1, 4).unwrap(); + let mut stream = TcpStream::connect(addr).await.unwrap(); + + assert_eq!( + stream + .read_buffer_pool(&buffer_pool, 0) + .await + .unwrap() + .as_ref(), + b"test" + ); + + assert!( + stream + .read_buffer_pool(&buffer_pool, 0) + .await + .unwrap() + .is_empty() + ); +} + +#[compio_macros::test] +async fn test_udp_read_buffer_pool() { + let listener = UdpSocket::bind((Ipv6Addr::LOCALHOST, 0)).await.unwrap(); + let addr = listener.local_addr().unwrap(); + let connected = UdpSocket::bind((Ipv6Addr::LOCALHOST, 0)).await.unwrap(); + connected.connect(addr).await.unwrap(); + let addr = connected.local_addr().unwrap(); + + compio_runtime::spawn(async move { + listener.send_to(b"test", addr).await.unwrap(); + }) + .detach(); + + let buffer_pool = BufferPool::new(1, 4).unwrap(); + + assert_eq!( + connected + .recv_buffer_pool(&buffer_pool, 0) + .await + .unwrap() + .as_ref(), + b"test" + ); +} + +#[compio_macros::test] +async fn test_uds_recv_buffer_pool() { + let dir = tempfile::Builder::new() + .prefix("compio-uds-buffer-pool-tests") + .tempdir() + .unwrap(); + let sock_path = dir.path().join("connect.sock"); + + let listener = UnixListener::bind(&sock_path).await.unwrap(); + + let (mut client, (mut server, _)) = + futures_util::try_join!(UnixStream::connect(&sock_path), listener.accept()).unwrap(); + + client.write_all("test").await.unwrap(); + drop(client); + + let buffer_pool = BufferPool::new(1, 4).unwrap(); + + assert_eq!( + server + .read_buffer_pool(&buffer_pool, 0) + .await + .unwrap() + .as_ref(), + b"test" + ); + + assert!( + server + .read_buffer_pool(&buffer_pool, 0) + .await + .unwrap() + .is_empty() + ); +} diff --git a/compio-process/src/unix.rs b/compio-process/src/unix.rs index a48abfe8..e9329930 100644 --- a/compio-process/src/unix.rs +++ b/compio-process/src/unix.rs @@ -2,10 +2,11 @@ use std::{io, panic::resume_unwind, process}; use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut}; use compio_driver::{ - op::{BufResultExt, Recv, Send}, - AsRawFd, RawFd, SharedFd, ToSharedFd, + op::{BufResultExt, Recv, RecvBufferPool, Send}, + AsRawFd, RawFd, SharedFd, TakeBuffer, ToSharedFd, }; -use compio_io::{AsyncRead, AsyncWrite}; +use compio_io::{AsyncRead, AsyncReadBufferPool, AsyncWrite}; +use compio_runtime::buffer_pool::{BorrowedBuffer, BufferPool}; use crate::{ChildStderr, ChildStdin, ChildStdout}; @@ -35,6 +36,23 @@ impl AsyncRead for ChildStdout { } } +impl AsyncReadBufferPool for ChildStdout { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + let fd = self.to_shared_fd(); + let op = RecvBufferPool::new(buffer_pool.as_driver_buffer_pool(), fd, len as _)?; + let (BufResult(res, op), flags) = compio_runtime::submit_with_flags(op).await; + + op.take_buffer(buffer_pool.as_driver_buffer_pool(), res, flags) + } +} + impl AsRawFd for ChildStderr { fn as_raw_fd(&self) -> RawFd { self.0.as_raw_fd() @@ -55,6 +73,23 @@ impl AsyncRead for ChildStderr { } } +impl AsyncReadBufferPool for ChildStderr { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + let fd = self.to_shared_fd(); + let op = RecvBufferPool::new(buffer_pool.as_driver_buffer_pool(), fd, len as _)?; + let (BufResult(res, op), flags) = compio_runtime::submit_with_flags(op).await; + + op.take_buffer(buffer_pool.as_driver_buffer_pool(), res, flags) + } +} + impl AsRawFd for ChildStdin { fn as_raw_fd(&self) -> RawFd { self.0.as_raw_fd() diff --git a/compio-process/src/windows.rs b/compio-process/src/windows.rs index 7722a28c..5dafc333 100644 --- a/compio-process/src/windows.rs +++ b/compio-process/src/windows.rs @@ -8,10 +8,11 @@ use std::{ use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut}; use compio_driver::{ - op::{BufResultExt, Recv, Send}, - syscall, AsRawFd, OpCode, OpType, RawFd, SharedFd, ToSharedFd, + op::{BufResultExt, Recv, RecvBufferPool, Send}, + syscall, AsRawFd, OpCode, OpType, RawFd, SharedFd, TakeBuffer, ToSharedFd, }; -use compio_io::{AsyncRead, AsyncWrite}; +use compio_io::{AsyncRead, AsyncReadBufferPool, AsyncWrite}; +use compio_runtime::buffer_pool::{BorrowedBuffer, BufferPool}; use windows_sys::Win32::System::{Threading::GetExitCodeProcess, IO::OVERLAPPED}; use crate::{ChildStderr, ChildStdin, ChildStdout}; @@ -67,6 +68,23 @@ impl AsyncRead for ChildStdout { } } +impl AsyncReadBufferPool for ChildStdout { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + let fd = self.to_shared_fd(); + let op = RecvBufferPool::new(buffer_pool.as_driver_buffer_pool(), fd, len as _)?; + let (BufResult(res, op), flags) = compio_runtime::submit_with_flags(op).await; + + op.take_buffer(buffer_pool.as_driver_buffer_pool(), res, flags) + } +} + impl AsRawFd for ChildStderr { fn as_raw_fd(&self) -> RawFd { self.0.as_raw_fd() @@ -87,6 +105,23 @@ impl AsyncRead for ChildStderr { } } +impl AsyncReadBufferPool for ChildStderr { + type Buffer<'a> = BorrowedBuffer<'a>; + type BufferPool = BufferPool; + + async fn read_buffer_pool<'a>( + &mut self, + buffer_pool: &'a Self::BufferPool, + len: usize, + ) -> io::Result> { + let fd = self.to_shared_fd(); + let op = RecvBufferPool::new(buffer_pool.as_driver_buffer_pool(), fd, len as _)?; + let (BufResult(res, op), flags) = compio_runtime::submit_with_flags(op).await; + + op.take_buffer(buffer_pool.as_driver_buffer_pool(), res, flags) + } +} + impl AsRawFd for ChildStdin { fn as_raw_fd(&self) -> RawFd { self.0.as_raw_fd() diff --git a/compio-runtime/Cargo.toml b/compio-runtime/Cargo.toml index d3486b1d..5482b587 100644 --- a/compio-runtime/Cargo.toml +++ b/compio-runtime/Cargo.toml @@ -39,6 +39,7 @@ criterion = { workspace = true, optional = true } crossbeam-queue = { workspace = true } futures-util = { workspace = true } once_cell = { workspace = true } +rand = { workspace = true } scoped-tls = "1.0.1" slab = { workspace = true, optional = true } smallvec = "1.11.1" diff --git a/compio-runtime/src/buffer_pool.rs b/compio-runtime/src/buffer_pool.rs new file mode 100644 index 00000000..9678ee08 --- /dev/null +++ b/compio-runtime/src/buffer_pool.rs @@ -0,0 +1,87 @@ +//! Buffer pool + +use std::{io, marker::PhantomData, mem::ManuallyDrop}; + +use crate::Runtime; + +/// Buffer borrowed from buffer pool +/// +/// When IO operation finish, user will obtain a `BorrowedBuffer` to access the +/// filled data +pub type BorrowedBuffer<'a> = compio_driver::BorrowedBuffer<'a>; + +/// Buffer pool +/// +/// A buffer pool to allow user no need to specify a specific buffer to do the +/// IO operation +/// +/// Drop the `BufferPool` will release the buffer pool automatically +#[derive(Debug)] +pub struct BufferPool { + inner: ManuallyDrop, + runtime_id: i64, + + // make it !Send and !Sync, to prevent user send the buffer pool to other thread + _marker: PhantomData<*const ()>, +} + +impl BufferPool { + /// Create buffer pool with given `buffer_size` and `buffer_len` + /// + /// # Notes + /// + /// If `buffer_len` is not power of 2, it will be upward with + /// [`u16::next_power_of_two`] + pub fn new(buffer_len: u16, buffer_size: usize) -> io::Result { + let (inner, runtime_id) = Runtime::with_current(|runtime| { + let buffer_pool = runtime.create_buffer_pool(buffer_len, buffer_size)?; + let runtime_id = runtime.id(); + + Ok((buffer_pool, runtime_id)) + })?; + + Ok(Self::inner_new(inner, runtime_id)) + } + + fn inner_new(inner: compio_driver::BufferPool, runtime_id: i64) -> Self { + Self { + inner: ManuallyDrop::new(inner), + runtime_id, + _marker: Default::default(), + } + } + + /// Get the inner driver buffer pool reference + /// + /// # Notes + /// + /// You should not use this method unless you are writing your own IO opcode + /// + /// # Panic + /// + /// If call this method in incorrect runtime, will panic + pub fn as_driver_buffer_pool(&self) -> &compio_driver::BufferPool { + let current_runtime_id = Runtime::with_current(|runtime| runtime.id()); + assert_eq!(current_runtime_id, self.runtime_id); + + &self.inner + } +} + +impl Drop for BufferPool { + fn drop(&mut self) { + let _ = Runtime::try_with_current(|runtime| { + if self.runtime_id != runtime.id() { + return; + } + + unsafe { + // Safety: we own the inner + let inner = ManuallyDrop::take(&mut self.inner); + + // Safety: the buffer pool is created by current thread runtime + let _ = runtime.release_buffer_pool(inner); + } + }); + } +} diff --git a/compio-runtime/src/lib.rs b/compio-runtime/src/lib.rs index de82cd56..b6647ec9 100644 --- a/compio-runtime/src/lib.rs +++ b/compio-runtime/src/lib.rs @@ -14,6 +14,7 @@ mod attacher; mod runtime; +pub mod buffer_pool; #[cfg(feature = "event")] pub mod event; #[cfg(feature = "time")] diff --git a/compio-runtime/src/runtime/mod.rs b/compio-runtime/src/runtime/mod.rs index 78c13c43..38e0ab43 100644 --- a/compio-runtime/src/runtime/mod.rs +++ b/compio-runtime/src/runtime/mod.rs @@ -47,6 +47,14 @@ pub struct Runtime { sync_runnables: Arc>, #[cfg(feature = "time")] timer_runtime: RefCell, + // Runtime id is used to check if the buffer pool is belonged to this runtime or not. + // Without this, if user enable `io-uring-buf-ring` feature then: + // 1. Create a buffer pool at runtime1 + // 3. Create another runtime2, then use the exists buffer pool in runtime2, it may cause + // - io-uring report error if the buffer group id is not registered + // - buffer pool will return a wrong buffer which the buffer's data is uninit, that will cause + // UB + id: i64, // Other fields don't make it !Send, but actually `local_runnables` implies it should be !Send, // otherwise it won't be valid if the runtime is sent to other threads. _p: PhantomData>>, @@ -70,6 +78,7 @@ impl Runtime { sync_runnables: Arc::new(SegQueue::new()), #[cfg(feature = "time")] timer_runtime: RefCell::new(TimerRuntime::new()), + id: rand::random(), _p: PhantomData, }) } @@ -348,6 +357,27 @@ impl Runtime { #[cfg(feature = "time")] self.timer_runtime.borrow_mut().wake(); } + + pub(crate) fn create_buffer_pool( + &self, + buffer_len: u16, + buffer_size: usize, + ) -> io::Result { + self.driver + .borrow_mut() + .create_buffer_pool(buffer_len, buffer_size) + } + + pub(crate) unsafe fn release_buffer_pool( + &self, + buffer_pool: compio_driver::BufferPool, + ) -> io::Result<()> { + self.driver.borrow_mut().release_buffer_pool(buffer_pool) + } + + pub(crate) fn id(&self) -> i64 { + self.id + } } impl AsRawFd for Runtime { diff --git a/compio/Cargo.toml b/compio/Cargo.toml index 692d7ac1..30e915d2 100644 --- a/compio/Cargo.toml +++ b/compio/Cargo.toml @@ -81,6 +81,7 @@ io-uring = [ "compio-fs?/io-uring", "compio-net?/io-uring", ] +io-uring-buf-ring = ["compio-driver/io-uring-buf-ring"] polling = ["compio-driver/polling"] io = ["dep:compio-io"] io-compat = ["io", "compio-io/compat"]