From 46368f6d43bd83177a8f983229e6a17eb6684c53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Cortier?= Date: Tue, 14 Mar 2023 18:41:26 -0400 Subject: [PATCH] feat(dgw): WebSocket-TLS endpoint (/jet/tls) (#400) Issue: DGW-83 --- devolutions-gateway/src/lib.rs | 2 +- devolutions-gateway/src/rdp.rs | 3 +- devolutions-gateway/src/rdp_extension.rs | 5 +- devolutions-gateway/src/tcp.rs | 89 --------- devolutions-gateway/src/websocket_client.rs | 92 ++++++++- devolutions-gateway/src/websocket_forward.rs | 187 +++++++++++++++++++ 6 files changed, 282 insertions(+), 96 deletions(-) delete mode 100644 devolutions-gateway/src/tcp.rs create mode 100644 devolutions-gateway/src/websocket_forward.rs diff --git a/devolutions-gateway/src/lib.rs b/devolutions-gateway/src/lib.rs index 7a51c6050..a20ba2a46 100644 --- a/devolutions-gateway/src/lib.rs +++ b/devolutions-gateway/src/lib.rs @@ -27,11 +27,11 @@ pub mod registry; pub mod service; pub mod session; pub mod subscriber; -pub mod tcp; pub mod token; pub mod transport; pub mod utils; pub mod websocket_client; +pub mod websocket_forward; pub mod tls_sanity { use anyhow::Context as _; diff --git a/devolutions-gateway/src/rdp.rs b/devolutions-gateway/src/rdp.rs index 17e0c06a1..0fbedf56b 100644 --- a/devolutions-gateway/src/rdp.rs +++ b/devolutions-gateway/src/rdp.rs @@ -19,8 +19,7 @@ use crate::utils::{self, TargetAddr}; use anyhow::Context; use bytes::BytesMut; use nonempty::NonEmpty; -use sspi::credssp; -use sspi::AuthIdentity; +use sspi::{credssp, AuthIdentity}; use std::io; use std::net::SocketAddr; use std::sync::Arc; diff --git a/devolutions-gateway/src/rdp_extension.rs b/devolutions-gateway/src/rdp_extension.rs index 231eb46d7..cc19e4043 100644 --- a/devolutions-gateway/src/rdp_extension.rs +++ b/devolutions-gateway/src/rdp_extension.rs @@ -190,7 +190,10 @@ async fn process_cleanpath( let mut server_transport = { // Establish TLS connection with server - let dns_name = "stub_string".try_into().unwrap(); + let dns_name = destination + .host() + .try_into() + .context("Invalid DNS name in selected target")?; // TODO: optimize client config creation // diff --git a/devolutions-gateway/src/tcp.rs b/devolutions-gateway/src/tcp.rs deleted file mode 100644 index eb8a1ed07..000000000 --- a/devolutions-gateway/src/tcp.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::net::SocketAddr; -use std::sync::Arc; - -use crate::config::Conf; -use crate::proxy::Proxy; -use crate::session::{ConnectionModeDetails, SessionInfo, SessionManagerHandle}; -use crate::subscriber::SubscriberSender; -use crate::token::{AssociationTokenClaims, ConnectionMode, CurrentJrl, TokenCache, TokenError}; -use crate::utils; - -use anyhow::Context as _; -use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite}; - -#[derive(Debug, Error)] -pub enum AuthorizationError { - #[error("token not allowed")] - Forbidden, - #[error("bad token")] - BadToken(#[from] TokenError), -} - -pub fn authorize( - client_addr: SocketAddr, - token: &str, - conf: &Conf, - token_cache: &TokenCache, - jrl: &CurrentJrl, -) -> Result { - use crate::token::AccessTokenClaims; - - if let AccessTokenClaims::Association(claims) = - crate::http::middlewares::auth::authenticate(client_addr, token, conf, token_cache, jrl)? - { - Ok(claims) - } else { - Err(AuthorizationError::Forbidden) - } -} - -#[instrument(skip_all)] -pub async fn handle( - client_stream: impl AsyncRead + AsyncWrite + Unpin + Send + 'static, - client_addr: SocketAddr, - conf: Arc, - claims: AssociationTokenClaims, - sessions: SessionManagerHandle, - subscriber_tx: SubscriberSender, -) -> anyhow::Result<()> { - info!( - "Starting WebSocket-TCP forwarding with application protocol {:?}", - claims.jet_ap - ); - - if claims.jet_rec { - anyhow::bail!("can't meet recording policy"); - } - - let ConnectionMode::Fwd { targets, .. } = claims.jet_cm else { - anyhow::bail!("invalid connection mode") - }; - - let (server_transport, selected_target) = utils::successive_try(&targets, utils::tcp_transport_connect).await?; - - let info = SessionInfo::new( - claims.jet_aid, - claims.jet_ap, - ConnectionModeDetails::Fwd { - destination_host: selected_target.clone(), - }, - ) - .with_ttl(claims.jet_ttl) - .with_recording_policy(claims.jet_rec) - .with_filtering_policy(claims.jet_flt); - - Proxy::builder() - .conf(conf) - .session_info(info) - .address_a(client_addr) - .transport_a(client_stream) - .address_b(server_transport.addr) - .transport_b(server_transport) - .sessions(sessions) - .subscriber_tx(subscriber_tx) - .build() - .select_dissector_and_forward() - .await - .context("Encountered a failure during plain tcp traffic proxying") -} diff --git a/devolutions-gateway/src/websocket_client.rs b/devolutions-gateway/src/websocket_client.rs index b81f2d884..859b25296 100644 --- a/devolutions-gateway/src/websocket_client.rs +++ b/devolutions-gateway/src/websocket_client.rs @@ -93,6 +93,19 @@ impl WebsocketService { ) .await .map_err(|err| io::Error::new(ErrorKind::Other, format!("Handle TCP error - {err:#}"))) + } else if req.method() == Method::GET && req_uri.starts_with("/jet/tls") { + info!("{} {}", req.method(), req_uri); + handle_tls( + req, + client_addr, + self.conf.clone(), + &self.token_cache, + &self.jrl, + self.sessions.clone(), + self.subscriber_tx.clone(), + ) + .await + .map_err(|err| io::Error::new(ErrorKind::Other, format!("Handle TLS error - {err:#}"))) } else { saphir::server::inject_raw_with_peer_addr(req, Some(client_addr)) .await @@ -462,7 +475,8 @@ fn process_req(req: &Request) -> Response { Author: Ran Benita (ran234@gmail.com) */ - use base64::{engine::general_purpose::STANDARD, Engine as _}; + use base64::engine::general_purpose::STANDARD; + use base64::Engine as _; fn convert_key(input: &[u8]) -> String { const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; @@ -643,7 +657,7 @@ async fn handle_tcp( anyhow::bail!("missing authorization"); // AUTHORIZATION }; - let claims = crate::tcp::authorize(client_addr, token, &conf, token_cache, jrl)?; // FORBIDDEN + let claims = crate::websocket_forward::authorize(client_addr, token, &conf, token_cache, jrl)?; // FORBIDDEN if let Some(upgrade_val) = req.headers().get("upgrade").and_then(|v| v.to_str().ok()) { if upgrade_val != "websocket" { @@ -656,7 +670,16 @@ async fn handle_tcp( tokio::spawn(async move { let fut = async { let stream = upgrade_websocket(&mut req).await?; - crate::tcp::handle(stream, client_addr, conf, claims, sessions, subscriber_tx).await + crate::websocket_forward::PlainForward::builder() + .client_addr(client_addr) + .client_stream(stream) + .conf(conf) + .claims(claims) + .sessions(sessions) + .subscriber_tx(subscriber_tx) + .build() + .run() + .await } .instrument(info_span!("tcp", client = %client_addr)); @@ -669,6 +692,69 @@ async fn handle_tcp( Ok(rsp) } +async fn handle_tls( + mut req: Request, + client_addr: SocketAddr, + conf: Arc, + token_cache: &TokenCache, + jrl: &CurrentJrl, + sessions: SessionManagerHandle, + subscriber_tx: SubscriberSender, +) -> anyhow::Result> { + use crate::http::middlewares::auth::{parse_auth_header, AuthHeaderType}; + + let token = if let Some(authorization_value) = req.headers().get(header::AUTHORIZATION) { + let authorization_value = authorization_value.to_str().context("bad authorization header value")?; // BAD REQUEST + match parse_auth_header(authorization_value) { + Some((AuthHeaderType::Bearer, token)) => token, + _ => anyhow::bail!("bad authorization header value"), // BAD REQUEST + } + } else if let Some(token) = req.uri().query().and_then(|q| { + q.split('&') + .filter_map(|segment| segment.split_once('=')) + .find_map(|(key, val)| key.eq("token").then_some(val)) + }) { + token + } else { + anyhow::bail!("missing authorization"); // AUTHORIZATION + }; + + let claims = crate::websocket_forward::authorize(client_addr, token, &conf, token_cache, jrl)?; // FORBIDDEN + + if let Some(upgrade_val) = req.headers().get("upgrade").and_then(|v| v.to_str().ok()) { + if upgrade_val != "websocket" { + anyhow::bail!("unexpected upgrade header value: {}", upgrade_val) // BAD REQUEST + } + } + + let rsp = process_req(&req); + + tokio::spawn(async move { + let fut = async { + let stream = upgrade_websocket(&mut req).await?; + crate::websocket_forward::PlainForward::builder() + .client_addr(client_addr) + .client_stream(stream) + .conf(conf) + .claims(claims) + .sessions(sessions) + .subscriber_tx(subscriber_tx) + .with_tls(true) + .build() + .run() + .await + } + .instrument(info_span!("tls", client = %client_addr)); + + match fut.await { + Ok(()) => {} + Err(error) => error!(client = %client_addr, error = format!("{error:#}"), "WebSocket-TLS failure"), + } + }); + + Ok(rsp) +} + type WebsocketTransport = transport::WebSocketStream>; async fn upgrade_websocket(req: &mut Request) -> anyhow::Result { diff --git a/devolutions-gateway/src/websocket_forward.rs b/devolutions-gateway/src/websocket_forward.rs new file mode 100644 index 000000000..d1d3bd4d3 --- /dev/null +++ b/devolutions-gateway/src/websocket_forward.rs @@ -0,0 +1,187 @@ +use std::net::SocketAddr; +use std::sync::Arc; + +use crate::config::Conf; +use crate::proxy::Proxy; +use crate::session::{ConnectionModeDetails, SessionInfo, SessionManagerHandle}; +use crate::subscriber::SubscriberSender; +use crate::token::{AssociationTokenClaims, ConnectionMode, CurrentJrl, TokenCache, TokenError}; +use crate::utils; + +use anyhow::Context as _; +use tap::prelude::*; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _}; +use tokio_rustls::rustls::client::ClientConfig as TlsClientConfig; +use typed_builder::TypedBuilder; + +#[derive(Debug, Error)] +pub enum AuthorizationError { + #[error("token not allowed")] + Forbidden, + #[error("bad token")] + BadToken(#[from] TokenError), +} + +pub fn authorize( + client_addr: SocketAddr, + token: &str, + conf: &Conf, + token_cache: &TokenCache, + jrl: &CurrentJrl, +) -> Result { + use crate::token::AccessTokenClaims; + + if let AccessTokenClaims::Association(claims) = + crate::http::middlewares::auth::authenticate(client_addr, token, conf, token_cache, jrl)? + { + Ok(claims) + } else { + Err(AuthorizationError::Forbidden) + } +} + +#[derive(TypedBuilder)] +pub struct PlainForward { + conf: Arc, + claims: AssociationTokenClaims, + client_stream: S, + client_addr: SocketAddr, + sessions: SessionManagerHandle, + subscriber_tx: SubscriberSender, + #[builder(default = false)] + with_tls: bool, +} + +impl PlainForward +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + #[instrument(skip_all)] + pub async fn run(self) -> anyhow::Result<()> { + let Self { + conf, + claims, + client_stream, + client_addr, + sessions, + subscriber_tx, + with_tls, + } = self; + + if claims.jet_rec { + anyhow::bail!("can't meet recording policy"); + } + + let ConnectionMode::Fwd { targets, .. } = claims.jet_cm else { + anyhow::bail!("invalid connection mode") + }; + + trace!("Connecting to target"); + + let (server_transport, selected_target) = utils::successive_try(&targets, utils::tcp_transport_connect).await?; + + trace!("Connected"); + + if with_tls { + trace!("Establishing TLS connection with server"); + + // Establish TLS connection with server + + let dns_name = selected_target + .host() + .try_into() + .context("Invalid DNS name in selected target")?; + + // TODO: optimize client config creation + // + // rustls doc says: + // + // > Making one of these can be expensive, and should be once per process rather than once per connection. + // + // source: https://docs.rs/rustls/latest/rustls/struct.ClientConfig.html + // + // In our case, this doesn’t work, so I’m creating a new ClientConfig from scratch each time (slow). + // rustls issue: https://github.com/rustls/rustls/issues/1186 + let tls_client_config = TlsClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(std::sync::Arc::new( + crate::utils::danger_transport::NoCertificateVerification, + )) + .with_no_client_auth() + .pipe(Arc::new); + + let server_addr = server_transport.addr; + + let mut server_transport = tokio_rustls::TlsConnector::from(tls_client_config) + .connect(dns_name, server_transport) + .await + .context("TLS connect")?; + + // https://docs.rs/tokio-rustls/latest/tokio_rustls/#why-do-i-need-to-call-poll_flush + server_transport.flush().await?; + + trace!("TLS connection established with success"); + + info!( + "Starting WebSocket-TLS forwarding with application protocol {:?}", + claims.jet_ap + ); + + let info = SessionInfo::new( + claims.jet_aid, + claims.jet_ap, + ConnectionModeDetails::Fwd { + destination_host: selected_target.clone(), + }, + ) + .with_ttl(claims.jet_ttl) + .with_recording_policy(claims.jet_rec) + .with_filtering_policy(claims.jet_flt); + + Proxy::builder() + .conf(conf) + .session_info(info) + .address_a(client_addr) + .transport_a(client_stream) + .address_b(server_addr) + .transport_b(server_transport) + .sessions(sessions) + .subscriber_tx(subscriber_tx) + .build() + .select_dissector_and_forward() + .await + .context("Encountered a failure during plain tls traffic proxying") + } else { + info!( + "Starting WebSocket-TCP forwarding with application protocol {:?}", + claims.jet_ap + ); + + let info = SessionInfo::new( + claims.jet_aid, + claims.jet_ap, + ConnectionModeDetails::Fwd { + destination_host: selected_target.clone(), + }, + ) + .with_ttl(claims.jet_ttl) + .with_recording_policy(claims.jet_rec) + .with_filtering_policy(claims.jet_flt); + + Proxy::builder() + .conf(conf) + .session_info(info) + .address_a(client_addr) + .transport_a(client_stream) + .address_b(server_transport.addr) + .transport_b(server_transport) + .sessions(sessions) + .subscriber_tx(subscriber_tx) + .build() + .select_dissector_and_forward() + .await + .context("Encountered a failure during plain tcp traffic proxying") + } + } +}