diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 3418b4ded..2df494b97 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -1095,6 +1095,14 @@ impl Connection { self.path.congestion.as_ref() } + /// Modify the number of remotely initiated streams that may be concurrently open + /// + /// No streams may be opened by the peer unless fewer than `count` are already open. Large + /// `count`s increase both minimum and worst-case memory consumption. + pub fn set_max_concurrent_streams(&mut self, dir: Dir, count: VarInt) { + self.streams.set_max_concurrent(dir, count); + } + fn on_ack_received( &mut self, now: Instant, diff --git a/quinn-proto/src/connection/streams/state.rs b/quinn-proto/src/connection/streams/state.rs index a23c85f7e..d3ae57ea0 100644 --- a/quinn-proto/src/connection/streams/state.rs +++ b/quinn-proto/src/connection/streams/state.rs @@ -26,11 +26,20 @@ pub struct StreamsState { pub(super) send: FxHashMap, pub(super) recv: FxHashMap, pub(super) next: [u64; 2], - // Locally initiated + /// Maximum number of locally-initiated streams that may be opened over the lifetime of the + /// connection so far, per direction pub(super) max: [u64; 2], - // Maximum that can be remotely initiated + /// Maximum number of remotely-initiated streams that may be opened over the lifetime of the + /// connection so far, per direction max_remote: [u64; 2], - // Lowest that hasn't actually been opened + /// Number of streams that we've given the peer permission to open + allocated_remote_count: [u64; 2], + /// Size of the desired stream flow control window. May be smaller than `allocated_remote_count` + /// due to `set_max_concurrent` calls. + max_concurrent_remote_count: [u64; 2], + /// Whether `max_concurrent_remote_count` has ever changed + flow_control_adjusted: bool, + /// Lowest remotely-initiated stream index that haven't actually been opened by the peer pub(super) next_remote: [u64; 2], /// Whether the remote endpoint has opened any streams the application doesn't know about yet, /// per directionality @@ -94,6 +103,9 @@ impl StreamsState { next: [0, 0], max: [0, 0], max_remote: [max_remote_bi.into(), max_remote_uni.into()], + allocated_remote_count: [max_remote_bi.into(), max_remote_uni.into()], + max_concurrent_remote_count: [max_remote_bi.into(), max_remote_uni.into()], + flow_control_adjusted: false, next_remote: [0, 0], opened: [false, false], next_reported_remote: [0, 0], @@ -139,11 +151,18 @@ impl StreamsState { } } - fn alloc_remote_stream(&mut self, dir: Dir) { - self.max_remote[dir as usize] += 1; - let id = StreamId::new(!self.side, dir, self.max_remote[dir as usize] - 1); - self.insert(true, id); - self.max_streams_dirty[dir as usize] = true; + /// Ensure we have space for at least a full flow control window of remotely-initiated streams + /// to be open, and notify the peer if the window has moved + fn ensure_remote_streams(&mut self, dir: Dir) { + let new_count = self.max_concurrent_remote_count[dir as usize] + .saturating_sub(self.allocated_remote_count[dir as usize]); + for i in 0..new_count { + let id = StreamId::new(!self.side, dir, self.max_remote[dir as usize] + i); + self.insert(true, id); + } + self.allocated_remote_count[dir as usize] += new_count; + self.max_remote[dir as usize] += new_count; + self.max_streams_dirty[dir as usize] = new_count != 0; } pub fn zero_rtt_rejected(&mut self) { @@ -159,7 +178,13 @@ impl StreamsState { } } self.next[dir as usize] = 0; + + // If 0-RTT was rejected, any flow control frames we sent were lost. + if self.flow_control_adjusted { + self.max_streams_dirty[dir as usize] = true; + } } + self.pending.clear(); self.send_streams = 0; self.data_sent = 0; @@ -728,6 +753,12 @@ impl StreamsState { id.index() >= self.next[id.dir() as usize] } + pub fn set_max_concurrent(&mut self, dir: Dir, count: VarInt) { + self.flow_control_adjusted = true; + self.max_concurrent_remote_count[dir as usize] = count.into(); + self.ensure_remote_streams(dir); + } + pub(super) fn insert(&mut self, remote: bool, id: StreamId) { let bi = id.dir() == Dir::Bi; if bi || !remote { @@ -782,7 +813,8 @@ impl StreamsState { StreamHalf::Recv => !self.send.contains_key(&id), }; if fully_free { - self.alloc_remote_stream(id.dir()); + self.allocated_remote_count[id.dir() as usize] -= 1; + self.ensure_remote_streams(id.dir()); } } if half == StreamHalf::Send { @@ -1316,4 +1348,209 @@ mod tests { assert_eq!(pending.reset_stream, &[(id, 0u32.into())]); assert!(!server.can_send_stream_data()); } + + #[test] + fn stream_limit_fixed() { + let mut client = make(Side::Client); + // Open streams 0-127 + assert_eq!( + client.received( + frame::Stream { + id: StreamId::new(Side::Server, Dir::Uni, 127), + offset: 0, + fin: true, + data: Bytes::from_static(&[]), + }, + 0 + ), + Ok(ShouldTransmit(false)) + ); + // Try to open stream 128, exceeding limit + assert_eq!( + client + .received( + frame::Stream { + id: StreamId::new(Side::Server, Dir::Uni, 128), + offset: 0, + fin: true, + data: Bytes::from_static(&[]), + }, + 0 + ) + .unwrap_err() + .code, + TransportErrorCode::STREAM_LIMIT_ERROR + ); + + // Free stream 127 + let mut pending = Retransmits::default(); + let mut stream = RecvStream { + id: StreamId::new(Side::Server, Dir::Uni, 127), + state: &mut client, + pending: &mut pending, + }; + stream.stop(0u32.into()).unwrap(); + + assert!(client.max_streams_dirty[Dir::Uni as usize]); + + // Open stream 128 + assert_eq!( + client.received( + frame::Stream { + id: StreamId::new(Side::Server, Dir::Uni, 128), + offset: 0, + fin: true, + data: Bytes::from_static(&[]), + }, + 0 + ), + Ok(ShouldTransmit(false)) + ); + } + + #[test] + fn stream_limit_grows() { + let mut client = make(Side::Client); + // Open streams 0-127 + assert_eq!( + client.received( + frame::Stream { + id: StreamId::new(Side::Server, Dir::Uni, 127), + offset: 0, + fin: true, + data: Bytes::from_static(&[]), + }, + 0 + ), + Ok(ShouldTransmit(false)) + ); + // Try to open stream 128, exceeding limit + assert_eq!( + client + .received( + frame::Stream { + id: StreamId::new(Side::Server, Dir::Uni, 128), + offset: 0, + fin: true, + data: Bytes::from_static(&[]), + }, + 0 + ) + .unwrap_err() + .code, + TransportErrorCode::STREAM_LIMIT_ERROR + ); + + // Relax limit by one + client.set_max_concurrent(Dir::Uni, 129u32.into()); + + assert!(client.max_streams_dirty[Dir::Uni as usize]); + + // Open stream 128 + assert_eq!( + client.received( + frame::Stream { + id: StreamId::new(Side::Server, Dir::Uni, 128), + offset: 0, + fin: true, + data: Bytes::from_static(&[]), + }, + 0 + ), + Ok(ShouldTransmit(false)) + ); + } + + #[test] + fn stream_limit_shrinks() { + let mut client = make(Side::Client); + // Open streams 0-127 + assert_eq!( + client.received( + frame::Stream { + id: StreamId::new(Side::Server, Dir::Uni, 127), + offset: 0, + fin: true, + data: Bytes::from_static(&[]), + }, + 0 + ), + Ok(ShouldTransmit(false)) + ); + + // Tighten limit by one + client.set_max_concurrent(Dir::Uni, 127u32.into()); + + // Free stream 127 + let mut pending = Retransmits::default(); + let mut stream = RecvStream { + id: StreamId::new(Side::Server, Dir::Uni, 127), + state: &mut client, + pending: &mut pending, + }; + stream.stop(0u32.into()).unwrap(); + assert!(!client.max_streams_dirty[Dir::Uni as usize]); + + // Try to open stream 128, still exceeding limit + assert_eq!( + client + .received( + frame::Stream { + id: StreamId::new(Side::Server, Dir::Uni, 128), + offset: 0, + fin: true, + data: Bytes::from_static(&[]), + }, + 0 + ) + .unwrap_err() + .code, + TransportErrorCode::STREAM_LIMIT_ERROR + ); + + // Free stream 126 + assert_eq!( + client.received_reset(frame::ResetStream { + id: StreamId::new(Side::Server, Dir::Uni, 126), + error_code: 0u32.into(), + final_offset: 0u32.into(), + }), + Ok(ShouldTransmit(false)) + ); + let mut pending = Retransmits::default(); + let mut stream = RecvStream { + id: StreamId::new(Side::Server, Dir::Uni, 126), + state: &mut client, + pending: &mut pending, + }; + stream.stop(0u32.into()).unwrap(); + + assert!(client.max_streams_dirty[Dir::Uni as usize]); + + // Open stream 128 + assert_eq!( + client.received( + frame::Stream { + id: StreamId::new(Side::Server, Dir::Uni, 128), + offset: 0, + fin: true, + data: Bytes::from_static(&[]), + }, + 0 + ), + Ok(ShouldTransmit(false)) + ); + } + + #[test] + fn remote_stream_capacity() { + let mut client = make(Side::Client); + for _ in 0..2 { + client.set_max_concurrent(Dir::Uni, 200u32.into()); + client.set_max_concurrent(Dir::Bi, 201u32.into()); + assert_eq!(client.recv.len(), 200 + 201); + assert_eq!(client.max_remote[Dir::Uni as usize], 200); + assert_eq!(client.max_remote[Dir::Bi as usize], 201); + } + } } diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index ea84a7a63..dc920f90a 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -534,6 +534,28 @@ impl Connection { .crypto_session() .export_keying_material(output, label, context) } + + /// Modify the number of remotely initiated unidirectional streams that may be concurrently open + /// + /// No streams may be opened by the peer unless fewer than `count` are already open. Large + /// `count`s increase both minimum and worst-case memory consumption. + pub fn set_max_concurrent_uni_streams(&self, count: VarInt) { + let mut conn = self.0.lock("set_max_concurrent_uni_streams"); + conn.inner.set_max_concurrent_streams(Dir::Uni, count); + // May need to send MAX_STREAMS to make progress + conn.wake(); + } + + /// Modify the number of remotely initiated bidirectional streams that may be concurrently open + /// + /// No streams may be opened by the peer unless fewer than `count` are already open. Large + /// `count`s increase both minimum and worst-case memory consumption. + pub fn set_max_concurrent_bi_streams(&self, count: VarInt) { + let mut conn = self.0.lock("set_max_concurrent_bi_streams"); + conn.inner.set_max_concurrent_streams(Dir::Bi, count); + // May need to send MAX_STREAMS to make progress + conn.wake(); + } } impl Clone for Connection {