diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index 1642eb68..00575023 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -33,6 +33,7 @@ h3 = { version = "0.0.6", optional = true } bytes = { workspace = true } flume = { workspace = true } futures-util = { workspace = true } +rustc-hash = "2.0.0" thiserror = "1.0.63" # Windows specific dependencies diff --git a/compio-quic/benches/quic.rs b/compio-quic/benches/quic.rs index 63d12c88..66694da5 100644 --- a/compio-quic/benches/quic.rs +++ b/compio-quic/benches/quic.rs @@ -127,13 +127,13 @@ fn echo_quinn(b: &mut Bencher, content: &[u8], streams: usize) { client.set_default_client_config(client_config); let addr = server.local_addr().unwrap(); - let (client_conn, server_conn) = futures_util::join!( + let (client_conn, server_conn) = tokio::join!( async move { client.connect(addr, "localhost").unwrap().await.unwrap() }, async move { server.accept().await.unwrap().await.unwrap() } ); let start = Instant::now(); - tokio::spawn(async move { + let handle = tokio::spawn(async move { while let Ok((mut send, mut recv)) = server_conn.accept_bi().await { tokio::spawn(async move { echo_impl!(send, recv); @@ -157,6 +157,7 @@ fn echo_quinn(b: &mut Bencher, content: &[u8], streams: usize) { .collect::>(); while futures.next().await.is_some() {} } + handle.abort(); start.elapsed() }); } diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs index 83a41cf0..89a24b10 100644 --- a/compio-quic/src/connection.rs +++ b/compio-quic/src/connection.rs @@ -1,5 +1,5 @@ use std::{ - collections::{HashMap, VecDeque}, + collections::VecDeque, io, net::{IpAddr, SocketAddr}, pin::{pin, Pin}, @@ -21,6 +21,7 @@ use quinn_proto::{ congestion::Controller, crypto::rustls::HandshakeData, ConnectionHandle, ConnectionStats, Dir, EndpointEvent, StreamEvent, StreamId, VarInt, }; +use rustc_hash::FxHashMap as HashMap; use thiserror::Error; use crate::{RecvStream, SendStream, Socket}; @@ -37,7 +38,7 @@ pub(crate) struct ConnectionState { pub(crate) error: Option, connected: bool, worker: Option>, - poll_waker: Option, + poller: Option, on_connected: Option, on_handshake_data: Option, datagram_received: VecDeque, @@ -73,8 +74,14 @@ impl ConnectionState { wake_all_streams(&mut self.stopped); } + fn close(&mut self, error_code: VarInt, reason: Bytes) { + self.conn.close(Instant::now(), error_code, reason); + self.terminate(ConnectionError::LocallyClosed); + self.wake(); + } + pub(crate) fn wake(&mut self) { - if let Some(waker) = self.poll_waker.take() { + if let Some(waker) = self.poller.take() { waker.wake() } } @@ -110,6 +117,12 @@ pub(crate) struct ConnectionInner { events_rx: Receiver, } +fn implicit_close(this: &Arc) { + if Arc::strong_count(this) == 2 { + this.state().close(0u32.into(), Bytes::new()) + } +} + impl ConnectionInner { fn new( handle: ConnectionHandle, @@ -124,16 +137,16 @@ impl ConnectionInner { connected: false, error: None, worker: None, - poll_waker: None, + poller: None, on_connected: None, on_handshake_data: None, datagram_received: VecDeque::new(), datagrams_unblocked: VecDeque::new(), stream_opened: [VecDeque::new(), VecDeque::new()], stream_available: [VecDeque::new(), VecDeque::new()], - writable: HashMap::new(), - readable: HashMap::new(), - stopped: HashMap::new(), + writable: HashMap::default(), + readable: HashMap::default(), + stopped: HashMap::default(), }), handle, socket, @@ -157,25 +170,13 @@ impl ConnectionInner { } } - fn close(&self, error_code: VarInt, reason: Bytes) { - let mut state = self.state(); - state.conn.close(Instant::now(), error_code, reason); - state.terminate(ConnectionError::LocallyClosed); - state.wake(); - } - - async fn run(&self) -> io::Result<()> { - let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize)); - let mut transmit_fut = pin!(Fuse::terminated()); - - let mut timer = Timer::new(); - + async fn run(self: &Arc) -> io::Result<()> { let mut poller = stream::poll_fn(|cx| { let mut state = self.state(); - let ready = state.poll_waker.is_none(); - match &state.poll_waker { + let ready = state.poller.is_none(); + match &state.poller { Some(waker) if waker.will_wake(cx.waker()) => {} - _ => state.poll_waker = Some(cx.waker().clone()), + _ => state.poller = Some(cx.waker().clone()), }; if ready { Poll::Ready(Some(())) @@ -185,36 +186,45 @@ impl ConnectionInner { }) .fuse(); + let mut timer = Timer::new(); + let mut event_stream = self.events_rx.stream().ready_chunks(100); + let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize)); + let mut transmit_fut = pin!(Fuse::terminated()); + loop { - select! { - _ = poller.next() => {} + let mut state = select! { + _ = poller.select_next_some() => self.state(), _ = timer => { - self.state().conn.handle_timeout(Instant::now()); - timer.reset(None); + let mut state = self.state(); + state.conn.handle_timeout(Instant::now()); + state } - ev = self.events_rx.recv_async() => match ev { - Ok(ConnectionEvent::Close(error_code, reason)) => self.close(error_code, reason), - Ok(ConnectionEvent::Proto(ev)) => self.state().conn.handle_event(ev), - Err(_) => unreachable!("endpoint dropped connection"), + events = event_stream.select_next_some() => { + let mut state = self.state(); + for event in events { + match event { + ConnectionEvent::Close(error_code, reason) => state.close(error_code, reason), + ConnectionEvent::Proto(event) => state.conn.handle_event(event), + } + } + state }, BufResult::<(), Vec>(res, mut buf) = transmit_fut => match res { Ok(()) => { buf.clear(); send_buf = Some(buf); + self.state() }, Err(e) => break Err(e), }, - } - - let now = Instant::now(); - let mut state = self.state(); + }; if let Some(mut buf) = send_buf.take() { - if let Some(transmit) = - state - .conn - .poll_transmit(now, self.socket.max_gso_segments(), &mut buf) - { + if let Some(transmit) = state.conn.poll_transmit( + Instant::now(), + self.socket.max_gso_segments(), + &mut buf, + ) { transmit_fut.set(async move { self.socket.send(buf, &transmit).await }.fuse()) } else { send_buf = Some(buf); @@ -480,9 +490,7 @@ impl Future for Connecting { impl Drop for Connecting { fn drop(&mut self) { - if Arc::strong_count(&self.0) == 2 { - self.0.close(0u32.into(), Bytes::new()) - } + implicit_close(&self.0) } } @@ -593,7 +601,9 @@ impl Connection { /// [`Endpoint::shutdown()`]: crate::Endpoint::shutdown /// [`close()`]: Connection::close pub fn close(&self, error_code: VarInt, reason: &[u8]) { - self.0.close(error_code, Bytes::copy_from_slice(reason)); + self.0 + .state() + .close(error_code, Bytes::copy_from_slice(reason)); } /// Wait for the connection to be closed for any reason. @@ -838,9 +848,7 @@ impl Eq for Connection {} impl Drop for Connection { fn drop(&mut self) { - if Arc::strong_count(&self.0) == 2 { - self.close(0u32.into(), b"") - } + implicit_close(&self.0) } } diff --git a/compio-quic/src/endpoint.rs b/compio-quic/src/endpoint.rs index 2721ffd4..99d7400f 100644 --- a/compio-quic/src/endpoint.rs +++ b/compio-quic/src/endpoint.rs @@ -1,5 +1,5 @@ use std::{ - collections::{HashMap, VecDeque}, + collections::VecDeque, io, mem::ManuallyDrop, net::{SocketAddr, SocketAddrV6}, @@ -19,12 +19,13 @@ use futures_util::{ future::{self}, select, task::AtomicWaker, - FutureExt, + FutureExt, StreamExt, }; use quinn_proto::{ ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointConfig, EndpointEvent, ServerConfig, Transmit, VarInt, }; +use rustc_hash::FxHashMap as HashMap; use crate::{Connecting, ConnectionEvent, Incoming, RecvMeta, Socket}; @@ -153,7 +154,7 @@ impl EndpointInner { None, ), worker: None, - connections: HashMap::new(), + connections: HashMap::default(), close: None, exit_on_idle: false, incoming: VecDeque::new(), @@ -254,6 +255,8 @@ impl EndpointInner { } async fn run(&self) -> io::Result<()> { + let respond_fn = |buf: Vec, transmit: Transmit| self.respond(buf, transmit); + let mut recv_fut = pin!( self.socket .recv(Vec::with_capacity( @@ -269,26 +272,31 @@ impl EndpointInner { .fuse() ); - let respond_fn = |buf: Vec, transmit: Transmit| self.respond(buf, transmit); + let mut event_stream = self.events.1.stream().ready_chunks(100); loop { - select! { + let mut state = select! { BufResult(res, recv_buf) = recv_fut => { + let mut state = self.state.lock().unwrap(); match res { - Ok(meta) => self.state.lock().unwrap().handle_data(meta, &recv_buf, respond_fn), + Ok(meta) => state.handle_data(meta, &recv_buf, respond_fn), Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {} #[cfg(windows)] Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {} Err(e) => break Err(e), } recv_fut.set(self.socket.recv(recv_buf).fuse()); + state }, - (ch, event) = self.events.1.recv_async().map(Result::unwrap) => { - self.state.lock().unwrap().handle_event(ch, event); + events = event_stream.select_next_some() => { + let mut state = self.state.lock().unwrap(); + for (ch, event) in events { + state.handle_event(ch, event); + } + state }, - } + }; - let mut state = self.state.lock().unwrap(); if state.exit_on_idle && state.is_idle() { break Ok(()); }