Skip to content

Commit

Permalink
(todo refactor commits) address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
gretchenfrage committed Feb 15, 2024
1 parent 9ba8924 commit f674458
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 77 deletions.
12 changes: 10 additions & 2 deletions quinn-proto/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2203,7 +2203,10 @@ impl Connection {
}
ConnectionError::VersionMismatch => State::Draining,
ConnectionError::LocallyClosed => {
unreachable!("LocallyClosed isn't generated by packet processing")
unreachable!("LocallyClosed isn't generated by packet processing");
}
ConnectionError::ConnectionLimitExceeded => {
unreachable!("ConnectionLimitExceeded isn't generated by packet processing");
}
};
}
Expand Down Expand Up @@ -3501,6 +3504,9 @@ pub enum ConnectionError {
/// The local application closed the connection
#[error("closed")]
LocallyClosed,
/// The connection could not be created without exceeding the endpoint's connection limit
#[error("connection limit exceeded")]
ConnectionLimitExceeded,
}

impl From<Close> for ConnectionError {
Expand All @@ -3520,7 +3526,9 @@ impl From<ConnectionError> for io::Error {
TimedOut => io::ErrorKind::TimedOut,
Reset => io::ErrorKind::ConnectionReset,
ApplicationClosed(_) | ConnectionClosed(_) => io::ErrorKind::ConnectionAborted,
TransportError(_) | VersionMismatch | LocallyClosed => io::ErrorKind::Other,
TransportError(_) | VersionMismatch | LocallyClosed | ConnectionLimitExceeded => {
io::ErrorKind::Other
}
};
Self::new(kind, x)
}
Expand Down
29 changes: 13 additions & 16 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,15 +506,15 @@ impl Endpoint {
incoming: IncomingConnection,
now: Instant,
buf: &mut BytesMut,
) -> Result<(ConnectionHandle, Connection), Option<Transmit>> {
) -> Result<(ConnectionHandle, Connection), (ConnectionError, Option<Transmit>)> {
self.check_connection_limit(
incoming.version,
incoming.addresses,
&incoming.crypto,
&incoming.src_cid,
buf,
)
.map_err(Some)?;
.map_err(|response| (ConnectionError::ConnectionLimitExceeded, Some(response)))?;

let server_config = self.server_config.as_ref().unwrap().clone();

Expand Down Expand Up @@ -567,17 +567,18 @@ impl Endpoint {
Err(e) => {
debug!("handshake failed: {}", e);
self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained));
match e {
ConnectionError::TransportError(e) => Err(Some(self.initial_close(
let response = match e {
ConnectionError::TransportError(ref e) => Some(self.initial_close(
incoming.version,
incoming.addresses,
&incoming.crypto,
&incoming.src_cid,
e,
e.clone(),
buf,
))),
_ => Err(None),
}
)),
_ => None,
};
Err((e, response))
}
}
}
Expand Down Expand Up @@ -1050,9 +1051,11 @@ pub enum ConnectError {
UnsupportedVersion,
}

/// Error for attempting to retry an [`IncomingConnection`] that can not be retried
/// Error for attempting to retry an [`IncomingConnection`] which already bears an address
/// validation token from a previous retry
#[derive(Debug, Error)]
pub struct RetryError(pub IncomingConnection);
#[error("retry() with validated IncomingConnection")]
pub struct RetryError(IncomingConnection);

impl RetryError {
/// Get the [`IncomingConnection`]
Expand All @@ -1061,12 +1064,6 @@ impl RetryError {
}
}

impl fmt::Display for RetryError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("retry() with validated IncomingConnection")
}
}

