Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tokio: introduce io::Duplex #2661

Merged
merged 1 commit into from
Jul 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tokio/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ cfg_io_util! {

pub(crate) mod util;
pub use util::{
copy, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt,
BufReader, BufStream, BufWriter, Copy, Empty, Lines, Repeat, Sink, Split, Take,
copy, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt,
BufReader, BufStream, BufWriter, DuplexStream, Copy, Empty, Lines, Repeat, Sink, Split, Take,
};

cfg_stream! {
Expand Down
222 changes: 222 additions & 0 deletions tokio/src/io/util/mem.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
//! In-process memory IO types.

use crate::io::{AsyncRead, AsyncWrite};
use crate::loom::sync::Mutex;

use bytes::{Buf, BytesMut};
use std::{
pin::Pin,
sync::Arc,
task::{self, Poll, Waker},
};

/// A bidirectional pipe to read and write bytes in memory.
seanmonstar marked this conversation as resolved.
Show resolved Hide resolved
///
/// A pair of `DuplexStream`s are created together, and they act as a "channel"
/// that can be used as in-memory IO types. Writing to one of the pairs will
/// allow that data to be read from the other, and vice versa.
///
/// # Example
///
/// ```
/// # async fn ex() -> std::io::Result<()> {
/// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
/// let (mut client, mut server) = tokio::io::duplex(64);
///
/// client.write_all(b"ping").await?;
///
/// let mut buf = [0u8; 4];
/// server.read_exact(&mut buf).await?;
/// assert_eq!(&buf, b"ping");
///
/// server.write_all(b"pong").await?;
///
/// client.read_exact(&mut buf).await?;
/// assert_eq!(&buf, b"pong");
/// # Ok(())
/// # }
/// ```
#[derive(Debug)]
pub struct DuplexStream {
read: Arc<Mutex<Pipe>>,
write: Arc<Mutex<Pipe>>,
}

/// A unidirectional IO over a piece of memory.
///
/// Data can be written to the pipe, and reading will return that data.
#[derive(Debug)]
struct Pipe {
/// The buffer storing the bytes written, also read from.
///
/// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
/// functionality already. Additionally, it can try to copy data in the
/// same buffer if there read index has advanced far enough.
buffer: BytesMut,
/// Determines if the write side has been closed.
is_closed: bool,
/// The maximum amount of bytes that can be written before returning
/// `Poll::Pending`.
max_buf_size: usize,
/// If the `read` side has been polled and is pending, this is the waker
/// for that parked task.
read_waker: Option<Waker>,
/// If the `write` side has filled the `max_buf_size` and returned
/// `Poll::Pending`, this is the waker for that parked task.
write_waker: Option<Waker>,
}

// ===== impl DuplexStream =====

/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
///
/// The `max_buf_size` argument is the maximum amount of bytes that can be
/// written to a side before the write returns `Poll::Pending`.
pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
let one = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
let two = Arc::new(Mutex::new(Pipe::new(max_buf_size)));

(
DuplexStream {
read: one.clone(),
write: two.clone(),
},
DuplexStream {
read: two,
write: one,
},
)
}

impl AsyncRead for DuplexStream {
// Previous rustc required this `self` to be `mut`, even though newer
// versions recognize it isn't needed to call `lock()`. So for
// compatibility, we include the `mut` and `allow` the lint.
//
// See https://github.com/rust-lang/rust/issues/73592
#[allow(unused_mut)]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut *self.read.lock().unwrap()).poll_read(cx, buf)
}
}

impl AsyncWrite for DuplexStream {
#[allow(unused_mut)]
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut *self.write.lock().unwrap()).poll_write(cx, buf)
}

#[allow(unused_mut)]
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut *self.write.lock().unwrap()).poll_flush(cx)
}

#[allow(unused_mut)]
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut *self.write.lock().unwrap()).poll_shutdown(cx)
}
}

impl Drop for DuplexStream {
fn drop(&mut self) {
// notify the other side of the closure
self.write.lock().unwrap().close();
}
}

// ===== impl Pipe =====

impl Pipe {
fn new(max_buf_size: usize) -> Self {
Pipe {
buffer: BytesMut::new(),
is_closed: false,
max_buf_size,
read_waker: None,
write_waker: None,
}
}

fn close(&mut self) {
self.is_closed = true;
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
}
}

impl AsyncRead for Pipe {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
if self.buffer.has_remaining() {
let max = self.buffer.remaining().min(buf.len());
self.buffer.copy_to_slice(&mut buf[..max]);
if max > 0 {
// The passed `buf` might have been empty, don't wake up if
// no bytes have been moved.
if let Some(waker) = self.write_waker.take() {
waker.wake();
}
}
Poll::Ready(Ok(max))
} else if self.is_closed {
Poll::Ready(Ok(0))
} else {
self.read_waker = Some(cx.waker().clone());
Poll::Pending
}
}
}

impl AsyncWrite for Pipe {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
if self.is_closed {
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
}
let avail = self.max_buf_size - self.buffer.len();
if avail == 0 {
self.write_waker = Some(cx.waker().clone());
return Poll::Pending;
}

let len = buf.len().min(avail);
self.buffer.extend_from_slice(&buf[..len]);
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
Poll::Ready(Ok(len))
}

fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
_: &mut task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.close();
Poll::Ready(Ok(()))
}
}
3 changes: 3 additions & 0 deletions tokio/src/io/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ cfg_io_util! {
mod lines;
pub use lines::Lines;

mod mem;
pub use mem::{duplex, DuplexStream};

mod read;
mod read_buf;
mod read_exact;
Expand Down
83 changes: 83 additions & 0 deletions tokio/tests/io_mem_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};

#[tokio::test]
async fn ping_pong() {
let (mut a, mut b) = duplex(32);

let mut buf = [0u8; 4];

a.write_all(b"ping").await.unwrap();
b.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping");

b.write_all(b"pong").await.unwrap();
a.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"pong");
}

#[tokio::test]
async fn across_tasks() {
let (mut a, mut b) = duplex(32);

let t1 = tokio::spawn(async move {
a.write_all(b"ping").await.unwrap();
let mut buf = [0u8; 4];
a.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"pong");
});

let t2 = tokio::spawn(async move {
let mut buf = [0u8; 4];
b.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping");
b.write_all(b"pong").await.unwrap();
});

t1.await.unwrap();
t2.await.unwrap();
}

#[tokio::test]
async fn disconnect() {
let (mut a, mut b) = duplex(32);

let t1 = tokio::spawn(async move {
a.write_all(b"ping").await.unwrap();
// and dropped
});

let t2 = tokio::spawn(async move {
let mut buf = [0u8; 32];
let n = b.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"ping");

let n = b.read(&mut buf).await.unwrap();
assert_eq!(n, 0);
});

t1.await.unwrap();
t2.await.unwrap();
}

#[tokio::test]
async fn max_write_size() {
let (mut a, mut b) = duplex(32);

let t1 = tokio::spawn(async move {
let n = a.write(&[0u8; 64]).await.unwrap();
assert_eq!(n, 32);
let n = a.write(&[0u8; 64]).await.unwrap();
assert_eq!(n, 4);
});

let t2 = tokio::spawn(async move {
let mut buf = [0u8; 4];
b.read_exact(&mut buf).await.unwrap();
});

t1.await.unwrap();
t2.await.unwrap();
}