Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(iroh-net)!: remove async channel #2620

Merged
merged 8 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions iroh-net/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ workspace = true

[dependencies]
anyhow = { version = "1" }
async-channel = "2.3.1"
base64 = "0.22.1"
backoff = "0.4.0"
bytes = "1"
Expand Down Expand Up @@ -58,7 +57,6 @@ ring = "0.17"
rustls = { version = "0.21.11", default-features = false, features = ["dangerous_configuration"] }
serde = { version = "1", features = ["derive", "rc"] }
smallvec = "1.11.1"
swarm-discovery = { version = "0.2.1", optional = true }
socket2 = "0.5.3"
stun-rs = "0.1.5"
surge-ping = "0.8.0"
Expand Down Expand Up @@ -92,6 +90,11 @@ tokio-rustls-acme = { version = "0.3", optional = true }
iroh-metrics = { version = "0.22.0", path = "../iroh-metrics", default-features = false }
strum = { version = "0.26.2", features = ["derive"] }

# local_swarm_discovery
swarm-discovery = { version = "0.2.1", optional = true }
tokio-stream = { version = "0.1.15", optional = true }


[target.'cfg(any(target_os = "linux", target_os = "android"))'.dependencies]
netlink-packet-core = "0.7.0"
netlink-packet-route = "0.17.0"
Expand Down Expand Up @@ -140,7 +143,7 @@ iroh-relay = [
]
metrics = ["iroh-metrics/metrics"]
test-utils = ["iroh-relay"]
local_swarm_discovery = ["dep:swarm-discovery"]
local_swarm_discovery = ["dep:swarm-discovery", "dep:tokio-stream"]

