From 986605b4428278bf411267e4386d8b5749b73043 Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Sat, 15 Dec 2018 00:49:35 -0800 Subject: [PATCH] Send/receive ECN bits on linux --- quinn/Cargo.toml | 2 + quinn/src/lib.rs | 28 ++++--- quinn/src/platform/cmsg.rs | 102 +++++++++++++++++++++++ quinn/src/platform/fallback.rs | 24 ++++++ quinn/src/platform/linux.rs | 143 +++++++++++++++++++++++++++++++++ quinn/src/platform/mod.rs | 25 ++++++ quinn/src/udp.rs | 60 ++++++++++++++ 7 files changed, 372 insertions(+), 12 deletions(-) create mode 100644 quinn/src/platform/cmsg.rs create mode 100644 quinn/src/platform/fallback.rs create mode 100644 quinn/src/platform/linux.rs create mode 100644 quinn/src/platform/mod.rs create mode 100644 quinn/src/udp.rs diff --git a/quinn/Cargo.toml b/quinn/Cargo.toml index f087a2e6d..b0089022a 100644 --- a/quinn/Cargo.toml +++ b/quinn/Cargo.toml @@ -38,6 +38,8 @@ tokio-timer = "0.2.1" untrusted = "0.6.2" webpki = "0.18" webpki-roots = "0.15" +libc = { git = "https://github.com/Ralith/libc/", branch = "ip-tos" } +mio = "0.6" [dev-dependencies] slog-term = "2" diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index ec49ad887..93d0c437b 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -52,6 +52,9 @@ extern crate failure; #[macro_use] extern crate slog; +mod platform; +mod udp; + use std::borrow::Cow; use std::cell::RefCell; use std::collections::{hash_map, VecDeque}; @@ -74,11 +77,11 @@ use rustls::{Certificate, KeyLogFile, PrivateKey, ProtocolVersion, TLSError}; use slog::Logger; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_timer::Delay; -use tokio_udp::UdpSocket; pub use crate::quinn::{ Config, ConnectError, ConnectionError, ConnectionId, ServerConfig, ALPN_QUIC_HTTP, }; +use crate::udp::UdpSocket; /// Errors that can occur during the construction of an `Endpoint`. #[derive(Debug, Fail)] @@ -113,7 +116,7 @@ struct EndpointInner { log: Logger, socket: UdpSocket, inner: quinn::Endpoint, - outgoing: VecDeque<(SocketAddrV6, Box<[u8]>)>, + outgoing: VecDeque<(SocketAddrV6, Option, Box<[u8]>)>, epoch: Instant, pending: FnvHashMap, // TODO: Replace this with something custom that avoids using oneshots to cancel @@ -646,12 +649,11 @@ impl Future for Driver { let now = micros_from(endpoint.epoch.elapsed()); loop { loop { - // TODO: Read ECN codepoint - match endpoint.socket.poll_recv_from(&mut buf) { - Ok(Async::Ready((n, addr))) => { + match endpoint.socket.poll_recv(&mut buf) { + Ok(Async::Ready((n, addr, ecn))) => { endpoint .inner - .handle(now, normalize(addr), None, (&buf[0..n]).into()); + .handle(now, normalize(addr), ecn, (&buf[0..n]).into()); } Ok(Async::NotReady) => { break; @@ -749,8 +751,11 @@ impl Future for Driver { let mut blocked = false; while !endpoint.outgoing.is_empty() { { - let front = endpoint.outgoing.front().unwrap(); - match endpoint.socket.poll_send_to(&front.1, &front.0.into()) { + let (destination, ecn, packet) = endpoint.outgoing.front().unwrap(); + match endpoint + .socket + .poll_send(&(*destination).into(), *ecn, packet) + { Ok(Async::Ready(_)) => {} Ok(Async::NotReady) => { blocked = true; @@ -773,11 +778,10 @@ impl Future for Driver { Transmit { destination, packet, - .. + ecn, } => { if !blocked { - // TODO: Set ECN codepoint - match endpoint.socket.poll_send_to(&packet, &destination.into()) { + match endpoint.socket.poll_send(&destination.into(), ecn, &packet) { Ok(Async::Ready(_)) => {} Ok(Async::NotReady) => { blocked = true; @@ -791,7 +795,7 @@ impl Future for Driver { } } if blocked { - endpoint.outgoing.push_front((destination, packet)); + endpoint.outgoing.push_front((destination, ecn, packet)); } } TimerUpdate { diff --git a/quinn/src/platform/cmsg.rs b/quinn/src/platform/cmsg.rs new file mode 100644 index 000000000..9c24bfbc4 --- /dev/null +++ b/quinn/src/platform/cmsg.rs @@ -0,0 +1,102 @@ +use std::{mem, ptr}; + +macro_rules! cmsgs { + {$($level:ident { $($name:ident : $ty:ty;)* })*} => { + #[allow(non_camel_case_types)] + #[derive(Debug, Copy, Clone)] + pub enum Cmsg { + $($($name($ty),)*)* + } + + impl Cmsg { + pub fn space(&self) -> usize { + let x = match *self { + $($(Cmsg::$name(_) => unsafe { libc::CMSG_SPACE(mem::size_of::<$ty>() as _)},)*)* + }; + x as usize + } + + unsafe fn encode(&self, cmsg: &mut libc::cmsghdr) { + match *self { + $($(Cmsg::$name(x) => { + cmsg.cmsg_level = libc::$level as _; + cmsg.cmsg_type = libc::$name as _; + cmsg.cmsg_len = libc::CMSG_LEN(mem::size_of::<$ty>() as _) as _; + ptr::write::<$ty>(libc::CMSG_DATA(cmsg) as *mut $ty, x); + })*)* + } + } + + unsafe fn decode(cmsg: &libc::cmsghdr) -> Option { + Some(match cmsg.cmsg_level { + $(libc::$level => match cmsg.cmsg_type { + $(libc::$name => Cmsg::$name(ptr::read::<$ty>(libc::CMSG_DATA(cmsg) as *const $ty)),)* + _ => { return None; } + },)* + _ => { return None; } + }) + } + } + } +} + +cmsgs! { + IPPROTO_IP { + IP_TOS: u8; + } + IPPROTO_IPV6 { + IPV6_TCLASS: libc::c_int; + } +} + +pub fn encode(hdr: &mut libc::msghdr, buf: &mut [u8], msgs: &[Cmsg]) { + assert!(buf.len() >= msgs.iter().map(|msg| msg.space()).sum()); + hdr.msg_control = buf.as_mut_ptr() as _; + hdr.msg_controllen = buf.len() as _; + + let mut len = 0; + let mut cursor = unsafe { libc::CMSG_FIRSTHDR(hdr) }; + for msg in msgs { + unsafe { + msg.encode(&mut *cursor); + } + len += msg.space(); + cursor = unsafe { libc::CMSG_NXTHDR(hdr, cursor) }; + } + debug_assert!(len as usize <= buf.len()); + hdr.msg_controllen = len; +} + +pub struct Iter<'a> { + hdr: &'a libc::msghdr, + cmsg: *const libc::cmsghdr, +} + +impl<'a> Iter<'a> { + /// # Safety + /// + /// `hdr.msg_control` must point to mutable memory containing at least `hdr.msg_controllen` + /// bytes, which lives at least as long as `'a`. + pub unsafe fn new(hdr: &'a libc::msghdr) -> Self { + Self { + hdr, + cmsg: libc::CMSG_FIRSTHDR(hdr), + } + } +} + +impl<'a> Iterator for Iter<'a> { + type Item = Cmsg; + fn next(&mut self) -> Option { + loop { + if self.cmsg.is_null() { + return None; + } + let current = self.cmsg; + self.cmsg = unsafe { libc::CMSG_NXTHDR(self.hdr, self.cmsg) }; + if let Some(x) = unsafe { Cmsg::decode(&*current) } { + return Some(x); + } + } + } +} diff --git a/quinn/src/platform/fallback.rs b/quinn/src/platform/fallback.rs new file mode 100644 index 000000000..d4632a3c9 --- /dev/null +++ b/quinn/src/platform/fallback.rs @@ -0,0 +1,24 @@ +use std::{io, net::SocketAddr}; + +use mio::net::UdpSocket; + +use quinn_proto::EcnCodepoint; + +impl super::UdpExt for UdpSocket { + fn init_ext(&self) -> io::Result<()> { + Ok(()) + } + + fn send_ext( + &self, + remote: &SocketAddr, + _: Option, + msg: &[u8], + ) -> io::Result { + self.send_to(msg, remote) + } + + fn recv_ext(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr, Option)> { + self.recv_from(buf).map(|(x, y)| (x, y, None)) + } +} diff --git a/quinn/src/platform/linux.rs b/quinn/src/platform/linux.rs new file mode 100644 index 000000000..57a7e6d95 --- /dev/null +++ b/quinn/src/platform/linux.rs @@ -0,0 +1,143 @@ +use std::os::unix::io::AsRawFd; +use std::{ + io, mem, + net::{SocketAddr, SocketAddrV4, SocketAddrV6}, + ptr, +}; + +use mio::net::UdpSocket; + +use quinn_proto::EcnCodepoint; + +use super::cmsg::{self, Cmsg}; + +impl super::UdpExt for UdpSocket { + fn init_ext(&self) -> io::Result<()> { + // Safety + assert_eq!( + mem::size_of::(), + mem::size_of::() + ); + assert_eq!( + mem::size_of::(), + mem::size_of::() + ); + assert_eq!( + CMSG_LEN, + std::cmp::max(Cmsg::IP_TOS(0).space(), Cmsg::IPV6_TCLASS(0).space(),) as usize + ); + + if !self.only_v6()? { + let rc = unsafe { + libc::setsockopt( + self.as_raw_fd(), + libc::IPPROTO_IP, + libc::IP_RECVTOS, + &true as *const _ as _, + 1, + ) + }; + if rc == -1 { + return Err(io::Error::last_os_error()); + } + } + let on: libc::c_int = 1; + let rc = unsafe { + libc::setsockopt( + self.as_raw_fd(), + libc::IPPROTO_IPV6, + libc::IPV6_RECVTCLASS, + &on as *const _ as _, + mem::size_of::() as _, + ) + }; + if rc == -1 { + return Err(io::Error::last_os_error()); + } + Ok(()) + } + + fn send_ext( + &self, + remote: &SocketAddr, + ecn: Option, + msg: &[u8], + ) -> io::Result { + let (name, namelen) = match *remote { + SocketAddr::V4(ref addr) => { + (addr as *const _ as _, mem::size_of::()) + } + SocketAddr::V6(ref addr) => { + (addr as *const _ as _, mem::size_of::()) + } + }; + let ecn = ecn.map_or(0, |x| x as u8); + let mut iov = libc::iovec { + iov_base: msg.as_ptr() as *const _ as *mut _, + iov_len: msg.len(), + }; + let mut hdr = libc::msghdr { + msg_name: name, + msg_namelen: namelen as _, + msg_iov: &mut iov, + msg_iovlen: 1, + msg_control: ptr::null_mut(), + msg_controllen: 0, + msg_flags: 0, + }; + let cmsg; + if remote.is_ipv4() { + cmsg = Cmsg::IP_TOS(ecn as _); + } else { + cmsg = Cmsg::IPV6_TCLASS(ecn as _); + } + let mut ctrl: [u8; CMSG_LEN] = unsafe { mem::uninitialized() }; + cmsg::encode(&mut hdr, &mut ctrl, &[cmsg]); + let n = unsafe { libc::sendmsg(self.as_raw_fd(), &hdr, 0) }; + if n == -1 { + return Err(io::Error::last_os_error()); + } + Ok(n as usize) + } + + fn recv_ext(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr, Option)> { + let mut name: libc::sockaddr_storage = unsafe { mem::uninitialized() }; + let mut iov = libc::iovec { + iov_base: buf.as_ptr() as *mut _, + iov_len: buf.len(), + }; + let mut ctrl: [u8; CMSG_LEN] = unsafe { mem::uninitialized() }; + let mut hdr = libc::msghdr { + msg_name: &mut name as *mut _ as _, + msg_namelen: mem::size_of::() as _, + msg_iov: &mut iov, + msg_iovlen: 1, + msg_control: ctrl.as_mut_ptr() as _, + msg_controllen: CMSG_LEN as _, + msg_flags: 0, + }; + let n = unsafe { libc::recvmsg(self.as_raw_fd(), &mut hdr, 0) }; + if n == -1 { + return Err(io::Error::last_os_error()); + } + let mut ecn = None; + for cmsg in unsafe { cmsg::Iter::new(&hdr) } { + match cmsg { + Cmsg::IP_TOS(bits) => { + ecn = EcnCodepoint::from_bits(bits); + } + Cmsg::IPV6_TCLASS(bits) => { + ecn = EcnCodepoint::from_bits(bits as u8); + } + } + } + let addr = match name.ss_family as libc::c_int { + libc::AF_INET => unsafe { SocketAddr::V4(ptr::read(&name as *const _ as _)) }, + libc::AF_INET6 => unsafe { SocketAddr::V6(ptr::read(&name as *const _ as _)) }, + _ => unreachable!(), + }; + Ok((n as usize, addr, ecn)) + } +} + +const CMSG_LEN: usize = 24; diff --git a/quinn/src/platform/mod.rs b/quinn/src/platform/mod.rs new file mode 100644 index 000000000..4ee674336 --- /dev/null +++ b/quinn/src/platform/mod.rs @@ -0,0 +1,25 @@ +//! Uniform interface to send/recv UDP packets with ECN information. +use quinn_proto::EcnCodepoint; +use std::{io, net::SocketAddr}; + +// The Linux code should work for most unixes, but as of this writing nobody's ported the +// CMSG_... macros to the libc crate for any of the BSDs. +#[cfg(target_os = "linux")] +mod cmsg; +#[cfg(target_os = "linux")] +mod linux; + +// No ECN support +#[cfg(not(target_os = "linux"))] +mod fallback; + +pub trait UdpExt { + fn init_ext(&self) -> io::Result<()>; + fn send_ext( + &self, + remote: &SocketAddr, + ecn: Option, + msg: &[u8], + ) -> io::Result; + fn recv_ext(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr, Option)>; +} diff --git a/quinn/src/udp.rs b/quinn/src/udp.rs new file mode 100644 index 000000000..7d02ef2be --- /dev/null +++ b/quinn/src/udp.rs @@ -0,0 +1,60 @@ +use std::io; +use std::net::SocketAddr; + +use futures::{try_ready, Async, Poll}; +use mio; + +use tokio_reactor::{Handle, PollEvented}; + +use quinn_proto::EcnCodepoint; + +use crate::platform::UdpExt; + +/// Tokio-compatible UDP socket with some useful specializations. +/// +/// Unlike a standard tokio UDP socket, this allows ECN bits to be read and written on some +/// platforms. +pub struct UdpSocket { + io: PollEvented, +} + +impl UdpSocket { + pub fn from_std(socket: std::net::UdpSocket, handle: &Handle) -> io::Result { + let io = mio::net::UdpSocket::from_socket(socket)?; + io.init_ext()?; + let io = PollEvented::new_with_handle(io, handle)?; + Ok(UdpSocket { io }) + } + + pub fn poll_send( + &self, + remote: &SocketAddr, + ecn: Option, + msg: &[u8], + ) -> Poll { + try_ready!(self.io.poll_write_ready()); + match self.io.get_ref().send_ext(remote, ecn, msg) { + Ok(n) => Ok(Async::Ready(n)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_write_ready()?; + Ok(Async::NotReady) + } + Err(e) => Err(e), + } + } + + pub fn poll_recv( + &self, + buf: &mut [u8], + ) -> Poll<(usize, SocketAddr, Option), io::Error> { + try_ready!(self.io.poll_read_ready(mio::Ready::readable())); + match self.io.get_ref().recv_ext(buf) { + Ok(n) => Ok(n.into()), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_read_ready(mio::Ready::readable())?; + Ok(Async::NotReady) + } + Err(e) => Err(e), + } + } +}