diff --git a/muxers/mplex/Cargo.toml b/muxers/mplex/Cargo.toml index f00e7879f492..f88262a1aa0b 100644 --- a/muxers/mplex/Cargo.toml +++ b/muxers/mplex/Cargo.toml @@ -11,11 +11,11 @@ categories = ["network-programming", "asynchronous"] [dependencies] bytes = "0.5" -fnv = "1.0" futures = "0.3.1" futures_codec = "0.4" libp2p-core = { version = "0.22.0", path = "../../core" } log = "0.4" +nohash-hasher = "0.2" parking_lot = "0.11" rand = "0.7" smallvec = "1.4" diff --git a/muxers/mplex/src/codec.rs b/muxers/mplex/src/codec.rs index 83caf8692148..9a367a7ae90c 100644 --- a/muxers/mplex/src/codec.rs +++ b/muxers/mplex/src/codec.rs @@ -18,11 +18,10 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use libp2p_core::Endpoint; -use futures_codec::{Decoder, Encoder}; -use std::io::{Error as IoError, ErrorKind as IoErrorKind}; -use std::{fmt, mem}; use bytes::{BufMut, Bytes, BytesMut}; +use futures_codec::{Decoder, Encoder}; +use libp2p_core::Endpoint; +use std::{fmt, hash::{Hash, Hasher}, io, mem}; use unsigned_varint::{codec, encode}; // Maximum size for a packet: 1MB as per the spec. @@ -46,7 +45,7 @@ pub(crate) const MAX_FRAME_SIZE: usize = 1024 * 1024; /// > we initiated the stream, so the local ID has the role `Endpoint::Dialer`. /// > Conversely, when receiving a frame with a flag identifying the remote as a "sender", /// > the corresponding local ID has the role `Endpoint::Listener`. -#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Copy, Clone, PartialEq, Eq, Debug)] pub struct LocalStreamId { num: u32, role: Endpoint, @@ -61,6 +60,14 @@ impl fmt::Display for LocalStreamId { } } +impl Hash for LocalStreamId { + fn hash(&self, state: &mut H) { + state.write_u32(self.num); + } +} + +impl nohash_hasher::IsEnabled for LocalStreamId {} + /// A unique identifier used by the remote node for a substream. /// /// `RemoteStreamId`s are received with frames from the remote @@ -148,7 +155,7 @@ impl Codec { impl Decoder for Codec { type Item = Frame; - type Error = IoError; + type Error = io::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { loop { @@ -169,7 +176,7 @@ impl Decoder for Codec { Some(len) => { if len as usize > MAX_FRAME_SIZE { let msg = format!("Mplex frame length {} exceeds maximum", len); - return Err(IoError::new(IoErrorKind::InvalidData, msg)); + return Err(io::Error::new(io::ErrorKind::InvalidData, msg)); } self.decoder_state = CodecDecodeState::HasHeaderAndLen(header, len as usize); @@ -200,7 +207,7 @@ impl Decoder for Codec { 6 => Frame::Reset { stream_id: RemoteStreamId::dialer(num) }, _ => { let msg = format!("Invalid mplex header value 0x{:x}", header); - return Err(IoError::new(IoErrorKind::InvalidData, msg)); + return Err(io::Error::new(io::ErrorKind::InvalidData, msg)); }, }; @@ -209,7 +216,7 @@ impl Decoder for Codec { }, CodecDecodeState::Poisoned => { - return Err(IoError::new(IoErrorKind::InvalidData, "Mplex codec poisoned")); + return Err(io::Error::new(io::ErrorKind::InvalidData, "Mplex codec poisoned")); } } } @@ -218,7 +225,7 @@ impl Decoder for Codec { impl Encoder for Codec { type Item = Frame; - type Error = IoError; + type Error = io::Error; fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { let (header, data) = match item { @@ -253,7 +260,7 @@ impl Encoder for Codec { let data_len_bytes = encode::usize(data_len, &mut data_buf); if data_len > MAX_FRAME_SIZE { - return Err(IoError::new(IoErrorKind::InvalidData, "data size exceed maximum")); + return Err(io::Error::new(io::ErrorKind::InvalidData, "data size exceed maximum")); } dst.reserve(header_bytes.len() + data_len_bytes.len() + data_len); diff --git a/muxers/mplex/src/io.rs b/muxers/mplex/src/io.rs index 9769ae7ed92f..9ad53fcf8ae0 100644 --- a/muxers/mplex/src/io.rs +++ b/muxers/mplex/src/io.rs @@ -22,10 +22,10 @@ use bytes::Bytes; use crate::{MplexConfig, MaxBufferBehaviour}; use crate::codec::{Codec, Frame, LocalStreamId, RemoteStreamId}; use log::{debug, trace}; -use fnv::FnvHashMap; use futures::{prelude::*, ready, stream::Fuse}; use futures::task::{AtomicWaker, ArcWake, waker_ref, WakerRef}; use futures_codec::Framed; +use nohash_hasher::{IntMap, IntSet}; use parking_lot::Mutex; use smallvec::SmallVec; use std::collections::VecDeque; @@ -66,7 +66,7 @@ pub struct Multiplexed { open_buffer: VecDeque, /// Whether a flush is pending due to one or more new outbound /// `Open` frames, before reading frames can proceed. - pending_flush_open: bool, + pending_flush_open: IntSet, /// The stream that currently blocks reading for all streams /// due to a full buffer, if any. Only applicable for use /// with [`MaxBufferBehaviour::Block`]. @@ -80,7 +80,7 @@ pub struct Multiplexed { /// if some or all of the pending frames cannot be sent. pending_frames: VecDeque>, /// The managed substreams. - substreams: FnvHashMap, + substreams: IntMap, /// The ID for the next outbound substream. next_outbound_stream_id: LocalStreamId, /// Registry of wakers for pending tasks interested in reading. @@ -121,7 +121,7 @@ where io: Framed::new(io, Codec::new()).fuse(), open_buffer: Default::default(), substreams: Default::default(), - pending_flush_open: false, + pending_flush_open: Default::default(), pending_frames: Default::default(), blocked_stream: None, next_outbound_stream_id: LocalStreamId::dialer(0), @@ -154,7 +154,7 @@ where match ready!(self.io.poll_flush_unpin(&mut Context::from_waker(&waker))) { Err(e) => Poll::Ready(self.on_error(e)), Ok(()) => { - self.pending_flush_open = false; + self.pending_flush_open = Default::default(); Poll::Ready(Ok(())) } } @@ -264,7 +264,7 @@ where self.id, stream_id, self.substreams.len()); // The flush is delayed and the `Open` frame may be sent // together with other frames in the same transport packet. - self.pending_flush_open = true; + self.pending_flush_open.insert(stream_id); Poll::Ready(Ok(stream_id)) } Err(e) => Poll::Ready(self.on_error(e)), @@ -591,10 +591,11 @@ where } // Perform any pending flush before reading. - if self.pending_flush_open { - trace!("{}: Executing pending flush.", self.id); - ready!(self.poll_flush(cx))?; - debug_assert!(!self.pending_flush_open); + if let Some(id) = &stream_id { + if self.pending_flush_open.remove(id) { + trace!("{}: Executing pending flush for {}.", self.id, id); + ready!(self.poll_flush(cx))?; + } } // Try to read another frame from the underlying I/O stream. @@ -821,7 +822,7 @@ struct NotifierRead { next_stream: AtomicWaker, /// The wakers of currently pending tasks that last /// called `poll_read_stream` for a particular substream. - read_stream: Mutex>, + read_stream: Mutex>, } impl NotifierRead {