Skip to content

Commit

Permalink
refactor(server): replace mio with tokio (#1581)
Browse files Browse the repository at this point in the history
* refactor(server): replace mio with tokio

* Move ready logic into fn

* Extend expect docs

* Restrict tokio features

* Only process datagram once

* Remove superfluous pub

* fmt

* Fix busy loop

* Fold `ServersRunner::init into ServersRunner::new

* Fix imports

---------

Signed-off-by: Max Inden <mail@max-inden.de>
  • Loading branch information
mxinden committed Feb 1, 2024
1 parent 5e32696 commit 20c8e8c
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 160 deletions.
4 changes: 2 additions & 2 deletions neqo-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ rust-version = "1.70.0"
license = "MIT OR Apache-2.0"

[dependencies]
futures = "0.3"
log = {version = "0.4.17", default-features = false}
mio = "0.6.23"
mio-extras = "2.0.6"
neqo-common = { path="./../neqo-common" }
neqo-crypto = { path = "./../neqo-crypto" }
neqo-http3 = { path = "./../neqo-http3" }
Expand All @@ -18,6 +17,7 @@ neqo-transport = { path = "./../neqo-transport" }
qlog = "0.11.0"
regex = "1.9"
structopt = "0.3"
tokio = { version = "1", features = ["net", "time", "macros", "rt", "rt-multi-thread"] }

[features]
deny-warnings = []
251 changes: 93 additions & 158 deletions neqo-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,27 @@
use std::{
cell::RefCell,
cmp::min,
collections::{HashMap, HashSet},
collections::HashMap,
convert::TryFrom,
fmt::{self, Display},
fs::OpenOptions,
io,
io::Read,
mem,
net::{SocketAddr, ToSocketAddrs},
path::PathBuf,
pin::Pin,
process::exit,
rc::Rc,
str::FromStr,
time::{Duration, Instant},
};

use mio::{net::UdpSocket, Events, Poll, PollOpt, Ready, Token};
use mio_extras::timer::{Builder, Timeout, Timer};
use futures::{
future::{select, select_all, Either},
FutureExt,
};
use tokio::{net::UdpSocket, time::Sleep};

use neqo_common::{hex, qdebug, qinfo, qwarn, Datagram, Header, IpTos};
use neqo_crypto::{
constants::{TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256},
Expand All @@ -44,7 +48,6 @@ use structopt::StructOpt;

use crate::old_https::Http09Server;

const TIMER_TOKEN: Token = Token(0xffff_ffff);
const ANTI_REPLAY_WINDOW: Duration = Duration::from_secs(10);

mod old_https;
Expand Down Expand Up @@ -316,8 +319,8 @@ impl QuicParameters {
}
}

fn emit_packet(socket: &mut UdpSocket, out_dgram: Datagram) {
let sent = match socket.send_to(&out_dgram, &out_dgram.destination()) {
async fn emit_packet(socket: &mut UdpSocket, out_dgram: Datagram) {
let sent = match socket.send_to(&out_dgram, &out_dgram.destination()).await {
Err(ref err) => {
if err.kind() != io::ErrorKind::WouldBlock || err.kind() == io::ErrorKind::Interrupted {
eprintln!("UDP send error: {err:?}");
Expand Down Expand Up @@ -594,7 +597,7 @@ fn read_dgram(
local_address: &SocketAddr,
) -> Result<Option<Datagram>, io::Error> {
let buf = &mut [0u8; 2048];
let (sz, remote_addr) = match socket.recv_from(&mut buf[..]) {
let (sz, remote_addr) = match socket.try_recv_from(&mut buf[..]) {
Err(ref err)
if err.kind() == io::ErrorKind::WouldBlock
|| err.kind() == io::ErrorKind::Interrupted =>
Expand Down Expand Up @@ -628,82 +631,36 @@ fn read_dgram(

struct ServersRunner {
args: Args,
poll: Poll,
hosts: Vec<SocketAddr>,
server: Box<dyn HttpServer>,
timeout: Option<Timeout>,
sockets: Vec<UdpSocket>,
active_sockets: HashSet<usize>,
timer: Timer<usize>,
timeout: Option<Pin<Box<Sleep>>>,
sockets: Vec<(SocketAddr, UdpSocket)>,
}

impl ServersRunner {
pub fn new(args: Args) -> Result<Self, io::Error> {
let server = Self::create_server(&args);
let mut runner = Self {
args,
poll: Poll::new()?,
hosts: Vec::new(),
server,
timeout: None,
sockets: Vec::new(),
active_sockets: HashSet::new(),
timer: Builder::default()
.tick_duration(Duration::from_millis(1))
.build::<usize>(),
};
runner.init()?;
Ok(runner)
}

/// Init Poll for all hosts. Create sockets, and a map of the
/// socketaddrs to instances of the HttpServer handling that addr.
fn init(&mut self) -> Result<(), io::Error> {
self.hosts = self.args.listen_addresses();
if self.hosts.is_empty() {
let hosts = args.listen_addresses();
if hosts.is_empty() {
eprintln!("No valid hosts defined");
return Err(io::Error::new(io::ErrorKind::InvalidInput, "No hosts"));
}
let sockets = hosts
.into_iter()
.map(|host| {
let socket = std::net::UdpSocket::bind(host)?;
let local_addr = socket.local_addr()?;
println!("Server waiting for connection on: {local_addr:?}");
socket.set_nonblocking(true)?;
Ok((host, UdpSocket::from_std(socket)?))
})
.collect::<Result<_, io::Error>>()?;
let server = Self::create_server(&args);

for (i, host) in self.hosts.iter().enumerate() {
let socket = match UdpSocket::bind(host) {
Err(err) => {
eprintln!("Unable to bind UDP socket: {err}");
return Err(err);
}
Ok(s) => s,
};

let local_addr = match socket.local_addr() {
Err(err) => {
eprintln!("Socket local address not bound: {err}");
return Err(err);
}
Ok(s) => s,
};

print!("Server waiting for connection on: {local_addr:?}");
// On Windows, this is not supported.
#[cfg(not(target_os = "windows"))]
if !socket.only_v6().unwrap_or(true) {
print!(" as well as V4");
};
println!();

self.poll.register(
&socket,
Token(i),
Ready::readable() | Ready::writable(),
PollOpt::edge(),
)?;

self.sockets.push(socket);
}

self.poll
.register(&self.timer, TIMER_TOKEN, Ready::readable(), PollOpt::edge())?;

Ok(())
Ok(Self {
args,
server,
timeout: None,
sockets,
})
}

fn create_server(args: &Args) -> Box<dyn HttpServer> {
Expand Down Expand Up @@ -741,110 +698,88 @@ impl ServersRunner {

/// Tries to find a socket, but then just falls back to sending from the first.
fn find_socket(&mut self, addr: SocketAddr) -> &mut UdpSocket {
let (first, rest) = self.sockets.split_first_mut().unwrap();
let ((_host, first_socket), rest) = self.sockets.split_first_mut().unwrap();
rest.iter_mut()
.find(|s| {
s.local_addr()
.map(|(_host, socket)| socket)
.find(|socket| {
socket
.local_addr()
.ok()
.map_or(false, |socket_addr| socket_addr == addr)
})
.unwrap_or(first)
.unwrap_or(first_socket)
}

fn process(&mut self, inx: usize, dgram: Option<&Datagram>) -> bool {
match self.server.process(dgram, self.args.now()) {
Output::Datagram(dgram) => {
let socket = self.find_socket(dgram.source());
emit_packet(socket, dgram);
true
}
Output::Callback(new_timeout) => {
if let Some(to) = &self.timeout {
self.timer.cancel_timeout(to);
async fn process(&mut self, mut dgram: Option<&Datagram>) {
loop {
match self.server.process(dgram.take(), self.args.now()) {
Output::Datagram(dgram) => {
let socket = self.find_socket(dgram.source());
emit_packet(socket, dgram).await;
}
Output::Callback(new_timeout) => {
qinfo!("Setting timeout of {:?}", new_timeout);
self.timeout = Some(Box::pin(tokio::time::sleep(new_timeout)));
break;
}
Output::None => {
qdebug!("Output::None");
break;
}

qinfo!("Setting timeout of {:?} for socket {}", new_timeout, inx);
self.timeout = Some(self.timer.set_timeout(new_timeout, inx));
false
}
Output::None => {
qdebug!("Output::None");
false
}
}
}

fn process_datagrams_and_events(
&mut self,
inx: usize,
read_socket: bool,
) -> Result<(), io::Error> {
if self.sockets.get_mut(inx).is_some() {
if read_socket {
loop {
let socket = self.sockets.get_mut(inx).unwrap();
let dgram = read_dgram(socket, &self.hosts[inx])?;
// Wait for any of the sockets to be readable or the timeout to fire.
async fn ready(&mut self) -> Result<Ready, io::Error> {
let sockets_ready = select_all(
self.sockets
.iter()
.map(|(_host, socket)| Box::pin(socket.readable())),
)
.map(|(res, inx, _)| match res {
Ok(()) => Ok(Ready::Socket(inx)),
Err(e) => Err(e),
});
let timeout_ready = self
.timeout
.as_mut()
.map(Either::Left)
.unwrap_or(Either::Right(futures::future::pending()))
.map(|()| Ok(Ready::Timeout));
select(sockets_ready, timeout_ready).await.factor_first().0
}

async fn run(&mut self) -> Result<(), io::Error> {
loop {
match self.ready().await? {
Ready::Socket(inx) => loop {
let (host, socket) = self.sockets.get_mut(inx).unwrap();
let dgram = read_dgram(socket, host)?;
if dgram.is_none() {
break;
}
_ = self.process(inx, dgram.as_ref());
self.process(dgram.as_ref()).await;
},
Ready::Timeout => {
self.timeout = None;
self.process(None).await;
}
} else {
_ = self.process(inx, None);
}
self.server.process_events(&self.args, self.args.now());
if self.process(inx, None) {
self.active_sockets.insert(inx);
}
}
Ok(())
}

fn process_active_conns(&mut self) -> Result<(), io::Error> {
let curr_active = mem::take(&mut self.active_sockets);
for inx in curr_active {
self.process_datagrams_and_events(inx, false)?;
}
Ok(())
}

fn process_timeout(&mut self) -> Result<(), io::Error> {
while let Some(inx) = self.timer.poll() {
qinfo!("Timer expired for {:?}", inx);
self.process_datagrams_and_events(inx, false)?;
self.server.process_events(&self.args, self.args.now());
self.process(None).await;
}
Ok(())
}
}

pub fn run(&mut self) -> Result<(), io::Error> {
let mut events = Events::with_capacity(1024);
loop {
// If there are active servers do not block in poll.
self.poll.poll(
&mut events,
if self.active_sockets.is_empty() {
None
} else {
Some(Duration::from_millis(0))
},
)?;

for event in &events {
if event.token() == TIMER_TOKEN {
self.process_timeout()?;
} else {
if !event.readiness().is_readable() {
continue;
}
self.process_datagrams_and_events(event.token().0, true)?;
}
}
self.process_active_conns()?;
}
}
enum Ready {
Socket(usize),
Timeout,
}

fn main() -> Result<(), io::Error> {
#[tokio::main]
async fn main() -> Result<(), io::Error> {
const HQ_INTEROP: &str = "hq-interop";

let mut args = Args::from_args();
Expand Down Expand Up @@ -896,5 +831,5 @@ fn main() -> Result<(), io::Error> {
}

let mut servers_runner = ServersRunner::new(args)?;
servers_runner.run()
servers_runner.run().await
}

0 comments on commit 20c8e8c

Please sign in to comment.