Skip to content

Commit

Permalink
io: use vectored io for write_all_buf when possible (#6724)
Browse files Browse the repository at this point in the history
  • Loading branch information
mox692 authored Jul 29, 2024
1 parent 0cbf1a5 commit 1077b0b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
12 changes: 10 additions & 2 deletions tokio/src/io/util/write_all_buf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::io::AsyncWrite;
use bytes::Buf;
use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
use std::io::{self, IoSlice};
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -42,9 +42,17 @@ where
type Output = io::Result<()>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
const MAX_VECTOR_ELEMENTS: usize = 64;

let me = self.project();
while me.buf.has_remaining() {
let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf.chunk())?);
let n = if me.writer.is_write_vectored() {
let mut slices = [IoSlice::new(&[]); MAX_VECTOR_ELEMENTS];
let cnt = me.buf.chunks_vectored(&mut slices);
ready!(Pin::new(&mut *me.writer).poll_write_vectored(cx, &slices[..cnt]))?
} else {
ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf.chunk())?)
};
me.buf.advance(n);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
Expand Down
49 changes: 49 additions & 0 deletions tokio/tests/io_write_all_buf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,52 @@ async fn write_buf_err() {
Bytes::from_static(b"oworld")
);
}

#[tokio::test]
async fn write_all_buf_vectored() {
struct Wr {
buf: BytesMut,
}
impl AsyncWrite for Wr {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
// When executing `write_all_buf` with this writer,
// `poll_write` is not called.
panic!("shouldn't be called")
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()).into()
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()).into()
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
for buf in bufs {
self.buf.extend_from_slice(buf);
}
let n = self.buf.len();
Ok(n).into()
}
fn is_write_vectored(&self) -> bool {
// Enable vectored write.
true
}
}

let mut wr = Wr {
buf: BytesMut::with_capacity(64),
};
let mut buf = Bytes::from_static(b"hello")
.chain(Bytes::from_static(b" "))
.chain(Bytes::from_static(b"world"));

wr.write_all_buf(&mut buf).await.unwrap();
assert_eq!(&wr.buf[..], b"hello world");
}

0 comments on commit 1077b0b

Please sign in to comment.