Skip to content

Commit

Permalink
network: Buffer partial reads for cancellation safety
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Aug 1, 2023
1 parent aae4236 commit e19282a
Showing 1 changed file with 106 additions and 19 deletions.
125 changes: 106 additions & 19 deletions src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod config;
mod mock;

#[cfg(any(feature = "test_helpers", test))]
pub use mock::{NoRecvNetwork, UnboundedDuplexStream, MockNetwork};
pub use mock::{MockNetwork, NoRecvNetwork, UnboundedDuplexStream};

use async_trait::async_trait;
use quinn::{Endpoint, RecvStream, SendStream};
Expand All @@ -27,6 +27,8 @@ const BYTES_PER_U64: usize = 8;

/// Error message emitted when reading a message length from the stream fails
const ERR_READ_MESSAGE_LENGTH: &str = "error reading message length from stream";
/// Error thrown when a stream finishes early
const ERR_STREAM_FINISHED_EARLY: &str = "stream finished early";

// ---------
// | Trait |
Expand Down Expand Up @@ -121,6 +123,60 @@ pub enum ReadWriteOrder {
WriteFirst,
}

/// A wrapper around a raw [u8] buffer that tracks the cursor within the buffer
/// to allow partial fills across cancelled futures
///
/// Similar to `tokio::io::ReadBuf` but takes ownership of the underlying buffer to
/// avoid coloring interfaces with lifetime parameters
#[derive(Debug)]
struct BufferWithCursor {
/// The underlying buffer
buffer: Vec<u8>,
/// The current cursor position
cursor: usize,
}

impl BufferWithCursor {
/// Create a new buffer with a cursor at the start of the buffer
pub fn new(buf: Vec<u8>) -> Self {
assert_eq!(
buf.len(),
buf.capacity(),
"buffer must be fully initialized"
);

Self {
buffer: buf,
cursor: 0,
}
}

/// The number of bytes remaining in the buffer
pub fn remaining(&self) -> usize {
self.buffer.capacity() - self.cursor
}

/// Whether the buffer is full
pub fn is_full(&self) -> bool {
self.remaining() == 0
}

/// Get a mutable reference to the empty section of the underlying buffer
pub fn get_unfilled(&mut self) -> &mut [u8] {
&mut self.buffer[self.cursor..]
}

/// Advance the cursor by `n` bytes
pub fn advance_cursor(&mut self, n: usize) {
self.cursor += n
}

/// Take ownership of the underlying buffer
pub fn into_vec(self) -> Vec<u8> {
self.buffer
}
}

/// Implements an MpcNetwork on top of QUIC
#[derive(Debug)]
pub struct QuicTwoPartyNet {
Expand All @@ -132,6 +188,19 @@ pub struct QuicTwoPartyNet {
local_addr: SocketAddr,
/// Addresses of the counterparties in the MPC
peer_addr: SocketAddr,
/// A buffered message length read from the stream
///
/// In the case that the whole message is not available yet, `read_exact` may block
/// and the `read` method may be cancelled. We buffer the message length to avoid re-reading
/// the message length incorrectly from the stream. Essentially this field gives cancellation
/// safety to the `read` method.
buffered_message_length: Option<u64>,
/// A buffered partial message read from the stream
///
/// This buffer exists to provide cancellation safety to a `read` future as the underlying `quinn`
/// stream is not cancellation safe, i.e. if a `ReadBuf` future is dropped, the buffer is dropped with
/// it and the message skipped
buffered_message: Option<BufferWithCursor>,
/// The send side of the bidirectional stream
send_stream: Option<SendStream>,
/// The receive side of the bidirectional stream
Expand All @@ -148,6 +217,8 @@ impl<'a> QuicTwoPartyNet {
local_addr,
peer_addr,
connected: false,
buffered_message_length: None,
buffered_message: None,
send_stream: None,
recv_stream: None,
}
Expand Down Expand Up @@ -244,14 +315,7 @@ impl<'a> QuicTwoPartyNet {

/// Read a message length from the stream
async fn read_message_length(&mut self) -> Result<u64, MpcNetworkError> {
let mut read_buffer = vec![0u8; BYTES_PER_U64];
self.recv_stream
.as_mut()
.unwrap()
.read_exact(&mut read_buffer)
.await
.map_err(|e| MpcNetworkError::RecvError(e.to_string()))?;

let read_buffer = self.read_bytes(BYTES_PER_U64).await?;
Ok(u64::from_le_bytes(read_buffer.try_into().map_err(
|_| MpcNetworkError::SerializationError(ERR_READ_MESSAGE_LENGTH.to_string()),
)?))
Expand All @@ -269,15 +333,30 @@ impl<'a> QuicTwoPartyNet {

/// Read exactly `n` bytes from the stream
async fn read_bytes(&mut self, num_bytes: usize) -> Result<Vec<u8>, MpcNetworkError> {
let mut read_buffer = vec![0u8; num_bytes];
self.recv_stream
.as_mut()
.unwrap()
.read_exact(&mut read_buffer)
.await
.map_err(|e| MpcNetworkError::RecvError(e.to_string()))?;
// Allocate a buffer for the next message if one does not already exist
if self.buffered_message.is_none() {
self.buffered_message = Some(BufferWithCursor::new(vec![0u8; num_bytes]));
}

// Read until the buffer is full
let read_buffer = self.buffered_message.as_mut().unwrap();
while !read_buffer.is_full() {
let bytes_read = self
.recv_stream
.as_mut()
.unwrap()
.read(read_buffer.get_unfilled())
.await
.map_err(|e| MpcNetworkError::RecvError(e.to_string()))?
.ok_or(MpcNetworkError::RecvError(
ERR_STREAM_FINISHED_EARLY.to_string(),
))?;

read_buffer.advance_cursor(bytes_read);
}

Ok(read_buffer.to_vec())
// Take ownership of the buffer, and reset the buffered message to `None`
Ok(self.buffered_message.take().unwrap().into_vec())
}
}

Expand All @@ -298,10 +377,18 @@ impl MpcNetwork for QuicTwoPartyNet {
}

async fn receive_message(&mut self) -> Result<NetworkOutbound, MpcNetworkError> {
// Read the message length from the buffer
let len = self.read_message_length().await?;
// Read the message length from the buffer if already read from the stream
if self.buffered_message_length.is_none() {
self.buffered_message_length = Some(self.read_message_length().await?);
}

// Read the data from the stream
let len = self.buffered_message_length.unwrap();
let bytes = self.read_bytes(len as usize).await?;

// Reset the message length buffer after the data has been pulled from the stream
self.buffered_message_length = None;

// Deserialize the message
serde_json::from_slice(&bytes)
.map_err(|err| MpcNetworkError::SerializationError(err.to_string()))
Expand Down

0 comments on commit e19282a

Please sign in to comment.