Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(server): replace mio with tokio #1581

Merged
merged 12 commits into from
Feb 1, 2024
4 changes: 2 additions & 2 deletions neqo-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ rust-version = "1.70.0"
license = "MIT OR Apache-2.0"

[dependencies]
futures = "0.3"
log = {version = "0.4.17", default-features = false}
mio = "0.6.23"
mio-extras = "2.0.6"
neqo-common = { path="./../neqo-common" }
neqo-crypto = { path = "./../neqo-crypto" }
neqo-http3 = { path = "./../neqo-http3" }
Expand All @@ -18,6 +17,7 @@ neqo-transport = { path = "./../neqo-transport" }
qlog = "0.11.0"
regex = "1.9"
structopt = "0.3"
tokio = { version = "1", features = ["net", "time", "macros", "rt", "rt-multi-thread"] }

[features]
deny-warnings = []
251 changes: 93 additions & 158 deletions neqo-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,27 @@
use std::{
cell::RefCell,
cmp::min,
collections::{HashMap, HashSet},
collections::HashMap,
convert::TryFrom,
fmt::{self, Display},
fs::OpenOptions,
io,
io::Read,
mem,
net::{SocketAddr, ToSocketAddrs},
path::PathBuf,
pin::Pin,
process::exit,
rc::Rc,
str::FromStr,
time::{Duration, Instant},
};

use mio::{net::UdpSocket, Events, Poll, PollOpt, Ready, Token};
use mio_extras::timer::{Builder, Timeout, Timer};
use futures::{
future::{select, select_all, Either},
FutureExt,
};
use tokio::{net::UdpSocket, time::Sleep};

use neqo_common::{hex, qdebug, qinfo, qwarn, Datagram, Header, IpTos};
use neqo_crypto::{
constants::{TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256},
Expand All @@ -44,7 +48,6 @@ use structopt::StructOpt;

use crate::old_https::Http09Server;

const TIMER_TOKEN: Token = Token(0xffff_ffff);
const ANTI_REPLAY_WINDOW: Duration = Duration::from_secs(10);

mod old_https;
Expand Down Expand Up @@ -316,8 +319,8 @@ impl QuicParameters {
}
}

