Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send/receive ECN bits on linux #126

Merged
merged 1 commit into from
Dec 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions quinn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ tokio-timer = "0.2.1"
untrusted = "0.6.2"
webpki = "0.18"
webpki-roots = "0.15"
libc = { git = "https://github.com/Ralith/libc/", branch = "ip-tos" }
mio = "0.6"

[dev-dependencies]
slog-term = "2"
Expand Down
28 changes: 16 additions & 12 deletions quinn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ extern crate failure;
#[macro_use]
extern crate slog;

mod platform;
mod udp;

use std::borrow::Cow;
use std::cell::RefCell;
use std::collections::{hash_map, VecDeque};
Expand All @@ -74,11 +77,11 @@ use rustls::{Certificate, KeyLogFile, PrivateKey, ProtocolVersion, TLSError};
use slog::Logger;
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_timer::Delay;
use tokio_udp::UdpSocket;

pub use crate::quinn::{
Config, ConnectError, ConnectionError, ConnectionId, ServerConfig, ALPN_QUIC_HTTP,
};
use crate::udp::UdpSocket;

/// Errors that can occur during the construction of an `Endpoint`.
#[derive(Debug, Fail)]
Expand Down Expand Up @@ -113,7 +116,7 @@ struct EndpointInner {
log: Logger,
socket: UdpSocket,
inner: quinn::Endpoint,
outgoing: VecDeque<(SocketAddrV6, Box<[u8]>)>,
outgoing: VecDeque<(SocketAddrV6, Option<quinn::EcnCodepoint>, Box<[u8]>)>,
epoch: Instant,
pending: FnvHashMap<ConnectionHandle, Pending>,
// TODO: Replace this with something custom that avoids using oneshots to cancel
Expand Down Expand Up @@ -646,12 +649,11 @@ impl Future for Driver {
let now = micros_from(endpoint.epoch.elapsed());
loop {
loop {
// TODO: Read ECN codepoint
match endpoint.socket.poll_recv_from(&mut buf) {
Ok(Async::Ready((n, addr))) => {
match endpoint.socket.poll_recv(&mut buf) {
Ok(Async::Ready((n, addr, ecn))) => {
endpoint
.inner
.handle(now, normalize(addr), None, (&buf[0..n]).into());
.handle(now, normalize(addr), ecn, (&buf[0..n]).into());
}
Ok(Async::NotReady) => {
break;
Expand Down Expand Up @@ -749,8 +751,11 @@ impl Future for Driver {
let mut blocked = false;
while !endpoint.outgoing.is_empty() {
{
let front = endpoint.outgoing.front().unwrap();
match endpoint.socket.poll_send_to(&front.1, &front.0.into()) {
let (destination, ecn, packet) = endpoint.outgoing.front().unwrap();
match endpoint
.socket
.poll_send(&(*destination).into(), *ecn, packet)
{
Ok(Async::Ready(_)) => {}
Ok(Async::NotReady) => {
blocked = true;
Expand All @@ -773,11 +778,10 @@ impl Future for Driver {
Transmit {
destination,
packet,
..
ecn,
} => {
if !blocked {
// TODO: Set ECN codepoint
match endpoint.socket.poll_send_to(&packet, &destination.into()) {
match endpoint.socket.poll_send(&destination.into(), ecn, &packet) {
Ok(Async::Ready(_)) => {}
Ok(Async::NotReady) => {
blocked = true;
Expand All @@ -791,7 +795,7 @@ impl Future for Driver {
}
}
if blocked {
endpoint.outgoing.push_front((destination, packet));
endpoint.outgoing.push_front((destination, ecn, packet));
}
}
TimerUpdate {
Expand Down
102 changes: 102 additions & 0 deletions quinn/src/platform/cmsg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use std::{mem, ptr};

macro_rules! cmsgs {
{$($level:ident { $($name:ident : $ty:ty;)* })*} => {
#[allow(non_camel_case_types)]
#[derive(Debug, Copy, Clone)]
pub enum Cmsg {
$($($name($ty),)*)*
}

impl Cmsg {
pub fn space(&self) -> usize {
let x = match *self {
$($(Cmsg::$name(_) => unsafe { libc::CMSG_SPACE(mem::size_of::<$ty>() as _)},)*)*
};
x as usize
}

unsafe fn encode(&self, cmsg: &mut libc::cmsghdr) {
match *self {
$($(Cmsg::$name(x) => {
cmsg.cmsg_level = libc::$level as _;
cmsg.cmsg_type = libc::$name as _;
cmsg.cmsg_len = libc::CMSG_LEN(mem::size_of::<$ty>() as _) as _;
ptr::write::<$ty>(libc::CMSG_DATA(cmsg) as *mut $ty, x);
})*)*
}
}

unsafe fn decode(cmsg: &libc::cmsghdr) -> Option<Self> {
Some(match cmsg.cmsg_level {
$(libc::$level => match cmsg.cmsg_type {
$(libc::$name => Cmsg::$name(ptr::read::<$ty>(libc::CMSG_DATA(cmsg) as *const $ty)),)*
_ => { return None; }
},)*
_ => { return None; }
})
}
}
}
}

cmsgs! {
IPPROTO_IP {
IP_TOS: u8;
}
IPPROTO_IPV6 {
IPV6_TCLASS: libc::c_int;
}
}

pub fn encode(hdr: &mut libc::msghdr, buf: &mut [u8], msgs: &[Cmsg]) {
assert!(buf.len() >= msgs.iter().map(|msg| msg.space()).sum());
hdr.msg_control = buf.as_mut_ptr() as _;
hdr.msg_controllen = buf.len() as _;

let mut len = 0;
let mut cursor = unsafe { libc::CMSG_FIRSTHDR(hdr) };
for msg in msgs {
unsafe {
msg.encode(&mut *cursor);
}
len += msg.space();
cursor = unsafe { libc::CMSG_NXTHDR(hdr, cursor) };
}
debug_assert!(len as usize <= buf.len());
hdr.msg_controllen = len;
}

pub struct Iter<'a> {
hdr: &'a libc::msghdr,
cmsg: *const libc::cmsghdr,
}

impl<'a> Iter<'a> {
/// # Safety
///
/// `hdr.msg_control` must point to mutable memory containing at least `hdr.msg_controllen`
/// bytes, which lives at least as long as `'a`.
pub unsafe fn new(hdr: &'a libc::msghdr) -> Self {
Self {
hdr,
cmsg: libc::CMSG_FIRSTHDR(hdr),
}
}
}

impl<'a> Iterator for Iter<'a> {
type Item = Cmsg;
fn next(&mut self) -> Option<Cmsg> {
loop {
if self.cmsg.is_null() {
return None;
}
let current = self.cmsg;
self.cmsg = unsafe { libc::CMSG_NXTHDR(self.hdr, self.cmsg) };
if let Some(x) = unsafe { Cmsg::decode(&*current) } {
return Some(x);
}
}
}
}
24 changes: 24 additions & 0 deletions quinn/src/platform/fallback.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use std::{io, net::SocketAddr};

use mio::net::UdpSocket;

use quinn_proto::EcnCodepoint;

impl super::UdpExt for UdpSocket {
fn init_ext(&self) -> io::Result<()> {
Ok(())
}

fn send_ext(
&self,
remote: &SocketAddr,
_: Option<EcnCodepoint>,
msg: &[u8],
) -> io::Result<usize> {
self.send_to(msg, remote)
}

fn recv_ext(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr, Option<EcnCodepoint>)> {
self.recv_from(buf).map(|(x, y)| (x, y, None))
}
}
143 changes: 143 additions & 0 deletions quinn/src/platform/linux.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
use std::os::unix::io::AsRawFd;
use std::{
io, mem,
net::{SocketAddr, SocketAddrV4, SocketAddrV6},
ptr,
};

use mio::net::UdpSocket;

use quinn_proto::EcnCodepoint;

use super::cmsg::{self, Cmsg};

impl super::UdpExt for UdpSocket {
fn init_ext(&self) -> io::Result<()> {
// Safety
assert_eq!(
mem::size_of::<SocketAddrV4>(),
mem::size_of::<libc::sockaddr_in>()
);
assert_eq!(
mem::size_of::<SocketAddrV6>(),
mem::size_of::<libc::sockaddr_in6>()
);
assert_eq!(
CMSG_LEN,
std::cmp::max(Cmsg::IP_TOS(0).space(), Cmsg::IPV6_TCLASS(0).space(),) as usize
);

if !self.only_v6()? {
let rc = unsafe {
libc::setsockopt(
self.as_raw_fd(),
libc::IPPROTO_IP,
libc::IP_RECVTOS,
&true as *const _ as _,
1,
)
};
if rc == -1 {
return Err(io::Error::last_os_error());
}
}
let on: libc::c_int = 1;
let rc = unsafe {
libc::setsockopt(
self.as_raw_fd(),
libc::IPPROTO_IPV6,
libc::IPV6_RECVTCLASS,
&on as *const _ as _,
mem::size_of::<libc::c_int>() as _,
)
};
if rc == -1 {
return Err(io::Error::last_os_error());
}
Ok(())
}

fn send_ext(
&self,
remote: &SocketAddr,
ecn: Option<EcnCodepoint>,
msg: &[u8],
) -> io::Result<usize> {
let (name, namelen) = match *remote {
SocketAddr::V4(ref addr) => {
(addr as *const _ as _, mem::size_of::<libc::sockaddr_in>())
}
SocketAddr::V6(ref addr) => {
(addr as *const _ as _, mem::size_of::<libc::sockaddr_in6>())
}
};
let ecn = ecn.map_or(0, |x| x as u8);
let mut iov = libc::iovec {
iov_base: msg.as_ptr() as *const _ as *mut _,
iov_len: msg.len(),
};
let mut hdr = libc::msghdr {
msg_name: name,
msg_namelen: namelen as _,
msg_iov: &mut iov,
msg_iovlen: 1,
msg_control: ptr::null_mut(),
msg_controllen: 0,
msg_flags: 0,
};
let cmsg;
if remote.is_ipv4() {
cmsg = Cmsg::IP_TOS(ecn as _);
} else {
cmsg = Cmsg::IPV6_TCLASS(ecn as _);
}
let mut ctrl: [u8; CMSG_LEN] = unsafe { mem::uninitialized() };
cmsg::encode(&mut hdr, &mut ctrl, &[cmsg]);
let n = unsafe { libc::sendmsg(self.as_raw_fd(), &hdr, 0) };
if n == -1 {
return Err(io::Error::last_os_error());
}
Ok(n as usize)
}

fn recv_ext(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr, Option<EcnCodepoint>)> {
let mut name: libc::sockaddr_storage = unsafe { mem::uninitialized() };
let mut iov = libc::iovec {
iov_base: buf.as_ptr() as *mut _,
iov_len: buf.len(),
};
let mut ctrl: [u8; CMSG_LEN] = unsafe { mem::uninitialized() };
let mut hdr = libc::msghdr {
msg_name: &mut name as *mut _ as _,
msg_namelen: mem::size_of::<libc::sockaddr_storage>() as _,
msg_iov: &mut iov,
msg_iovlen: 1,
msg_control: ctrl.as_mut_ptr() as _,
msg_controllen: CMSG_LEN as _,
msg_flags: 0,
};
let n = unsafe { libc::recvmsg(self.as_raw_fd(), &mut hdr, 0) };
if n == -1 {
return Err(io::Error::last_os_error());
}
let mut ecn = None;
for cmsg in unsafe { cmsg::Iter::new(&hdr) } {
match cmsg {
Cmsg::IP_TOS(bits) => {
ecn = EcnCodepoint::from_bits(bits);
}
Cmsg::IPV6_TCLASS(bits) => {
ecn = EcnCodepoint::from_bits(bits as u8);
}
}
}
let addr = match name.ss_family as libc::c_int {
libc::AF_INET => unsafe { SocketAddr::V4(ptr::read(&name as *const _ as _)) },
libc::AF_INET6 => unsafe { SocketAddr::V6(ptr::read(&name as *const _ as _)) },
_ => unreachable!(),
};
Ok((n as usize, addr, ecn))
}
}

const CMSG_LEN: usize = 24;
25 changes: 25 additions & 0 deletions quinn/src/platform/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//! Uniform interface to send/recv UDP packets with ECN information.
use quinn_proto::EcnCodepoint;
use std::{io, net::SocketAddr};

// The Linux code should work for most unixes, but as of this writing nobody's ported the
// CMSG_... macros to the libc crate for any of the BSDs.
#[cfg(target_os = "linux")]
mod cmsg;
#[cfg(target_os = "linux")]
mod linux;

// No ECN support
#[cfg(not(target_os = "linux"))]
mod fallback;

pub trait UdpExt {
fn init_ext(&self) -> io::Result<()>;
fn send_ext(
&self,
remote: &SocketAddr,
ecn: Option<EcnCodepoint>,
msg: &[u8],
) -> io::Result<usize>;
fn recv_ext(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr, Option<EcnCodepoint>)>;
}
Loading