Skip to content

Commit

Permalink
Hack multiplexing...
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshLind committed Oct 7, 2023
1 parent 7a1a302 commit c433f1b
Showing 1 changed file with 196 additions and 65 deletions.
261 changes: 196 additions & 65 deletions network/netcore/src/transport/quic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ use std::{
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};

// Useful constants
const NUM_BYTES_PER_STREAM_FRAGMENT: usize = 1000;
const NUM_STREAMS_TO_CREATE: u64 = 5;
const SERVER_STRING: &str = "aptos-node";
const STREAM_START_MESSAGE: &str = "start-stream";
const STREAM_START_MESSSAGE_LENGTH: usize = 12; // Update this if the stream start message changes!
Expand Down Expand Up @@ -391,111 +393,240 @@ async fn create_quic_connection(connection: quinn::Connection) -> io::Result<Qui
#[allow(dead_code)]
pub struct QuicConnection {
connection: quinn::Connection,
send_stream: Compat<quinn::SendStream>,
recv_stream: Compat<quinn::RecvStream>,

// The set of send and receive streams
send_streams: Vec<Compat<quinn::SendStream>>,
recv_streams: Vec<Compat<quinn::RecvStream>>,

// Indices to identify which streams to multiplex across
next_stream_index_to_write: usize,
next_stream_index_to_read: usize,

// The number of bytes written and read in the current fragment
num_bytes_written_in_fragment: usize,
num_bytes_read_in_fragment: usize,
}

impl QuicConnection {
pub async fn new(connection: quinn::Connection) -> io::Result<Self> {
// Open a uni-directional stream and send a stream start
// message so that the receiver can accept it.
let send_connection = connection.clone();
let send_stream = tokio::task::spawn(async move {
// Open the stream
let mut send_stream = send_connection.open_uni().await.unwrap();
info!(
"(QUIC remote: {:?}) Opened a new send stream!",
send_connection.remote_address(),
);

// Send a stream start message
send_stream
.write_all(STREAM_START_MESSAGE.as_bytes())
.await
.unwrap();
info!(
"(QUIC remote: {:?}) Wrote stream start message",
send_connection.remote_address(),
);
// Create the specified number of streams
let mut send_streams = Vec::new();
let mut recv_streams = Vec::new();
for stream_id in 0..NUM_STREAMS_TO_CREATE {
// Open a uni-directional stream and send a stream start
// message so that the receiver can accept it.
let send_connection = connection.clone();
let send_stream = tokio::task::spawn(async move {
// Open the stream
let mut send_stream = send_connection.open_uni().await.unwrap();
info!(
"(QUIC remote: {:?}) Opened a new send stream! Stream ID: {:?}",
send_connection.remote_address(),
stream_id
);

send_stream
});
// Send a stream start message
send_stream
.write_all(STREAM_START_MESSAGE.as_bytes())
.await
.unwrap();
info!(
"(QUIC remote: {:?}) Wrote stream start message. Stream ID: {:?}",
send_connection.remote_address(),
stream_id
);

// Accept the stream so that we have a receiver
let recv_connection = connection.clone();
let recv_stream = tokio::task::spawn(async move {
// Accept the stream
let mut recv_stream = recv_connection.accept_uni().await.unwrap();
info!(
"(QUIC remote: {:?}) Accepted a new recv stream!",
recv_connection.remote_address(),
);
send_stream
});

// Read the stream start message
let mut buf = [0; STREAM_START_MESSSAGE_LENGTH];
recv_stream.read_exact(&mut buf).await.unwrap();
info!(
"(QUIC remote: {:?}) Read stream start message",
recv_connection.remote_address(),
);
// Accept the stream so that we have a receiver
let recv_connection = connection.clone();
let recv_stream = tokio::task::spawn(async move {
// Accept the stream
let mut recv_stream = recv_connection.accept_uni().await.unwrap();
info!(
"(QUIC remote: {:?}) Accepted a new recv stream! Stream ID: {:?}",
recv_connection.remote_address(),
stream_id
);

// Verify the stream start message
if buf == STREAM_START_MESSAGE.as_bytes() {
// Read the stream start message
let mut buf = [0; STREAM_START_MESSSAGE_LENGTH];
recv_stream.read_exact(&mut buf).await.unwrap();
info!(
"(QUIC remote: {:?}) Stream start message is valid!",
"(QUIC remote: {:?}) Read stream start message. Stream ID: {:?}",
recv_connection.remote_address(),
stream_id
);
} else {
panic!("The stream start message is invalid!!");
}

recv_stream
});
// Verify the stream start message
if buf == STREAM_START_MESSAGE.as_bytes() {
info!(
"(QUIC remote: {:?}) Stream start message is valid! Stream ID: {:?}",
recv_connection.remote_address(),
stream_id
);
} else {
panic!(
"The stream start message is invalid!! Stream ID: {:?}",
stream_id
);
}

recv_stream
});

let (send_stream, recv_stream) = join(send_stream, recv_stream).await;
let send_stream = send_stream?;
let recv_stream = recv_stream?;
// Join the tasks to wait for the send and receive streams
let (send_stream, recv_stream) = join(send_stream, recv_stream).await;
let send_stream = send_stream?;
let recv_stream = recv_stream?;

// Save the send and receive streams
send_streams.push(send_stream.compat_write());
recv_streams.push(recv_stream.compat());
}

// Create the QUIC connection
Ok(Self {
connection,
send_stream: send_stream.compat_write(),
recv_stream: recv_stream.compat(),
send_streams,
recv_streams,
next_stream_index_to_write: 0,
next_stream_index_to_read: 0,
num_bytes_written_in_fragment: 0,
num_bytes_read_in_fragment: 0,
})
}
}

