Skip to content

Commit

Permalink
feat(dgw): graceful shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
CBenoit committed May 31, 2023
1 parent 7de2a2f commit ef1d12d
Show file tree
Hide file tree
Showing 20 changed files with 518 additions and 134 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions crates/devolutions-gateway-task/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "devolutions-gateway-task"
version = "0.0.0"
authors = ["Devolutions Inc. <infos@devolutions.net>"]
edition = "2021"
publish = false

[features]
default = []
named_tasks = ["tokio/tracing"]

[dependencies]
tokio = { version = "1.28.1", features = ["sync", "rt", "tracing"] }
async-trait = "0.1.68"
115 changes: 115 additions & 0 deletions crates/devolutions-gateway-task/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use std::future::Future;

use async_trait::async_trait;
use tokio::task::JoinHandle;

#[derive(Debug)]
pub struct ShutdownHandle(tokio::sync::watch::Sender<()>);

impl ShutdownHandle {
pub fn new() -> (Self, ShutdownSignal) {
let (sender, receiver) = tokio::sync::watch::channel(());
(Self(sender), ShutdownSignal(receiver))
}

pub fn signal(&self) {
let _ = self.0.send(());
}

pub async fn all_closed(&self) {
self.0.closed().await;
}
}

#[derive(Clone, Debug)]
pub struct ShutdownSignal(tokio::sync::watch::Receiver<()>);

impl ShutdownSignal {
pub async fn wait(&mut self) {
let _ = self.0.changed().await;
}
}

/// Aborts the running task when dropped.
/// Also see https://github.com/tokio-rs/tokio/issues/1830 for some background.
#[must_use]
pub struct ChildTask<T>(JoinHandle<T>);

impl<T> ChildTask<T> {
pub fn spawn<F>(future: F) -> Self
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
ChildTask(tokio::task::spawn(future))
}

pub async fn join(mut self) -> Result<T, tokio::task::JoinError> {
(&mut self.0).await
}

/// Immediately abort the task
pub fn abort(&self) {
self.0.abort()
}

/// Drop without aborting the task
pub fn detach(self) {
core::mem::forget(self);
}
}

impl<T> From<JoinHandle<T>> for ChildTask<T> {
fn from(value: JoinHandle<T>) -> Self {
Self(value)
}
}

impl<T> Drop for ChildTask<T> {
fn drop(&mut self) {
self.abort();
}
}

#[async_trait]
pub trait Task {
type Output: Send;

const NAME: &'static str;

async fn run(self, shutdown_signal: ShutdownSignal) -> Self::Output;
}

pub fn spawn_task<T>(task: T, shutdown_signal: ShutdownSignal) -> ChildTask<T::Output>
where
T: Task + 'static,
{
let task_fut = task.run(shutdown_signal);
let handle = spawn_task_impl(task_fut, T::NAME);
ChildTask(handle)
}

#[cfg(not(all(feature = "named_tasks", tokio_unstable)))]
#[track_caller]
fn spawn_task_impl<T>(future: T, _name: &str) -> JoinHandle<T::Output>
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
tokio::task::spawn(future)
}

#[cfg(all(feature = "named_tasks", tokio_unstable))]
#[track_caller]
fn spawn_task_impl<T>(future: T, name: &str) -> JoinHandle<T::Output>
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
// NOTE: To enable this code, this crate must be built using a command similar to:
// > RUSTFLAGS="--cfg tokio_unstable" cargo check --features named_tasks

// NOTE: unwrap because as of now (tokio 1.28), this never returns an error,
// and production build never enable tokio-console instrumentation anyway (unstable).
tokio::task::Builder::new().name(name).spawn(future).unwrap()
}
2 changes: 0 additions & 2 deletions crates/sogar-registry/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// TODO: extract this module into another crate

mod push_files;
mod sogar_auth;
mod sogar_token;
Expand Down
1 change: 1 addition & 0 deletions devolutions-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ transport = { path = "../crates/transport" }
jmux-proxy = { path = "../crates/jmux-proxy" }
ironrdp-pdu = { version = "0.1.0", git = "https://github.com/Devolutions/IronRDP", rev = "f11131b18afa5f8ff9" }
ironrdp-rdcleanpath = { version = "0.1.0", git = "https://github.com/Devolutions/IronRDP", rev = "f11131b18afa5f8ff9" }
devolutions-gateway-task = { path = "../crates/devolutions-gateway-task" }
ceviche = "0.5.2"
picky-krb = "0.6.0"

