diff --git a/russh/src/client/encrypted.rs b/russh/src/client/encrypted.rs index af2bb143..f237140f 100644 --- a/russh/src/client/encrypted.rs +++ b/russh/src/client/encrypted.rs @@ -525,7 +525,6 @@ impl Session { ); match req { b"xon-xoff" => { - self.activity = false; r.read_byte().map_err(crate::Error::from)?; // should be 0. let client_can_do = r.read_byte().map_err(crate::Error::from)? != 0; if let Some(chan) = self.channels.get(&channel_num) { @@ -572,7 +571,6 @@ impl Session { .await } b"keepalive@openssh.com" => { - self.activity = false; let wants_reply = r.read_byte().map_err(crate::Error::from)?; if wants_reply == 1 { if let Some(ref mut enc) = self.common.encrypted { @@ -592,7 +590,7 @@ impl Session { Ok((client, self)) } _ => { - self.activity = false; + self.common.received_data = false; let wants_reply = r.read_byte().map_err(crate::Error::from)?; if wants_reply == 1 { if let Some(ref mut enc) = self.common.encrypted { @@ -692,7 +690,7 @@ impl Session { push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE)) } } - self.activity = false; + self.common.received_data = false; Ok((client, self)) } Some(&msg::CHANNEL_SUCCESS) => { @@ -806,15 +804,14 @@ impl Session { } } Some(&msg::REQUEST_SUCCESS | &msg::REQUEST_FAILURE) - if self.server_alive_timeouts > 0 => + if self.common.alive_timeouts > 0 => { - self.activity = false; // TODO what other things might need to happen in response to these two opcodes? - self.server_alive_timeouts = 0; + self.common.alive_timeouts = 0; Ok((client, self)) } _ => { - self.activity = false; + self.common.received_data = false; info!("Unhandled packet: {:?}", buf); Ok((client, self)) } diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index c0651c7d..d83df836 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -124,8 +124,6 @@ pub struct Session { pending_len: u32, inbound_channel_sender: Sender, inbound_channel_receiver: Receiver, - server_alive_timeouts: usize, - activity: bool, } impl Drop for Session { @@ -693,6 +691,8 @@ where wants_reply: false, disconnected: false, buffer: CryptoVec::new(), + alive_timeouts: 0, + received_data: false, }, session_receiver, session_sender, @@ -723,16 +723,6 @@ async fn start_reading( Ok((n, stream_read, buffer, cipher)) } -fn future_or_pending( - val: Option, - f: impl FnOnce(T) -> F, -) -> futures::future::Either::Output>, F> { - val.map_or( - futures::future::Either::Left(futures::future::pending()), - |x| futures::future::Either::Right(f(x)), - ) -} - impl Session { fn new( target_window_size: u32, @@ -751,8 +741,6 @@ impl Session { channels: HashMap::new(), pending_reads: Vec::new(), pending_len: 0, - server_alive_timeouts: 0, - activity: false, } } @@ -782,11 +770,11 @@ impl Session { std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); let keepalive_timer = - future_or_pending(self.common.config.keepalive_interval, tokio::time::sleep); + crate::future_or_pending(self.common.config.keepalive_interval, tokio::time::sleep); pin!(keepalive_timer); let inactivity_timer = - future_or_pending(self.common.config.inactivity_timeout, tokio::time::sleep); + crate::future_or_pending(self.common.config.inactivity_timeout, tokio::time::sleep); pin!(inactivity_timer); let reading = start_reading(stream_read, buffer, opening_cipher); @@ -794,20 +782,9 @@ impl Session { #[allow(clippy::panic)] // false positive in select! macro while !self.common.disconnected { - self.activity = false; + self.common.received_data = false; + let mut only_sent_keepalive = false; tokio::select! { - () = &mut keepalive_timer => { - self.send_keepalive(true); - if self.common.config.keepalive_max != 0 && self.server_alive_timeouts > self.common.config.keepalive_max { - debug!("Timeout, server not responding to keepalives"); - break - } - self.server_alive_timeouts = self.server_alive_timeouts.saturating_add(1); - } - () = &mut inactivity_timer => { - debug!("timeout"); - break - } r = &mut reading => { let (stream_read, buffer, mut opening_cipher) = match r { Ok((_, stream_read, buffer, opening_cipher)) => (stream_read, buffer, opening_cipher), @@ -839,7 +816,7 @@ impl Session { if buf[0] == crate::msg::DISCONNECT { break; } else if buf[0] > 4 { - self.activity = true; + self.common.received_data = true; let (h, s) = reply(self, handler, &mut encrypted_signal, buf).await?; handler = h; self = s; @@ -849,6 +826,19 @@ impl Session { std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); reading.set(start_reading(stream_read, buffer, opening_cipher)); } + () = &mut keepalive_timer => { + self.send_keepalive(true); + only_sent_keepalive = true; + if self.common.config.keepalive_max != 0 && self.common.alive_timeouts > self.common.config.keepalive_max { + debug!("Timeout, server not responding to keepalives"); + break + } + self.common.alive_timeouts = self.common.alive_timeouts.saturating_add(1); + } + () = &mut inactivity_timer => { + debug!("timeout"); + break + } msg = self.receiver.recv(), if !self.is_rekeying() => { match msg { Some(msg) => self.handle_msg(msg)?, @@ -888,7 +878,6 @@ impl Session { "writing to stream: {:?} bytes", self.common.write_buffer.buffer.len() ); - self.activity = true; stream_write .write_all(&self.common.write_buffer.buffer) .await @@ -903,14 +892,15 @@ impl Session { } } - if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( - keepalive_timer.as_mut().as_pin_mut(), - self.common.config.keepalive_interval, - ) { - sleep.as_mut().reset(tokio::time::Instant::now() + d); + if self.common.received_data { + if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + keepalive_timer.as_mut().as_pin_mut(), + self.common.config.keepalive_interval, + ) { + sleep.as_mut().reset(tokio::time::Instant::now() + d); + } } - - if self.activity { + if !only_sent_keepalive { if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( inactivity_timer.as_mut().as_pin_mut(), self.common.config.inactivity_timeout, @@ -1315,7 +1305,7 @@ pub struct Config { pub preferred: negotiation::Preferred, /// Time after which the connection is garbage-collected. pub inactivity_timeout: Option, - /// If nothing is sent or received for this amount of time, send a keepalive message. + /// If nothing is received from the server for this amount of time, send a keepalive message. pub keepalive_interval: Option, /// If this many keepalives have been sent without reply, close the connection. pub keepalive_max: usize, diff --git a/russh/src/lib.rs b/russh/src/lib.rs index a9eed8bf..237edaa5 100644 --- a/russh/src/lib.rs +++ b/russh/src/lib.rs @@ -466,10 +466,12 @@ impl ChannelParams { } } -pub(crate) async fn timeout(delay: Option) { - if let Some(delay) = delay { - tokio::time::sleep(delay).await - } else { - futures::future::pending().await - }; +pub(crate) fn future_or_pending( + val: Option, + f: impl FnOnce(T) -> F, +) -> futures::future::Either::Output>, F> { + val.map_or( + futures::future::Either::Left(futures::future::pending()), + |x| futures::future::Either::Right(f(x)), + ) } diff --git a/russh/src/server/encrypted.rs b/russh/src/server/encrypted.rs index 354eddbd..00af6259 100644 --- a/russh/src/server/encrypted.rs +++ b/russh/src/server/encrypted.rs @@ -948,6 +948,7 @@ impl Session { handler.signal(channel_num, signal, self).await } x => { + self.common.received_data = false; warn!("unknown channel request {}", String::from_utf8_lossy(x)); self.channel_failure(channel_num); Ok((handler, self)) @@ -1001,6 +1002,7 @@ impl Session { Ok((h, s)) } _ => { + self.common.received_data = false; if let Some(ref mut enc) = self.common.encrypted { push_packet!(enc.write, { enc.write.push(msg::REQUEST_FAILURE); @@ -1040,6 +1042,7 @@ impl Session { Ok((handler, self)) } m => { + self.common.received_data = false; debug!("unknown message received: {:?}", m); Ok((handler, self)) } diff --git a/russh/src/server/mod.rs b/russh/src/server/mod.rs index ab0c9ec8..b5077f7d 100644 --- a/russh/src/server/mod.rs +++ b/russh/src/server/mod.rs @@ -168,6 +168,10 @@ pub struct Config { pub max_auth_attempts: usize, /// Time after which the connection is garbage-collected. pub inactivity_timeout: Option, + /// If nothing is received from the client for this amount of time, send a keepalive message. + pub keepalive_interval: Option, + /// If this many keepalives have been sent without reply, close the connection. + pub keepalive_max: usize, } impl Default for Config { @@ -190,6 +194,8 @@ impl Default for Config { preferred: Default::default(), max_auth_attempts: 10, inactivity_timeout: Some(std::time::Duration::from_secs(600)), + keepalive_interval: None, + keepalive_max: 3, } } } @@ -805,6 +811,8 @@ async fn read_ssh_id( wants_reply: false, disconnected: false, buffer: CryptoVec::new(), + alive_timeouts: 0, + received_data: false, }) } diff --git a/russh/src/server/session.rs b/russh/src/server/session.rs index 9361e579..69ab59d9 100644 --- a/russh/src/server/session.rs +++ b/russh/src/server/session.rs @@ -350,14 +350,23 @@ impl Session { let mut opening_cipher = Box::new(clear::Key) as Box; std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); + let keepalive_timer = + future_or_pending(self.common.config.keepalive_interval, tokio::time::sleep); + pin!(keepalive_timer); + + let inactivity_timer = + future_or_pending(self.common.config.inactivity_timeout, tokio::time::sleep); + pin!(inactivity_timer); + let reading = start_reading(stream_read, buffer, opening_cipher); pin!(reading); let mut is_reading = None; let mut decomp = CryptoVec::new(); - let delay = self.common.config.inactivity_timeout; #[allow(clippy::panic)] // false positive in macro while !self.common.disconnected { + self.common.received_data = false; + let mut only_sent_keepalive = false; tokio::select! { r = &mut reading => { let (stream_read, buffer, mut opening_cipher) = match r { @@ -391,6 +400,7 @@ impl Session { is_reading = Some((stream_read, buffer, opening_cipher)); break; } else if buf[0] > 4 { + self.common.received_data = true; std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local); // TODO it'd be cleaner to just pass cipher to reply() match reply(self, handler, buf).await { @@ -405,10 +415,19 @@ impl Session { } reading.set(start_reading(stream_read, buffer, opening_cipher)); } - _ = timeout(delay) => { + () = &mut keepalive_timer => { + only_sent_keepalive = true; + self.keepalive_request(); + if self.common.config.keepalive_max != 0 && self.common.alive_timeouts > self.common.config.keepalive_max { + debug!("Timeout, server not responding to keepalives"); + break + } + self.common.alive_timeouts = self.common.alive_timeouts.saturating_add(1); + } + () = &mut inactivity_timer => { debug!("timeout"); break - }, + } msg = self.receiver.recv(), if !self.is_rekeying() => { match msg { Some(Msg::Channel(id, ChannelMsg::Data { data })) => { @@ -480,6 +499,23 @@ impl Session { .await .map_err(crate::Error::from)?; self.common.write_buffer.buffer.clear(); + + if self.common.received_data { + if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + keepalive_timer.as_mut().as_pin_mut(), + self.common.config.keepalive_interval, + ) { + sleep.as_mut().reset(tokio::time::Instant::now() + d); + } + } + if !only_sent_keepalive { + if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + inactivity_timer.as_mut().as_pin_mut(), + self.common.config.inactivity_timeout, + ) { + sleep.as_mut().reset(tokio::time::Instant::now() + d); + } + } } debug!("disconnected"); // Shutdown @@ -722,6 +758,18 @@ impl Session { } } + /// Ping the client to verify there is still connectivity. + pub fn keepalive_request(&mut self) { + let want_reply = u8::from(true); + if let Some(ref mut enc) = self.common.encrypted { + push_packet!(enc.write, { + enc.write.push(msg::GLOBAL_REQUEST); + enc.write.extend_ssh_string(b"keepalive@openssh.com"); + enc.write.push(want_reply); + }) + } + } + /// Send the exit status of a program. pub fn exit_status_request(&mut self, channel: ChannelId, exit_status: u32) { if let Some(ref mut enc) = self.common.encrypted { diff --git a/russh/src/session.rs b/russh/src/session.rs index 09afa95a..f2790687 100644 --- a/russh/src/session.rs +++ b/russh/src/session.rs @@ -63,6 +63,8 @@ pub(crate) struct CommonSession { pub wants_reply: bool, pub disconnected: bool, pub buffer: CryptoVec, + pub alive_timeouts: usize, + pub received_data: bool, } impl CommonSession {