diff --git a/crates/jmux-proto/src/lib.rs b/crates/jmux-proto/src/lib.rs index 93aca9af..26076193 100644 --- a/crates/jmux-proto/src/lib.rs +++ b/crates/jmux-proto/src/lib.rs @@ -160,7 +160,7 @@ impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Error::PacketOversized { packet_size, max } => { - write!(f, "Packet oversized: max is {max}, got {packet_size}") + write!(f, "packet oversized: max is {max}, got {packet_size}") } Error::NotEnoughBytes { name, @@ -168,13 +168,13 @@ impl fmt::Display for Error { expected, } => write!( f, - "Not enough bytes provided to decode {name}: received {received} bytes, expected {expected} bytes" + "not enough bytes provided to decode {name}: received {received} bytes, expected {expected} bytes" ), Error::InvalidPacket { name, field, reason } => { - write!(f, "Invalid `{field}` in {name}: {reason}") + write!(f, "invalid `{field}` in {name}: {reason}") } Error::InvalidDestinationUrl { value, reason } => { - write!(f, "Invalid destination URL `{value}`: {reason}") + write!(f, "invalid destination URL `{value}`: {reason}") } } } diff --git a/crates/jmux-proto/tests/message.rs b/crates/jmux-proto/tests/message.rs index e56b5263..72abe019 100644 --- a/crates/jmux-proto/tests/message.rs +++ b/crates/jmux-proto/tests/message.rs @@ -35,7 +35,7 @@ fn message_type_try_err_on_invalid_bytes() { fn header_decode_buffer_too_short_err() { let err = Header::decode(Bytes::from_static(&[])).err().unwrap(); assert_eq!( - "Not enough bytes provided to decode HEADER: received 0 bytes, expected 4 bytes", + "not enough bytes provided to decode HEADER: received 0 bytes, expected 4 bytes", err.to_string() ); } @@ -156,7 +156,7 @@ pub fn error_on_oversized_packet() { .encode(&mut buf) .err() .unwrap(); - assert_eq!("Packet oversized: max is 65535, got 65543", err.to_string()); + assert_eq!("packet oversized: max is 65535, got 65543", err.to_string()); } #[test] diff --git a/crates/jmux-proxy/src/lib.rs b/crates/jmux-proxy/src/lib.rs index 3e7c7454..7b464b79 100644 --- a/crates/jmux-proxy/src/lib.rs +++ b/crates/jmux-proxy/src/lib.rs @@ -1,4 +1,6 @@ -//! [Specification document](https://github.com/awakecoding/qmux/blob/protocol-update/SPEC.md) +//! [Specification document][source] +//! +//! [source]: https://github.com/Devolutions/devolutions-gateway/blob/master/docs/JMUX-spec.md #[macro_use] extern crate tracing; @@ -14,7 +16,6 @@ use self::codec::JmuxCodec; use self::id_allocator::IdAllocator; use anyhow::Context as _; use bytes::Bytes; -use futures_util::{SinkExt, StreamExt}; use jmux_proto::{ChannelData, DistantChannelId, Header, LocalChannelId, Message, ReasonCode}; use std::collections::HashMap; use std::convert::TryFrom; @@ -26,7 +27,7 @@ use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot, Notify}; use tokio::task::JoinHandle; -use tokio_util::codec::{FramedRead, FramedWrite}; +use tokio_util::codec::FramedRead; use tracing::{Instrument as _, Span}; // PERF/FIXME: changing this parameter to 16 * 1024 greatly improves the throughput, @@ -104,12 +105,6 @@ impl JmuxProxy { self } - // TODO: consider using something like ChildTask more widely in Devolutions Gateway - pub fn spawn(self) -> JoinHandle> { - let fut = self.run(); - tokio::spawn(fut) - } - pub async fn run(self) -> anyhow::Result<()> { let span = Span::current(); run_proxy_impl(self, span.clone()).instrument(span).await @@ -127,10 +122,9 @@ async fn run_proxy_impl(proxy: JmuxProxy, span: Span) -> anyhow::Result<()> { let (msg_to_send_tx, msg_to_send_rx) = mpsc::unbounded_channel::(); let jmux_stream = FramedRead::new(jmux_reader, JmuxCodec); - let jmux_sink = FramedWrite::new(jmux_writer, JmuxCodec); let sender_task_handle = JmuxSenderTask { - jmux_sink, + jmux_writer, msg_to_send_rx, } .spawn(span.clone()); @@ -243,7 +237,7 @@ enum InternalMessage { // ---------------------- // struct JmuxSenderTask { - jmux_sink: FramedWrite, + jmux_writer: T, msg_to_send_rx: MessageReceiver, } @@ -256,16 +250,26 @@ impl JmuxSenderTask { #[instrument("sender", skip_all)] async fn run(self) -> anyhow::Result<()> { let Self { - mut jmux_sink, + mut jmux_writer, mut msg_to_send_rx, } = self; + let mut buf = bytes::BytesMut::new(); + while let Some(msg) = msg_to_send_rx.recv().await { trace!(?msg, "Send channel message"); - jmux_sink.feed(msg).await?; - jmux_sink.flush().await?; + + buf.clear(); + msg.encode(&mut buf)?; + + jmux_writer.write_all(&buf).await?; + + jmux_writer.flush().await?; } + // TODO: send a signal to the main scheduler when we are done processing channel data messages + // and adjust windows for all the channels only then. + info!("Closing JMUX sender task..."); Ok(()) @@ -292,6 +296,8 @@ impl JmuxSchedulerTask { #[instrument("scheduler", skip_all)] async fn scheduler_task_impl(task: JmuxSchedulerTask) -> anyhow::Result<()> { + use futures_util::StreamExt as _; + let JmuxSchedulerTask { cfg, mut jmux_stream, @@ -342,7 +348,7 @@ async fn scheduler_task_impl(task: JmuxSc anyhow::bail!("detected two streams with the same ID {}", id); } - // Send leftover bytes if any + // Send leftover bytes if any. if let Some(leftover) = leftover { if let Err(error) = msg_to_send_tx.send(Message::data(channel.distant_id, leftover)) { error!(%error, "Couldn't send leftover bytes"); @@ -483,9 +489,8 @@ async fn scheduler_task_impl(task: JmuxSc nb_consecutive_pipe_failures += 1; if nb_consecutive_pipe_failures > MAX_CONSECUTIVE_PIPE_FAILURES { - // Some underlying `AsyncRead` implementations might handle errors poorly - // and cause infinite polling on errors such as broken pipe (this should - // stop instead of returning the same error indefinitely). + // Some underlying `AsyncRead` implementations might handle errors poorly and cause infinite polling on errors such as broken pipe. + // (This should stop instead of returning the same error indefinitely.) // Hence, this safety net to escape from such infinite loops. anyhow::bail!("forced JMUX proxy shutdown because of too many consecutive pipe failures"); } else { @@ -744,6 +749,8 @@ impl DataReaderTask { } async fn run(self) -> anyhow::Result<()> { + use futures_util::StreamExt as _; + let Self { reader, local_id,