From 2746a0ea88881b3c7e2f731f9e96d8b5866149a6 Mon Sep 17 00:00:00 2001 From: Yujia Qiao Date: Tue, 8 Mar 2022 23:27:18 +0800 Subject: [PATCH] feat: application layer heartbeat (#136) * feat: application layer heartbeat * feat: make heartbeat configurable * fix: update keepalive params * docs: update about heartbeat --- README.md | 10 +++--- src/client.rs | 30 ++++++++++++------ src/config.rs | 16 ++++++++++ src/lib.rs | 6 ++-- src/protocol.rs | 6 ++-- src/server.rs | 73 ++++++++++++++++++++++++++++++++------------ src/transport/mod.rs | 6 ++-- 7 files changed, 106 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index c8449842..988adac1 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ Here is the full configuration specification: [client] remote_addr = "example.com:2333" # Necessary. The address of the server default_token = "default_token_if_not_specify" # Optional. The default token of services, if they don't define their own ones +heartbeat_timeout = 40 # Optional. Set to 0 to disable the application-layer heartbeat test. The value must be greater than `server.heartbeat_interval`. Default: 40 secs [client.transport] # The whole block is optional. Specify which transport to use type = "tcp" # Optional. Possible values: ["tcp", "tls", "noise"]. Default: "tcp" @@ -112,8 +113,8 @@ type = "tcp" # Optional. Possible values: ["tcp", "tls", "noise"]. Default: "tcp [client.transport.tcp] # Optional proxy = "socks5://user:passwd@127.0.0.1:1080" # Optional. Use the proxy to connect to the server nodelay = false # Optional. Determine whether to enable TCP_NODELAY, if applicable, to improve the latency but decrease the bandwidth. Default: false -keepalive_secs = 10 # Optional. Specify `tcp_keepalive_time` in `tcp(7)`, if applicable. Default: 10 seconds -keepalive_interval = 5 # Optional. Specify `tcp_keepalive_intvl` in `tcp(7)`, if applicable. Default: 5 seconds +keepalive_secs = 20 # Optional. Specify `tcp_keepalive_time` in `tcp(7)`, if applicable. Default: 20 seconds +keepalive_interval = 8 # Optional. Specify `tcp_keepalive_intvl` in `tcp(7)`, if applicable. Default: 8 seconds [client.transport.tls] # Necessary if `type` is "tls" trusted_root = "ca.pem" # Necessary. The certificate of CA that signed the server's certificate @@ -136,12 +137,13 @@ local_addr = "127.0.0.1:1082" [server] bind_addr = "0.0.0.0:2333" # Necessary. The address that the server listens for clients. Generally only the port needs to be change. default_token = "default_token_if_not_specify" # Optional +heartbeat_interval = 30 # Optional. The interval between two application-layer heartbeat. Set to 0 to disable sending heartbeat. Default: 30 secs [server.transport] # Same as `[client.transport]` type = "tcp" nodelay = false -keepalive_secs = 10 -keepalive_interval = 5 +keepalive_secs = 20 +keepalive_interval = 8 [server.transport.tls] # Necessary if `type` is "tls" pkcs12 = "identify.pfx" # Necessary. pkcs12 file of server's certificate and private key diff --git a/src/client.rs b/src/client.rs index 8285c632..9c04944d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -29,11 +29,11 @@ use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE // The entrypoint of running a client pub async fn run_client( - config: &Config, + config: Config, shutdown_rx: broadcast::Receiver, service_rx: mpsc::Receiver, ) -> Result<()> { - let config = config.client.as_ref().ok_or(anyhow!( + let config = config.client.ok_or(anyhow!( "Try to run as a client, but the configuration is missing. Please add the `[client]` block" ))?; @@ -67,21 +67,21 @@ type ServiceDigest = protocol::Digest; type Nonce = protocol::Digest; // Holds the state of a client -struct Client<'a, T: Transport> { - config: &'a ClientConfig, +struct Client { + config: ClientConfig, service_handles: HashMap, transport: Arc, } -impl<'a, T: 'static + Transport> Client<'a, T> { +impl Client { // Create a Client from `[client]` config block - async fn from(config: &'a ClientConfig) -> Result> { + async fn from(config: ClientConfig) -> Result> { + let transport = + Arc::new(T::new(&config.transport).with_context(|| "Failed to create the transport")?); Ok(Client { config, service_handles: HashMap::new(), - transport: Arc::new( - T::new(&config.transport).with_context(|| "Failed to create the transport")?, - ), + transport, }) } @@ -97,6 +97,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> { (*config).clone(), self.config.remote_addr.clone(), self.transport.clone(), + self.config.heartbeat_timeout, ); self.service_handles.insert(name.clone(), handle); } @@ -122,6 +123,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> { s, self.config.remote_addr.clone(), self.transport.clone(), + self.config.heartbeat_timeout ); let _ = self.service_handles.insert(name, handle); }, @@ -369,6 +371,7 @@ struct ControlChannel { shutdown_rx: oneshot::Receiver, // Receives the shutdown signal remote_addr: String, // `client.remote_addr` transport: Arc, // Wrapper around the transport layer + heartbeat_timeout: u64, // Application layer heartbeat timeout in secs } // Handle of a control channel @@ -451,9 +454,14 @@ impl ControlChannel { warn!("{:#}", e); } }.instrument(Span::current())); - } + }, + ControlChannelCmd::HeartBeat => () } }, + _ = time::sleep(Duration::from_secs(self.heartbeat_timeout)), if self.heartbeat_timeout != 0 => { + warn!("Heartbeat timed out"); + break; + } _ = &mut self.shutdown_rx => { break; } @@ -471,6 +479,7 @@ impl ControlChannelHandle { service: ClientServiceConfig, remote_addr: String, transport: Arc, + heartbeat_timeout: u64, ) -> ControlChannelHandle { let digest = protocol::digest(service.name.as_bytes()); @@ -482,6 +491,7 @@ impl ControlChannelHandle { shutdown_rx, remote_addr, transport, + heartbeat_timeout, }; tokio::spawn( diff --git a/src/config.rs b/src/config.rs index 31e98af4..5ce96ce1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -9,6 +9,10 @@ use url::Url; use crate::transport::{DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_KEEPALIVE_SECS, DEFAULT_NODELAY}; +/// Application-layer heartbeat interval in secs +const DEFAULT_HEARTBEAT_INTERVAL_SECS: u64 = 30; +const DEFAULT_HEARTBEAT_TIMEOUT_SECS: u64 = 40; + /// String with Debug implementation that emits "MASKED" /// Used to mask sensitive strings when logging #[derive(Serialize, Deserialize, Default, PartialEq, Clone)] @@ -177,6 +181,10 @@ pub struct TransportConfig { pub noise: Option, } +fn default_heartbeat_timeout() -> u64 { + DEFAULT_HEARTBEAT_TIMEOUT_SECS +} + #[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)] #[serde(deny_unknown_fields)] pub struct ClientConfig { @@ -185,6 +193,12 @@ pub struct ClientConfig { pub services: HashMap, #[serde(default)] pub transport: TransportConfig, + #[serde(default = "default_heartbeat_timeout")] + pub heartbeat_timeout: u64, +} + +fn default_heartbeat_interval() -> u64 { + DEFAULT_HEARTBEAT_INTERVAL_SECS } #[derive(Debug, Serialize, Deserialize, Default, PartialEq, Clone)] @@ -195,6 +209,8 @@ pub struct ServerConfig { pub services: HashMap, #[serde(default)] pub transport: TransportConfig, + #[serde(default = "default_heartbeat_interval")] + pub heartbeat_interval: u64, } #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] diff --git a/src/lib.rs b/src/lib.rs index 481200a2..c31da23b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,7 +93,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<() last_instance = Some(( tokio::spawn(run_instance( - *(config.clone()), + *config, args.clone(), shutdown_tx.subscribe(), service_update_rx, @@ -127,13 +127,13 @@ async fn run_instance( #[cfg(not(feature = "client"))] crate::helper::feature_not_compile("client"); #[cfg(feature = "client")] - run_client(&config, shutdown_rx, service_update).await + run_client(config, shutdown_rx, service_update).await } RunMode::Server => { #[cfg(not(feature = "server"))] crate::helper::feature_not_compile("server"); #[cfg(feature = "server")] - run_server(&config, shutdown_rx, service_update).await + run_server(config, shutdown_rx, service_update).await } }; ret.unwrap(); diff --git a/src/protocol.rs b/src/protocol.rs index 4419c856..577c7323 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -9,9 +9,10 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::trace; type ProtocolVersion = u8; -const PROTO_V0: u8 = 0u8; +const _PROTO_V0: u8 = 0u8; +const PROTO_V1: u8 = 1u8; -pub const CURRENT_PROTO_VERSION: ProtocolVersion = PROTO_V0; +pub const CURRENT_PROTO_VERSION: ProtocolVersion = PROTO_V1; pub type Digest = [u8; HASH_WIDTH_IN_BYTES]; @@ -48,6 +49,7 @@ impl std::fmt::Display for Ack { #[derive(Deserialize, Serialize, Debug)] pub enum ControlChannelCmd { CreateDataChannel, + HeartBeat, } #[derive(Deserialize, Serialize, Debug)] diff --git a/src/server.rs b/src/server.rs index abd181b6..57fb20a6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -38,11 +38,11 @@ const HANDSHAKE_TIMEOUT: u64 = 5; // Timeout for transport handshake // The entrypoint of running a server pub async fn run_server( - config: &Config, + config: Config, shutdown_rx: broadcast::Receiver, service_rx: mpsc::Receiver, ) -> Result<()> { - let config = match &config.server { + let config = match config.server { Some(config) => config, None => { return Err(anyhow!("Try to run as a server, but the configuration is missing. Please add the `[server]` block")) @@ -82,9 +82,9 @@ pub async fn run_server( type ControlChannelMap = MultiMap>; // Server holds all states of running a server -struct Server<'a, T: Transport> { +struct Server { // `[server]` config - config: &'a ServerConfig, + config: Arc, // `[server.services]` config, indexed by ServiceDigest services: Arc>>, @@ -105,14 +105,18 @@ fn generate_service_hashmap( ret } -impl<'a, T: 'static + Transport> Server<'a, T> { +impl Server { // Create a server from `[server]` - pub async fn from(config: &'a ServerConfig) -> Result> { + pub async fn from(config: ServerConfig) -> Result> { + let config = Arc::new(config); + let services = Arc::new(RwLock::new(generate_service_hashmap(&config))); + let control_channels = Arc::new(RwLock::new(ControlChannelMap::new())); + let transport = Arc::new(T::new(&config.transport)?); Ok(Server { config, - services: Arc::new(RwLock::new(generate_service_hashmap(config))), - control_channels: Arc::new(RwLock::new(ControlChannelMap::new())), - transport: Arc::new(T::new(&config.transport)?), + services, + control_channels, + transport, }) } @@ -171,8 +175,9 @@ impl<'a, T: 'static + Transport> Server<'a, T> { Ok(conn) => { let services = self.services.clone(); let control_channels = self.control_channels.clone(); + let server_config = self.config.clone(); tokio::spawn(async move { - if let Err(err) = handle_connection(conn, services, control_channels).await { + if let Err(err) = handle_connection(conn, services, control_channels, server_config).await { error!("{:#}", err); } }.instrument(info_span!("connection", %addr))); @@ -233,12 +238,20 @@ async fn handle_connection( mut conn: T::Stream, services: Arc>>, control_channels: Arc>>, + server_config: Arc, ) -> Result<()> { // Read hello let hello = read_hello(&mut conn).await?; match hello { ControlChannelHello(_, service_digest) => { - do_control_channel_handshake(conn, services, control_channels, service_digest).await?; + do_control_channel_handshake( + conn, + services, + control_channels, + service_digest, + server_config, + ) + .await?; } DataChannelHello(_, nonce) => { do_data_channel_handshake(conn, control_channels, nonce).await?; @@ -252,6 +265,7 @@ async fn do_control_channel_handshake( services: Arc>>, control_channels: Arc>>, service_digest: ServiceDigest, + server_config: Arc, ) -> Result<()> { info!("Try to handshake a control channel"); @@ -321,7 +335,8 @@ async fn do_control_channel_handshake( conn.flush().await?; info!(service = %service_config.name, "Control channel established"); - let handle = ControlChannelHandle::new(conn, service_config); + let handle = + ControlChannelHandle::new(conn, service_config, server_config.heartbeat_interval); // Insert the new handle let _ = h.insert(service_digest, session_key, handle); @@ -371,7 +386,11 @@ where // Create a control channel handle, where the control channel handling task // and the connection pool task are created. #[instrument(name = "handle", skip_all, fields(service = %service.name))] - fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle { + fn new( + conn: T::Stream, + service: ServerServiceConfig, + heartbeat_interval: u64, + ) -> ControlChannelHandle { // Create a shutdown channel let (shutdown_tx, shutdown_rx) = broadcast::channel::(1); @@ -435,6 +454,7 @@ where conn, shutdown_rx, data_ch_req_rx, + heartbeat_interval, }; // Run the control channel @@ -460,13 +480,26 @@ struct ControlChannel { conn: T::Stream, // The connection of control channel shutdown_rx: broadcast::Receiver, // Receives the shutdown signal data_ch_req_rx: mpsc::UnboundedReceiver, // Receives visitor connections + heartbeat_interval: u64, // Application-layer heartbeat interval in secs } impl ControlChannel { + async fn write_and_flush(&mut self, data: &[u8]) -> Result<()> { + self.conn + .write_all(data) + .await + .with_context(|| "Failed to write control cmds")?; + self.conn + .flush() + .await + .with_context(|| "Failed to flush control cmds")?; + Ok(()) + } // Run a control channel #[instrument(skip_all)] async fn run(mut self) -> Result<()> { - let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap(); + let create_ch_cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap(); + let heartbeat = bincode::serialize(&ControlChannelCmd::HeartBeat).unwrap(); // Wait for data channel requests and the shutdown signal loop { @@ -474,11 +507,7 @@ impl ControlChannel { val = self.data_ch_req_rx.recv() => { match val { Some(_) => { - if let Err(e) = self.conn.write_all(&cmd).await.with_context(||"Failed to write control cmds") { - error!("{:#}", e); - break; - } - if let Err(e) = self.conn.flush().await.with_context(|| "Failed to flush control cmds") { + if let Err(e) = self.write_and_flush(&create_ch_cmd).await { error!("{:#}", e); break; } @@ -488,6 +517,12 @@ impl ControlChannel { } } }, + _ = time::sleep(Duration::from_secs(self.heartbeat_interval)), if self.heartbeat_interval != 0 => { + if let Err(e) = self.write_and_flush(&heartbeat).await { + error!("{:#}", e); + break; + } + } // Wait for the shutdown signal _ = self.shutdown_rx.recv() => { break; diff --git a/src/transport/mod.rs b/src/transport/mod.rs index c76d7f26..cc0e139f 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -9,10 +9,10 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpStream, ToSocketAddrs}; use tracing::{error, trace}; -pub static DEFAULT_NODELAY: bool = false; +pub const DEFAULT_NODELAY: bool = false; -pub static DEFAULT_KEEPALIVE_SECS: u64 = 10; -pub static DEFAULT_KEEPALIVE_INTERVAL: u64 = 3; +pub const DEFAULT_KEEPALIVE_SECS: u64 = 20; +pub const DEFAULT_KEEPALIVE_INTERVAL: u64 = 8; /// Specify a transport layer, like TCP, TLS #[async_trait]