From a0a9051d5f80a68f66ccf43f0cebd94c1e868077 Mon Sep 17 00:00:00 2001 From: Simon Farnsworth Date: Tue, 20 Sep 2022 10:46:33 +0100 Subject: [PATCH 1/7] tokio_util: add inspection wrapper for `AsyncRead` There are use cases like checking hashes of files that benefit from being able to inspect bytes read as they come in, while still letting the main code process the bytes as normal (e.g. deserializing into objects, knowing that if there's a hash failure, you'll discard the result). As this is non-trivial to get right (e.g. handling a `buf` that's not empty when passed to `poll_read`, add a wrapper `InspectReader` that gets this right, passing all newly read bytes to a supplied `FnMut` closure. Fixes: #4584 --- tokio-util/src/io/inspect.rs | 46 ++++++++++++++++++++++++++++++++++ tokio-util/src/io/mod.rs | 3 +++ tokio-util/tests/io_inspect.rs | 43 +++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+) create mode 100644 tokio-util/src/io/inspect.rs create mode 100644 tokio-util/tests/io_inspect.rs diff --git a/tokio-util/src/io/inspect.rs b/tokio-util/src/io/inspect.rs new file mode 100644 index 00000000000..8fbd50f3513 --- /dev/null +++ b/tokio-util/src/io/inspect.rs @@ -0,0 +1,46 @@ +use futures_core::ready; +use pin_project_lite::pin_project; +use std::{ + io::Result, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, ReadBuf}; + +pin_project! { + /// An adapter that lets you inspect the data that's being read. + /// + /// This is useful for things like hashing data as it's read in. + pub struct InspectReader { + #[pin] + reader: R, + f: F, + } +} + +impl InspectReader { + /// Create a new InspectReader, wrapping `reader` and calling `f` for the + /// new data supplied by each read call. + pub fn new(reader: R, f: F) -> InspectReader { + InspectReader { reader, f } + } + + /// Consumes the `InspectReader`, returning the wrapped reader + pub fn into_inner(self) -> R { + self.reader + } +} + +impl AsyncRead for InspectReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let me = self.project(); + let filled_length = buf.filled().len(); + ready!(me.reader.poll_read(cx, buf))?; + (me.f)(&buf.filled()[filled_length..]); + Poll::Ready(Ok(())) + } +} diff --git a/tokio-util/src/io/mod.rs b/tokio-util/src/io/mod.rs index eb48a21fb98..df0dd02fdc0 100644 --- a/tokio-util/src/io/mod.rs +++ b/tokio-util/src/io/mod.rs @@ -10,14 +10,17 @@ //! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html //! [`AsyncRead`]: tokio::io::AsyncRead +mod inspect; mod read_buf; mod reader_stream; mod stream_reader; + cfg_io_util! { mod sync_bridge; pub use self::sync_bridge::SyncIoBridge; } +pub use self::inspect::InspectReader; pub use self::read_buf::read_buf; pub use self::reader_stream::ReaderStream; pub use self::stream_reader::StreamReader; diff --git a/tokio-util/tests/io_inspect.rs b/tokio-util/tests/io_inspect.rs new file mode 100644 index 00000000000..aa907c0c148 --- /dev/null +++ b/tokio-util/tests/io_inspect.rs @@ -0,0 +1,43 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; +use tokio_util::io::InspectReader; + +/// An AsyncRead implementation that works byte-by-byte, to catch out callers +/// who don't allow for `buf` being part-filled before the call +struct SmallReader { + contents: Vec, +} + +impl Unpin for SmallReader {} + +impl AsyncRead for SmallReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(byte) = self.contents.pop() { + buf.put_slice(&[byte]) + } + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn read_tee() { + let contents = b"This could be really long, you know".to_vec(); + let reader = SmallReader { + contents: contents.clone(), + }; + let mut altout: Vec = Vec::new(); + let mut teeout = Vec::new(); + { + let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes)); + tee.read_to_end(&mut teeout).await.unwrap(); + } + assert_eq!(teeout, altout); + assert_eq!(altout.len(), contents.len()); +} From e9520fe583372923a93c0b43068024b9cbc124bb Mon Sep 17 00:00:00 2001 From: Simon Farnsworth Date: Tue, 20 Sep 2022 10:53:47 +0100 Subject: [PATCH 2/7] tokio_util: add inspection wrapper for `AsyncWrite` When writing things out, it's useful to be able to inspect the bytes that are being written and do things like hash them as they go past. This isn't trivial to get right, due to partial writes and efficiently handling vectored writes (if used). Provide an `InspectWriter` wrapper that gets this right, giving a supplied `FnMut` closure a chance to inspect the buffers that have been successfully written out. Fixes: #4584 --- tokio-util/src/io/inspect.rs | 73 ++++++++++++++- tokio-util/src/io/mod.rs | 2 +- tokio-util/tests/io_inspect.rs | 156 ++++++++++++++++++++++++++++++++- 3 files changed, 226 insertions(+), 5 deletions(-) diff --git a/tokio-util/src/io/inspect.rs b/tokio-util/src/io/inspect.rs index 8fbd50f3513..53b1d5504ad 100644 --- a/tokio-util/src/io/inspect.rs +++ b/tokio-util/src/io/inspect.rs @@ -1,11 +1,11 @@ use futures_core::ready; use pin_project_lite::pin_project; use std::{ - io::Result, + io::{IoSlice, Result}, pin::Pin, task::{Context, Poll}, }; -use tokio::io::{AsyncRead, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pin_project! { /// An adapter that lets you inspect the data that's being read. @@ -44,3 +44,72 @@ impl AsyncRead for InspectReader { Poll::Ready(Ok(())) } } + +pin_project! { + /// An adapter that lets you inspect the data that's being written. + /// + /// This is useful for things like hashing data as it's written out. + pub struct InspectWriter { + #[pin] + writer: W, + f: F, + } +} + +impl InspectWriter { + /// Create a new InspectWriter, wrapping `write` and calling `f` for the + /// data successfully written by each write call. + pub fn new(writer: W, f: F) -> InspectWriter { + InspectWriter { writer, f } + } + + /// Consumes the `InspectWriter`, returning the wrapped writer + pub fn into_inner(self) -> W { + self.writer + } +} + +impl AsyncWrite for InspectWriter { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let me = self.project(); + let res = me.writer.poll_write(cx, buf); + if let Poll::Ready(Ok(count)) = res { + (me.f)(&buf[..count]); + } + res + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + me.writer.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + me.writer.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let me = self.project(); + let res = me.writer.poll_write_vectored(cx, bufs); + if let Poll::Ready(Ok(mut count)) = res { + for buf in bufs { + let size = count.min(buf.len()); + (me.f)(&buf[..size]); + count -= size; + if count == 0 { + break; + } + } + } + res + } + + fn is_write_vectored(&self) -> bool { + self.writer.is_write_vectored() + } +} diff --git a/tokio-util/src/io/mod.rs b/tokio-util/src/io/mod.rs index df0dd02fdc0..317d93b3640 100644 --- a/tokio-util/src/io/mod.rs +++ b/tokio-util/src/io/mod.rs @@ -20,7 +20,7 @@ cfg_io_util! { pub use self::sync_bridge::SyncIoBridge; } -pub use self::inspect::InspectReader; +pub use self::inspect::{InspectReader, InspectWriter}; pub use self::read_buf::read_buf; pub use self::reader_stream::ReaderStream; pub use self::stream_reader::StreamReader; diff --git a/tokio-util/tests/io_inspect.rs b/tokio-util/tests/io_inspect.rs index aa907c0c148..5da13f57cd6 100644 --- a/tokio-util/tests/io_inspect.rs +++ b/tokio-util/tests/io_inspect.rs @@ -1,9 +1,11 @@ +use futures::future::poll_fn; use std::{ + io::IoSlice, pin::Pin, task::{Context, Poll}, }; -use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; -use tokio_util::io::InspectReader; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio_util::io::{InspectReader, InspectWriter}; /// An AsyncRead implementation that works byte-by-byte, to catch out callers /// who don't allow for `buf` being part-filled before the call @@ -41,3 +43,153 @@ async fn read_tee() { assert_eq!(teeout, altout); assert_eq!(altout.len(), contents.len()); } + +/// An AsyncWrite implementation that works byte-by-byte for poll_write, and +/// that reads the whole of the first buffer plus one byte from the second in +/// poll_write_vectored. +/// +/// This is designed to catch bugs in handling partially written buffers +#[derive(Debug)] +struct SmallWriter { + contents: Vec, +} + +impl Unpin for SmallWriter {} + +impl AsyncWrite for SmallWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + // Just write one byte at a time + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + self.contents.push(buf[0]); + Poll::Ready(Ok(1)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + // Write all of the first buffer, then one byte from the second buffer + // This should trip up anything that doesn't correctly handle multiple + // buffers. + if bufs.is_empty() { + return Poll::Ready(Ok(0)); + } + let mut written_len = bufs[0].len(); + self.contents.extend_from_slice(&bufs[0]); + + if bufs.len() > 1 { + let buf = bufs[1]; + if !buf.is_empty() { + written_len += 1; + self.contents.push(buf[0]); + } + } + Poll::Ready(Ok(written_len)) + } + + fn is_write_vectored(&self) -> bool { + true + } +} + +#[tokio::test] +async fn write_tee() { + let mut altout: Vec = Vec::new(); + let mut writeout = SmallWriter { + contents: Vec::new(), + }; + { + let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes)); + tee.write_all(b"A testing string, very testing") + .await + .unwrap(); + } + assert_eq!(altout, writeout.contents); +} + +// This is inefficient, but works well enough for test use. +// If you want something similar for real code, you'll want to avoid all the +// fun of manipulating `bufs` - ideally, by the time you read this, +// IoSlice::advance_slices will be stable, and you can use that. +async fn write_all_vectored( + mut writer: W, + mut bufs: Vec>, +) -> Result { + let mut res = 0; + while !bufs.is_empty() { + let mut written = poll_fn(|cx| { + let bufs: Vec = bufs.iter().map(|v| IoSlice::new(&v)).collect(); + Pin::new(&mut writer).poll_write_vectored(cx, &bufs) + }) + .await?; + res += written; + while written > 0 { + let buf_len = bufs[0].len(); + if buf_len <= written { + bufs.remove(0); + written -= buf_len; + } else { + let buf = &mut bufs[0]; + while written > 0 { + buf.remove(0); + written -= 1; + } + } + } + } + Ok(res) +} + +#[tokio::test] +async fn write_tee_vectored() { + let mut altout: Vec = Vec::new(); + let mut writeout = SmallWriter { + contents: Vec::new(), + }; + let original = b"A very long string split up"; + let bufs: Vec> = original + .split(|b| b.is_ascii_whitespace()) + .map(Vec::from) + .collect(); + assert!(bufs.len() > 1); + let expected: Vec = { + let mut out = Vec::new(); + for item in &bufs { + out.extend_from_slice(item) + } + out + }; + { + let mut bufcount = 0; + let tee = InspectWriter::new(&mut writeout, |bytes| { + bufcount += 1; + altout.extend(bytes) + }); + + assert!(tee.is_write_vectored()); + + write_all_vectored(tee, bufs.clone()).await.unwrap(); + + assert!(bufcount >= bufs.len()); + } + assert_eq!(altout, writeout.contents); + assert_eq!(writeout.contents, expected); +} From cd68c4821901cb985e8486b1367e62858d52396f Mon Sep 17 00:00:00 2001 From: Simon Farnsworth Date: Wed, 28 Sep 2022 11:21:56 +0100 Subject: [PATCH 3/7] update to master, fix @Darksonn's review comments --- tokio-util/src/io/inspect.rs | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/tokio-util/src/io/inspect.rs b/tokio-util/src/io/inspect.rs index 53b1d5504ad..ff6e5898ead 100644 --- a/tokio-util/src/io/inspect.rs +++ b/tokio-util/src/io/inspect.rs @@ -1,27 +1,30 @@ use futures_core::ready; use pin_project_lite::pin_project; -use std::{ - io::{IoSlice, Result}, - pin::Pin, - task::{Context, Poll}, -}; +use std::io::{IoSlice, Result}; +use std::pin::Pin; +use std::task::{Context, Poll}; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pin_project! { /// An adapter that lets you inspect the data that's being read. /// /// This is useful for things like hashing data as it's read in. - pub struct InspectReader { + pub struct InspectReader { #[pin] reader: R, f: F, } } -impl InspectReader { +impl InspectReader { /// Create a new InspectReader, wrapping `reader` and calling `f` for the /// new data supplied by each read call. - pub fn new(reader: R, f: F) -> InspectReader { + pub fn new(reader: R, f: F) -> InspectReader + where + R: AsyncRead, + F: FnMut(&[u8]), + { InspectReader { reader, f } } @@ -49,17 +52,21 @@ pin_project! { /// An adapter that lets you inspect the data that's being written. /// /// This is useful for things like hashing data as it's written out. - pub struct InspectWriter { + pub struct InspectWriter { #[pin] writer: W, f: F, } } -impl InspectWriter { +impl InspectWriter { /// Create a new InspectWriter, wrapping `write` and calling `f` for the /// data successfully written by each write call. - pub fn new(writer: W, f: F) -> InspectWriter { + pub fn new(writer: W, f: F) -> InspectWriter + where + W: AsyncWrite, + F: FnMut(&[u8]), + { InspectWriter { writer, f } } From 9d15d7a6808dc1c8003c05cc3bcd0c94ad723e59 Mon Sep 17 00:00:00 2001 From: Simon Farnsworth Date: Wed, 28 Sep 2022 12:05:47 +0100 Subject: [PATCH 4/7] make clippy happy --- tokio-util/tests/io_inspect.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokio-util/tests/io_inspect.rs b/tokio-util/tests/io_inspect.rs index 5da13f57cd6..64bb9bd9904 100644 --- a/tokio-util/tests/io_inspect.rs +++ b/tokio-util/tests/io_inspect.rs @@ -136,7 +136,7 @@ async fn write_all_vectored( let mut res = 0; while !bufs.is_empty() { let mut written = poll_fn(|cx| { - let bufs: Vec = bufs.iter().map(|v| IoSlice::new(&v)).collect(); + let bufs: Vec = bufs.iter().map(|v| IoSlice::new(v)).collect(); Pin::new(&mut writer).poll_write_vectored(cx, &bufs) }) .await?; From ed89580a4d977b331fc7303126144a61a4e76a7d Mon Sep 17 00:00:00 2001 From: Simon Farnsworth Date: Wed, 28 Sep 2022 14:02:02 +0100 Subject: [PATCH 5/7] use drain instead of a byte-by-byte loop Co-authored-by: Alice Ryhl --- tokio-util/tests/io_inspect.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tokio-util/tests/io_inspect.rs b/tokio-util/tests/io_inspect.rs index 64bb9bd9904..e6319afcf1b 100644 --- a/tokio-util/tests/io_inspect.rs +++ b/tokio-util/tests/io_inspect.rs @@ -148,10 +148,9 @@ async fn write_all_vectored( written -= buf_len; } else { let buf = &mut bufs[0]; - while written > 0 { - buf.remove(0); - written -= 1; - } + let drain_len = written.min(buf.len()); + buf.drain(..drain_len); + written -= drain_len; } } } From 073940213d1b8c3d2c59d64f39f5b5dccbec8986 Mon Sep 17 00:00:00 2001 From: Simon Farnsworth Date: Wed, 28 Sep 2022 14:50:09 +0100 Subject: [PATCH 6/7] no empty slices on write --- tokio-util/src/io/inspect.rs | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tokio-util/src/io/inspect.rs b/tokio-util/src/io/inspect.rs index ff6e5898ead..ac6ba0e4e66 100644 --- a/tokio-util/src/io/inspect.rs +++ b/tokio-util/src/io/inspect.rs @@ -20,6 +20,9 @@ pin_project! { impl InspectReader { /// Create a new InspectReader, wrapping `reader` and calling `f` for the /// new data supplied by each read call. + /// + /// If no new data is supplied by a successful `poll_read`, then `f` will + /// be called with an empty slice. pub fn new(reader: R, f: F) -> InspectReader where R: AsyncRead, @@ -62,6 +65,10 @@ pin_project! { impl InspectWriter { /// Create a new InspectWriter, wrapping `write` and calling `f` for the /// data successfully written by each write call. + /// + /// `f` will never be called with an empty slice; a vectored write will + /// result in multiple calls to `f`, one for each buffer that was used by + /// the write. pub fn new(writer: W, f: F) -> InspectWriter where W: AsyncWrite, @@ -81,7 +88,9 @@ impl AsyncWrite for InspectWriter { let me = self.project(); let res = me.writer.poll_write(cx, buf); if let Poll::Ready(Ok(count)) = res { - (me.f)(&buf[..count]); + if count != 0 { + (me.f)(&buf[..count]); + } } res } @@ -105,12 +114,14 @@ impl AsyncWrite for InspectWriter { let res = me.writer.poll_write_vectored(cx, bufs); if let Poll::Ready(Ok(mut count)) = res { for buf in bufs { - let size = count.min(buf.len()); - (me.f)(&buf[..size]); - count -= size; if count == 0 { break; } + let size = count.min(buf.len()); + if size != 0 { + (me.f)(&buf[..size]); + count -= size; + } } } res From 542ee971164a8c2da5dec3790c3a4b4c6c3f57ad Mon Sep 17 00:00:00 2001 From: Simon Farnsworth Date: Wed, 28 Sep 2022 18:54:06 +0100 Subject: [PATCH 7/7] Take @Darksonn's advice on writing clear docs --- tokio-util/src/io/inspect.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tokio-util/src/io/inspect.rs b/tokio-util/src/io/inspect.rs index ac6ba0e4e66..ec5bb97e61c 100644 --- a/tokio-util/src/io/inspect.rs +++ b/tokio-util/src/io/inspect.rs @@ -21,8 +21,9 @@ impl InspectReader { /// Create a new InspectReader, wrapping `reader` and calling `f` for the /// new data supplied by each read call. /// - /// If no new data is supplied by a successful `poll_read`, then `f` will - /// be called with an empty slice. + /// The closure will only be called with an empty slice if the inner reader + /// returns without reading data into the buffer. This happens at EOF, or if + /// `poll_read` is called with a zero-size buffer. pub fn new(reader: R, f: F) -> InspectReader where R: AsyncRead, @@ -66,9 +67,9 @@ impl InspectWriter { /// Create a new InspectWriter, wrapping `write` and calling `f` for the /// data successfully written by each write call. /// - /// `f` will never be called with an empty slice; a vectored write will - /// result in multiple calls to `f`, one for each buffer that was used by - /// the write. + /// The closure `f` will never be called with an empty slice. A vectored + /// write can result in multiple calls to `f` - at most one call to `f` per + /// buffer supplied to `poll_write_vectored`. pub fn new(writer: W, f: F) -> InspectWriter where W: AsyncWrite,