Skip to content

Commit

Permalink
feat(dgw): WebSocket-TCP endpoint (/jet/tcp) (#399)
Browse files Browse the repository at this point in the history
Issue: DGW-82
  • Loading branch information
CBenoit authored Mar 14, 2023
1 parent 5980763 commit 265f0db
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 0 deletions.
1 change: 1 addition & 0 deletions devolutions-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub mod registry;
pub mod service;
pub mod session;
pub mod subscriber;
pub mod tcp;
pub mod token;
pub mod transport;
pub mod utils;
Expand Down
89 changes: 89 additions & 0 deletions devolutions-gateway/src/tcp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use std::net::SocketAddr;
use std::sync::Arc;

use crate::config::Conf;
use crate::proxy::Proxy;
use crate::session::{ConnectionModeDetails, SessionInfo, SessionManagerHandle};
use crate::subscriber::SubscriberSender;
use crate::token::{AssociationTokenClaims, ConnectionMode, CurrentJrl, TokenCache, TokenError};
use crate::utils;

use anyhow::Context as _;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};

#[derive(Debug, Error)]
pub enum AuthorizationError {
#[error("token not allowed")]
Forbidden,
#[error("bad token")]
BadToken(#[from] TokenError),
}

pub fn authorize(
client_addr: SocketAddr,
token: &str,
conf: &Conf,
token_cache: &TokenCache,
jrl: &CurrentJrl,
) -> Result<AssociationTokenClaims, AuthorizationError> {
use crate::token::AccessTokenClaims;

if let AccessTokenClaims::Association(claims) =
crate::http::middlewares::auth::authenticate(client_addr, token, conf, token_cache, jrl)?
{
Ok(claims)
} else {
Err(AuthorizationError::Forbidden)
}
}

#[instrument(skip_all)]
pub async fn handle(
client_stream: impl AsyncRead + AsyncWrite + Unpin + Send + 'static,
client_addr: SocketAddr,
conf: Arc<Conf>,
claims: AssociationTokenClaims,
sessions: SessionManagerHandle,
subscriber_tx: SubscriberSender,
) -> anyhow::Result<()> {
info!(
"Starting WebSocket-TCP forwarding with application protocol {:?}",
claims.jet_ap
);

if claims.jet_rec {
anyhow::bail!("can't meet recording policy");
}

let ConnectionMode::Fwd { targets, .. } = claims.jet_cm else {
anyhow::bail!("invalid connection mode")
};

let (server_transport, selected_target) = utils::successive_try(&targets, utils::tcp_transport_connect).await?;

let info = SessionInfo::new(
claims.jet_aid,
claims.jet_ap,
ConnectionModeDetails::Fwd {
destination_host: selected_target.clone(),
},
)
.with_ttl(claims.jet_ttl)
.with_recording_policy(claims.jet_rec)
.with_filtering_policy(claims.jet_flt);

Proxy::builder()
.conf(conf)
.session_info(info)
.address_a(client_addr)
.transport_a(client_stream)
.address_b(server_transport.addr)
.transport_b(server_transport)
.sessions(sessions)
.subscriber_tx(subscriber_tx)
.build()
.select_dissector_and_forward()
.await
.context("Encountered a failure during plain tcp traffic proxying")
}
66 changes: 66 additions & 0 deletions devolutions-gateway/src/websocket_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ impl WebsocketService {
)
.await
.map_err(|err| io::Error::new(ErrorKind::Other, format!("Handle RDP error - {err:#}")))
} else if req.method() == Method::GET && req_uri.starts_with("/jet/tcp") {
info!("{} {}", req.method(), req_uri);
handle_tcp(
req,
client_addr,
self.conf.clone(),
&self.token_cache,
&self.jrl,
self.sessions.clone(),
self.subscriber_tx.clone(),
)
.await
.map_err(|err| io::Error::new(ErrorKind::Other, format!("Handle TCP error - {err:#}")))
} else {
saphir::server::inject_raw_with_peer_addr(req, Some(client_addr))
.await
Expand Down Expand Up @@ -603,6 +616,59 @@ async fn handle_rdp(
Ok(rsp)
}

async fn handle_tcp(
mut req: Request<Body>,
client_addr: SocketAddr,
conf: Arc<Conf>,
token_cache: &TokenCache,
jrl: &CurrentJrl,
sessions: SessionManagerHandle,
subscriber_tx: SubscriberSender,
) -> anyhow::Result<Response<Body>> {
use crate::http::middlewares::auth::{parse_auth_header, AuthHeaderType};

let token = if let Some(authorization_value) = req.headers().get(header::AUTHORIZATION) {
let authorization_value = authorization_value.to_str().context("bad authorization header value")?; // BAD REQUEST
match parse_auth_header(authorization_value) {
Some((AuthHeaderType::Bearer, token)) => token,
_ => anyhow::bail!("bad authorization header value"), // BAD REQUEST
}
} else if let Some(token) = req.uri().query().and_then(|q| {
q.split('&')
.filter_map(|segment| segment.split_once('='))
.find_map(|(key, val)| key.eq("token").then_some(val))
}) {
token
} else {
anyhow::bail!("missing authorization"); // AUTHORIZATION
};

let claims = crate::tcp::authorize(client_addr, token, &conf, token_cache, jrl)?; // FORBIDDEN

if let Some(upgrade_val) = req.headers().get("upgrade").and_then(|v| v.to_str().ok()) {
if upgrade_val != "websocket" {
anyhow::bail!("unexpected upgrade header value: {}", upgrade_val) // BAD REQUEST
}
}

let rsp = process_req(&req);

tokio::spawn(async move {
let fut = async {
let stream = upgrade_websocket(&mut req).await?;
crate::tcp::handle(stream, client_addr, conf, claims, sessions, subscriber_tx).await
}
.instrument(info_span!("tcp", client = %client_addr));

match fut.await {
Ok(()) => {}
Err(error) => error!(client = %client_addr, error = format!("{error:#}"), "WebSocket-TCP failure"),
}
});

Ok(rsp)
}

type WebsocketTransport = transport::WebSocketStream<tokio_tungstenite::WebSocketStream<hyper::upgrade::Upgraded>>;

async fn upgrade_websocket(req: &mut Request<Body>) -> anyhow::Result<WebsocketTransport> {
Expand Down

0 comments on commit 265f0db

Please sign in to comment.