From 206fecdd6c69720201faaa48411c3813bc6dc324 Mon Sep 17 00:00:00 2001 From: Alexandr Yusiuk Date: Mon, 15 Apr 2024 09:14:15 +0300 Subject: [PATCH 1/7] fix: Use 1024 buffer size for ARD VNC session _ARD_ uses _MVS_ video codec which doesn't like buffering, and we need to have the buffer as minimal as possible. Also, this commit adds new `copy_bidirectional` transport that is forked [the one from tokio](https://docs.rs/tokio/latest/tokio/io/fn.copy_bidirectional.html). It's forked because the original function doesn't allow overriding the buffer size (8K is used by default). There is [an issue](https://github.com/tokio-rs/tokio/issues/6454) on tokio side for it. We will be able to replace our fork with the upstream easily when it's ready. Releated to https://github.com/Devolutions/IronVNC/issues/338. --- crates/transport/src/copy_bidirectional.rs | 123 ++++++++++++++++++ crates/transport/src/copy_buffer.rs | 143 +++++++++++++++++++++ crates/transport/src/lib.rs | 3 + devolutions-gateway/src/api/fwd.rs | 11 +- devolutions-gateway/src/proxy.rs | 21 ++- 5 files changed, 296 insertions(+), 5 deletions(-) create mode 100644 crates/transport/src/copy_bidirectional.rs create mode 100644 crates/transport/src/copy_buffer.rs diff --git a/crates/transport/src/copy_bidirectional.rs b/crates/transport/src/copy_bidirectional.rs new file mode 100644 index 000000000..eea7a6e62 --- /dev/null +++ b/crates/transport/src/copy_bidirectional.rs @@ -0,0 +1,123 @@ +//! Fork of https://github.com/tokio-rs/tokio/blob/master/tokio/src/io/util/copy.rs to allow us set +//! variable length `CopyBuffer` size instead of default 8k. +//! See . + +use super::copy_buffer::CopyBuffer; +use futures_core::ready; +use tokio::io::{AsyncRead, AsyncWrite}; + +use std::future::Future; +use std::io::{self}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +enum TransferState { + Running(CopyBuffer), + ShuttingDown(u64), + Done(u64), +} + +struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> { + a: &'a mut A, + b: &'a mut B, + a_to_b: TransferState, + b_to_a: TransferState, +} + +fn transfer_one_direction( + cx: &mut Context<'_>, + state: &mut TransferState, + r: &mut A, + w: &mut B, +) -> Poll> + where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + let mut r = Pin::new(r); + let mut w = Pin::new(w); + + loop { + match state { + TransferState::Running(buf) => { + let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?; + *state = TransferState::ShuttingDown(count); + } + TransferState::ShuttingDown(count) => { + ready!(w.as_mut().poll_shutdown(cx))?; + + *state = TransferState::Done(*count); + } + TransferState::Done(count) => return Poll::Ready(Ok(*count)), + } + } +} + +impl<'a, A, B> Future for CopyBidirectional<'a, A, B> + where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<(u64, u64)>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // Unpack self into mut refs to each field to avoid borrow check issues. + let CopyBidirectional { + a, + b, + a_to_b, + b_to_a, + } = &mut *self; + + let a_to_b = transfer_one_direction(cx, a_to_b, &mut *a, &mut *b)?; + let b_to_a = transfer_one_direction(cx, b_to_a, &mut *b, &mut *a)?; + + // It is not a problem if ready! returns early because transfer_one_direction for the + // other direction will keep returning TransferState::Done(count) in future calls to poll + let a_to_b = ready!(a_to_b); + let b_to_a = ready!(b_to_a); + + Poll::Ready(Ok((a_to_b, b_to_a))) + } +} + +/// Copies data in both directions between `a` and `b`. +/// +/// This function returns a future that will read from both streams, +/// writing any data read to the opposing stream. +/// This happens in both directions concurrently. +/// +/// If an EOF is observed on one stream, [`shutdown()`] will be invoked on +/// the other, and reading from that stream will stop. Copying of data in +/// the other direction will continue. +/// +/// The future will complete successfully once both directions of communication has been shut down. +/// A direction is shut down when the reader reports EOF, +/// at which point [`shutdown()`] is called on the corresponding writer. When finished, +/// it will return a tuple of the number of bytes copied from a to b +/// and the number of bytes copied from b to a, in that order. +/// +/// [`shutdown()`]: crate::io::AsyncWriteExt::shutdown +/// +/// # Errors +/// +/// The future will immediately return an error if any IO operation on `a` +/// or `b` returns an error. Some data read from either stream may be lost (not +/// written to the other stream) in this case. +/// +/// # Return value +/// +/// Returns a tuple of bytes copied `a` to `b` and bytes copied `b` to `a`. +pub async fn copy_bidirectional(a: &mut A, b: &mut B, send_buffer_size: usize, recv_buffer_size: usize) -> Result<(u64, u64), std::io::Error> + where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + CopyBidirectional { + a, + b, + a_to_b: TransferState::Running(CopyBuffer::new(send_buffer_size)), + b_to_a: TransferState::Running(CopyBuffer::new(recv_buffer_size)), + } + .await +} \ No newline at end of file diff --git a/crates/transport/src/copy_buffer.rs b/crates/transport/src/copy_buffer.rs new file mode 100644 index 000000000..24879b421 --- /dev/null +++ b/crates/transport/src/copy_buffer.rs @@ -0,0 +1,143 @@ +//! Fork of https://github.com/tokio-rs/tokio/blob/master/tokio/src/io/util/copy.rs to allow us set +//! variable length `CopyBuffer` size instead of default 8k. +//! See . +use futures_core::ready; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use std::io::{self}; +use std::pin::Pin; +use std::task::{Context, Poll}; + + +#[derive(Debug)] +pub(super) struct CopyBuffer { + read_done: bool, + need_flush: bool, + pos: usize, + cap: usize, + amt: u64, + buf: Box<[u8]>, +} + +impl CopyBuffer { + pub(super) fn new(buffer_size: usize) -> Self { // <- This is our change + Self { + read_done: false, + need_flush: false, + pos: 0, + cap: 0, + amt: 0, + buf: vec![0; buffer_size].into_boxed_slice(), + } + } + + fn poll_fill_buf( + &mut self, + cx: &mut Context<'_>, + reader: Pin<&mut R>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + { + let me = &mut *self; + let mut buf = ReadBuf::new(&mut me.buf); + buf.set_filled(me.cap); + + let res = reader.poll_read(cx, &mut buf); + if let Poll::Ready(Ok(_)) = res { + let filled_len = buf.filled().len(); + me.read_done = me.cap == filled_len; + me.cap = filled_len; + } + res + } + + fn poll_write_buf( + &mut self, + cx: &mut Context<'_>, + mut reader: Pin<&mut R>, + mut writer: Pin<&mut W>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, + { + let me = &mut *self; + match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) { + Poll::Pending => { + // Top up the buffer towards full if we can read a bit more + // data - this should improve the chances of a large write + if !me.read_done && me.cap < me.buf.len() { + ready!(me.poll_fill_buf(cx, reader.as_mut()))?; + } + Poll::Pending + } + res => res, + } + } + + pub(super) fn poll_copy( + &mut self, + cx: &mut Context<'_>, + mut reader: Pin<&mut R>, + mut writer: Pin<&mut W>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, + { + loop { + // If our buffer is empty, then we need to read some data to + // continue. + if self.pos == self.cap && !self.read_done { + self.pos = 0; + self.cap = 0; + + match self.poll_fill_buf(cx, reader.as_mut()) { + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => { + // Try flushing when the reader has no progress to avoid deadlock + // when the reader depends on buffered writer. + if self.need_flush { + ready!(writer.as_mut().poll_flush(cx))?; + self.need_flush = false; + } + + return Poll::Pending; + } + } + } + + // If our buffer has some data, let's write it out! + while self.pos < self.cap { + let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; + if i == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "write zero byte into writer", + ))); + } else { + self.pos += i; + self.amt += i as u64; + self.need_flush = true; + } + } + + // If pos larger than cap, this loop will never stop. + // In particular, user's wrong poll_write implementation returning + // incorrect written length may lead to thread blocking. + debug_assert!( + self.pos <= self.cap, + "writer returned length larger than input slice" + ); + + // If we've written all the data and we've seen EOF, flush out the + // data and finish the transfer. + if self.pos == self.cap && self.read_done { + ready!(writer.as_mut().poll_flush(cx))?; + return Poll::Ready(Ok(self.amt)); + } + } + } +} \ No newline at end of file diff --git a/crates/transport/src/lib.rs b/crates/transport/src/lib.rs index 6a35fce4c..ea518f610 100644 --- a/crates/transport/src/lib.rs +++ b/crates/transport/src/lib.rs @@ -1,6 +1,9 @@ mod forward; mod ws; +mod copy_bidirectional; +mod copy_buffer; +pub use copy_bidirectional::*; pub use self::forward::*; pub use self::ws::*; diff --git a/devolutions-gateway/src/api/fwd.rs b/devolutions-gateway/src/api/fwd.rs index a79b6c341..6bc2b64ab 100644 --- a/devolutions-gateway/src/api/fwd.rs +++ b/devolutions-gateway/src/api/fwd.rs @@ -18,7 +18,7 @@ use crate::http::HttpError; use crate::proxy::Proxy; use crate::session::{ConnectionModeDetails, SessionInfo, SessionMessageSender}; use crate::subscriber::SubscriberSender; -use crate::token::{AssociationTokenClaims, ConnectionMode}; +use crate::token::{ApplicationProtocol, AssociationTokenClaims, ConnectionMode, Protocol}; use crate::{utils, DgwState}; pub fn make_router(state: DgwState) -> Router { @@ -162,6 +162,13 @@ where trace!(%selected_target, "Connected"); span.record("target", selected_target.to_string()); + // ARD uses MVS codec which doesn't like buffering. + let buffer_size = if claims.jet_ap == ApplicationProtocol::Known(Protocol::Ard) { + Some(1024) + } else { + None + }; + if with_tls { trace!("Establishing TLS connection with server"); @@ -193,6 +200,7 @@ where .transport_b(server_stream) .sessions(sessions) .subscriber_tx(subscriber_tx) + .buffer_size(buffer_size) .build() .select_dissector_and_forward() .await @@ -220,6 +228,7 @@ where .transport_b(server_stream) .sessions(sessions) .subscriber_tx(subscriber_tx) + .buffer_size(buffer_size) .build() .select_dissector_and_forward() .await diff --git a/devolutions-gateway/src/proxy.rs b/devolutions-gateway/src/proxy.rs index d50666df6..c50ff2c1c 100644 --- a/devolutions-gateway/src/proxy.rs +++ b/devolutions-gateway/src/proxy.rs @@ -24,6 +24,8 @@ pub struct Proxy { address_b: SocketAddr, sessions: SessionMessageSender, subscriber_tx: SubscriberSender, + #[builder(default = None)] + buffer_size: Option, } impl Proxy @@ -95,6 +97,7 @@ where address_b: self.address_b, sessions: self.sessions, subscriber_tx: self.subscriber_tx, + buffer_size: self.buffer_size, } .forward() .await @@ -121,12 +124,22 @@ where // NOTE(DGW-86): when recording is required, should we wait for it to start before we forward, or simply spawn // a timer to check if the recording is started within a few seconds? - let forward_fut = tokio::io::copy_bidirectional(&mut transport_a, &mut transport_b); let kill_notified = notify_kill.notified(); - let res = match futures::future::select(pin!(forward_fut), pin!(kill_notified)).await { - Either::Left((res, _)) => res.map(|_| ()), - Either::Right(_) => Ok(()), + let res = if let Some(buffer_size) = self.buffer_size { + // Use our for of copy_bidirectional because tokio doesn't have an API to set the buffer size. + // See https://github.com/tokio-rs/tokio/issues/6454. + let forward_fut = transport::copy_bidirectional(&mut transport_a, &mut transport_b, buffer_size, buffer_size); + match futures::future::select(pin!(forward_fut), pin!(kill_notified)).await { + Either::Left((res, _)) => res.map(|_| ()), + Either::Right(_) => Ok(()), + } + } else { + let forward_fut = tokio::io::copy_bidirectional(&mut transport_a, &mut transport_b); + match futures::future::select(pin!(forward_fut), pin!(kill_notified)).await { + Either::Left((res, _)) => res.map(|_| ()), + Either::Right(_) => Ok(()), + } }; // Ensure we close the transports cleanly at the end (ignore errors at this point) From 3c59de067458d640deea8deeaf81239777245ca9 Mon Sep 17 00:00:00 2001 From: Alexandr Yusiuk Date: Mon, 15 Apr 2024 10:15:57 +0300 Subject: [PATCH 2/7] fix(ci): run cargo fmt --- crates/transport/src/copy_bidirectional.rs | 36 +++++++++++----------- crates/transport/src/copy_buffer.rs | 33 ++++++++------------ crates/transport/src/lib.rs | 6 ++-- devolutions-gateway/src/proxy.rs | 3 +- 4 files changed, 36 insertions(+), 42 deletions(-) diff --git a/crates/transport/src/copy_bidirectional.rs b/crates/transport/src/copy_bidirectional.rs index eea7a6e62..231f6e9fa 100644 --- a/crates/transport/src/copy_bidirectional.rs +++ b/crates/transport/src/copy_bidirectional.rs @@ -30,9 +30,9 @@ fn transfer_one_direction( r: &mut A, w: &mut B, ) -> Poll> - where - A: AsyncRead + AsyncWrite + Unpin + ?Sized, - B: AsyncRead + AsyncWrite + Unpin + ?Sized, +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, { let mut r = Pin::new(r); let mut w = Pin::new(w); @@ -54,20 +54,15 @@ fn transfer_one_direction( } impl<'a, A, B> Future for CopyBidirectional<'a, A, B> - where - A: AsyncRead + AsyncWrite + Unpin + ?Sized, - B: AsyncRead + AsyncWrite + Unpin + ?Sized, +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, { type Output = io::Result<(u64, u64)>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { // Unpack self into mut refs to each field to avoid borrow check issues. - let CopyBidirectional { - a, - b, - a_to_b, - b_to_a, - } = &mut *self; + let CopyBidirectional { a, b, a_to_b, b_to_a } = &mut *self; let a_to_b = transfer_one_direction(cx, a_to_b, &mut *a, &mut *b)?; let b_to_a = transfer_one_direction(cx, b_to_a, &mut *b, &mut *a)?; @@ -108,10 +103,15 @@ impl<'a, A, B> Future for CopyBidirectional<'a, A, B> /// # Return value /// /// Returns a tuple of bytes copied `a` to `b` and bytes copied `b` to `a`. -pub async fn copy_bidirectional(a: &mut A, b: &mut B, send_buffer_size: usize, recv_buffer_size: usize) -> Result<(u64, u64), std::io::Error> - where - A: AsyncRead + AsyncWrite + Unpin + ?Sized, - B: AsyncRead + AsyncWrite + Unpin + ?Sized, +pub async fn copy_bidirectional( + a: &mut A, + b: &mut B, + send_buffer_size: usize, + recv_buffer_size: usize, +) -> Result<(u64, u64), std::io::Error> +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, { CopyBidirectional { a, @@ -119,5 +119,5 @@ pub async fn copy_bidirectional(a: &mut A, b: &mut B, send_buffer_size: us a_to_b: TransferState::Running(CopyBuffer::new(send_buffer_size)), b_to_a: TransferState::Running(CopyBuffer::new(recv_buffer_size)), } - .await -} \ No newline at end of file + .await +} diff --git a/crates/transport/src/copy_buffer.rs b/crates/transport/src/copy_buffer.rs index 24879b421..91b4b8ec3 100644 --- a/crates/transport/src/copy_buffer.rs +++ b/crates/transport/src/copy_buffer.rs @@ -8,7 +8,6 @@ use std::io::{self}; use std::pin::Pin; use std::task::{Context, Poll}; - #[derive(Debug)] pub(super) struct CopyBuffer { read_done: bool, @@ -20,7 +19,8 @@ pub(super) struct CopyBuffer { } impl CopyBuffer { - pub(super) fn new(buffer_size: usize) -> Self { // <- This is our change + pub(super) fn new(buffer_size: usize) -> Self { + // <- This is our change Self { read_done: false, need_flush: false, @@ -31,13 +31,9 @@ impl CopyBuffer { } } - fn poll_fill_buf( - &mut self, - cx: &mut Context<'_>, - reader: Pin<&mut R>, - ) -> Poll> - where - R: AsyncRead + ?Sized, + fn poll_fill_buf(&mut self, cx: &mut Context<'_>, reader: Pin<&mut R>) -> Poll> + where + R: AsyncRead + ?Sized, { let me = &mut *self; let mut buf = ReadBuf::new(&mut me.buf); @@ -58,9 +54,9 @@ impl CopyBuffer { mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll> - where - R: AsyncRead + ?Sized, - W: AsyncWrite + ?Sized, + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, { let me = &mut *self; match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) { @@ -82,9 +78,9 @@ impl CopyBuffer { mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll> - where - R: AsyncRead + ?Sized, - W: AsyncWrite + ?Sized, + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, { loop { // If our buffer is empty, then we need to read some data to @@ -127,10 +123,7 @@ impl CopyBuffer { // If pos larger than cap, this loop will never stop. // In particular, user's wrong poll_write implementation returning // incorrect written length may lead to thread blocking. - debug_assert!( - self.pos <= self.cap, - "writer returned length larger than input slice" - ); + debug_assert!(self.pos <= self.cap, "writer returned length larger than input slice"); // If we've written all the data and we've seen EOF, flush out the // data and finish the transfer. @@ -140,4 +133,4 @@ impl CopyBuffer { } } } -} \ No newline at end of file +} diff --git a/crates/transport/src/lib.rs b/crates/transport/src/lib.rs index ea518f610..e4814489d 100644 --- a/crates/transport/src/lib.rs +++ b/crates/transport/src/lib.rs @@ -1,11 +1,11 @@ -mod forward; -mod ws; mod copy_bidirectional; mod copy_buffer; +mod forward; +mod ws; -pub use copy_bidirectional::*; pub use self::forward::*; pub use self::ws::*; +pub use copy_bidirectional::*; use tokio::io::{AsyncRead, AsyncWrite}; diff --git a/devolutions-gateway/src/proxy.rs b/devolutions-gateway/src/proxy.rs index c50ff2c1c..15d0453b1 100644 --- a/devolutions-gateway/src/proxy.rs +++ b/devolutions-gateway/src/proxy.rs @@ -129,7 +129,8 @@ where let res = if let Some(buffer_size) = self.buffer_size { // Use our for of copy_bidirectional because tokio doesn't have an API to set the buffer size. // See https://github.com/tokio-rs/tokio/issues/6454. - let forward_fut = transport::copy_bidirectional(&mut transport_a, &mut transport_b, buffer_size, buffer_size); + let forward_fut = + transport::copy_bidirectional(&mut transport_a, &mut transport_b, buffer_size, buffer_size); match futures::future::select(pin!(forward_fut), pin!(kill_notified)).await { Either::Left((res, _)) => res.map(|_| ()), Either::Right(_) => Ok(()), From 9a3c01fa428edb87e84d2020c91aa53b6d7520fb Mon Sep 17 00:00:00 2001 From: Alex Yusiuk <55661041+RRRadicalEdward@users.noreply.github.com> Date: Mon, 15 Apr 2024 07:27:50 +0000 Subject: [PATCH 3/7] Update crates/transport/src/copy_bidirectional.rs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: BenoƮt Cortier --- crates/transport/src/copy_bidirectional.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/transport/src/copy_bidirectional.rs b/crates/transport/src/copy_bidirectional.rs index 231f6e9fa..0c19ba5ac 100644 --- a/crates/transport/src/copy_bidirectional.rs +++ b/crates/transport/src/copy_bidirectional.rs @@ -1,5 +1,5 @@ -//! Fork of https://github.com/tokio-rs/tokio/blob/master/tokio/src/io/util/copy.rs to allow us set -//! variable length `CopyBuffer` size instead of default 8k. +//! Vendored code from https://github.com/tokio-rs/tokio/blob/1f6fc55917f971791d76dc91cce795e656c0e0d3/tokio/src/io/util/copy.rs +//! It is modified to allow us setting the `CopyBuffer` size instead of hardcoding 8k. //! See . use super::copy_buffer::CopyBuffer; From 4fea288fe003735d92847e09184900141cec0469 Mon Sep 17 00:00:00 2001 From: Alexandr Yusiuk Date: Mon, 15 Apr 2024 10:37:37 +0300 Subject: [PATCH 4/7] review: fix reexport --- crates/transport/src/lib.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/transport/src/lib.rs b/crates/transport/src/lib.rs index e4814489d..e75723847 100644 --- a/crates/transport/src/lib.rs +++ b/crates/transport/src/lib.rs @@ -1,11 +1,10 @@ mod copy_bidirectional; -mod copy_buffer; mod forward; mod ws; pub use self::forward::*; pub use self::ws::*; -pub use copy_bidirectional::*; +pub use self::copy_bidirectional::*; use tokio::io::{AsyncRead, AsyncWrite}; From d8dd5e3bf3e9a0802b052e721ee61ebdfe07e4d3 Mon Sep 17 00:00:00 2001 From: Alexandr Yusiuk Date: Mon, 15 Apr 2024 10:37:43 +0300 Subject: [PATCH 5/7] review: move copy_buffer.rs code to copy_bidirectional.rs --- crates/transport/src/copy_bidirectional.rs | 130 +++++++++++++++++++- crates/transport/src/copy_buffer.rs | 136 --------------------- 2 files changed, 128 insertions(+), 138 deletions(-) delete mode 100644 crates/transport/src/copy_buffer.rs diff --git a/crates/transport/src/copy_bidirectional.rs b/crates/transport/src/copy_bidirectional.rs index 0c19ba5ac..41c7fed67 100644 --- a/crates/transport/src/copy_bidirectional.rs +++ b/crates/transport/src/copy_bidirectional.rs @@ -2,9 +2,8 @@ //! It is modified to allow us setting the `CopyBuffer` size instead of hardcoding 8k. //! See . -use super::copy_buffer::CopyBuffer; use futures_core::ready; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use std::future::Future; use std::io::{self}; @@ -121,3 +120,130 @@ where } .await } + +#[derive(Debug)] +pub(super) struct CopyBuffer { + read_done: bool, + need_flush: bool, + pos: usize, + cap: usize, + amt: u64, + buf: Box<[u8]>, +} + +impl CopyBuffer { + pub(super) fn new(buffer_size: usize) -> Self { + // <- This is our change + Self { + read_done: false, + need_flush: false, + pos: 0, + cap: 0, + amt: 0, + buf: vec![0; buffer_size].into_boxed_slice(), + } + } + + fn poll_fill_buf(&mut self, cx: &mut Context<'_>, reader: Pin<&mut R>) -> Poll> + where + R: AsyncRead + ?Sized, + { + let me = &mut *self; + let mut buf = ReadBuf::new(&mut me.buf); + buf.set_filled(me.cap); + + let res = reader.poll_read(cx, &mut buf); + if let Poll::Ready(Ok(_)) = res { + let filled_len = buf.filled().len(); + me.read_done = me.cap == filled_len; + me.cap = filled_len; + } + res + } + + fn poll_write_buf( + &mut self, + cx: &mut Context<'_>, + mut reader: Pin<&mut R>, + mut writer: Pin<&mut W>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, + { + let me = &mut *self; + match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) { + Poll::Pending => { + // Top up the buffer towards full if we can read a bit more + // data - this should improve the chances of a large write + if !me.read_done && me.cap < me.buf.len() { + ready!(me.poll_fill_buf(cx, reader.as_mut()))?; + } + Poll::Pending + } + res => res, + } + } + + pub(super) fn poll_copy( + &mut self, + cx: &mut Context<'_>, + mut reader: Pin<&mut R>, + mut writer: Pin<&mut W>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, + { + loop { + // If our buffer is empty, then we need to read some data to + // continue. + if self.pos == self.cap && !self.read_done { + self.pos = 0; + self.cap = 0; + + match self.poll_fill_buf(cx, reader.as_mut()) { + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => { + // Try flushing when the reader has no progress to avoid deadlock + // when the reader depends on buffered writer. + if self.need_flush { + ready!(writer.as_mut().poll_flush(cx))?; + self.need_flush = false; + } + + return Poll::Pending; + } + } + } + + // If our buffer has some data, let's write it out! + while self.pos < self.cap { + let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; + if i == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "write zero byte into writer", + ))); + } else { + self.pos += i; + self.amt += i as u64; + self.need_flush = true; + } + } + + // If pos larger than cap, this loop will never stop. + // In particular, user's wrong poll_write implementation returning + // incorrect written length may lead to thread blocking. + debug_assert!(self.pos <= self.cap, "writer returned length larger than input slice"); + + // If we've written all the data and we've seen EOF, flush out the + // data and finish the transfer. + if self.pos == self.cap && self.read_done { + ready!(writer.as_mut().poll_flush(cx))?; + return Poll::Ready(Ok(self.amt)); + } + } + } +} diff --git a/crates/transport/src/copy_buffer.rs b/crates/transport/src/copy_buffer.rs deleted file mode 100644 index 91b4b8ec3..000000000 --- a/crates/transport/src/copy_buffer.rs +++ /dev/null @@ -1,136 +0,0 @@ -//! Fork of https://github.com/tokio-rs/tokio/blob/master/tokio/src/io/util/copy.rs to allow us set -//! variable length `CopyBuffer` size instead of default 8k. -//! See . -use futures_core::ready; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - -use std::io::{self}; -use std::pin::Pin; -use std::task::{Context, Poll}; - -#[derive(Debug)] -pub(super) struct CopyBuffer { - read_done: bool, - need_flush: bool, - pos: usize, - cap: usize, - amt: u64, - buf: Box<[u8]>, -} - -impl CopyBuffer { - pub(super) fn new(buffer_size: usize) -> Self { - // <- This is our change - Self { - read_done: false, - need_flush: false, - pos: 0, - cap: 0, - amt: 0, - buf: vec![0; buffer_size].into_boxed_slice(), - } - } - - fn poll_fill_buf(&mut self, cx: &mut Context<'_>, reader: Pin<&mut R>) -> Poll> - where - R: AsyncRead + ?Sized, - { - let me = &mut *self; - let mut buf = ReadBuf::new(&mut me.buf); - buf.set_filled(me.cap); - - let res = reader.poll_read(cx, &mut buf); - if let Poll::Ready(Ok(_)) = res { - let filled_len = buf.filled().len(); - me.read_done = me.cap == filled_len; - me.cap = filled_len; - } - res - } - - fn poll_write_buf( - &mut self, - cx: &mut Context<'_>, - mut reader: Pin<&mut R>, - mut writer: Pin<&mut W>, - ) -> Poll> - where - R: AsyncRead + ?Sized, - W: AsyncWrite + ?Sized, - { - let me = &mut *self; - match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) { - Poll::Pending => { - // Top up the buffer towards full if we can read a bit more - // data - this should improve the chances of a large write - if !me.read_done && me.cap < me.buf.len() { - ready!(me.poll_fill_buf(cx, reader.as_mut()))?; - } - Poll::Pending - } - res => res, - } - } - - pub(super) fn poll_copy( - &mut self, - cx: &mut Context<'_>, - mut reader: Pin<&mut R>, - mut writer: Pin<&mut W>, - ) -> Poll> - where - R: AsyncRead + ?Sized, - W: AsyncWrite + ?Sized, - { - loop { - // If our buffer is empty, then we need to read some data to - // continue. - if self.pos == self.cap && !self.read_done { - self.pos = 0; - self.cap = 0; - - match self.poll_fill_buf(cx, reader.as_mut()) { - Poll::Ready(Ok(_)) => (), - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => { - // Try flushing when the reader has no progress to avoid deadlock - // when the reader depends on buffered writer. - if self.need_flush { - ready!(writer.as_mut().poll_flush(cx))?; - self.need_flush = false; - } - - return Poll::Pending; - } - } - } - - // If our buffer has some data, let's write it out! - while self.pos < self.cap { - let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; - if i == 0 { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "write zero byte into writer", - ))); - } else { - self.pos += i; - self.amt += i as u64; - self.need_flush = true; - } - } - - // If pos larger than cap, this loop will never stop. - // In particular, user's wrong poll_write implementation returning - // incorrect written length may lead to thread blocking. - debug_assert!(self.pos <= self.cap, "writer returned length larger than input slice"); - - // If we've written all the data and we've seen EOF, flush out the - // data and finish the transfer. - if self.pos == self.cap && self.read_done { - ready!(writer.as_mut().poll_flush(cx))?; - return Poll::Ready(Ok(self.amt)); - } - } - } -} From d604e9504291cf62ad02bf85f82c3a7952f9333f Mon Sep 17 00:00:00 2001 From: Alexandr Yusiuk Date: Mon, 15 Apr 2024 10:39:22 +0300 Subject: [PATCH 6/7] fix(ci): run cargo fmt --- crates/transport/src/copy_bidirectional.rs | 16 ++++++++-------- crates/transport/src/lib.rs | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/crates/transport/src/copy_bidirectional.rs b/crates/transport/src/copy_bidirectional.rs index 41c7fed67..0b199500c 100644 --- a/crates/transport/src/copy_bidirectional.rs +++ b/crates/transport/src/copy_bidirectional.rs @@ -145,8 +145,8 @@ impl CopyBuffer { } fn poll_fill_buf(&mut self, cx: &mut Context<'_>, reader: Pin<&mut R>) -> Poll> - where - R: AsyncRead + ?Sized, + where + R: AsyncRead + ?Sized, { let me = &mut *self; let mut buf = ReadBuf::new(&mut me.buf); @@ -167,9 +167,9 @@ impl CopyBuffer { mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll> - where - R: AsyncRead + ?Sized, - W: AsyncWrite + ?Sized, + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, { let me = &mut *self; match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) { @@ -191,9 +191,9 @@ impl CopyBuffer { mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll> - where - R: AsyncRead + ?Sized, - W: AsyncWrite + ?Sized, + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, { loop { // If our buffer is empty, then we need to read some data to diff --git a/crates/transport/src/lib.rs b/crates/transport/src/lib.rs index e75723847..82c64fb3c 100644 --- a/crates/transport/src/lib.rs +++ b/crates/transport/src/lib.rs @@ -2,9 +2,9 @@ mod copy_bidirectional; mod forward; mod ws; +pub use self::copy_bidirectional::*; pub use self::forward::*; pub use self::ws::*; -pub use self::copy_bidirectional::*; use tokio::io::{AsyncRead, AsyncWrite}; From f31d8b3fdb3d28267901ce3edec54f796ead9b86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Cortier?= Date: Mon, 15 Apr 2024 03:45:54 -0400 Subject: [PATCH 7/7] Update crates/transport/src/copy_bidirectional.rs --- crates/transport/src/copy_bidirectional.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crates/transport/src/copy_bidirectional.rs b/crates/transport/src/copy_bidirectional.rs index 0b199500c..71396c52f 100644 --- a/crates/transport/src/copy_bidirectional.rs +++ b/crates/transport/src/copy_bidirectional.rs @@ -1,4 +1,6 @@ -//! Vendored code from https://github.com/tokio-rs/tokio/blob/1f6fc55917f971791d76dc91cce795e656c0e0d3/tokio/src/io/util/copy.rs +//! Vendored code from: +//! - https://github.com/tokio-rs/tokio/blob/1f6fc55917f971791d76dc91cce795e656c0e0d3/tokio/src/io/util/copy.rs +//! - https://github.com/tokio-rs/tokio/blob/1f6fc55917f971791d76dc91cce795e656c0e0d3/tokio/src/io/util/copy_bidirectional.rs //! It is modified to allow us setting the `CopyBuffer` size instead of hardcoding 8k. //! See .