Skip to content

Commit

Permalink
fix(hydroflow): cleanup temp tcp networking code, fix race condition fix
Browse files Browse the repository at this point in the history
 hydro-project#1458

only spawn one task to prevent races between tasks
  • Loading branch information
MingweiSamuel committed Sep 20, 2024
1 parent afe78c3 commit bced4ad
Showing 1 changed file with 96 additions and 72 deletions.
168 changes: 96 additions & 72 deletions hydroflow/src/util/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
#![cfg(not(target_arch = "wasm32"))]

use std::cell::RefCell;
use std::collections::hash_map::Entry::{Occupied, Vacant};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::pin::pin;
use std::rc::Rc;

use futures::{SinkExt, StreamExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::select;
use tokio::task::spawn_local;
use tokio_stream::StreamMap;
use tokio_util::codec::{
BytesCodec, Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec, LinesCodec,
};
Expand Down Expand Up @@ -74,6 +73,7 @@ pub type TcpFramedSink<T> = Sender<(T, SocketAddr)>;
pub type TcpFramedStream<Codec: Decoder> =
Receiver<Result<(<Codec as Decoder>::Item, SocketAddr), <Codec as Decoder>::Error>>;

// TODO(mingwei): this temporary code should be replaced with a properly thought out networking system.
/// Create a listening tcp socket, and then as new connections come in, receive their data and forward it to a queue.
pub async fn bind_tcp<T: 'static, Codec: 'static + Clone + Decoder + Encoder<T>>(
endpoint: SocketAddr,
Expand All @@ -83,60 +83,64 @@ pub async fn bind_tcp<T: 'static, Codec: 'static + Clone + Decoder + Encoder<T>>

let bound_endpoint = listener.local_addr()?;

let (tx_egress, mut rx_egress) = unsync_channel(None);
let (tx_ingress, rx_ingress) = unsync_channel(None);

let clients = Rc::new(RefCell::new(HashMap::new()));

spawn_local({
let clients = clients.clone();

async move {
while let Some((payload, addr)) = rx_egress.next().await {
let client = clients.borrow_mut().remove(&addr);

if let Some(mut sender) = client {
let _ = SinkExt::send(&mut sender, payload).await;
clients.borrow_mut().insert(addr, sender);
}
}
}
});
let (send_egress, mut recv_egress) = unsync_channel::<(T, SocketAddr)>(None);
let (send_ingres, recv_ingres) = unsync_channel(None);

spawn_local(async move {
let send_ingress = send_ingres;
let mut peers_send = HashMap::new();
let mut peers_recv = StreamMap::new();

loop {
let (stream, peer_addr) = if let Ok((stream, _)) = listener.accept().await {
if let Ok(peer_addr) = stream.peer_addr() {
(stream, peer_addr)
} else {
continue;
// Calling methods in a loop, futures must be cancel-safe.
select! {
biased;
// Accept new clients.
new_peer = listener.accept() => {
let Ok((stream, _addr)) = new_peer else {
continue;
};
let Ok(peer_addr) = stream.peer_addr() else {
continue;
};
let (peer_send, peer_recv) = tcp_framed(stream, codec.clone());

// TODO: Using peer_addr here as the key is a little bit sketchy.
// It's possible that a peer could send a message, disconnect, then another peer connects from the
// same IP address (and the same src port), and then the response could be sent to that new client.
// This can be solved by using monotonically increasing IDs for each new peer, but would break the
// similarity with the UDP versions of this function.
peers_send.insert(peer_addr, peer_send);
peers_recv.insert(peer_addr, peer_recv);
}
} else {
continue;
};

let mut tx_ingress = tx_ingress.clone();

let (send, recv) = tcp_framed(stream, codec.clone());

// TODO: Using peer_addr here as the key is a little bit sketchy.
// It's possible that a client could send a message, disconnect, then another client connects from the same IP address (and the same src port), and then the response could be sent to that new client.
// This can be solved by using monotonically increasing IDs for each new client, but would break the similarity with the UDP versions of this function.
clients.borrow_mut().insert(peer_addr, send);

spawn_local({
let clients = clients.clone();
async move {
let mapped = recv.map(|x| Ok(x.map(|x| (x, peer_addr))));
let _ = tx_ingress.send_all(&mut pin!(mapped)).await;

clients.borrow_mut().remove(&peer_addr);
// Send outgoing messages.
msg_send = recv_egress.next() => {
let Some((payload, peer_addr)) = msg_send else {
continue;
};
let Some(stream) = peers_send.get_mut(&peer_addr) else {
eprintln!("Dropping message to non-connected peer: {}", peer_addr);
continue;
};
if let Err(_err) = SinkExt::send(stream, payload).await {
eprintln!("Failed to send message to peer: {}", peer_addr);
};
}
// Receive incoming messages.
msg_recv = peers_recv.next() => {
let Some((peer_addr, payload_result)) = msg_recv else {
eprintln!("Error receiving message");
continue;
};
if let Err(err) = send_ingress.send(payload_result.map(|payload| (payload, peer_addr))).await {
eprintln!("Error passing along received message: {:?}", err);
}
}
});
}
}
});

Ok((tx_egress, rx_ingress, bound_endpoint))
Ok((send_egress, recv_ingres, bound_endpoint))
}

/// The inverse of [`bind_tcp`].
Expand All @@ -147,34 +151,54 @@ pub async fn bind_tcp<T: 'static, Codec: 'static + Clone + Decoder + Encoder<T>>
pub fn connect_tcp<T: 'static, Codec: 'static + Clone + Decoder + Encoder<T>>(
codec: Codec,
) -> (TcpFramedSink<T>, TcpFramedStream<Codec>) {
let (tx_egress, mut rx_egress) = unsync_channel(None);
let (tx_ingress, rx_ingress) = unsync_channel(None);
let (send_egress, mut recv_egress) = unsync_channel(None);
let (send_ingres, recv_ingres) = unsync_channel(None);

spawn_local(async move {
let mut streams = HashMap::new();

while let Some((payload, addr)) = rx_egress.next().await {
let stream = match streams.entry(addr) {
Occupied(entry) => entry.into_mut(),
Vacant(entry) => {
let socket = TcpSocket::new_v4().unwrap();
let stream = socket.connect(addr).await.unwrap();

let (send, recv) = tcp_framed(stream, codec.clone());
let send_ingres = send_ingres;
let mut peers_send = HashMap::new();
let mut peers_recv = StreamMap::new();

let mut tx_ingress = tx_ingress.clone();
spawn_local(async move {
let mapped = recv.map(|x| Ok(x.map(|x| (x, addr))));
let _ = tx_ingress.send_all(&mut pin!(mapped)).await;
});

entry.insert(send)
loop {
// Calling methods in a loop, futures must be cancel-safe.
select! {
biased;
// Send outgoing messages.
msg_send = recv_egress.next() => {
let Some((payload, peer_addr)) = msg_send else {
continue;
};

let stream = match peers_send.entry(peer_addr) {
Occupied(entry) => entry.into_mut(),
Vacant(entry) => {
let socket = TcpSocket::new_v4().unwrap();
let stream = socket.connect(peer_addr).await.unwrap();

let (peer_send, peer_recv) = tcp_framed(stream, codec.clone());

peers_recv.insert(peer_addr, peer_recv);
entry.insert(peer_send)
}
};

if let Err(_err) = stream.send(payload).await {
eprintln!("Failed to send message to peer: {}", peer_addr);
}
}
};

let _ = stream.send(payload).await;
// Receive incoming messages.
msg_recv = peers_recv.next() => {
let Some((peer_addr, payload_result)) = msg_recv else {
eprintln!("Error receiving message");
continue;
};
if let Err(err) = send_ingres.send(payload_result.map(|payload| (payload, peer_addr))).await {
eprintln!("Error passing along received message: {:?}", err);
}
}
}
}
});

(tx_egress, rx_ingress)
(send_egress, recv_ingres)
}

0 comments on commit bced4ad

Please sign in to comment.