diff --git a/src/connections.rs b/src/connections.rs index bab7745f..1996f112 100644 --- a/src/connections.rs +++ b/src/connections.rs @@ -103,14 +103,18 @@ impl RecvStream { Self { quinn_recv_stream } } - /// Read next message from the stream + /// Read next user message from the stream pub async fn next(&mut self) -> Result { - match read_bytes(&mut self.quinn_recv_stream).await { - Ok(WireMsg::UserMsg(bytes)) => Ok(bytes), - Ok(msg) => Err(Error::UnexpectedMessageType(msg)), - Err(error) => Err(error), + match self.next_wire_msg().await? { + WireMsg::UserMsg(bytes) => Ok(bytes), + msg => Err(Error::UnexpectedMessageType(msg.into())), } } + + /// Read next wire msg from the stream + pub(crate) async fn next_wire_msg(&mut self) -> Result { + read_bytes(&mut self.quinn_recv_stream).await + } } impl Debug for RecvStream { @@ -149,7 +153,7 @@ impl SendStream { } /// Send a wire message - pub async fn send(&mut self, msg: WireMsg) -> Result<()> { + pub(crate) async fn send(&mut self, msg: WireMsg) -> Result<()> { msg.write_to_stream(&mut self.quinn_send_stream).await } diff --git a/src/endpoint.rs b/src/endpoint.rs index d9ba78cf..abd211f2 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -195,7 +195,7 @@ impl Endpoint { "Unexpected message when verifying public endpoint: {}", other ); - return Err(Error::UnexpectedMessageType(other)); + return Err(Error::UnexpectedMessageType(other.into())); } Ok(Err(err)) => { error!("Error while verifying Public IP Address"); @@ -416,25 +416,18 @@ impl Endpoint { match timeout( Duration::from_secs(ECHO_SERVICE_QUERY_TIMEOUT), - recv_stream.next(), + recv_stream.next_wire_msg(), ) .await { - Ok(Err(Error::UnexpectedMessageType(WireMsg::EndpointEchoResp(_)))) => Ok(()), - Ok(Err(Error::UnexpectedMessageType(other))) => { + Ok(Ok(WireMsg::EndpointEchoResp(_))) => Ok(()), + Ok(Ok(other)) => { info!( "Unexpected message type when verifying reachability: {}", &other ); Ok(()) } - Ok(Ok(bytes)) => { - info!( - "Unexpected message type when verifying reachability: {}", - WireMsg::UserMsg(bytes) - ); - Ok(()) - } Ok(Err(err)) => { info!("Unable to contact peer: {:?}", err); Err(err) @@ -527,7 +520,7 @@ impl Endpoint { send_stream.send(WireMsg::EndpointEchoReq).await?; match WireMsg::read_from_stream(&mut recv_stream.quinn_recv_stream).await { Ok(WireMsg::EndpointEchoResp(socket_addr)) => Ok(socket_addr), - Ok(msg) => Err(Error::UnexpectedMessageType(msg)), + Ok(msg) => Err(Error::UnexpectedMessageType(msg.into())), Err(err) => Err(err), } }); diff --git a/src/error.rs b/src/error.rs index 5b2ca019..e4c1677a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -78,8 +78,8 @@ pub enum Error { #[error("Empty response message received from peer")] EmptyResponse, /// The type of message received is not the expected one. - #[error("Type of the message received was not the expected one: {0}")] - UnexpectedMessageType(WireMsg), + #[error(transparent)] + UnexpectedMessageType(#[from] UnexpectedMessageType), /// The message exceeds the maximum message length allowed. #[error("Maximum data length exceeded, length: {0}")] MaxLengthExceeded(usize), @@ -300,3 +300,14 @@ impl fmt::Display for TransportErrorCode { write!(f, "{}", self.0) } } + +/// The type of message received is not the expected one. +#[derive(Debug, Error)] +#[error("The of the message received was not the expected one: {0}")] +pub struct UnexpectedMessageType(WireMsg); + +impl From for UnexpectedMessageType { + fn from(msg: WireMsg) -> Self { + UnexpectedMessageType(msg) + } +} diff --git a/src/lib.rs b/src/lib.rs index 46e33c88..488487cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -66,7 +66,7 @@ pub use connections::{DisconnectionEvents, RecvStream, SendStream}; pub use endpoint::{Endpoint, IncomingConnections, IncomingMessages}; pub use error::{ ClientEndpointError, Close, ConnectionError, Error, InternalConfigError, Result, - TransportErrorCode, + TransportErrorCode, UnexpectedMessageType, }; #[cfg(test)] diff --git a/src/wire_msg.rs b/src/wire_msg.rs index b75e44bb..25734d19 100644 --- a/src/wire_msg.rs +++ b/src/wire_msg.rs @@ -21,7 +21,7 @@ const MSG_PROTOCOL_VERSION: u16 = 0x0001; /// Final type serialised and sent on the wire by `QuicP2p` #[derive(Serialize, Deserialize, Debug, Clone)] -pub enum WireMsg { +pub(crate) enum WireMsg { EndpointEchoReq, EndpointEchoResp(SocketAddr), EndpointVerificationReq(SocketAddr), @@ -34,7 +34,7 @@ const ECHO_SRVC_MSG_FLAG: u8 = 0x01; impl WireMsg { // Read a message's bytes from the provided stream - pub async fn read_from_stream(recv: &mut quinn::RecvStream) -> Result { + pub(crate) async fn read_from_stream(recv: &mut quinn::RecvStream) -> Result { let mut header_bytes = [0; MSG_HEADER_LEN]; recv.read_exact(&mut header_bytes).await?; @@ -64,7 +64,7 @@ impl WireMsg { } // Helper to write WireMsg bytes to the provided stream. - pub async fn write_to_stream(&self, send_stream: &mut quinn::SendStream) -> Result<()> { + pub(crate) async fn write_to_stream(&self, send_stream: &mut quinn::SendStream) -> Result<()> { // Let's generate the message bytes let (msg_bytes, msg_flag) = match self { WireMsg::UserMsg(ref m) => (m.clone(), USER_MSG_FLAG),