From f419e88a79250851f3acb6335beb3926cb78353d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20CORTIER?= Date: Thu, 27 Jun 2024 10:36:37 +0900 Subject: [PATCH] fix(dgw): enforce recording policy When recording flag is set and recording stream is closed, the associated session is killed within 10 seconds. Issue: DGW-86 --- devolutions-gateway/src/proxy.rs | 3 - devolutions-gateway/src/recording.rs | 61 ++++++++++++- devolutions-gateway/src/service.rs | 7 +- devolutions-gateway/src/session.rs | 128 ++++++++++++++++++++++++--- tools/tokengen/src/main.rs | 15 +++- 5 files changed, 194 insertions(+), 20 deletions(-) diff --git a/devolutions-gateway/src/proxy.rs b/devolutions-gateway/src/proxy.rs index 15d0453b1..aa1339724 100644 --- a/devolutions-gateway/src/proxy.rs +++ b/devolutions-gateway/src/proxy.rs @@ -121,9 +121,6 @@ where ) .await?; - // NOTE(DGW-86): when recording is required, should we wait for it to start before we forward, or simply spawn - // a timer to check if the recording is started within a few seconds? - let kill_notified = notify_kill.notified(); let res = if let Some(buffer_size) = self.buffer_size { diff --git a/devolutions-gateway/src/recording.rs b/devolutions-gateway/src/recording.rs index 925a3bce4..6ca100abe 100644 --- a/devolutions-gateway/src/recording.rs +++ b/devolutions-gateway/src/recording.rs @@ -16,6 +16,7 @@ use tokio::{fs, io}; use typed_builder::TypedBuilder; use uuid::Uuid; +use crate::session::SessionMessageSender; use crate::token::{JrecTokenClaims, RecordingFileType}; const DISCONNECTED_TTL_SECS: i64 = 10; @@ -162,6 +163,7 @@ struct OnGoingRecording { state: OnGoingRecordingState, manifest: JrecManifest, manifest_path: Utf8PathBuf, + session_must_be_recorded: bool, } enum RecordingManagerMessage { @@ -310,14 +312,20 @@ pub struct RecordingManagerTask { rx: RecordingMessageReceiver, ongoing_recordings: HashMap, recordings_path: Utf8PathBuf, + session_manager_handle: SessionMessageSender, } impl RecordingManagerTask { - pub fn new(rx: RecordingMessageReceiver, recordings_path: Utf8PathBuf) -> Self { + pub fn new( + rx: RecordingMessageReceiver, + recordings_path: Utf8PathBuf, + session_manager_handle: SessionMessageSender, + ) -> Self { Self { rx, ongoing_recordings: HashMap::new(), recordings_path, + session_manager_handle, } } @@ -389,12 +397,26 @@ impl RecordingManagerTask { let active_recording_count = self.rx.active_recordings.insert(id); + // NOTE: the session associated to this recording is not always running through the Devolutions Gateway. + // It is a normal situation when the Devolutions is used solely as a recording server. + // In such cases, we can only assume there is no recording policy. + let session_must_be_recorded = self + .session_manager_handle + .get_session_info(id) + .await + .inspect_err(|error| error!(%error, session.id = %id, "Failed to retrieve session info")) + .ok() + .flatten() + .map(|info| info.recording_policy) + .unwrap_or(false); + self.ongoing_recordings.insert( id, OnGoingRecording { state: OnGoingRecordingState::Connected, manifest, manifest_path, + session_must_be_recorded, }, ); let ongoing_recording_count = self.ongoing_recordings.len(); @@ -453,9 +475,42 @@ impl RecordingManagerTask { OnGoingRecordingState::LastSeen { timestamp } if now >= timestamp + DISCONNECTED_TTL_SECS - 1 => { debug!(%id, "Mark recording as terminated"); self.rx.active_recordings.remove(id); - self.ongoing_recordings.remove(&id); - // TODO(DGW-86): now is a good timing to kill sessions that _must_ be recorded + // Check the recording policy of the associated session and kill it if necessary. + if ongoing.session_must_be_recorded { + tokio::spawn({ + let session_manager_handle = self.session_manager_handle.clone(); + + async move { + let result = session_manager_handle.kill_session(id).await; + + match result { + Ok(crate::session::KillResult::Success) => { + warn!( + session.id = %id, + reason = "recording policy violated", + "Session killed", + ); + } + Ok(crate::session::KillResult::NotFound) => { + trace!( + session.id = %id, + "Associated session is not running, as expected", + ); + } + Err(error) => { + error!( + session.id = %id, + %error, + "Couldn’t kill session", + ) + } + } + } + }); + } + + self.ongoing_recordings.remove(&id); } _ => { trace!(%id, "Recording should not be removed yet"); diff --git a/devolutions-gateway/src/service.rs b/devolutions-gateway/src/service.rs index dceaf4cdd..b664505ff 100644 --- a/devolutions-gateway/src/service.rs +++ b/devolutions-gateway/src/service.rs @@ -208,7 +208,7 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result { sessions: session_manager_handle.clone(), subscriber_tx: subscriber_tx.clone(), shutdown_signal: tasks.shutdown_signal.clone(), - recordings: recording_manager_handle, + recordings: recording_manager_handle.clone(), }; conf.listeners @@ -243,7 +243,7 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result { )); tasks.register(devolutions_gateway::subscriber::SubscriberPollingTask { - sessions: session_manager_handle, + sessions: session_manager_handle.clone(), subscriber: subscriber_tx, }); @@ -253,12 +253,15 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result { }); tasks.register(devolutions_gateway::session::SessionManagerTask::new( + session_manager_handle.clone(), session_manager_rx, + recording_manager_handle, )); tasks.register(devolutions_gateway::recording::RecordingManagerTask::new( recording_manager_rx, conf.recording_path.clone(), + session_manager_handle, )); Ok(tasks) diff --git a/devolutions-gateway/src/session.rs b/devolutions-gateway/src/session.rs index ca0d15df5..0d51d6475 100644 --- a/devolutions-gateway/src/session.rs +++ b/devolutions-gateway/src/session.rs @@ -1,3 +1,4 @@ +use crate::recording::RecordingMessageSender; use crate::subscriber; use crate::target_addr::TargetAddr; use crate::token::{ApplicationProtocol, SessionTtl}; @@ -133,6 +134,10 @@ enum SessionManagerMessage { info: SessionInfo, notify_kill: Arc, }, + GetInfo { + id: Uuid, + channel: oneshot::Sender>, + }, Remove { id: Uuid, channel: oneshot::Sender>, @@ -155,6 +160,9 @@ impl fmt::Debug for SessionManagerMessage { SessionManagerMessage::New { info, notify_kill: _ } => { f.debug_struct("New").field("info", info).finish_non_exhaustive() } + SessionManagerMessage::GetInfo { id, channel: _ } => { + f.debug_struct("GetInfo").field("id", id).finish_non_exhaustive() + } SessionManagerMessage::Remove { id, channel: _ } => { f.debug_struct("Remove").field("id", id).finish_non_exhaustive() } @@ -179,6 +187,16 @@ impl SessionMessageSender { .context("couldn't send New message") } + pub async fn get_session_info(&self, id: Uuid) -> anyhow::Result> { + let (tx, rx) = oneshot::channel(); + self.0 + .send(SessionManagerMessage::GetInfo { id, channel: tx }) + .await + .ok() + .context("couldn't send Remove message")?; + rx.await.context("couldn't receive info for session") + } + pub async fn remove_session(&self, id: Uuid) -> anyhow::Result> { let (tx, rx) = oneshot::channel(); self.0 @@ -256,26 +274,48 @@ impl Ord for WithTtlInfo { } pub struct SessionManagerTask { + tx: SessionMessageSender, rx: SessionMessageReceiver, all_running: RunningSessions, all_notify_kill: HashMap>, + recording_manager_handle: RecordingMessageSender, } impl SessionManagerTask { - pub fn new(rx: SessionMessageReceiver) -> Self { + pub fn init(recording_manager_handle: RecordingMessageSender) -> Self { + let (tx, rx) = session_manager_channel(); + + Self::new(tx, rx, recording_manager_handle) + } + + pub fn new( + tx: SessionMessageSender, + rx: SessionMessageReceiver, + recording_manager_handle: RecordingMessageSender, + ) -> Self { Self { + tx, rx, all_running: HashMap::new(), all_notify_kill: HashMap::new(), + recording_manager_handle, } } + pub fn handle(&self) -> SessionMessageSender { + self.tx.clone() + } + fn handle_new(&mut self, info: SessionInfo, notify_kill: Arc) { let id = info.association_id; self.all_running.insert(id, info); self.all_notify_kill.insert(id, notify_kill); } + fn handle_get_info(&mut self, id: Uuid) -> Option { + self.all_running.get(&id).cloned() + } + fn handle_remove(&mut self, id: Uuid) -> Option { let removed_session = self.all_running.remove(&id); let _ = self.all_notify_kill.remove(&id); @@ -312,17 +352,14 @@ async fn session_manager_task( debug!("Task started"); let mut with_ttl = BinaryHeap::::new(); - let auto_kill_sleep = tokio::time::sleep_until(tokio::time::Instant::now()); tokio::pin!(auto_kill_sleep); - - // Consume initial sleep - (&mut auto_kill_sleep).await; + (&mut auto_kill_sleep).await; // Consume initial sleep. loop { tokio::select! { () = &mut auto_kill_sleep, if !with_ttl.is_empty() => { - // Will never panic since we check for non-emptiness before entering this block + // Will never panic since we check for non-emptiness before entering this block. let to_kill = with_ttl.pop().unwrap(); match manager.handle_kill(to_kill.session_id) { @@ -334,7 +371,7 @@ async fn session_manager_task( } } - // Re-arm the Sleep instance with the next deadline if required + // Re-arm the Sleep instance with the next deadline if required. if let Some(next) = with_ttl.peek() { auto_kill_sleep.as_mut().reset(next.deadline) } @@ -350,24 +387,41 @@ async fn session_manager_task( match msg { SessionManagerMessage::New { info, notify_kill } => { if let SessionTtl::Limited { minutes } = info.time_to_live { - let duration = Duration::from_secs(minutes.get() * 60); let now = tokio::time::Instant::now(); + let duration = Duration::from_secs(minutes.get() * 60); let deadline = now + duration; + with_ttl.push(WithTtlInfo { deadline, session_id: info.id(), }); - // Reset the Sleep instance if the new deadline is sooner or it is already elapsed + // Reset the Sleep instance if the new deadline is sooner or it is already elapsed. if auto_kill_sleep.is_elapsed() || deadline < auto_kill_sleep.deadline() { auto_kill_sleep.as_mut().reset(deadline); } - debug!(session.id = %info.id(), minutes = minutes.get(), "Limited TTL session registed"); + debug!(session.id = %info.id(), minutes = minutes.get(), "Limited TTL session registered"); + } + + if info.recording_policy { + let task = EnsureRecordingPolicyTask { + session_id: info.id(), + session_manager_handle: manager.tx.clone(), + recording_manager_handle: manager.recording_manager_handle.clone(), + }; + + devolutions_gateway_task::spawn_task(task, shutdown_signal.clone()).detach(); + + debug!(session.id = %info.id(), "Session with recording policy registered"); } manager.handle_new(info, notify_kill); - }, + } + SessionManagerMessage::GetInfo { id, channel } => { + let session_info = manager.handle_get_info(id); + let _ = channel.send(session_info); + } SessionManagerMessage::Remove { id, channel } => { let removed_session = manager.handle_remove(id); let _ = channel.send(removed_session); @@ -416,3 +470,55 @@ async fn session_manager_task( Ok(()) } + +struct EnsureRecordingPolicyTask { + session_id: Uuid, + session_manager_handle: SessionMessageSender, + recording_manager_handle: RecordingMessageSender, +} + +#[async_trait] +impl Task for EnsureRecordingPolicyTask { + type Output = (); + + const NAME: &'static str = "ensure recording policy"; + + async fn run(self, mut shutdown_signal: ShutdownSignal) -> Self::Output { + use futures::future::Either; + use std::pin::pin; + + let sleep = tokio::time::sleep(Duration::from_secs(10)); + let shutdown_signal = shutdown_signal.wait(); + + match futures::future::select(pin!(sleep), pin!(shutdown_signal)).await { + Either::Left(_) => {} + Either::Right(_) => return, + } + + let is_not_recording = self + .recording_manager_handle + .get_state(self.session_id) + .await + .ok() + .flatten() + .is_none(); + + if is_not_recording { + match self.session_manager_handle.kill_session(self.session_id).await { + Ok(KillResult::Success) => { + warn!( + session.id = %self.session_id, + reason = "recording policy violated", + "Session killed", + ); + } + Ok(KillResult::NotFound) => { + trace!(session.id = %self.session_id, "Session already ended"); + } + Err(error) => { + debug!(session.id = %self.session_id, error = format!("{error:#}"), "Couldn’t kill the session"); + } + } + } + } +} diff --git a/tools/tokengen/src/main.rs b/tools/tokengen/src/main.rs index 8d11545a6..acbc69ee7 100644 --- a/tools/tokengen/src/main.rs +++ b/tools/tokengen/src/main.rs @@ -32,6 +32,7 @@ fn main() -> Result<(), Box> { jet_ap, jet_ttl, jet_aid, + jet_rec, } => { let claims = AssociationClaims { exp, @@ -40,6 +41,7 @@ fn main() -> Result<(), Box> { dst_hst: Some(&dst_hst), jet_cm: "fwd", jet_ap: jet_ap.unwrap_or(ApplicationProtocol::Unknown), + jet_rec, jet_aid: jet_aid.unwrap_or_else(Uuid::new_v4), jet_ttl, jet_gw_id: app.jet_gw_id, @@ -62,6 +64,7 @@ fn main() -> Result<(), Box> { dst_hst: Some(&dst_hst), jet_cm: "fwd", jet_ap: ApplicationProtocol::Rdp, + jet_rec: false, jet_aid: jet_aid.unwrap_or_else(Uuid::new_v4), jet_ttl: None, jet_gw_id: app.jet_gw_id, @@ -74,7 +77,11 @@ fn main() -> Result<(), Box> { }; ("ASSOCIATION", serde_json::to_value(claims)?) } - SubCommand::Rendezvous { jet_ap, jet_aid } => { + SubCommand::Rendezvous { + jet_ap, + jet_aid, + jet_rec, + } => { let claims = AssociationClaims { exp, nbf, @@ -82,6 +89,7 @@ fn main() -> Result<(), Box> { dst_hst: None, jet_cm: "rdv", jet_ap: jet_ap.unwrap_or(ApplicationProtocol::Unknown), + jet_rec, jet_aid: jet_aid.unwrap_or_else(Uuid::new_v4), jet_ttl: None, jet_gw_id: app.jet_gw_id, @@ -223,12 +231,16 @@ enum SubCommand { jet_ttl: Option, #[clap(long)] jet_aid: Option, + #[clap(long)] + jet_rec: bool, }, Rendezvous { #[clap(long)] jet_ap: Option, #[clap(long)] jet_aid: Option, + #[clap(long)] + jet_rec: bool, }, RdpTls { #[clap(long)] @@ -287,6 +299,7 @@ struct AssociationClaims<'a> { jti: Uuid, jet_cm: &'a str, jet_ap: ApplicationProtocol, + jet_rec: bool, jet_aid: Uuid, #[serde(skip_serializing_if = "Option::is_none")] jet_ttl: Option,