diff --git a/quinn-proto/src/connection/streams/mod.rs b/quinn-proto/src/connection/streams/mod.rs index 69d6b8abc..b08f3302d 100644 --- a/quinn-proto/src/connection/streams/mod.rs +++ b/quinn-proto/src/connection/streams/mod.rs @@ -7,8 +7,10 @@ use bytes::Bytes; use thiserror::Error; use tracing::trace; +use self::state::get_or_insert_recv; + use super::spaces::{Retransmits, ThinRetransmits}; -use crate::{frame, Dir, StreamId, VarInt}; +use crate::{connection::streams::state::get_or_insert_send, frame, Dir, StreamId, VarInt}; mod recv; use recv::Recv; @@ -133,7 +135,7 @@ impl<'a> RecvStream<'a> { hash_map::Entry::Occupied(s) => s, hash_map::Entry::Vacant(_) => return Err(UnknownStream { _private: () }), }; - let stream = entry.get_mut(); + let stream = get_or_insert_recv(self.state.stream_receive_window)(entry.get_mut()); let (read_credits, stop_sending) = stream.stop()?; if stop_sending.should_transmit() { @@ -207,11 +209,16 @@ impl<'a> SendStream<'a> { } let limit = self.state.write_limit(); + + let max_send_data = self.state.max_send_data(&self.id); + let stream = self .state .send .get_mut(&self.id) + .map(get_or_insert_send(max_send_data)) .ok_or(WriteError::UnknownStream)?; + if limit == 0 { trace!( stream = %self.id, max_data = self.state.max_data, data_sent = self.state.data_sent, @@ -237,8 +244,9 @@ impl<'a> SendStream<'a> { /// Check if this stream was stopped, get the reason if it was pub fn stopped(&mut self) -> Result, UnknownStream> { - match self.state.send.get(&self.id) { - Some(s) => Ok(s.stop_reason), + match self.state.send.get(&self.id).as_ref() { + Some(Some(s)) => Ok(s.stop_reason), + Some(None) => Ok(None), None => Err(UnknownStream { _private: () }), } } @@ -249,10 +257,12 @@ impl<'a> SendStream<'a> { /// /// [`StreamEvent::Finished`]: crate::StreamEvent::Finished pub fn finish(&mut self) -> Result<(), FinishError> { + let max_send_data = self.state.max_send_data(&self.id); let stream = self .state .send .get_mut(&self.id) + .map(get_or_insert_send(max_send_data)) .ok_or(FinishError::UnknownStream)?; let was_pending = stream.is_pending(); @@ -269,10 +279,12 @@ impl<'a> SendStream<'a> { /// # Panics /// - when applied to a receive stream pub fn reset(&mut self, error_code: VarInt) -> Result<(), UnknownStream> { + let max_send_data = self.state.max_send_data(&self.id); let stream = self .state .send .get_mut(&self.id) + .map(get_or_insert_send(max_send_data)) .ok_or(UnknownStream { _private: () })?; if matches!(stream.state, SendState::ResetSent) { @@ -296,10 +308,12 @@ impl<'a> SendStream<'a> { /// # Panics /// - when applied to a receive stream pub fn set_priority(&mut self, priority: i32) -> Result<(), UnknownStream> { + let max_send_data = self.state.max_send_data(&self.id); let stream = self .state .send .get_mut(&self.id) + .map(get_or_insert_send(max_send_data)) .ok_or(UnknownStream { _private: () })?; stream.priority = priority; @@ -317,7 +331,7 @@ impl<'a> SendStream<'a> { .get(&self.id) .ok_or(UnknownStream { _private: () })?; - Ok(stream.priority) + Ok(stream.as_ref().map(|s| s.priority).unwrap_or_default()) } } diff --git a/quinn-proto/src/connection/streams/recv.rs b/quinn-proto/src/connection/streams/recv.rs index 627666b83..6f46ec865 100644 --- a/quinn-proto/src/connection/streams/recv.rs +++ b/quinn-proto/src/connection/streams/recv.rs @@ -4,6 +4,7 @@ use std::mem; use thiserror::Error; use tracing::debug; +use super::state::get_or_insert_recv; use super::{Retransmits, ShouldTransmit, StreamHalf, StreamId, StreamsState, UnknownStream}; use crate::connection::assembler::{Assembler, Chunk, IllegalOrderedRead}; use crate::{frame, TransportError, VarInt}; @@ -18,14 +19,14 @@ pub(super) struct Recv { } impl Recv { - pub(super) fn new(initial_max_data: u64) -> Self { - Self { + pub(super) fn new(initial_max_data: u64) -> Box { + Box::new(Self { state: RecvState::default(), assembler: Assembler::new(), sent_max_stream_data: initial_max_data, end: 0, stopped: false, - } + }) } /// Process a STREAM frame @@ -215,15 +216,16 @@ impl<'a> Chunks<'a> { streams: &'a mut StreamsState, pending: &'a mut Retransmits, ) -> Result { - let entry = match streams.recv.entry(id) { + let mut entry = match streams.recv.entry(id) { Entry::Occupied(entry) => entry, Entry::Vacant(_) => return Err(ReadableError::UnknownStream), }; - let mut recv = match entry.get().stopped { - true => return Err(ReadableError::UnknownStream), - false => entry.remove(), - }; + let mut recv = + match get_or_insert_recv(streams.stream_receive_window)(entry.get_mut()).stopped { + true => return Err(ReadableError::UnknownStream), + false => entry.remove().unwrap(), // this can't fail due to the previous get_or_insert_with + }; recv.assembler.ensure_ordering(ordered)?; Ok(Self { @@ -313,7 +315,7 @@ impl<'a> Chunks<'a> { self.pending.max_stream_data.insert(self.id); } // Return the stream to storage for future use - self.streams.recv.insert(self.id, rs); + self.streams.recv.insert(self.id, Some(rs)); } // Issue connection-level flow control credit for any data we read regardless of state @@ -331,7 +333,7 @@ impl<'a> Drop for Chunks<'a> { } enum ChunksState { - Readable(Recv), + Readable(Box), Reset(VarInt), Finished, Finalized, diff --git a/quinn-proto/src/connection/streams/send.rs b/quinn-proto/src/connection/streams/send.rs index 5a2f39c71..4a26cd01c 100644 --- a/quinn-proto/src/connection/streams/send.rs +++ b/quinn-proto/src/connection/streams/send.rs @@ -18,8 +18,8 @@ pub(super) struct Send { } impl Send { - pub(super) fn new(max_data: VarInt) -> Self { - Self { + pub(super) fn new(max_data: VarInt) -> Box { + Box::new(Self { max_data: max_data.into(), state: SendState::Ready, pending: SendBuffer::new(), @@ -27,7 +27,7 @@ impl Send { fin_pending: false, connection_blocked: false, stop_reason: None, - } + }) } /// Whether the stream has been reset diff --git a/quinn-proto/src/connection/streams/state.rs b/quinn-proto/src/connection/streams/state.rs index 207778f00..ce2e86669 100644 --- a/quinn-proto/src/connection/streams/state.rs +++ b/quinn-proto/src/connection/streams/state.rs @@ -24,8 +24,8 @@ use crate::{ pub struct StreamsState { pub(super) side: Side, // Set of streams that are currently open, or could be immediately opened by the peer - pub(super) send: FxHashMap, - pub(super) recv: FxHashMap, + pub(super) send: FxHashMap>>, + pub(super) recv: FxHashMap>>, pub(super) next: [u64; 2], /// Maximum number of locally-initiated streams that may be opened over the lifetime of the /// connection so far, per direction @@ -152,8 +152,9 @@ impl StreamsState { self.received_max_data(params.initial_max_data); for i in 0..self.max_remote[Dir::Bi as usize] { let id = StreamId::new(!self.side, Dir::Bi, i); - self.send.get_mut(&id).unwrap().max_data = - params.initial_max_stream_data_bidi_local.into(); + if let Some(s) = self.send.get_mut(&id).and_then(|s| s.as_mut()) { + s.max_data = params.initial_max_stream_data_bidi_local.into(); + } } } @@ -205,13 +206,17 @@ impl StreamsState { frame: frame::Stream, payload_len: usize, ) -> Result { - let stream = frame.id; - self.validate_receive_id(stream).map_err(|e| { + let id = frame.id; + self.validate_receive_id(id).map_err(|e| { debug!("received illegal STREAM frame"); e })?; - let rs = match self.recv.get_mut(&stream) { + let rs = match self + .recv + .get_mut(&id) + .map(get_or_insert_recv(self.stream_receive_window)) + { Some(rs) => rs, None => { trace!("dropping frame for closed stream"); @@ -229,14 +234,14 @@ impl StreamsState { self.data_recvd = self.data_recvd.saturating_add(new_bytes); if !rs.stopped { - self.on_stream_frame(true, stream); + self.on_stream_frame(true, id); return Ok(ShouldTransmit(false)); } // Stopped streams become closed instantly on FIN, so check whether we need to clean up if closed { - self.recv.remove(&stream); - self.stream_freed(stream, StreamHalf::Recv); + self.recv.remove(&id); + self.stream_freed(id, StreamHalf::Recv); } // We don't buffer data on stopped streams, so issue flow control credit immediately @@ -261,7 +266,11 @@ impl StreamsState { e })?; - let rs = match self.recv.get_mut(&id) { + let rs = match self + .recv + .get_mut(&id) + .map(get_or_insert_recv(self.stream_receive_window)) + { Some(stream) => stream, None => { trace!("received RESET_STREAM on closed stream"); @@ -304,7 +313,12 @@ impl StreamsState { /// Process incoming `STOP_SENDING` frame #[allow(unreachable_pub)] // fuzzing only pub fn received_stop_sending(&mut self, id: StreamId, error_code: VarInt) { - let stream = match self.send.get_mut(&id) { + let max_send_data = self.max_send_data(&id); + let stream = match self + .send + .get_mut(&id) + .map(get_or_insert_send(max_send_data)) + { Some(ss) => ss, None => return, }; @@ -320,7 +334,7 @@ impl StreamsState { match self.send.entry(id) { hash_map::Entry::Vacant(_) => {} hash_map::Entry::Occupied(e) => { - if let SendState::ResetSent = e.get().state { + if let Some(SendState::ResetSent) = e.get().as_ref().map(|s| s.state) { e.remove_entry(); self.stream_freed(id, StreamHalf::Send); } @@ -332,11 +346,12 @@ impl StreamsState { pub(crate) fn can_send_stream_data(&self) -> bool { // Reset streams may linger in the pending stream list, but will never produce stream frames self.pending.iter().any(|level| { - level - .queue - .borrow() - .iter() - .any(|id| self.send.get(id).map_or(false, |s| !s.is_reset())) + level.queue.borrow().iter().any(|id| { + self.send + .get(id) + .and_then(|s| s.as_ref()) + .map_or(false, |s| !s.is_reset()) + }) }) } @@ -344,6 +359,7 @@ impl StreamsState { pub(crate) fn can_send_flow_control(&self, id: StreamId) -> bool { self.recv .get(&id) + .and_then(|s| s.as_ref()) .map_or(false, |s| s.receiving_unknown_size()) } @@ -361,7 +377,7 @@ impl StreamsState { Some(x) => x, None => break, }; - let stream = match self.send.get_mut(&id) { + let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { Some(x) => x, None => continue, }; @@ -428,7 +444,7 @@ impl StreamsState { None => break, }; pending.max_stream_data.remove(&id); - let rs = match self.recv.get_mut(&id) { + let rs = match self.recv.get_mut(&id).and_then(|s| s.as_mut()) { Some(x) => x, None => continue, }; @@ -507,7 +523,7 @@ impl StreamsState { break; } }; - let stream = match self.send.get_mut(&id) { + let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { Some(s) => s, // Stream was reset with pending data and the reset was acknowledged None => continue, @@ -593,7 +609,18 @@ impl StreamsState { hash_map::Entry::Vacant(_) => return, hash_map::Entry::Occupied(e) => e, }; - let stream = entry.get_mut(); + + let stream = match entry.get_mut().as_mut() { + Some(s) => s, + None => { + // Because we only call this after sending data on this stream, + // this closure should be unreachable. If we did somehow screw that up, + // then we might hit an underflow below with unpredictable effects down + // the line. Best to short-circuit. + return; + } + }; + if stream.is_reset() { // We account for outstanding data on reset streams at time of reset return; @@ -611,7 +638,7 @@ impl StreamsState { } pub(crate) fn retransmit(&mut self, frame: frame::StreamMeta) { - let stream = match self.send.get_mut(&frame.id) { + let stream = match self.send.get_mut(&frame.id).and_then(|s| s.as_mut()) { // Loss of data on a closed stream is a noop None => return, Some(x) => x, @@ -627,7 +654,10 @@ impl StreamsState { for dir in Dir::iter() { for index in 0..self.next[dir as usize] { let id = StreamId::new(Side::Client, dir, index); - let stream = self.send.get_mut(&id).unwrap(); + let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { + Some(stream) => stream, + None => continue, + }; if stream.pending.is_fully_acked() && !stream.fin_pending { // Stream data can't be acked in 0-RTT, so we must not have sent anything on // this stream @@ -679,7 +709,12 @@ impl StreamsState { } let write_limit = self.write_limit(); - if let Some(ss) = self.send.get_mut(&id) { + let max_send_data = self.max_send_data(&id); + if let Some(ss) = self + .send + .get_mut(&id) + .map(get_or_insert_send(max_send_data)) + { if ss.increase_max_data(offset) { if write_limit > 0 { self.events.push_back(StreamEvent::Writable { id }); @@ -716,7 +751,7 @@ impl StreamsState { if self.write_limit() > 0 { while let Some(id) = self.connection_blocked.pop() { - let stream = match self.send.get_mut(&id) { + let stream = match self.send.get_mut(&id).and_then(|s| s.as_mut()) { None => continue, Some(s) => s, }; @@ -797,24 +832,26 @@ impl StreamsState { expanded } + pub(super) fn max_send_data(&self, id: &StreamId) -> VarInt { + let remote = self.side != id.initiator(); + match id.dir() { + Dir::Uni => self.initial_max_stream_data_uni, + // Remote/local appear reversed here because the transport parameters are named from + // the perspective of the peer. + Dir::Bi if remote => self.initial_max_stream_data_bidi_local, + Dir::Bi => self.initial_max_stream_data_bidi_remote, + } + } + pub(super) fn insert(&mut self, remote: bool, id: StreamId) { let bi = id.dir() == Dir::Bi; + // bidirectional OR (unidirectional AND NOT remote) if bi || !remote { - let max_data = match id.dir() { - Dir::Uni => self.initial_max_stream_data_uni, - // Remote/local appear reversed here because the transport parameters are named from - // the perspective of the peer. - Dir::Bi if remote => self.initial_max_stream_data_bidi_local, - Dir::Bi => self.initial_max_stream_data_bidi_remote, - }; - let stream = Send::new(max_data); - assert!(self.send.insert(id, stream).is_none()); + assert!(self.send.insert(id, None).is_none()); } + // bidirectional OR (unidirectional AND remote) if bi || remote { - assert!(self - .recv - .insert(id, Recv::new(self.stream_receive_window)) - .is_none()); + assert!(self.recv.insert(id, None).is_none()); } } @@ -867,6 +904,20 @@ impl StreamsState { } } +#[inline] +pub(super) fn get_or_insert_send( + max_data: VarInt, +) -> impl Fn(&mut Option>) -> &mut Box { + move |opt| opt.get_or_insert_with(|| Send::new(max_data)) +} + +#[inline] +pub(super) fn get_or_insert_recv( + initial_max_data: u64, +) -> impl Fn(&mut Option>) -> &mut Box { + move |opt| opt.get_or_insert_with(|| Recv::new(initial_max_data)) +} + #[cfg(test)] mod tests { use super::*;