Skip to content

Commit

Permalink
Add ClientAliveInterval analogue to server
Browse files Browse the repository at this point in the history
Also bring client and server into parity regarding timers.

Also, per OpenSSH documentation, only reset keepalive timer when
receiving data, not when sending it.

Also, always reset the inactivity timer unless the iteration was ended
via sending a keepalive request.
  • Loading branch information
mmirate committed Dec 6, 2023
1 parent b1aab65 commit 7b9a3f0
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 56 deletions.
13 changes: 5 additions & 8 deletions russh/src/client/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,6 @@ impl Session {
);
match req {
b"xon-xoff" => {
self.activity = false;
r.read_byte().map_err(crate::Error::from)?; // should be 0.
let client_can_do = r.read_byte().map_err(crate::Error::from)? != 0;
if let Some(chan) = self.channels.get(&channel_num) {
Expand Down Expand Up @@ -572,7 +571,6 @@ impl Session {
.await
}
b"keepalive@openssh.com" => {
self.activity = false;
let wants_reply = r.read_byte().map_err(crate::Error::from)?;
if wants_reply == 1 {
if let Some(ref mut enc) = self.common.encrypted {
Expand All @@ -592,7 +590,7 @@ impl Session {
Ok((client, self))
}
_ => {
self.activity = false;
self.common.received_data = false;
let wants_reply = r.read_byte().map_err(crate::Error::from)?;
if wants_reply == 1 {
if let Some(ref mut enc) = self.common.encrypted {
Expand Down Expand Up @@ -692,7 +690,7 @@ impl Session {
push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE))
}
}
self.activity = false;
self.common.received_data = false;
Ok((client, self))
}
Some(&msg::CHANNEL_SUCCESS) => {
Expand Down Expand Up @@ -806,15 +804,14 @@ impl Session {
}
}
Some(&msg::REQUEST_SUCCESS | &msg::REQUEST_FAILURE)
if self.server_alive_timeouts > 0 =>
if self.common.alive_timeouts > 0 =>
{
self.activity = false;
// TODO what other things might need to happen in response to these two opcodes?
self.server_alive_timeouts = 0;
self.common.alive_timeouts = 0;
Ok((client, self))
}
_ => {
self.activity = false;
self.common.received_data = false;
info!("Unhandled packet: {:?}", buf);
Ok((client, self))
}
Expand Down
68 changes: 29 additions & 39 deletions russh/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ pub struct Session {
pending_len: u32,
inbound_channel_sender: Sender<Msg>,
inbound_channel_receiver: Receiver<Msg>,
server_alive_timeouts: usize,
activity: bool,
}

impl Drop for Session {
Expand Down Expand Up @@ -693,6 +691,8 @@ where
wants_reply: false,
disconnected: false,
buffer: CryptoVec::new(),
alive_timeouts: 0,
received_data: false,
},
session_receiver,
session_sender,
Expand Down Expand Up @@ -723,16 +723,6 @@ async fn start_reading<R: AsyncRead + Unpin>(
Ok((n, stream_read, buffer, cipher))
}

fn future_or_pending<F: futures::Future, T>(
val: Option<T>,
f: impl FnOnce(T) -> F,
) -> futures::future::Either<futures::future::Pending<<F as futures::Future>::Output>, F> {
val.map_or(
futures::future::Either::Left(futures::future::pending()),
|x| futures::future::Either::Right(f(x)),
)
}

impl Session {
fn new(
target_window_size: u32,
Expand All @@ -751,8 +741,6 @@ impl Session {
channels: HashMap::new(),
pending_reads: Vec::new(),
pending_len: 0,
server_alive_timeouts: 0,
activity: false,
}
}

Expand Down Expand Up @@ -782,32 +770,21 @@ impl Session {
std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local);

let keepalive_timer =
future_or_pending(self.common.config.keepalive_interval, tokio::time::sleep);
crate::future_or_pending(self.common.config.keepalive_interval, tokio::time::sleep);
pin!(keepalive_timer);

let inactivity_timer =
future_or_pending(self.common.config.inactivity_timeout, tokio::time::sleep);
crate::future_or_pending(self.common.config.inactivity_timeout, tokio::time::sleep);
pin!(inactivity_timer);

let reading = start_reading(stream_read, buffer, opening_cipher);
pin!(reading);

#[allow(clippy::panic)] // false positive in select! macro
while !self.common.disconnected {
self.activity = false;
self.common.received_data = false;
let mut only_sent_keepalive = false;
tokio::select! {
() = &mut keepalive_timer => {
self.send_keepalive(true);
if self.common.config.keepalive_max != 0 && self.server_alive_timeouts > self.common.config.keepalive_max {
debug!("Timeout, server not responding to keepalives");
break
}
self.server_alive_timeouts = self.server_alive_timeouts.saturating_add(1);
}
() = &mut inactivity_timer => {
debug!("timeout");
break
}
r = &mut reading => {
let (stream_read, buffer, mut opening_cipher) = match r {
Ok((_, stream_read, buffer, opening_cipher)) => (stream_read, buffer, opening_cipher),
Expand Down Expand Up @@ -839,7 +816,7 @@ impl Session {
if buf[0] == crate::msg::DISCONNECT {
break;
} else if buf[0] > 4 {
self.activity = true;
self.common.received_data = true;
let (h, s) = reply(self, handler, &mut encrypted_signal, buf).await?;
handler = h;
self = s;
Expand All @@ -849,6 +826,19 @@ impl Session {
std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local);
reading.set(start_reading(stream_read, buffer, opening_cipher));
}
() = &mut keepalive_timer => {
self.send_keepalive(true);
only_sent_keepalive = true;
if self.common.config.keepalive_max != 0 && self.common.alive_timeouts > self.common.config.keepalive_max {
debug!("Timeout, server not responding to keepalives");
break
}
self.common.alive_timeouts = self.common.alive_timeouts.saturating_add(1);
}
() = &mut inactivity_timer => {
debug!("timeout");
break
}
msg = self.receiver.recv(), if !self.is_rekeying() => {
match msg {
Some(msg) => self.handle_msg(msg)?,
Expand Down Expand Up @@ -888,7 +878,6 @@ impl Session {
"writing to stream: {:?} bytes",
self.common.write_buffer.buffer.len()
);
self.activity = true;
stream_write
.write_all(&self.common.write_buffer.buffer)
.await
Expand All @@ -903,14 +892,15 @@ impl Session {
}
}

if let (futures::future::Either::Right(ref mut sleep), Some(d)) = (
keepalive_timer.as_mut().as_pin_mut(),
self.common.config.keepalive_interval,
) {
sleep.as_mut().reset(tokio::time::Instant::now() + d);
if self.common.received_data {
if let (futures::future::Either::Right(ref mut sleep), Some(d)) = (
keepalive_timer.as_mut().as_pin_mut(),
self.common.config.keepalive_interval,
) {
sleep.as_mut().reset(tokio::time::Instant::now() + d);
}
}

if self.activity {
if !only_sent_keepalive {
if let (futures::future::Either::Right(ref mut sleep), Some(d)) = (
inactivity_timer.as_mut().as_pin_mut(),
self.common.config.inactivity_timeout,
Expand Down Expand Up @@ -1315,7 +1305,7 @@ pub struct Config {
pub preferred: negotiation::Preferred,
/// Time after which the connection is garbage-collected.
pub inactivity_timeout: Option<std::time::Duration>,
/// If nothing is sent or received for this amount of time, send a keepalive message.
/// If nothing is received from the server for this amount of time, send a keepalive message.
pub keepalive_interval: Option<std::time::Duration>,
/// If this many keepalives have been sent without reply, close the connection.
pub keepalive_max: usize,
Expand Down
14 changes: 8 additions & 6 deletions russh/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,12 @@ impl ChannelParams {
}
}

pub(crate) async fn timeout(delay: Option<std::time::Duration>) {
if let Some(delay) = delay {
tokio::time::sleep(delay).await
} else {
futures::future::pending().await
};
pub(crate) fn future_or_pending<F: futures::Future, T>(
val: Option<T>,
f: impl FnOnce(T) -> F,
) -> futures::future::Either<futures::future::Pending<<F as futures::Future>::Output>, F> {
val.map_or(
futures::future::Either::Left(futures::future::pending()),
|x| futures::future::Either::Right(f(x)),
)
}
3 changes: 3 additions & 0 deletions russh/src/server/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,7 @@ impl Session {
handler.signal(channel_num, signal, self).await
}
x => {
self.common.received_data = false;
warn!("unknown channel request {}", String::from_utf8_lossy(x));
self.channel_failure(channel_num);
Ok((handler, self))
Expand Down Expand Up @@ -1001,6 +1002,7 @@ impl Session {
Ok((h, s))
}
_ => {
self.common.received_data = false;
if let Some(ref mut enc) = self.common.encrypted {
push_packet!(enc.write, {
enc.write.push(msg::REQUEST_FAILURE);
Expand Down Expand Up @@ -1040,6 +1042,7 @@ impl Session {
Ok((handler, self))
}
m => {
self.common.received_data = false;
debug!("unknown message received: {:?}", m);
Ok((handler, self))
}
Expand Down
8 changes: 8 additions & 0 deletions russh/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ pub struct Config {
pub max_auth_attempts: usize,
/// Time after which the connection is garbage-collected.
pub inactivity_timeout: Option<std::time::Duration>,
/// If nothing is received from the client for this amount of time, send a keepalive message.
pub keepalive_interval: Option<std::time::Duration>,
/// If this many keepalives have been sent without reply, close the connection.
pub keepalive_max: usize,
}

impl Default for Config {
Expand All @@ -190,6 +194,8 @@ impl Default for Config {
preferred: Default::default(),
max_auth_attempts: 10,
inactivity_timeout: Some(std::time::Duration::from_secs(600)),
keepalive_interval: None,
keepalive_max: 3,
}
}
}
Expand Down Expand Up @@ -805,6 +811,8 @@ async fn read_ssh_id<R: AsyncRead + Unpin>(
wants_reply: false,
disconnected: false,
buffer: CryptoVec::new(),
alive_timeouts: 0,
received_data: false,
})
}

Expand Down
54 changes: 51 additions & 3 deletions russh/src/server/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,23 @@ impl Session {
let mut opening_cipher = Box::new(clear::Key) as Box<dyn OpeningKey + Send>;
std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local);

let keepalive_timer =
future_or_pending(self.common.config.keepalive_interval, tokio::time::sleep);
pin!(keepalive_timer);

let inactivity_timer =
future_or_pending(self.common.config.inactivity_timeout, tokio::time::sleep);
pin!(inactivity_timer);

let reading = start_reading(stream_read, buffer, opening_cipher);
pin!(reading);
let mut is_reading = None;
let mut decomp = CryptoVec::new();
let delay = self.common.config.inactivity_timeout;

#[allow(clippy::panic)] // false positive in macro
while !self.common.disconnected {
self.common.received_data = false;
let mut only_sent_keepalive = false;
tokio::select! {
r = &mut reading => {
let (stream_read, buffer, mut opening_cipher) = match r {
Expand Down Expand Up @@ -391,6 +400,7 @@ impl Session {
is_reading = Some((stream_read, buffer, opening_cipher));
break;
} else if buf[0] > 4 {
self.common.received_data = true;
std::mem::swap(&mut opening_cipher, &mut self.common.cipher.remote_to_local);
// TODO it'd be cleaner to just pass cipher to reply()
match reply(self, handler, buf).await {
Expand All @@ -405,10 +415,19 @@ impl Session {
}
reading.set(start_reading(stream_read, buffer, opening_cipher));
}
_ = timeout(delay) => {
() = &mut keepalive_timer => {
only_sent_keepalive = true;
self.keepalive_request();
if self.common.config.keepalive_max != 0 && self.common.alive_timeouts > self.common.config.keepalive_max {
debug!("Timeout, server not responding to keepalives");
break
}
self.common.alive_timeouts = self.common.alive_timeouts.saturating_add(1);
}
() = &mut inactivity_timer => {
debug!("timeout");
break
},
}
msg = self.receiver.recv(), if !self.is_rekeying() => {
match msg {
Some(Msg::Channel(id, ChannelMsg::Data { data })) => {
Expand Down Expand Up @@ -480,6 +499,23 @@ impl Session {
.await
.map_err(crate::Error::from)?;
self.common.write_buffer.buffer.clear();

if self.common.received_data {
if let (futures::future::Either::Right(ref mut sleep), Some(d)) = (
keepalive_timer.as_mut().as_pin_mut(),
self.common.config.keepalive_interval,
) {
sleep.as_mut().reset(tokio::time::Instant::now() + d);
}
}
if !only_sent_keepalive {
if let (futures::future::Either::Right(ref mut sleep), Some(d)) = (
inactivity_timer.as_mut().as_pin_mut(),
self.common.config.inactivity_timeout,
) {
sleep.as_mut().reset(tokio::time::Instant::now() + d);
}
}
}
debug!("disconnected");
// Shutdown
Expand Down Expand Up @@ -722,6 +758,18 @@ impl Session {
}
}

/// Ping the client to verify there is still connectivity.
pub fn keepalive_request(&mut self) {
let want_reply = u8::from(true);
if let Some(ref mut enc) = self.common.encrypted {
push_packet!(enc.write, {
enc.write.push(msg::GLOBAL_REQUEST);
enc.write.extend_ssh_string(b"keepalive@openssh.com");
enc.write.push(want_reply);
})
}
}

/// Send the exit status of a program.
pub fn exit_status_request(&mut self, channel: ChannelId, exit_status: u32) {
if let Some(ref mut enc) = self.common.encrypted {
Expand Down
2 changes: 2 additions & 0 deletions russh/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ pub(crate) struct CommonSession<Config> {
pub wants_reply: bool,
pub disconnected: bool,
pub buffer: CryptoVec,
pub alive_timeouts: usize,
pub received_data: bool,
}

impl<C> CommonSession<C> {
Expand Down

0 comments on commit 7b9a3f0

Please sign in to comment.