Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(jetsocat,dgw): remove to_vec calls from JMUX implementation #973

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crates/jmux-generators/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ pub fn message_window_adjust() -> impl Strategy<Value = Message> {
}

pub fn message_data() -> impl Strategy<Value = Message> {
(distant_channel_id(), vec(any::<u8>(), 0..512)).prop_map(|(distant_id, data)| Message::data(distant_id, data))
(distant_channel_id(), vec(any::<u8>(), 0..512))
.prop_map(|(distant_id, data)| Message::data(distant_id, Bytes::from(data)))
}

pub fn message_eof() -> impl Strategy<Value = Message> {
Expand Down
15 changes: 9 additions & 6 deletions crates/jmux-proto/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
//! [Specification document](https://github.com/Devolutions/devolutions-gateway/blob/master/crates/jmux-proto/spec/JMUX_Spec.md)

use bytes::{Buf as _, BufMut as _, Bytes, BytesMut};
use bytes::{Buf as _, BufMut as _};
use core::fmt;
use smol_str::SmolStr;

// We re-export these types, because they are used in the public API.
pub use bytes::{Bytes, BytesMut};

/// Distant identifier for a channel
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub struct DistantChannelId(u32);
Expand Down Expand Up @@ -235,7 +238,7 @@ impl Message {
Self::WindowAdjust(ChannelWindowAdjust::new(distant_id, window_adjustment))
}

pub fn data(id: DistantChannelId, data: Vec<u8>) -> Self {
pub fn data(id: DistantChannelId, data: Bytes) -> Self {
Self::Data(ChannelData::new(id, data))
}

Expand Down Expand Up @@ -654,7 +657,7 @@ impl ChannelWindowAdjust {
#[derive(PartialEq, Eq)]
pub struct ChannelData {
pub recipient_channel_id: u32,
pub transfer_data: Vec<u8>,
pub transfer_data: Bytes,
}

// We don't want to print `transfer_data` content (usually too big)
Expand All @@ -671,7 +674,7 @@ impl ChannelData {
pub const NAME: &'static str = "CHANNEL DATA";
pub const FIXED_PART_SIZE: usize = 4 /*recipientChannelId*/;

pub fn new(id: DistantChannelId, data: Vec<u8>) -> Self {
pub fn new(id: DistantChannelId, data: Bytes) -> Self {
ChannelData {
recipient_channel_id: u32::from(id),
transfer_data: data,
Expand All @@ -684,14 +687,14 @@ impl ChannelData {

pub fn encode(&self, buf: &mut BytesMut) {
buf.put_u32(self.recipient_channel_id);
buf.put(self.transfer_data.as_slice());
buf.put(self.transfer_data.slice(..));
}

pub fn decode(mut buf: Bytes) -> Result<Self, Error> {
ensure_size!(fixed Self in buf);
Ok(Self {
recipient_channel_id: buf.get_u32(),
transfer_data: buf.to_vec(),
transfer_data: buf,
})
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/jmux-proto/tests/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ pub fn channel_window_adjust() {
#[test]
pub fn error_on_oversized_packet() {
let mut buf = BytesMut::new();
let err = Message::data(DistantChannelId::from(1), vec![0; u16::MAX as usize])
let err = Message::data(DistantChannelId::from(1), vec![0; u16::MAX as usize].into())
.encode(&mut buf)
.err()
.unwrap();
Expand All @@ -171,7 +171,7 @@ pub fn channel_data() {

let msg_example = ChannelData {
recipient_channel_id: 1,
transfer_data: vec![11, 12, 13, 14],
transfer_data: vec![11, 12, 13, 14].into(),
};

check_encode_decode(Message::Data(msg_example), raw_msg);
Expand Down
25 changes: 13 additions & 12 deletions crates/jmux-proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ impl JmuxCtx {

type MessageReceiver = mpsc::UnboundedReceiver<Message>;
type MessageSender = mpsc::UnboundedSender<Message>;
type DataReceiver = mpsc::UnboundedReceiver<Vec<u8>>;
type DataSender = mpsc::UnboundedSender<Vec<u8>>;
type DataReceiver = mpsc::UnboundedReceiver<Bytes>;
type DataSender = mpsc::UnboundedSender<Bytes>;
type InternalMessageSender = mpsc::UnboundedSender<InternalMessage>;

#[derive(Debug)]
Expand Down Expand Up @@ -325,15 +325,15 @@ async fn scheduler_task_impl<T: AsyncRead + Unpin + Send + 'static>(task: JmuxSc
JmuxApiRequest::Start { id, stream, leftover } => {
let channel = jmux_ctx.get_channel(id).with_context(|| format!("couldn’t find channel with id {id}"))?;

let (data_tx, data_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let (data_tx, data_rx) = mpsc::unbounded_channel::<Bytes>();

if data_senders.insert(id, data_tx).is_some() {
anyhow::bail!("detected two streams with the same ID {}", id);
}

// Send leftover bytes if any
if let Some(leftover) = leftover {
if let Err(error) = msg_to_send_tx.send(Message::data(channel.distant_id, leftover.to_vec())) {
if let Err(error) = msg_to_send_tx.send(Message::data(channel.distant_id, leftover)) {
error!(%error, "Couldn't send leftover bytes");
} ;
}
Expand Down Expand Up @@ -405,7 +405,7 @@ async fn scheduler_task_impl<T: AsyncRead + Unpin + Send + 'static>(task: JmuxSc
let window_size = Arc::clone(&channel.window_size);
let channel_span = channel.span.clone();

let (data_tx, data_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let (data_tx, data_rx) = mpsc::unbounded_channel::<Bytes>();

if data_senders.insert(channel.local_id, data_tx).is_some() {
anyhow::bail!("detected two streams with the same local ID {}", channel.local_id);
Expand Down Expand Up @@ -746,7 +746,7 @@ impl DataReaderTask {
trace!("Started forwarding");

while let Some(bytes) = bytes_stream.next().await {
let bytes = match bytes {
let mut bytes = match bytes {
Ok(bytes) => bytes,
Err(error) if is_really_an_error(&error) => {
return Err(anyhow::Error::new(error).context("couldn’t read next bytes from stream"))
Expand All @@ -759,31 +759,32 @@ impl DataReaderTask {

let chunk_size = maximum_packet_size - Header::SIZE - ChannelData::FIXED_PART_SIZE;

let queue: Vec<Vec<u8>> = bytes.chunks(chunk_size).map(|slice| slice.to_vec()).collect();
while !bytes.is_empty() {
let split_at = core::cmp::min(chunk_size, bytes.len());
let mut chunk = bytes.split_to(split_at);

for mut bytes in queue {
loop {
let window_size_now = window_size.load(Ordering::SeqCst);
if window_size_now < bytes.len() {
if window_size_now < chunk.len() {
trace!(
window_size_now,
full_packet_size = bytes.len(),
"Window size insufficient to send full packet. Truncate and wait."
);

if window_size_now > 0 {
let bytes_to_send_now: Vec<u8> = bytes.drain(..window_size_now).collect();
let bytes_to_send_now = chunk.split_to(window_size_now);
window_size.fetch_sub(bytes_to_send_now.len(), Ordering::SeqCst);
msg_to_send_tx
.send(Message::data(distant_id, bytes_to_send_now))
.send(Message::data(distant_id, bytes_to_send_now.freeze()))
.context("couldn’t send DATA message")?;
}

window_size_updated.notified().await;
} else {
window_size.fetch_sub(bytes.len(), Ordering::SeqCst);
msg_to_send_tx
.send(Message::data(distant_id, bytes))
.send(Message::data(distant_id, chunk.freeze()))
.context("couldn’t send DATA message")?;
break;
}
Expand Down