Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(dgw): use a buffer of 1k bytes for ARD VNC sessions #809

Merged
merged 7 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 251 additions & 0 deletions crates/transport/src/copy_bidirectional.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
//! Vendored code from:
//! - https://github.com/tokio-rs/tokio/blob/1f6fc55917f971791d76dc91cce795e656c0e0d3/tokio/src/io/util/copy.rs
//! - https://github.com/tokio-rs/tokio/blob/1f6fc55917f971791d76dc91cce795e656c0e0d3/tokio/src/io/util/copy_bidirectional.rs
//! It is modified to allow us setting the `CopyBuffer` size instead of hardcoding 8k.
//! See <https://github.com/tokio-rs/tokio/issues/6454>.

use futures_core::ready;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use std::future::Future;
use std::io::{self};
use std::pin::Pin;
use std::task::{Context, Poll};

enum TransferState {
Running(CopyBuffer),
ShuttingDown(u64),
Done(u64),
}

struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
a: &'a mut A,
b: &'a mut B,
a_to_b: TransferState,
b_to_a: TransferState,
}

fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
r: &mut A,
w: &mut B,
) -> Poll<io::Result<u64>>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut r = Pin::new(r);
let mut w = Pin::new(w);

loop {
match state {
TransferState::Running(buf) => {
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(count);
}
TransferState::ShuttingDown(count) => {
ready!(w.as_mut().poll_shutdown(cx))?;

*state = TransferState::Done(*count);
}
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
}
}
}

impl<'a, A, B> Future for CopyBidirectional<'a, A, B>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
type Output = io::Result<(u64, u64)>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Unpack self into mut refs to each field to avoid borrow check issues.
let CopyBidirectional { a, b, a_to_b, b_to_a } = &mut *self;

let a_to_b = transfer_one_direction(cx, a_to_b, &mut *a, &mut *b)?;
let b_to_a = transfer_one_direction(cx, b_to_a, &mut *b, &mut *a)?;

// It is not a problem if ready! returns early because transfer_one_direction for the
// other direction will keep returning TransferState::Done(count) in future calls to poll
let a_to_b = ready!(a_to_b);
let b_to_a = ready!(b_to_a);

Poll::Ready(Ok((a_to_b, b_to_a)))
}
}

/// Copies data in both directions between `a` and `b`.
///
/// This function returns a future that will read from both streams,
/// writing any data read to the opposing stream.
/// This happens in both directions concurrently.
///
/// If an EOF is observed on one stream, [`shutdown()`] will be invoked on
/// the other, and reading from that stream will stop. Copying of data in
/// the other direction will continue.
///
/// The future will complete successfully once both directions of communication has been shut down.
/// A direction is shut down when the reader reports EOF,
/// at which point [`shutdown()`] is called on the corresponding writer. When finished,
/// it will return a tuple of the number of bytes copied from a to b
/// and the number of bytes copied from b to a, in that order.
///
/// [`shutdown()`]: crate::io::AsyncWriteExt::shutdown
///
/// # Errors
///
/// The future will immediately return an error if any IO operation on `a`
/// or `b` returns an error. Some data read from either stream may be lost (not
/// written to the other stream) in this case.
///
/// # Return value
///
/// Returns a tuple of bytes copied `a` to `b` and bytes copied `b` to `a`.
pub async fn copy_bidirectional<A, B>(
a: &mut A,
b: &mut B,
send_buffer_size: usize,
recv_buffer_size: usize,
) -> Result<(u64, u64), std::io::Error>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
CopyBidirectional {
a,
b,
a_to_b: TransferState::Running(CopyBuffer::new(send_buffer_size)),
b_to_a: TransferState::Running(CopyBuffer::new(recv_buffer_size)),
}
.await
}

#[derive(Debug)]
pub(super) struct CopyBuffer {
read_done: bool,
need_flush: bool,
pos: usize,
cap: usize,
amt: u64,
buf: Box<[u8]>,
}

impl CopyBuffer {
pub(super) fn new(buffer_size: usize) -> Self {
// <- This is our change
Self {
read_done: false,
need_flush: false,
pos: 0,
cap: 0,
amt: 0,
buf: vec![0; buffer_size].into_boxed_slice(),
}
}

fn poll_fill_buf<R>(&mut self, cx: &mut Context<'_>, reader: Pin<&mut R>) -> Poll<io::Result<()>>
where
R: AsyncRead + ?Sized,
{
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
buf.set_filled(me.cap);

let res = reader.poll_read(cx, &mut buf);
if let Poll::Ready(Ok(_)) = res {
let filled_len = buf.filled().len();
me.read_done = me.cap == filled_len;
me.cap = filled_len;
}
res
}

fn poll_write_buf<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<usize>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
let me = &mut *self;
match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
Poll::Pending => {
// Top up the buffer towards full if we can read a bit more
// data - this should improve the chances of a large write
if !me.read_done && me.cap < me.buf.len() {
ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
}
Poll::Pending
}
res => res,
}
}

