From 32de1d50de509559e2b8f2d6c7e1259c0db85cb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Cortier?= Date: Tue, 13 Aug 2024 22:20:07 -0400 Subject: [PATCH] perf(jetsocat,dgw): remove to_vec calls from JMUX implementation (#973) By removing a few `to_vec()` calls and reusing the `Bytes` buffer as-is, perfomance of JMUX proxy is increased by ~62.3%. Performance was measured using `iperf` on local network. JMUX proxy performance before this patch: > 0.0000-10.0493 sec 8.21 GBytes 7.02 Gbits/sec JMUX proxy performance after this patch: > 0.0000-19.0245 sec 25.2 GBytes 11.4 Gbits/sec This is still 78.6% slower than jetsocat regular TCP port forwarding. --- crates/jmux-generators/src/lib.rs | 3 ++- crates/jmux-proto/src/lib.rs | 15 +++++++++------ crates/jmux-proto/tests/message.rs | 4 ++-- crates/jmux-proxy/src/lib.rs | 25 +++++++++++++------------ 4 files changed, 26 insertions(+), 21 deletions(-) diff --git a/crates/jmux-generators/src/lib.rs b/crates/jmux-generators/src/lib.rs index 09bd6c88d..95fe49eae 100644 --- a/crates/jmux-generators/src/lib.rs +++ b/crates/jmux-generators/src/lib.rs @@ -46,7 +46,8 @@ pub fn message_window_adjust() -> impl Strategy { } pub fn message_data() -> impl Strategy { - (distant_channel_id(), vec(any::(), 0..512)).prop_map(|(distant_id, data)| Message::data(distant_id, data)) + (distant_channel_id(), vec(any::(), 0..512)) + .prop_map(|(distant_id, data)| Message::data(distant_id, Bytes::from(data))) } pub fn message_eof() -> impl Strategy { diff --git a/crates/jmux-proto/src/lib.rs b/crates/jmux-proto/src/lib.rs index 628d8e79a..93aca9af9 100644 --- a/crates/jmux-proto/src/lib.rs +++ b/crates/jmux-proto/src/lib.rs @@ -1,9 +1,12 @@ //! [Specification document](https://github.com/Devolutions/devolutions-gateway/blob/master/crates/jmux-proto/spec/JMUX_Spec.md) -use bytes::{Buf as _, BufMut as _, Bytes, BytesMut}; +use bytes::{Buf as _, BufMut as _}; use core::fmt; use smol_str::SmolStr; +// We re-export these types, because they are used in the public API. +pub use bytes::{Bytes, BytesMut}; + /// Distant identifier for a channel #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] pub struct DistantChannelId(u32); @@ -235,7 +238,7 @@ impl Message { Self::WindowAdjust(ChannelWindowAdjust::new(distant_id, window_adjustment)) } - pub fn data(id: DistantChannelId, data: Vec) -> Self { + pub fn data(id: DistantChannelId, data: Bytes) -> Self { Self::Data(ChannelData::new(id, data)) } @@ -654,7 +657,7 @@ impl ChannelWindowAdjust { #[derive(PartialEq, Eq)] pub struct ChannelData { pub recipient_channel_id: u32, - pub transfer_data: Vec, + pub transfer_data: Bytes, } // We don't want to print `transfer_data` content (usually too big) @@ -671,7 +674,7 @@ impl ChannelData { pub const NAME: &'static str = "CHANNEL DATA"; pub const FIXED_PART_SIZE: usize = 4 /*recipientChannelId*/; - pub fn new(id: DistantChannelId, data: Vec) -> Self { + pub fn new(id: DistantChannelId, data: Bytes) -> Self { ChannelData { recipient_channel_id: u32::from(id), transfer_data: data, @@ -684,14 +687,14 @@ impl ChannelData { pub fn encode(&self, buf: &mut BytesMut) { buf.put_u32(self.recipient_channel_id); - buf.put(self.transfer_data.as_slice()); + buf.put(self.transfer_data.slice(..)); } pub fn decode(mut buf: Bytes) -> Result { ensure_size!(fixed Self in buf); Ok(Self { recipient_channel_id: buf.get_u32(), - transfer_data: buf.to_vec(), + transfer_data: buf, }) } } diff --git a/crates/jmux-proto/tests/message.rs b/crates/jmux-proto/tests/message.rs index 3eea270db..e56b52632 100644 --- a/crates/jmux-proto/tests/message.rs +++ b/crates/jmux-proto/tests/message.rs @@ -152,7 +152,7 @@ pub fn channel_window_adjust() { #[test] pub fn error_on_oversized_packet() { let mut buf = BytesMut::new(); - let err = Message::data(DistantChannelId::from(1), vec![0; u16::MAX as usize]) + let err = Message::data(DistantChannelId::from(1), vec![0; u16::MAX as usize].into()) .encode(&mut buf) .err() .unwrap(); @@ -171,7 +171,7 @@ pub fn channel_data() { let msg_example = ChannelData { recipient_channel_id: 1, - transfer_data: vec![11, 12, 13, 14], + transfer_data: vec![11, 12, 13, 14].into(), }; check_encode_decode(Message::Data(msg_example), raw_msg); diff --git a/crates/jmux-proxy/src/lib.rs b/crates/jmux-proxy/src/lib.rs index 3bb8c0ede..749e2c2af 100644 --- a/crates/jmux-proxy/src/lib.rs +++ b/crates/jmux-proxy/src/lib.rs @@ -217,8 +217,8 @@ impl JmuxCtx { type MessageReceiver = mpsc::UnboundedReceiver; type MessageSender = mpsc::UnboundedSender; -type DataReceiver = mpsc::UnboundedReceiver>; -type DataSender = mpsc::UnboundedSender>; +type DataReceiver = mpsc::UnboundedReceiver; +type DataSender = mpsc::UnboundedSender; type InternalMessageSender = mpsc::UnboundedSender; #[derive(Debug)] @@ -325,7 +325,7 @@ async fn scheduler_task_impl(task: JmuxSc JmuxApiRequest::Start { id, stream, leftover } => { let channel = jmux_ctx.get_channel(id).with_context(|| format!("couldn’t find channel with id {id}"))?; - let (data_tx, data_rx) = mpsc::unbounded_channel::>(); + let (data_tx, data_rx) = mpsc::unbounded_channel::(); if data_senders.insert(id, data_tx).is_some() { anyhow::bail!("detected two streams with the same ID {}", id); @@ -333,7 +333,7 @@ async fn scheduler_task_impl(task: JmuxSc // Send leftover bytes if any if let Some(leftover) = leftover { - if let Err(error) = msg_to_send_tx.send(Message::data(channel.distant_id, leftover.to_vec())) { + if let Err(error) = msg_to_send_tx.send(Message::data(channel.distant_id, leftover)) { error!(%error, "Couldn't send leftover bytes"); } ; } @@ -405,7 +405,7 @@ async fn scheduler_task_impl(task: JmuxSc let window_size = Arc::clone(&channel.window_size); let channel_span = channel.span.clone(); - let (data_tx, data_rx) = mpsc::unbounded_channel::>(); + let (data_tx, data_rx) = mpsc::unbounded_channel::(); if data_senders.insert(channel.local_id, data_tx).is_some() { anyhow::bail!("detected two streams with the same local ID {}", channel.local_id); @@ -746,7 +746,7 @@ impl DataReaderTask { trace!("Started forwarding"); while let Some(bytes) = bytes_stream.next().await { - let bytes = match bytes { + let mut bytes = match bytes { Ok(bytes) => bytes, Err(error) if is_really_an_error(&error) => { return Err(anyhow::Error::new(error).context("couldn’t read next bytes from stream")) @@ -759,12 +759,13 @@ impl DataReaderTask { let chunk_size = maximum_packet_size - Header::SIZE - ChannelData::FIXED_PART_SIZE; - let queue: Vec> = bytes.chunks(chunk_size).map(|slice| slice.to_vec()).collect(); + while !bytes.is_empty() { + let split_at = core::cmp::min(chunk_size, bytes.len()); + let mut chunk = bytes.split_to(split_at); - for mut bytes in queue { loop { let window_size_now = window_size.load(Ordering::SeqCst); - if window_size_now < bytes.len() { + if window_size_now < chunk.len() { trace!( window_size_now, full_packet_size = bytes.len(), @@ -772,10 +773,10 @@ impl DataReaderTask { ); if window_size_now > 0 { - let bytes_to_send_now: Vec = bytes.drain(..window_size_now).collect(); + let bytes_to_send_now = chunk.split_to(window_size_now); window_size.fetch_sub(bytes_to_send_now.len(), Ordering::SeqCst); msg_to_send_tx - .send(Message::data(distant_id, bytes_to_send_now)) + .send(Message::data(distant_id, bytes_to_send_now.freeze())) .context("couldn’t send DATA message")?; } @@ -783,7 +784,7 @@ impl DataReaderTask { } else { window_size.fetch_sub(bytes.len(), Ordering::SeqCst); msg_to_send_tx - .send(Message::data(distant_id, bytes)) + .send(Message::data(distant_id, chunk.freeze())) .context("couldn’t send DATA message")?; break; }