Skip to content

Commit

Permalink
transports/quic: add Endpoint::try_send
Browse files Browse the repository at this point in the history
  • Loading branch information
elenaf9 committed Sep 10, 2022
1 parent 689460f commit 41d39fb
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 50 deletions.
27 changes: 15 additions & 12 deletions transports/quic/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ pub enum Error {
/// Endpoint has force-killed this connection because it was too busy.
#[error("Endpoint has force-killed our connection")]
ClosedChannel,
/// The background task driving the endpoint has crashed.
#[error("Background task crashed.")]
TaskCrashed,
/// Error in the inner state machine.
#[error("{0}")]
Quinn(#[from] quinn_proto::ConnectionError),
Expand Down Expand Up @@ -109,15 +112,15 @@ impl Connection {
/// Works for server connections only.
pub fn local_addr(&self) -> SocketAddr {
debug_assert_eq!(self.connection.side(), quinn_proto::Side::Server);
let endpoint_addr = self.endpoint.socket_addr;
let endpoint_addr = self.endpoint.socket_addr();
self.connection
.local_ip()
.map(|ip| SocketAddr::new(ip, endpoint_addr.port()))
.unwrap_or_else(|| {
// In a normal case scenario this should not happen, because
// we get want to get a local addr for a server connection only.
tracing::error!("trying to get quinn::local_ip for a client");
endpoint_addr
*endpoint_addr
})
}

Expand Down Expand Up @@ -214,17 +217,17 @@ impl Connection {
// However we don't deliver substream-related events to the user as long as
// `to_endpoint` is full. This should propagate the back-pressure of `to_endpoint`
// being full to the user.
if self.pending_to_endpoint.is_some() {
match self.endpoint.to_endpoint.poll_ready_unpin(cx) {
Poll::Ready(Ok(())) => {
let to_endpoint = self.pending_to_endpoint.take().expect("is some");
self.endpoint
.to_endpoint
.start_send(to_endpoint)
.expect("Channel is ready.");
if let Some(to_endpoint) = self.pending_to_endpoint.take() {
match self.endpoint.try_send(to_endpoint, cx) {
Ok(Ok(())) => {}
Ok(Err(to_endpoint)) => {
self.pending_to_endpoint = Some(to_endpoint);
return Poll::Pending;
}
Err(_) => {
tracing::error!("Background task crashed.");
return Poll::Ready(ConnectionEvent::ConnectionLost(Error::TaskCrashed));
}
Poll::Ready(Err(_)) => panic!("Background task crashed"),
Poll::Pending => return Poll::Pending,
}
}

Expand Down
28 changes: 24 additions & 4 deletions transports/quic/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
use crate::{connection::Connection, tls, transport};

use futures::{
channel::{mpsc, oneshot},
channel::{
mpsc::{self, SendError},
oneshot,
},
prelude::*,
};
use quinn_proto::{ClientConfig as QuinnClientConfig, ServerConfig as QuinnServerConfig};
Expand All @@ -40,7 +43,7 @@ use std::{
fmt,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
sync::Arc,
task::{Poll, Waker},
task::{Context, Poll, Waker},
time::{Duration, Instant},
};

Expand Down Expand Up @@ -87,9 +90,9 @@ impl Config {
#[derive(Clone)]
pub struct Endpoint {
/// Channel to the background of the endpoint.
pub to_endpoint: mpsc::Sender<ToEndpoint>,
to_endpoint: mpsc::Sender<ToEndpoint>,

pub socket_addr: SocketAddr,
socket_addr: SocketAddr,
}

impl Endpoint {
Expand Down Expand Up @@ -142,6 +145,23 @@ impl Endpoint {

Ok(endpoint)
}

pub fn socket_addr(&self) -> &SocketAddr {
&self.socket_addr
}

pub fn try_send(
&mut self,
to_endpoint: ToEndpoint,
cx: &mut Context<'_>,
) -> Result<Result<(), ToEndpoint>, SendError> {
match self.to_endpoint.poll_ready_unpin(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Err(err),
Poll::Pending => return Ok(Err(to_endpoint)),
};
self.to_endpoint.start_send(to_endpoint).map(Ok)
}
}

/// Message sent to the endpoint background task.
Expand Down
66 changes: 32 additions & 34 deletions transports/quic/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use crate::endpoint::ToEndpoint;
use crate::Config;
use crate::{endpoint::Endpoint, muxer::QuicMuxer, upgrade::Upgrade};

use futures::channel::mpsc::SendError;
use futures::channel::oneshot;
use futures::ready;
use futures::stream::StreamExt;
Expand Down Expand Up @@ -140,7 +141,7 @@ impl Transport for QuicTransport {
.listeners
.iter_mut()
.filter(|l| {
let listen_addr = l.endpoint.socket_addr;
let listen_addr = l.endpoint.socket_addr();
listen_addr.is_ipv4() == socket_addr.is_ipv4()
&& listen_addr.ip().is_loopback() == socket_addr.ip().is_loopback()
})
Expand Down Expand Up @@ -177,7 +178,7 @@ impl Transport for QuicTransport {
Ok(async move {
let connection = rx
.await
.expect("background task has crashed")
.map_err(|_| Error::TaskCrashed)?
.map_err(Error::Reach)?;
let final_connec = Upgrade::from_connection(connection).await?;
Ok(final_connec)
Expand All @@ -201,10 +202,18 @@ impl Transport for QuicTransport {
cx: &mut Context<'_>,
) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
if let Some(dialer) = self.ipv4_dialer.as_mut() {
dialer.drive_dials(cx)
if dialer.drive_dials(cx).is_err() {
// Background task of dialer crashed.
// Drop dialer and all pending dials so that the connection receiver is notified.
self.ipv4_dialer = None;
}
}
if let Some(dialer) = self.ipv6_dialer.as_mut() {
dialer.drive_dials(cx)
if dialer.drive_dials(cx).is_err() {
// Background task of dialer crashed.
// Drop dialer and all pending dials so that the connection receiver is notified.
self.ipv4_dialer = None;
}
}
match self.listeners.poll_next_unpin(cx) {
Poll::Ready(Some(ev)) => Poll::Ready(ev),
Expand All @@ -228,20 +237,18 @@ impl Dialer {
})
}

fn drive_dials(&mut self, cx: &mut Context<'_>) {
if !self.pending_dials.is_empty() {
match self.endpoint.to_endpoint.poll_ready_unpin(cx) {
Poll::Ready(Ok(())) => {
let to_endpoint = self.pending_dials.pop_front().expect("!is_empty");
self.endpoint
.to_endpoint
.start_send(to_endpoint)
.expect("Channel is ready.");
fn drive_dials(&mut self, cx: &mut Context<'_>) -> Result<(), SendError> {
if let Some(to_endpoint) = self.pending_dials.pop_front() {
match self.endpoint.try_send(to_endpoint, cx) {
Ok(Ok(())) => {}
Ok(Err(to_endpoint)) => self.pending_dials.push_front(to_endpoint),
Err(err) => {
tracing::error!("Background task of dialing endpoint crashed.");
return Err(err);
}
Poll::Ready(Err(_)) => panic!("Background task crashed."),
Poll::Pending => {}
}
}
Ok(())
}
}

Expand Down Expand Up @@ -282,7 +289,7 @@ impl Listener {
pending_event = None;
} else {
if_watcher = None;
let ma = socketaddr_to_multiaddr(&endpoint.socket_addr);
let ma = socketaddr_to_multiaddr(endpoint.socket_addr());
pending_event = Some(TransportEvent::NewAddress {
listener_id,
listen_addr: ma,
Expand Down Expand Up @@ -324,8 +331,8 @@ impl Listener {
match ready!(if_watcher.poll_if_event(cx)) {
Ok(IfEvent::Up(inet)) => {
let ip = inet.addr();
if self.endpoint.socket_addr.is_ipv4() == ip.is_ipv4() {
let socket_addr = SocketAddr::new(ip, self.endpoint.socket_addr.port());
if self.endpoint.socket_addr().is_ipv4() == ip.is_ipv4() {
let socket_addr = SocketAddr::new(ip, self.endpoint.socket_addr().port());
let ma = socketaddr_to_multiaddr(&socket_addr);
tracing::debug!("New listen address: {}", ma);
return Poll::Ready(TransportEvent::NewAddress {
Expand All @@ -336,8 +343,8 @@ impl Listener {
}
Ok(IfEvent::Down(inet)) => {
let ip = inet.addr();
if self.endpoint.socket_addr.is_ipv4() == ip.is_ipv4() {
let socket_addr = SocketAddr::new(ip, self.endpoint.socket_addr.port());
if self.endpoint.socket_addr().is_ipv4() == ip.is_ipv4() {
let socket_addr = SocketAddr::new(ip, self.endpoint.socket_addr().port());
let ma = socketaddr_to_multiaddr(&socket_addr);
tracing::debug!("Expired listen address: {}", ma);
return Poll::Ready(TransportEvent::AddressExpired {
Expand Down Expand Up @@ -371,23 +378,14 @@ impl Stream for Listener {
Poll::Ready(event) => return Poll::Ready(Some(event)),
Poll::Pending => {}
}
if !self.pending_dials.is_empty() {
match self.endpoint.to_endpoint.poll_ready_unpin(cx) {
Poll::Ready(Ok(_)) => {
let to_endpoint = self
.pending_dials
.pop_front()
.expect("Pending dials is not empty.");
self.endpoint
.to_endpoint
.start_send(to_endpoint)
.expect("Channel is ready");
}
Poll::Ready(Err(_)) => {
if let Some(to_endpoint) = self.pending_dials.pop_front() {
match self.endpoint.try_send(to_endpoint, cx) {
Ok(Ok(())) => {}
Ok(Err(to_endpoint)) => self.pending_dials.push_front(to_endpoint),
Err(_) => {
self.close(Err(Error::TaskCrashed));
continue;
}
Poll::Pending => {}
}
}
match self.new_connections_rx.poll_next_unpin(cx) {
Expand Down

0 comments on commit 41d39fb

Please sign in to comment.