Skip to content

Commit

Permalink
WAYK-2566: Add token type + bridge token + update powershell to gener…
Browse files Browse the repository at this point in the history
…ate token
  • Loading branch information
fdubois1 committed Aug 3, 2021
1 parent e989ddb commit 2beacd1
Show file tree
Hide file tree
Showing 16 changed files with 127 additions and 127 deletions.
133 changes: 54 additions & 79 deletions devolutions-gateway/src/http/controllers/http_bridge.rs
Original file line number Diff line number Diff line change
@@ -1,112 +1,87 @@
use crate::config::Config;
use crate::http::guards::access::{AccessGuard, JetTokenType};
use crate::http::HttpErrorStatus;
use saphir::http::StatusCode;
use jet_proto::token::JetAccessTokenClaims;
use saphir::macros::controller;
use saphir::request::Request;
use saphir::response::Builder;
use std::sync::Arc;

pub const GATEWAY_BRIDGE_TOKEN_HDR_NAME: &str = "Gateway-Bridge-Token";

#[derive(Deserialize)]
struct HttpBridgeClaims {
target: url::Url,
}
pub const REQUEST_AUTHORIZATION_TOKEN_HDR_NAME: &str = "Request-Authorization-Token";

pub struct HttpBridgeController {
config: Arc<Config>,
client: reqwest::Client,
}

