Skip to content

Commit

Permalink
Merge pull request #223 from Berrysoft/dev/win32-event
Browse files Browse the repository at this point in the history
feat(driver,windows): add win32 event support
  • Loading branch information
Berrysoft committed Mar 8, 2024
2 parents 4631553 + 338a762 commit 7fbd340
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 45 deletions.
151 changes: 128 additions & 23 deletions compio-driver/src/iocp/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use std::{
collections::HashSet,
collections::{HashMap, HashSet},
io,
mem::ManuallyDrop,
os::windows::prelude::{
AsRawHandle, AsRawSocket, FromRawHandle, FromRawSocket, IntoRawHandle, IntoRawSocket,
RawHandle,
os::{
raw::c_void,
windows::prelude::{
AsRawHandle, AsRawSocket, FromRawHandle, FromRawSocket, IntoRawHandle, IntoRawSocket,
RawHandle,
},
},
pin::Pin,
ptr::NonNull,
ptr::{null, NonNull},
sync::Arc,
task::Poll,
time::Duration,
Expand All @@ -17,9 +20,15 @@ use compio_buf::BufResult;
use compio_log::{instrument, trace};
use slab::Slab;
use windows_sys::Win32::{
Foundation::{ERROR_BUSY, ERROR_OPERATION_ABORTED},
Foundation::{ERROR_BUSY, ERROR_OPERATION_ABORTED, ERROR_TIMEOUT, WAIT_OBJECT_0, WAIT_TIMEOUT},
Networking::WinSock::{WSACleanup, WSAStartup, WSADATA},
System::IO::OVERLAPPED,
System::{
Threading::{
CloseThreadpoolWait, CreateThreadpoolWait, SetThreadpoolWait,
WaitForThreadpoolWaitCallbacks, PTP_CALLBACK_INSTANCE, PTP_WAIT,
},
IO::OVERLAPPED,
},
};

use crate::{syscall, AsyncifyPool, Entry, OutEntries, ProactorBuilder};
Expand Down Expand Up @@ -98,12 +107,25 @@ impl IntoRawFd for socket2::Socket {
}
}

/// Operation type.
pub enum OpType {
/// An overlapped operation.
Overlapped,
/// A blocking operation, needs a thread to spawn. The `operate` method
/// should be thread safe.
Blocking,
/// A Win32 event object to be waited. The user should ensure that the
/// handle is valid till operation completes. The `operate` method should be
/// thread safe.
Event(RawFd),
}

/// Abstraction of IOCP operations.
pub trait OpCode {
/// Determines that the operation is really overlapped defined by Windows
/// API. If not, the driver will try to operate it in another thread.
fn is_overlapped(&self) -> bool {
true
fn op_type(&self) -> OpType {
OpType::Overlapped
}

/// Perform Windows API call with given pointer to overlapped struct.
Expand Down Expand Up @@ -133,6 +155,7 @@ pub trait OpCode {
/// Low-level driver of IOCP.
pub(crate) struct Driver {
port: cp::Port,
waits: HashMap<usize, WinThreadpollWait>,
cancelled: HashSet<usize>,
pool: AsyncifyPool,
notify_overlapped: Arc<Overlapped<()>>,
Expand All @@ -150,6 +173,7 @@ impl Driver {
let driver = port.as_raw_handle() as _;
Ok(Self {
port,
waits: HashMap::default(),
cancelled: HashSet::default(),
pool: builder.create_or_get_thread_pool(),
notify_overlapped: Arc::new(Overlapped::new(driver, Self::NOTIFY, ())),
Expand Down Expand Up @@ -188,12 +212,22 @@ impl Driver {
trace!("push RawOp");
let optr = op.as_mut_ptr();
let op_pin = op.as_op_pin();
if op_pin.is_overlapped() {
unsafe { op_pin.operate(optr.cast()) }
} else if self.push_blocking(op)? {
Poll::Pending
} else {
Poll::Ready(Err(io::Error::from_raw_os_error(ERROR_BUSY as _)))
match op_pin.op_type() {
OpType::Overlapped => unsafe { op_pin.operate(optr.cast()) },
OpType::Blocking => {
if self.push_blocking(op)? {
Poll::Pending
} else {
Poll::Ready(Err(io::Error::from_raw_os_error(ERROR_BUSY as _)))
}
}
OpType::Event(e) => {
self.waits.insert(
user_data,
WinThreadpollWait::new(self.port.handle(), e, op)?,
);
Poll::Pending
}
}
}
}
Expand All @@ -213,20 +247,20 @@ impl Driver {
// Safety: the pointer is created from a reference.
let op = unsafe { optr.0.as_mut() };
let optr = op.as_mut_ptr();
let op = op.as_op_pin();
let res = unsafe { op.operate(optr.cast()) };
let res = match res {
Poll::Pending => unreachable!("this operation is not overlapped"),
Poll::Ready(res) => res,
};
let res = op.operate_blocking();
port.post(res, optr).ok();
})
.is_ok())
}

fn create_entry(cancelled: &mut HashSet<usize>, entry: Entry) -> Option<Entry> {
fn create_entry(
cancelled: &mut HashSet<usize>,
waits: &mut HashMap<usize, WinThreadpollWait>,
entry: Entry,
) -> Option<Entry> {
let user_data = entry.user_data();
if user_data != Self::NOTIFY {
waits.remove(&user_data);
let result = if cancelled.remove(&user_data) {
Err(io::Error::from_raw_os_error(ERROR_OPERATION_ABORTED as _))
} else {
Expand All @@ -248,7 +282,7 @@ impl Driver {
entries.extend(
self.port
.poll(timeout)?
.filter_map(|e| Self::create_entry(&mut self.cancelled, e)),
.filter_map(|e| Self::create_entry(&mut self.cancelled, &mut self.waits, e)),
);

Ok(())
Expand Down Expand Up @@ -291,6 +325,67 @@ impl NotifyHandle {
}
}

struct WinThreadpollWait {
wait: PTP_WAIT,
// For memory safety.
#[allow(dead_code)]
context: Box<WinThreadpollWaitContext>,
}

impl WinThreadpollWait {
pub fn new(port: cp::PortHandle, event: RawFd, op: &mut RawOp) -> io::Result<Self> {
let mut context = Box::new(WinThreadpollWaitContext { port, op });
let wait = syscall!(
BOOL,
CreateThreadpoolWait(
Some(Self::wait_callback),
(&mut *context) as *mut WinThreadpollWaitContext as _,
null()
)
)?;
unsafe {
SetThreadpoolWait(wait, event as _, null());
}
Ok(Self { wait, context })
}

unsafe extern "system" fn wait_callback(
_instance: PTP_CALLBACK_INSTANCE,
context: *mut c_void,
_wait: PTP_WAIT,
result: u32,
) {
let context = &*(context as *mut WinThreadpollWaitContext);
let res = match result {
WAIT_OBJECT_0 => Ok(0),
WAIT_TIMEOUT => Err(io::Error::from_raw_os_error(ERROR_TIMEOUT as _)),
_ => Err(io::Error::from_raw_os_error(result as _)),
};
let res = if res.is_err() {
res
} else {
let op = unsafe { &mut *context.op };
op.operate_blocking()
};
context.port.post(res, (*context.op).as_mut_ptr()).ok();
}
}

impl Drop for WinThreadpollWait {
fn drop(&mut self) {
unsafe {
SetThreadpoolWait(self.wait, 0, null());
WaitForThreadpoolWaitCallbacks(self.wait, 1);
CloseThreadpoolWait(self.wait);
}
}
}

struct WinThreadpollWaitContext {
port: cp::PortHandle,
op: *mut RawOp,
}

/// The overlapped struct we actually used for IOCP.
#[repr(C)]
pub struct Overlapped<T: ?Sized> {
Expand Down Expand Up @@ -371,6 +466,16 @@ impl RawOp {
let overlapped: Box<Overlapped<T>> = Box::from_raw(this.op.cast().as_ptr());
BufResult(this.result.take().unwrap(), overlapped.op)
}

fn operate_blocking(&mut self) -> io::Result<usize> {
let optr = self.as_mut_ptr();
let op = self.as_op_pin();
let res = unsafe { op.operate(optr.cast()) };
match res {
Poll::Pending => unreachable!("this operation is not overlapped"),
Poll::Ready(res) => res,
}
}
}

impl Drop for RawOp {
Expand Down
34 changes: 17 additions & 17 deletions compio-driver/src/iocp/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use windows_sys::{
},
};

use crate::{op::*, syscall, OpCode, RawFd};
use crate::{op::*, syscall, OpCode, OpType, RawFd};

#[inline]
fn winapi_result(transferred: u32) -> Poll<io::Result<usize>> {
Expand Down Expand Up @@ -119,8 +119,8 @@ impl<
F: (FnOnce() -> BufResult<usize, D>) + std::marker::Send + std::marker::Sync + 'static,
> OpCode for Asyncify<F, D>
{
fn is_overlapped(&self) -> bool {
false
fn op_type(&self) -> OpType {
OpType::Blocking
}

unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
Expand Down Expand Up @@ -176,8 +176,8 @@ impl OpenFile {
}

impl OpCode for OpenFile {
fn is_overlapped(&self) -> bool {
false
fn op_type(&self) -> OpType {
OpType::Blocking
}

unsafe fn operate(mut self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
Expand All @@ -200,8 +200,8 @@ impl OpCode for OpenFile {
}

impl OpCode for CloseFile {
fn is_overlapped(&self) -> bool {
false
fn op_type(&self) -> OpType {
OpType::Blocking
}

unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
Expand Down Expand Up @@ -293,8 +293,8 @@ impl FileStat {
}

impl OpCode for FileStat {
fn is_overlapped(&self) -> bool {
false
fn op_type(&self) -> OpType {
OpType::Blocking
}

unsafe fn operate(mut self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
Expand Down Expand Up @@ -378,8 +378,8 @@ impl PathStat {
}

impl OpCode for PathStat {
fn is_overlapped(&self) -> bool {
false
fn op_type(&self) -> OpType {
OpType::Blocking
}

unsafe fn operate(mut self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
Expand Down Expand Up @@ -473,8 +473,8 @@ impl<T: IoBuf> OpCode for WriteAt<T> {
}

impl OpCode for Sync {
fn is_overlapped(&self) -> bool {
false
fn op_type(&self) -> OpType {
OpType::Blocking
}

unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
Expand All @@ -483,8 +483,8 @@ impl OpCode for Sync {
}

impl OpCode for ShutdownSocket {
fn is_overlapped(&self) -> bool {
false
fn op_type(&self) -> OpType {
OpType::Blocking
}

unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
Expand All @@ -498,8 +498,8 @@ impl OpCode for ShutdownSocket {
}

impl OpCode for CloseSocket {
fn is_overlapped(&self) -> bool {
false
fn op_type(&self) -> OpType {
OpType::Blocking
}

unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
Expand Down
10 changes: 5 additions & 5 deletions compio-fs/src/stdio/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut};
use compio_driver::{
op::{BufResultExt, Recv, Send},
AsRawFd, OpCode, RawFd,
AsRawFd, OpCode, OpType, RawFd,
};
use compio_io::{AsyncRead, AsyncWrite};
use compio_runtime::Runtime;
Expand All @@ -30,8 +30,8 @@ impl<R: Read, B: IoBufMut> StdRead<R, B> {
}

impl<R: Read, B: IoBufMut> OpCode for StdRead<R, B> {
fn is_overlapped(&self) -> bool {
false
fn op_type(&self) -> OpType {
OpType::Blocking
}

unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
Expand Down Expand Up @@ -79,8 +79,8 @@ impl<W: Write, B: IoBuf> StdWrite<W, B> {
}

impl<W: Write, B: IoBuf> OpCode for StdWrite<W, B> {
fn is_overlapped(&self) -> bool {
false
fn op_type(&self) -> OpType {
OpType::Blocking
}

unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
Expand Down
Loading

0 comments on commit 7fbd340

Please sign in to comment.