diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index 56948ffa66e..92671a1f64d 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -807,6 +807,61 @@ impl TcpStream { mio_socket.set_linger(dur) } + /// Reads the keepalive duration for this socket by getting the `SO_KEEPALIVE` + /// option along with more system-specific parameters (e.g. TCP_KEEPALIVE + /// or SIO_KEEPALIVE_VALS). + /// + /// For more information about this option, see [`set_keepalive`]. + /// + /// [`set_keepalive`]: TcpStream::set_keepalive + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// + /// # async fn dox() -> Result<(), Box> { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// println!("{:?}", stream.keepalive()?); + /// # Ok(()) + /// # } + /// ``` + #[cfg_attr(docsrs, doc(cfg(not(target_os = "windows"))))] + #[cfg(not(target_os = "windows"))] + pub fn keepalive(&self) -> io::Result> { + let mio_socket = std::mem::ManuallyDrop::new(self.to_mio()); + mio_socket.get_keepalive_time() + } + + /// Sets the keepalive duration of this socket by setting the SO_KEEPALIVE option + /// along with more system-specific parameters (e.g. TCP_KEEPALIVE or SIO_KEEPALIVE_VALS). + /// + /// This option controls whether keep-alive TCP packets should be used + /// for a socket connection and what should be their idle interval. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// + /// # async fn dox() -> Result<(), Box> { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// stream.set_keepalive(None)?; + /// # Ok(()) + /// # } + /// ``` + pub fn set_keepalive(&self, dur: Option) -> io::Result<()> { + let mio_socket = std::mem::ManuallyDrop::new(self.to_mio()); + + if let Some(duration) = dur { + mio_socket.set_keepalive_params(mio::net::TcpKeepalive::new().with_time(duration)) + } else { + mio_socket.set_keepalive(false) + } + } + fn to_mio(&self) -> mio::net::TcpSocket { #[cfg(windows)] { diff --git a/tokio/tests/tcp_stream.rs b/tokio/tests/tcp_stream.rs index 58b06ee3233..3f2eda6b152 100644 --- a/tokio/tests/tcp_stream.rs +++ b/tokio/tests/tcp_stream.rs @@ -28,6 +28,22 @@ async fn set_linger() { assert!(stream.linger().unwrap().is_none()); } +#[tokio::test] +#[cfg(not(windows))] +async fn set_keepalive() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + + let stream = TcpStream::connect(listener.local_addr().unwrap()) + .await + .unwrap(); + + assert_ok!(stream.set_keepalive(Some(Duration::from_secs(1)))); + assert_eq!(stream.keepalive().unwrap().unwrap().as_secs(), 1); + + assert_ok!(stream.set_keepalive(None)); + assert!(stream.keepalive().unwrap().is_none()); +} + #[tokio::test] async fn try_read_write() { const DATA: &[u8] = b"this is some data to write to the socket";