Skip to content

Commit

Permalink
perf(quic): batch recv on channel
Browse files Browse the repository at this point in the history
  • Loading branch information
AsakuraMizu committed Aug 28, 2024
1 parent 148f3ac commit 7e3b2fb
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 59 deletions.
1 change: 1 addition & 0 deletions compio-quic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ h3 = { version = "0.0.6", optional = true }
bytes = { workspace = true }
flume = { workspace = true }
futures-util = { workspace = true }
rustc-hash = "2.0.0"
thiserror = "1.0.63"

# Windows specific dependencies
Expand Down
5 changes: 3 additions & 2 deletions compio-quic/benches/quic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ fn echo_quinn(b: &mut Bencher, content: &[u8], streams: usize) {
client.set_default_client_config(client_config);
let addr = server.local_addr().unwrap();

let (client_conn, server_conn) = futures_util::join!(
let (client_conn, server_conn) = tokio::join!(
async move { client.connect(addr, "localhost").unwrap().await.unwrap() },
async move { server.accept().await.unwrap().await.unwrap() }
);

let start = Instant::now();
tokio::spawn(async move {
let handle = tokio::spawn(async move {
while let Ok((mut send, mut recv)) = server_conn.accept_bi().await {
tokio::spawn(async move {
echo_impl!(send, recv);
Expand All @@ -157,6 +157,7 @@ fn echo_quinn(b: &mut Bencher, content: &[u8], streams: usize) {
.collect::<FuturesUnordered<_>>();
while futures.next().await.is_some() {}
}
handle.abort();
start.elapsed()
});
}
Expand Down
102 changes: 55 additions & 47 deletions compio-quic/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
collections::{HashMap, VecDeque},
collections::VecDeque,
io,
net::{IpAddr, SocketAddr},
pin::{pin, Pin},
Expand All @@ -21,6 +21,7 @@ use quinn_proto::{
congestion::Controller, crypto::rustls::HandshakeData, ConnectionHandle, ConnectionStats, Dir,
EndpointEvent, StreamEvent, StreamId, VarInt,
};
use rustc_hash::FxHashMap as HashMap;
use thiserror::Error;

use crate::{RecvStream, SendStream, Socket};
Expand All @@ -37,7 +38,7 @@ pub(crate) struct ConnectionState {
pub(crate) error: Option<ConnectionError>,
connected: bool,
worker: Option<JoinHandle<()>>,
poll_waker: Option<Waker>,
poller: Option<Waker>,
on_connected: Option<Waker>,
on_handshake_data: Option<Waker>,
datagram_received: VecDeque<Waker>,
Expand Down Expand Up @@ -73,8 +74,14 @@ impl ConnectionState {
wake_all_streams(&mut self.stopped);
}

fn close(&mut self, error_code: VarInt, reason: Bytes) {
self.conn.close(Instant::now(), error_code, reason);
self.terminate(ConnectionError::LocallyClosed);
self.wake();
}

pub(crate) fn wake(&mut self) {
if let Some(waker) = self.poll_waker.take() {
if let Some(waker) = self.poller.take() {
waker.wake()
}
}
Expand Down Expand Up @@ -110,6 +117,12 @@ pub(crate) struct ConnectionInner {
events_rx: Receiver<ConnectionEvent>,
}

fn implicit_close(this: &Arc<ConnectionInner>) {
if Arc::strong_count(this) == 2 {
this.state().close(0u32.into(), Bytes::new())
}
}

impl ConnectionInner {
fn new(
handle: ConnectionHandle,
Expand All @@ -124,16 +137,16 @@ impl ConnectionInner {
connected: false,
error: None,
worker: None,
poll_waker: None,
poller: None,
on_connected: None,
on_handshake_data: None,
datagram_received: VecDeque::new(),
datagrams_unblocked: VecDeque::new(),
stream_opened: [VecDeque::new(), VecDeque::new()],
stream_available: [VecDeque::new(), VecDeque::new()],
writable: HashMap::new(),
readable: HashMap::new(),
stopped: HashMap::new(),
writable: HashMap::default(),
readable: HashMap::default(),
stopped: HashMap::default(),
}),
handle,
socket,
Expand All @@ -157,25 +170,13 @@ impl ConnectionInner {
}
}

fn close(&self, error_code: VarInt, reason: Bytes) {
let mut state = self.state();
state.conn.close(Instant::now(), error_code, reason);
state.terminate(ConnectionError::LocallyClosed);
state.wake();
}

async fn run(&self) -> io::Result<()> {
let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize));
let mut transmit_fut = pin!(Fuse::terminated());

let mut timer = Timer::new();

async fn run(self: &Arc<Self>) -> io::Result<()> {
let mut poller = stream::poll_fn(|cx| {
let mut state = self.state();
let ready = state.poll_waker.is_none();
match &state.poll_waker {
let ready = state.poller.is_none();
match &state.poller {
Some(waker) if waker.will_wake(cx.waker()) => {}
_ => state.poll_waker = Some(cx.waker().clone()),
_ => state.poller = Some(cx.waker().clone()),
};
if ready {
Poll::Ready(Some(()))
Expand All @@ -185,36 +186,45 @@ impl ConnectionInner {
})
.fuse();

let mut timer = Timer::new();
let mut event_stream = self.events_rx.stream().ready_chunks(100);
let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize));
let mut transmit_fut = pin!(Fuse::terminated());

loop {
select! {
_ = poller.next() => {}
let mut state = select! {
_ = poller.select_next_some() => self.state(),
_ = timer => {
self.state().conn.handle_timeout(Instant::now());
timer.reset(None);
let mut state = self.state();
state.conn.handle_timeout(Instant::now());
state
}
ev = self.events_rx.recv_async() => match ev {
Ok(ConnectionEvent::Close(error_code, reason)) => self.close(error_code, reason),
Ok(ConnectionEvent::Proto(ev)) => self.state().conn.handle_event(ev),
Err(_) => unreachable!("endpoint dropped connection"),
events = event_stream.select_next_some() => {
let mut state = self.state();
for event in events {
match event {
ConnectionEvent::Close(error_code, reason) => state.close(error_code, reason),
ConnectionEvent::Proto(event) => state.conn.handle_event(event),
}
}
state
},
BufResult::<(), Vec<u8>>(res, mut buf) = transmit_fut => match res {
Ok(()) => {
buf.clear();
send_buf = Some(buf);
self.state()
},
Err(e) => break Err(e),
},
}

let now = Instant::now();
let mut state = self.state();
};

if let Some(mut buf) = send_buf.take() {
if let Some(transmit) =
state
.conn
.poll_transmit(now, self.socket.max_gso_segments(), &mut buf)
{
if let Some(transmit) = state.conn.poll_transmit(
Instant::now(),
self.socket.max_gso_segments(),
&mut buf,
) {
transmit_fut.set(async move { self.socket.send(buf, &transmit).await }.fuse())
} else {
send_buf = Some(buf);
Expand Down Expand Up @@ -480,9 +490,7 @@ impl Future for Connecting {

impl Drop for Connecting {
fn drop(&mut self) {
if Arc::strong_count(&self.0) == 2 {
self.0.close(0u32.into(), Bytes::new())
}
implicit_close(&self.0)
}
}

Expand Down Expand Up @@ -593,7 +601,9 @@ impl Connection {
/// [`Endpoint::shutdown()`]: crate::Endpoint::shutdown
/// [`close()`]: Connection::close
pub fn close(&self, error_code: VarInt, reason: &[u8]) {
self.0.close(error_code, Bytes::copy_from_slice(reason));
self.0
.state()
.close(error_code, Bytes::copy_from_slice(reason));
}

/// Wait for the connection to be closed for any reason.
Expand Down Expand Up @@ -838,9 +848,7 @@ impl Eq for Connection {}

impl Drop for Connection {
fn drop(&mut self) {
if Arc::strong_count(&self.0) == 2 {
self.close(0u32.into(), b"")
}
implicit_close(&self.0)
}
}

Expand Down
28 changes: 18 additions & 10 deletions compio-quic/src/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
collections::{HashMap, VecDeque},
collections::VecDeque,
io,
mem::ManuallyDrop,
net::{SocketAddr, SocketAddrV6},
Expand All @@ -19,12 +19,13 @@ use futures_util::{
future::{self},
select,
task::AtomicWaker,
FutureExt,
FutureExt, StreamExt,
};
use quinn_proto::{
ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointConfig,
EndpointEvent, ServerConfig, Transmit, VarInt,
};
use rustc_hash::FxHashMap as HashMap;

use crate::{Connecting, ConnectionEvent, Incoming, RecvMeta, Socket};

Expand Down Expand Up @@ -153,7 +154,7 @@ impl EndpointInner {
None,
),
worker: None,
connections: HashMap::new(),
connections: HashMap::default(),
close: None,
exit_on_idle: false,
incoming: VecDeque::new(),
Expand Down Expand Up @@ -254,6 +255,8 @@ impl EndpointInner {
}

async fn run(&self) -> io::Result<()> {
let respond_fn = |buf: Vec<u8>, transmit: Transmit| self.respond(buf, transmit);

let mut recv_fut = pin!(
self.socket
.recv(Vec::with_capacity(
Expand All @@ -269,26 +272,31 @@ impl EndpointInner {
.fuse()
);

let respond_fn = |buf: Vec<u8>, transmit: Transmit| self.respond(buf, transmit);
let mut event_stream = self.events.1.stream().ready_chunks(100);

loop {
select! {
let mut state = select! {
BufResult(res, recv_buf) = recv_fut => {
let mut state = self.state.lock().unwrap();
match res {
Ok(meta) => self.state.lock().unwrap().handle_data(meta, &recv_buf, respond_fn),
Ok(meta) => state.handle_data(meta, &recv_buf, respond_fn),
Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {}
#[cfg(windows)]
Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {}
Err(e) => break Err(e),
}
recv_fut.set(self.socket.recv(recv_buf).fuse());
state
},
(ch, event) = self.events.1.recv_async().map(Result::unwrap) => {
self.state.lock().unwrap().handle_event(ch, event);
events = event_stream.select_next_some() => {
let mut state = self.state.lock().unwrap();
for (ch, event) in events {
state.handle_event(ch, event);
}
state
},
}
};

let mut state = self.state.lock().unwrap();
if state.exit_on_idle && state.is_idle() {
break Ok(());
}
Expand Down

0 comments on commit 7e3b2fb

Please sign in to comment.