Skip to content

Commit

Permalink
refactor: simplify by getting rid of the unneeded User wrapper (#1722)
Browse files Browse the repository at this point in the history
* refactor: simplify by getting rid of the unneeded User wrapper

* refactor: use async_trait from axum
  • Loading branch information
chesedo authored Apr 4, 2024
1 parent d155595 commit d536aa4
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 61 deletions.
1 change: 1 addition & 0 deletions common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ wiremock = { workspace = true, optional = true }
[features]
axum = ["dep:axum"]
claims = [
"axum",
"bytes",
"chrono/clock",
"headers",
Expand Down
23 changes: 22 additions & 1 deletion common/src/claims.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ use std::{
task::{Context, Poll},
};

use axum::extract::FromRequestParts;
use bytes::Bytes;
use chrono::{Duration, Utc};
use headers::{Authorization, HeaderMapExt};
use http::{Request, StatusCode};
use http::{request::Parts, Request, StatusCode};
use http_body::combinators::UnsyncBoxBody;
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use opentelemetry::global;
Expand Down Expand Up @@ -332,6 +333,26 @@ impl Claim {
}
}

/// Extract the claim from the request and fail with unauthorized if the claim doesn't exist
#[axum::async_trait]
impl<S> FromRequestParts<S> for Claim {
type Rejection = StatusCode;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let claim = parts
.extensions
.get::<Claim>()
.ok_or(StatusCode::UNAUTHORIZED)?;

// Record current account name for tracing purposes
Span::current().record("account.user_id", &claim.sub);

trace!(?claim, "got user");

Ok(claim.clone())
}
}