pub(super) fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<u64>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
loop {
// If our buffer is empty, then we need to read some data to
// continue.
if self.pos == self.cap && !self.read_done {
self.pos = 0;
self.cap = 0;

match self.poll_fill_buf(cx, reader.as_mut()) {
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => {
// Try flushing when the reader has no progress to avoid deadlock
// when the reader depends on buffered writer.
if self.need_flush {
ready!(writer.as_mut().poll_flush(cx))?;
self.need_flush = false;
}

return Poll::Pending;
}
}
}

// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
if i == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"write zero byte into writer",
)));
} else {
self.pos += i;
self.amt += i as u64;
self.need_flush = true;
}
}

// If pos larger than cap, this loop will never stop.
// In particular, user's wrong poll_write implementation returning
// incorrect written length may lead to thread blocking.
debug_assert!(self.pos <= self.cap, "writer returned length larger than input slice");

// If we've written all the data and we've seen EOF, flush out the
// data and finish the transfer.
if self.pos == self.cap && self.read_done {
ready!(writer.as_mut().poll_flush(cx))?;
return Poll::Ready(Ok(self.amt));
}
}
}
}
2 changes: 2 additions & 0 deletions crates/transport/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod copy_bidirectional;
mod forward;
mod ws;

pub use self::copy_bidirectional::*;
pub use self::forward::*;
pub use self::ws::*;

Expand Down
11 changes: 10 additions & 1 deletion devolutions-gateway/src/api/fwd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::http::HttpError;
use crate::proxy::Proxy;
use crate::session::{ConnectionModeDetails, SessionInfo, SessionMessageSender};
use crate::subscriber::SubscriberSender;
use crate::token::{AssociationTokenClaims, ConnectionMode};
use crate::token::{ApplicationProtocol, AssociationTokenClaims, ConnectionMode, Protocol};
use crate::{utils, DgwState};

pub fn make_router<S>(state: DgwState) -> Router<S> {
Expand Down Expand Up @@ -162,6 +162,13 @@ where
trace!(%selected_target, "Connected");
span.record("target", selected_target.to_string());

// ARD uses MVS codec which doesn't like buffering.
let buffer_size = if claims.jet_ap == ApplicationProtocol::Known(Protocol::Ard) {
Some(1024)
} else {
None
};

if with_tls {
trace!("Establishing TLS connection with server");

Expand Down Expand Up @@ -193,6 +200,7 @@ where
.transport_b(server_stream)
.sessions(sessions)
.subscriber_tx(subscriber_tx)
.buffer_size(buffer_size)
.build()
.select_dissector_and_forward()
.await
Expand Down Expand Up @@ -220,6 +228,7 @@ where
.transport_b(server_stream)
.sessions(sessions)
.subscriber_tx(subscriber_tx)
.buffer_size(buffer_size)
.build()
.select_dissector_and_forward()
.await
Expand Down
22 changes: 18 additions & 4 deletions devolutions-gateway/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub struct Proxy<A, B> {
address_b: SocketAddr,
sessions: SessionMessageSender,
subscriber_tx: SubscriberSender,
#[builder(default = None)]
buffer_size: Option<usize>,
}

impl<A, B> Proxy<A, B>
Expand Down Expand Up @@ -95,6 +97,7 @@ where
address_b: self.address_b,
sessions: self.sessions,
subscriber_tx: self.subscriber_tx,
buffer_size: self.buffer_size,
}
.forward()
.await
Expand All @@ -121,12 +124,23 @@ where
// NOTE(DGW-86): when recording is required, should we wait for it to start before we forward, or simply spawn
// a timer to check if the recording is started within a few seconds?

let forward_fut = tokio::io::copy_bidirectional(&mut transport_a, &mut transport_b);
let kill_notified = notify_kill.notified();

let res = match futures::future::select(pin!(forward_fut), pin!(kill_notified)).await {
Either::Left((res, _)) => res.map(|_| ()),
Either::Right(_) => Ok(()),
let res = if let Some(buffer_size) = self.buffer_size {
// Use our for of copy_bidirectional because tokio doesn't have an API to set the buffer size.
// See https://github.com/tokio-rs/tokio/issues/6454.
let forward_fut =
transport::copy_bidirectional(&mut transport_a, &mut transport_b, buffer_size, buffer_size);
match futures::future::select(pin!(forward_fut), pin!(kill_notified)).await {
Either::Left((res, _)) => res.map(|_| ()),
Either::Right(_) => Ok(()),
}
} else {
let forward_fut = tokio::io::copy_bidirectional(&mut transport_a, &mut transport_b);
match futures::future::select(pin!(forward_fut), pin!(kill_notified)).await {
Either::Left((res, _)) => res.map(|_| ()),
Either::Right(_) => Ok(()),
}
};

// Ensure we close the transports cleanly at the end (ignore errors at this point)
Expand Down
Loading