From e94e21ba39ed14851cd37262bcadfd42c9fd1ef8 Mon Sep 17 00:00:00 2001 From: "Roman S. Borschel" Date: Tue, 16 Jul 2019 10:46:47 +0200 Subject: [PATCH 1/5] Remove tokio-codec dependency from multistream-select. In preparation for the eventual switch from tokio to std futures. Includes some initial refactoring in preparation for further work in the context of https://github.com/libp2p/rust-libp2p/issues/659. --- misc/multistream-select/Cargo.toml | 3 +- misc/multistream-select/src/dialer_select.rs | 43 ++-- misc/multistream-select/src/error.rs | 19 +- .../src/length_delimited.rs | 239 +++++++++--------- misc/multistream-select/src/lib.rs | 11 +- .../multistream-select/src/listener_select.rs | 16 +- .../multistream-select/src/protocol/dialer.rs | 134 +++++----- misc/multistream-select/src/protocol/error.rs | 28 +- .../src/protocol/listener.rs | 140 +++++----- misc/multistream-select/src/protocol/mod.rs | 22 +- misc/multistream-select/src/tests.rs | 12 +- 11 files changed, 323 insertions(+), 344 deletions(-) diff --git a/misc/multistream-select/Cargo.toml b/misc/multistream-select/Cargo.toml index bfa652f03c0..83de968e89d 100644 --- a/misc/multistream-select/Cargo.toml +++ b/misc/multistream-select/Cargo.toml @@ -14,9 +14,8 @@ bytes = "0.4" futures = { version = "0.1" } log = "0.4" smallvec = "0.6" -tokio-codec = "0.1" tokio-io = "0.1" -unsigned-varint = { version = "0.2.1", features = ["codec"] } +unsigned-varint = { version = "0.2.2" } [dev-dependencies] tokio = "0.1" diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index bbd40c1a200..59bda8ad396 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -22,19 +22,14 @@ //! `multistream-select` for the dialer. use futures::{future::Either, prelude::*, stream::StreamFuture}; -use crate::protocol::{ - Dialer, - DialerFuture, - DialerToListenerMessage, - ListenerToDialerMessage -}; +use crate::protocol::{Dialer, DialerFuture, Request, Response}; use log::trace; use std::mem; use tokio_io::{AsyncRead, AsyncWrite}; use crate::{Negotiated, ProtocolChoiceError}; /// Future, returned by `dialer_select_proto`, which selects a protocol and dialer -/// either sequentially of by considering all protocols in parallel. +/// either sequentially or by considering all protocols in parallel. pub type DialerSelectFuture = Either, DialerSelectPar>; /// Helps selecting a protocol amongst the ones supported. @@ -75,7 +70,10 @@ where { let protocols = protocols.into_iter(); DialerSelectSeq { - inner: DialerSelectSeqState::AwaitDialer { dialer_fut: Dialer::dial(inner), protocols } + inner: DialerSelectSeqState::AwaitDialer { + dialer_fut: Dialer::dial(inner), + protocols + } } } @@ -148,9 +146,7 @@ where } DialerSelectSeqState::NextProtocol { mut dialer, protocols, proto_name } => { trace!("sending {:?}", proto_name.as_ref()); - let req = DialerToListenerMessage::ProtocolRequest { - name: proto_name.clone() - }; + let req = Request::Protocol { name: proto_name.clone() }; match dialer.start_send(req)? { AsyncSink::Ready => { self.inner = DialerSelectSeqState::FlushProtocol { @@ -204,12 +200,12 @@ where }; trace!("received {:?}", m); match m.ok_or(ProtocolChoiceError::UnexpectedMessage)? { - ListenerToDialerMessage::ProtocolAck { ref name } + Response::Protocol { ref name } if name.as_ref() == proto_name.as_ref() => { return Ok(Async::Ready((proto_name, Negotiated(r.into_inner())))) } - ListenerToDialerMessage::NotAvailable => { + Response::ProtocolNotAvailable => { let proto_name = protocols.next() .ok_or(ProtocolChoiceError::NoProtocolFound)?; self.inner = DialerSelectSeqState::NextProtocol { @@ -244,9 +240,8 @@ where } } - /// Future, returned by `dialer_select_proto_parallel`, which selects a protocol and dialer in -/// parellel, by first requesting the liste of protocols supported by the remote endpoint and +/// parallel, by first requesting the list of protocols supported by the remote endpoint and /// then selecting the most appropriate one by applying a match predicate to the result. pub struct DialerSelectPar where @@ -319,7 +314,7 @@ where } DialerSelectParState::ProtocolList { mut dialer, protocols } => { trace!("requesting protocols list"); - match dialer.start_send(DialerToListenerMessage::ProtocolsListRequest)? { + match dialer.start_send(Request::ListProtocols)? { AsyncSink::Ready => { self.inner = DialerSelectParState::FlushListRequest { dialer, @@ -359,15 +354,15 @@ where Err((e, _)) => return Err(ProtocolChoiceError::from(e)) }; trace!("protocols list response: {:?}", resp); - let list = - if let Some(ListenerToDialerMessage::ProtocolsListResponse { list }) = resp { - list + let supported = + if let Some(Response::SupportedProtocols { protocols }) = resp { + protocols } else { return Err(ProtocolChoiceError::UnexpectedMessage) }; let mut found = None; for local_name in protocols { - for remote_name in &list { + for remote_name in &supported { if remote_name.as_ref() == local_name.as_ref() { found = Some(local_name); break; @@ -381,10 +376,8 @@ where self.inner = DialerSelectParState::Protocol { dialer, proto_name } } DialerSelectParState::Protocol { mut dialer, proto_name } => { - trace!("requesting protocol: {:?}", proto_name.as_ref()); - let req = DialerToListenerMessage::ProtocolRequest { - name: proto_name.clone() - }; + trace!("Requesting protocol: {:?}", proto_name.as_ref()); + let req = Request::Protocol { name: proto_name.clone() }; match dialer.start_send(req)? { AsyncSink::Ready => { self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name } @@ -420,7 +413,7 @@ where }; trace!("received {:?}", resp); match resp { - Some(ListenerToDialerMessage::ProtocolAck { ref name }) + Some(Response::Protocol { ref name }) if name.as_ref() == proto_name.as_ref() => { return Ok(Async::Ready((proto_name, Negotiated(dialer.into_inner())))) diff --git a/misc/multistream-select/src/error.rs b/misc/multistream-select/src/error.rs index 62b540ec502..1f72b5c0c8a 100644 --- a/misc/multistream-select/src/error.rs +++ b/misc/multistream-select/src/error.rs @@ -21,9 +21,8 @@ //! Main `ProtocolChoiceError` error. use crate::protocol::MultistreamSelectError; -use std::error; -use std::fmt; -use std::io::Error as IoError; +use std::error::Error; +use std::{fmt, io}; /// Error that can happen when negotiating a protocol with the remote. #[derive(Debug)] @@ -39,21 +38,18 @@ pub enum ProtocolChoiceError { } impl From for ProtocolChoiceError { - #[inline] fn from(err: MultistreamSelectError) -> ProtocolChoiceError { ProtocolChoiceError::MultistreamSelectError(err) } } -impl From for ProtocolChoiceError { - #[inline] - fn from(err: IoError) -> ProtocolChoiceError { +impl From for ProtocolChoiceError { + fn from(err: io::Error) -> ProtocolChoiceError { MultistreamSelectError::from(err).into() } } -impl error::Error for ProtocolChoiceError { - #[inline] +impl Error for ProtocolChoiceError { fn description(&self) -> &str { match *self { ProtocolChoiceError::MultistreamSelectError(_) => "error in the protocol", @@ -66,7 +62,7 @@ impl error::Error for ProtocolChoiceError { } } - fn cause(&self) -> Option<&dyn error::Error> { + fn source(&self) -> Option<&(dyn Error + 'static)> { match *self { ProtocolChoiceError::MultistreamSelectError(ref err) => Some(err), _ => None, @@ -75,8 +71,7 @@ impl error::Error for ProtocolChoiceError { } impl fmt::Display for ProtocolChoiceError { - #[inline] fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - write!(fmt, "{}", error::Error::description(self)) + write!(fmt, "{}", Error::description(self)) } } diff --git a/misc/multistream-select/src/length_delimited.rs b/misc/multistream-select/src/length_delimited.rs index 72256b83980..44a8ff36090 100644 --- a/misc/multistream-select/src/length_delimited.rs +++ b/misc/multistream-select/src/length_delimited.rs @@ -18,55 +18,59 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use bytes::Bytes; -use futures::{Async, Poll, Sink, StartSend, Stream}; -use smallvec::SmallVec; +use bytes::{Bytes, BytesMut, BufMut}; +use futures::{try_ready, Async, Poll, Sink, StartSend, Stream, AsyncSink}; use std::{io, u16}; -use tokio_codec::{Encoder, FramedWrite}; use tokio_io::{AsyncRead, AsyncWrite}; -use unsigned_varint::decode; +use unsigned_varint as uvi; -/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` object to read +const MAX_LEN_BYTES: u16 = 2; +const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1; + +/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` resource to read /// and write unsigned-varint prefixed frames. /// -/// We purposely only support a frame length of under 64kiB. Frames mostly consist -/// in a short protocol name, which is highly unlikely to be more than 64kiB long. -pub struct LengthDelimited { - // The inner socket where data is pulled from. - inner: FramedWrite, - // Intermediary buffer where we put either the length of the next frame of data, or the frame - // of data itself before it is returned. - // Must always contain enough space to read data from `inner`. - internal_buffer: SmallVec<[u8; 64]>, - // Number of bytes within `internal_buffer` that contain valid data. - internal_buffer_pos: usize, - // State of the decoder. - state: State +/// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint +/// frame length). Frames mostly consist in a short protocol name, which is highly +/// unlikely to be more than 16KiB long. +pub struct LengthDelimited { + /// The inner I/O resource. + inner: R, + /// Read buffer for a single unsigned-varint length-delimited frame. + read_buffer: BytesMut, + /// Write buffer for a single unsigned-varint length-delimited frame. + write_buffer: BytesMut, + /// The current read state, alternating between reading a frame + /// length and reading a frame payload. + read_state: ReadState, } #[derive(Debug, Copy, Clone, PartialEq, Eq)] -enum State { - // We are currently reading the length of the next frame of data. - ReadingLength, - // We are currently reading the frame of data itself. - ReadingData { frame_len: u16 }, +enum ReadState { + /// We are currently reading the length of the next frame of data. + ReadLength { buf: [u8; MAX_LEN_BYTES as usize], pos: usize }, + /// We are currently reading the frame of data itself. + ReadData { len: u16, pos: usize }, } -impl LengthDelimited -where - R: AsyncWrite, - C: Encoder -{ - pub fn new(inner: R, codec: C) -> LengthDelimited { +impl Default for ReadState { + fn default() -> Self { + ReadState::ReadLength { + buf: [0; MAX_LEN_BYTES as usize], + pos: 0 + } + } +} + +impl LengthDelimited { + /// Creates a new I/O resource for reading and writing unsigned-varint + /// length delimited frames. + pub fn new(inner: R) -> LengthDelimited { LengthDelimited { - inner: FramedWrite::new(inner, codec), - internal_buffer: { - let mut v = SmallVec::new(); - v.push(0); - v - }, - internal_buffer_pos: 0, - state: State::ReadingLength + inner, + read_state: ReadState::default(), + read_buffer: BytesMut::with_capacity(MAX_FRAME_SIZE as usize), + write_buffer: BytesMut::with_capacity((MAX_FRAME_SIZE + MAX_LEN_BYTES) as usize), } } @@ -81,15 +85,14 @@ where /// you call `poll()` manually**. Using this struct as it is intended to be used (i.e. through /// the modifiers provided by the `futures` crate) will always leave the object in a state in /// which `into_inner()` will not panic. - #[inline] pub fn into_inner(self) -> R { - assert_eq!(self.state, State::ReadingLength); - assert_eq!(self.internal_buffer_pos, 0); - self.inner.into_inner() + assert!(self.write_buffer.is_empty()); + assert!(self.read_buffer.is_empty()); + self.inner } } -impl Stream for LengthDelimited +impl Stream for LengthDelimited where R: AsyncRead { @@ -98,16 +101,11 @@ where fn poll(&mut self) -> Poll, Self::Error> { loop { - debug_assert!(!self.internal_buffer.is_empty()); - debug_assert!(self.internal_buffer_pos < self.internal_buffer.len()); - - match self.state { - State::ReadingLength => { - let slice = &mut self.internal_buffer[self.internal_buffer_pos..]; - match self.inner.get_mut().read(slice) { + match &mut self.read_state { + ReadState::ReadLength { buf, pos } => { + match self.inner.read(&mut buf[*pos .. *pos + 1]) { Ok(0) => { - // EOF - if self.internal_buffer_pos == 0 { + if *pos == 0 { return Ok(Async::Ready(None)); } else { return Err(io::ErrorKind::UnexpectedEof.into()); @@ -115,7 +113,7 @@ where } Ok(n) => { debug_assert_eq!(n, 1); - self.internal_buffer_pos += n; + *pos += n; } Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { return Ok(Async::NotReady); @@ -125,56 +123,45 @@ where } }; - debug_assert_eq!(self.internal_buffer.len(), self.internal_buffer_pos); - - if (*self.internal_buffer.last().unwrap_or(&0) & 0x80) == 0 { - // End of length prefix. Most of the time we will switch to reading data, - // but we need to handle a few corner cases first. - let (frame_len, _) = decode::u16(&self.internal_buffer).map_err(|e| { + if (buf[*pos - 1] & 0x80) == 0 { + // MSB is not set, indicating the end of the length prefix. + let (len, _) = uvi::decode::u16(buf).map_err(|e| { log::debug!("invalid length prefix: {}", e); io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix") })?; - if frame_len >= 1 { - self.state = State::ReadingData { frame_len }; - self.internal_buffer.clear(); - self.internal_buffer.reserve(frame_len as usize); - self.internal_buffer.extend((0..frame_len).map(|_| 0)); - self.internal_buffer_pos = 0; + if len >= 1 { + self.read_state = ReadState::ReadData { len, pos: 0 }; + self.read_buffer.resize(len as usize, 0); } else { - debug_assert_eq!(frame_len, 0); - self.state = State::ReadingLength; - self.internal_buffer.clear(); - self.internal_buffer.push(0); - self.internal_buffer_pos = 0; - return Ok(Async::Ready(Some(From::from(&[][..])))); + debug_assert_eq!(len, 0); + self.read_state = ReadState::default(); + return Ok(Async::Ready(Some(Bytes::new()))); } - } else if self.internal_buffer_pos >= 2 { - // Length prefix is too long. See module doc for info about max frame len. - return Err(io::Error::new(io::ErrorKind::InvalidData, "frame length too long")); - } else { - // Prepare for next read. - self.internal_buffer.push(0); + } else if *pos == MAX_LEN_BYTES as usize { + // MSB signals more length bytes but we have already read the maximum. + // See the module documentation about the max frame len. + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Maximum frame length exceeded")); } } - State::ReadingData { frame_len } => { - let slice = &mut self.internal_buffer[self.internal_buffer_pos..]; - match self.inner.get_mut().read(slice) { + ReadState::ReadData { len, pos } => { + match self.inner.read(&mut self.read_buffer[*pos..]) { Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()), - Ok(n) => self.internal_buffer_pos += n, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - return Ok(Async::NotReady) - } - Err(err) => return Err(err) + Ok(n) => *pos += n, + Err(err) => + if err.kind() == io::ErrorKind::WouldBlock { + return Ok(Async::NotReady) + } else { + return Err(err) + } }; - if self.internal_buffer_pos >= frame_len as usize { - // Finished reading the frame of data. - self.state = State::ReadingLength; - let out_data = From::from(&self.internal_buffer[..]); - self.internal_buffer.clear(); - self.internal_buffer.push(0); - self.internal_buffer_pos = 0; - return Ok(Async::Ready(Some(out_data))); + if *pos == *len as usize { + // Finished reading the frame. + let frame = self.read_buffer.split_off(0).freeze(); + self.read_state = ReadState::default(); + return Ok(Async::Ready(Some(frame))); } } } @@ -182,27 +169,54 @@ where } } -impl Sink for LengthDelimited +impl Sink for LengthDelimited where R: AsyncWrite, - C: Encoder { - type SinkItem = as Sink>::SinkItem; - type SinkError = as Sink>::SinkError; + type SinkItem = Bytes; + type SinkError = io::Error; + + fn start_send(&mut self, msg: Self::SinkItem) -> StartSend { + if !self.write_buffer.is_empty() { + self.poll_complete()?; + if !self.write_buffer.is_empty() { + return Ok(AsyncSink::NotReady(msg)) + } + } + + let len = msg.len() as u16; + if len > MAX_FRAME_SIZE { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Maximum frame size exceeded.")) + } - #[inline] - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - self.inner.start_send(item) + self.write_buffer.put_slice(uvi::encode::u16(len, &mut [0; 3])); + self.write_buffer.extend(msg); + + Ok(AsyncSink::Ready) } - #[inline] fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - self.inner.poll_complete() + while !self.write_buffer.is_empty() { + let n = try_ready!(self.inner.poll_write(&self.write_buffer)); + + if n == 0 { + return Err(io::Error::new( + io::ErrorKind::WriteZero, + "Failed to write buffered frame.")) + } + + let _ = self.write_buffer.split_to(n); + } + + try_ready!(self.inner.poll_flush()); + return Ok(Async::Ready(())); } - #[inline] fn close(&mut self) -> Poll<(), Self::SinkError> { - self.inner.close() + try_ready!(self.poll_complete()); + Ok(self.inner.shutdown()?) } } @@ -211,12 +225,11 @@ mod tests { use futures::{Future, Stream}; use crate::length_delimited::LengthDelimited; use std::io::{Cursor, ErrorKind}; - use unsigned_varint::codec::UviBytes; #[test] fn basic_read() { let data = vec![6, 9, 8, 7, 6, 5, 4]; - let framed = LengthDelimited::new(Cursor::new(data), UviBytes::>::default()); + let framed = LengthDelimited::new(Cursor::new(data)); let recved = framed.collect().wait().unwrap(); assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]); } @@ -224,7 +237,7 @@ mod tests { #[test] fn basic_read_two() { let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7]; - let framed = LengthDelimited::new(Cursor::new(data), UviBytes::>::default()); + let framed = LengthDelimited::new(Cursor::new(data)); let recved = framed.collect().wait().unwrap(); assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]); } @@ -236,7 +249,7 @@ mod tests { let frame = (0..len).map(|n| (n & 0xff) as u8).collect::>(); let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8]; data.extend(frame.clone().into_iter()); - let framed = LengthDelimited::new(Cursor::new(data), UviBytes::>::default()); + let framed = LengthDelimited::new(Cursor::new(data)); let recved = framed .into_future() .map(|(m, _)| m) @@ -250,7 +263,7 @@ mod tests { fn packet_len_too_long() { let mut data = vec![0x81, 0x81, 0x1]; data.extend((0..16513).map(|_| 0)); - let framed = LengthDelimited::new(Cursor::new(data), UviBytes::>::default()); + let framed = LengthDelimited::new(Cursor::new(data)); let recved = framed .into_future() .map(|(m, _)| m) @@ -267,7 +280,7 @@ mod tests { #[test] fn empty_frames() { let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7]; - let framed = LengthDelimited::new(Cursor::new(data), UviBytes::>::default()); + let framed = LengthDelimited::new(Cursor::new(data)); let recved = framed.collect().wait().unwrap(); assert_eq!( recved, @@ -284,7 +297,7 @@ mod tests { #[test] fn unexpected_eof_in_len() { let data = vec![0x89]; - let framed = LengthDelimited::new(Cursor::new(data), UviBytes::>::default()); + let framed = LengthDelimited::new(Cursor::new(data)); let recved = framed.collect().wait(); if let Err(io_err) = recved { assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof) @@ -296,7 +309,7 @@ mod tests { #[test] fn unexpected_eof_in_data() { let data = vec![5]; - let framed = LengthDelimited::new(Cursor::new(data), UviBytes::>::default()); + let framed = LengthDelimited::new(Cursor::new(data)); let recved = framed.collect().wait(); if let Err(io_err) = recved { assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof) @@ -308,7 +321,7 @@ mod tests { #[test] fn unexpected_eof_in_data2() { let data = vec![5, 9, 8, 7]; - let framed = LengthDelimited::new(Cursor::new(data), UviBytes::>::default()); + let framed = LengthDelimited::new(Cursor::new(data)); let recved = framed.collect().wait(); if let Err(io_err) = recved { assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof) diff --git a/misc/multistream-select/src/lib.rs b/misc/multistream-select/src/lib.rs index 746f167608c..b2bc054bfff 100644 --- a/misc/multistream-select/src/lib.rs +++ b/misc/multistream-select/src/lib.rs @@ -21,7 +21,7 @@ //! # Multistream-select //! //! This crate implements the `multistream-select` protocol, which is the protocol used by libp2p -//! to negotiate which protocol to use with the remote. +//! to negotiate which protocol to use with the remote on a connection or substream. //! //! > **Note**: This crate is used by the internals of *libp2p*, and it is not required to //! > understand it in order to use *libp2p*. @@ -76,6 +76,7 @@ mod protocol; use futures::prelude::*; use std::io; +use tokio_io::{AsyncRead, AsyncWrite}; pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture}; pub use self::error::ProtocolChoiceError; @@ -93,9 +94,9 @@ where } } -impl tokio_io::AsyncRead for Negotiated +impl AsyncRead for Negotiated where - TInner: tokio_io::AsyncRead + TInner: AsyncRead { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { self.0.prepare_uninitialized_buffer(buf) @@ -119,9 +120,9 @@ where } } -impl tokio_io::AsyncWrite for Negotiated +impl AsyncWrite for Negotiated where - TInner: tokio_io::AsyncWrite + TInner: AsyncWrite { fn shutdown(&mut self) -> Poll<(), io::Error> { self.0.shutdown() diff --git a/misc/multistream-select/src/listener_select.rs b/misc/multistream-select/src/listener_select.rs index 59492dc0722..40ed92d057e 100644 --- a/misc/multistream-select/src/listener_select.rs +++ b/misc/multistream-select/src/listener_select.rs @@ -23,10 +23,10 @@ use futures::{prelude::*, sink, stream::StreamFuture}; use crate::protocol::{ - DialerToListenerMessage, + Request, + Response, Listener, ListenerFuture, - ListenerToDialerMessage }; use log::{debug, trace}; use std::mem; @@ -126,13 +126,13 @@ where Err((e, _)) => return Err(ProtocolChoiceError::from(e)) }; match msg { - Some(DialerToListenerMessage::ProtocolsListRequest) => { + Some(Request::ListProtocols) => { trace!("protocols list response: {:?}", protocols .into_iter() .map(|p| p.as_ref().into()) .collect::>>()); - let list = protocols.into_iter().collect(); - let msg = ListenerToDialerMessage::ProtocolsListResponse { list }; + let supported = protocols.into_iter().collect(); + let msg = Response::SupportedProtocols { protocols: supported }; let sender = listener.send(msg); self.inner = ListenerSelectState::Outgoing { sender, @@ -140,12 +140,12 @@ where outcome: None } } - Some(DialerToListenerMessage::ProtocolRequest { name }) => { + Some(Request::Protocol { name }) => { let mut outcome = None; - let mut send_back = ListenerToDialerMessage::NotAvailable; + let mut send_back = Response::ProtocolNotAvailable; for supported in &protocols { if name.as_ref() == supported.as_ref() { - send_back = ListenerToDialerMessage::ProtocolAck { + send_back = Response::Protocol { name: supported.clone() }; outcome = Some(supported); diff --git a/misc/multistream-select/src/protocol/dialer.rs b/misc/multistream-select/src/protocol/dialer.rs index 71c9d21077f..d2d732a46ac 100644 --- a/misc/multistream-select/src/protocol/dialer.rs +++ b/misc/multistream-select/src/protocol/dialer.rs @@ -20,23 +20,25 @@ //! Contains the `Dialer` wrapper, which allows raw communications with a listener. +use super::*; + use bytes::{BufMut, Bytes, BytesMut}; use crate::length_delimited::LengthDelimited; -use crate::protocol::DialerToListenerMessage; -use crate::protocol::ListenerToDialerMessage; -use crate::protocol::MultistreamSelectError; -use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF; +use crate::protocol::{Request, Response, MultistreamSelectError}; use futures::{prelude::*, sink, Async, StartSend, try_ready}; -use std::io; -use tokio_codec::Encoder; use tokio_io::{AsyncRead, AsyncWrite}; -use unsigned_varint::{decode, codec::Uvi}; +use std::marker; +use unsigned_varint as uvi; + +/// The maximum number of supported protocols that can be processed. +const MAX_PROTOCOLS: usize = 1000; /// Wraps around a `AsyncRead+AsyncWrite`. /// Assumes that we're on the dialer's side. Produces and accepts messages. pub struct Dialer { - inner: LengthDelimited>, - handshake_finished: bool + inner: LengthDelimited, + handshake_finished: bool, + _protocol_name: marker::PhantomData, } impl Dialer @@ -45,15 +47,16 @@ where N: AsRef<[u8]> { pub fn dial(inner: R) -> DialerFuture { - let codec = MessageEncoder(std::marker::PhantomData); - let sender = LengthDelimited::new(inner, codec); + let sender = LengthDelimited::new(inner); + let mut buf = BytesMut::new(); + let _ = Message::::Header.encode(&mut buf); DialerFuture { - inner: sender.send(Message::Header) + inner: sender.send(buf.freeze()), + _protocol_name: marker::PhantomData, } } /// Grants back the socket. Typically used after a `ProtocolAck` has been received. - #[inline] pub fn into_inner(self) -> R { self.inner.into_inner() } @@ -64,24 +67,22 @@ where R: AsyncRead + AsyncWrite, N: AsRef<[u8]> { - type SinkItem = DialerToListenerMessage; + type SinkItem = Request; type SinkError = MultistreamSelectError; - #[inline] fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - match self.inner.start_send(Message::Body(item))? { - AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)), - AsyncSink::NotReady(Message::Header) => unreachable!(), - AsyncSink::Ready => Ok(AsyncSink::Ready) + let mut msg = BytesMut::new(); + Message::Body(&item).encode(&mut msg)?; + match self.inner.start_send(msg.freeze())? { + AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(item)), + AsyncSink::Ready => Ok(AsyncSink::Ready), } } - #[inline] fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { Ok(self.inner.poll_complete()?) } - #[inline] fn close(&mut self) -> Poll<(), Self::SinkError> { Ok(self.inner.close()?) } @@ -91,20 +92,20 @@ impl Stream for Dialer where R: AsyncRead + AsyncWrite { - type Item = ListenerToDialerMessage; + type Item = Response; type Error = MultistreamSelectError; fn poll(&mut self) -> Poll, Self::Error> { loop { - let mut frame = match self.inner.poll() { - Ok(Async::Ready(Some(frame))) => frame, + let mut msg = match self.inner.poll() { + Ok(Async::Ready(Some(msg))) => msg, Ok(Async::Ready(None)) => return Ok(Async::Ready(None)), Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => return Err(err.into()), }; if !self.handshake_finished { - if frame == MULTISTREAM_PROTOCOL_WITH_LF { + if msg == MSG_MULTISTREAM_1_0 { self.handshake_finished = true; continue; } else { @@ -112,31 +113,31 @@ where } } - if frame.get(0) == Some(&b'/') && frame.last() == Some(&b'\n') { - let frame_len = frame.len(); - let protocol = frame.split_to(frame_len - 1); - return Ok(Async::Ready(Some(ListenerToDialerMessage::ProtocolAck { - name: protocol - }))); - } else if frame == b"na\n"[..] { - return Ok(Async::Ready(Some(ListenerToDialerMessage::NotAvailable))); + if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') { + let len = msg.len(); + let name = msg.split_to(len - 1); + return Ok(Async::Ready(Some( + Response::Protocol { name } + ))); + } else if msg == MSG_PROTOCOL_NA { + return Ok(Async::Ready(Some(Response::ProtocolNotAvailable))); } else { // A varint number of protocols - let (num_protocols, mut remaining) = decode::usize(&frame)?; - if num_protocols > 1000 { // TODO: configurable limit - return Err(MultistreamSelectError::VarintParseError("too many protocols".into())) + let (num_protocols, mut remaining) = uvi::decode::usize(&msg)?; + if num_protocols > MAX_PROTOCOLS { // TODO: configurable limit + return Err(MultistreamSelectError::TooManyProtocols) } - let mut out = Vec::with_capacity(num_protocols); + let mut protocols = Vec::with_capacity(num_protocols); for _ in 0 .. num_protocols { - let (len, rem) = decode::usize(remaining)?; + let (len, rem) = uvi::decode::usize(remaining)?; if len == 0 || len > rem.len() || rem[len - 1] != b'\n' { return Err(MultistreamSelectError::UnknownMessage) } - out.push(Bytes::from(&rem[.. len - 1])); + protocols.push(Bytes::from(&rem[.. len - 1])); remaining = &rem[len ..] } return Ok(Async::Ready(Some( - ListenerToDialerMessage::ProtocolsListResponse { list: out }, + Response::SupportedProtocols { protocols }, ))); } } @@ -145,7 +146,8 @@ where /// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`. pub struct DialerFuture> { - inner: sink::Send>> + inner: sink::Send>, + _protocol_name: marker::PhantomData, } impl> Future for DialerFuture { @@ -154,48 +156,40 @@ impl> Future for DialerFuture { fn poll(&mut self) -> Poll { let inner = try_ready!(self.inner.poll()); - Ok(Async::Ready(Dialer { inner, handshake_finished: false })) + Ok(Async::Ready(Dialer { + inner, + handshake_finished: false, + _protocol_name: marker::PhantomData, + })) } } -/// tokio-codec `Encoder` handling `DialerToListenerMessage` values. -struct MessageEncoder(std::marker::PhantomData); - -enum Message { +enum Message<'a, N> { Header, - Body(DialerToListenerMessage) + Body(&'a Request) } -impl> Encoder for MessageEncoder { - type Item = Message; - type Error = MultistreamSelectError; - - fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> { - match item { +impl> Message<'_, N> { + fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> { + match self { Message::Header => { - Uvi::::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?; - dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len()); - dest.put(MULTISTREAM_PROTOCOL_WITH_LF); + dest.reserve(MSG_MULTISTREAM_1_0.len()); + dest.put(MSG_MULTISTREAM_1_0); Ok(()) } - Message::Body(DialerToListenerMessage::ProtocolRequest { name }) => { + Message::Body(Request::Protocol { name }) => { if !name.as_ref().starts_with(b"/") { - return Err(MultistreamSelectError::WrongProtocolName) + return Err(MultistreamSelectError::InvalidProtocolName) } let len = name.as_ref().len() + 1; // + 1 for \n - if len > std::u16::MAX as usize { - return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into()) - } - Uvi::::default().encode(len, dest)?; dest.reserve(len); dest.put(name.as_ref()); dest.put(&b"\n"[..]); Ok(()) } - Message::Body(DialerToListenerMessage::ProtocolsListRequest) => { - Uvi::::default().encode(3, dest)?; - dest.reserve(3); - dest.put(&b"ls\n"[..]); + Message::Body(Request::ListProtocols) => { + dest.reserve(MSG_LS.len()); + dest.put(MSG_LS); Ok(()) } } @@ -204,7 +198,7 @@ impl> Encoder for MessageEncoder { #[cfg(test)] mod tests { - use crate::protocol::{Dialer, DialerToListenerMessage, MultistreamSelectError}; + use super::*; use tokio::runtime::current_thread::Runtime; use tokio_tcp::{TcpListener, TcpStream}; use futures::Future; @@ -225,13 +219,13 @@ mod tests { .from_err() .and_then(move |stream| Dialer::dial(stream)) .and_then(move |dialer| { - let p = b"invalid_name"; - dialer.send(DialerToListenerMessage::ProtocolRequest { name: p }) + let name = b"invalid_name"; + dialer.send(Request::Protocol { name }) }); let mut rt = Runtime::new().unwrap(); match rt.block_on(server.join(client)) { - Err(MultistreamSelectError::WrongProtocolName) => (), + Err(MultistreamSelectError::InvalidProtocolName) => (), _ => panic!(), } } diff --git a/misc/multistream-select/src/protocol/error.rs b/misc/multistream-select/src/protocol/error.rs index f3b859f15da..f6686ee9fbd 100644 --- a/misc/multistream-select/src/protocol/error.rs +++ b/misc/multistream-select/src/protocol/error.rs @@ -20,7 +20,7 @@ //! Contains the error structs for the low-level protocol handling. -use std::error; +use std::error::Error; use std::fmt; use std::io; use unsigned_varint::decode; @@ -38,29 +38,25 @@ pub enum MultistreamSelectError { UnknownMessage, /// Protocol names must always start with `/`, otherwise this error is returned. - WrongProtocolName, + InvalidProtocolName, - /// Failure to parse variable-length integer. - // TODO: we don't include the actual error, because that would remove Send from the enum - VarintParseError(String), + /// Too many protocols have been returned by the remote. + TooManyProtocols, } impl From for MultistreamSelectError { - #[inline] fn from(err: io::Error) -> MultistreamSelectError { MultistreamSelectError::IoError(err) } } impl From for MultistreamSelectError { - #[inline] fn from(err: decode::Error) -> MultistreamSelectError { - MultistreamSelectError::VarintParseError(err.to_string()) + Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string())) } } -impl error::Error for MultistreamSelectError { - #[inline] +impl Error for MultistreamSelectError { fn description(&self) -> &str { match *self { MultistreamSelectError::IoError(_) => "I/O error", @@ -68,16 +64,15 @@ impl error::Error for MultistreamSelectError { "the remote doesn't use the same multistream-select protocol as we do" } MultistreamSelectError::UnknownMessage => "received an unknown message from the remote", - MultistreamSelectError::WrongProtocolName => { + MultistreamSelectError::InvalidProtocolName => { "protocol names must always start with `/`, otherwise this error is returned" } - MultistreamSelectError::VarintParseError(_) => { - "failure to parse variable-length integer" - } + MultistreamSelectError::TooManyProtocols => + "Too many protocols." } } - fn cause(&self) -> Option<&dyn error::Error> { + fn source(&self) -> Option<&(dyn Error + 'static)> { match *self { MultistreamSelectError::IoError(ref err) => Some(err), _ => None, @@ -86,8 +81,7 @@ impl error::Error for MultistreamSelectError { } impl fmt::Display for MultistreamSelectError { - #[inline] fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - write!(fmt, "{}", error::Error::description(self)) + write!(fmt, "{}", Error::description(self)) } } diff --git a/misc/multistream-select/src/protocol/listener.rs b/misc/multistream-select/src/protocol/listener.rs index 7616e8eac8f..ef1a962a3ea 100644 --- a/misc/multistream-select/src/protocol/listener.rs +++ b/misc/multistream-select/src/protocol/listener.rs @@ -20,23 +20,22 @@ //! Contains the `Listener` wrapper, which allows raw communications with a dialer. +use super::*; + use bytes::{BufMut, Bytes, BytesMut}; use crate::length_delimited::LengthDelimited; -use crate::protocol::DialerToListenerMessage; -use crate::protocol::ListenerToDialerMessage; -use crate::protocol::MultistreamSelectError; -use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF; +use crate::protocol::{Request, Response, MultistreamSelectError}; use futures::{prelude::*, sink, stream::StreamFuture}; use log::{debug, trace}; -use std::{io, mem}; -use tokio_codec::Encoder; +use std::{marker, mem}; use tokio_io::{AsyncRead, AsyncWrite}; -use unsigned_varint::{encode, codec::Uvi}; +use unsigned_varint as uvi; /// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and /// accepts messages. pub struct Listener { - inner: LengthDelimited> + inner: LengthDelimited, + _protocol_name: marker::PhantomData, } impl Listener @@ -47,10 +46,10 @@ where /// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the /// future returns a `Listener`. pub fn listen(inner: R) -> ListenerFuture { - let codec = MessageEncoder(std::marker::PhantomData); - let inner = LengthDelimited::new(inner, codec); + let inner = LengthDelimited::new(inner); ListenerFuture { - inner: ListenerFutureState::Await { inner: inner.into_future() } + inner: ListenerFutureState::Await { inner: inner.into_future() }, + _protocol_name: marker::PhantomData, } } @@ -67,24 +66,22 @@ where R: AsyncRead + AsyncWrite, N: AsRef<[u8]> { - type SinkItem = ListenerToDialerMessage; + type SinkItem = Response; type SinkError = MultistreamSelectError; - #[inline] fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - match self.inner.start_send(Message::Body(item))? { - AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)), - AsyncSink::NotReady(Message::Header) => unreachable!(), + let mut msg = BytesMut::new(); + Message::Body(&item).encode(&mut msg)?; + match self.inner.start_send(msg.freeze())? { + AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(item)), AsyncSink::Ready => Ok(AsyncSink::Ready) } } - #[inline] fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { Ok(self.inner.poll_complete()?) } - #[inline] fn close(&mut self) -> Poll<(), Self::SinkError> { Ok(self.inner.close()?) } @@ -94,26 +91,26 @@ impl Stream for Listener where R: AsyncRead + AsyncWrite, { - type Item = DialerToListenerMessage; + type Item = Request; type Error = MultistreamSelectError; fn poll(&mut self) -> Poll, Self::Error> { - let mut frame = match self.inner.poll() { - Ok(Async::Ready(Some(frame))) => frame, + let mut msg = match self.inner.poll() { + Ok(Async::Ready(Some(msg))) => msg, Ok(Async::Ready(None)) => return Ok(Async::Ready(None)), Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => return Err(err.into()), }; - if frame.get(0) == Some(&b'/') && frame.last() == Some(&b'\n') { - let frame_len = frame.len(); - let protocol = frame.split_to(frame_len - 1); + if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') { + let len = msg.len(); + let name = msg.split_to(len - 1); Ok(Async::Ready(Some( - DialerToListenerMessage::ProtocolRequest { name: protocol }, + Request::Protocol { name }, ))) - } else if frame == b"ls\n"[..] { + } else if msg == MSG_LS { Ok(Async::Ready(Some( - DialerToListenerMessage::ProtocolsListRequest, + Request::ListProtocols, ))) } else { Err(MultistreamSelectError::UnknownMessage) @@ -124,16 +121,17 @@ where /// Future, returned by `Listener::new` which performs the handshake and returns /// the `Listener` if successful. -pub struct ListenerFuture> { - inner: ListenerFutureState +pub struct ListenerFuture { + inner: ListenerFutureState, + _protocol_name: marker::PhantomData, } -enum ListenerFutureState> { +enum ListenerFutureState { Await { - inner: StreamFuture>> + inner: StreamFuture> }, Reply { - sender: sink::Send>> + sender: sink::Send> }, Undefined } @@ -155,12 +153,14 @@ impl> Future for ListenerFuture } Err((e, _)) => return Err(MultistreamSelectError::from(e)) }; - if msg.as_ref().map(|b| &b[..]) != Some(MULTISTREAM_PROTOCOL_WITH_LF) { - debug!("failed handshake; received: {:?}", msg); + if msg.as_ref().map(|b| &b[..]) != Some(MSG_MULTISTREAM_1_0) { + debug!("Unexpected message: {:?}", msg); return Err(MultistreamSelectError::FailedHandshake) } trace!("sending back /multistream/ to finish the handshake"); - let sender = socket.send(Message::Header); + let mut frame = BytesMut::new(); + Message::::Header.encode(&mut frame)?; + let sender = socket.send(frame.freeze()); self.inner = ListenerFutureState::Reply { sender } } ListenerFutureState::Reply { mut sender } => { @@ -171,69 +171,57 @@ impl> Future for ListenerFuture return Ok(Async::NotReady) } }; - return Ok(Async::Ready(Listener { inner: listener })) + return Ok(Async::Ready(Listener { + inner: listener, + _protocol_name: marker::PhantomData + })) } - ListenerFutureState::Undefined => panic!("ListenerFutureState::poll called after completion") + ListenerFutureState::Undefined => + panic!("ListenerFutureState::poll called after completion") } } } } -/// tokio-codec `Encoder` handling `ListenerToDialerMessage` values. -struct MessageEncoder(std::marker::PhantomData); - -enum Message { +enum Message<'a, N> { Header, - Body(ListenerToDialerMessage) + Body(&'a Response) } -impl> Encoder for MessageEncoder { - type Item = Message; - type Error = MultistreamSelectError; +impl> Message<'_, N> { - fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> { - match item { + fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> { + match self { Message::Header => { - Uvi::::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?; - dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len()); - dest.put(MULTISTREAM_PROTOCOL_WITH_LF); + dest.reserve(MSG_MULTISTREAM_1_0.len()); + dest.put(MSG_MULTISTREAM_1_0); Ok(()) } - Message::Body(ListenerToDialerMessage::ProtocolAck { name }) => { + Message::Body(Response::Protocol { name }) => { if !name.as_ref().starts_with(b"/") { - return Err(MultistreamSelectError::WrongProtocolName) + return Err(MultistreamSelectError::InvalidProtocolName) } let len = name.as_ref().len() + 1; // + 1 for \n - if len > std::u16::MAX as usize { - return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into()) - } - Uvi::::default().encode(len, dest)?; dest.reserve(len); dest.put(name.as_ref()); dest.put(&b"\n"[..]); Ok(()) } - Message::Body(ListenerToDialerMessage::ProtocolsListResponse { list }) => { - let mut buf = encode::usize_buffer(); - let mut out_msg = Vec::from(encode::usize(list.len(), &mut buf)); - for e in &list { - if e.as_ref().len() + 1 > std::u16::MAX as usize { - return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into()) - } - out_msg.extend(encode::usize(e.as_ref().len() + 1, &mut buf)); // +1 for '\n' - out_msg.extend_from_slice(e.as_ref()); + Message::Body(Response::SupportedProtocols { protocols }) => { + let mut buf = uvi::encode::usize_buffer(); + let mut out_msg = Vec::from(uvi::encode::usize(protocols.len(), &mut buf)); + for p in protocols { + out_msg.extend(uvi::encode::usize(p.as_ref().len() + 1, &mut buf)); // +1 for '\n' + out_msg.extend_from_slice(p.as_ref()); out_msg.push(b'\n') } - let len = encode::usize(out_msg.len(), &mut buf); - dest.reserve(len.len() + out_msg.len()); - dest.put(len); + dest.reserve(out_msg.len()); dest.put(out_msg); Ok(()) } - Message::Body(ListenerToDialerMessage::NotAvailable) => { - Uvi::::default().encode(3, dest)?; - dest.reserve(3); - dest.put(&b"na\n"[..]); + Message::Body(Response::ProtocolNotAvailable) => { + dest.reserve(MSG_PROTOCOL_NA.len()); + dest.put(MSG_PROTOCOL_NA); Ok(()) } } @@ -242,12 +230,12 @@ impl> Encoder for MessageEncoder { #[cfg(test)] mod tests { + use super::*; use tokio::runtime::current_thread::Runtime; use tokio_tcp::{TcpListener, TcpStream}; use bytes::Bytes; use futures::Future; use futures::{Sink, Stream}; - use crate::protocol::{Dialer, Listener, ListenerToDialerMessage, MultistreamSelectError}; #[test] fn wrong_proto_name() { @@ -260,8 +248,8 @@ mod tests { .map_err(|(e, _)| e.into()) .and_then(move |(connec, _)| Listener::listen(connec.unwrap())) .and_then(|listener| { - let proto_name = Bytes::from("invalid-proto"); - listener.send(ListenerToDialerMessage::ProtocolAck { name: proto_name }) + let name = Bytes::from("invalid-proto"); + listener.send(Response::Protocol { name }) }); let client = TcpStream::connect(&listener_addr) @@ -270,7 +258,7 @@ mod tests { let mut rt = Runtime::new().unwrap(); match rt.block_on(server.join(client)) { - Err(MultistreamSelectError::WrongProtocolName) => (), + Err(MultistreamSelectError::InvalidProtocolName) => (), _ => panic!(), } } diff --git a/misc/multistream-select/src/protocol/mod.rs b/misc/multistream-select/src/protocol/mod.rs index 7e840b31c33..738f8872234 100644 --- a/misc/multistream-select/src/protocol/mod.rs +++ b/misc/multistream-select/src/protocol/mod.rs @@ -20,47 +20,49 @@ //! Contains lower-level structs to handle the multistream protocol. +const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n"; +const MSG_PROTOCOL_NA: &[u8] = b"na\n"; +const MSG_LS: &[u8] = b"ls\n"; + mod dialer; mod error; mod listener; -const MULTISTREAM_PROTOCOL_WITH_LF: &[u8] = b"/multistream/1.0.0\n"; - pub use self::dialer::{Dialer, DialerFuture}; pub use self::error::MultistreamSelectError; pub use self::listener::{Listener, ListenerFuture}; /// Message sent from the dialer to the listener. #[derive(Debug, Clone, PartialEq, Eq)] -pub enum DialerToListenerMessage { +pub enum Request { /// The dialer wants us to use a protocol. /// /// If this is accepted (by receiving back a `ProtocolAck`), then we immediately start /// communicating in the new protocol. - ProtocolRequest { + Protocol { /// Name of the protocol. name: N }, /// The dialer requested the list of protocols that the listener supports. - ProtocolsListRequest, + ListProtocols, } /// Message sent from the listener to the dialer. #[derive(Debug, Clone, PartialEq, Eq)] -pub enum ListenerToDialerMessage { +pub enum Response { /// The protocol requested by the dialer is accepted. The socket immediately starts using the /// new protocol. - ProtocolAck { name: N }, + Protocol { name: N }, /// The protocol requested by the dialer is not supported or available. - NotAvailable, + ProtocolNotAvailable, /// Response to the request for the list of protocols. - ProtocolsListResponse { + SupportedProtocols { /// The list of protocols. // TODO: use some sort of iterator - list: Vec, + protocols: Vec, }, } diff --git a/misc/multistream-select/src/tests.rs b/misc/multistream-select/src/tests.rs index 0feba85469b..dbfc0588d7e 100644 --- a/misc/multistream-select/src/tests.rs +++ b/misc/multistream-select/src/tests.rs @@ -24,7 +24,7 @@ use crate::ProtocolChoiceError; use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial}; -use crate::protocol::{Dialer, DialerToListenerMessage, Listener, ListenerToDialerMessage}; +use crate::protocol::{Dialer, Request, Listener, Response}; use crate::{dialer_select_proto, listener_select_proto}; use futures::prelude::*; use tokio::runtime::current_thread::Runtime; @@ -56,23 +56,23 @@ fn negotiate_with_self_succeeds() { .and_then(|l| l.into_future().map_err(|(e, _)| e)) .and_then(|(msg, rest)| { let proto = match msg { - Some(DialerToListenerMessage::ProtocolRequest { name }) => name, + Some(Request::Protocol { name }) => name, _ => panic!(), }; - rest.send(ListenerToDialerMessage::ProtocolAck { name: proto }) + rest.send(Response::Protocol { name: proto }) }); let client = TcpStream::connect(&listener_addr) .from_err() .and_then(move |stream| Dialer::dial(stream)) .and_then(move |dialer| { - let p = b"/hello/1.0.0"; - dialer.send(DialerToListenerMessage::ProtocolRequest { name: p }) + let name = b"/hello/1.0.0"; + dialer.send(Request::Protocol { name }) }) .and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e)) .and_then(move |(msg, _)| { let proto = match msg { - Some(ListenerToDialerMessage::ProtocolAck { name }) => name, + Some(Response::Protocol { name }) => name, _ => panic!(), }; assert_eq!(proto, "/hello/1.0.0"); From d292aacb7beefebb3774e807d5bd3114f1a36b7a Mon Sep 17 00:00:00 2001 From: "Roman S. Borschel" Date: Tue, 16 Jul 2019 16:32:42 +0200 Subject: [PATCH 2/5] Reduce default buffer sizes. --- misc/multistream-select/src/length_delimited.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/misc/multistream-select/src/length_delimited.rs b/misc/multistream-select/src/length_delimited.rs index 44a8ff36090..45caaecbc2d 100644 --- a/misc/multistream-select/src/length_delimited.rs +++ b/misc/multistream-select/src/length_delimited.rs @@ -26,6 +26,7 @@ use unsigned_varint as uvi; const MAX_LEN_BYTES: u16 = 2; const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1; +const DEFAULT_BUFFER_SIZE: usize = 64; /// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` resource to read /// and write unsigned-varint prefixed frames. @@ -69,8 +70,8 @@ impl LengthDelimited { LengthDelimited { inner, read_state: ReadState::default(), - read_buffer: BytesMut::with_capacity(MAX_FRAME_SIZE as usize), - write_buffer: BytesMut::with_capacity((MAX_FRAME_SIZE + MAX_LEN_BYTES) as usize), + read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE), + write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize), } } @@ -191,8 +192,11 @@ where "Maximum frame size exceeded.")) } - self.write_buffer.put_slice(uvi::encode::u16(len, &mut [0; 3])); - self.write_buffer.extend(msg); + let mut uvi_buf = uvi::encode::u16_buffer(); + let uvi_len = uvi::encode::u16(len, &mut uvi_buf); + self.write_buffer.reserve(len as usize + uvi_len.len()); + self.write_buffer.put(uvi_len); + self.write_buffer.put(msg); Ok(AsyncSink::Ready) } From 402f1a25cfbee7b5a3ea7b25016073a093654994 Mon Sep 17 00:00:00 2001 From: "Roman S. Borschel" Date: Tue, 16 Jul 2019 16:54:20 +0200 Subject: [PATCH 3/5] Allow more than one frame to be buffered for sending. --- misc/multistream-select/src/length_delimited.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/misc/multistream-select/src/length_delimited.rs b/misc/multistream-select/src/length_delimited.rs index 45caaecbc2d..51493c33888 100644 --- a/misc/multistream-select/src/length_delimited.rs +++ b/misc/multistream-select/src/length_delimited.rs @@ -178,9 +178,12 @@ where type SinkError = io::Error; fn start_send(&mut self, msg: Self::SinkItem) -> StartSend { - if !self.write_buffer.is_empty() { + // Use the maximum frame length also as a (soft) upper limit + // for the entire write buffer. The actual (hard) limit is thus + // implied to be roughly 2 * MAX_FRAME_SIZE. + if self.write_buffer.len() >= MAX_FRAME_SIZE as usize { self.poll_complete()?; - if !self.write_buffer.is_empty() { + if self.write_buffer.len() >= MAX_FRAME_SIZE as usize { return Ok(AsyncSink::NotReady(msg)) } } From ae2f5099d4a7b6ea01c6355f9f7bc0cbbb248a7a Mon Sep 17 00:00:00 2001 From: "Roman S. Borschel" Date: Tue, 16 Jul 2019 17:07:37 +0200 Subject: [PATCH 4/5] Doc tweaks. --- misc/multistream-select/src/length_delimited.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/misc/multistream-select/src/length_delimited.rs b/misc/multistream-select/src/length_delimited.rs index 51493c33888..4d4f2c02bb4 100644 --- a/misc/multistream-select/src/length_delimited.rs +++ b/misc/multistream-select/src/length_delimited.rs @@ -37,9 +37,9 @@ const DEFAULT_BUFFER_SIZE: usize = 64; pub struct LengthDelimited { /// The inner I/O resource. inner: R, - /// Read buffer for a single unsigned-varint length-delimited frame. + /// Read buffer for a single incoming unsigned-varint length-delimited frame. read_buffer: BytesMut, - /// Write buffer for a single unsigned-varint length-delimited frame. + /// Write buffer for outgoing unsigned-varint length-delimited frames. write_buffer: BytesMut, /// The current read state, alternating between reading a frame /// length and reading a frame payload. @@ -77,15 +77,14 @@ impl LengthDelimited { /// Destroys the `LengthDelimited` and returns the underlying socket. /// - /// Contrary to its equivalent `tokio_io::codec::length_delimited::FramedRead`, this method is - /// guaranteed not to skip any data from the socket. + /// This method is guaranteed not to skip any data from the socket. /// /// # Panic /// - /// Will panic if called while there is data inside the buffer. **This can only happen if - /// you call `poll()` manually**. Using this struct as it is intended to be used (i.e. through - /// the modifiers provided by the `futures` crate) will always leave the object in a state in - /// which `into_inner()` will not panic. + /// Will panic if called while there is data inside the read or write buffer. + /// **This can only happen if you call `poll()` manually**. Using this struct + /// as it is intended to be used (i.e. through the high-level `futures` API) + /// will always leave the object in a state in which `into_inner()` will not panic. pub fn into_inner(self) -> R { assert!(self.write_buffer.is_empty()); assert!(self.read_buffer.is_empty()); From 05e33338ba1a81444e1770114cdae9002010237e Mon Sep 17 00:00:00 2001 From: "Roman S. Borschel" Date: Wed, 17 Jul 2019 15:35:25 +0200 Subject: [PATCH 5/5] Remove superfluous (duplicated) Message types. --- .../multistream-select/src/protocol/dialer.rs | 46 ++--------- .../src/protocol/listener.rs | 57 ++------------ misc/multistream-select/src/protocol/mod.rs | 76 +++++++++++++++++++ 3 files changed, 88 insertions(+), 91 deletions(-) diff --git a/misc/multistream-select/src/protocol/dialer.rs b/misc/multistream-select/src/protocol/dialer.rs index d2d732a46ac..28da191a490 100644 --- a/misc/multistream-select/src/protocol/dialer.rs +++ b/misc/multistream-select/src/protocol/dialer.rs @@ -22,7 +22,7 @@ use super::*; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{Bytes, BytesMut}; use crate::length_delimited::LengthDelimited; use crate::protocol::{Request, Response, MultistreamSelectError}; use futures::{prelude::*, sink, Async, StartSend, try_ready}; @@ -47,11 +47,11 @@ where N: AsRef<[u8]> { pub fn dial(inner: R) -> DialerFuture { - let sender = LengthDelimited::new(inner); + let io = LengthDelimited::new(inner); let mut buf = BytesMut::new(); - let _ = Message::::Header.encode(&mut buf); + Header::Multistream10.encode(&mut buf); DialerFuture { - inner: sender.send(buf.freeze()), + inner: io.send(buf.freeze()), _protocol_name: marker::PhantomData, } } @@ -70,11 +70,11 @@ where type SinkItem = Request; type SinkError = MultistreamSelectError; - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { + fn start_send(&mut self, request: Self::SinkItem) -> StartSend { let mut msg = BytesMut::new(); - Message::Body(&item).encode(&mut msg)?; + request.encode(&mut msg)?; match self.inner.start_send(msg.freeze())? { - AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(item)), + AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(request)), AsyncSink::Ready => Ok(AsyncSink::Ready), } } @@ -164,38 +164,6 @@ impl> Future for DialerFuture { } } -enum Message<'a, N> { - Header, - Body(&'a Request) -} - -impl> Message<'_, N> { - fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> { - match self { - Message::Header => { - dest.reserve(MSG_MULTISTREAM_1_0.len()); - dest.put(MSG_MULTISTREAM_1_0); - Ok(()) - } - Message::Body(Request::Protocol { name }) => { - if !name.as_ref().starts_with(b"/") { - return Err(MultistreamSelectError::InvalidProtocolName) - } - let len = name.as_ref().len() + 1; // + 1 for \n - dest.reserve(len); - dest.put(name.as_ref()); - dest.put(&b"\n"[..]); - Ok(()) - } - Message::Body(Request::ListProtocols) => { - dest.reserve(MSG_LS.len()); - dest.put(MSG_LS); - Ok(()) - } - } - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/misc/multistream-select/src/protocol/listener.rs b/misc/multistream-select/src/protocol/listener.rs index ef1a962a3ea..243304edcff 100644 --- a/misc/multistream-select/src/protocol/listener.rs +++ b/misc/multistream-select/src/protocol/listener.rs @@ -22,14 +22,13 @@ use super::*; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{Bytes, BytesMut}; use crate::length_delimited::LengthDelimited; use crate::protocol::{Request, Response, MultistreamSelectError}; use futures::{prelude::*, sink, stream::StreamFuture}; use log::{debug, trace}; use std::{marker, mem}; use tokio_io::{AsyncRead, AsyncWrite}; -use unsigned_varint as uvi; /// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and /// accepts messages. @@ -55,7 +54,6 @@ where /// Grants back the socket. Typically used after a `ProtocolRequest` has been received and a /// `ProtocolAck` has been sent back. - #[inline] pub fn into_inner(self) -> R { self.inner.into_inner() } @@ -69,11 +67,11 @@ where type SinkItem = Response; type SinkError = MultistreamSelectError; - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { + fn start_send(&mut self, response: Self::SinkItem) -> StartSend { let mut msg = BytesMut::new(); - Message::Body(&item).encode(&mut msg)?; + response.encode(&mut msg)?; match self.inner.start_send(msg.freeze())? { - AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(item)), + AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(response)), AsyncSink::Ready => Ok(AsyncSink::Ready) } } @@ -159,7 +157,7 @@ impl> Future for ListenerFuture } trace!("sending back /multistream/ to finish the handshake"); let mut frame = BytesMut::new(); - Message::::Header.encode(&mut frame)?; + Header::Multistream10.encode(&mut frame); let sender = socket.send(frame.freeze()); self.inner = ListenerFutureState::Reply { sender } } @@ -183,51 +181,6 @@ impl> Future for ListenerFuture } } -enum Message<'a, N> { - Header, - Body(&'a Response) -} - -impl> Message<'_, N> { - - fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> { - match self { - Message::Header => { - dest.reserve(MSG_MULTISTREAM_1_0.len()); - dest.put(MSG_MULTISTREAM_1_0); - Ok(()) - } - Message::Body(Response::Protocol { name }) => { - if !name.as_ref().starts_with(b"/") { - return Err(MultistreamSelectError::InvalidProtocolName) - } - let len = name.as_ref().len() + 1; // + 1 for \n - dest.reserve(len); - dest.put(name.as_ref()); - dest.put(&b"\n"[..]); - Ok(()) - } - Message::Body(Response::SupportedProtocols { protocols }) => { - let mut buf = uvi::encode::usize_buffer(); - let mut out_msg = Vec::from(uvi::encode::usize(protocols.len(), &mut buf)); - for p in protocols { - out_msg.extend(uvi::encode::usize(p.as_ref().len() + 1, &mut buf)); // +1 for '\n' - out_msg.extend_from_slice(p.as_ref()); - out_msg.push(b'\n') - } - dest.reserve(out_msg.len()); - dest.put(out_msg); - Ok(()) - } - Message::Body(Response::ProtocolNotAvailable) => { - dest.reserve(MSG_PROTOCOL_NA.len()); - dest.put(MSG_PROTOCOL_NA); - Ok(()) - } - } - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/misc/multistream-select/src/protocol/mod.rs b/misc/multistream-select/src/protocol/mod.rs index 738f8872234..5b1fca7153b 100644 --- a/misc/multistream-select/src/protocol/mod.rs +++ b/misc/multistream-select/src/protocol/mod.rs @@ -32,6 +32,24 @@ pub use self::dialer::{Dialer, DialerFuture}; pub use self::error::MultistreamSelectError; pub use self::listener::{Listener, ListenerFuture}; +use bytes::{BytesMut, BufMut}; +use unsigned_varint as uvi; + +pub enum Header { + Multistream10 +} + +impl Header { + fn encode(&self, dest: &mut BytesMut) { + match self { + Header::Multistream10 => { + dest.reserve(MSG_MULTISTREAM_1_0.len()); + dest.put(MSG_MULTISTREAM_1_0); + } + } + } +} + /// Message sent from the dialer to the listener. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Request { @@ -48,6 +66,29 @@ pub enum Request { ListProtocols, } +impl> Request { + fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> { + match self { + Request::Protocol { name } => { + if !name.as_ref().starts_with(b"/") { + return Err(MultistreamSelectError::InvalidProtocolName) + } + let len = name.as_ref().len() + 1; // + 1 for \n + dest.reserve(len); + dest.put(name.as_ref()); + dest.put(&b"\n"[..]); + Ok(()) + } + Request::ListProtocols => { + dest.reserve(MSG_LS.len()); + dest.put(MSG_LS); + Ok(()) + } + } + } +} + + /// Message sent from the listener to the dialer. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Response { @@ -66,3 +107,38 @@ pub enum Response { }, } +impl> Response { + fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> { + match self { + Response::Protocol { name } => { + if !name.as_ref().starts_with(b"/") { + return Err(MultistreamSelectError::InvalidProtocolName) + } + let len = name.as_ref().len() + 1; // + 1 for \n + dest.reserve(len); + dest.put(name.as_ref()); + dest.put(&b"\n"[..]); + Ok(()) + } + Response::SupportedProtocols { protocols } => { + let mut buf = uvi::encode::usize_buffer(); + let mut out_msg = Vec::from(uvi::encode::usize(protocols.len(), &mut buf)); + for p in protocols { + out_msg.extend(uvi::encode::usize(p.as_ref().len() + 1, &mut buf)); // +1 for '\n' + out_msg.extend_from_slice(p.as_ref()); + out_msg.push(b'\n') + } + dest.reserve(out_msg.len()); + dest.put(out_msg); + Ok(()) + } + Response::ProtocolNotAvailable => { + dest.reserve(MSG_PROTOCOL_NA.len()); + dest.put(MSG_PROTOCOL_NA); + Ok(()) + } + } + } +} + +