Expand Down
2 changes: 1 addition & 1 deletion devolutions-gateway/src/api/jrec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async fn handle_jrec_push(
) {
let stream = crate::ws::websocket_compat(ws);

let result = crate::jrec::ClientPush::builder()
let result = crate::recording::ClientPush::builder()
.client_stream(stream)
.conf(conf)
.claims(claims)
Expand Down
3 changes: 2 additions & 1 deletion devolutions-gateway/src/interceptor/pcap.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::interceptor::{Dissector, Inspector, PeerSide};
use anyhow::Context as _;
use bytes::BytesMut;
use devolutions_gateway_task::ChildTask;
use packet::builder::Builder;
use packet::ether::{Builder as BuildEthernet, Protocol};
use packet::ip::v6::Builder as BuildV6;
Expand Down Expand Up @@ -40,7 +41,7 @@ impl PcapInspector {

let (sender, receiver) = mpsc::unbounded_channel();

tokio::spawn(writer_task(receiver, pcap_writer, client_addr, server_addr, dissector));
ChildTask::spawn(writer_task(receiver, pcap_writer, client_addr, server_addr, dissector)).detach();

Ok((
Self {
Expand Down
6 changes: 4 additions & 2 deletions devolutions-gateway/src/interceptor/plugin_recording.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::interceptor::{Inspector, PeerSide};
use crate::plugin_manager::{PacketsParser, Recorder, PLUGIN_MANAGER};
use anyhow::Context as _;
use devolutions_gateway_task::ChildTask;
use parking_lot::{Condvar, Mutex};
use std::path::PathBuf;
use std::sync::Arc;
Expand Down Expand Up @@ -80,13 +81,14 @@ impl PluginRecordingInspector {
move || timeout_task(recorder, condition_timeout)
});

tokio::spawn(inspector_task(
ChildTask::spawn(inspector_task(
receiver,
handle,
packet_parser,
recorder,
condition_timeout,
));
))
.detach();

Ok(InitResult {
client_inspector: Self {
Expand Down
21 changes: 10 additions & 11 deletions devolutions-gateway/src/jmux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::subscriber::SubscriberSender;
use crate::token::JmuxTokenClaims;

use anyhow::Context as _;
use devolutions_gateway_task::ChildTask;
use jmux_proxy::JmuxProxy;
use tap::prelude::*;
use tokio::io::{AsyncRead, AsyncWrite};
Expand All @@ -17,7 +18,6 @@ pub async fn handle(
sessions: SessionManagerHandle,
subscriber_tx: SubscriberSender,
) -> anyhow::Result<()> {
use futures::future::Either;
use jmux_proxy::{FilteringRule, JmuxConfig};

let (reader, writer) = tokio::io::split(stream);
Expand Down Expand Up @@ -58,20 +58,19 @@ pub async fn handle(
crate::session::add_session_in_progress(&sessions, &subscriber_tx, info, notify_kill.clone()).await?;

let proxy_fut = JmuxProxy::new(reader, writer).with_config(config).run();

let proxy_handle = tokio::spawn(proxy_fut);
tokio::pin!(proxy_handle);
let proxy_handle = ChildTask::spawn(proxy_fut);
let join_fut = proxy_handle.join();
tokio::pin!(join_fut);

let kill_notified = notify_kill.notified();
tokio::pin!(kill_notified);

let res = match futures::future::select(proxy_handle, kill_notified).await {
Either::Left((Ok(res), _)) => res.context("JMUX proxy error"),
Either::Left((Err(e), _)) => anyhow::Error::new(e).context("Failed to wait for proxy task").pipe(Err),
Either::Right((_, proxy_handle)) => {
proxy_handle.abort();
Ok(())
}
let res = tokio::select! {
res = join_fut => match res {
Ok(res) => res.context("JMUX proxy error"),
Err(e) => anyhow::Error::new(e).context("Failed to wait for proxy task").pipe(Err),
},
_ = kill_notified => Ok(()),
};

crate::session::remove_session_in_progress(&sessions, &subscriber_tx, session_id).await?;
Expand Down
3 changes: 2 additions & 1 deletion devolutions-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ pub mod generic_client;
pub mod http;
pub mod interceptor;
pub mod jmux;
pub mod jrec;
pub mod listener;
pub mod log;
pub mod middleware;
Expand All @@ -24,6 +23,7 @@ pub mod plugin_manager;
pub mod proxy;
pub mod rdp_extension;
pub mod rdp_pcb;
pub mod recording;
pub mod session;
pub mod subscriber;
pub mod target_addr;
Expand All @@ -39,6 +39,7 @@ pub struct DgwState {
pub jrl: Arc<token::CurrentJrl>,
pub sessions: session::SessionManagerHandle,
pub subscriber_tx: subscriber::SubscriberSender,
pub shutdown_signal: devolutions_gateway_task::ShutdownSignal,
}

pub fn make_http_service(state: DgwState) -> axum::Router<()> {
Expand Down
26 changes: 22 additions & 4 deletions devolutions-gateway/src/listener.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use anyhow::Context;
use async_trait::async_trait;
use devolutions_gateway_task::{ChildTask, ShutdownSignal, Task};
use std::net::SocketAddr;
use tap::Pipe as _;
use tokio::io::{AsyncRead, AsyncWrite};
Expand Down Expand Up @@ -92,16 +94,32 @@ impl GatewayListener {
}
}

#[async_trait]
impl Task for GatewayListener {
type Output = anyhow::Result<()>;

const NAME: &'static str = "gateway listener";

async fn run(self, mut shutdown_signal: ShutdownSignal) -> Self::Output {
tokio::select! {
result = self.run() => result,
_ = shutdown_signal.wait() => Ok(()),
}
}
}

async fn run_tcp_listener(listener: TcpListener, state: DgwState) -> anyhow::Result<()> {
loop {
match listener.accept().await.context("failed to accept connection") {
Ok((stream, peer_addr)) => {
let state = state.clone();
tokio::spawn(async move {

ChildTask::spawn(async move {
if let Err(e) = handle_tcp_peer(stream, state, peer_addr).await {
error!(error = format!("{e:#}"), "Peer failure");
}
});
})
.detach();
}
Err(e) => error!(error = format!("{e:#}"), "Listener failure"),
}
Expand Down Expand Up @@ -156,7 +174,7 @@ async fn run_http_listener(listener: TcpListener, state: DgwState) -> anyhow::Re
}
.instrument(info_span!("http", client = %peer_addr));

tokio::spawn(fut);
ChildTask::spawn(fut).detach();
}
Err(error) => {
error!(%error, "failed to accept connection");
Expand All @@ -183,7 +201,7 @@ async fn run_https_listener(listener: TcpListener, state: DgwState) -> anyhow::R
}
.instrument(info_span!("https", client = %peer_addr));

tokio::spawn(fut);
ChildTask::spawn(fut).detach();
}
Err(error) => {
error!(%error, "failed to accept connection");
Expand Down
Loading

0 comments on commit ef1d12d

Please sign in to comment.