From d536aa408658ca0783028670d8b67b844e06f16a Mon Sep 17 00:00:00 2001 From: Pieter Date: Thu, 4 Apr 2024 17:43:24 +0100 Subject: [PATCH] refactor: simplify by getting rid of the unneeded User wrapper (#1722) * refactor: simplify by getting rid of the unneeded User wrapper * refactor: use async_trait from axum --- common/Cargo.toml | 1 + common/src/claims.rs | 23 ++++++++++++- gateway/src/api/latest.rs | 22 ++++++------ gateway/src/api/project_caller.rs | 2 +- gateway/src/auth.rs | 57 +++++-------------------------- 5 files changed, 44 insertions(+), 61 deletions(-) diff --git a/common/Cargo.toml b/common/Cargo.toml index ba5a8ff45..1d96b009e 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -44,6 +44,7 @@ wiremock = { workspace = true, optional = true } [features] axum = ["dep:axum"] claims = [ + "axum", "bytes", "chrono/clock", "headers", diff --git a/common/src/claims.rs b/common/src/claims.rs index a27c6e4ec..89fe09682 100644 --- a/common/src/claims.rs +++ b/common/src/claims.rs @@ -5,10 +5,11 @@ use std::{ task::{Context, Poll}, }; +use axum::extract::FromRequestParts; use bytes::Bytes; use chrono::{Duration, Utc}; use headers::{Authorization, HeaderMapExt}; -use http::{Request, StatusCode}; +use http::{request::Parts, Request, StatusCode}; use http_body::combinators::UnsyncBoxBody; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; use opentelemetry::global; @@ -332,6 +333,26 @@ impl Claim { } } +/// Extract the claim from the request and fail with unauthorized if the claim doesn't exist +#[axum::async_trait] +impl FromRequestParts for Claim { + type Rejection = StatusCode; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let claim = parts + .extensions + .get::() + .ok_or(StatusCode::UNAUTHORIZED)?; + + // Record current account name for tracing purposes + Span::current().record("account.user_id", &claim.sub); + + trace!(?claim, "got user"); + + Ok(claim.clone()) + } +} + // Future for layers that just return the inner response #[pin_project] pub struct ResponseFuture(#[pin] pub F); diff --git a/gateway/src/api/latest.rs b/gateway/src/api/latest.rs index f6abf863f..10ea0278b 100644 --- a/gateway/src/api/latest.rs +++ b/gateway/src/api/latest.rs @@ -25,7 +25,7 @@ use shuttle_backends::metrics::{Metrics, TraceLayer}; use shuttle_backends::project_name::ProjectName; use shuttle_backends::request_span; use shuttle_backends::ClaimExt; -use shuttle_common::claims::{Scope, EXP_MINUTES}; +use shuttle_common::claims::{Claim, Scope, EXP_MINUTES}; use shuttle_common::models::error::ErrorKind; use shuttle_common::models::service; use shuttle_common::models::{admin::ProjectResponse, project, stats}; @@ -47,7 +47,7 @@ use x509_parser::time::ASN1Time; use crate::acme::{AccountWrapper, AcmeClient, CustomDomain}; use crate::api::tracing::project_name_tracing_layer; -use crate::auth::{ScopedUser, User}; +use crate::auth::ScopedUser; use crate::service::{ContainerSettings, GatewayService}; use crate::task::{self, BoxedTask}; use crate::tls::{GatewayCertResolver, RENEWAL_VALIDITY_THRESHOLD_IN_DAYS}; @@ -131,12 +131,12 @@ async fn check_project_name( } async fn get_projects_list( State(RouterState { service, .. }): State, - User { id, .. }: User, + Claim { sub, .. }: Claim, ) -> Result>, Error> { let mut projects = vec![]; for p in service .permit_client - .get_user_projects(&id) + .get_user_projects(&sub) .await .map_err(|_| Error::from(ErrorKind::Internal))? { @@ -163,7 +163,7 @@ async fn create_project( State(RouterState { service, sender, .. }): State, - User { id, claim, .. }: User, + claim: Claim, CustomErrorPath(project_name): CustomErrorPath, AxumJson(config): AxumJson, ) -> Result, Error> { @@ -172,7 +172,7 @@ async fn create_project( // Check that the user is within their project limits. let can_create_project = claim.can_create_project( service - .get_project_count(&id) + .get_project_count(&claim.sub) .await? .saturating_sub(is_cch_project as u32), ); @@ -184,7 +184,7 @@ async fn create_project( let project = service .create_project( project_name.clone(), - &id, + &claim.sub, claim.is_admin(), can_create_project, if is_cch_project { @@ -398,7 +398,7 @@ async fn override_create_service( scoped_user: ScopedUser, req: Request, ) -> Result, Error> { - let user_id = scoped_user.user.id.clone(); + let user_id = scoped_user.claim.sub.clone(); let posthog_client = state.posthog_client.clone(); tokio::spawn(async move { let event = async_posthog::Event::new("shuttle_api_start_deployment", &user_id); @@ -460,9 +460,9 @@ async fn route_project( let project_name = scoped_user.scope; let is_cch_project = project_name.is_cch_project(); - if !scoped_user.user.claim.is_admin() { + if !scoped_user.claim.is_admin() { service - .has_capacity(is_cch_project, &scoped_user.user.claim.tier) + .has_capacity(is_cch_project, &scoped_user.claim.tier) .await?; } @@ -471,7 +471,7 @@ async fn route_project( .await? .0; service - .route(&project.state, &project_name, &scoped_user.user.id, req) + .route(&project.state, &project_name, &scoped_user.claim.sub, req) .await } diff --git a/gateway/src/api/project_caller.rs b/gateway/src/api/project_caller.rs index cddf214ed..ff1676e82 100644 --- a/gateway/src/api/project_caller.rs +++ b/gateway/src/api/project_caller.rs @@ -43,7 +43,7 @@ impl ProjectCaller { Ok(Self { project: project.state, project_name, - user_id: scoped_user.user.id, + user_id: scoped_user.claim.sub, service, headers: headers.clone(), }) diff --git a/gateway/src/auth.rs b/gateway/src/auth.rs index cb3dc6b4b..b2bc72d65 100644 --- a/gateway/src/auth.rs +++ b/gateway/src/auth.rs @@ -1,55 +1,14 @@ -use std::fmt::Debug; - use axum::extract::{FromRef, FromRequestParts, Path}; use axum::http::request::Parts; -use serde::{Deserialize, Serialize}; use shuttle_backends::project_name::ProjectName; use shuttle_backends::ClaimExt; use shuttle_common::claims::Claim; use shuttle_common::models::error::InvalidProjectName; -use shuttle_common::models::user::UserId; -use tracing::{error, trace, Span}; +use tracing::error; use crate::api::latest::RouterState; use crate::{Error, ErrorKind}; -/// A wrapper to enrich a token with user details -/// -/// The `FromRequest` impl consumes the API claim and enriches it with project -/// details. Generally you want to use [`ScopedUser`] instead to ensure the request -/// is valid against the user's owned resources. -#[derive(Clone, Deserialize, PartialEq, Eq, Serialize, Debug)] -pub struct User { - pub claim: Claim, - pub id: UserId, -} - -#[async_trait] -impl FromRequestParts for User -where - S: Send + Sync, - RouterState: FromRef, -{ - type Rejection = Error; - - async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - let claim = parts.extensions.get::().ok_or(ErrorKind::Internal)?; - let user_id = claim.sub.clone(); - - // Record current account name for tracing purposes - Span::current().record("account.user_id", &user_id); - - let user = User { - claim: claim.clone(), - id: user_id, - }; - - trace!(?user, "got user"); - - Ok(user) - } -} - /// A wrapper for a guard that validates a user's API token *and* /// scopes the request to a project they own. /// @@ -57,7 +16,7 @@ where /// by [`ScopedUser::name`]. #[derive(Clone)] pub struct ScopedUser { - pub user: User, + pub claim: Claim, pub scope: ProjectName, } @@ -70,7 +29,9 @@ where type Rejection = Error; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let user = User::from_request_parts(parts, state).await?; + let claim = Claim::from_request_parts(parts, state) + .await + .map_err(|_| ErrorKind::Unauthorized)?; let scope = match Path::::from_request_parts(parts, state).await { Ok(Path(p)) => p, @@ -82,12 +43,12 @@ where let RouterState { service, .. } = RouterState::from_ref(state); - let allowed = user.claim.is_admin() - || user.claim.is_deployer() + let allowed = claim.is_admin() + || claim.is_deployer() || service .permit_client .allowed( - &user.id, + &claim.sub, &service.find_project_by_name(&scope).await?.id, "develop", // TODO: make this configurable per endpoint? ) @@ -98,7 +59,7 @@ where })?; if allowed { - Ok(Self { user, scope }) + Ok(Self { claim, scope }) } else { Err(Error::from(ErrorKind::ProjectNotFound(scope.to_string()))) }