/// Reset Tokens which are associated with peer socket addresses
///
/// The standard `HashMap` is used since both `SocketAddr` and `ResetToken` are
Expand Down
23 changes: 20 additions & 3 deletions quinn-proto/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ fn draft_version_compat() {
fn stateless_retry() {
let _guard = subscribe();
let mut pair = Pair::default();
pair.server.use_retry = true;
pair.server.retry_policy = RetryPolicy::yes();
pair.connect();
}

Expand Down Expand Up @@ -455,7 +455,7 @@ fn high_latency_handshake() {
fn zero_rtt_happypath() {
let _guard = subscribe();
let mut pair = Pair::default();
pair.server.use_retry = true;
pair.server.retry_policy = RetryPolicy::yes();
let config = client_config();

// Establish normal connection
Expand Down Expand Up @@ -1980,7 +1980,7 @@ fn connect_too_low_mtu() {

pair.begin_connect(client_config());
pair.drive();
pair.server.assert_no_accept()
pair.server.assert_no_accept();
}

#[test]
Expand Down Expand Up @@ -2750,3 +2750,20 @@ fn reject_new_connections() {
pair.server.assert_no_accept();
assert!(pair.client.connections.get(&client_ch).unwrap().is_closed());
}

#[test]
fn reject_manually() {
let _guard = subscribe();
let mut pair = Pair::default();
pair.server.retry_policy = RetryPolicy(Box::new(|_| IncomingConnectionResponse::Reject));

// The server should now reject incoming connections.
let client_ch = pair.begin_connect(client_config());
pair.drive();
let e = pair.server.assert_accept_error();
assert!(
matches!(e, crate::ConnectionError::ConnectionClosed(_)),
"wrong error"
);
assert!(pair.client.connections.get(&client_ch).unwrap().is_closed());
}
79 changes: 61 additions & 18 deletions quinn-proto/src/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,39 @@ pub(super) struct TestEndpoint {
pub(super) outbound: VecDeque<(Transmit, Bytes)>,
delayed: VecDeque<(Transmit, Bytes)>,
pub(super) inbound: VecDeque<(Instant, Option<EcnCodepoint>, BytesMut)>,
accepted: Option<ConnectionHandle>,
accepted: Option<Result<ConnectionHandle, ConnectionError>>,
pub(super) connections: HashMap<ConnectionHandle, Connection>,
conn_events: HashMap<ConnectionHandle, VecDeque<ConnectionEvent>>,
pub(super) captured_packets: Vec<Vec<u8>>,
pub(super) capture_inbound_packets: bool,
pub(super) use_retry: bool,
pub(super) retry_policy: RetryPolicy,
}

pub(super) struct RetryPolicy(
pub(super) Box<dyn Fn(&IncomingConnection) -> IncomingConnectionResponse>,
);

impl RetryPolicy {
pub(super) fn no() -> Self {
Self(Box::new(|_| IncomingConnectionResponse::Accept))
}

pub(super) fn yes() -> Self {
Self(Box::new(|incoming| {
if incoming.remote_address_validated() {
IncomingConnectionResponse::Accept
} else {
IncomingConnectionResponse::Retry
}
}))
}
}

#[derive(Debug, Copy, Clone)]
pub(super) enum IncomingConnectionResponse {
Accept,
Reject,
Retry,
}

impl TestEndpoint {
Expand All @@ -319,7 +346,7 @@ impl TestEndpoint {
conn_events: HashMap::default(),
captured_packets: Vec::new(),
capture_inbound_packets: false,
use_retry: false,
retry_policy: RetryPolicy::no(),
}
}

Expand All @@ -343,10 +370,16 @@ impl TestEndpoint {
{
match event {
DatagramEvent::NewConnection(incoming) => {
if self.use_retry && !incoming.remote_address_validated() {
self.retry(incoming);
} else {
self.try_accept(incoming, now);
match (self.retry_policy.0)(&incoming) {
IncomingConnectionResponse::Accept => {
let _ = self.try_accept(incoming, now);
}
IncomingConnectionResponse::Reject => {
self.reject(incoming);
}
IncomingConnectionResponse::Retry => {
self.retry(incoming);
}
}
}
DatagramEvent::ConnectionEvent(ch, event) => {
Expand Down Expand Up @@ -427,23 +460,23 @@ impl TestEndpoint {
&mut self,
incoming: IncomingConnection,
now: Instant,
) -> Option<ConnectionHandle> {
) -> Result<ConnectionHandle, ConnectionError> {
let mut buf = BytesMut::new();
match self.endpoint.accept(incoming, now, &mut buf) {
Ok((ch, conn)) => {
self.endpoint
.accept(incoming, now, &mut buf)
.map(|(ch, conn)| {
self.connections.insert(ch, conn);
self.accepted = Some(ch);
Some(ch)
}
Err(transmit) => {
self.accepted = Some(Ok(ch));
ch
})
.map_err(|(e, transmit)| {
if let Some(transmit) = transmit {
let size = transmit.size;
self.outbound
.extend(split_transmit(transmit, buf.split_to(size).freeze()));
}
None
}
}
e
})
}

pub(super) fn retry(&mut self, incoming: IncomingConnection) {
Expand All @@ -463,7 +496,17 @@ impl TestEndpoint {
}

pub(super) fn assert_accept(&mut self) -> ConnectionHandle {
self.accepted.take().expect("server didn't connect")
self.accepted
.take()
.expect("server didn't try connecting")
.expect("server experienced error connecting")
}

pub(super) fn assert_accept_error(&mut self) -> ConnectionError {
self.accepted
.take()
.expect("server didn't try connecting")
.expect_err("server did unexpectedly connect without error")
}

pub(super) fn assert_no_accept(&self) {
Expand Down
5 changes: 3 additions & 2 deletions quinn/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ use proto::congestion::Controller;

/// In-progress connection attempt future
#[derive(Debug)]
#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
pub struct Connecting {
conn: Option<ConnectionRef>,
connected: oneshot::Receiver<bool>,
Expand Down Expand Up @@ -152,14 +151,16 @@ impl Connecting {
///
/// On all non-supported platforms the local IP address will not be available,
/// and the method will return `None`.
///
/// Will panic if called after `poll` has returned `Ready`.
pub fn local_ip(&self) -> Option<IpAddr> {
let conn = self.conn.as_ref().unwrap();
let inner = conn.state.lock("local_ip");

inner.inner.local_ip()
}

/// The peer's UDP address.
/// The peer's UDP address
///
/// Will panic if called after `poll` has returned `Ready`.
pub fn remote_address(&self) -> SocketAddr {
Expand Down
29 changes: 15 additions & 14 deletions quinn/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use crate::runtime::{default_runtime, AsyncUdpSocket, Runtime};
use bytes::{Bytes, BytesMut};
use pin_project_lite::pin_project;
use proto::{
self as proto, ClientConfig, ConnectError, ConnectionHandle, DatagramEvent, ServerConfig,
self as proto, ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent,
ServerConfig,
};
use rustc_hash::FxHashMap;
use tokio::sync::{futures::Notified, mpsc, Notify};
Expand Down Expand Up @@ -137,9 +138,11 @@ impl Endpoint {

/// Get the next incoming connection attempt from a client
///
/// Yields [`IncomingConnection`]s, which can be `await`ed to obtain the final
/// [`Connection`](crate::Connection) or used in more complex ways such as to perform retries,
/// or `None` if the endpoint is [`close`](Self::close)d.
/// Yields [`IncomingConnection`]s, or `None` if the endpoint is [`close`](Self::close)d.
/// [`IncomingConnection`] can be `await`ed to obtain the final
/// [`Connection`](crate::Connection), or used to eg. filter connection attempts or force
/// address validation, or converted into an intermediate `Connecting` future which can be
/// used to eg. send 0.5-RTT data.
pub fn accept(&self) -> Accept<'_> {
Accept {
endpoint: self,
Expand Down Expand Up @@ -770,24 +773,22 @@ impl EndpointInner {
&self,
incoming: proto::IncomingConnection,
mut response_buffer: BytesMut,
) -> Option<Connecting> {
) -> Result<Connecting, ConnectionError> {
let mut state = self.state.lock().unwrap();
match state
state
.inner
.accept(incoming, Instant::now(), &mut response_buffer)
{
Ok((handle, conn)) => {
.map(|(handle, conn)| {
let socket = state.socket.clone();
let runtime = state.runtime.clone();
Some(state.connections.insert(handle, conn, socket, runtime))
}
Err(response) => {
state.connections.insert(handle, conn, socket, runtime)
})
.map_err(|(e, response)| {
if let Some(transmit) = response {
state.transmit_state.respond(transmit, response_buffer);
}
None
}
}
e
})
}

pub(crate) fn reject(
Expand Down
Loading

0 comments on commit f674458

Please sign in to comment.