Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dgw): enforce recording policy #906

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions devolutions-gateway/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,6 @@ where
)
.await?;

// NOTE(DGW-86): when recording is required, should we wait for it to start before we forward, or simply spawn
// a timer to check if the recording is started within a few seconds?

let kill_notified = notify_kill.notified();

let res = if let Some(buffer_size) = self.buffer_size {
Expand Down
61 changes: 58 additions & 3 deletions devolutions-gateway/src/recording.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use tokio::{fs, io};
use typed_builder::TypedBuilder;
use uuid::Uuid;

use crate::session::SessionMessageSender;
use crate::token::{JrecTokenClaims, RecordingFileType};

const DISCONNECTED_TTL_SECS: i64 = 10;
Expand Down Expand Up @@ -162,6 +163,7 @@ struct OnGoingRecording {
state: OnGoingRecordingState,
manifest: JrecManifest,
manifest_path: Utf8PathBuf,
session_must_be_recorded: bool,
}

enum RecordingManagerMessage {
Expand Down Expand Up @@ -310,14 +312,20 @@ pub struct RecordingManagerTask {
rx: RecordingMessageReceiver,
ongoing_recordings: HashMap<Uuid, OnGoingRecording>,
recordings_path: Utf8PathBuf,
session_manager_handle: SessionMessageSender,
}

impl RecordingManagerTask {
pub fn new(rx: RecordingMessageReceiver, recordings_path: Utf8PathBuf) -> Self {
pub fn new(
rx: RecordingMessageReceiver,
recordings_path: Utf8PathBuf,
session_manager_handle: SessionMessageSender,
) -> Self {
Self {
rx,
ongoing_recordings: HashMap::new(),
recordings_path,
session_manager_handle,
}
}

Expand Down Expand Up @@ -389,12 +397,26 @@ impl RecordingManagerTask {

let active_recording_count = self.rx.active_recordings.insert(id);

// NOTE: the session associated to this recording is not always running through the Devolutions Gateway.
// It is a normal situation when the Devolutions is used solely as a recording server.
// In such cases, we can only assume there is no recording policy.
let session_must_be_recorded = self
.session_manager_handle
.get_session_info(id)
.await
.inspect_err(|error| error!(%error, session.id = %id, "Failed to retrieve session info"))
.ok()
.flatten()
.map(|info| info.recording_policy)
.unwrap_or(false);

self.ongoing_recordings.insert(
id,
OnGoingRecording {
state: OnGoingRecordingState::Connected,
manifest,
manifest_path,
session_must_be_recorded,
},
);
let ongoing_recording_count = self.ongoing_recordings.len();
Expand Down Expand Up @@ -453,9 +475,42 @@ impl RecordingManagerTask {
OnGoingRecordingState::LastSeen { timestamp } if now >= timestamp + DISCONNECTED_TTL_SECS - 1 => {
debug!(%id, "Mark recording as terminated");
self.rx.active_recordings.remove(id);
self.ongoing_recordings.remove(&id);

// TODO(DGW-86): now is a good timing to kill sessions that _must_ be recorded
// Check the recording policy of the associated session and kill it if necessary.
if ongoing.session_must_be_recorded {
tokio::spawn({
let session_manager_handle = self.session_manager_handle.clone();

async move {
let result = session_manager_handle.kill_session(id).await;

match result {
Ok(crate::session::KillResult::Success) => {
warn!(
session.id = %id,
reason = "recording policy violated",
"Session killed",
);
}
Ok(crate::session::KillResult::NotFound) => {
trace!(
session.id = %id,
"Associated session is not running, as expected",
);
}
Err(error) => {
error!(
session.id = %id,
%error,
"Couldn’t kill session",
)
}
}
}
});
}

self.ongoing_recordings.remove(&id);
}
_ => {
trace!(%id, "Recording should not be removed yet");
Expand Down
7 changes: 5 additions & 2 deletions devolutions-gateway/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result<Tasks> {
sessions: session_manager_handle.clone(),
subscriber_tx: subscriber_tx.clone(),
shutdown_signal: tasks.shutdown_signal.clone(),
recordings: recording_manager_handle,
recordings: recording_manager_handle.clone(),
};

conf.listeners
Expand Down Expand Up @@ -243,7 +243,7 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result<Tasks> {
));

tasks.register(devolutions_gateway::subscriber::SubscriberPollingTask {
sessions: session_manager_handle,
sessions: session_manager_handle.clone(),
subscriber: subscriber_tx,
});

Expand All @@ -253,12 +253,15 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result<Tasks> {
});

tasks.register(devolutions_gateway::session::SessionManagerTask::new(
session_manager_handle.clone(),
session_manager_rx,
recording_manager_handle,
));

tasks.register(devolutions_gateway::recording::RecordingManagerTask::new(
recording_manager_rx,
conf.recording_path.clone(),
session_manager_handle,
));

Ok(tasks)
Expand Down
128 changes: 117 additions & 11 deletions devolutions-gateway/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::recording::RecordingMessageSender;
use crate::subscriber;
use crate::target_addr::TargetAddr;
use crate::token::{ApplicationProtocol, SessionTtl};
Expand Down Expand Up @@ -133,6 +134,10 @@ enum SessionManagerMessage {
info: SessionInfo,
notify_kill: Arc<Notify>,
},
GetInfo {
id: Uuid,
channel: oneshot::Sender<Option<SessionInfo>>,
},
Remove {
id: Uuid,
channel: oneshot::Sender<Option<SessionInfo>>,
Expand All @@ -155,6 +160,9 @@ impl fmt::Debug for SessionManagerMessage {
SessionManagerMessage::New { info, notify_kill: _ } => {
f.debug_struct("New").field("info", info).finish_non_exhaustive()
}
SessionManagerMessage::GetInfo { id, channel: _ } => {
f.debug_struct("GetInfo").field("id", id).finish_non_exhaustive()
}
SessionManagerMessage::Remove { id, channel: _ } => {
f.debug_struct("Remove").field("id", id).finish_non_exhaustive()
}
Expand All @@ -179,6 +187,16 @@ impl SessionMessageSender {
.context("couldn't send New message")
}

pub async fn get_session_info(&self, id: Uuid) -> anyhow::Result<Option<SessionInfo>> {
let (tx, rx) = oneshot::channel();
self.0
.send(SessionManagerMessage::GetInfo { id, channel: tx })
.await
.ok()
.context("couldn't send Remove message")?;
rx.await.context("couldn't receive info for session")
}

pub async fn remove_session(&self, id: Uuid) -> anyhow::Result<Option<SessionInfo>> {
let (tx, rx) = oneshot::channel();
self.0
Expand Down Expand Up @@ -256,26 +274,48 @@ impl Ord for WithTtlInfo {
}

pub struct SessionManagerTask {
tx: SessionMessageSender,
rx: SessionMessageReceiver,
all_running: RunningSessions,
all_notify_kill: HashMap<Uuid, Arc<Notify>>,
recording_manager_handle: RecordingMessageSender,
}

impl SessionManagerTask {
pub fn new(rx: SessionMessageReceiver) -> Self {
pub fn init(recording_manager_handle: RecordingMessageSender) -> Self {
let (tx, rx) = session_manager_channel();

Self::new(tx, rx, recording_manager_handle)
}

pub fn new(
tx: SessionMessageSender,
rx: SessionMessageReceiver,
recording_manager_handle: RecordingMessageSender,
) -> Self {
Self {
tx,
rx,
all_running: HashMap::new(),
all_notify_kill: HashMap::new(),
recording_manager_handle,
}
}

pub fn handle(&self) -> SessionMessageSender {
self.tx.clone()
}

fn handle_new(&mut self, info: SessionInfo, notify_kill: Arc<Notify>) {
let id = info.association_id;
self.all_running.insert(id, info);
self.all_notify_kill.insert(id, notify_kill);
}

fn handle_get_info(&mut self, id: Uuid) -> Option<SessionInfo> {
self.all_running.get(&id).cloned()
}

fn handle_remove(&mut self, id: Uuid) -> Option<SessionInfo> {
let removed_session = self.all_running.remove(&id);
let _ = self.all_notify_kill.remove(&id);
Expand Down Expand Up @@ -312,17 +352,14 @@ async fn session_manager_task(
debug!("Task started");

let mut with_ttl = BinaryHeap::<WithTtlInfo>::new();

let auto_kill_sleep = tokio::time::sleep_until(tokio::time::Instant::now());
tokio::pin!(auto_kill_sleep);

// Consume initial sleep
(&mut auto_kill_sleep).await;
(&mut auto_kill_sleep).await; // Consume initial sleep.

loop {
tokio::select! {
() = &mut auto_kill_sleep, if !with_ttl.is_empty() => {
// Will never panic since we check for non-emptiness before entering this block
// Will never panic since we check for non-emptiness before entering this block.
let to_kill = with_ttl.pop().unwrap();

match manager.handle_kill(to_kill.session_id) {
Expand All @@ -334,7 +371,7 @@ async fn session_manager_task(
}
}

// Re-arm the Sleep instance with the next deadline if required
// Re-arm the Sleep instance with the next deadline if required.
if let Some(next) = with_ttl.peek() {
auto_kill_sleep.as_mut().reset(next.deadline)
}
Expand All @@ -350,24 +387,41 @@ async fn session_manager_task(
match msg {
SessionManagerMessage::New { info, notify_kill } => {
if let SessionTtl::Limited { minutes } = info.time_to_live {
let duration = Duration::from_secs(minutes.get() * 60);
let now = tokio::time::Instant::now();
let duration = Duration::from_secs(minutes.get() * 60);
let deadline = now + duration;

with_ttl.push(WithTtlInfo {
deadline,
session_id: info.id(),
});

// Reset the Sleep instance if the new deadline is sooner or it is already elapsed
// Reset the Sleep instance if the new deadline is sooner or it is already elapsed.
if auto_kill_sleep.is_elapsed() || deadline < auto_kill_sleep.deadline() {
auto_kill_sleep.as_mut().reset(deadline);
}

debug!(session.id = %info.id(), minutes = minutes.get(), "Limited TTL session registed");
debug!(session.id = %info.id(), minutes = minutes.get(), "Limited TTL session registered");
}

if info.recording_policy {
let task = EnsureRecordingPolicyTask {
session_id: info.id(),
session_manager_handle: manager.tx.clone(),
recording_manager_handle: manager.recording_manager_handle.clone(),
};

devolutions_gateway_task::spawn_task(task, shutdown_signal.clone()).detach();

debug!(session.id = %info.id(), "Session with recording policy registered");
}

manager.handle_new(info, notify_kill);
},
}
SessionManagerMessage::GetInfo { id, channel } => {
let session_info = manager.handle_get_info(id);
let _ = channel.send(session_info);
}
SessionManagerMessage::Remove { id, channel } => {
let removed_session = manager.handle_remove(id);
let _ = channel.send(removed_session);
Expand Down Expand Up @@ -416,3 +470,55 @@ async fn session_manager_task(

Ok(())
}

struct EnsureRecordingPolicyTask {
session_id: Uuid,
session_manager_handle: SessionMessageSender,
recording_manager_handle: RecordingMessageSender,
}

#[async_trait]
impl Task for EnsureRecordingPolicyTask {
type Output = ();

const NAME: &'static str = "ensure recording policy";

async fn run(self, mut shutdown_signal: ShutdownSignal) -> Self::Output {
use futures::future::Either;
use std::pin::pin;

let sleep = tokio::time::sleep(Duration::from_secs(10));
let shutdown_signal = shutdown_signal.wait();

match futures::future::select(pin!(sleep), pin!(shutdown_signal)).await {
Either::Left(_) => {}
Either::Right(_) => return,
}

let is_not_recording = self
.recording_manager_handle
.get_state(self.session_id)
.await
.ok()
.flatten()
.is_none();

if is_not_recording {
match self.session_manager_handle.kill_session(self.session_id).await {
Ok(KillResult::Success) => {
warn!(
session.id = %self.session_id,
reason = "recording policy violated",
"Session killed",
);
}
Ok(KillResult::NotFound) => {
trace!(session.id = %self.session_id, "Session already ended");
}
Err(error) => {
debug!(session.id = %self.session_id, error = format!("{error:#}"), "Couldn’t kill the session");
}
}
}
}
}
Loading