fn emit_packet(socket: &mut UdpSocket, out_dgram: Datagram) {
let sent = match socket.send_to(&out_dgram, &out_dgram.destination()) {
async fn emit_packet(socket: &mut UdpSocket, out_dgram: Datagram) {
let sent = match socket.send_to(&out_dgram, &out_dgram.destination()).await {
Err(ref err) => {
if err.kind() != io::ErrorKind::WouldBlock || err.kind() == io::ErrorKind::Interrupted {
eprintln!("UDP send error: {err:?}");
Expand Down Expand Up @@ -594,7 +597,7 @@ fn read_dgram(
local_address: &SocketAddr,
) -> Result<Option<Datagram>, io::Error> {
let buf = &mut [0u8; 2048];
let (sz, remote_addr) = match socket.recv_from(&mut buf[..]) {
let (sz, remote_addr) = match socket.try_recv_from(&mut buf[..]) {
Err(ref err)
if err.kind() == io::ErrorKind::WouldBlock
|| err.kind() == io::ErrorKind::Interrupted =>
Expand Down Expand Up @@ -628,82 +631,36 @@ fn read_dgram(

struct ServersRunner {
args: Args,
poll: Poll,
hosts: Vec<SocketAddr>,
server: Box<dyn HttpServer>,
timeout: Option<Timeout>,
sockets: Vec<UdpSocket>,
active_sockets: HashSet<usize>,
timer: Timer<usize>,
timeout: Option<Pin<Box<Sleep>>>,
sockets: Vec<(SocketAddr, UdpSocket)>,
}

impl ServersRunner {
pub fn new(args: Args) -> Result<Self, io::Error> {
let server = Self::create_server(&args);
let mut runner = Self {
args,
poll: Poll::new()?,
hosts: Vec::new(),
server,
timeout: None,
sockets: Vec::new(),
active_sockets: HashSet::new(),
timer: Builder::default()
.tick_duration(Duration::from_millis(1))
.build::<usize>(),
};
runner.init()?;
Ok(runner)
}

/// Init Poll for all hosts. Create sockets, and a map of the
/// socketaddrs to instances of the HttpServer handling that addr.
fn init(&mut self) -> Result<(), io::Error> {
self.hosts = self.args.listen_addresses();
if self.hosts.is_empty() {
let hosts = args.listen_addresses();
if hosts.is_empty() {
eprintln!("No valid hosts defined");
return Err(io::Error::new(io::ErrorKind::InvalidInput, "No hosts"));
}
let sockets = hosts
.into_iter()
.map(|host| {
let socket = std::net::UdpSocket::bind(host)?;
let local_addr = socket.local_addr()?;
println!("Server waiting for connection on: {local_addr:?}");
socket.set_nonblocking(true)?;
Ok((host, UdpSocket::from_std(socket)?))
})
.collect::<Result<_, io::Error>>()?;
let server = Self::create_server(&args);

for (i, host) in self.hosts.iter().enumerate() {
let socket = match UdpSocket::bind(host) {
Err(err) => {
eprintln!("Unable to bind UDP socket: {err}");
return Err(err);
}
Ok(s) => s,
};

let local_addr = match socket.local_addr() {
Err(err) => {
eprintln!("Socket local address not bound: {err}");
return Err(err);
}
Ok(s) => s,
};

print!("Server waiting for connection on: {local_addr:?}");
// On Windows, this is not supported.
#[cfg(not(target_os = "windows"))]
if !socket.only_v6().unwrap_or(true) {
print!(" as well as V4");
};
println!();
mxinden marked this conversation as resolved.
Show resolved Hide resolved

self.poll.register(
&socket,
Token(i),
Ready::readable() | Ready::writable(),
PollOpt::edge(),
)?;

self.sockets.push(socket);
}

self.poll
.register(&self.timer, TIMER_TOKEN, Ready::readable(), PollOpt::edge())?;

Ok(())
Ok(Self {
args,
server,
timeout: None,
sockets,
})
}

fn create_server(args: &Args) -> Box<dyn HttpServer> {
Expand Down Expand Up @@ -741,110 +698,88 @@ impl ServersRunner {

/// Tries to find a socket, but then just falls back to sending from the first.
fn find_socket(&mut self, addr: SocketAddr) -> &mut UdpSocket {
let (first, rest) = self.sockets.split_first_mut().unwrap();
let ((_host, first_socket), rest) = self.sockets.split_first_mut().unwrap();
rest.iter_mut()
.find(|s| {
s.local_addr()
.map(|(_host, socket)| socket)
.find(|socket| {
socket
.local_addr()
.ok()
.map_or(false, |socket_addr| socket_addr == addr)
})
.unwrap_or(first)
.unwrap_or(first_socket)
}

fn process(&mut self, inx: usize, dgram: Option<&Datagram>) -> bool {
match self.server.process(dgram, self.args.now()) {
Output::Datagram(dgram) => {
let socket = self.find_socket(dgram.source());
emit_packet(socket, dgram);
true
}
Output::Callback(new_timeout) => {
if let Some(to) = &self.timeout {
self.timer.cancel_timeout(to);
async fn process(&mut self, mut dgram: Option<&Datagram>) {
loop {
match self.server.process(dgram.take(), self.args.now()) {
mxinden marked this conversation as resolved.
Show resolved Hide resolved
Output::Datagram(dgram) => {
let socket = self.find_socket(dgram.source());
emit_packet(socket, dgram).await;
}
Output::Callback(new_timeout) => {
qinfo!("Setting timeout of {:?}", new_timeout);
self.timeout = Some(Box::pin(tokio::time::sleep(new_timeout)));
break;
}
Output::None => {
qdebug!("Output::None");
break;
}

qinfo!("Setting timeout of {:?} for socket {}", new_timeout, inx);
self.timeout = Some(self.timer.set_timeout(new_timeout, inx));
false
}
Output::None => {
qdebug!("Output::None");
false
}
}
}

fn process_datagrams_and_events(
&mut self,
inx: usize,
read_socket: bool,
) -> Result<(), io::Error> {
if self.sockets.get_mut(inx).is_some() {
if read_socket {
loop {
let socket = self.sockets.get_mut(inx).unwrap();
let dgram = read_dgram(socket, &self.hosts[inx])?;
// Wait for any of the sockets to be readable or the timeout to fire.
async fn ready(&mut self) -> Result<Ready, io::Error> {
let sockets_ready = select_all(
self.sockets
.iter()
.map(|(_host, socket)| Box::pin(socket.readable())),
)
.map(|(res, inx, _)| match res {
Ok(()) => Ok(Ready::Socket(inx)),
Err(e) => Err(e),
});
let timeout_ready = self
.timeout
.as_mut()
.map(Either::Left)
.unwrap_or(Either::Right(futures::future::pending()))
.map(|()| Ok(Ready::Timeout));
select(sockets_ready, timeout_ready).await.factor_first().0
}

async fn run(&mut self) -> Result<(), io::Error> {
loop {
match self.ready().await? {
Ready::Socket(inx) => loop {
let (host, socket) = self.sockets.get_mut(inx).unwrap();
let dgram = read_dgram(socket, host)?;
if dgram.is_none() {
break;
}
_ = self.process(inx, dgram.as_ref());
self.process(dgram.as_ref()).await;
},
Ready::Timeout => {
self.timeout = None;
self.process(None).await;
}
} else {
_ = self.process(inx, None);
}
self.server.process_events(&self.args, self.args.now());
if self.process(inx, None) {
self.active_sockets.insert(inx);
}
}
Ok(())
}

fn process_active_conns(&mut self) -> Result<(), io::Error> {
let curr_active = mem::take(&mut self.active_sockets);
for inx in curr_active {
self.process_datagrams_and_events(inx, false)?;
}
Ok(())
}

fn process_timeout(&mut self) -> Result<(), io::Error> {
while let Some(inx) = self.timer.poll() {
qinfo!("Timer expired for {:?}", inx);
self.process_datagrams_and_events(inx, false)?;
self.server.process_events(&self.args, self.args.now());
self.process(None).await;
}
Ok(())
}
}

pub fn run(&mut self) -> Result<(), io::Error> {
let mut events = Events::with_capacity(1024);
loop {
// If there are active servers do not block in poll.
self.poll.poll(
&mut events,
if self.active_sockets.is_empty() {
None
} else {
Some(Duration::from_millis(0))
},
)?;

for event in &events {
if event.token() == TIMER_TOKEN {
self.process_timeout()?;
} else {
if !event.readiness().is_readable() {
continue;
}
self.process_datagrams_and_events(event.token().0, true)?;
}
}
self.process_active_conns()?;
}
}
enum Ready {
Socket(usize),
Timeout,
}

fn main() -> Result<(), io::Error> {
#[tokio::main]
async fn main() -> Result<(), io::Error> {
const HQ_INTEROP: &str = "hq-interop";

let mut args = Args::from_args();
Expand Down Expand Up @@ -896,5 +831,5 @@ fn main() -> Result<(), io::Error> {
}

let mut servers_runner = ServersRunner::new(args)?;
servers_runner.run()
servers_runner.run().await
}
Loading