[[bin]]
name = "iroh-relay"
Expand Down
41 changes: 24 additions & 17 deletions iroh-net/src/discovery/local_swarm_discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@ use std::{

use anyhow::Result;
use derive_more::FromStr;
use futures_lite::{stream::Boxed as BoxStream, StreamExt};
use futures_lite::stream::Boxed as BoxStream;
use tracing::{debug, error, trace, warn};

use async_channel::Sender;
use iroh_base::key::PublicKey;
use swarm_discovery::{Discoverer, DropGuard, IpClass, Peer};
use tokio::task::JoinSet;
use tokio::{sync::mpsc, task::JoinSet};

use crate::{
discovery::{Discovery, DiscoveryItem},
Expand All @@ -39,13 +38,13 @@ const DISCOVERY_DURATION: Duration = Duration::from_secs(10);
pub struct LocalSwarmDiscovery {
#[allow(dead_code)]
handle: AbortingJoinHandle<()>,
sender: Sender<Message>,
sender: mpsc::Sender<Message>,
}

#[derive(Debug)]
enum Message {
Discovery(String, Peer),
SendAddrs(NodeId, Sender<Result<DiscoveryItem>>),
SendAddrs(NodeId, mpsc::Sender<Result<DiscoveryItem>>),
ChangeLocalAddrs(AddrInfo),
Timeout(NodeId, usize),
}
Expand All @@ -62,7 +61,7 @@ impl LocalSwarmDiscovery {
/// This relies on [`tokio::runtime::Handle::current`] and will panic if called outside of the context of a tokio runtime.
pub fn new(node_id: NodeId) -> Result<Self> {
debug!("Creating new LocalSwarmDiscovery service");
let (send, recv) = async_channel::bounded(64);
let (send, mut recv) = mpsc::channel(64);
let task_sender = send.clone();
let rt = tokio::runtime::Handle::current();
let discovery = LocalSwarmDiscovery::spawn_discoverer(
Expand All @@ -75,19 +74,21 @@ impl LocalSwarmDiscovery {
let handle = tokio::spawn(async move {
let mut node_addrs: HashMap<PublicKey, Peer> = HashMap::default();
let mut last_id = 0;
let mut senders: HashMap<PublicKey, HashMap<usize, Sender<Result<DiscoveryItem>>>> =
HashMap::default();
let mut senders: HashMap<
PublicKey,
HashMap<usize, mpsc::Sender<Result<DiscoveryItem>>>,
> = HashMap::default();
let mut timeouts = JoinSet::new();
loop {
trace!(?node_addrs, "LocalSwarmDiscovery Service loop tick");
let msg = match recv.recv().await {
Err(err) => {
error!("LocalSwarmDiscovery service error: {err:?}");
None => {
error!("LocalSwarmDiscovery channel closed");
error!("closing LocalSwarmDiscovery");
timeouts.abort_all();
return;
}
Ok(msg) => msg,
Some(msg) => msg,
};
match msg {
Message::Discovery(discovered_node_id, peer_info) => {
Expand Down Expand Up @@ -189,20 +190,24 @@ impl LocalSwarmDiscovery {

fn spawn_discoverer(
node_id: PublicKey,
sender: Sender<Message>,
sender: mpsc::Sender<Message>,
socketaddrs: BTreeSet<SocketAddr>,
rt: &tokio::runtime::Handle,
) -> Result<DropGuard> {
let spawn_rt = rt.clone();
let callback = move |node_id: &str, peer: &Peer| {
trace!(
node_id,
?peer,
"Received peer information from LocalSwarmDiscovery"
);

sender
.send_blocking(Message::Discovery(node_id.to_string(), peer.clone()))
.ok();
let sender = sender.clone();
let node_id = node_id.to_string();
let peer = peer.clone();
spawn_rt.spawn(async move {
sender.send(Message::Discovery(node_id, peer)).await.ok();
});
};
let addrs = LocalSwarmDiscovery::socketaddrs_to_addrs(socketaddrs);
let mut discoverer =
Expand Down Expand Up @@ -247,15 +252,16 @@ impl From<&Peer> for DiscoveryItem {

impl Discovery for LocalSwarmDiscovery {
fn resolve(&self, _ep: Endpoint, node_id: NodeId) -> Option<BoxStream<Result<DiscoveryItem>>> {
let (send, recv) = async_channel::bounded(20);
let (send, recv) = mpsc::channel(20);
let discovery_sender = self.sender.clone();
tokio::spawn(async move {
discovery_sender
.send(Message::SendAddrs(node_id, send))
.await
.ok();
});
Some(recv.boxed())
let stream = tokio_stream::wrappers::ReceiverStream::new(recv);
Some(Box::pin(stream))
}

fn publish(&self, info: &AddrInfo) {
Expand All @@ -277,6 +283,7 @@ mod tests {
/// tests)
mod run_in_isolation {
use super::super::*;
use futures_lite::StreamExt;
use testresult::TestResult;

#[tokio::test]
Expand Down
15 changes: 8 additions & 7 deletions iroh-net/src/magicsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ pub(crate) struct MagicSock {
proxy_url: Option<Url>,

/// Used for receiving relay messages.
relay_recv_receiver: async_channel::Receiver<RelayRecvResult>,
relay_recv_receiver: parking_lot::Mutex<mpsc::Receiver<RelayRecvResult>>,
/// Stores wakers, to be called when relay_recv_ch receives new data.
network_recv_wakers: parking_lot::Mutex<Option<Waker>>,
network_send_wakers: parking_lot::Mutex<Option<Waker>>,
Expand Down Expand Up @@ -788,12 +788,13 @@ impl MagicSock {
if self.is_closed() {
break;
}
match self.relay_recv_receiver.try_recv() {
Err(async_channel::TryRecvError::Empty) => {
let mut relay_recv_receiver = self.relay_recv_receiver.lock();
match relay_recv_receiver.try_recv() {
Err(mpsc::error::TryRecvError::Empty) => {
self.network_recv_wakers.lock().replace(cx.waker().clone());
break;
}
Err(async_channel::TryRecvError::Closed) => {
Err(mpsc::error::TryRecvError::Disconnected) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::NotConnected,
"connection closed",
Expand Down Expand Up @@ -1378,7 +1379,7 @@ impl Handle {
insecure_skip_relay_cert_verify,
} = opts;

let (relay_recv_sender, relay_recv_receiver) = async_channel::bounded(128);
let (relay_recv_sender, relay_recv_receiver) = mpsc::channel(128);

let (pconn4, pconn6) = bind(port)?;
let port = pconn4.port();
Expand Down Expand Up @@ -1412,7 +1413,7 @@ impl Handle {
local_addrs: std::sync::RwLock::new((ipv4_addr, ipv6_addr)),
closing: AtomicBool::new(false),
closed: AtomicBool::new(false),
relay_recv_receiver,
relay_recv_receiver: parking_lot::Mutex::new(relay_recv_receiver),
network_recv_wakers: parking_lot::Mutex::new(None),
network_send_wakers: parking_lot::Mutex::new(None),
actor_sender: actor_sender.clone(),
Expand Down Expand Up @@ -1704,7 +1705,7 @@ struct Actor {
relay_actor_sender: mpsc::Sender<RelayActorMessage>,
relay_actor_cancel_token: CancellationToken,
/// Channel to send received relay messages on, for processing.
relay_recv_sender: async_channel::Sender<RelayRecvResult>,
relay_recv_sender: mpsc::Sender<RelayRecvResult>,
/// When set, is an AfterFunc timer that will call MagicSock::do_periodic_stun.
periodic_re_stun_timer: time::Interval,
/// The `NetInfo` provided in the last call to `net_info_func`. It's used to deduplicate calls to netInfoFunc.
Expand Down
5 changes: 3 additions & 2 deletions iroh-net/src/magicsock/udp_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ mod tests {

use super::*;
use anyhow::Result;
use tokio::sync::mpsc;

const ALPN: &[u8] = b"n0/test/1";

Expand Down Expand Up @@ -192,7 +193,7 @@ mod tests {
let (m2, _m2_key) = wrap_socket(m2)?;

let m1_addr = SocketAddr::new(network.local_addr(), m1.local_addr()?.port());
let (m1_send, m1_recv) = async_channel::bounded(8);
let (m1_send, mut m1_recv) = mpsc::channel(8);

let m1_task = tokio::task::spawn(async move {
if let Some(conn) = m1.accept().await {
Expand Down Expand Up @@ -220,7 +221,7 @@ mod tests {
drop(send_bi);

// make sure the right values arrived
let val = m1_recv.recv().await?;
let val = m1_recv.recv().await.unwrap();
assert_eq!(val, b"hello");

m1_task.await??;
Expand Down
6 changes: 3 additions & 3 deletions iroh-net/src/net/netmon/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub(super) struct Actor {
/// OS specific monitor.
#[allow(dead_code)]
route_monitor: RouteMonitor,
mon_receiver: async_channel::Receiver<NetworkMessage>,
mon_receiver: mpsc::Receiver<NetworkMessage>,
actor_receiver: mpsc::Receiver<ActorMessage>,
actor_sender: mpsc::Sender<ActorMessage>,
/// Callback registry.
Expand All @@ -84,7 +84,7 @@ impl Actor {
let wall_time = Instant::now();

// Use flume channels, as tokio::mpsc is not safe to use across ffi boundaries.
let (mon_sender, mon_receiver) = async_channel::bounded(MON_CHAN_CAPACITY);
let (mon_sender, mon_receiver) = mpsc::channel(MON_CHAN_CAPACITY);
let route_monitor = RouteMonitor::new(mon_sender)?;
let (actor_sender, actor_receiver) = mpsc::channel(ACTOR_CHAN_CAPACITY);

Expand Down Expand Up @@ -129,7 +129,7 @@ impl Actor {
debounce_interval.reset_immediately();
}
}
Ok(_event) = self.mon_receiver.recv() => {
Some(_event) = self.mon_receiver.recv() => {
trace!("network activity detected");
last_event.replace(false);
debounce_interval.reset_immediately();
Expand Down
3 changes: 2 additions & 1 deletion iroh-net/src/net/netmon/android.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use anyhow::Result;
use tokio::sync::mpsc;

use super::actor::NetworkMessage;

#[derive(Debug)]
pub(super) struct RouteMonitor {}

impl RouteMonitor {
pub(super) fn new(_sender: async_channel::Sender<NetworkMessage>) -> Result<Self> {
pub(super) fn new(_sender: mpsc::Sender<NetworkMessage>) -> Result<Self> {
// Very sad monitor. Android doesn't allow us to do this

Ok(RouteMonitor {})
Expand Down
4 changes: 2 additions & 2 deletions iroh-net/src/net/netmon/bsd.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Result;
use tokio::{io::AsyncReadExt, task::JoinHandle};
use tokio::{io::AsyncReadExt, sync::mpsc, task::JoinHandle};
use tracing::{trace, warn};

#[cfg(any(target_os = "freebsd", target_os = "netbsd", target_os = "openbsd"))]
Expand All @@ -23,7 +23,7 @@ impl Drop for RouteMonitor {
}

impl RouteMonitor {
pub(super) fn new(sender: async_channel::Sender<NetworkMessage>) -> Result<Self> {
pub(super) fn new(sender: mpsc::Sender<NetworkMessage>) -> Result<Self> {
let socket = socket2::Socket::new(libc::AF_ROUTE.into(), socket2::Type::RAW, None)?;
socket.set_nonblocking(true)?;
let socket_std: std::os::unix::net::UnixStream = socket.into();
Expand Down
4 changes: 2 additions & 2 deletions iroh-net/src/net/netmon/linux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use netlink_packet_core::NetlinkPayload;
use netlink_packet_route::{address, constants::*, route, RtnlMessage};
use netlink_sys::{AsyncSocket, SocketAddr};
use rtnetlink::new_connection;
use tokio::task::JoinHandle;
use tokio::{sync::mpsc, task::JoinHandle};
use tracing::{info, trace, warn};

use crate::net::ip::is_link_local;
Expand Down Expand Up @@ -49,7 +49,7 @@ macro_rules! get_nla {
}

impl RouteMonitor {
pub(super) fn new(sender: async_channel::Sender<NetworkMessage>) -> Result<Self> {
pub(super) fn new(sender: mpsc::Sender<NetworkMessage>) -> Result<Self> {
let (mut conn, mut _handle, mut messages) = new_connection()?;

// Specify flags to listen on.
Expand Down
7 changes: 4 additions & 3 deletions iroh-net/src/net/netmon/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{collections::HashMap, sync::Arc};

use anyhow::Result;
use libc::c_void;
use tokio::sync::mpsc;
use tracing::{trace, warn};
use windows::Win32::{
Foundation::{BOOLEAN, HANDLE as Handle},
Expand All @@ -19,21 +20,21 @@ pub(super) struct RouteMonitor {
}

impl RouteMonitor {
pub(super) fn new(sender: async_channel::Sender<NetworkMessage>) -> Result<Self> {
pub(super) fn new(sender: mpsc::Sender<NetworkMessage>) -> Result<Self> {
// Register two callbacks with the windows api
let mut cb_handler = CallbackHandler::default();

// 1. Unicast Address Changes
let s = sender.clone();
cb_handler.register_unicast_address_change_callback(Box::new(move || {
if let Err(err) = s.send_blocking(NetworkMessage::Change) {
if let Err(err) = s.blocking_send(NetworkMessage::Change) {
warn!("unable to send: unicast change notification: {:?}", err);
}
}))?;

// 2. Route Changes
cb_handler.register_route_change_callback(Box::new(move || {
if let Err(err) = sender.send_blocking(NetworkMessage::Change) {
if let Err(err) = sender.blocking_send(NetworkMessage::Change) {
warn!("unable to send: route change notification: {:?}", err);
}
}))?;
Expand Down
Loading