From 68a9e3d07947a4f261dd38463fc71dfb86f6f60e Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 6 Dec 2023 13:24:22 +0100 Subject: [PATCH] feat: auto-tune (dynamic) stream receive window (#176) - Send Yamux' Pings on an interval to measure the connection round-trip-time. - Dynamically grow the stream receive window based on the round-trip-time and the estimated bandwidth. --- CHANGELOG.md | 3 + Cargo.toml | 2 +- quickcheck-ext/Cargo.toml | 13 + quickcheck-ext/src/lib.rs | 46 +++ test-harness/Cargo.toml | 3 +- test-harness/src/lib.rs | 16 +- test-harness/tests/ack_backlog.rs | 4 +- yamux/Cargo.toml | 4 +- yamux/src/connection.rs | 100 +++---- yamux/src/connection/rtt.rs | 140 +++++++++ yamux/src/connection/stream.rs | 163 ++++++----- yamux/src/connection/stream/flow_control.rs | 304 ++++++++++++++++++++ yamux/src/frame.rs | 16 ++ yamux/src/frame/io.rs | 15 +- yamux/src/lib.rs | 98 +++++-- 15 files changed, 769 insertions(+), 158 deletions(-) create mode 100644 quickcheck-ext/Cargo.toml create mode 100644 quickcheck-ext/src/lib.rs create mode 100644 yamux/src/connection/rtt.rs create mode 100644 yamux/src/connection/stream/flow_control.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index afd00caa..295cf9df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # 0.13.0 +- Introduce dynamic stream receive window auto-tuning. + While low-resourced deployments maintain the benefit of small buffers, high resource deployments eventually end-up with a window of roughly the bandwidth-delay-product (ideal) and are thus able to use the entire available bandwidth. + See [PR 176](https://github.com/libp2p/rust-yamux/pull/176) for performance results and details on the implementation. - Remove `WindowUpdateMode`. Behavior will always be `WindowUpdateMode::OnRead`, thus enabling flow-control and enforcing backpressure. See [PR 178](https://github.com/libp2p/rust-yamux/pull/178). diff --git a/Cargo.toml b/Cargo.toml index d28c2d96..7c60ecb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,3 @@ [workspace] -members = ["yamux", "test-harness"] +members = ["yamux", "test-harness", "quickcheck-ext"] resolver = "2" diff --git a/quickcheck-ext/Cargo.toml b/quickcheck-ext/Cargo.toml new file mode 100644 index 00000000..50aed154 --- /dev/null +++ b/quickcheck-ext/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "quickcheck-ext" +version = "0.1.0" +edition = "2021" +publish = false +license = "Unlicense/MIT" + +[package.metadata.release] +release = false + +[dependencies] +quickcheck = "1" +num-traits = "0.2" diff --git a/quickcheck-ext/src/lib.rs b/quickcheck-ext/src/lib.rs new file mode 100644 index 00000000..a3b9ce26 --- /dev/null +++ b/quickcheck-ext/src/lib.rs @@ -0,0 +1,46 @@ +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] + +pub use quickcheck::*; + +use core::ops::Range; +use num_traits::sign::Unsigned; + +pub trait GenRange { + fn gen_range(&mut self, _range: Range) -> T; + + fn gen_index(&mut self, ubound: usize) -> usize { + if ubound <= (core::u32::MAX as usize) { + self.gen_range(0..ubound as u32) as usize + } else { + self.gen_range(0..ubound) + } + } +} + +impl GenRange for Gen { + fn gen_range(&mut self, range: Range) -> T { + ::arbitrary(self) % (range.end - range.start) + range.start + } +} + +pub trait SliceRandom { + fn shuffle(&mut self, arr: &mut [T]); + fn choose_multiple<'a, T>( + &mut self, + arr: &'a [T], + amount: usize, + ) -> std::iter::Take> { + let mut v: Vec<&T> = arr.iter().collect(); + self.shuffle(&mut v); + v.into_iter().take(amount) + } +} + +impl SliceRandom for Gen { + fn shuffle(&mut self, arr: &mut [T]) { + for i in (1..arr.len()).rev() { + // invariant: elements with index > i have been locked in place. + arr.swap(i, self.gen_index(i + 1)); + } + } +} diff --git a/test-harness/Cargo.toml b/test-harness/Cargo.toml index 376d2900..4a3d6daa 100644 --- a/test-harness/Cargo.toml +++ b/test-harness/Cargo.toml @@ -7,7 +7,7 @@ publish = false [dependencies] yamux = { path = "../yamux" } futures = "0.3.4" -quickcheck = "1.0" +quickcheck = { package = "quickcheck-ext", path = "../quickcheck-ext" } tokio = { version = "1.0", features = ["net", "rt-multi-thread", "macros", "time"] } tokio-util = { version = "0.7", features = ["compat"] } anyhow = "1" @@ -17,7 +17,6 @@ log = "0.4.17" criterion = "0.5" env_logger = "0.10" futures = "0.3.4" -quickcheck = "1.0" tokio = { version = "1.0", features = ["net", "rt-multi-thread", "macros", "time"] } tokio-util = { version = "0.7", features = ["compat"] } constrained-connection = "0.1" diff --git a/test-harness/src/lib.rs b/test-harness/src/lib.rs index e02c12e0..58412ac3 100644 --- a/test-harness/src/lib.rs +++ b/test-harness/src/lib.rs @@ -76,7 +76,7 @@ where .try_for_each_concurrent(None, |mut stream| async move { { let (mut r, mut w) = AsyncReadExt::split(&mut stream); - futures::io::copy(&mut r, &mut w).await?; + futures::io::copy(&mut r, &mut w).await.unwrap(); } stream.close().await?; Ok(()) @@ -447,9 +447,21 @@ pub struct TestConfig(pub Config); impl Arbitrary for TestConfig { fn arbitrary(g: &mut Gen) -> Self { + use quickcheck::GenRange; + let mut c = Config::default(); + let max_num_streams = 512; + c.set_read_after_close(Arbitrary::arbitrary(g)); - c.set_receive_window(256 * 1024 + u32::arbitrary(g) % (768 * 1024)); + c.set_max_num_streams(max_num_streams); + if bool::arbitrary(g) { + c.set_max_connection_receive_window(Some( + g.gen_range(max_num_streams * (yamux::DEFAULT_CREDIT as usize)..usize::MAX), + )); + } else { + c.set_max_connection_receive_window(None); + } + TestConfig(c) } } diff --git a/test-harness/tests/ack_backlog.rs b/test-harness/tests/ack_backlog.rs index 2aa4e33a..30030425 100644 --- a/test-harness/tests/ack_backlog.rs +++ b/test-harness/tests/ack_backlog.rs @@ -197,8 +197,8 @@ where this.worker_streams.push(ping_pong(stream.unwrap()).boxed()); continue; } - (Poll::Ready(_), Some(_)) => { - panic!("should not be able to open stream if server hasn't acknowledged existing streams") + (Poll::Ready(e), Some(_)) => { + panic!("should not be able to open stream if server hasn't acknowledged existing streams: {:?}", e) } (Poll::Pending, None) => {} } diff --git a/yamux/Cargo.toml b/yamux/Cargo.toml index 69973e05..3cac8f02 100644 --- a/yamux/Cargo.toml +++ b/yamux/Cargo.toml @@ -10,7 +10,7 @@ repository = "https://github.com/paritytech/yamux" edition = "2021" [dependencies] -futures = { version = "0.3.12", default-features = false, features = ["std"] } +futures = { version = "0.3.12", default-features = false, features = ["std", "executor"] } log = "0.4.8" nohash-hasher = "0.2" parking_lot = "0.12" @@ -20,4 +20,4 @@ pin-project = "1.1.0" [dev-dependencies] futures = { version = "0.3.12", default-features = false, features = ["executor"] } -quickcheck = "1.0" +quickcheck = { package = "quickcheck-ext", path = "../quickcheck-ext" } diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 0474f79d..27beb28e 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -14,6 +14,7 @@ mod cleanup; mod closing; +mod rtt; mod stream; use crate::tagged_stream::TaggedStream; @@ -287,8 +288,15 @@ struct Active { pending_frames: VecDeque>, new_outbound_stream_waker: Option, -} + rtt: rtt::Rtt, + + /// A stream's `max_stream_receive_window` can grow beyond [`DEFAULT_CREDIT`], see + /// [`Stream::next_window_update`]. This field is the sum of the bytes by which all streams' + /// `max_stream_receive_window` have each exceeded [`DEFAULT_CREDIT`]. Used to enforce + /// [`Config::max_connection_receive_window`]. + accumulated_max_stream_windows: Arc>, +} /// `Stream` to `Connection` commands. #[derive(Debug)] pub(crate) enum StreamCommand { @@ -300,15 +308,13 @@ pub(crate) enum StreamCommand { /// Possible actions as a result of incoming frame handling. #[derive(Debug)] -enum Action { +pub(crate) enum Action { /// Nothing to be done. None, /// A new stream has been opened by the remote. New(Stream), /// A ping should be answered. Ping(Frame), - /// A stream should be reset. - Reset(Frame), /// The connection should be terminated. Terminate(Frame), } @@ -341,7 +347,7 @@ impl Active { fn new(socket: T, cfg: Config, mode: Mode) -> Self { let id = Id::random(); log::debug!("new connection: {} ({:?})", id, mode); - let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse(); + let socket = frame::Io::new(id, socket).fuse(); Active { id, mode, @@ -356,6 +362,8 @@ impl Active { }, pending_frames: VecDeque::default(), new_outbound_stream_waker: None, + rtt: rtt::Rtt::new(), + accumulated_max_stream_windows: Default::default(), } } @@ -376,6 +384,14 @@ impl Active { fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { loop { if self.socket.poll_ready_unpin(cx).is_ready() { + // Note `next_ping` does not register a waker and thus if not called regularly (idle + // connection) no ping is sent. This is deliberate as an idle connection does not + // need RTT measurements to increase its stream receive window. + if let Some(frame) = self.rtt.next_ping() { + self.socket.start_send_unpin(frame.into())?; + continue; + } + if let Some(frame) = self.pending_frames.pop_front() { self.socket.start_send_unpin(frame)?; continue; @@ -439,20 +455,7 @@ impl Active { log::trace!("{}: creating new outbound stream", self.id); let id = self.next_stream_id()?; - let extra_credit = self.config.receive_window - DEFAULT_CREDIT; - - if extra_credit > 0 { - let mut frame = Frame::window_update(id, extra_credit); - frame.header_mut().syn(); - log::trace!("{}/{}: sending initial {}", self.id, id, frame.header()); - self.pending_frames.push_back(frame.into()); - } - - let mut stream = self.make_new_outbound_stream(id, self.config.receive_window); - - if extra_credit == 0 { - stream.set_flag(stream::Flag::Syn) - } + let stream = self.make_new_outbound_stream(id); log::debug!("{}: new outbound {} of {}", self.id, stream, self); self.streams.insert(id, stream.clone_shared()); @@ -537,7 +540,9 @@ impl Active { fn on_frame(&mut self, frame: Frame<()>) -> Result> { log::trace!("{}: received: {}", self.id, frame.header()); - if frame.header().flags().contains(header::ACK) { + if frame.header().flags().contains(header::ACK) + && matches!(frame.header().tag(), Tag::Data | Tag::WindowUpdate) + { let id = frame.header().stream_id(); if let Some(stream) = self.streams.get(&id) { stream @@ -565,10 +570,6 @@ impl Active { log::trace!("{}/{}: pong", self.id, f.header().stream_id()); self.pending_frames.push_back(f.into()); } - Action::Reset(f) => { - log::trace!("{}/{}: sending reset", self.id, f.header().stream_id()); - self.pending_frames.push_back(f.into()); - } Action::Terminate(f) => { log::trace!("{}: sending term", self.id); self.pending_frames.push_back(f.into()); @@ -620,23 +621,22 @@ impl Active { log::error!("{}: maximum number of streams reached", self.id); return Action::Terminate(Frame::internal_error()); } - let mut stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT); + let stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT); { let mut shared = stream.shared(); if is_finish { shared.update_state(self.id, stream_id, State::RecvClosed); } - shared.window = shared.window.saturating_sub(frame.body_len()); + shared.consume_receive_window(frame.body_len()); shared.buffer.push(frame.into_body()); } - stream.set_flag(stream::Flag::Ack); self.streams.insert(stream_id, stream.clone_shared()); return Action::New(stream); } if let Some(s) = self.streams.get_mut(&stream_id) { let mut shared = s.lock(); - if frame.body().len() > shared.window as usize { + if frame.body_len() > shared.receive_window() { log::error!( "{}/{}: frame body larger than window of stream", self.id, @@ -647,18 +647,7 @@ impl Active { if is_finish { shared.update_state(self.id, stream_id, State::RecvClosed); } - let max_buffer_size = self.config.max_buffer_size; - if shared.buffer.len() >= max_buffer_size { - log::error!( - "{}/{}: buffer of stream grows beyond limit", - self.id, - stream_id - ); - let mut header = Header::data(stream_id, 0); - header.rst(); - return Action::Reset(Frame::new(header)); - } - shared.window = shared.window.saturating_sub(frame.body_len()); + shared.consume_receive_window(frame.body_len()); shared.buffer.push(frame.into_body()); if let Some(w) = shared.reader.take() { w.wake() @@ -718,8 +707,7 @@ impl Active { } let credit = frame.header().credit() + DEFAULT_CREDIT; - let mut stream = self.make_new_inbound_stream(stream_id, credit); - stream.set_flag(stream::Flag::Ack); + let stream = self.make_new_inbound_stream(stream_id, credit); if is_finish { stream @@ -732,7 +720,7 @@ impl Active { if let Some(s) = self.streams.get_mut(&stream_id) { let mut shared = s.lock(); - shared.credit += frame.header().credit(); + shared.increase_send_window_by(frame.header().credit()); if is_finish { shared.update_state(self.id, stream_id, State::RecvClosed); } @@ -761,15 +749,14 @@ impl Active { fn on_ping(&mut self, frame: &Frame) -> Action { let stream_id = frame.header().stream_id(); if frame.header().flags().contains(header::ACK) { - // pong - return Action::None; + return self.rtt.handle_pong(frame.nonce()); } if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id) { let mut hdr = Header::ping(frame.header().nonce()); hdr.ack(); return Action::Ping(Frame::new(hdr)); } - log::trace!( + log::debug!( "{}/{}: ping for unknown stream, possibly dropped earlier: {:?}", self.id, stream_id, @@ -794,10 +781,18 @@ impl Active { waker.wake(); } - Stream::new_inbound(id, self.id, config, credit, sender) + Stream::new_inbound( + id, + self.id, + config, + credit, + sender, + self.rtt.clone(), + self.accumulated_max_stream_windows.clone(), + ) } - fn make_new_outbound_stream(&mut self, id: StreamId, window: u32) -> Stream { + fn make_new_outbound_stream(&mut self, id: StreamId) -> Stream { let config = self.config.clone(); let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number. @@ -806,7 +801,14 @@ impl Active { waker.wake(); } - Stream::new_outbound(id, self.id, config, window, sender) + Stream::new_outbound( + id, + self.id, + config, + sender, + self.rtt.clone(), + self.accumulated_max_stream_windows.clone(), + ) } fn next_stream_id(&mut self) -> Result { diff --git a/yamux/src/connection/rtt.rs b/yamux/src/connection/rtt.rs new file mode 100644 index 00000000..56938d7f --- /dev/null +++ b/yamux/src/connection/rtt.rs @@ -0,0 +1,140 @@ +// Copyright (c) 2023 Protocol Labs. +// +// Licensed under the Apache License, Version 2.0 or MIT license, at your option. +// +// A copy of the Apache License, Version 2.0 is included in the software as +// LICENSE-APACHE and a copy of the MIT license is included in the software +// as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0 +// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license +// at https://opensource.org/licenses/MIT. + +//! Connection round-trip time measurement + +use std::{ + sync::Arc, + time::{Duration, Instant}, +}; + +use parking_lot::Mutex; + +use crate::connection::Action; +use crate::frame::{header::Ping, Frame}; + +const PING_INTERVAL: Duration = Duration::from_secs(10); + +#[derive(Clone, Debug)] +pub(crate) struct Rtt(Arc>); + +impl Rtt { + pub(crate) fn new() -> Self { + Self(Arc::new(Mutex::new(RttInner { + rtt: None, + state: RttState::Waiting { + next: Instant::now(), + }, + }))) + } + + pub(crate) fn next_ping(&mut self) -> Option> { + let state = &mut self.0.lock().state; + + match state { + RttState::AwaitingPong { .. } => return None, + RttState::Waiting { next } => { + if *next > Instant::now() { + return None; + } + } + } + + let nonce = rand::random(); + *state = RttState::AwaitingPong { + sent_at: Instant::now(), + nonce, + }; + log::debug!("sending ping {nonce}"); + Some(Frame::ping(nonce)) + } + + pub(crate) fn handle_pong(&mut self, received_nonce: u32) -> Action { + let inner = &mut self.0.lock(); + + let (sent_at, expected_nonce) = match inner.state { + RttState::Waiting { .. } => { + log::error!("received unexpected pong {received_nonce}"); + return Action::Terminate(Frame::protocol_error()); + } + RttState::AwaitingPong { sent_at, nonce } => (sent_at, nonce), + }; + + if received_nonce != expected_nonce { + log::error!("received pong with {received_nonce} but expected {expected_nonce}"); + return Action::Terminate(Frame::protocol_error()); + } + + let rtt = sent_at.elapsed(); + inner.rtt = Some(rtt); + log::debug!("received pong {received_nonce}, estimated round-trip-time {rtt:?}"); + + inner.state = RttState::Waiting { + next: Instant::now() + PING_INTERVAL, + }; + + return Action::None; + } + + pub(crate) fn get(&self) -> Option { + self.0.lock().rtt + } +} + +#[cfg(test)] +impl quickcheck::Arbitrary for Rtt { + fn arbitrary(g: &mut quickcheck::Gen) -> Self { + Self(Arc::new(Mutex::new(RttInner::arbitrary(g)))) + } +} + +#[derive(Debug)] +#[cfg_attr(test, derive(Clone))] +struct RttInner { + state: RttState, + rtt: Option, +} + +#[cfg(test)] +impl quickcheck::Arbitrary for RttInner { + fn arbitrary(g: &mut quickcheck::Gen) -> Self { + Self { + state: RttState::arbitrary(g), + rtt: if bool::arbitrary(g) { + Some(Duration::arbitrary(g)) + } else { + None + }, + } + } +} + +#[derive(Debug)] +#[cfg_attr(test, derive(Clone))] +enum RttState { + AwaitingPong { sent_at: Instant, nonce: u32 }, + Waiting { next: Instant }, +} + +#[cfg(test)] +impl quickcheck::Arbitrary for RttState { + fn arbitrary(g: &mut quickcheck::Gen) -> Self { + if bool::arbitrary(g) { + RttState::AwaitingPong { + sent_at: Instant::now(), + nonce: u32::arbitrary(g), + } + } else { + RttState::Waiting { + next: Instant::now(), + } + } + } +} diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index 84ab08f8..1f48e1b9 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -8,16 +8,18 @@ // at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license // at https://opensource.org/licenses/MIT. +use crate::connection::rtt::Rtt; use crate::frame::header::ACK; use crate::{ chunks::Chunks, - connection::{self, StreamCommand}, + connection::{self, rtt, StreamCommand}, frame::{ header::{Data, Header, StreamId, WindowUpdate}, Frame, }, Config, DEFAULT_CREDIT, }; +use flow_control::FlowController; use futures::{ channel::mpsc, future::Either, @@ -25,7 +27,6 @@ use futures::{ ready, SinkExt, }; use parking_lot::{Mutex, MutexGuard}; -use std::convert::TryInto; use std::{ fmt, io, pin::Pin, @@ -33,6 +34,8 @@ use std::{ task::{Context, Poll, Waker}, }; +mod flow_control; + /// The state of a Yamux stream. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum State { @@ -116,16 +119,24 @@ impl Stream { id: StreamId, conn: connection::Id, config: Arc, - credit: u32, + send_window: u32, sender: mpsc::Sender, + rtt: rtt::Rtt, + accumulated_max_stream_windows: Arc>, ) -> Self { Self { id, conn, config: config.clone(), sender, - flag: Flag::None, - shared: Arc::new(Mutex::new(Shared::new(DEFAULT_CREDIT, credit, config))), + flag: Flag::Ack, + shared: Arc::new(Mutex::new(Shared::new( + DEFAULT_CREDIT, + send_window, + accumulated_max_stream_windows, + rtt, + config, + ))), } } @@ -133,16 +144,23 @@ impl Stream { id: StreamId, conn: connection::Id, config: Arc, - window: u32, sender: mpsc::Sender, + rtt: rtt::Rtt, + accumulated_max_stream_windows: Arc>, ) -> Self { Self { id, conn, config: config.clone(), sender, - flag: Flag::None, - shared: Arc::new(Mutex::new(Shared::new(window, DEFAULT_CREDIT, config))), + flag: Flag::Syn, + shared: Arc::new(Mutex::new(Shared::new( + DEFAULT_CREDIT, + DEFAULT_CREDIT, + accumulated_max_stream_windows, + rtt, + config, + ))), } } @@ -164,11 +182,6 @@ impl Stream { self.shared().is_pending_ack() } - /// Set the flag that should be set on the next outbound frame header. - pub(crate) fn set_flag(&mut self, flag: Flag) { - self.flag = flag - } - pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> { self.shared.lock() } @@ -200,25 +213,26 @@ impl Stream { /// Send new credit to the sending side via a window update message if /// permitted. fn send_window_update(&mut self, cx: &mut Context) -> Poll> { - let mut shared = self.shared.lock(); - - if let Some(credit) = shared.next_window_update() { - ready!(self - .sender - .poll_ready(cx) - .map_err(|_| self.write_zero_err())?); - - shared.window += credit; - drop(shared); - - let mut frame = Frame::window_update(self.id, credit).right(); - self.add_flag(frame.header_mut()); - let cmd = StreamCommand::SendFrame(frame); - self.sender - .start_send(cmd) - .map_err(|_| self.write_zero_err())?; + if !self.shared.lock().state.can_read() { + return Poll::Ready(Ok(())); } + ready!(self + .sender + .poll_ready(cx) + .map_err(|_| self.write_zero_err())?); + + let Some(credit) = self.shared.lock().next_window_update() else { + return Poll::Ready(Ok(())); + }; + + let mut frame = Frame::window_update(self.id, credit).right(); + self.add_flag(frame.header_mut()); + let cmd = StreamCommand::SendFrame(frame); + self.sender + .start_send(cmd) + .map_err(|_| self.write_zero_err())?; + Poll::Ready(Ok(())) } } @@ -353,15 +367,21 @@ impl AsyncWrite for Stream { log::debug!("{}/{}: can no longer write", self.conn, self.id); return Poll::Ready(Err(self.write_zero_err())); } - if shared.credit == 0 { + if shared.send_window() == 0 { log::trace!("{}/{}: no more credit left", self.conn, self.id); shared.writer = Some(cx.waker().clone()); return Poll::Pending; } - let k = std::cmp::min(shared.credit as usize, buf.len()); - let k = std::cmp::min(k, self.config.split_send_size); - shared.credit = shared.credit.saturating_sub(k as u32); - Vec::from(&buf[..k]) + let k = std::cmp::min( + shared.send_window(), + buf.len().try_into().unwrap_or(u32::MAX), + ); + let k = std::cmp::min( + k, + self.config.split_send_size.try_into().unwrap_or(u32::MAX), + ); + shared.consume_send_window(k); + Vec::from(&buf[..k as usize]) }; let n = body.len(); let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left(); @@ -418,26 +438,34 @@ impl AsyncWrite for Stream { #[derive(Debug)] pub(crate) struct Shared { state: State, - pub(crate) window: u32, - pub(crate) credit: u32, + flow_controller: FlowController, pub(crate) buffer: Chunks, pub(crate) reader: Option, pub(crate) writer: Option, - config: Arc, } impl Shared { - fn new(window: u32, credit: u32, config: Arc) -> Self { + fn new( + receive_window: u32, + send_window: u32, + accumulated_max_stream_windows: Arc>, + rtt: Rtt, + config: Arc, + ) -> Self { Shared { state: State::Open { acknowledged: false, }, - window, - credit, + flow_controller: FlowController::new( + receive_window, + send_window, + accumulated_max_stream_windows, + rtt, + config, + ), buffer: Chunks::new(), reader: None, writer: None, - config, } } @@ -481,37 +509,8 @@ impl Shared { current // Return the previous stream state for informational purposes. } - // TODO: This does not need to live in shared any longer. - /// Calculate the number of additional window bytes the receiving side - /// should grant the sending side via a window update message. - /// - /// Returns `None` if too small to justify a window update message. - /// - /// Note: Once a caller successfully sent a window update message, the - /// locally tracked window size needs to be updated manually by the caller. pub(crate) fn next_window_update(&mut self) -> Option { - if !self.state.can_read() { - return None; - } - - let new_credit = { - debug_assert!(self.config.receive_window >= self.window); - let bytes_received = self.config.receive_window.saturating_sub(self.window); - let buffer_len: u32 = self.buffer.len().try_into().unwrap_or(std::u32::MAX); - - bytes_received.saturating_sub(buffer_len) - }; - - // Send WindowUpdate message when half or more of the configured receive - // window can be granted as additional credit to the sender. - // - // See https://github.com/paritytech/yamux/issues/100 for a detailed - // discussion. - if new_credit >= self.config.receive_window / 2 { - Some(new_credit) - } else { - None - } + self.flow_controller.next_window_update(self.buffer.len()) } /// Whether we are still waiting for the remote to acknowledge this stream. @@ -523,4 +522,24 @@ impl Shared { } ) } + + pub(crate) fn send_window(&self) -> u32 { + self.flow_controller.send_window() + } + + pub(crate) fn consume_send_window(&mut self, i: u32) { + self.flow_controller.consume_send_window(i) + } + + pub(crate) fn increase_send_window_by(&mut self, i: u32) { + self.flow_controller.increase_send_window_by(i) + } + + pub(crate) fn receive_window(&self) -> u32 { + self.flow_controller.receive_window() + } + + pub(crate) fn consume_receive_window(&mut self, i: u32) { + self.flow_controller.consume_receive_window(i) + } } diff --git a/yamux/src/connection/stream/flow_control.rs b/yamux/src/connection/stream/flow_control.rs new file mode 100644 index 00000000..9d9bd6cf --- /dev/null +++ b/yamux/src/connection/stream/flow_control.rs @@ -0,0 +1,304 @@ +use std::{cmp, sync::Arc, time::Instant}; + +use parking_lot::Mutex; + +use crate::{connection::rtt::Rtt, Config, DEFAULT_CREDIT}; + +#[derive(Debug)] +pub(crate) struct FlowController { + config: Arc, + last_window_update: Instant, + /// See [`Connection::rtt`]. + rtt: Rtt, + /// See [`Connection::accumulated_max_stream_windows`]. + accumulated_max_stream_windows: Arc>, + receive_window: u32, + max_receive_window: u32, + send_window: u32, +} + +impl FlowController { + pub(crate) fn new( + receive_window: u32, + send_window: u32, + accumulated_max_stream_windows: Arc>, + rtt: Rtt, + config: Arc, + ) -> Self { + Self { + receive_window, + send_window, + config, + rtt, + accumulated_max_stream_windows, + max_receive_window: DEFAULT_CREDIT, + last_window_update: Instant::now(), + } + } + + /// Calculate the number of additional window bytes the receiving side (local) should grant the + /// sending side (remote) via a window update message. + /// + /// Returns `None` if too small to justify a window update message. + pub(crate) fn next_window_update(&mut self, buffer_len: usize) -> Option { + self.assert_invariants(buffer_len); + + let bytes_received = self.max_receive_window - self.receive_window; + let mut next_window_update = + bytes_received.saturating_sub(buffer_len.try_into().unwrap_or(u32::MAX)); + + // Don't send an update in case half or more of the window is still available to the sender. + if next_window_update < self.max_receive_window / 2 { + return None; + } + + log::trace!( + "received {} mb in {} seconds ({} mbit/s)", + next_window_update as f64 / crate::MIB as f64, + self.last_window_update.elapsed().as_secs_f64(), + next_window_update as f64 / crate::MIB as f64 * 8.0 + / self.last_window_update.elapsed().as_secs_f64() + ); + + // Auto-tuning `max_receive_window` + // + // The ideal `max_receive_window` is equal to the bandwidth-delay-product (BDP), thus + // allowing the remote sender to exhaust the entire available bandwidth on a single stream. + // Choosing `max_receive_window` too small prevents the remote sender from exhausting the + // available bandwidth. Choosing `max_receive_window` to large is wasteful and delays + // backpressure from the receiver to the sender on the stream. + // + // In case the remote sender has exhausted half or more of its credit in less than 2 + // round-trips, try to double `max_receive_window`. + // + // For simplicity `max_receive_window` is never decreased. + // + // This implementation is heavily influenced by QUIC. See document below for rational on the + // above strategy. + // + // https://docs.google.com/document/d/1F2YfdDXKpy20WVKJueEf4abn_LVZHhMUMS5gX6Pgjl4/edit?usp=sharing + if self + .rtt + .get() + .map(|rtt| self.last_window_update.elapsed() < rtt * 2) + .unwrap_or(false) + { + let mut accumulated_max_stream_windows = self.accumulated_max_stream_windows.lock(); + + // Ideally one can just double it: + let new_max = self.max_receive_window.saturating_mul(2); + + // But one has to consider the configured connection limit: + let new_max = { + let connection_limit: usize = self.max_receive_window as usize + + // the overall configured conneciton limit + (self.config.max_connection_receive_window.unwrap_or(usize::MAX) + // minus the minimum amount of window guaranteed to each stream + - self.config.max_num_streams * DEFAULT_CREDIT as usize + // minus the amount of bytes beyond the minimum amount (`DEFAULT_CREDIT`) + // already allocated by this and other streams on the connection. + - *accumulated_max_stream_windows); + + cmp::min(new_max, connection_limit.try_into().unwrap_or(u32::MAX)) + }; + + // Account for the additional credit on the accumulated connection counter. + *accumulated_max_stream_windows += (new_max - self.max_receive_window) as usize; + drop(accumulated_max_stream_windows); + + log::debug!( + "old window_max: {} mb, new window_max: {} mb", + self.max_receive_window as f64 / crate::MIB as f64, + new_max as f64 / crate::MIB as f64 + ); + + self.max_receive_window = new_max; + + // Recalculate `next_window_update` with the new `max_receive_window`. + let bytes_received = self.max_receive_window - self.receive_window; + next_window_update = + bytes_received.saturating_sub(buffer_len.try_into().unwrap_or(u32::MAX)); + } + + self.last_window_update = Instant::now(); + self.receive_window += next_window_update; + + self.assert_invariants(buffer_len); + + return Some(next_window_update); + } + + fn assert_invariants(&self, buffer_len: usize) { + if !cfg!(debug_assertions) { + return; + } + + let config = &self.config; + let rtt = self.rtt.get(); + let accumulated_max_stream_windows = *self.accumulated_max_stream_windows.lock(); + + assert!( + buffer_len <= self.max_receive_window as usize, + "The current buffer size never exceeds the maximum stream receive window." + ); + assert!( + self.receive_window <= self.max_receive_window, + "The current window never exceeds the maximum." + ); + assert!( + (self.max_receive_window - DEFAULT_CREDIT) as usize + <= config.max_connection_receive_window.unwrap_or(usize::MAX) + - config.max_num_streams * DEFAULT_CREDIT as usize, + "The maximum never exceeds its maximum portion of the configured connection limit." + ); + assert!( + (self.max_receive_window - DEFAULT_CREDIT) as usize + <= accumulated_max_stream_windows, + "The amount by which the stream maximum exceeds DEFAULT_CREDIT is tracked in accumulated_max_stream_windows." + ); + if rtt.is_none() { + assert_eq!( + self.max_receive_window, DEFAULT_CREDIT, + "The maximum is only increased iff an rtt measurement is available." + ); + } + } + + pub(crate) fn send_window(&self) -> u32 { + self.send_window + } + + pub(crate) fn consume_send_window(&mut self, i: u32) { + self.send_window = self + .send_window + .checked_sub(i) + .expect("not exceed send window"); + } + + pub(crate) fn increase_send_window_by(&mut self, i: u32) { + self.send_window = self + .send_window + .checked_add(i) + .expect("send window not to exceed u32"); + } + + pub(crate) fn receive_window(&self) -> u32 { + self.receive_window + } + + pub(crate) fn consume_receive_window(&mut self, i: u32) { + self.receive_window = self + .receive_window + .checked_sub(i) + .expect("not exceed receive window"); + } +} + +impl Drop for FlowController { + fn drop(&mut self) { + let mut accumulated_max_stream_windows = self.accumulated_max_stream_windows.lock(); + + debug_assert!( + *accumulated_max_stream_windows >= (self.max_receive_window - DEFAULT_CREDIT) as usize, + "{accumulated_max_stream_windows} {}", + self.max_receive_window + ); + + *accumulated_max_stream_windows -= (self.max_receive_window - DEFAULT_CREDIT) as usize; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use quickcheck::{GenRange, QuickCheck}; + + #[derive(Debug)] + struct Input { + controller: FlowController, + buffer_len: usize, + } + + #[cfg(test)] + impl Clone for Input { + fn clone(&self) -> Self { + Self { + controller: FlowController { + config: self.controller.config.clone(), + accumulated_max_stream_windows: Arc::new(Mutex::new( + self.controller + .accumulated_max_stream_windows + .lock() + .clone(), + )), + rtt: self.controller.rtt.clone(), + last_window_update: self.controller.last_window_update.clone(), + receive_window: self.controller.receive_window, + max_receive_window: self.controller.max_receive_window, + send_window: self.controller.send_window, + }, + buffer_len: self.buffer_len, + } + } + } + + impl quickcheck::Arbitrary for Input { + fn arbitrary(g: &mut quickcheck::Gen) -> Self { + let config = Arc::new(Config::arbitrary(g)); + let rtt = Rtt::arbitrary(g); + + let max_connection_minus_default = + config.max_connection_receive_window.unwrap_or(usize::MAX) + - (config.max_num_streams * (DEFAULT_CREDIT as usize)); + + let max_receive_window = if rtt.get().is_none() { + DEFAULT_CREDIT + } else { + g.gen_range( + DEFAULT_CREDIT + ..(DEFAULT_CREDIT as usize) + .saturating_add(max_connection_minus_default) + .try_into() + .unwrap_or(u32::MAX) + .saturating_add(1), + ) + }; + let receive_window = g.gen_range(0..max_receive_window); + let buffer_len = g.gen_range(0..max_receive_window as usize); + let accumulated_max_stream_windows = Arc::new(Mutex::new(g.gen_range( + (max_receive_window - DEFAULT_CREDIT) as usize + ..max_connection_minus_default.saturating_add(1), + ))); + let last_window_update = + Instant::now() - std::time::Duration::from_secs(g.gen_range(0..(60 * 60 * 24))); + let send_window = g.gen_range(0..u32::MAX); + + Self { + controller: FlowController { + accumulated_max_stream_windows, + rtt, + last_window_update, + config, + receive_window, + max_receive_window, + send_window, + }, + buffer_len, + } + } + } + + #[test] + fn next_window_update() { + fn property( + Input { + mut controller, + buffer_len, + }: Input, + ) { + controller.next_window_update(buffer_len); + } + + QuickCheck::new().quickcheck(property as fn(_)) + } +} diff --git a/yamux/src/frame.rs b/yamux/src/frame.rs index 692840a4..be010a3b 100644 --- a/yamux/src/frame.rs +++ b/yamux/src/frame.rs @@ -132,6 +132,22 @@ impl Frame { } } +impl Frame { + pub fn ping(nonce: u32) -> Self { + let mut header = Header::ping(nonce); + header.syn(); + + Frame { + header, + body: Vec::new(), + } + } + + pub fn nonce(&self) -> u32 { + self.header.nonce() + } +} + impl Frame { pub fn term() -> Self { Frame { diff --git a/yamux/src/frame/io.rs b/yamux/src/frame/io.rs index 43e03987..3712a7c2 100644 --- a/yamux/src/frame/io.rs +++ b/yamux/src/frame/io.rs @@ -20,6 +20,13 @@ use std::{ task::{Context, Poll}, }; +/// Maximum Yamux frame body length +/// +/// Limits the amount of bytes a remote can cause the local node to allocate at once when reading. +/// +/// Chosen based on intuition in past iterations. +const MAX_FRAME_BODY_LEN: usize = 1 * crate::MIB; + /// A [`Stream`] and writer of [`Frame`] values. #[derive(Debug)] pub(crate) struct Io { @@ -27,17 +34,15 @@ pub(crate) struct Io { io: T, read_state: ReadState, write_state: WriteState, - max_body_len: usize, } impl Io { - pub(crate) fn new(id: Id, io: T, max_frame_body_len: usize) -> Self { + pub(crate) fn new(id: Id, io: T) -> Self { Io { id, io, read_state: ReadState::Init, write_state: WriteState::Init, - max_body_len: max_frame_body_len, } } } @@ -200,7 +205,7 @@ impl Stream for Io { let body_len = header.len().val() as usize; - if body_len > this.max_body_len { + if body_len > MAX_FRAME_BODY_LEN { return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge( body_len, )))); @@ -349,7 +354,7 @@ mod tests { fn property(f: Frame<()>) -> bool { futures::executor::block_on(async move { let id = crate::connection::Id::random(); - let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.body.len()); + let mut io = Io::new(id, futures::io::Cursor::new(Vec::new())); if io.send(f.clone()).await.is_err() { return false; } diff --git a/yamux/src/lib.rs b/yamux/src/lib.rs index 1df9d193..c082abb0 100644 --- a/yamux/src/lib.rs +++ b/yamux/src/lib.rs @@ -38,7 +38,11 @@ pub use crate::frame::{ FrameDecodeError, }; -pub const DEFAULT_CREDIT: u32 = 256 * 1024; // as per yamux specification +const KIB: usize = 1024; +const MIB: usize = KIB * 1024; +const GIB: usize = MIB * 1024; + +pub const DEFAULT_CREDIT: u32 = 256 * KIB as u32; // as per yamux specification pub type Result = std::result::Result; @@ -61,22 +65,19 @@ const MAX_ACK_BACKLOG: usize = 256; /// /// For details on why this concrete value was chosen, see /// https://github.com/paritytech/yamux/issues/100. -const DEFAULT_SPLIT_SEND_SIZE: usize = 16 * 1024; +const DEFAULT_SPLIT_SEND_SIZE: usize = 16 * KIB; /// Yamux configuration. /// /// The default configuration values are as follows: /// -/// - receive window = 256 KiB -/// - max. buffer size (per stream) = 1 MiB -/// - max. number of streams = 8192 -/// - window update mode = on read +/// - max. for the total receive window size across all streams of a connection = 1 GiB +/// - max. number of streams = 512 /// - read after close = true /// - split send size = 16 KiB #[derive(Debug, Clone)] pub struct Config { - receive_window: u32, - max_buffer_size: usize, + max_connection_receive_window: Option, max_num_streams: usize, read_after_close: bool, split_send_size: usize, @@ -85,9 +86,8 @@ pub struct Config { impl Default for Config { fn default() -> Self { Config { - receive_window: DEFAULT_CREDIT, - max_buffer_size: 1024 * 1024, - max_num_streams: 8192, + max_connection_receive_window: Some(1 * GIB), + max_num_streams: 512, read_after_close: true, split_send_size: DEFAULT_SPLIT_SEND_SIZE, } @@ -95,26 +95,58 @@ impl Default for Config { } impl Config { - /// Set the receive window per stream (must be >= 256 KiB). + /// Set the upper limit for the total receive window size across all streams of a connection. /// - /// # Panics + /// Must be `>= 256 KiB * max_num_streams` to allow each stream at least the Yamux default + /// window size. /// - /// If the given receive window is < 256 KiB. - pub fn set_receive_window(&mut self, n: u32) -> &mut Self { - assert!(n >= DEFAULT_CREDIT); - self.receive_window = n; - self - } + /// The window of a stream starts at 256 KiB and is increased (auto-tuned) based on the + /// connection's round-trip time and the stream's bandwidth (striving for the + /// bandwidth-delay-product). + /// + /// Set to `None` to disable limit, i.e. allow each stream to grow receive window based on + /// connection's round-trip time and stream's bandwidth without limit. + /// + /// ## DOS attack mitigation + /// + /// A remote node (attacker) might trick the local node (target) into allocating large stream + /// receive windows, trying to make the local node run out of memory. + /// + /// This attack is difficult, as the local node only increases the stream receive window up to + /// 2x the bandwidth-delay-product, where bandwidth is the amount of bytes read, not just + /// received. In other words, the attacker has to send (and have the local node read) + /// significant amount of bytes on a stream over a long period of time to increase the stream + /// receive window. E.g. on a 60ms 10Gbit/s connection the bandwidth-delay-product is ~75 MiB + /// and thus the local node will at most allocate ~150 MiB (2x bandwidth-delay-product) per + /// stream. + /// + /// Despite the difficulty of the attack one should choose a reasonable + /// `max_connection_receive_window` to protect against this attack, especially since an attacker + /// might use more than one stream per connection. + pub fn set_max_connection_receive_window(&mut self, n: Option) -> &mut Self { + self.max_connection_receive_window = n; + + assert!( + self.max_connection_receive_window.unwrap_or(usize::MAX) + >= self.max_num_streams * DEFAULT_CREDIT as usize, + "`max_connection_receive_window` must be `>= 256 KiB * max_num_streams` to allow each + stream at least the Yamux default window size" + ); - /// Set the max. buffer size per stream. - pub fn set_max_buffer_size(&mut self, n: usize) -> &mut Self { - self.max_buffer_size = n; self } - /// Set the max. number of streams. + /// Set the max. number of streams per connection. pub fn set_max_num_streams(&mut self, n: usize) -> &mut Self { self.max_num_streams = n; + + assert!( + self.max_connection_receive_window.unwrap_or(usize::MAX) + >= self.max_num_streams * DEFAULT_CREDIT as usize, + "`max_connection_receive_window` must be `>= 256 KiB * max_num_streams` to allow each + stream at least the Yamux default window size" + ); + self } @@ -142,3 +174,23 @@ static_assertions::const_assert! { static_assertions::const_assert! { std::mem::size_of::() <= std::mem::size_of::() } + +#[cfg(test)] +impl quickcheck::Arbitrary for Config { + fn arbitrary(g: &mut quickcheck::Gen) -> Self { + use quickcheck::GenRange; + + let max_num_streams = g.gen_range(0..u16::MAX as usize); + + Config { + max_connection_receive_window: if bool::arbitrary(g) { + Some(g.gen_range((DEFAULT_CREDIT as usize * max_num_streams)..usize::MAX)) + } else { + None + }, + max_num_streams, + read_after_close: bool::arbitrary(g), + split_send_size: g.gen_range(DEFAULT_SPLIT_SEND_SIZE..usize::MAX), + } + } +}