diff --git a/src/connections.rs b/src/connections.rs index 22e3e346..a07f6d5e 100644 --- a/src/connections.rs +++ b/src/connections.rs @@ -10,7 +10,7 @@ use super::{ api::Message, error::{Error, Result}, - wire_msg::WireMsg, + wire_msg::{MsgHeader, WireMsg, HEADER_LEN}, }; use bytes::Bytes; use futures::{lock::Mutex, stream::StreamExt}; @@ -233,13 +233,15 @@ impl SendStream { // Helper to read the message's bytes from the provided stream async fn read_bytes(recv: &mut quinn::RecvStream) -> Result { - let mut data_len: [u8; 8] = [0; 8]; - recv.read_exact(&mut data_len).await?; - let data_len = usize::from_le_bytes(data_len); - let mut data: Vec = vec![0; data_len]; + let mut header_bytes = [0; HEADER_LEN]; + recv.read_exact(&mut header_bytes).await?; + + let msg_header = MsgHeader::from_bytes(header_bytes); + let mut data: Vec = vec![0; msg_header.data_len()]; + recv.read_exact(&mut data).await?; trace!("Got new message with {} bytes.", data.len()); - match WireMsg::from_raw(data)? { + match WireMsg::from_raw(data, msg_header.usr_msg_flag())? { WireMsg::UserMsg(msg_bytes) => Ok(Bytes::copy_from_slice(&msg_bytes)), WireMsg::EndpointEchoReq | WireMsg::EndpointEchoResp(_) => { // TODO: handle the echo request/response message @@ -253,20 +255,17 @@ async fn send_msg(send_stream: &mut quinn::SendStream, msg: Bytes) -> Result<()> // Let's generate the message bytes let wire_msg = WireMsg::UserMsg(msg); let (msg_bytes, msg_flag) = wire_msg.into(); + trace!("Sending message to remote peer ({} bytes)", msg_bytes.len()); - trace!("Sending message to remote peer ({} bytes)", msg_bytes.len(),); + let msg_header = MsgHeader::new(&msg_bytes, msg_flag)?; + let header_bytes = msg_header.to_bytes(); - // Send the length of the message + 1 (for the flag) - send_stream - .write_all(&(msg_bytes.len() + 1).to_le_bytes()) - .await?; + // Send the message header + send_stream.write_all(&header_bytes).await?; // Send message bytes over QUIC send_stream.write_all(&msg_bytes[..]).await?; - // Then send message flag over QUIC - send_stream.write_all(&[msg_flag]).await?; - trace!("Message was sent to remote peer"); Ok(()) diff --git a/src/error.rs b/src/error.rs index 3453ce01..6da42fa8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -83,6 +83,8 @@ pub enum Error { EmptyResponse, #[error(display = "Type of the message received was not the expected one")] UnexpectedMessageType, + #[error(display = "Maximum data length exceeded")] + MaxLengthExceeded, #[error(display = "Unexpected: {}", 0)] Unexpected(String), } diff --git a/src/wire_msg.rs b/src/wire_msg.rs index e4806405..390c1839 100644 --- a/src/wire_msg.rs +++ b/src/wire_msg.rs @@ -16,6 +16,73 @@ use serde::{Deserialize, Serialize}; use std::{fmt, net::SocketAddr}; use unwrap::unwrap; +pub(crate) const HEADER_LEN: usize = 9; +pub(crate) const VERSION: i16 = 0; +pub(crate) const MAX_DATA_LEN: usize = usize::from_be_bytes([0, 0, 0, 0, 255, 255, 255, 255]); + +/// Message Header that is sent over the wire +/// Format of the message header is as follows +/// | version | message length | usr_msg_flag | reserved | +/// | 2 bytes | 4 bytes | 1 byte | 2 bytes | +pub(crate) struct MsgHeader { + version: i16, + data_len: usize, + usr_msg_flag: u8, + #[allow(unused)] + reserved: [u8; 2], +} + +impl MsgHeader { + pub fn new(msg: &Bytes, usr_msg_flag: u8) -> Result { + let data_len = msg.len(); + if data_len > MAX_DATA_LEN { + return Err(Error::MaxLengthExceeded); + } + Ok(Self { + version: VERSION, + data_len, + usr_msg_flag, + reserved: [0, 0], + }) + } + + pub fn data_len(&self) -> usize { + self.data_len + } + + pub fn usr_msg_flag(&self) -> u8 { + self.usr_msg_flag + } + + pub fn to_bytes(&self) -> [u8; HEADER_LEN] { + let version = self.version.to_be_bytes(); + let data_len = self.data_len.to_be_bytes(); + [ + version[0], + version[1], + data_len[4], + data_len[5], + data_len[6], + data_len[7], + self.usr_msg_flag, + 0, + 0, + ] + } + + pub fn from_bytes(bytes: [u8; HEADER_LEN]) -> Self { + let version = i16::from_be_bytes([bytes[0], bytes[1]]); + let data_len = usize::from_be_bytes([0, 0, 0, 0, bytes[2], bytes[3], bytes[4], bytes[5]]); + let usr_msg_flag = bytes[6]; + Self { + version, + data_len, + usr_msg_flag, + reserved: [0, 0], + } + } +} + /// Final type serialised and sent on the wire by QuicP2p #[derive(Serialize, Deserialize, Debug, Clone)] pub enum WireMsg { @@ -39,16 +106,15 @@ impl Into<(Bytes, u8)> for WireMsg { } impl WireMsg { - pub fn from_raw(mut raw: Vec) -> Result { + pub fn from_raw(raw: Vec, msg_flag: u8) -> Result { if raw.is_empty() { Err(Error::EmptyResponse) + } else if msg_flag == USER_MSG_FLAG { + Ok(WireMsg::UserMsg(From::from(raw))) + } else if msg_flag == !USER_MSG_FLAG { + Ok(bincode::deserialize(&raw)?) } else { - let msg_flag = raw.pop(); - match msg_flag { - Some(flag) if flag == USER_MSG_FLAG => Ok(WireMsg::UserMsg(From::from(raw))), - Some(flag) if flag == !USER_MSG_FLAG => Ok(bincode::deserialize(&raw)?), - _ => Err(Error::InvalidWireMsgFlag), - } + Err(Error::InvalidWireMsgFlag) } } }