Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[multistream-select] Require remaining negotiation data to be flushed. #1781

Merged
merged 4 commits into from
Oct 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions misc/multistream-select/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# 0.8.3 [unreleased]

- Fix a potential deadlock during protocol negotiation due
to a missing flush, potentially resulting in sporadic protocol
upgrade timeouts.
[PR 1781](https://github.com/libp2p/rust-libp2p/pull/1781).

- Update dependencies.

# 0.8.2 [2020-06-22]
Expand Down
3 changes: 1 addition & 2 deletions misc/multistream-select/src/dialer_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,7 @@ where
}
Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => {
log::debug!("Dialer: Received confirmation for protocol: {}", p);
let (io, remaining) = io.into_inner();
let io = Negotiated::completed(io, remaining);
let io = Negotiated::completed(io.into_inner());
return Poll::Ready(Ok((protocol, io)));
}
Message::NotAvailable => {
Expand Down
22 changes: 9 additions & 13 deletions misc/multistream-select/src/length_delimited.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,18 @@ impl<R> LengthDelimited<R> {
}
}

/// Drops the [`LengthDelimited`] resource, yielding the underlying I/O stream
/// together with the remaining write buffer containing the uvi-framed data
/// that has not yet been written to the underlying I/O stream.
///
/// The returned remaining write buffer may be prepended to follow-up
/// protocol data to send with a single `write`. Either way, if non-empty,
/// the write buffer _must_ eventually be written to the I/O stream
/// _before_ any follow-up data, in order to maintain a correct data stream.
/// Drops the [`LengthDelimited`] resource, yielding the underlying I/O stream.
///
/// # Panic
///
/// Will panic if called while there is data in the read buffer. The read buffer is
/// guaranteed to be empty whenever `Stream::poll` yields a new `Bytes` frame.
pub fn into_inner(self) -> (R, BytesMut) {
/// Will panic if called while there is data in the read or write buffer.
/// The read buffer is guaranteed to be empty whenever `Stream::poll` yields
/// a new `Bytes` frame. The write buffer is guaranteed to be empty after
/// flushing.
pub fn into_inner(self) -> R {
assert!(self.read_buffer.is_empty());
(self.inner, self.write_buffer)
assert!(self.write_buffer.is_empty());
self.inner
}

/// Converts the [`LengthDelimited`] into a [`LengthDelimitedReader`], dropping the
Expand Down Expand Up @@ -303,7 +299,7 @@ impl<R> LengthDelimitedReader<R> {
/// yield a new `Message`. The write buffer is guaranteed to be empty whenever
/// [`LengthDelimited::poll_write_buffer`] yields [`Poll::Ready`] or after
/// the [`Sink`] has been completely flushed via [`Sink::poll_flush`].
pub fn into_inner(self) -> (R, BytesMut) {
pub fn into_inner(self) -> R {
self.inner.into_inner()
}
}
Expand Down
39 changes: 21 additions & 18 deletions misc/multistream-select/src/listener_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ where
message: Message,
protocol: Option<N>
},
Flush { io: MessageIO<R> },
Flush {
io: MessageIO<R>,
protocol: Option<N>
},
Done
}

Expand Down Expand Up @@ -141,7 +144,7 @@ where
}

*this.state = match version {
Version::V1 => State::Flush { io },
Version::V1 => State::Flush { io, protocol: None },
Version::V1Lazy => State::RecvMessage { io },
}
}
Expand Down Expand Up @@ -204,28 +207,28 @@ where
return Poll::Ready(Err(From::from(err)));
}

// If a protocol has been selected, finish negotiation.
// Otherwise flush the sink and expect to receive another
// message.
*this.state = match protocol {
Some(protocol) => {
log::debug!("Listener: sent confirmed protocol: {}",
String::from_utf8_lossy(protocol.as_ref()));
let (io, remaining) = io.into_inner();
let io = Negotiated::completed(io, remaining);
return Poll::Ready(Ok((protocol, io)));
}
None => State::Flush { io }
};
*this.state = State::Flush { io, protocol };
}

