Skip to content

Commit

Permalink
feat(stream): include connection id information as part of the StreamId
Browse files Browse the repository at this point in the history
- Make use of the connection id string to be part of the `StreamId` string. In some
cases the user may have several connections to the same peer, so having the connection
id as part of the stream id string can help the user to uniquely identify them.
  • Loading branch information
bochaco authored and joshuef committed Dec 6, 2022
1 parent 5465aeb commit 8acd490
Showing 1 changed file with 50 additions and 31 deletions.
81 changes: 50 additions & 31 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,18 @@ impl Connection {
let (alive_tx, alive_rx) = watch::channel(());
let alive_tx = Arc::new(alive_tx);
let peer_address = connection.connection.remote_address();
let conn = Self {
inner: connection.connection,
default_retry_config,
_alive_tx: Arc::clone(&alive_tx),
};
let conn_id = conn.id();

(
Self {
inner: connection.connection,
default_retry_config,
_alive_tx: Arc::clone(&alive_tx),
},
conn,
ConnectionIncoming::new(
endpoint,
conn_id,
peer_address,
connection.uni_streams,
connection.bi_streams,
Expand All @@ -76,9 +79,8 @@ impl Connection {
/// the peer. So this _should_ be unique per peer (without IP spoofing).
///
pub fn id(&self) -> String {
let socket = self.remote_address();

format!("{}{}", socket, self.inner.stable_id())
let remote_addr = self.remote_address();
format!("{remote_addr}{}", self.inner.stable_id())
}

/// The address of the remote peer.
Expand Down Expand Up @@ -134,7 +136,7 @@ impl Connection {
/// Messages sent over the stream will arrive at the peer in the order they were sent.
pub async fn open_uni(&self) -> Result<SendStream, ConnectionError> {
let send_stream = self.inner.open_uni().await?;
Ok(SendStream::new(send_stream))
Ok(SendStream::new(send_stream, self.id()))
}

/// Open a bidirectional stream to the peer.
Expand All @@ -145,7 +147,10 @@ impl Connection {
/// Messages sent over the stream will arrive at the peer in the order they were sent.
pub async fn open_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
let (send_stream, recv_stream) = self.inner.open_bi().await?;
Ok((SendStream::new(send_stream), RecvStream::new(recv_stream)))
Ok((
SendStream::new(send_stream, self.id()),
RecvStream::new(recv_stream, self.id()),
))
}

/// Close the connection immediately.
Expand Down Expand Up @@ -188,41 +193,46 @@ impl fmt::Debug for Connection {

/// Identifier for a stream within a particular connection
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct StreamId(quinn::StreamId);
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct StreamId {
stream_id: quinn::StreamId,
conn_id: String,
}

impl fmt::Display for StreamId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let initiator = match self.0.initiator() {
quinn_proto::Side::Client => "initiator",
quinn_proto::Side::Server => "acceptor",
};
let dir = match self.0.dir() {
quinn_proto::Dir::Uni => "uni",
quinn_proto::Dir::Bi => "bi",
let initiator = if self.stream_id.initiator().is_client() {
"initiator"
} else {
"acceptor"
};
let dir = self.stream_id.dir();
write!(
f,
"{} {}directional stream {}",
initiator,
dir,
self.0.index()
"{initiator} {dir:?}directional stream {}@{}",
self.stream_id.index(),
self.conn_id
)
}
}

/// The sending API for a QUIC stream.
pub struct SendStream {
conn_id: String,
inner: quinn::SendStream,
}

impl SendStream {
fn new(inner: quinn::SendStream) -> Self {
Self { inner }
fn new(inner: quinn::SendStream, conn_id: String) -> Self {
Self { conn_id, inner }
}

/// Get the identity of this stream
pub fn id(&self) -> StreamId {
StreamId(self.inner.id())
StreamId {
stream_id: self.inner.id(),
conn_id: self.conn_id.clone(),
}
}

/// Set the priority of the send stream.
Expand Down Expand Up @@ -269,17 +279,21 @@ impl fmt::Debug for SendStream {

/// The receiving API for a bidirectional QUIC stream.
pub struct RecvStream {
conn_id: String,
inner: quinn::RecvStream,
}

impl RecvStream {
fn new(inner: quinn::RecvStream) -> Self {
Self { inner }
fn new(inner: quinn::RecvStream, conn_id: String) -> Self {
Self { conn_id, inner }
}

/// Get the identity of this stream
pub fn id(&self) -> StreamId {
StreamId(self.inner.id())
StreamId {
stream_id: self.inner.id(),
conn_id: self.conn_id.clone(),
}
}

/// Get the next message sent by the peer over this stream.
Expand Down Expand Up @@ -312,6 +326,7 @@ pub struct ConnectionIncoming {
impl ConnectionIncoming {
fn new(
endpoint: quinn::Endpoint,
conn_id: String,
peer_addr: SocketAddr,
uni_streams: quinn::IncomingUniStreams,
bi_streams: quinn::IncomingBiStreams,
Expand All @@ -324,6 +339,7 @@ impl ConnectionIncoming {
// `alive_tx` is dropped, which would be when both sides of the connection are dropped.
start_message_listeners(
endpoint,
conn_id,
peer_addr,
uni_streams,
bi_streams,
Expand Down Expand Up @@ -361,6 +377,7 @@ impl ConnectionIncoming {
// `message_tx` is used to exfiltrate messages and stream errors.
fn start_message_listeners(
endpoint: quinn::Endpoint,
conn_id: String,
peer_addr: SocketAddr,
uni_streams: quinn::IncomingUniStreams,
bi_streams: quinn::IncomingBiStreams,
Expand All @@ -375,7 +392,7 @@ fn start_message_listeners(
));

let _ = tokio::spawn(listen_on_bi_streams(
endpoint, peer_addr, bi_streams, alive_rx, message_tx,
endpoint, conn_id, peer_addr, bi_streams, alive_rx, message_tx,
));
}

Expand Down Expand Up @@ -456,6 +473,7 @@ async fn listen_on_uni_streams(

async fn listen_on_bi_streams(
endpoint: quinn::Endpoint,
conn_id: String,
peer_addr: SocketAddr,
bi_streams: quinn::IncomingBiStreams,
mut alive_rx: watch::Receiver<()>,
Expand All @@ -465,6 +483,7 @@ async fn listen_on_bi_streams(
let streaming = bi_streams.try_for_each_concurrent(None, |(send_stream, mut recv_stream)| {
let endpoint = &endpoint;
let message_tx = &message_tx;
let conn_id = &conn_id;
async move {
trace!("Handling incoming bi-stream from {peer_addr}");
match WireMsg::read_from_stream(&mut recv_stream).await {
Expand All @@ -479,7 +498,7 @@ async fn listen_on_bi_streams(
if let Err(msg) = message_tx
.send(Ok((
(header, dst, payload),
Some(SendStream::new(send_stream)),
Some(SendStream::new(send_stream, conn_id.clone())),
)))
.await
{
Expand Down

0 comments on commit 8acd490

Please sign in to comment.