From d93e7247a83f25c631306dcf2e846c9cd2856934 Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Sun, 4 Dec 2022 08:10:32 -0500 Subject: [PATCH] feat(console): add support for Unix domain sockets Add support for console connections that use Unix domain sockets rather than TCP. Fix #296. --- Cargo.lock | 1 + console-subscriber/Cargo.toml | 2 +- console-subscriber/examples/uds.rs | 20 ++++++++ console-subscriber/src/builder.rs | 77 ++++++++++++++++++++++++++---- console-subscriber/src/lib.rs | 32 +++++++++---- tokio-console/Cargo.toml | 1 + tokio-console/src/conn.rs | 36 +++++++++++++- 7 files changed, 148 insertions(+), 21 deletions(-) create mode 100644 console-subscriber/examples/uds.rs diff --git a/Cargo.lock b/Cargo.lock index 615baa841..ad49dbfe5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1454,6 +1454,7 @@ dependencies = [ "tokio", "toml", "tonic", + "tower", "tracing", "tracing-journald", "tracing-subscriber 0.3.11", diff --git a/console-subscriber/Cargo.toml b/console-subscriber/Cargo.toml index f541a5aeb..867690036 100644 --- a/console-subscriber/Cargo.toml +++ b/console-subscriber/Cargo.toml @@ -33,7 +33,7 @@ env-filter = ["tracing-subscriber/env-filter"] crossbeam-utils = "0.8.7" tokio = { version = "^1.15", features = ["sync", "time", "macros", "tracing"] } -tokio-stream = "0.1" +tokio-stream = { version = "0.1", features = ["net"] } thread_local = "1.1.3" console-api = { version = "0.4.0", path = "../console-api", features = ["transport"] } tonic = { version = "0.8", features = ["transport"] } diff --git a/console-subscriber/examples/uds.rs b/console-subscriber/examples/uds.rs new file mode 100644 index 000000000..d74d92769 --- /dev/null +++ b/console-subscriber/examples/uds.rs @@ -0,0 +1,20 @@ +use std::path::Path; +use std::time::Duration; +use tokio::{task, time}; +use tracing::info; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let addr = Path::new("./console-server"); + console_subscriber::ConsoleLayer::builder() + .server_addr(addr) + .init(); + info!("listening for console connections at {}", addr.display()); + task::Builder::default() + .name("sleepy") + .spawn(async move { time::sleep(Duration::from_secs(90)).await }) + .unwrap() + .await?; + + Ok(()) +} diff --git a/console-subscriber/src/builder.rs b/console-subscriber/src/builder.rs index 2c9fb0cc6..d8c58b5ab 100644 --- a/console-subscriber/src/builder.rs +++ b/console-subscriber/src/builder.rs @@ -1,6 +1,8 @@ use super::{ConsoleLayer, Server}; +#[cfg(unix)] +use std::path::Path; use std::{ - net::{SocketAddr, ToSocketAddrs}, + net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}, path::PathBuf, thread, time::Duration, @@ -32,7 +34,7 @@ pub struct Builder { pub(crate) retention: Duration, /// The address on which to serve the RPC server. - pub(super) server_addr: SocketAddr, + pub(super) server_addr: ServerAddr, /// If and where to save a recording of the events. pub(super) recording_path: Option, @@ -58,7 +60,7 @@ impl Default for Builder { publish_interval: ConsoleLayer::DEFAULT_PUBLISH_INTERVAL, retention: ConsoleLayer::DEFAULT_RETENTION, poll_duration_max: ConsoleLayer::DEFAULT_POLL_DURATION_MAX, - server_addr: SocketAddr::new(Server::DEFAULT_IP, Server::DEFAULT_PORT), + server_addr: ServerAddr::Tcp(SocketAddr::new(Server::DEFAULT_IP, Server::DEFAULT_PORT)), recording_path: None, filter_env_var: "RUST_LOG".to_string(), self_trace: false, @@ -138,7 +140,7 @@ impl Builder { /// defaults. /// /// [environment variable]: `Builder::with_default_env` - pub fn server_addr(self, server_addr: impl Into) -> Self { + pub fn server_addr(self, server_addr: impl Into) -> Self { Self { server_addr: server_addr.into(), ..self @@ -231,11 +233,14 @@ impl Builder { } if let Ok(bind) = std::env::var("TOKIO_CONSOLE_BIND") { - self.server_addr = bind - .to_socket_addrs() - .expect("TOKIO_CONSOLE_BIND must be formatted as HOST:PORT, such as localhost:4321") - .next() - .expect("tokio console could not resolve TOKIO_CONSOLE_BIND"); + self.server_addr = ServerAddr::Tcp( + bind.to_socket_addrs() + .expect( + "TOKIO_CONSOLE_BIND must be formatted as HOST:PORT, such as localhost:4321", + ) + .next() + .expect("tokio console could not resolve TOKIO_CONSOLE_BIND"), + ); } if let Some(interval) = duration_from_env("TOKIO_CONSOLE_PUBLISH_INTERVAL") { @@ -456,6 +461,60 @@ impl Builder { } } +/// Specifies the address on which a [`Server`] should listen. +/// +/// [`Server`]: crate::Server +#[derive(Clone, Debug)] +#[non_exhaustive] +pub enum ServerAddr { + /// A TCP address. + Tcp(SocketAddr), + /// A Unix socket address. + #[cfg(unix)] + Unix(PathBuf), +} + +impl From for ServerAddr { + fn from(addr: SocketAddr) -> ServerAddr { + ServerAddr::Tcp(addr) + } +} + +impl From for ServerAddr { + fn from(addr: SocketAddrV4) -> ServerAddr { + ServerAddr::Tcp(addr.into()) + } +} + +impl From for ServerAddr { + fn from(addr: SocketAddrV6) -> ServerAddr { + ServerAddr::Tcp(addr.into()) + } +} + +impl From<(I, u16)> for ServerAddr +where + I: Into, +{ + fn from(pieces: (I, u16)) -> ServerAddr { + ServerAddr::Tcp(pieces.into()) + } +} + +#[cfg(unix)] +impl From for ServerAddr { + fn from(path: PathBuf) -> ServerAddr { + ServerAddr::Unix(path) + } +} + +#[cfg(unix)] +impl<'a> From<&'a Path> for ServerAddr { + fn from(path: &'a Path) -> ServerAddr { + ServerAddr::Unix(path.to_path_buf()) + } +} + /// Initializes the console [tracing `Subscriber`][sub] and starts the console /// subscriber [`Server`] on its own background thread. /// diff --git a/console-subscriber/src/lib.rs b/console-subscriber/src/lib.rs index bde80347c..18ea2ed12 100644 --- a/console-subscriber/src/lib.rs +++ b/console-subscriber/src/lib.rs @@ -5,7 +5,7 @@ use serde::Serialize; use std::{ cell::RefCell, fmt, - net::{IpAddr, Ipv4Addr, SocketAddr}, + net::{IpAddr, Ipv4Addr}, sync::{ atomic::{AtomicUsize, Ordering}, Arc, @@ -13,7 +13,11 @@ use std::{ time::{Duration, Instant}, }; use thread_local::ThreadLocal; +#[cfg(unix)] +use tokio::net::UnixListener; use tokio::sync::{mpsc, oneshot}; +#[cfg(unix)] +use tokio_stream::wrappers::UnixListenerStream; use tracing_core::{ span::{self, Id}, subscriber::{self, Subscriber}, @@ -37,6 +41,7 @@ mod visitors; use aggregator::Aggregator; pub use builder::Builder; +use builder::ServerAddr; use callsites::Callsites; use record::Recorder; use stack::SpanStack; @@ -134,7 +139,7 @@ pub struct ConsoleLayer { /// [cli]: https://crates.io/crates/tokio-console pub struct Server { subscribe: mpsc::Sender, - addr: SocketAddr, + addr: ServerAddr, aggregator: Option, client_buffer: usize, } @@ -945,13 +950,22 @@ impl Server { .take() .expect("cannot start server multiple times"); let aggregate = spawn_named(aggregate.run(), "console::aggregate"); - let addr = self.addr; - let serve = builder - .add_service(proto::instrument::instrument_server::InstrumentServer::new( - self, - )) - .serve(addr); - let res = spawn_named(serve, "console::serve").await; + let addr = self.addr.clone(); + let router = builder.add_service( + proto::instrument::instrument_server::InstrumentServer::new(self), + ); + let res = match addr { + ServerAddr::Tcp(addr) => { + let serve = router.serve(addr); + spawn_named(serve, "console::serve").await + } + #[cfg(unix)] + ServerAddr::Unix(path) => { + let incoming = UnixListener::bind(path)?; + let serve = router.serve_with_incoming(UnixListenerStream::new(incoming)); + spawn_named(serve, "console::serve").await + } + }; aggregate.abort(); res?.map_err(Into::into) } diff --git a/tokio-console/Cargo.toml b/tokio-console/Cargo.toml index 30496f672..69a4001e4 100644 --- a/tokio-console/Cargo.toml +++ b/tokio-console/Cargo.toml @@ -34,6 +34,7 @@ tokio = { version = "1", features = ["full", "rt-multi-thread"] } tonic = { version = "0.8", features = ["transport"] } futures = "0.3" tui = { version = "0.16.0", default-features = false, features = ["crossterm"] } +tower = "0.4.12" tracing = "0.1" tracing-subscriber = { version = "0.3.0", features = ["env-filter"] } tracing-journald = { version = "0.2", optional = true } diff --git a/tokio-console/src/conn.rs b/tokio-console/src/conn.rs index faf42e9e0..7269746a2 100644 --- a/tokio-console/src/conn.rs +++ b/tokio-console/src/conn.rs @@ -5,7 +5,12 @@ use console_api::instrument::{ use console_api::tasks::TaskDetails; use futures::stream::StreamExt; use std::{error::Error, pin::Pin, time::Duration}; -use tonic::{transport::Channel, transport::Uri, Streaming}; +#[cfg(unix)] +use tokio::net::UnixStream; +use tonic::{ + transport::{Channel, Endpoint, Uri}, + Streaming, +}; #[derive(Debug)] pub struct Connection { @@ -78,7 +83,34 @@ impl Connection { tokio::time::sleep(backoff).await; } let try_connect = async { - let mut client = InstrumentClient::connect(self.target.clone()).await?; + let channel = match self.target.scheme_str() { + #[cfg(unix)] + Some("file") => { + // Dummy endpoint is ignored by the connector. + let endpoint = Endpoint::from_static("http://localhost"); + // Reconstruct the full path, which will have been split + // between the host and path components of the URI. + let path = match (self.target.host(), self.target.path()) { + (None, _) => self.target.path().to_owned(), + (Some(host), "/") => host.to_owned(), + (Some(host), path) => format!("{host}{path}"), + }; + endpoint + .connect_with_connector(tower::service_fn(move |_| { + UnixStream::connect(path.clone()) + })) + .await? + } + #[cfg(not(unix))] + Some("file") => { + return Err("unix domain sockets are not supported on this platform".into()); + } + _ => { + let endpoint = Endpoint::try_from(self.target.clone())?; + endpoint.connect().await? + } + }; + let mut client = InstrumentClient::new(channel); let request = tonic::Request::new(InstrumentRequest {}); let stream = Box::new(client.watch_updates(request).await?.into_inner()); Ok::>(State::Connected { client, stream })