State::Flush { mut io } => {
State::Flush { mut io, protocol } => {
match Pin::new(&mut io).poll_flush(cx) {
Poll::Pending => {
*this.state = State::Flush { io };
*this.state = State::Flush { io, protocol };
return Poll::Pending
},
Poll::Ready(Ok(())) => *this.state = State::RecvMessage { io },
Poll::Ready(Ok(())) => {
// If a protocol has been selected, finish negotiation.
// Otherwise expect to receive another message.
match protocol {
Some(protocol) => {
log::debug!("Listener: sent confirmed protocol: {}",
String::from_utf8_lossy(protocol.as_ref()));
let io = Negotiated::completed(io.into_inner());
return Poll::Ready(Ok((protocol, io)))
}
None => *this.state = State::RecvMessage { io }
}
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
}
}
Expand Down
151 changes: 19 additions & 132 deletions misc/multistream-select/src/negotiated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

use crate::protocol::{Protocol, MessageReader, Message, Version, ProtocolError};

use bytes::{BytesMut, Buf};
use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready};
use pin_project::pin_project;
use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}};
Expand Down Expand Up @@ -74,10 +73,9 @@ where
}

impl<TInner> Negotiated<TInner> {
/// Creates a `Negotiated` in state [`State::Completed`], possibly
/// with `remaining` data to be sent.
pub(crate) fn completed(io: TInner, remaining: BytesMut) -> Self {
Negotiated { state: State::Completed { io, remaining } }
/// Creates a `Negotiated` in state [`State::Completed`].
pub(crate) fn completed(io: TInner) -> Self {
Negotiated { state: State::Completed { io } }
}

/// Creates a `Negotiated` in state [`State::Expecting`] that is still
Expand Down Expand Up @@ -107,10 +105,7 @@ impl<TInner> Negotiated<TInner> {
let mut this = self.project();

match this.state.as_mut().project() {
StateProj::Completed { remaining, .. } => {
debug_assert!(remaining.is_empty());
return Poll::Ready(Ok(()))
}
StateProj::Completed { .. } => return Poll::Ready(Ok(())),
_ => {}
}

Expand Down Expand Up @@ -139,8 +134,7 @@ impl<TInner> Negotiated<TInner> {
if let Message::Protocol(p) = &msg {
if p.as_ref() == protocol.as_ref() {
log::debug!("Negotiated: Received confirmation for protocol: {}", p);
let (io, remaining) = io.into_inner();
*this.state = State::Completed { io, remaining };
*this.state = State::Completed { io: io.into_inner() };
return Poll::Ready(Ok(()));
}
}
Expand All @@ -165,7 +159,8 @@ impl<TInner> Negotiated<TInner> {
#[derive(Debug)]
enum State<R> {
/// In this state, a `Negotiated` is still expecting to
/// receive confirmation of the protocol it as settled on.
/// receive confirmation of the protocol it has optimistically
/// settled on.
Expecting {
/// The underlying I/O stream.
#[pin]
Expand All @@ -176,11 +171,9 @@ enum State<R> {
version: Version
},

/// In this state, a protocol has been agreed upon and may
/// only be pending the sending of the final acknowledgement,
/// which is prepended to / combined with the next write for
/// efficiency.
Completed { #[pin] io: R, remaining: BytesMut },
/// In this state, a protocol has been agreed upon and I/O
/// on the underlying stream can commence.
Completed { #[pin] io: R },

/// Temporary state while moving the `io` resource from
/// `Expecting` to `Completed`.
Expand All @@ -196,12 +189,9 @@ where
{
loop {
match self.as_mut().project().state.project() {
StateProj::Completed { io, remaining } => {
// If protocol negotiation is complete and there is no
// remaining data to be flushed, commence with reading.
if remaining.is_empty() {
return io.poll_read(cx, buf)
}
StateProj::Completed { io } => {
// If protocol negotiation is complete, commence with reading.
return io.poll_read(cx, buf)
},
_ => {}
}
Expand Down Expand Up @@ -230,12 +220,9 @@ where
{
loop {
match self.as_mut().project().state.project() {
StateProj::Completed { io, remaining } => {
// If protocol negotiation is complete and there is no
// remaining data to be flushed, commence with reading.
if remaining.is_empty() {
return io.poll_read_vectored(cx, bufs)
}
StateProj::Completed { io } => {
// If protocol negotiation is complete, commence with reading.
return io.poll_read_vectored(cx, bufs)
},
_ => {}
}
Expand All @@ -257,33 +244,15 @@ where
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
match self.project().state.project() {
StateProj::Completed { mut io, remaining } => {
while !remaining.is_empty() {
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
remaining.advance(n);
}
io.poll_write(cx, buf)
},
StateProj::Completed { io } => io.poll_write(cx, buf),
StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
StateProj::Invalid => panic!("Negotiated: Invalid state"),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match self.project().state.project() {
StateProj::Completed { mut io, remaining } => {
while !remaining.is_empty() {
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
remaining.advance(n);
}
io.poll_flush(cx)
},
StateProj::Completed { io } => io.poll_flush(cx),
StateProj::Expecting { io, .. } => io.poll_flush(cx),
StateProj::Invalid => panic!("Negotiated: Invalid state"),
}
Expand All @@ -307,16 +276,7 @@ where
-> Poll<Result<usize, io::Error>>
{
match self.project().state.project() {
StateProj::Completed { mut io, remaining } => {
while !remaining.is_empty() {
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
remaining.advance(n);
}
io.poll_write_vectored(cx, bufs)
},
StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
StateProj::Invalid => panic!("Negotiated: Invalid state"),
}
Expand Down Expand Up @@ -373,76 +333,3 @@ impl fmt::Display for NegotiationError {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use quickcheck::*;
use std::{io::Write, task::Poll};

/// An I/O resource with a fixed write capacity (total and per write op).
struct Capped { buf: Vec<u8>, step: usize }

impl AsyncRead for Capped {
fn poll_read(self: Pin<&mut Self>, _: &mut Context<'_>, _: &mut [u8]) -> Poll<Result<usize, io::Error>> {
unreachable!()
}
}

impl AsyncWrite for Capped {
fn poll_write(mut self: Pin<&mut Self>, _: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
if self.buf.len() + buf.len() > self.buf.capacity() {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
let len = usize::min(self.step, buf.len());
let n = Write::write(&mut self.buf, &buf[.. len]).unwrap();
Poll::Ready(Ok(n))
}

fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}

fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}

#[test]
fn write_remaining() {
fn prop(rem: Vec<u8>, new: Vec<u8>, free: u8, step: u8) -> TestResult {
let cap = rem.len() + free as usize;
let step = u8::min(free, step) as usize + 1;
let buf = Capped { buf: Vec::with_capacity(cap), step };
let rem = BytesMut::from(&rem[..]);
let mut io = Negotiated::completed(buf, rem.clone());
let mut written = 0;
loop {
// Write until `new` has been fully written or the capped buffer runs
// over capacity and yields WriteZero.
match future::poll_fn(|cx| Pin::new(&mut io).poll_write(cx, &new[written..])).now_or_never().unwrap() {
Ok(n) =>
if let State::Completed { remaining, .. } = &io.state {
assert!(remaining.is_empty());
written += n;
if written == new.len() {
return TestResult::passed()
}
} else {
return TestResult::failed()
}
Err(e) if e.kind() == io::ErrorKind::WriteZero => {
if let State::Completed { .. } = &io.state {
assert!(rem.len() + new.len() > cap);
return TestResult::passed()
} else {
return TestResult::failed()
}
}
Err(e) => panic!("Unexpected error: {:?}", e),
}
}
}
quickcheck(prop as fn(_,_,_,_) -> _)
}
}
Loading