diff --git a/src/io/sys/windows/miow.rs b/src/io/sys/windows/miow.rs index e9e387c6..8cac75b6 100644 --- a/src/io/sys/windows/miow.rs +++ b/src/io/sys/windows/miow.rs @@ -1,368 +1,528 @@ -//! ported from miow crate which is not maintained anymore - -use std::os::windows::io::{AsRawHandle, AsRawSocket}; -use std::time::Duration; -use std::{io, os::windows::io::RawSocket}; - -use windows_sys::Win32::Foundation::{HANDLE, INVALID_HANDLE_VALUE}; -use windows_sys::Win32::Networking::WinSock::*; -use windows_sys::Win32::System::Threading::INFINITE; -use windows_sys::Win32::System::IO::*; - -/// A handle to an Windows I/O Completion Port. -#[derive(Debug)] -pub struct CompletionPort { - handle: HANDLE, -} - -impl CompletionPort { - /// Creates a new I/O completion port with the specified concurrency value. - /// - /// The number of threads given corresponds to the level of concurrency - /// allowed for threads associated with this port. Consult the Windows - /// documentation for more information about this value. - pub fn new(threads: u32) -> io::Result { - let ret = unsafe { CreateIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0, threads) }; - if ret == 0 { - Err(io::Error::last_os_error()) - } else { - Ok(CompletionPort { handle: ret }) - } - } - - /// Associates a new `HANDLE` to this I/O completion port. - /// - /// This function will associate the given handle to this port with the - /// given `token` to be returned in status messages whenever it receives a - /// notification. - /// - /// Any object which is convertible to a `HANDLE` via the `AsRawHandle` - /// trait can be provided to this function, such as `std::fs::File` and - /// friends. - #[allow(dead_code)] - pub fn add_handle(&self, token: usize, t: &T) -> io::Result<()> { - self._add(token, t.as_raw_handle() as HANDLE) - } - - /// Associates a new `SOCKET` to this I/O completion port. - /// - /// This function will associate the given socket to this port with the - /// given `token` to be returned in status messages whenever it receives a - /// notification. - /// - /// Any object which is convertible to a `SOCKET` via the `AsRawSocket` - /// trait can be provided to this function, such as `std::net::TcpStream` - /// and friends. - pub fn add_socket(&self, token: usize, t: &T) -> io::Result<()> { - self._add(token, t.as_raw_socket() as HANDLE) - } - - fn _add(&self, token: usize, handle: HANDLE) -> io::Result<()> { - assert_eq!(std::mem::size_of_val(&token), std::mem::size_of::()); - let ret = unsafe { CreateIoCompletionPort(handle, self.handle, token, 0) }; - if ret == 0 { - Err(io::Error::last_os_error()) - } else { - debug_assert_eq!(ret, self.handle); - Ok(()) - } - } - - /// Dequeue a completion status from this I/O completion port. - /// - /// This function will associate the calling thread with this completion - /// port and then wait for a status message to become available. The precise - /// semantics on when this function returns depends on the concurrency value - /// specified when the port was created. - /// - /// A timeout can optionally be specified to this function. If `None` is - /// provided this function will not time out, and otherwise it will time out - /// after the specified duration has passed. - /// - /// On success this will return the status message which was dequeued from - /// this completion port. - #[allow(dead_code)] - pub fn get(&self, timeout: Option) -> io::Result { - let mut bytes = 0; - let mut token = 0; - let mut overlapped = std::ptr::null_mut(); - let timeout = dur2ms(timeout); - let ret = unsafe { - GetQueuedCompletionStatus( - self.handle, - &mut bytes, - &mut token, - &mut overlapped, - timeout, - ) - }; - cvt(ret, 0).map(|_| { - CompletionStatus(OVERLAPPED_ENTRY { - dwNumberOfBytesTransferred: bytes, - lpCompletionKey: token, - lpOverlapped: overlapped, - Internal: 0, - }) - }) - } - - /// Dequeues a number of completion statuses from this I/O completion port. - /// - /// This function is the same as `get` except that it may return more than - /// one status. A buffer of "zero" statuses is provided (the contents are - /// not read) and then on success this function will return a sub-slice of - /// statuses which represent those which were dequeued from this port. This - /// function does not wait to fill up the entire list of statuses provided. - /// - /// Like with `get`, a timeout may be specified for this operation. - pub fn get_many<'a>( - &self, - list: &'a mut [CompletionStatus], - timeout: Option, - ) -> io::Result<&'a mut [CompletionStatus]> { - debug_assert_eq!( - std::mem::size_of::(), - std::mem::size_of::() - ); - let mut removed = 0; - let timeout = dur2ms(timeout); - let len = std::cmp::min(list.len(), ::max_value() as usize) as u32; - let ret = unsafe { - GetQueuedCompletionStatusEx( - self.handle, - list.as_ptr() as *mut _, - len, - &mut removed, - timeout, - 0, - ) - }; - match cvt_ret(ret) { - Ok(_) => Ok(&mut list[..removed as usize]), - Err(e) => Err(e), - } - } - - /// Posts a new completion status onto this I/O completion port. - /// - /// This function will post the given status, with custom parameters, to the - /// port. Threads blocked in `get` or `get_many` will eventually receive - /// this status. - pub fn post(&self, status: CompletionStatus) -> io::Result<()> { - let ret = unsafe { - PostQueuedCompletionStatus( - self.handle, - status.0.dwNumberOfBytesTransferred, - status.0.lpCompletionKey, - status.0.lpOverlapped, - ) - }; - cvt_ret(ret).map(|_| ()) - } -} - -// impl AsRawHandle for CompletionPort { -// fn as_raw_handle(&self) -> RawHandle { -// self.handle.raw() as RawHandle -// } -// } - -// impl FromRawHandle for CompletionPort { -// unsafe fn from_raw_handle(handle: RawHandle) -> CompletionPort { -// CompletionPort { -// handle: Handle::new(handle as HANDLE), -// } -// } -// } - -// impl IntoRawHandle for CompletionPort { -// fn into_raw_handle(self) -> RawHandle { -// self.handle.into_raw() as RawHandle -// } -// } - -/// A status message received from an I/O completion port. -/// -/// These statuses can be created via the `new` or `empty` constructors and then -/// provided to a completion port, or they are read out of a completion port. -/// The fields of each status are read through its accessor methods. -#[derive(Clone, Copy)] -#[repr(transparent)] -pub struct CompletionStatus(OVERLAPPED_ENTRY); - -impl std::fmt::Debug for CompletionStatus { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "CompletionStatus(OVERLAPPED_ENTRY)") - } -} - -unsafe impl Send for CompletionStatus {} -unsafe impl Sync for CompletionStatus {} - -impl CompletionStatus { - /// Creates a new completion status with the provided parameters. - /// - /// This function is useful when creating a status to send to a port with - /// the `post` method. The parameters are opaquely passed through and not - /// interpreted by the system at all. - pub fn new(bytes: u32, token: usize, overlapped: *mut OVERLAPPED) -> CompletionStatus { - assert_eq!(std::mem::size_of_val(&token), std::mem::size_of::()); - CompletionStatus(OVERLAPPED_ENTRY { - dwNumberOfBytesTransferred: bytes, - lpCompletionKey: token, - lpOverlapped: overlapped, - Internal: 0, - }) - } - - /// Creates a new borrowed completion status from the borrowed - /// `OVERLAPPED_ENTRY` argument provided. - /// - /// This method will wrap the `OVERLAPPED_ENTRY` in a `CompletionStatus`, - /// returning the wrapped structure. - #[allow(dead_code)] - pub fn from_entry(entry: &OVERLAPPED_ENTRY) -> &CompletionStatus { - // Safety: CompletionStatus is repr(transparent) w/ OVERLAPPED_ENTRY, so - // a reference to one is guaranteed to be layout compatible with the - // reference to another. - unsafe { &*(entry as *const _ as *const _) } - } - - /// Creates a new "zero" completion status. - /// - /// This function is useful when creating a stack buffer or vector of - /// completion statuses to be passed to the `get_many` function. - #[allow(dead_code)] - pub fn zero() -> CompletionStatus { - CompletionStatus::new(0, 0, std::ptr::null_mut()) - } - - /// Returns the number of bytes that were transferred for the I/O operation - /// associated with this completion status. - #[allow(dead_code)] - pub fn bytes_transferred(&self) -> u32 { - self.0.dwNumberOfBytesTransferred - } - - /// Returns the completion key value associated with the file handle whose - /// I/O operation has completed. - /// - /// A completion key is a per-handle key that is specified when it is added - /// to an I/O completion port via `add_handle` or `add_socket`. - #[allow(dead_code)] - pub fn token(&self) -> usize { - self.0.lpCompletionKey - } - - /// Returns a pointer to the `Overlapped` structure that was specified when - /// the I/O operation was started. - pub fn overlapped(&self) -> *mut OVERLAPPED { - self.0.lpOverlapped - } - - /// Returns a pointer to the internal `OVERLAPPED_ENTRY` object. - #[allow(dead_code)] - pub fn entry(&self) -> &OVERLAPPED_ENTRY { - &self.0 - } -} - -fn dur2ms(dur: Option) -> u32 { - let dur = match dur { - Some(dur) => dur, - None => return INFINITE, - }; - let ms = dur.as_secs().checked_mul(1_000); - let ms_extra = dur.subsec_millis(); - ms.and_then(|ms| ms.checked_add(ms_extra as u64)) - .map(|ms| std::cmp::min(u32::max_value() as u64, ms) as u32) - .unwrap_or(INFINITE - 1) -} - -unsafe fn slice2buf(slice: &[u8]) -> WSABUF { - WSABUF { - len: std::cmp::min(slice.len(), ::max_value() as usize) as u32, - buf: slice.as_ptr() as *mut _, - } -} - -fn last_err() -> io::Result> { - let err = unsafe { WSAGetLastError() }; - if err == WSA_IO_PENDING { - Ok(None) - } else { - Err(io::Error::from_raw_os_error(err)) - } -} - -fn cvt_ret(i: i32) -> io::Result { - if i == 0 { - Err(io::Error::last_os_error()) - } else { - Ok(i != 0) - } -} - -fn cvt(i: i32, size: u32) -> io::Result> { - if i == SOCKET_ERROR { - last_err() - } else { - Ok(Some(size as usize)) - } -} - -pub unsafe fn socket_read( - socket: RawSocket, - buf: &mut [u8], - flags: i32, - overlapped: *mut OVERLAPPED, -) -> io::Result> { - let buf = slice2buf(buf); - let mut bytes_read: u32 = 0; - let mut flags = flags as u32; - let r = WSARecv( - socket as SOCKET, - &buf, - 1, - &mut bytes_read, - &mut flags, - overlapped, - None, - ); - cvt(r, bytes_read) -} - -pub unsafe fn socket_write( - socket: RawSocket, - buf: &[u8], - overlapped: *mut OVERLAPPED, -) -> io::Result> { - let mut buf = slice2buf(buf); - let mut bytes_written = 0; - // Note here that we capture the number of bytes written. The - // documentation on MSDN, however, states: - // - // > Use NULL for this parameter if the lpOverlapped parameter is not - // > NULL to avoid potentially erroneous results. This parameter can be - // > NULL only if the lpOverlapped parameter is not NULL. - // - // If we're not passing a null overlapped pointer here, then why are we - // then capturing the number of bytes! Well so it turns out that this is - // clearly faster to learn the bytes here rather than later calling - // `WSAGetOverlappedResult`, and in practice almost all implementations - // use this anyway [1]. - // - // As a result we use this to and report back the result. - // - // [1]: https://github.com/carllerche/mio/pull/520#issuecomment-273983823 - let r = WSASend( - socket as SOCKET, - &mut buf, - 1, - &mut bytes_written, - 0, - overlapped, - None, - ); - cvt(r, bytes_written) -} +//! ported from miow crate which is not maintained anymore + +use std::net::SocketAddr; +use std::os::windows::io::{AsRawHandle, AsRawSocket}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; +use std::{io, os::windows::io::RawSocket}; + +use windows_sys::core::GUID; +use windows_sys::Win32::Foundation::{HANDLE, INVALID_HANDLE_VALUE}; +use windows_sys::Win32::Networking::WinSock::*; +use windows_sys::Win32::System::Threading::INFINITE; +use windows_sys::Win32::System::IO::*; + +type BOOL = i32; +const TRUE: BOOL = 1; +const FALSE: BOOL = 0; + +/// A handle to an Windows I/O Completion Port. +#[derive(Debug)] +pub struct CompletionPort { + handle: HANDLE, +} + +impl CompletionPort { + /// Creates a new I/O completion port with the specified concurrency value. + /// + /// The number of threads given corresponds to the level of concurrency + /// allowed for threads associated with this port. Consult the Windows + /// documentation for more information about this value. + pub fn new(threads: u32) -> io::Result { + let ret = unsafe { CreateIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0, threads) }; + if ret == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(CompletionPort { handle: ret }) + } + } + + /// Associates a new `HANDLE` to this I/O completion port. + /// + /// This function will associate the given handle to this port with the + /// given `token` to be returned in status messages whenever it receives a + /// notification. + /// + /// Any object which is convertible to a `HANDLE` via the `AsRawHandle` + /// trait can be provided to this function, such as `std::fs::File` and + /// friends. + #[allow(dead_code)] + pub fn add_handle(&self, token: usize, t: &T) -> io::Result<()> { + self._add(token, t.as_raw_handle() as HANDLE) + } + + /// Associates a new `SOCKET` to this I/O completion port. + /// + /// This function will associate the given socket to this port with the + /// given `token` to be returned in status messages whenever it receives a + /// notification. + /// + /// Any object which is convertible to a `SOCKET` via the `AsRawSocket` + /// trait can be provided to this function, such as `std::net::TcpStream` + /// and friends. + pub fn add_socket(&self, token: usize, t: &T) -> io::Result<()> { + self._add(token, t.as_raw_socket() as HANDLE) + } + + fn _add(&self, token: usize, handle: HANDLE) -> io::Result<()> { + assert_eq!(std::mem::size_of_val(&token), std::mem::size_of::()); + let ret = unsafe { CreateIoCompletionPort(handle, self.handle, token, 0) }; + if ret == 0 { + Err(io::Error::last_os_error()) + } else { + debug_assert_eq!(ret, self.handle); + Ok(()) + } + } + + /// Dequeue a completion status from this I/O completion port. + /// + /// This function will associate the calling thread with this completion + /// port and then wait for a status message to become available. The precise + /// semantics on when this function returns depends on the concurrency value + /// specified when the port was created. + /// + /// A timeout can optionally be specified to this function. If `None` is + /// provided this function will not time out, and otherwise it will time out + /// after the specified duration has passed. + /// + /// On success this will return the status message which was dequeued from + /// this completion port. + #[allow(dead_code)] + pub fn get(&self, timeout: Option) -> io::Result { + let mut bytes = 0; + let mut token = 0; + let mut overlapped = std::ptr::null_mut(); + let timeout = dur2ms(timeout); + let ret = unsafe { + GetQueuedCompletionStatus( + self.handle, + &mut bytes, + &mut token, + &mut overlapped, + timeout, + ) + }; + cvt(ret, 0).map(|_| { + CompletionStatus(OVERLAPPED_ENTRY { + dwNumberOfBytesTransferred: bytes, + lpCompletionKey: token, + lpOverlapped: overlapped, + Internal: 0, + }) + }) + } + + /// Dequeues a number of completion statuses from this I/O completion port. + /// + /// This function is the same as `get` except that it may return more than + /// one status. A buffer of "zero" statuses is provided (the contents are + /// not read) and then on success this function will return a sub-slice of + /// statuses which represent those which were dequeued from this port. This + /// function does not wait to fill up the entire list of statuses provided. + /// + /// Like with `get`, a timeout may be specified for this operation. + pub fn get_many<'a>( + &self, + list: &'a mut [CompletionStatus], + timeout: Option, + ) -> io::Result<&'a mut [CompletionStatus]> { + debug_assert_eq!( + std::mem::size_of::(), + std::mem::size_of::() + ); + let mut removed = 0; + let timeout = dur2ms(timeout); + let len = std::cmp::min(list.len(), ::max_value() as usize) as u32; + let ret = unsafe { + GetQueuedCompletionStatusEx( + self.handle, + list.as_ptr() as *mut _, + len, + &mut removed, + timeout, + 0, + ) + }; + match cvt_ret(ret) { + Ok(_) => Ok(&mut list[..removed as usize]), + Err(e) => Err(e), + } + } + + /// Posts a new completion status onto this I/O completion port. + /// + /// This function will post the given status, with custom parameters, to the + /// port. Threads blocked in `get` or `get_many` will eventually receive + /// this status. + pub fn post(&self, status: CompletionStatus) -> io::Result<()> { + let ret = unsafe { + PostQueuedCompletionStatus( + self.handle, + status.0.dwNumberOfBytesTransferred, + status.0.lpCompletionKey, + status.0.lpOverlapped, + ) + }; + cvt_ret(ret).map(|_| ()) + } +} + +// impl AsRawHandle for CompletionPort { +// fn as_raw_handle(&self) -> RawHandle { +// self.handle.raw() as RawHandle +// } +// } + +// impl FromRawHandle for CompletionPort { +// unsafe fn from_raw_handle(handle: RawHandle) -> CompletionPort { +// CompletionPort { +// handle: Handle::new(handle as HANDLE), +// } +// } +// } + +// impl IntoRawHandle for CompletionPort { +// fn into_raw_handle(self) -> RawHandle { +// self.handle.into_raw() as RawHandle +// } +// } + +/// A status message received from an I/O completion port. +/// +/// These statuses can be created via the `new` or `empty` constructors and then +/// provided to a completion port, or they are read out of a completion port. +/// The fields of each status are read through its accessor methods. +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct CompletionStatus(OVERLAPPED_ENTRY); + +impl std::fmt::Debug for CompletionStatus { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "CompletionStatus(OVERLAPPED_ENTRY)") + } +} + +unsafe impl Send for CompletionStatus {} +unsafe impl Sync for CompletionStatus {} + +impl CompletionStatus { + /// Creates a new completion status with the provided parameters. + /// + /// This function is useful when creating a status to send to a port with + /// the `post` method. The parameters are opaquely passed through and not + /// interpreted by the system at all. + pub fn new(bytes: u32, token: usize, overlapped: *mut OVERLAPPED) -> CompletionStatus { + assert_eq!(std::mem::size_of_val(&token), std::mem::size_of::()); + CompletionStatus(OVERLAPPED_ENTRY { + dwNumberOfBytesTransferred: bytes, + lpCompletionKey: token, + lpOverlapped: overlapped, + Internal: 0, + }) + } + + /// Creates a new borrowed completion status from the borrowed + /// `OVERLAPPED_ENTRY` argument provided. + /// + /// This method will wrap the `OVERLAPPED_ENTRY` in a `CompletionStatus`, + /// returning the wrapped structure. + #[allow(dead_code)] + pub fn from_entry(entry: &OVERLAPPED_ENTRY) -> &CompletionStatus { + // Safety: CompletionStatus is repr(transparent) w/ OVERLAPPED_ENTRY, so + // a reference to one is guaranteed to be layout compatible with the + // reference to another. + unsafe { &*(entry as *const _ as *const _) } + } + + /// Creates a new "zero" completion status. + /// + /// This function is useful when creating a stack buffer or vector of + /// completion statuses to be passed to the `get_many` function. + #[allow(dead_code)] + pub fn zero() -> CompletionStatus { + CompletionStatus::new(0, 0, std::ptr::null_mut()) + } + + /// Returns the number of bytes that were transferred for the I/O operation + /// associated with this completion status. + #[allow(dead_code)] + pub fn bytes_transferred(&self) -> u32 { + self.0.dwNumberOfBytesTransferred + } + + /// Returns the completion key value associated with the file handle whose + /// I/O operation has completed. + /// + /// A completion key is a per-handle key that is specified when it is added + /// to an I/O completion port via `add_handle` or `add_socket`. + #[allow(dead_code)] + pub fn token(&self) -> usize { + self.0.lpCompletionKey + } + + /// Returns a pointer to the `Overlapped` structure that was specified when + /// the I/O operation was started. + pub fn overlapped(&self) -> *mut OVERLAPPED { + self.0.lpOverlapped + } + + /// Returns a pointer to the internal `OVERLAPPED_ENTRY` object. + #[allow(dead_code)] + pub fn entry(&self) -> &OVERLAPPED_ENTRY { + &self.0 + } +} + +struct WsaExtension { + guid: GUID, + val: AtomicUsize, +} + +impl WsaExtension { + fn get(&self, socket: SOCKET) -> io::Result { + let prev = self.val.load(Ordering::SeqCst); + if prev != 0 && !cfg!(debug_assertions) { + return Ok(prev); + } + let mut ret = 0 as usize; + let mut bytes = 0; + + // https://github.com/microsoft/win32metadata/issues/671 + const SIO_GET_EXTENSION_FUNCTION_POINTER: u32 = 33_5544_3206u32; + + let r = unsafe { + WSAIoctl( + socket, + SIO_GET_EXTENSION_FUNCTION_POINTER, + &self.guid as *const _ as *mut _, + std::mem::size_of_val(&self.guid) as u32, + &mut ret as *mut _ as *mut _, + std::mem::size_of_val(&ret) as u32, + &mut bytes, + 0 as *mut _, + None, + ) + }; + cvt(r, 0).map(|_| { + debug_assert_eq!(bytes as usize, std::mem::size_of_val(&ret)); + debug_assert!(prev == 0 || prev == ret); + self.val.store(ret, Ordering::SeqCst); + ret + }) + } +} + +fn dur2ms(dur: Option) -> u32 { + let dur = match dur { + Some(dur) => dur, + None => return INFINITE, + }; + let ms = dur.as_secs().checked_mul(1_000); + let ms_extra = dur.subsec_millis(); + ms.and_then(|ms| ms.checked_add(ms_extra as u64)) + .map(|ms| std::cmp::min(u32::max_value() as u64, ms) as u32) + .unwrap_or(INFINITE - 1) +} + +unsafe fn slice2buf(slice: &[u8]) -> WSABUF { + WSABUF { + len: std::cmp::min(slice.len(), ::max_value() as usize) as u32, + buf: slice.as_ptr() as *mut _, + } +} + +fn last_err() -> io::Result> { + let err = unsafe { WSAGetLastError() }; + if err == WSA_IO_PENDING { + Ok(None) + } else { + Err(io::Error::from_raw_os_error(err)) + } +} + +fn cvt_ret(i: BOOL) -> io::Result { + if i == FALSE { + Err(io::Error::last_os_error()) + } else { + Ok(i != 0) + } +} + +fn cvt(i: i32, size: u32) -> io::Result> { + if i == SOCKET_ERROR { + last_err() + } else { + Ok(Some(size as usize)) + } +} + +/// A type with the same memory layout as `SOCKADDR`. Used in converting Rust level +/// SocketAddr* types into their system representation. The benefit of this specific +/// type over using `SOCKADDR_STORAGE` is that this type is exactly as large as it +/// needs to be and not a lot larger. And it can be initialized cleaner from Rust. +#[repr(C)] +pub(crate) union SocketAddrCRepr { + v4: SOCKADDR_IN, + v6: SOCKADDR_IN6, +} + +impl SocketAddrCRepr { + pub(crate) fn as_ptr(&self) -> *const SOCKADDR { + self as *const _ as *const SOCKADDR + } +} + +fn socket_addr_to_ptrs(addr: &SocketAddr) -> (SocketAddrCRepr, i32) { + match *addr { + SocketAddr::V4(ref a) => { + let sin_addr = IN_ADDR { + S_un: IN_ADDR_0 { + S_addr: u32::from_ne_bytes(a.ip().octets()), + }, + }; + + let sockaddr_in = SOCKADDR_IN { + sin_family: AF_INET as _, + sin_port: a.port().to_be(), + sin_addr, + sin_zero: [0; 8], + }; + + let sockaddr = SocketAddrCRepr { v4: sockaddr_in }; + (sockaddr, std::mem::size_of::() as i32) + } + SocketAddr::V6(ref a) => { + let sockaddr_in6 = SOCKADDR_IN6 { + sin6_family: AF_INET6 as _, + sin6_port: a.port().to_be(), + sin6_addr: IN6_ADDR { + u: IN6_ADDR_0 { + Byte: a.ip().octets(), + }, + }, + sin6_flowinfo: a.flowinfo(), + Anonymous: SOCKADDR_IN6_0 { + sin6_scope_id: a.scope_id(), + }, + }; + + let sockaddr = SocketAddrCRepr { v6: sockaddr_in6 }; + (sockaddr, std::mem::size_of::() as i32) + } + } +} + +pub unsafe fn socket_read( + socket: RawSocket, + buf: &mut [u8], + flags: i32, + overlapped: *mut OVERLAPPED, +) -> io::Result> { + let buf = slice2buf(buf); + let mut bytes_read: u32 = 0; + let mut flags = flags as u32; + let r = WSARecv( + socket as SOCKET, + &buf, + 1, + &mut bytes_read, + &mut flags, + overlapped, + None, + ); + cvt(r, bytes_read) +} + +pub unsafe fn socket_write( + socket: RawSocket, + buf: &[u8], + overlapped: *mut OVERLAPPED, +) -> io::Result> { + let mut buf = slice2buf(buf); + let mut bytes_written = 0; + // Note here that we capture the number of bytes written. The + // documentation on MSDN, however, states: + // + // > Use NULL for this parameter if the lpOverlapped parameter is not + // > NULL to avoid potentially erroneous results. This parameter can be + // > NULL only if the lpOverlapped parameter is not NULL. + // + // If we're not passing a null overlapped pointer here, then why are we + // then capturing the number of bytes! Well so it turns out that this is + // clearly faster to learn the bytes here rather than later calling + // `WSAGetOverlappedResult`, and in practice almost all implementations + // use this anyway [1]. + // + // As a result we use this to and report back the result. + // + // [1]: https://github.com/carllerche/mio/pull/520#issuecomment-273983823 + let r = WSASend( + socket as SOCKET, + &mut buf, + 1, + &mut bytes_written, + 0, + overlapped, + None, + ); + cvt(r, bytes_written) +} + +pub unsafe fn connect_overlapped( + socket: RawSocket, + addr: &SocketAddr, + buf: &[u8], + overlapped: *mut OVERLAPPED, +) -> io::Result> { + static CONNECTEX: WsaExtension = WsaExtension { + guid: GUID { + data1: 0x25a207b9, + data2: 0xddf3, + data3: 0x4660, + data4: [0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e], + }, + val: AtomicUsize::new(0), + }; + + let socket = socket as SOCKET; + + let ptr = CONNECTEX.get(socket)?; + assert!(ptr != 0); + let connect_ex = std::mem::transmute::<_, LPFN_CONNECTEX>(ptr).unwrap(); + + let (addr_buf, addr_len) = socket_addr_to_ptrs(addr); + let mut bytes_sent: u32 = 0; + let r = connect_ex( + socket, + addr_buf.as_ptr(), + addr_len, + buf.as_ptr() as *mut _, + buf.len() as u32, + &mut bytes_sent, + overlapped, + ); + if r == TRUE { + Ok(Some(bytes_sent as usize)) + } else { + last_err() + } +} + +pub fn connect_complete(socket: RawSocket) -> io::Result<()> { + const SO_UPDATE_CONNECT_CONTEXT: i32 = 0x7010; + let result = unsafe { + setsockopt( + socket as SOCKET, + SOL_SOCKET as _, + SO_UPDATE_CONNECT_CONTEXT, + 0 as *mut _, + 0, + ) + }; + if result == 0 { + Ok(()) + } else { + Err(io::Error::last_os_error()) + } +} diff --git a/src/io/sys/windows/net/socket_write.rs b/src/io/sys/windows/net/socket_write.rs index eb9602c1..62ba2323 100644 --- a/src/io/sys/windows/net/socket_write.rs +++ b/src/io/sys/windows/net/socket_write.rs @@ -3,8 +3,8 @@ use std::os::windows::io::{AsRawSocket, RawSocket}; #[cfg(feature = "io_timeout")] use std::time::Duration; -use super::super::{co_io_result, EventData}; use super::super::miow::socket_write; +use super::super::{co_io_result, EventData}; use crate::coroutine_impl::{is_coroutine, CoroutineImpl, EventSource}; use crate::scheduler::get_scheduler; use windows_sys::Win32::Foundation::*; @@ -53,11 +53,7 @@ impl<'a> EventSource for SocketWrite<'a> { self.io_data.co = Some(co); // call the overlapped write API co_try!(s, self.io_data.co.take().expect("can't get co"), unsafe { - socket_write( - self.socket, - self.buf, - self.io_data.get_overlapped(), - ) + socket_write(self.socket, self.buf, self.io_data.get_overlapped()) }); } } diff --git a/src/io/sys/windows/net/tcp_stream_connect.rs b/src/io/sys/windows/net/tcp_stream_connect.rs index 780eb2f9..9f09a4c3 100644 --- a/src/io/sys/windows/net/tcp_stream_connect.rs +++ b/src/io/sys/windows/net/tcp_stream_connect.rs @@ -4,6 +4,7 @@ use std::os::windows::io::AsRawSocket; #[cfg(feature = "io_timeout")] use std::time::Duration; +use super::super::miow::{connect_complete, connect_overlapped}; use super::super::{add_socket, co_io_result, EventData, IoData}; #[cfg(feature = "io_cancel")] use crate::coroutine_impl::co_cancel_data; @@ -14,7 +15,6 @@ use crate::io::OptionCell; use crate::net::TcpStream; use crate::scheduler::get_scheduler; use crate::sync::delay_drop::DelayDrop; -use miow::net::TcpStreamExt; use windows_sys::Win32::Foundation::*; pub struct TcpStreamConnect { @@ -83,7 +83,7 @@ impl TcpStreamConnect { pub fn done(&mut self) -> io::Result { co_io_result(&self.io_data, self.is_coroutine)?; let stream = self.stream.take(); - stream.connect_complete()?; + connect_complete(stream.as_raw_socket())?; Ok(TcpStream::from_stream(stream, IoData)) } } @@ -102,8 +102,12 @@ impl EventSource for TcpStreamConnect { // call the overlapped connect API co_try!(s, self.io_data.co.take().expect("can't get co"), unsafe { - self.stream - .connect_overlapped(&self.addr, &[], self.io_data.get_overlapped()) + connect_overlapped( + self.stream.as_raw_socket(), + &self.addr, + &[], + self.io_data.get_overlapped(), + ) }); #[cfg(feature = "io_cancel")]