// Future for layers that just return the inner response
#[pin_project]
pub struct ResponseFuture<F>(#[pin] pub F);
Expand Down
22 changes: 11 additions & 11 deletions gateway/src/api/latest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use shuttle_backends::metrics::{Metrics, TraceLayer};
use shuttle_backends::project_name::ProjectName;
use shuttle_backends::request_span;
use shuttle_backends::ClaimExt;
use shuttle_common::claims::{Scope, EXP_MINUTES};
use shuttle_common::claims::{Claim, Scope, EXP_MINUTES};
use shuttle_common::models::error::ErrorKind;
use shuttle_common::models::service;
use shuttle_common::models::{admin::ProjectResponse, project, stats};
Expand All @@ -47,7 +47,7 @@ use x509_parser::time::ASN1Time;

use crate::acme::{AccountWrapper, AcmeClient, CustomDomain};
use crate::api::tracing::project_name_tracing_layer;
use crate::auth::{ScopedUser, User};
use crate::auth::ScopedUser;
use crate::service::{ContainerSettings, GatewayService};
use crate::task::{self, BoxedTask};
use crate::tls::{GatewayCertResolver, RENEWAL_VALIDITY_THRESHOLD_IN_DAYS};
Expand Down Expand Up @@ -131,12 +131,12 @@ async fn check_project_name(
}
async fn get_projects_list(
State(RouterState { service, .. }): State<RouterState>,
User { id, .. }: User,
Claim { sub, .. }: Claim,
) -> Result<AxumJson<Vec<project::Response>>, Error> {
let mut projects = vec![];
for p in service
.permit_client
.get_user_projects(&id)
.get_user_projects(&sub)
.await
.map_err(|_| Error::from(ErrorKind::Internal))?
{
Expand All @@ -163,7 +163,7 @@ async fn create_project(
State(RouterState {
service, sender, ..
}): State<RouterState>,
User { id, claim, .. }: User,
claim: Claim,
CustomErrorPath(project_name): CustomErrorPath<ProjectName>,
AxumJson(config): AxumJson<project::Config>,
) -> Result<AxumJson<project::Response>, Error> {
Expand All @@ -172,7 +172,7 @@ async fn create_project(
// Check that the user is within their project limits.
let can_create_project = claim.can_create_project(
service
.get_project_count(&id)
.get_project_count(&claim.sub)
.await?
.saturating_sub(is_cch_project as u32),
);
Expand All @@ -184,7 +184,7 @@ async fn create_project(
let project = service
.create_project(
project_name.clone(),
&id,
&claim.sub,
claim.is_admin(),
can_create_project,
if is_cch_project {
Expand Down Expand Up @@ -398,7 +398,7 @@ async fn override_create_service(
scoped_user: ScopedUser,
req: Request<Body>,
) -> Result<Response<Body>, Error> {
let user_id = scoped_user.user.id.clone();
let user_id = scoped_user.claim.sub.clone();
let posthog_client = state.posthog_client.clone();
tokio::spawn(async move {
let event = async_posthog::Event::new("shuttle_api_start_deployment", &user_id);
Expand Down Expand Up @@ -460,9 +460,9 @@ async fn route_project(
let project_name = scoped_user.scope;
let is_cch_project = project_name.is_cch_project();

if !scoped_user.user.claim.is_admin() {
if !scoped_user.claim.is_admin() {
service
.has_capacity(is_cch_project, &scoped_user.user.claim.tier)
.has_capacity(is_cch_project, &scoped_user.claim.tier)
.await?;
}

Expand All @@ -471,7 +471,7 @@ async fn route_project(
.await?
.0;
service
.route(&project.state, &project_name, &scoped_user.user.id, req)
.route(&project.state, &project_name, &scoped_user.claim.sub, req)
.await
}

Expand Down
2 changes: 1 addition & 1 deletion gateway/src/api/project_caller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl ProjectCaller {
Ok(Self {
project: project.state,
project_name,
user_id: scoped_user.user.id,
user_id: scoped_user.claim.sub,
service,
headers: headers.clone(),
})
Expand Down
57 changes: 9 additions & 48 deletions gateway/src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,63 +1,22 @@
use std::fmt::Debug;

use axum::extract::{FromRef, FromRequestParts, Path};
use axum::http::request::Parts;
use serde::{Deserialize, Serialize};
use shuttle_backends::project_name::ProjectName;
use shuttle_backends::ClaimExt;
use shuttle_common::claims::Claim;
use shuttle_common::models::error::InvalidProjectName;
use shuttle_common::models::user::UserId;
use tracing::{error, trace, Span};
use tracing::error;

use crate::api::latest::RouterState;
use crate::{Error, ErrorKind};

/// A wrapper to enrich a token with user details
///
/// The `FromRequest` impl consumes the API claim and enriches it with project
/// details. Generally you want to use [`ScopedUser`] instead to ensure the request
/// is valid against the user's owned resources.
#[derive(Clone, Deserialize, PartialEq, Eq, Serialize, Debug)]
pub struct User {
pub claim: Claim,
pub id: UserId,
}

#[async_trait]
impl<S> FromRequestParts<S> for User
where
S: Send + Sync,
RouterState: FromRef<S>,
{
type Rejection = Error;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let claim = parts.extensions.get::<Claim>().ok_or(ErrorKind::Internal)?;
let user_id = claim.sub.clone();

// Record current account name for tracing purposes
Span::current().record("account.user_id", &user_id);

let user = User {
claim: claim.clone(),
id: user_id,
};

trace!(?user, "got user");

Ok(user)
}
}

/// A wrapper for a guard that validates a user's API token *and*
/// scopes the request to a project they own.
///
/// It is guaranteed that [`ScopedUser::scope`] exists and is owned
/// by [`ScopedUser::name`].
#[derive(Clone)]
pub struct ScopedUser {
pub user: User,
pub claim: Claim,
pub scope: ProjectName,
}

Expand All @@ -70,7 +29,9 @@ where
type Rejection = Error;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let user = User::from_request_parts(parts, state).await?;
let claim = Claim::from_request_parts(parts, state)
.await
.map_err(|_| ErrorKind::Unauthorized)?;

let scope = match Path::<ProjectName>::from_request_parts(parts, state).await {
Ok(Path(p)) => p,
Expand All @@ -82,12 +43,12 @@ where

let RouterState { service, .. } = RouterState::from_ref(state);

let allowed = user.claim.is_admin()
|| user.claim.is_deployer()
let allowed = claim.is_admin()
|| claim.is_deployer()
|| service
.permit_client
.allowed(
&user.id,
&claim.sub,
&service.find_project_by_name(&scope).await?.id,
"develop", // TODO: make this configurable per endpoint?
)
Expand All @@ -98,7 +59,7 @@ where
})?;

if allowed {
Ok(Self { user, scope })
Ok(Self { claim, scope })
} else {
Err(Error::from(ErrorKind::ProjectNotFound(scope.to_string())))
}
Expand Down

0 comments on commit d536aa4

Please sign in to comment.