From 297c74c1ff760919013ed12c0b449a95e2456368 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 23 Nov 2022 20:41:57 +0000 Subject: [PATCH] Safer poll timeout --- CHANGELOG.md | 2 + src/poll.rs | 256 +++++++++++++++++++++++++++++++++++++++++++++- test/test_poll.rs | 6 +- 3 files changed, 259 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fec222caae..4c174b2006 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,8 @@ This project adheres to [Semantic Versioning](https://semver.org/). ([#1870](https://github.com/nix-rust/nix/pull/1870)) - The `length` argument of `sys::mman::mmap` is now of type `NonZeroUsize`. ([#1873](https://github.com/nix-rust/nix/pull/1873)) +- The `timeout` argument of `poll::poll` is now of type `poll::PollTimeout`. + ([#1876](https://github.com/nix-rust/nix/pull/1876)) ### Fixed diff --git a/src/poll.rs b/src/poll.rs index e1baa814f1..3f779f596f 100644 --- a/src/poll.rs +++ b/src/poll.rs @@ -1,5 +1,8 @@ //! Wait for events to trigger on specific file descriptors +use std::convert::TryFrom; +use std::fmt; use std::os::unix::io::{AsRawFd, RawFd}; +use std::time::Duration; use crate::errno::Errno; use crate::Result; @@ -112,6 +115,255 @@ libc_bitflags! { } } +/// Timeout argument for [`poll`]. +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct PollTimeout(i32); + +/// Error type for [`PollTimeout::try_from::::()`]. +#[derive(Debug, Clone, Copy)] +pub enum TryFromI128Error { + /// Value is less than -1. + Underflow(crate::Errno), + /// Value is greater than [`i32::MAX`]. + Overflow(>::Error), +} +impl fmt::Display for TryFromI128Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Underflow(err) => write!(f, "Underflow: {}", err), + Self::Overflow(err) => write!(f, "Overflow: {}", err), + } + } +} +impl std::error::Error for TryFromI128Error {} + +/// Error type for [`PollTimeout::try_from::()`]. +#[derive(Debug, Clone, Copy)] +pub enum TryFromI64Error { + /// Value is less than -1. + Underflow(crate::Errno), + /// Value is greater than [`i32::MAX`]. + Overflow(>::Error), +} +impl fmt::Display for TryFromI64Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Underflow(err) => write!(f, "Underflow: {}", err), + Self::Overflow(err) => write!(f, "Overflow: {}", err), + } + } +} +impl std::error::Error for TryFromI64Error {} + +// These cases implement slightly different conversions that make using generics impossible without +// specialization. +impl PollTimeout { + /// Blocks indefinitely. + pub const NONE: Self = Self(1 << 31); + /// Returns immediately. + pub const ZERO: Self = Self(0); + /// Blocks for at most [`std::i32::MAX`] milliseconds. + pub const MAX: Self = Self(i32::MAX); + /// Returns if `self` equals [`PollTimeout::NONE`]. + pub fn is_none(&self) -> bool { + *self == Self::NONE + } + /// Returns if `self` does not equal [`PollTimeout::NONE`]. + pub fn is_some(&self) -> bool { + !self.is_none() + } + /// Returns the timeout in milliseconds if there is some, otherwise returns `None`. + pub fn timeout(&self) -> Option { + self.is_some().then(|| self.0) + } +} +impl TryFrom for PollTimeout { + type Error = >::Error; + fn try_from(x: Duration) -> std::result::Result { + Ok(Self(i32::try_from(x.as_millis())?)) + } +} +impl TryFrom for PollTimeout { + type Error = >::Error; + fn try_from(x: u128) -> std::result::Result { + Ok(Self(i32::try_from(x)?)) + } +} +impl TryFrom for PollTimeout { + type Error = >::Error; + fn try_from(x: u64) -> std::result::Result { + Ok(Self(i32::try_from(x)?)) + } +} +impl TryFrom for PollTimeout { + type Error = >::Error; + fn try_from(x: u32) -> std::result::Result { + Ok(Self(i32::try_from(x)?)) + } +} +impl From for PollTimeout { + fn from(x: u16) -> Self { + Self(i32::from(x)) + } +} +impl From for PollTimeout { + fn from(x: u8) -> Self { + Self(i32::from(x)) + } +} +impl TryFrom for PollTimeout { + type Error = TryFromI128Error; + fn try_from(x: i128) -> std::result::Result { + match x { + -1 => Ok(Self::NONE), + millis @ 0.. => Ok(Self( + i32::try_from(millis).map_err(TryFromI128Error::Overflow)?, + )), + // EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative). + _ => Err(TryFromI128Error::Underflow(Errno::EINVAL)), + } + } +} +impl TryFrom for PollTimeout { + type Error = TryFromI64Error; + fn try_from(x: i64) -> std::result::Result { + match x { + -1 => Ok(Self::NONE), + millis @ 0.. => Ok(Self( + i32::try_from(millis).map_err(TryFromI64Error::Overflow)?, + )), + // EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative). + _ => Err(TryFromI64Error::Underflow(Errno::EINVAL)), + } + } +} +impl TryFrom for PollTimeout { + type Error = Errno; + fn try_from(x: i32) -> Result { + match x { + -1 => Ok(Self::NONE), + millis @ 0.. => Ok(Self(millis)), + // EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative). + _ => Err(Errno::EINVAL), + } + } +} +impl TryFrom for PollTimeout { + type Error = Errno; + fn try_from(x: i16) -> Result { + match x { + -1 => Ok(Self::NONE), + millis @ 0.. => Ok(Self(millis.into())), + // EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative). + _ => Err(Errno::EINVAL), + } + } +} +impl TryFrom for PollTimeout { + type Error = Errno; + fn try_from(x: i8) -> Result { + match x { + -1 => Ok(Self::NONE), + millis @ 0.. => Ok(Self(millis.into())), + // EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative). + _ => Err(Errno::EINVAL), + } + } +} +impl TryFrom for Duration { + type Error = (); + fn try_from(x: PollTimeout) -> std::result::Result { + match x.timeout() { + // SAFETY: `x.0` is always positive. + Some(millis) => Ok(Duration::from_millis(unsafe { + u64::try_from(millis).unwrap_unchecked() + })), + None => Err(()), + } + } +} +impl TryFrom for u128 { + type Error = (); + fn try_from(x: PollTimeout) -> std::result::Result { + match x.timeout() { + // SAFETY: When `x.timeout()` returns `Some(a)`, `a` is always positive. + Some(millis) => { + Ok(unsafe { Self::try_from(millis).unwrap_unchecked() }) + } + None => Err(()), + } + } +} +impl TryFrom for u64 { + type Error = (); + fn try_from(x: PollTimeout) -> std::result::Result { + match x.timeout() { + // SAFETY: When `x.timeout()` returns `Some(a)`, `a` is always positive. + Some(millis) => { + Ok(unsafe { Self::try_from(millis).unwrap_unchecked() }) + } + None => Err(()), + } + } +} +impl TryFrom for u32 { + type Error = (); + fn try_from(x: PollTimeout) -> std::result::Result { + match x.timeout() { + // SAFETY: When `x.timeout()` returns `Some(a)`, `a` is always positive. + Some(millis) => { + Ok(unsafe { Self::try_from(millis).unwrap_unchecked() }) + } + None => Err(()), + } + } +} +impl TryFrom for u16 { + type Error = Option<>::Error>; + fn try_from(x: PollTimeout) -> std::result::Result { + match x.timeout() { + Some(millis) => Ok(Self::try_from(millis).map_err(Some)?), + None => Err(None), + } + } +} +impl TryFrom for u8 { + type Error = Option<>::Error>; + fn try_from(x: PollTimeout) -> std::result::Result { + match x.timeout() { + Some(millis) => Ok(Self::try_from(millis).map_err(Some)?), + None => Err(None), + } + } +} +impl From for i128 { + fn from(x: PollTimeout) -> Self { + x.timeout().unwrap_or(-1).into() + } +} +impl From for i64 { + fn from(x: PollTimeout) -> Self { + x.timeout().unwrap_or(-1).into() + } +} +impl From for i32 { + fn from(x: PollTimeout) -> Self { + x.timeout().unwrap_or(-1) + } +} +impl TryFrom for i16 { + type Error = >::Error; + fn try_from(x: PollTimeout) -> std::result::Result { + Self::try_from(x.timeout().unwrap_or(-1)) + } +} +impl TryFrom for i8 { + type Error = >::Error; + fn try_from(x: PollTimeout) -> std::result::Result { + Self::try_from(x.timeout().unwrap_or(-1)) + } +} + /// `poll` waits for one of a set of file descriptors to become ready to perform I/O. /// ([`poll(2)`](https://pubs.opengroup.org/onlinepubs/9699919799/functions/poll.html)) /// @@ -132,12 +384,12 @@ libc_bitflags! { /// in timeout means an infinite timeout. Specifying a timeout of zero /// causes `poll()` to return immediately, even if no file descriptors are /// ready. -pub fn poll(fds: &mut [PollFd], timeout: libc::c_int) -> Result { +pub fn poll(fds: &mut [PollFd], timeout: PollTimeout) -> Result { let res = unsafe { libc::poll( fds.as_mut_ptr() as *mut libc::pollfd, fds.len() as libc::nfds_t, - timeout, + timeout.into(), ) }; diff --git a/test/test_poll.rs b/test/test_poll.rs index 53964e26bb..ca05745dd4 100644 --- a/test/test_poll.rs +++ b/test/test_poll.rs @@ -1,6 +1,6 @@ use nix::{ errno::Errno, - poll::{poll, PollFd, PollFlags}, + poll::{poll, PollFd, PollFlags, PollTimeout}, unistd::{pipe, write}, }; @@ -22,14 +22,14 @@ fn test_poll() { let mut fds = [PollFd::new(r, PollFlags::POLLIN)]; // Poll an idle pipe. Should timeout - let nfds = loop_while_eintr!(poll(&mut fds, 100)); + let nfds = loop_while_eintr!(poll(&mut fds, PollTimeout::from(100u8))); assert_eq!(nfds, 0); assert!(!fds[0].revents().unwrap().contains(PollFlags::POLLIN)); write(w, b".").unwrap(); // Poll a readable pipe. Should return an event. - let nfds = poll(&mut fds, 100).unwrap(); + let nfds = poll(&mut fds, PollTimeout::from(100u8)).unwrap(); assert_eq!(nfds, 1); assert!(fds[0].revents().unwrap().contains(PollFlags::POLLIN)); }