Skip to content

Commit

Permalink
implement write_vectored for DuplexStream
Browse files Browse the repository at this point in the history
  • Loading branch information
maminrayej committed Sep 5, 2023
1 parent 8ea303e commit 0b23fbd
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tokio/src/io/util/mem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ impl AsyncWrite for DuplexStream {
Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut *self.write.lock()).poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
true
}

#[allow(unused_mut)]
fn poll_flush(
mut self: Pin<&mut Self>,
Expand Down Expand Up @@ -285,6 +297,41 @@ impl AsyncWrite for Pipe {
}
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
if self.is_closed {
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
}
let avail = self.max_buf_size - self.buffer.len();
if avail == 0 {
self.write_waker = Some(cx.waker().clone());
return Poll::Pending;
}

let mut rem = avail;
for buf in bufs {
if rem == 0 {
break;
}

let len = buf.len().min(rem);
self.buffer.extend_from_slice(&buf[..len]);
rem -= len;
}

if let Some(waker) = self.read_waker.take() {
waker.wake();
}
Poll::Ready(Ok(avail - rem))
}

fn is_write_vectored(&self) -> bool {
true
}

fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
Expand Down
47 changes: 47 additions & 0 deletions tokio/tests/duplex_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use std::io::IoSlice;
use tokio::io::{AsyncReadExt, AsyncWriteExt};

const HELLO: &[u8] = b"hello world...";

#[tokio::test]
async fn write_vectored() {
let (mut client, mut server) = tokio::io::duplex(64);

let ret = client
.write_vectored(&[IoSlice::new(HELLO), IoSlice::new(HELLO)])
.await
.unwrap();
assert_eq!(ret, HELLO.len() * 2);

client.flush().await.unwrap();
drop(client);

let mut buf = Vec::with_capacity(HELLO.len() * 2);
let bytes_read = server.read_to_end(&mut buf).await.unwrap();

assert_eq!(bytes_read, HELLO.len() * 2);
assert_eq!(buf, [HELLO, HELLO].concat());
}

#[tokio::test]
async fn write_vectored_and_shutdown() {
let (mut client, mut server) = tokio::io::duplex(64);

let ret = client
.write_vectored(&[IoSlice::new(HELLO), IoSlice::new(HELLO)])
.await
.unwrap();
assert_eq!(ret, HELLO.len() * 2);

client.shutdown().await.unwrap();
drop(client);

let mut buf = Vec::with_capacity(HELLO.len() * 2);
let bytes_read = server.read_to_end(&mut buf).await.unwrap();

assert_eq!(bytes_read, HELLO.len() * 2);
assert_eq!(buf, [HELLO, HELLO].concat());
}

0 comments on commit 0b23fbd

Please sign in to comment.