impl HttpBridgeController {
pub fn new(config: Arc<Config>) -> Self {
pub fn new() -> Self {
let client = reqwest::Client::new();
Self { config, client }
}
}

impl HttpBridgeController {
fn h_decode_claims(&self, token_str: &str) -> Result<HttpBridgeClaims, HttpErrorStatus> {
use core::convert::TryFrom;
use picky::jose::jwt;
use std::time::{SystemTime, UNIX_EPOCH};

let key = self
.config
.provisioner_public_key
.as_ref()
.ok_or((StatusCode::INTERNAL_SERVER_ERROR, "provisioner public key is missing"))?;

let numeric_date = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("UNIX EPOCH is in the past")
.as_secs();
let date = jwt::JwtDate::new_with_leeway(i64::try_from(numeric_date).unwrap(), 60);
let validator = jwt::JwtValidator::strict(&date);

let jws = jwt::JwtSig::decode(token_str, key, &validator).map_err(HttpErrorStatus::forbidden)?;

Ok(jws.claims)
Self { client }
}
}

#[controller(name = "bridge")]
impl HttpBridgeController {
#[post("/message")]
#[guard(AccessGuard, init_expr = r#"JetTokenType::Bridge"#)]
async fn message(&self, req: Request) -> Result<Builder, HttpErrorStatus> {
use core::convert::TryFrom;

// FIXME: when updating reqwest 0.10 → 0.11 and hyper 0.13 → 0.14:
// Use https://docs.rs/reqwest/0.11.4/reqwest/struct.Body.html#impl-From%3CBody%3E
// to get a streaming reqwest Request instead of loading the whole body in memory.
let req = req.load_body().await.map_err(HttpErrorStatus::internal)?;
let req: saphir::request::Request<reqwest::Body> = req.map(reqwest::Body::from);
let mut req: http::Request<reqwest::Body> = http::Request::from(req);
if let Some(JetAccessTokenClaims::Bridge(claims)) = req
.extensions()
.get::<JetAccessTokenClaims>()
.map(|claim| claim.clone())
{
// FIXME: when updating reqwest 0.10 → 0.11 and hyper 0.13 → 0.14:
// Use https://docs.rs/reqwest/0.11.4/reqwest/struct.Body.html#impl-From%3CBody%3E
// to get a streaming reqwest Request instead of loading the whole body in memory.
let req = req.load_body().await.map_err(HttpErrorStatus::internal)?;
let req: saphir::request::Request<reqwest::Body> = req.map(reqwest::Body::from);
let mut req: http::Request<reqwest::Body> = http::Request::from(req);

// === Replace Authorization header (used to be authorized on the gateway) with the request authorization token === //

let mut rsp = {
let headers = req.headers_mut();
headers.remove(http::header::AUTHORIZATION);
if let Some(auth_token) = headers.remove(REQUEST_AUTHORIZATION_TOKEN_HDR_NAME) {
headers.insert(http::header::AUTHORIZATION, auth_token);
}

// === Filter and validate request to forward === //
// Update request destination
let uri = http::Uri::try_from(claims.target.as_str()).map_err(HttpErrorStatus::bad_request)?;
*req.uri_mut() = uri;

let mut rsp = {
// Gateway Bridge Claims
let headers = req.headers_mut();
let token_hdr = headers
.remove(GATEWAY_BRIDGE_TOKEN_HDR_NAME)
.ok_or((StatusCode::BAD_REQUEST, "Gateway-Bridge-Token header is missing"))?;
let token_str = token_hdr.to_str().map_err(HttpErrorStatus::bad_request)?;
let claims = self.h_decode_claims(token_str)?;
// Forward
slog_scope::debug!("Forward HTTP request to {}", req.uri());
let req = reqwest::Request::try_from(req).map_err(HttpErrorStatus::internal)?;
self.client.execute(req).await.map_err(HttpErrorStatus::bad_gateway)?
};

// Update request destination
let uri = http::Uri::try_from(claims.target.as_str()).map_err(HttpErrorStatus::bad_request)?;
*req.uri_mut() = uri;
// === Create HTTP response using target response === //

// Forward
slog_scope::debug!("Forward HTTP request to {}", req.uri());
let req = reqwest::Request::try_from(req).map_err(HttpErrorStatus::internal)?;
self.client.execute(req).await.map_err(HttpErrorStatus::bad_gateway)?
};
let mut rsp_builder = Builder::new();

// === Create HTTP response using target response === //
{
// Status code
rsp_builder = rsp_builder.status(rsp.status());

let mut rsp_builder = Builder::new();
// Headers
let headers = rsp_builder.headers_mut().unwrap();
rsp.headers_mut().drain().for_each(|(name, value)| {
if let Some(name) = name {
headers.insert(name, value);
}
});

{
// Status code
rsp_builder = rsp_builder.status(rsp.status());

// Headers
let headers = rsp_builder.headers_mut().unwrap();
rsp.headers_mut().drain().for_each(|(name, value)| {
if let Some(name) = name {
headers.insert(name, value);
// Body
match rsp.bytes().await {
Ok(body) => rsp_builder = rsp_builder.body(body),
Err(e) => slog_scope::warn!("Couldn’t get bytes from response body: {}", e),
}
});

// Body
match rsp.bytes().await {
Ok(body) => rsp_builder = rsp_builder.body(body),
Err(e) => slog_scope::warn!("Couldn’t get bytes from response body: {}", e),
}
}

Ok(rsp_builder)
Ok(rsp_builder)
} else {
Err(HttpErrorStatus::unauthorized("Bridge token is mandatory"))
}
}
}
16 changes: 10 additions & 6 deletions devolutions-gateway/src/http/controllers/jet.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::config::Config;
use crate::http::controllers::health::build_health_response;
use crate::http::guards::access::{AccessGuard, JetAccessType};
use crate::http::guards::access::{AccessGuard, JetTokenType};
use crate::jet::association::{Association, AssociationResponse};
use crate::jet::candidate::Candidate;
use crate::jet_client::JetAssociationsMap;
use crate::utils::association::{remove_jet_association, ACCEPT_REQUEST_TIMEOUT};
use jet_proto::token::JetAccessTokenClaims;
use jet_proto::token::{JetAccessScope, JetAccessTokenClaims};
use jet_proto::JET_VERSION_V2;
use saphir::controller::Controller;
use saphir::http::{Method, StatusCode};
Expand Down Expand Up @@ -34,6 +34,10 @@ impl JetController {
#[controller(name = "jet")]
impl JetController {
#[get("/association")]
#[guard(
AccessGuard,
init_expr = r#"JetTokenType::Scope(JetAccessScope::GatewayAssociationsRead)"#
)]
async fn get_associations(&self, detail: Option<bool>) -> (StatusCode, Option<String>) {
let with_detail = detail.unwrap_or(false);
let associations_response: Vec<AssociationResponse>;
Expand All @@ -52,9 +56,9 @@ impl JetController {
}

#[post("/association/<association_id>")]
#[guard(AccessGuard, init_expr = r#"JetAccessType::Session"#)]
#[guard(AccessGuard, init_expr = r#"JetTokenType::Association"#)]
async fn create_association(&self, req: Request) -> (StatusCode, ()) {
if let Some(JetAccessTokenClaims::Session(session_token)) = req.extensions().get::<JetAccessTokenClaims>() {
if let Some(JetAccessTokenClaims::Association(session_token)) = req.extensions().get::<JetAccessTokenClaims>() {
let association_id = match req
.captures()
.get("association_id")
Expand Down Expand Up @@ -91,9 +95,9 @@ impl JetController {
}

#[post("/association/<association_id>/candidates")]
#[guard(AccessGuard, init_expr = r#"JetAccessType::Session"#)]
#[guard(AccessGuard, init_expr = r#"JetTokenType::Association"#)]
async fn gather_association_candidates(&self, req: Request) -> (StatusCode, Option<String>) {
if let Some(JetAccessTokenClaims::Session(session_token)) = req.extensions().get::<JetAccessTokenClaims>() {
if let Some(JetAccessTokenClaims::Association(session_token)) = req.extensions().get::<JetAccessTokenClaims>() {
let association_id = match req
.captures()
.get("association_id")
Expand Down
6 changes: 3 additions & 3 deletions devolutions-gateway/src/http/controllers/sessions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::http::guards::access::{AccessGuard, JetAccessType};
use crate::http::guards::access::{AccessGuard, JetTokenType};
use crate::http::HttpErrorStatus;
use crate::{GatewaySessionInfo, SESSIONS_IN_PROGRESS};
use jet_proto::token::JetAccessScope;
Expand All @@ -14,7 +14,7 @@ impl SessionsController {
#[get("/count")]
#[guard(
AccessGuard,
init_expr = r#"JetAccessType::Scope(JetAccessScope::GatewaySessionsRead)"#
init_expr = r#"JetTokenType::Scope(JetAccessScope::GatewaySessionsRead)"#
)]
async fn get_count(&self) -> (StatusCode, String) {
let sessions = SESSIONS_IN_PROGRESS.read().await;
Expand All @@ -24,7 +24,7 @@ impl SessionsController {
#[get("/")]
#[guard(
AccessGuard,
init_expr = r#"JetAccessType::Scope(JetAccessScope::GatewaySessionsRead)"#
init_expr = r#"JetTokenType::Scope(JetAccessScope::GatewaySessionsRead)"#
)]
async fn get_sessions(&self) -> Result<Json<Vec<GatewaySessionInfo>>, HttpErrorStatus> {
let sessions = SESSIONS_IN_PROGRESS.read().await;
Expand Down
27 changes: 17 additions & 10 deletions devolutions-gateway/src/http/guards/access.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,44 @@
use crate::http::HttpErrorStatus;
use jet_proto::token::{JetAccessScope, JetAccessTokenClaims};
use saphir::prelude::*;

#[derive(Deserialize)]
pub enum JetAccessType {
pub enum JetTokenType {
Scope(JetAccessScope),
Session,
Bridge,
Association,
}

pub struct AccessGuard {
access_type: JetAccessType,
token_type: JetTokenType,
}

#[guard]
impl AccessGuard {
pub fn new(access_type: JetAccessType) -> Self {
AccessGuard { access_type }
pub fn new(token_type: JetTokenType) -> Self {
AccessGuard { token_type }
}

async fn validate(&self, req: Request) -> Result<Request, StatusCode> {
async fn validate(&self, req: Request) -> Result<Request, HttpErrorStatus> {
if let Some(claims) = req.extensions().get::<JetAccessTokenClaims>() {
match (claims, &self.access_type) {
(JetAccessTokenClaims::Session(_), JetAccessType::Session) => {
match (claims, &self.token_type) {
(JetAccessTokenClaims::Association(_), JetTokenType::Association) => {
return Ok(req);
}
(JetAccessTokenClaims::Scope(scope_from_request), JetAccessType::Scope(scope_needed))
(JetAccessTokenClaims::Scope(scope_from_request), JetTokenType::Scope(scope_needed))
if scope_from_request.scope == *scope_needed =>
{
return Ok(req);
}
(JetAccessTokenClaims::Bridge(_), JetTokenType::Bridge) => {
return Ok(req);
}
_ => {}
}
}

Err(StatusCode::FORBIDDEN)
Err(HttpErrorStatus::forbidden(
"Token provided can't be used to access the route",
))
}
}
2 changes: 1 addition & 1 deletion devolutions-gateway/src/http/http_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub fn configure_http_server(config: Arc<Config>, jet_associations: JetAssociati
let sogar = SogarController::new(registry_name.as_str(), registry_namespace.as_str());
let token_controller = TokenController::new(config.clone());

let http_bridge = HttpBridgeController::new(config.clone());
let http_bridge = HttpBridgeController::new();

info!("Configuring HTTP router");

Expand Down
5 changes: 5 additions & 0 deletions devolutions-gateway/src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ impl HttpErrorStatus {
Self::new(StatusCode::FORBIDDEN, source)
}

#[track_caller]
fn unauthorized<T: Display + Send + 'static>(source: T) -> Self {
Self::new(StatusCode::UNAUTHORIZED, source)
}

#[track_caller]
fn internal<T: Display + Send + 'static>(source: T) -> Self {
Self::new(StatusCode::INTERNAL_SERVER_ERROR, source)
Expand Down
8 changes: 4 additions & 4 deletions devolutions-gateway/src/jet/association.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::jet::TransportType;
use chrono::serde::ts_seconds;
use chrono::{DateTime, Utc};
use indexmap::IndexMap;
use jet_proto::token::JetSessionTokenClaims;
use jet_proto::token::JetAssociationTokenClaims;
use serde_json::Value;
use uuid::Uuid;

Expand All @@ -12,11 +12,11 @@ pub struct Association {
version: u8,
creation_timestamp: DateTime<Utc>,
candidates: IndexMap<Uuid, Candidate>,
session_token: JetSessionTokenClaims,
session_token: JetAssociationTokenClaims,
}

impl Association {
pub fn new(id: Uuid, version: u8, session_token: JetSessionTokenClaims) -> Self {
pub fn new(id: Uuid, version: u8, session_token: JetAssociationTokenClaims) -> Self {
Association {
id,
version,
Expand Down Expand Up @@ -92,7 +92,7 @@ impl Association {
.any(|(_, candidate)| candidate.state() == CandidateState::Connected)
}

pub fn jet_session_token_claims(&self) -> &JetSessionTokenClaims {
pub fn jet_session_token_claims(&self) -> &JetAssociationTokenClaims {
&self.session_token
}

Expand Down
4 changes: 2 additions & 2 deletions devolutions-gateway/src/jet_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::transport::{JetTransport, Transport};
use crate::utils::association::{remove_jet_association, ACCEPT_REQUEST_TIMEOUT};
use crate::utils::{create_tls_connector, into_other_io_error as error_other};
use crate::Proxy;
use jet_proto::token::JetSessionTokenClaims;
use jet_proto::token::JetAssociationTokenClaims;
use std::path::PathBuf;
use tokio_rustls::{TlsAcceptor, TlsStream};

Expand Down Expand Up @@ -448,7 +448,7 @@ pub struct HandleConnectJetMsgResponse {
pub server_transport: JetTransport,
pub association_id: Uuid,
pub candidate_id: Uuid,
pub session_token: JetSessionTokenClaims,
pub session_token: JetAssociationTokenClaims,
}

async fn handle_test_jet_msg(mut transport: JetTransport, request: JetTestReq) -> Result<(), io::Error> {
Expand Down
6 changes: 3 additions & 3 deletions devolutions-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use lazy_static::lazy_static;
use tokio::sync::RwLock;
use uuid::Uuid;

use jet_proto::token::JetSessionTokenClaims;
use jet_proto::token::JetAssociationTokenClaims;
pub use proxy::Proxy;

use jet_proto::token::JetConnectionMode;
Expand Down Expand Up @@ -70,8 +70,8 @@ impl GatewaySessionInfo {
}
}

impl From<JetSessionTokenClaims> for GatewaySessionInfo {
fn from(session_token: JetSessionTokenClaims) -> Self {
impl From<JetAssociationTokenClaims> for GatewaySessionInfo {
fn from(session_token: JetAssociationTokenClaims) -> Self {
GatewaySessionInfo {
association_id: session_token.jet_aid,
application_protocol: session_token.jet_ap.clone(),
Expand Down
Loading

0 comments on commit 2beacd1

Please sign in to comment.