From 20c8e8c77e650de0f45492442bd55b499ff71755 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Thu, 1 Feb 2024 17:30:53 +0100 Subject: [PATCH] refactor(server): replace mio with tokio (#1581) * refactor(server): replace mio with tokio * Move ready logic into fn * Extend expect docs * Restrict tokio features * Only process datagram once * Remove superfluous pub * fmt * Fix busy loop * Fold `ServersRunner::init into ServersRunner::new * Fix imports --------- Signed-off-by: Max Inden --- neqo-server/Cargo.toml | 4 +- neqo-server/src/main.rs | 251 +++++++++++++++------------------------- 2 files changed, 95 insertions(+), 160 deletions(-) diff --git a/neqo-server/Cargo.toml b/neqo-server/Cargo.toml index b3f8aae46..1d6b5df86 100644 --- a/neqo-server/Cargo.toml +++ b/neqo-server/Cargo.toml @@ -7,9 +7,8 @@ rust-version = "1.70.0" license = "MIT OR Apache-2.0" [dependencies] +futures = "0.3" log = {version = "0.4.17", default-features = false} -mio = "0.6.23" -mio-extras = "2.0.6" neqo-common = { path="./../neqo-common" } neqo-crypto = { path = "./../neqo-crypto" } neqo-http3 = { path = "./../neqo-http3" } @@ -18,6 +17,7 @@ neqo-transport = { path = "./../neqo-transport" } qlog = "0.11.0" regex = "1.9" structopt = "0.3" +tokio = { version = "1", features = ["net", "time", "macros", "rt", "rt-multi-thread"] } [features] deny-warnings = [] diff --git a/neqo-server/src/main.rs b/neqo-server/src/main.rs index 590e0d55d..0000ea4f8 100644 --- a/neqo-server/src/main.rs +++ b/neqo-server/src/main.rs @@ -10,23 +10,27 @@ use std::{ cell::RefCell, cmp::min, - collections::{HashMap, HashSet}, + collections::HashMap, convert::TryFrom, fmt::{self, Display}, fs::OpenOptions, io, io::Read, - mem, net::{SocketAddr, ToSocketAddrs}, path::PathBuf, + pin::Pin, process::exit, rc::Rc, str::FromStr, time::{Duration, Instant}, }; -use mio::{net::UdpSocket, Events, Poll, PollOpt, Ready, Token}; -use mio_extras::timer::{Builder, Timeout, Timer}; +use futures::{ + future::{select, select_all, Either}, + FutureExt, +}; +use tokio::{net::UdpSocket, time::Sleep}; + use neqo_common::{hex, qdebug, qinfo, qwarn, Datagram, Header, IpTos}; use neqo_crypto::{ constants::{TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256}, @@ -44,7 +48,6 @@ use structopt::StructOpt; use crate::old_https::Http09Server; -const TIMER_TOKEN: Token = Token(0xffff_ffff); const ANTI_REPLAY_WINDOW: Duration = Duration::from_secs(10); mod old_https; @@ -316,8 +319,8 @@ impl QuicParameters { } } -fn emit_packet(socket: &mut UdpSocket, out_dgram: Datagram) { - let sent = match socket.send_to(&out_dgram, &out_dgram.destination()) { +async fn emit_packet(socket: &mut UdpSocket, out_dgram: Datagram) { + let sent = match socket.send_to(&out_dgram, &out_dgram.destination()).await { Err(ref err) => { if err.kind() != io::ErrorKind::WouldBlock || err.kind() == io::ErrorKind::Interrupted { eprintln!("UDP send error: {err:?}"); @@ -594,7 +597,7 @@ fn read_dgram( local_address: &SocketAddr, ) -> Result, io::Error> { let buf = &mut [0u8; 2048]; - let (sz, remote_addr) = match socket.recv_from(&mut buf[..]) { + let (sz, remote_addr) = match socket.try_recv_from(&mut buf[..]) { Err(ref err) if err.kind() == io::ErrorKind::WouldBlock || err.kind() == io::ErrorKind::Interrupted => @@ -628,82 +631,36 @@ fn read_dgram( struct ServersRunner { args: Args, - poll: Poll, - hosts: Vec, server: Box, - timeout: Option, - sockets: Vec, - active_sockets: HashSet, - timer: Timer, + timeout: Option>>, + sockets: Vec<(SocketAddr, UdpSocket)>, } impl ServersRunner { pub fn new(args: Args) -> Result { - let server = Self::create_server(&args); - let mut runner = Self { - args, - poll: Poll::new()?, - hosts: Vec::new(), - server, - timeout: None, - sockets: Vec::new(), - active_sockets: HashSet::new(), - timer: Builder::default() - .tick_duration(Duration::from_millis(1)) - .build::(), - }; - runner.init()?; - Ok(runner) - } - - /// Init Poll for all hosts. Create sockets, and a map of the - /// socketaddrs to instances of the HttpServer handling that addr. - fn init(&mut self) -> Result<(), io::Error> { - self.hosts = self.args.listen_addresses(); - if self.hosts.is_empty() { + let hosts = args.listen_addresses(); + if hosts.is_empty() { eprintln!("No valid hosts defined"); return Err(io::Error::new(io::ErrorKind::InvalidInput, "No hosts")); } + let sockets = hosts + .into_iter() + .map(|host| { + let socket = std::net::UdpSocket::bind(host)?; + let local_addr = socket.local_addr()?; + println!("Server waiting for connection on: {local_addr:?}"); + socket.set_nonblocking(true)?; + Ok((host, UdpSocket::from_std(socket)?)) + }) + .collect::>()?; + let server = Self::create_server(&args); - for (i, host) in self.hosts.iter().enumerate() { - let socket = match UdpSocket::bind(host) { - Err(err) => { - eprintln!("Unable to bind UDP socket: {err}"); - return Err(err); - } - Ok(s) => s, - }; - - let local_addr = match socket.local_addr() { - Err(err) => { - eprintln!("Socket local address not bound: {err}"); - return Err(err); - } - Ok(s) => s, - }; - - print!("Server waiting for connection on: {local_addr:?}"); - // On Windows, this is not supported. - #[cfg(not(target_os = "windows"))] - if !socket.only_v6().unwrap_or(true) { - print!(" as well as V4"); - }; - println!(); - - self.poll.register( - &socket, - Token(i), - Ready::readable() | Ready::writable(), - PollOpt::edge(), - )?; - - self.sockets.push(socket); - } - - self.poll - .register(&self.timer, TIMER_TOKEN, Ready::readable(), PollOpt::edge())?; - - Ok(()) + Ok(Self { + args, + server, + timeout: None, + sockets, + }) } fn create_server(args: &Args) -> Box { @@ -741,110 +698,88 @@ impl ServersRunner { /// Tries to find a socket, but then just falls back to sending from the first. fn find_socket(&mut self, addr: SocketAddr) -> &mut UdpSocket { - let (first, rest) = self.sockets.split_first_mut().unwrap(); + let ((_host, first_socket), rest) = self.sockets.split_first_mut().unwrap(); rest.iter_mut() - .find(|s| { - s.local_addr() + .map(|(_host, socket)| socket) + .find(|socket| { + socket + .local_addr() .ok() .map_or(false, |socket_addr| socket_addr == addr) }) - .unwrap_or(first) + .unwrap_or(first_socket) } - fn process(&mut self, inx: usize, dgram: Option<&Datagram>) -> bool { - match self.server.process(dgram, self.args.now()) { - Output::Datagram(dgram) => { - let socket = self.find_socket(dgram.source()); - emit_packet(socket, dgram); - true - } - Output::Callback(new_timeout) => { - if let Some(to) = &self.timeout { - self.timer.cancel_timeout(to); + async fn process(&mut self, mut dgram: Option<&Datagram>) { + loop { + match self.server.process(dgram.take(), self.args.now()) { + Output::Datagram(dgram) => { + let socket = self.find_socket(dgram.source()); + emit_packet(socket, dgram).await; + } + Output::Callback(new_timeout) => { + qinfo!("Setting timeout of {:?}", new_timeout); + self.timeout = Some(Box::pin(tokio::time::sleep(new_timeout))); + break; + } + Output::None => { + qdebug!("Output::None"); + break; } - - qinfo!("Setting timeout of {:?} for socket {}", new_timeout, inx); - self.timeout = Some(self.timer.set_timeout(new_timeout, inx)); - false - } - Output::None => { - qdebug!("Output::None"); - false } } } - fn process_datagrams_and_events( - &mut self, - inx: usize, - read_socket: bool, - ) -> Result<(), io::Error> { - if self.sockets.get_mut(inx).is_some() { - if read_socket { - loop { - let socket = self.sockets.get_mut(inx).unwrap(); - let dgram = read_dgram(socket, &self.hosts[inx])?; + // Wait for any of the sockets to be readable or the timeout to fire. + async fn ready(&mut self) -> Result { + let sockets_ready = select_all( + self.sockets + .iter() + .map(|(_host, socket)| Box::pin(socket.readable())), + ) + .map(|(res, inx, _)| match res { + Ok(()) => Ok(Ready::Socket(inx)), + Err(e) => Err(e), + }); + let timeout_ready = self + .timeout + .as_mut() + .map(Either::Left) + .unwrap_or(Either::Right(futures::future::pending())) + .map(|()| Ok(Ready::Timeout)); + select(sockets_ready, timeout_ready).await.factor_first().0 + } + + async fn run(&mut self) -> Result<(), io::Error> { + loop { + match self.ready().await? { + Ready::Socket(inx) => loop { + let (host, socket) = self.sockets.get_mut(inx).unwrap(); + let dgram = read_dgram(socket, host)?; if dgram.is_none() { break; } - _ = self.process(inx, dgram.as_ref()); + self.process(dgram.as_ref()).await; + }, + Ready::Timeout => { + self.timeout = None; + self.process(None).await; } - } else { - _ = self.process(inx, None); } - self.server.process_events(&self.args, self.args.now()); - if self.process(inx, None) { - self.active_sockets.insert(inx); - } - } - Ok(()) - } - - fn process_active_conns(&mut self) -> Result<(), io::Error> { - let curr_active = mem::take(&mut self.active_sockets); - for inx in curr_active { - self.process_datagrams_and_events(inx, false)?; - } - Ok(()) - } - fn process_timeout(&mut self) -> Result<(), io::Error> { - while let Some(inx) = self.timer.poll() { - qinfo!("Timer expired for {:?}", inx); - self.process_datagrams_and_events(inx, false)?; + self.server.process_events(&self.args, self.args.now()); + self.process(None).await; } - Ok(()) } +} - pub fn run(&mut self) -> Result<(), io::Error> { - let mut events = Events::with_capacity(1024); - loop { - // If there are active servers do not block in poll. - self.poll.poll( - &mut events, - if self.active_sockets.is_empty() { - None - } else { - Some(Duration::from_millis(0)) - }, - )?; - - for event in &events { - if event.token() == TIMER_TOKEN { - self.process_timeout()?; - } else { - if !event.readiness().is_readable() { - continue; - } - self.process_datagrams_and_events(event.token().0, true)?; - } - } - self.process_active_conns()?; - } - } +enum Ready { + Socket(usize), + Timeout, } -fn main() -> Result<(), io::Error> { +#[tokio::main] +async fn main() -> Result<(), io::Error> { const HQ_INTEROP: &str = "hq-interop"; let mut args = Args::from_args(); @@ -896,5 +831,5 @@ fn main() -> Result<(), io::Error> { } let mut servers_runner = ServersRunner::new(args)?; - servers_runner.run() + servers_runner.run().await }