// TODO: this is unbelievably hacky!
impl AsyncRead for QuicConnection {
#[allow(clippy::comparison_chain)]
fn poll_read(
mut self: Pin<&mut Self>,
context: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.recv_stream).poll_read(context, buf)
// Calculate the number of bytes to read
let buffer_to_read_len = buf.len();
let min_buffer_size = std::cmp::min(buffer_to_read_len, NUM_BYTES_PER_STREAM_FRAGMENT);
let bytes_to_read = std::cmp::min(
min_buffer_size,
NUM_BYTES_PER_STREAM_FRAGMENT - self.num_bytes_read_in_fragment,
);

// Create an internal buffer for the read
let mut internal_buffer = [0; NUM_BYTES_PER_STREAM_FRAGMENT];
let internal_buffer_slice = &mut internal_buffer[0..bytes_to_read];

// Identify the next stream to read from
let next_stream_to_read = self.next_stream_index_to_read;
let recv_stream = &mut self.recv_streams[next_stream_to_read];

// Read the bytes into the internal buffer
let recv_stream = Pin::new(recv_stream);
match recv_stream.poll_read(context, internal_buffer_slice) {
Poll::Ready(Err(error)) => {
// Something went wrong!
Poll::Ready(Err(error))
},
Poll::Pending => {
// We need to be polled again
Poll::Pending
},
Poll::Ready(Ok(num_bytes_read)) => {
// We read some bytes from the stream. Copy them into the
// output buffer and update our tracking.
buf[0..num_bytes_read].copy_from_slice(&internal_buffer[0..num_bytes_read]);

// Update our tracking
self.num_bytes_read_in_fragment += num_bytes_read;

// If we have read the entire fragment, we need to move to the next stream
// and reset our tracking.
if self.num_bytes_read_in_fragment == NUM_BYTES_PER_STREAM_FRAGMENT {
self.next_stream_index_to_read =
(self.next_stream_index_to_read + 1) % NUM_STREAMS_TO_CREATE as usize;
self.num_bytes_read_in_fragment = 0;
} else if self.num_bytes_read_in_fragment > NUM_BYTES_PER_STREAM_FRAGMENT {
panic!("We read too many bytes from the stream!");
}

// Return the number of bytes read
Poll::Ready(Ok(num_bytes_read))
},
}
}
}

// TODO: this is unbelievably hacky!
impl AsyncWrite for QuicConnection {
#[allow(clippy::comparison_chain)]
fn poll_write(
mut self: Pin<&mut Self>,
context: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
// TODO: write to multiple streams?
let send_stream = &mut self.send_stream;
Pin::new(send_stream).poll_write(context, buf)
// Calculate the number of bytes to write
let buffer_to_write_len = buf.len();
let min_buffer_size = std::cmp::min(buffer_to_write_len, NUM_BYTES_PER_STREAM_FRAGMENT);
let bytes_to_write = std::cmp::min(
min_buffer_size,
NUM_BYTES_PER_STREAM_FRAGMENT - self.num_bytes_written_in_fragment,
);

// Create an internal buffer for the write
let mut internal_buffer = [0; NUM_BYTES_PER_STREAM_FRAGMENT];

// Copy the bytes to write into the internal buffer
internal_buffer[0..bytes_to_write].copy_from_slice(&buf[0..bytes_to_write]);

// Create a slice for the internal buffer
let internal_buffer_slice = &mut internal_buffer[0..bytes_to_write];

// Identify the next stream to write to
let next_stream_to_write = self.next_stream_index_to_write;
let send_stream = &mut self.send_streams[next_stream_to_write];

// Write the bytes to the stream
let send_stream = Pin::new(send_stream);
match send_stream.poll_write(context, internal_buffer_slice) {
Poll::Ready(Err(error)) => {
// Something went wrong!
Poll::Ready(Err(error))
},
Poll::Pending => {
// We need to be polled again
Poll::Pending
},
Poll::Ready(Ok(num_bytes_written)) => {
// We wrote some bytes to the stream. Update our tracking.
self.num_bytes_written_in_fragment += num_bytes_written;

// If we have written the entire fragment, we need to move to the next stream
// and reset our tracking.
if self.num_bytes_written_in_fragment == NUM_BYTES_PER_STREAM_FRAGMENT {
self.next_stream_index_to_write =
(self.next_stream_index_to_write + 1) % NUM_STREAMS_TO_CREATE as usize;
self.num_bytes_written_in_fragment = 0;
} else if self.num_bytes_written_in_fragment > NUM_BYTES_PER_STREAM_FRAGMENT {
panic!("We wrote too many bytes to the stream!");
}

// Return the number of bytes written
Poll::Ready(Ok(num_bytes_written))
},
}
}

fn poll_flush(mut self: Pin<&mut Self>, context: &mut Context) -> Poll<io::Result<()>> {
// TODO: flush all streams?
let send_stream = &mut self.send_stream;
Pin::new(send_stream).poll_flush(context)
fn poll_flush(self: Pin<&mut Self>, _context: &mut Context) -> Poll<io::Result<()>> {
// TODO: do we want to support this?
Poll::Ready(Ok(()))
}

fn poll_close(mut self: Pin<&mut Self>, _context: &mut Context) -> Poll<io::Result<()>> {
// TODO: close all streams?
let send_stream = &mut self.send_stream;
Pin::new(send_stream).poll_close(_context)
fn poll_close(self: Pin<&mut Self>, _context: &mut Context) -> Poll<io::Result<()>> {
// TODO: do we want to support this?
Poll::Ready(Ok(()))
}
}

Expand Down

0 comments on commit c433f1b

Please sign in to comment.