diff --git a/devolutions-gateway/src/lib.rs b/devolutions-gateway/src/lib.rs index 658664863..7a51c6050 100644 --- a/devolutions-gateway/src/lib.rs +++ b/devolutions-gateway/src/lib.rs @@ -27,6 +27,7 @@ pub mod registry; pub mod service; pub mod session; pub mod subscriber; +pub mod tcp; pub mod token; pub mod transport; pub mod utils; diff --git a/devolutions-gateway/src/tcp.rs b/devolutions-gateway/src/tcp.rs new file mode 100644 index 000000000..eb8a1ed07 --- /dev/null +++ b/devolutions-gateway/src/tcp.rs @@ -0,0 +1,89 @@ +use std::net::SocketAddr; +use std::sync::Arc; + +use crate::config::Conf; +use crate::proxy::Proxy; +use crate::session::{ConnectionModeDetails, SessionInfo, SessionManagerHandle}; +use crate::subscriber::SubscriberSender; +use crate::token::{AssociationTokenClaims, ConnectionMode, CurrentJrl, TokenCache, TokenError}; +use crate::utils; + +use anyhow::Context as _; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite}; + +#[derive(Debug, Error)] +pub enum AuthorizationError { + #[error("token not allowed")] + Forbidden, + #[error("bad token")] + BadToken(#[from] TokenError), +} + +pub fn authorize( + client_addr: SocketAddr, + token: &str, + conf: &Conf, + token_cache: &TokenCache, + jrl: &CurrentJrl, +) -> Result { + use crate::token::AccessTokenClaims; + + if let AccessTokenClaims::Association(claims) = + crate::http::middlewares::auth::authenticate(client_addr, token, conf, token_cache, jrl)? + { + Ok(claims) + } else { + Err(AuthorizationError::Forbidden) + } +} + +#[instrument(skip_all)] +pub async fn handle( + client_stream: impl AsyncRead + AsyncWrite + Unpin + Send + 'static, + client_addr: SocketAddr, + conf: Arc, + claims: AssociationTokenClaims, + sessions: SessionManagerHandle, + subscriber_tx: SubscriberSender, +) -> anyhow::Result<()> { + info!( + "Starting WebSocket-TCP forwarding with application protocol {:?}", + claims.jet_ap + ); + + if claims.jet_rec { + anyhow::bail!("can't meet recording policy"); + } + + let ConnectionMode::Fwd { targets, .. } = claims.jet_cm else { + anyhow::bail!("invalid connection mode") + }; + + let (server_transport, selected_target) = utils::successive_try(&targets, utils::tcp_transport_connect).await?; + + let info = SessionInfo::new( + claims.jet_aid, + claims.jet_ap, + ConnectionModeDetails::Fwd { + destination_host: selected_target.clone(), + }, + ) + .with_ttl(claims.jet_ttl) + .with_recording_policy(claims.jet_rec) + .with_filtering_policy(claims.jet_flt); + + Proxy::builder() + .conf(conf) + .session_info(info) + .address_a(client_addr) + .transport_a(client_stream) + .address_b(server_transport.addr) + .transport_b(server_transport) + .sessions(sessions) + .subscriber_tx(subscriber_tx) + .build() + .select_dissector_and_forward() + .await + .context("Encountered a failure during plain tcp traffic proxying") +} diff --git a/devolutions-gateway/src/websocket_client.rs b/devolutions-gateway/src/websocket_client.rs index 40b46c12b..b81f2d884 100644 --- a/devolutions-gateway/src/websocket_client.rs +++ b/devolutions-gateway/src/websocket_client.rs @@ -80,6 +80,19 @@ impl WebsocketService { ) .await .map_err(|err| io::Error::new(ErrorKind::Other, format!("Handle RDP error - {err:#}"))) + } else if req.method() == Method::GET && req_uri.starts_with("/jet/tcp") { + info!("{} {}", req.method(), req_uri); + handle_tcp( + req, + client_addr, + self.conf.clone(), + &self.token_cache, + &self.jrl, + self.sessions.clone(), + self.subscriber_tx.clone(), + ) + .await + .map_err(|err| io::Error::new(ErrorKind::Other, format!("Handle TCP error - {err:#}"))) } else { saphir::server::inject_raw_with_peer_addr(req, Some(client_addr)) .await @@ -603,6 +616,59 @@ async fn handle_rdp( Ok(rsp) } +async fn handle_tcp( + mut req: Request, + client_addr: SocketAddr, + conf: Arc, + token_cache: &TokenCache, + jrl: &CurrentJrl, + sessions: SessionManagerHandle, + subscriber_tx: SubscriberSender, +) -> anyhow::Result> { + use crate::http::middlewares::auth::{parse_auth_header, AuthHeaderType}; + + let token = if let Some(authorization_value) = req.headers().get(header::AUTHORIZATION) { + let authorization_value = authorization_value.to_str().context("bad authorization header value")?; // BAD REQUEST + match parse_auth_header(authorization_value) { + Some((AuthHeaderType::Bearer, token)) => token, + _ => anyhow::bail!("bad authorization header value"), // BAD REQUEST + } + } else if let Some(token) = req.uri().query().and_then(|q| { + q.split('&') + .filter_map(|segment| segment.split_once('=')) + .find_map(|(key, val)| key.eq("token").then_some(val)) + }) { + token + } else { + anyhow::bail!("missing authorization"); // AUTHORIZATION + }; + + let claims = crate::tcp::authorize(client_addr, token, &conf, token_cache, jrl)?; // FORBIDDEN + + if let Some(upgrade_val) = req.headers().get("upgrade").and_then(|v| v.to_str().ok()) { + if upgrade_val != "websocket" { + anyhow::bail!("unexpected upgrade header value: {}", upgrade_val) // BAD REQUEST + } + } + + let rsp = process_req(&req); + + tokio::spawn(async move { + let fut = async { + let stream = upgrade_websocket(&mut req).await?; + crate::tcp::handle(stream, client_addr, conf, claims, sessions, subscriber_tx).await + } + .instrument(info_span!("tcp", client = %client_addr)); + + match fut.await { + Ok(()) => {} + Err(error) => error!(client = %client_addr, error = format!("{error:#}"), "WebSocket-TCP failure"), + } + }); + + Ok(rsp) +} + type WebsocketTransport = transport::WebSocketStream>; async fn upgrade_websocket(req: &mut Request) -> anyhow::Result {