diff --git a/neqo-bin/src/udp.rs b/neqo-bin/src/udp.rs index 632a1293d7..8bfad219b4 100644 --- a/neqo-bin/src/udp.rs +++ b/neqo-bin/src/udp.rs @@ -9,6 +9,7 @@ use std::{ io::{self, IoSliceMut}, + mem::MaybeUninit, net::{SocketAddr, ToSocketAddrs}, slice, }; @@ -17,6 +18,13 @@ use neqo_common::{Datagram, IpTos}; use quinn_udp::{EcnCodepoint, RecvMeta, Transmit, UdpSocketState}; use tokio::io::Interest; +#[cfg(not(any(target_os = "macos", target_os = "ios")))] +// Chosen somewhat arbitrarily; might benefit from additional tuning. +pub(crate) const BATCH_SIZE: usize = 32; + +#[cfg(any(target_os = "macos", target_os = "ios"))] +pub(crate) const BATCH_SIZE: usize = 1; + /// Socket receive buffer size. /// /// Allows reading multiple datagrams in a single [`Socket::recv`] call. @@ -25,7 +33,8 @@ const RECV_BUF_SIZE: usize = u16::MAX as usize; pub struct Socket { socket: tokio::net::UdpSocket, state: UdpSocketState, - recv_buf: Vec, + // TODO: Rename + recv_buf: [Vec; BATCH_SIZE], } impl Socket { @@ -36,7 +45,11 @@ impl Socket { Ok(Self { state: quinn_udp::UdpSocketState::new((&socket).into())?, socket: tokio::net::UdpSocket::from_std(socket)?, - recv_buf: vec![0; RECV_BUF_SIZE], + recv_buf: (0..BATCH_SIZE) + .map(|_| vec![0; RECV_BUF_SIZE]) + .collect::>() + .try_into() + .expect("successful array instantiation"), }) } @@ -77,18 +90,25 @@ impl Socket { /// Receive a UDP datagram on the specified socket. pub fn recv(&mut self, local_address: &SocketAddr) -> Result, io::Error> { - let mut meta = RecvMeta::default(); - - match self.socket.try_io(Interest::READABLE, || { - self.state.recv( - (&self.socket).into(), - &mut [IoSliceMut::new(&mut self.recv_buf)], - slice::from_mut(&mut meta), - ) + let mut metas = [RecvMeta::default(); BATCH_SIZE]; + + // TODO: Safe? + let mut iovs = MaybeUninit::<[IoSliceMut<'_>; BATCH_SIZE]>::uninit(); + for (i, buf) in self.recv_buf.iter_mut().enumerate() { + unsafe { + iovs.as_mut_ptr() + .cast::() + .add(i) + .write(IoSliceMut::new(buf)); + }; + } + let mut iovs = unsafe { iovs.assume_init() }; + + let msgs = match self.socket.try_io(Interest::READABLE, || { + self.state + .recv((&self.socket).into(), &mut iovs, &mut metas) }) { - Ok(n) => { - assert_eq!(n, 1, "only passed one slice"); - } + Ok(n) => n, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock || err.kind() == io::ErrorKind::Interrupted => @@ -100,123 +120,35 @@ impl Socket { } }; - if meta.len == 0 { - eprintln!("zero length datagram received?"); - return Ok(vec![]); - } - if meta.len == self.recv_buf.len() { - eprintln!( - "Might have received more than {} bytes", - self.recv_buf.len() - ); - } - - Ok(self.recv_buf[0..meta.len] - .chunks(meta.stride.min(self.recv_buf.len())) - .map(|d| { - Datagram::new( - meta.addr, - *local_address, - meta.ecn.map(|n| IpTos::from(n as u8)).unwrap_or_default(), - None, // TODO: get the real TTL https://github.com/quinn-rs/quinn/issues/1749 - d, - ) + // TODO + // if meta.len == 0 { + // eprintln!("zero length datagram received?"); + // return Ok(vec![]); + // } + // if meta.len == self.recv_buf.len() { + // eprintln!( + // "Might have received more than {} bytes", + // self.recv_buf.len() + // ); + // } + + Ok(metas + .iter() + .zip(iovs.iter()) + .take(msgs) + .flat_map(|(meta, buf)| { + buf[0..meta.len] + .chunks(meta.stride.min(buf.len())) + .map(|d| { + Datagram::new( + meta.addr, + *local_address, + meta.ecn.map(|n| IpTos::from(n as u8)).unwrap_or_default(), + None, // TODO: get the real TTL https://github.com/quinn-rs/quinn/issues/1749 + d, + ) + }) }) .collect()) } } - -#[cfg(test)] -mod tests { - use neqo_common::{IpTosDscp, IpTosEcn}; - - use super::*; - - #[tokio::test] - async fn datagram_tos() -> Result<(), io::Error> { - let sender = Socket::bind("127.0.0.1:0")?; - let receiver_addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); - let mut receiver = Socket::bind(receiver_addr)?; - - let datagram = Datagram::new( - sender.local_addr()?, - receiver.local_addr()?, - IpTos::from((IpTosDscp::Le, IpTosEcn::Ect1)), - None, - "Hello, world!".as_bytes().to_vec(), - ); - - sender.writable().await?; - sender.send(datagram.clone())?; - - receiver.readable().await?; - let received_datagram = receiver - .recv(&receiver_addr) - .expect("receive to succeed") - .into_iter() - .next() - .expect("receive to yield datagram"); - - // Assert that the ECN is correct. - assert_eq!( - IpTosEcn::from(datagram.tos()), - IpTosEcn::from(received_datagram.tos()) - ); - - Ok(()) - } - - /// Expect [`Socket::recv`] to handle multiple [`Datagram`]s on GRO read. - #[tokio::test] - #[cfg_attr(not(any(target_os = "linux", target_os = "windows")), ignore)] - async fn many_datagrams_through_gro() -> Result<(), io::Error> { - const SEGMENT_SIZE: usize = 128; - - let sender = Socket::bind("127.0.0.1:0")?; - let receiver_addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); - let mut receiver = Socket::bind(receiver_addr)?; - - // `neqo_common::udp::Socket::send` does not yet - // (https://github.com/mozilla/neqo/issues/1693) support GSO. Use - // `quinn_udp` directly. - let max_gso_segments = sender.state.max_gso_segments(); - let msg = vec![0xAB; SEGMENT_SIZE * max_gso_segments]; - let transmit = Transmit { - destination: receiver.local_addr()?, - ecn: EcnCodepoint::from_bits(Into::::into(IpTos::from(( - IpTosDscp::Le, - IpTosEcn::Ect1, - )))), - contents: msg.clone().into(), - segment_size: Some(SEGMENT_SIZE), - src_ip: None, - }; - sender.writable().await?; - let n = sender.socket.try_io(Interest::WRITABLE, || { - sender - .state - .send((&sender.socket).into(), slice::from_ref(&transmit)) - })?; - assert_eq!(n, 1, "only passed one slice"); - - // Allow for one GSO sendmmsg to result in multiple GRO recvmmsg. - let mut num_received = 0; - while num_received < max_gso_segments { - receiver.readable().await?; - receiver - .recv(&receiver_addr) - .expect("receive to succeed") - .into_iter() - .for_each(|d| { - assert_eq!( - SEGMENT_SIZE, - d.len(), - "Expect received datagrams to have same length as sent datagrams." - ); - num_received += 1; - }); - } - - Ok(()) - } -}