diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index e352e294..44436aa3 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -16,6 +16,8 @@ jobs: steps: - uses: actions/checkout@v3 + - name: Check formatting + run: cargo fmt -- --check - name: Build run: cargo build --verbose - name: Run tests diff --git a/benches/concurrent.rs b/benches/concurrent.rs index f7f59d2e..7466702f 100644 --- a/benches/concurrent.rs +++ b/benches/concurrent.rs @@ -8,9 +8,9 @@ // at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license // at https://opensource.org/licenses/MIT. -use constrained_connection::{Endpoint, new_unconstrained_connection, samples}; -use criterion::{BenchmarkId, criterion_group, criterion_main, Criterion, Throughput}; -use futures::{channel::mpsc, future, prelude::*, io::AsyncReadExt}; +use constrained_connection::{new_unconstrained_connection, samples, Endpoint}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use futures::{channel::mpsc, future, io::AsyncReadExt, prelude::*}; use std::sync::Arc; use tokio::{runtime::Runtime, task}; use yamux::{Config, Connection, Mode}; @@ -31,9 +31,15 @@ fn concurrent(c: &mut Criterion) { let data = Bytes(Arc::new(vec![0x42; 4096])); let networks = vec![ ("mobile", (|| samples::mobile_hsdpa().2) as fn() -> (_, _)), - ("adsl2+", (|| samples::residential_adsl2().2) as fn() -> (_, _)), + ( + "adsl2+", + (|| samples::residential_adsl2().2) as fn() -> (_, _), + ), ("gbit-lan", (|| samples::gbit_lan().2) as fn() -> (_, _)), - ("unconstrained", new_unconstrained_connection as fn() -> (_, _)), + ( + "unconstrained", + new_unconstrained_connection as fn() -> (_, _), + ), ]; let mut group = c.benchmark_group("concurrent"); @@ -45,15 +51,20 @@ fn concurrent(c: &mut Criterion) { let data = data.clone(); let rt = Runtime::new().unwrap(); - group.throughput(Throughput::Bytes((nstreams * nmessages * data.0.len()) as u64)); + group.throughput(Throughput::Bytes( + (nstreams * nmessages * data.0.len()) as u64, + )); group.bench_function( - BenchmarkId::from_parameter( - format!("{}/#streams{}/#messages{}", network_name, nstreams, nmessages), - ), - |b| b.iter(|| { - let (server, client) = new_connection(); - rt.block_on(oneway(*nstreams, *nmessages, data.clone(), server, client)) - }), + BenchmarkId::from_parameter(format!( + "{}/#streams{}/#messages{}", + network_name, nstreams, nmessages + )), + |b| { + b.iter(|| { + let (server, client) = new_connection(); + rt.block_on(oneway(*nstreams, *nmessages, data.clone(), server, client)) + }) + }, ); } } @@ -89,7 +100,7 @@ async fn oneway( let mut b = vec![0; msg_len]; // Receive `nmessages` messages. - for _ in 0 .. nmessages { + for _ in 0..nmessages { stream.read_exact(&mut b[..]).await.unwrap(); n += b.len(); } @@ -103,16 +114,19 @@ async fn oneway( let conn = Connection::new(client, config(), Mode::Client); let mut ctrl = conn.control(); - task::spawn(yamux::into_stream(conn).for_each(|r| {r.unwrap(); future::ready(())} )); + task::spawn(yamux::into_stream(conn).for_each(|r| { + r.unwrap(); + future::ready(()) + })); - for _ in 0 .. nstreams { + for _ in 0..nstreams { let data = data.clone(); let mut ctrl = ctrl.clone(); task::spawn(async move { let mut stream = ctrl.open_stream().await.unwrap(); // Send `nmessages` messages. - for _ in 0 .. nmessages { + for _ in 0..nmessages { stream.write_all(data.as_ref()).await.unwrap(); } @@ -120,7 +134,10 @@ async fn oneway( }); } - let n = rx.take(nstreams).fold(0, |acc, n| future::ready(acc + n)).await; + let n = rx + .take(nstreams) + .fold(0, |acc, n| future::ready(acc + n)) + .await; assert_eq!(n, nstreams * nmessages * msg_len); ctrl.close().await.expect("close"); } diff --git a/src/chunks.rs b/src/chunks.rs index 213bbec3..0e66d894 100644 --- a/src/chunks.rs +++ b/src/chunks.rs @@ -18,13 +18,16 @@ use std::{collections::VecDeque, io}; #[derive(Debug)] pub(crate) struct Chunks { seq: VecDeque, - len: usize + len: usize, } impl Chunks { /// A new empty chunk list. pub(crate) fn new() -> Self { - Chunks { seq: VecDeque::new(), len: 0 } + Chunks { + seq: VecDeque::new(), + len: 0, + } } /// The total length of bytes yet-to-be-read in all `Chunk`s. @@ -36,7 +39,9 @@ impl Chunks { pub(crate) fn push(&mut self, x: Vec) { self.len += x.len(); if !x.is_empty() { - self.seq.push_back(Chunk { cursor: io::Cursor::new(x) }) + self.seq.push_back(Chunk { + cursor: io::Cursor::new(x), + }) } } @@ -59,7 +64,7 @@ impl Chunks { /// vector can be consumed in steps. #[derive(Debug)] pub(crate) struct Chunk { - cursor: io::Cursor> + cursor: io::Cursor>, } impl Chunk { @@ -83,13 +88,15 @@ impl Chunk { /// The `AsRef<[u8]>` impl of `Chunk` provides a byte-slice view /// from the current position to the end. pub(crate) fn advance(&mut self, amount: usize) { - assert!({ // the new position must not exceed the vector's length + assert!({ + // the new position must not exceed the vector's length let pos = self.offset().checked_add(amount); let max = self.cursor.get_ref().len(); pos.is_some() && pos <= Some(max) }); - self.cursor.set_position(self.cursor.position() + amount as u64); + self.cursor + .set_position(self.cursor.position() + amount as u64); } /// Consume `self` and return the inner vector. @@ -100,7 +107,6 @@ impl Chunk { impl AsRef<[u8]> for Chunk { fn as_ref(&self) -> &[u8] { - &self.cursor.get_ref()[self.offset() ..] + &self.cursor.get_ref()[self.offset()..] } } - diff --git a/src/connection/control.rs b/src/connection/control.rs index e317086e..cf260350 100644 --- a/src/connection/control.rs +++ b/src/connection/control.rs @@ -8,10 +8,17 @@ // at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license // at https://opensource.org/licenses/MIT. -use crate::{Stream, error::ConnectionError}; -use futures::{ready, channel::{mpsc, oneshot}, prelude::*}; -use std::{pin::Pin, task::{Context, Poll}}; use super::ControlCommand; +use crate::{error::ConnectionError, Stream}; +use futures::{ + channel::{mpsc, oneshot}, + prelude::*, + ready, +}; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; type Result = std::result::Result; @@ -31,7 +38,7 @@ pub struct Control { /// Pending state of `poll_open_stream`. pending_open: Option>>, /// Pending state of `poll_close`. - pending_close: Option> + pending_close: Option>, } impl Clone for Control { @@ -39,7 +46,7 @@ impl Clone for Control { Control { sender: self.sender.clone(), pending_open: None, - pending_close: None + pending_close: None, } } } @@ -49,7 +56,7 @@ impl Control { Control { sender, pending_open: None, - pending_close: None + pending_close: None, } } @@ -63,9 +70,14 @@ impl Control { /// Close the connection. pub async fn close(&mut self) -> Result<()> { let (tx, rx) = oneshot::channel(); - if self.sender.send(ControlCommand::CloseConnection(tx)).await.is_err() { + if self + .sender + .send(ControlCommand::CloseConnection(tx)) + .await + .is_err() + { // The receiver is closed which means the connection is already closed. - return Ok(()) + return Ok(()); } // A dropped `oneshot::Sender` means the `Connection` is gone, // so we do not treat receive errors differently here. @@ -84,14 +96,12 @@ impl Control { self.pending_open = Some(rx) } Some(mut rx) => match rx.poll_unpin(cx)? { - Poll::Ready(result) => { - return Poll::Ready(result) - } + Poll::Ready(result) => return Poll::Ready(result), Poll::Pending => { self.pending_open = Some(rx); - return Poll::Pending + return Poll::Pending; } - } + }, } } } @@ -108,35 +118,32 @@ impl Control { None => { if ready!(self.sender.poll_ready(cx)).is_err() { // The receiver is closed which means the connection is already closed. - return Poll::Ready(Ok(())) + return Poll::Ready(Ok(())); } let (tx, rx) = oneshot::channel(); if let Err(e) = self.sender.start_send(ControlCommand::CloseConnection(tx)) { if e.is_full() { - continue + continue; } debug_assert!(e.is_disconnected()); // The receiver is closed which means the connection is already closed. - return Poll::Ready(Ok(())) + return Poll::Ready(Ok(())); } self.pending_close = Some(rx) } Some(mut rx) => match rx.poll_unpin(cx) { - Poll::Ready(Ok(())) => { - return Poll::Ready(Ok(())) - } + Poll::Ready(Ok(())) => return Poll::Ready(Ok(())), Poll::Ready(Err(oneshot::Canceled)) => { // A dropped `oneshot::Sender` means the `Connection` is gone, // which is `Ok`ay for us here. - return Poll::Ready(Ok(())) + return Poll::Ready(Ok(())); } Poll::Pending => { self.pending_close = Some(rx); - return Poll::Pending + return Poll::Pending; } - } + }, } } } } - diff --git a/src/connection/stream.rs b/src/connection/stream.rs index 2c8229f7..22d25031 100644 --- a/src/connection/stream.rs +++ b/src/connection/stream.rs @@ -9,19 +9,28 @@ // at https://opensource.org/licenses/MIT. use crate::{ - Config, - WindowUpdateMode, chunks::Chunks, connection::{self, StreamCommand}, frame::{ + header::{Data, Header, StreamId, WindowUpdate}, Frame, - header::{Header, StreamId, Data, WindowUpdate} - } + }, + Config, WindowUpdateMode, +}; +use futures::{ + channel::mpsc, + future::Either, + io::{AsyncRead, AsyncWrite}, + ready, }; -use futures::{future::Either, ready, channel::mpsc, io::{AsyncRead, AsyncWrite}}; use parking_lot::{Mutex, MutexGuard}; -use std::{fmt, io, pin::Pin, sync::Arc, task::{Context, Poll, Waker}}; use std::convert::TryInto; +use std::{ + fmt, io, + pin::Pin, + sync::Arc, + task::{Context, Poll, Waker}, +}; /// The state of a Yamux stream. #[derive(Copy, Clone, Debug, PartialEq, Eq)] @@ -33,7 +42,7 @@ pub enum State { /// Open for outgoing messages. RecvClosed, /// Closed (terminal state). - Closed + Closed, } impl State { @@ -64,7 +73,7 @@ pub(crate) enum Flag { /// The stream was opened lazily, so set the initial SYN flag. Syn, /// The stream still needs acknowledgement, so set the ACK flag. - Ack + Ack, } /// A multiplexed Yamux stream. @@ -80,7 +89,7 @@ pub struct Stream { config: Arc, sender: mpsc::Sender, flag: Flag, - shared: Arc> + shared: Arc>, } impl fmt::Debug for Stream { @@ -99,15 +108,14 @@ impl fmt::Display for Stream { } impl Stream { - pub(crate) fn new - ( id: StreamId - , conn: connection::Id - , config: Arc - , window: u32 - , credit: u32 - , sender: mpsc::Sender - ) -> Self - { + pub(crate) fn new( + id: StreamId, + conn: connection::Id, + config: Arc, + window: u32, + credit: u32, + sender: mpsc::Sender, + ) -> Self { Stream { id, conn, @@ -148,7 +156,7 @@ impl Stream { config: self.config.clone(), sender: self.sender.clone(), flag: self.flag, - shared: self.shared.clone() + shared: self.shared.clone(), } } @@ -184,7 +192,10 @@ impl Stream { let mut shared = self.shared.lock(); if let Some(credit) = shared.next_window_update() { - ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?); + ready!(self + .sender + .poll_ready(cx) + .map_err(|_| self.write_zero_err())?); shared.window += credit; drop(shared); @@ -192,7 +203,9 @@ impl Stream { let mut frame = Frame::window_update(self.id, credit).right(); self.add_flag(frame.header_mut()); let cmd = StreamCommand::SendFrame(frame); - self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?; + self.sender + .start_send(cmd) + .map_err(|_| self.write_zero_err())?; } Poll::Ready(Ok(())) @@ -214,14 +227,14 @@ impl futures::stream::Stream for Stream { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { if !self.config.read_after_close && self.sender.is_closed() { - return Poll::Ready(None) + return Poll::Ready(None); } match self.send_window_update(cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), // Continue reading buffered data even though sending a window update blocked. - Poll::Pending => {}, + Poll::Pending => {} } let mut shared = self.shared(); @@ -234,16 +247,20 @@ impl futures::stream::Stream for Stream { // a `futures::stream::Stream` since the whole point of this impl is // to consume chunks atomically. It may perhaps happen when mixing // this impl and the `AsyncRead` one. - log::debug!("{}/{}: chunk has been partially consumed", self.conn, self.id); + log::debug!( + "{}/{}: chunk has been partially consumed", + self.conn, + self.id + ); vec = vec.split_off(off) } - return Poll::Ready(Some(Ok(Packet(vec)))) + return Poll::Ready(Some(Ok(Packet(vec)))); } // Buffer is empty, let's check if we can expect to read more data. if !shared.state().can_read() { log::debug!("{}/{}: eof", self.conn, self.id); - return Poll::Ready(None) // stream has been reset + return Poll::Ready(None); // stream has been reset } // Since we have no more data at this point, we want to be woken up @@ -257,16 +274,20 @@ impl futures::stream::Stream for Stream { // Like the `futures::stream::Stream` impl above, but copies bytes into the // provided mutable slice. impl AsyncRead for Stream { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { if !self.config.read_after_close && self.sender.is_closed() { - return Poll::Ready(Ok(0)) + return Poll::Ready(Ok(0)); } match self.send_window_update(cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), // Continue reading buffered data even though sending a window update blocked. - Poll::Pending => {}, + Poll::Pending => {} } // Copy data from stream buffer. @@ -275,26 +296,26 @@ impl AsyncRead for Stream { while let Some(chunk) = shared.buffer.front_mut() { if chunk.is_empty() { shared.buffer.pop(); - continue + continue; } let k = std::cmp::min(chunk.len(), buf.len() - n); - (&mut buf[n .. n + k]).copy_from_slice(&chunk.as_ref()[.. k]); + (&mut buf[n..n + k]).copy_from_slice(&chunk.as_ref()[..k]); n += k; chunk.advance(k); if n == buf.len() { - break + break; } } if n > 0 { log::trace!("{}/{}: read {} bytes", self.conn, self.id, n); - return Poll::Ready(Ok(n)) + return Poll::Ready(Ok(n)); } // Buffer is empty, let's check if we can expect to read more data. if !shared.state().can_read() { log::debug!("{}/{}: eof", self.conn, self.id); - return Poll::Ready(Ok(0)) // stream has been reset + return Poll::Ready(Ok(0)); // stream has been reset } // Since we have no more data at this point, we want to be woken up @@ -306,30 +327,39 @@ impl AsyncRead for Stream { } impl AsyncWrite for Stream { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { - ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?); + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + ready!(self + .sender + .poll_ready(cx) + .map_err(|_| self.write_zero_err())?); let body = { let mut shared = self.shared(); if !shared.state().can_write() { log::debug!("{}/{}: can no longer write", self.conn, self.id); - return Poll::Ready(Err(self.write_zero_err())) + return Poll::Ready(Err(self.write_zero_err())); } if shared.credit == 0 { log::trace!("{}/{}: no more credit left", self.conn, self.id); shared.writer = Some(cx.waker().clone()); - return Poll::Pending + return Poll::Pending; } let k = std::cmp::min(shared.credit as usize, buf.len()); let k = std::cmp::min(k, self.config.split_send_size); shared.credit = shared.credit.saturating_sub(k as u32); - Vec::from(&buf[.. k]) + Vec::from(&buf[..k]) }; let n = body.len(); let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left(); self.add_flag(frame.header_mut()); log::trace!("{}/{}: write {} bytes", self.conn, self.id, n); let cmd = StreamCommand::SendFrame(frame); - self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?; + self.sender + .start_send(cmd) + .map_err(|_| self.write_zero_err())?; Poll::Ready(Ok(n)) } @@ -339,9 +369,12 @@ impl AsyncWrite for Stream { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { if self.state() == State::Closed { - return Poll::Ready(Ok(())) + return Poll::Ready(Ok(())); } - ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?); + ready!(self + .sender + .poll_ready(cx) + .map_err(|_| self.write_zero_err())?); let ack = if self.flag == Flag::Ack { self.flag = Flag::None; true @@ -350,8 +383,11 @@ impl AsyncWrite for Stream { }; log::trace!("{}/{}: close", self.conn, self.id); let cmd = StreamCommand::CloseStream { id: self.id, ack }; - self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?; - self.shared().update_state(self.conn, self.id, State::SendClosed); + self.sender + .start_send(cmd) + .map_err(|_| self.write_zero_err())?; + self.shared() + .update_state(self.conn, self.id, State::SendClosed); Poll::Ready(Ok(())) } } @@ -364,7 +400,7 @@ pub(crate) struct Shared { pub(crate) buffer: Chunks, pub(crate) reader: Option, pub(crate) writer: Option, - config: Arc + config: Arc, } impl Shared { @@ -376,7 +412,7 @@ impl Shared { buffer: Chunks::new(), reader: None, writer: None, - config + config, } } @@ -385,25 +421,37 @@ impl Shared { } /// Update the stream state and return the state before it was updated. - pub(crate) fn update_state(&mut self, cid: connection::Id, sid: StreamId, next: State) -> State { + pub(crate) fn update_state( + &mut self, + cid: connection::Id, + sid: StreamId, + next: State, + ) -> State { use self::State::*; let current = self.state; match (current, next) { - (Closed, _) => {} - (Open, _) => self.state = next, - (RecvClosed, Closed) => self.state = Closed, - (RecvClosed, Open) => {} + (Closed, _) => {} + (Open, _) => self.state = next, + (RecvClosed, Closed) => self.state = Closed, + (RecvClosed, Open) => {} (RecvClosed, RecvClosed) => {} (RecvClosed, SendClosed) => self.state = Closed, - (SendClosed, Closed) => self.state = Closed, - (SendClosed, Open) => {} + (SendClosed, Closed) => self.state = Closed, + (SendClosed, Open) => {} (SendClosed, RecvClosed) => self.state = Closed, (SendClosed, SendClosed) => {} } - log::trace!("{}/{}: update state: ({:?} {:?} {:?})", cid, sid, current, next, self.state); + log::trace!( + "{}/{}: update state: ({:?} {:?} {:?})", + cid, + sid, + current, + next, + self.state + ); current // Return the previous stream state for informational purposes. } @@ -425,7 +473,7 @@ impl Shared { debug_assert!(self.config.receive_window >= self.window); let bytes_received = self.config.receive_window.saturating_sub(self.window); bytes_received - }, + } WindowUpdateMode::OnRead => { debug_assert!(self.config.receive_window >= self.window); let bytes_received = self.config.receive_window.saturating_sub(self.window); @@ -447,4 +495,3 @@ impl Shared { } } } - diff --git a/src/error.rs b/src/error.rs index 86420d29..f9a20c21 100644 --- a/src/error.rs +++ b/src/error.rs @@ -23,7 +23,7 @@ pub enum ConnectionError { /// An operation fails because the connection is closed. Closed, /// Too many streams are open, so no further ones can be opened at this time. - TooManyStreams + TooManyStreams, } impl ConnectionError { @@ -32,7 +32,7 @@ impl ConnectionError { match self { ConnectionError::Io(e) => Some(e.kind()), ConnectionError::Decode(FrameDecodeError::Io(e)) => Some(e.kind()), - _ => None + _ => None, } } } @@ -42,9 +42,11 @@ impl std::fmt::Display for ConnectionError { match self { ConnectionError::Io(e) => write!(f, "i/o error: {}", e), ConnectionError::Decode(e) => write!(f, "decode error: {}", e), - ConnectionError::NoMoreStreamIds => f.write_str("number of stream ids has been exhausted"), + ConnectionError::NoMoreStreamIds => { + f.write_str("number of stream ids has been exhausted") + } ConnectionError::Closed => f.write_str("connection is closed"), - ConnectionError::TooManyStreams => f.write_str("maximum number of streams reached") + ConnectionError::TooManyStreams => f.write_str("maximum number of streams reached"), } } } @@ -56,8 +58,7 @@ impl std::error::Error for ConnectionError { ConnectionError::Decode(e) => Some(e), ConnectionError::NoMoreStreamIds | ConnectionError::Closed - | ConnectionError::TooManyStreams - => None + | ConnectionError::TooManyStreams => None, } } } diff --git a/src/frame.rs b/src/frame.rs index 0e93e5e1..e107456e 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -12,22 +12,25 @@ pub mod header; mod io; use futures::future::Either; -use header::{Header, StreamId, Data, WindowUpdate, GoAway, Ping}; +use header::{Data, GoAway, Header, Ping, StreamId, WindowUpdate}; use std::{convert::TryInto, num::TryFromIntError}; -pub(crate) use io::Io; pub use io::FrameDecodeError; +pub(crate) use io::Io; /// A Yamux message frame consisting of header and body. #[derive(Clone, Debug, PartialEq, Eq)] pub struct Frame { header: Header, - body: Vec + body: Vec, } impl Frame { pub fn new(header: Header) -> Self { - Frame { header, body: Vec::new() } + Frame { + header, + body: Vec::new(), + } } pub fn header(&self) -> &Header { @@ -40,12 +43,18 @@ impl Frame { /// Introduce this frame to the right of a binary frame type. pub(crate) fn right(self) -> Frame> { - Frame { header: self.header.right(), body: self.body } + Frame { + header: self.header.right(), + body: self.body, + } } /// Introduce this frame to the left of a binary frame type. pub(crate) fn left(self) -> Frame> { - Frame { header: self.header.left(), body: self.body } + Frame { + header: self.header.left(), + body: self.body, + } } } @@ -53,22 +62,31 @@ impl From> for Frame<()> { fn from(f: Frame) -> Frame<()> { Frame { header: f.header.into(), - body: f.body + body: f.body, } } } impl Frame<()> { pub(crate) fn into_data(self) -> Frame { - Frame { header: self.header.into_data(), body: self.body } + Frame { + header: self.header.into_data(), + body: self.body, + } } pub(crate) fn into_window_update(self) -> Frame { - Frame { header: self.header.into_window_update(), body: self.body } + Frame { + header: self.header.into_window_update(), + body: self.body, + } } pub(crate) fn into_ping(self) -> Frame { - Frame { header: self.header.into_ping(), body: self.body } + Frame { + header: self.header.into_ping(), + body: self.body, + } } } @@ -76,7 +94,7 @@ impl Frame { pub fn data(id: StreamId, b: Vec) -> Result { Ok(Frame { header: Header::data(id, b.len().try_into()?), - body: b + body: b, }) } @@ -99,7 +117,7 @@ impl Frame { pub fn window_update(id: StreamId, credit: u32) -> Self { Frame { header: Header::window_update(id, credit), - body: Vec::new() + body: Vec::new(), } } } @@ -108,21 +126,21 @@ impl Frame { pub fn term() -> Self { Frame { header: Header::term(), - body: Vec::new() + body: Vec::new(), } } pub fn protocol_error() -> Self { Frame { header: Header::protocol_error(), - body: Vec::new() + body: Vec::new(), } } pub fn internal_error() -> Self { Frame { header: Header::internal_error(), - body: Vec::new() + body: Vec::new(), } } } diff --git a/src/frame/header.rs b/src/frame/header.rs index b3e45c0d..a3e4f0c1 100644 --- a/src/frame/header.rs +++ b/src/frame/header.rs @@ -19,16 +19,19 @@ pub struct Header { flags: Flags, stream_id: StreamId, length: Len, - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } impl fmt::Display for Header { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "(Header {:?} {} (len {}) (flags {:?}))", + write!( + f, + "(Header {:?} {} (len {}) (flags {:?}))", self.tag, self.stream_id, self.length.val(), - self.flags.val()) + self.flags.val() + ) } } @@ -62,7 +65,7 @@ impl Header { flags: self.flags, stream_id: self.stream_id, length: self.length, - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } @@ -137,7 +140,7 @@ impl Header { flags: Flags(0), stream_id: id, length: Len(len), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } } @@ -151,7 +154,7 @@ impl Header { flags: Flags(0), stream_id: id, length: Len(credit), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } @@ -170,7 +173,7 @@ impl Header { flags: Flags(0), stream_id: StreamId(0), length: Len(nonce), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } @@ -203,7 +206,7 @@ impl Header { flags: Flags(0), stream_id: StreamId(0), length: Len(code), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } } @@ -264,7 +267,7 @@ pub enum Tag { Data, WindowUpdate, Ping, - GoAway + GoAway, } /// The protocol version a message corresponds to. @@ -359,16 +362,16 @@ pub fn encode(hdr: &Header) -> [u8; HEADER_SIZE] { let mut buf = [0; HEADER_SIZE]; buf[0] = hdr.version.0; buf[1] = hdr.tag as u8; - buf[2 .. 4].copy_from_slice(&hdr.flags.0.to_be_bytes()); - buf[4 .. 8].copy_from_slice(&hdr.stream_id.0.to_be_bytes()); - buf[8 .. HEADER_SIZE].copy_from_slice(&hdr.length.0.to_be_bytes()); + buf[2..4].copy_from_slice(&hdr.flags.0.to_be_bytes()); + buf[4..8].copy_from_slice(&hdr.stream_id.0.to_be_bytes()); + buf[8..HEADER_SIZE].copy_from_slice(&hdr.length.0.to_be_bytes()); buf } /// Decode a [`Header`] value. pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result, HeaderDecodeError> { if buf[0] != 0 { - return Err(HeaderDecodeError::Version(buf[0])) + return Err(HeaderDecodeError::Version(buf[0])); } let hdr = Header { @@ -378,12 +381,12 @@ pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result, HeaderDecodeError> 1 => Tag::WindowUpdate, 2 => Tag::Ping, 3 => Tag::GoAway, - t => return Err(HeaderDecodeError::Type(t)) + t => return Err(HeaderDecodeError::Type(t)), }, flags: Flags(u16::from_be_bytes([buf[2], buf[3]])), stream_id: StreamId(u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]])), length: Len(u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]])), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, }; Ok(hdr) @@ -396,14 +399,14 @@ pub enum HeaderDecodeError { /// Unknown version. Version(u8), /// An unknown frame type. - Type(u8) + Type(u8), } impl std::fmt::Display for HeaderDecodeError { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { HeaderDecodeError::Version(v) => write!(f, "unknown version: {}", v), - HeaderDecodeError::Type(t) => write!(f, "unknown frame type: {}", t) + HeaderDecodeError::Type(t) => write!(f, "unknown frame type: {}", t), } } } @@ -412,12 +415,13 @@ impl std::error::Error for HeaderDecodeError {} #[cfg(test)] mod tests { - use quickcheck::{Arbitrary, Gen, QuickCheck}; use super::*; + use quickcheck::{Arbitrary, Gen, QuickCheck}; impl Arbitrary for Header<()> { fn arbitrary(g: &mut Gen) -> Self { - let tag = *g.choose(&[Tag::Data, Tag::WindowUpdate, Tag::Ping, Tag::GoAway]) + let tag = *g + .choose(&[Tag::Data, Tag::WindowUpdate, Tag::Ping, Tag::GoAway]) .unwrap(); Header { @@ -426,7 +430,7 @@ mod tests { flags: Flags(Arbitrary::arbitrary(g)), stream_id: StreamId(Arbitrary::arbitrary(g)), length: Len(Arbitrary::arbitrary(g)), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } } diff --git a/src/frame/io.rs b/src/frame/io.rs index d274b2cc..8953b19a 100644 --- a/src/frame/io.rs +++ b/src/frame/io.rs @@ -8,10 +8,17 @@ // at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license // at https://opensource.org/licenses/MIT. +use super::{ + header::{self, HeaderDecodeError}, + Frame, +}; use crate::connection::Id; use futures::{prelude::*, ready}; -use std::{fmt, io, pin::Pin, task::{Context, Poll}}; -use super::{Frame, header::{self, HeaderDecodeError}}; +use std::{ + fmt, io, + pin::Pin, + task::{Context, Poll}, +}; /// A [`Stream`] and writer of [`Frame`] values. #[derive(Debug)] @@ -20,7 +27,7 @@ pub(crate) struct Io { io: T, read_state: ReadState, write_state: WriteState, - max_body_len: usize + max_body_len: usize, } impl Io { @@ -30,7 +37,7 @@ impl Io { io, read_state: ReadState::Init, write_state: WriteState::Init, - max_body_len: max_frame_body_len + max_body_len: max_frame_body_len, } } } @@ -41,25 +48,28 @@ enum WriteState { Header { header: [u8; header::HEADER_SIZE], buffer: Vec, - offset: usize + offset: usize, }, Body { buffer: Vec, - offset: usize - } + offset: usize, + }, } impl fmt::Debug for WriteState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - WriteState::Init => { - f.write_str("(WriteState::Init)") - } + WriteState::Init => f.write_str("(WriteState::Init)"), WriteState::Header { offset, .. } => { write!(f, "(WriteState::Header (offset {}))", offset) } WriteState::Body { offset, buffer } => { - write!(f, "(WriteState::Body (offset {}) (buffer-len {}))", offset, buffer.len()) + write!( + f, + "(WriteState::Body (offset {}) (buffer-len {}))", + offset, + buffer.len() + ) } } } @@ -68,50 +78,50 @@ impl fmt::Debug for WriteState { impl Sink> for Io { type Error = io::Error; - fn poll_ready( - self: Pin<&mut Self>, - cx: &mut Context<'_> - ) -> Poll> { + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = Pin::into_inner(self); loop { log::trace!("{}: write: {:?}", this.id, this.write_state); match &mut this.write_state { WriteState::Init => return Poll::Ready(Ok(())), - WriteState::Header { header, buffer, ref mut offset } => { - match Pin::new(&mut this.io).poll_write(cx, &header[*offset ..]) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok(n)) => { - if n == 0 { - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) - } - *offset += n; - if *offset == header.len() { - if buffer.len() > 0 { - let buffer = std::mem::take(buffer); - this.write_state = WriteState::Body { buffer, offset: 0 }; - } else { - this.write_state = WriteState::Init; - } - } + WriteState::Header { + header, + buffer, + ref mut offset, + } => match Pin::new(&mut this.io).poll_write(cx, &header[*offset..]) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Ready(Ok(n)) => { + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } - } - } - WriteState::Body { buffer, ref mut offset } => { - match Pin::new(&mut this.io).poll_write(cx, &buffer[*offset ..]) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok(n)) => { - if n == 0 { - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) - } - *offset += n; - if *offset == buffer.len() { + *offset += n; + if *offset == header.len() { + if buffer.len() > 0 { + let buffer = std::mem::take(buffer); + this.write_state = WriteState::Body { buffer, offset: 0 }; + } else { this.write_state = WriteState::Init; } } } - } + }, + WriteState::Body { + buffer, + ref mut offset, + } => match Pin::new(&mut this.io).poll_write(cx, &buffer[*offset..]) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Ready(Ok(n)) => { + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + *offset += n; + if *offset == buffer.len() { + this.write_state = WriteState::Init; + } + } + }, } } } @@ -119,23 +129,21 @@ impl Sink> for Io { fn start_send(self: Pin<&mut Self>, f: Frame<()>) -> Result<(), Self::Error> { let header = header::encode(&f.header); let buffer = f.body; - self.get_mut().write_state = WriteState::Header { header, buffer, offset: 0 }; + self.get_mut().write_state = WriteState::Header { + header, + buffer, + offset: 0, + }; Ok(()) } - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut Context<'_> - ) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = Pin::into_inner(self); ready!(this.poll_ready_unpin(cx))?; Pin::new(&mut this.io).poll_flush(cx) } - fn poll_close( - self: Pin<&mut Self>, - cx: &mut Context<'_> - ) -> Poll> { + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = Pin::into_inner(self); ready!(this.poll_ready_unpin(cx))?; Pin::new(&mut this.io).poll_close(cx) @@ -149,14 +157,14 @@ enum ReadState { /// Reading the frame header. Header { offset: usize, - buffer: [u8; header::HEADER_SIZE] + buffer: [u8; header::HEADER_SIZE], }, /// Reading the frame body. Body { header: header::Header<()>, offset: usize, - buffer: Vec - } + buffer: Vec, + }, } impl Stream for Io { @@ -170,68 +178,76 @@ impl Stream for Io { ReadState::Init => { this.read_state = ReadState::Header { offset: 0, - buffer: [0; header::HEADER_SIZE] + buffer: [0; header::HEADER_SIZE], }; } - ReadState::Header { ref mut offset, ref mut buffer } => { + ReadState::Header { + ref mut offset, + ref mut buffer, + } => { if *offset == header::HEADER_SIZE { - let header = - match header::decode(&buffer) { - Ok(hd) => hd, - Err(e) => return Poll::Ready(Some(Err(e.into()))) - }; + let header = match header::decode(&buffer) { + Ok(hd) => hd, + Err(e) => return Poll::Ready(Some(Err(e.into()))), + }; log::trace!("{}: read: {}", this.id, header); if header.tag() != header::Tag::Data { this.read_state = ReadState::Init; - return Poll::Ready(Some(Ok(Frame::new(header)))) + return Poll::Ready(Some(Ok(Frame::new(header)))); } let body_len = header.len().val() as usize; if body_len > this.max_body_len { - return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge(body_len)))) + return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge( + body_len, + )))); } this.read_state = ReadState::Body { header, offset: 0, - buffer: vec![0; body_len] + buffer: vec![0; body_len], }; - continue + continue; } - let buf = &mut buffer[*offset .. header::HEADER_SIZE]; + let buf = &mut buffer[*offset..header::HEADER_SIZE]; match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? { 0 => { if *offset == 0 { - return Poll::Ready(None) + return Poll::Ready(None); } let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); - return Poll::Ready(Some(Err(e))) + return Poll::Ready(Some(Err(e))); } - n => *offset += n + n => *offset += n, } } - ReadState::Body { ref header, ref mut offset, ref mut buffer } => { + ReadState::Body { + ref header, + ref mut offset, + ref mut buffer, + } => { let body_len = header.len().val() as usize; if *offset == body_len { let h = header.clone(); let v = std::mem::take(buffer); this.read_state = ReadState::Init; - return Poll::Ready(Some(Ok(Frame { header: h, body: v }))) + return Poll::Ready(Some(Ok(Frame { header: h, body: v }))); } - let buf = &mut buffer[*offset .. body_len]; + let buf = &mut buffer[*offset..body_len]; match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? { 0 => { let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); - return Poll::Ready(Some(Err(e))) + return Poll::Ready(Some(Err(e))); } - n => *offset += n + n => *offset += n, } } } @@ -242,17 +258,22 @@ impl Stream for Io { impl fmt::Debug for ReadState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - ReadState::Init => { - f.write_str("(ReadState::Init)") - } + ReadState::Init => f.write_str("(ReadState::Init)"), ReadState::Header { offset, .. } => { write!(f, "(ReadState::Header (offset {}))", offset) } - ReadState::Body { header, offset, buffer } => { - write!(f, "(ReadState::Body (header {}) (offset {}) (buffer-len {}))", + ReadState::Body { + header, + offset, + buffer, + } => { + write!( + f, + "(ReadState::Body (header {}) (offset {}) (buffer-len {}))", header, offset, - buffer.len()) + buffer.len() + ) } } } @@ -267,7 +288,7 @@ pub enum FrameDecodeError { /// Decoding the frame header failed. Header(HeaderDecodeError), /// A data frame body length is larger than the configured maximum. - FrameTooLarge(usize) + FrameTooLarge(usize), } impl std::fmt::Display for FrameDecodeError { @@ -275,7 +296,7 @@ impl std::fmt::Display for FrameDecodeError { match self { FrameDecodeError::Io(e) => write!(f, "i/o error: {}", e), FrameDecodeError::Header(e) => write!(f, "decode error: {}", e), - FrameDecodeError::FrameTooLarge(n) => write!(f, "frame body is too large ({})", n) + FrameDecodeError::FrameTooLarge(n) => write!(f, "frame body is too large ({})", n), } } } @@ -285,7 +306,7 @@ impl std::error::Error for FrameDecodeError { match self { FrameDecodeError::Io(e) => Some(e), FrameDecodeError::Header(e) => Some(e), - FrameDecodeError::FrameTooLarge(_) => None + FrameDecodeError::FrameTooLarge(_) => None, } } } @@ -304,22 +325,21 @@ impl From for FrameDecodeError { #[cfg(test)] mod tests { + use super::*; use quickcheck::{Arbitrary, Gen, QuickCheck}; use rand::RngCore; - use super::*; impl Arbitrary for Frame<()> { fn arbitrary(g: &mut Gen) -> Self { let mut header: header::Header<()> = Arbitrary::arbitrary(g); - let body = - if header.tag() == header::Tag::Data { - header.set_len(header.len().val() % 4096); - let mut b = vec![0; header.len().val() as usize]; - rand::thread_rng().fill_bytes(&mut b); - b - } else { - Vec::new() - }; + let body = if header.tag() == header::Tag::Data { + header.set_len(header.len().val() % 4096); + let mut b = vec![0; header.len().val() as usize]; + rand::thread_rng().fill_bytes(&mut b); + b + } else { + Vec::new() + }; Frame { header, body } } } @@ -331,10 +351,10 @@ mod tests { let id = crate::connection::Id::random(); let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.body.len()); if io.send(f.clone()).await.is_err() { - return false + return false; } if io.flush().await.is_err() { - return false + return false; } io.io.set_position(0); if let Ok(Some(x)) = io.try_next().await { @@ -350,4 +370,3 @@ mod tests { .quickcheck(property as fn(Frame<()>) -> bool) } } - diff --git a/src/lib.rs b/src/lib.rs index 28e5a499..a5e2eb4c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,9 +34,12 @@ mod tests; pub(crate) mod connection; -pub use crate::connection::{Connection, Mode, Control, Packet, Stream, into_stream}; +pub use crate::connection::{into_stream, Connection, Control, Mode, Packet, Stream}; pub use crate::error::ConnectionError; -pub use crate::frame::{FrameDecodeError, header::{HeaderDecodeError, StreamId}}; +pub use crate::frame::{ + header::{HeaderDecodeError, StreamId}, + FrameDecodeError, +}; const DEFAULT_CREDIT: u32 = 256 * 1024; // as per yamux specification @@ -80,7 +83,7 @@ pub enum WindowUpdateMode { /// - Endpoints *A* and *B* write at most *n* frames concurrently such that the sum /// of the frame lengths is less or equal to the available credit of *A* and *B* /// respectively. - OnRead + OnRead, } /// Yamux configuration. @@ -100,7 +103,7 @@ pub struct Config { max_num_streams: usize, window_update_mode: WindowUpdateMode, read_after_close: bool, - split_send_size: usize + split_send_size: usize, } impl Default for Config { @@ -111,7 +114,7 @@ impl Default for Config { max_num_streams: 8192, window_update_mode: WindowUpdateMode::OnRead, read_after_close: true, - split_send_size: DEFAULT_SPLIT_SEND_SIZE + split_send_size: DEFAULT_SPLIT_SEND_SIZE, } } } @@ -170,4 +173,3 @@ static_assertions::const_assert! { static_assertions::const_assert! { std::mem::size_of::() <= std::mem::size_of::() } - diff --git a/src/pause.rs b/src/pause.rs index 16dde865..1ada2185 100644 --- a/src/pause.rs +++ b/src/pause.rs @@ -9,7 +9,10 @@ // at https://opensource.org/licenses/MIT. use futures::{prelude::*, stream::FusedStream}; -use std::{pin::Pin, task::{Context, Poll, Waker}}; +use std::{ + pin::Pin, + task::{Context, Poll, Waker}, +}; /// Wraps a [`futures::stream::Stream`] and adds the ability to pause it. /// @@ -21,7 +24,7 @@ use std::{pin::Pin, task::{Context, Poll, Waker}}; pub(crate) struct Pausable { paused: bool, stream: S, - waker: Option + waker: Option, } impl Pausable { @@ -29,7 +32,7 @@ impl Pausable { Pausable { paused: false, stream, - waker: None + waker: None, } } @@ -58,7 +61,7 @@ impl Stream for Pausable { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { if !self.paused { - return self.stream.poll_next_unpin(cx) + return self.stream.poll_next_unpin(cx); } self.waker = Some(cx.waker().clone()); Poll::Pending @@ -77,8 +80,8 @@ impl FusedStream for Pausable { #[cfg(test)] mod tests { - use futures::prelude::*; use super::Pausable; + use futures::prelude::*; #[test] fn pause_unpause() { diff --git a/src/tests.rs b/src/tests.rs index 854d1139..bdfc15e0 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -8,21 +8,29 @@ // at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license // at https://opensource.org/licenses/MIT. -use crate::{Config, Connection, ConnectionError, Mode, Control, connection::State}; use crate::WindowUpdateMode; -use futures::{future, prelude::*}; +use crate::{connection::State, Config, Connection, ConnectionError, Control, Mode}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; +use futures::executor::LocalPool; +use futures::future::join; use futures::io::AsyncReadExt; +use futures::task::{Spawn, SpawnExt}; +use futures::{future, prelude::*}; use quickcheck::{Arbitrary, Gen, QuickCheck, TestResult}; -use std::{fmt::Debug, io, net::{Ipv4Addr, SocketAddr, SocketAddrV4}}; -use tokio::{net::{TcpStream, TcpListener}, runtime::Runtime, task}; -use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; -use futures::channel::mpsc::{unbounded, UnboundedSender, UnboundedReceiver}; -use futures::executor::LocalPool; +use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll, Waker}; -use std::pin::Pin; -use futures::future::join; -use futures::task::{Spawn, SpawnExt}; +use std::{ + fmt::Debug, + io, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, +}; +use tokio::{ + net::{TcpListener, TcpStream}, + runtime::Runtime, + task, +}; +use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; #[test] fn prop_config_send_recv_single() { @@ -46,14 +54,18 @@ fn prop_config_send_recv_single() { let connection = Connection::new(socket, cfg2.0, Mode::Client); let control = connection.control(); task::spawn(crate::into_stream(connection).for_each(|_| future::ready(()))); - send_recv_single(control, iter.clone()).await.expect("send_recv") + send_recv_single(control, iter.clone()) + .await + .expect("send_recv") }; let result = futures::join!(server, client).1; TestResult::from_bool(result.len() == num_requests && result.into_iter().eq(iter)) }) } - QuickCheck::new().tests(10).quickcheck(prop as fn(_, _, _) -> _) + QuickCheck::new() + .tests(10) + .quickcheck(prop as fn(_, _, _) -> _) } #[test] @@ -85,14 +97,16 @@ fn prop_config_send_recv_multi() { TestResult::from_bool(result.len() == num_requests && result.into_iter().eq(iter)) }) } - QuickCheck::new().tests(10).quickcheck(prop as fn(_, _, _) -> _) + QuickCheck::new() + .tests(10) + .quickcheck(prop as fn(_, _, _) -> _) } #[test] fn prop_send_recv() { fn prop(msgs: Vec) -> TestResult { if msgs.is_empty() { - return TestResult::discard() + return TestResult::discard(); } let rt = Runtime::new().unwrap(); rt.block_on(async move { @@ -147,7 +161,7 @@ fn prop_max_streams() { let mut control = connection.control(); task::spawn(crate::into_stream(connection).for_each(|_| future::ready(()))); let mut v = Vec::new(); - for _ in 0 .. max_streams { + for _ in 0..max_streams { v.push(control.open_stream().await.expect("open_stream")) } if let Err(ConnectionError::TooManyStreams) = control.open_stream().await { @@ -172,7 +186,9 @@ fn prop_send_recv_half_closed() { let server = async { let socket = listener.accept().await.expect("accept").0.compat(); let mut connection = Connection::new(socket, Config::default(), Mode::Server); - let mut stream = connection.next_stream().await + let mut stream = connection + .next_stream() + .await .expect("S: next_stream") .expect("S: some stream"); task::spawn(crate::into_stream(connection).for_each(|_| future::ready(()))); @@ -236,21 +252,27 @@ fn write_deadlock() { // Create and spawn a "server" that echoes every message back to the client. let server = Connection::new(server_endpoint, Config::default(), Mode::Server); - pool.spawner().spawn_obj(async move { - crate::into_stream(server).try_for_each_concurrent( - None, |mut stream| async move { - { - let (mut r, mut w) = AsyncReadExt::split(&mut stream); - // Write back the bytes received. This may buffer internally. - futures::io::copy(&mut r, &mut w).await?; - } - log::debug!("S: stream {} done.", stream.id()); - stream.close().await?; - Ok(()) - }) - .await - .expect("server failed") - }.boxed().into()).unwrap(); + pool.spawner() + .spawn_obj( + async move { + crate::into_stream(server) + .try_for_each_concurrent(None, |mut stream| async move { + { + let (mut r, mut w) = AsyncReadExt::split(&mut stream); + // Write back the bytes received. This may buffer internally. + futures::io::copy(&mut r, &mut w).await?; + } + log::debug!("S: stream {} done.", stream.id()); + stream.close().await?; + Ok(()) + }) + .await + .expect("server failed") + } + .boxed() + .into(), + ) + .unwrap(); // Create and spawn a "client" that sends messages expected to be echoed // by the server. @@ -258,31 +280,44 @@ fn write_deadlock() { let mut ctrl = client.control(); // Continuously advance the Yamux connection of the client in a background task. - pool.spawner().spawn_obj( - crate::into_stream(client).for_each(|_| { - panic!("Unexpected inbound stream for client"); - #[allow(unreachable_code)] - future::ready(()) - }).boxed().into() - ).unwrap(); + pool.spawner() + .spawn_obj( + crate::into_stream(client) + .for_each(|_| { + panic!("Unexpected inbound stream for client"); + #[allow(unreachable_code)] + future::ready(()) + }) + .boxed() + .into(), + ) + .unwrap(); // Send the message, expecting it to be echo'd. - pool.run_until(pool.spawner().spawn_with_handle(async move { - let stream = ctrl.open_stream().await.unwrap(); - let (mut reader, mut writer) = AsyncReadExt::split(stream); - let mut b = vec![0; msg.len()]; - // Write & read concurrently, so that the client is able - // to start reading the echo'd bytes before it even finished - // sending them all. - let _ = join( - writer.write_all(msg.as_ref()).map_err(|e| panic!(e)), - reader.read_exact(&mut b[..]).map_err(|e| panic!(e)), - ).await; - let mut stream = reader.reunite(writer).unwrap(); - stream.close().await.unwrap(); - log::debug!("C: Stream {} done.", stream.id()); - assert_eq!(b, msg); - }.boxed()).unwrap()); + pool.run_until( + pool.spawner() + .spawn_with_handle( + async move { + let stream = ctrl.open_stream().await.unwrap(); + let (mut reader, mut writer) = AsyncReadExt::split(stream); + let mut b = vec![0; msg.len()]; + // Write & read concurrently, so that the client is able + // to start reading the echo'd bytes before it even finished + // sending them all. + let _ = join( + writer.write_all(msg.as_ref()).map_err(|e| panic!(e)), + reader.read_exact(&mut b[..]).map_err(|e| panic!(e)), + ) + .await; + let mut stream = reader.reunite(writer).unwrap(); + stream.close().await.unwrap(); + log::debug!("C: Stream {} done.", stream.id()); + assert_eq!(b, msg); + } + .boxed(), + ) + .unwrap(), + ); } #[derive(Clone, Debug)] @@ -346,7 +381,7 @@ async fn repeat_echo(c: Connection>) -> Result<(), ConnectionE /// collect the response. The sequence of responses will be returned. async fn send_recv(mut control: Control, iter: I) -> Result>, ConnectionError> where - I: Iterator> + I: Iterator>, { let mut result = Vec::new(); @@ -379,7 +414,7 @@ where /// sequence of responses will be returned. async fn send_recv_single(mut control: Control, iter: I) -> Result>, ConnectionError> where - I: Iterator> + I: Iterator>, { let stream = control.open_stream().await?; log::debug!("C: new stream: {}", stream); @@ -436,7 +471,7 @@ mod bounded { pub fn channel( (name_a, capacity_a): (&'static str, usize), - (name_b, capacity_b): (&'static str, usize) + (name_b, capacity_b): (&'static str, usize), ) -> (Endpoint, Endpoint) { let (a_to_b_sender, a_to_b_receiver) = unbounded(); let (b_to_a_sender, b_to_a_receiver) = unbounded(); @@ -475,8 +510,10 @@ mod bounded { ) -> Poll> { if self.recv_buf.is_empty() { match ready!(self.recv.poll_next_unpin(cx)) { - Some(bytes) => { self.recv_buf = bytes; } - None => return Poll::Ready(Ok(0)) + Some(bytes) => { + self.recv_buf = bytes; + } + None => return Poll::Ready(Ok(0)), } } @@ -486,19 +523,32 @@ mod bounded { let mut guard = self.recv_guard.lock().unwrap(); if let Some(waker) = guard.waker.take() { - log::debug!("{}: read: notifying waker after read of {} bytes", self.name, n); + log::debug!( + "{}: read: notifying waker after read of {} bytes", + self.name, + n + ); waker.wake(); } guard.size -= n; - log::debug!("{}: read: channel: {}/{}", self.name, guard.size, self.capacity); + log::debug!( + "{}: read: channel: {}/{}", + self.name, + guard.size, + self.capacity + ); Poll::Ready(Ok(n)) } } impl AsyncWrite for Endpoint { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { debug_assert!(buf.len() > 0); let mut guard = self.send_guard.lock().unwrap(); let n = std::cmp::min(self.capacity - guard.size, buf.len()); @@ -508,11 +558,17 @@ mod bounded { return Poll::Pending; } - self.send.unbounded_send(buf[0..n].to_vec()) + self.send + .unbounded_send(buf[0..n].to_vec()) .map_err(|e| Error::new(ErrorKind::ConnectionAborted, e))?; guard.size += n; - log::debug!("{}: write: channel: {}/{}", self.name, guard.size, self.capacity); + log::debug!( + "{}: write: channel: {}/{}", + self.name, + guard.size, + self.capacity + ); Poll::Ready(Ok(n)) } diff --git a/tests/concurrent.rs b/tests/concurrent.rs index 1bfe835a..d816eb87 100644 --- a/tests/concurrent.rs +++ b/tests/concurrent.rs @@ -9,11 +9,14 @@ // at https://opensource.org/licenses/MIT. use futures::{channel::mpsc, prelude::*}; -use std::{net::{Ipv4Addr, SocketAddr, SocketAddrV4}, sync::Arc}; -use tokio::{net::TcpSocket, task, runtime::Runtime}; +use quickcheck::{Arbitrary, Gen, QuickCheck}; +use std::{ + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + sync::Arc, +}; +use tokio::{net::TcpSocket, runtime::Runtime, task}; use tokio_util::compat::TokioAsyncReadCompatExt; use yamux::{Config, Connection, Mode, WindowUpdateMode}; -use quickcheck::{Arbitrary, Gen, QuickCheck}; const PAYLOAD_SIZE: usize = 128 * 1024; async fn roundtrip( @@ -25,8 +28,12 @@ async fn roundtrip( let listener = { let socket = TcpSocket::new_v4().expect("new_v4"); if let Some(size) = tcp_buffer_sizes { - socket.set_send_buffer_size(size.send).expect("send size set"); - socket.set_recv_buffer_size(size.recv).expect("recv size set"); + socket + .set_send_buffer_size(size.send) + .expect("send size set"); + socket + .set_recv_buffer_size(size.recv) + .expect("recv size set"); } socket.bind(address).expect("bind"); socket.listen(1024).expect("listen") @@ -68,8 +75,12 @@ async fn roundtrip( let conn = { let socket = TcpSocket::new_v4().expect("new_v4"); if let Some(size) = tcp_buffer_sizes { - socket.set_send_buffer_size(size.send).expect("send size set"); - socket.set_recv_buffer_size(size.recv).expect("recv size set"); + socket + .set_send_buffer_size(size.send) + .expect("send size set"); + socket + .set_recv_buffer_size(size.recv) + .expect("recv size set"); } let stream = socket.connect(address).await.expect("connect").compat(); Connection::new(stream, client_cfg, Mode::Client) @@ -77,14 +88,16 @@ async fn roundtrip( let (tx, rx) = mpsc::unbounded(); let mut ctrl = conn.control(); task::spawn(yamux::into_stream(conn).for_each(|_| future::ready(()))); - for _ in 0 .. nstreams { + for _ in 0..nstreams { let data = data.clone(); let tx = tx.clone(); let mut ctrl = ctrl.clone(); task::spawn(async move { let mut stream = ctrl.open_stream().await?; log::debug!("C: opened new stream {}", stream.id()); - stream.write_all(&(data.len() as u32).to_be_bytes()[..]).await?; + stream + .write_all(&(data.len() as u32).to_be_bytes()[..]) + .await?; stream.write_all(&data).await?; stream.close().await?; log::debug!("C: {}: wrote {} bytes", stream.id(), data.len()); @@ -96,7 +109,10 @@ async fn roundtrip( Ok::<(), yamux::ConnectionError>(()) }); } - let n = rx.take(nstreams).fold(0, |acc, n| future::ready(acc + n)).await; + let n = rx + .take(nstreams) + .fold(0, |acc, n| future::ready(acc + n)) + .await; ctrl.close().await.expect("close connection"); assert_eq!(nstreams, n) } @@ -111,9 +127,9 @@ struct TcpBufferSizes { impl Arbitrary for TcpBufferSizes { fn arbitrary(g: &mut Gen) -> Self { let send = if bool::arbitrary(g) { - 16*1024 + 16 * 1024 } else { - 32*1024 + 32 * 1024 }; // Have receive buffer size be some multiple of send buffer size. @@ -135,7 +151,12 @@ fn concurrent_streams() { let data = Arc::new(vec![0x42; PAYLOAD_SIZE]); let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 0)); - Runtime::new().expect("new runtime").block_on(roundtrip(addr, 1000, data, tcp_buffer_sizes)); + Runtime::new().expect("new runtime").block_on(roundtrip( + addr, + 1000, + data, + tcp_buffer_sizes, + )); } QuickCheck::new().tests(3).quickcheck(prop as fn(_) -> _)