Skip to content

Commit

Permalink
fix: revert addition of apikey to auth (#886)
Browse files Browse the repository at this point in the history
* fix: revert addition of apikey to auth

* fix: display impl is needed for key.to_string()
  • Loading branch information
oddgrd authored May 8, 2023
1 parent ab12fdd commit 7054e6a
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 37 deletions.
2 changes: 1 addition & 1 deletion auth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ tracing-subscriber = { workspace = true }

[dependencies.shuttle-common]
workspace = true
features = ["backend", "models", "persist"]
features = ["backend", "models"]

[dev-dependencies]
axum-extra = { version = "0.7.1", features = ["cookie"] }
Expand Down
2 changes: 1 addition & 1 deletion auth/src/api/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub(crate) async fn convert_key(
let User {
name, account_tier, ..
} = user_manager
.get_user_by_key(key.as_ref().clone())
.get_user_by_key(key)
.await
.map_err(|_| StatusCode::UNAUTHORIZED)?;

Expand Down
17 changes: 8 additions & 9 deletions auth/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ mod user;
use std::{io, str::FromStr, time::Duration};

use args::StartArgs;
use shuttle_common::ApiKey;
use sqlx::{
migrate::Migrator,
query,
Expand All @@ -16,7 +15,10 @@ use sqlx::{
};
use tracing::info;

use crate::{api::serve, user::AccountTier};
use crate::{
api::serve,
user::{AccountTier, Key},
};
pub use api::ApiBuilder;
pub use args::{Args, Commands, InitArgs};

Expand All @@ -39,8 +41,8 @@ pub async fn start(pool: SqlitePool, args: StartArgs) -> io::Result<()> {

pub async fn init(pool: SqlitePool, args: InitArgs) -> io::Result<()> {
let key = match args.key {
Some(ref key) => ApiKey::parse(key).unwrap(),
None => ApiKey::generate(),
Some(ref key) => Key::from_str(key).unwrap(),
None => Key::new_random(),
};

query("INSERT INTO users (account_name, key, account_tier) VALUES (?1, ?2, ?3)")
Expand All @@ -51,11 +53,8 @@ pub async fn init(pool: SqlitePool, args: InitArgs) -> io::Result<()> {
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;

println!(
"`{}` created as super user with key: {}",
args.name,
key.as_ref()
);
println!("`{}` created as super user with key: {key}", args.name,);

Ok(())
}

Expand Down
62 changes: 36 additions & 26 deletions auth/src/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@ use axum::{
http::request::Parts,
TypedHeader,
};
use rand::distributions::{Alphanumeric, DistString};
use serde::{Deserialize, Deserializer, Serialize};
use shuttle_common::{
claims::{Scope, ScopeBuilder},
ApiKey,
};
use shuttle_common::claims::{Scope, ScopeBuilder};
use sqlx::{query, Row, SqlitePool};
use tracing::{trace, Span};

Expand All @@ -21,7 +19,7 @@ use crate::{api::UserManagerState, error::Error};
pub trait UserManagement: Send + Sync {
async fn create_user(&self, name: AccountName, tier: AccountTier) -> Result<User, Error>;
async fn get_user(&self, name: AccountName) -> Result<User, Error>;
async fn get_user_by_key(&self, key: ApiKey) -> Result<User, Error>;
async fn get_user_by_key(&self, key: Key) -> Result<User, Error>;
}

#[derive(Clone)]
Expand All @@ -32,7 +30,7 @@ pub struct UserManager {
#[async_trait]
impl UserManagement for UserManager {
async fn create_user(&self, name: AccountName, tier: AccountTier) -> Result<User, Error> {
let key = ApiKey::generate();
let key = Key::new_random();

query("INSERT INTO users (account_name, key, account_tier) VALUES (?1, ?2, ?3)")
.bind(&name)
Expand All @@ -57,7 +55,7 @@ impl UserManagement for UserManager {
.ok_or(Error::UserNotFound)
}

async fn get_user_by_key(&self, key: ApiKey) -> Result<User, Error> {
async fn get_user_by_key(&self, key: Key) -> Result<User, Error> {
query("SELECT account_name, key, account_tier FROM users WHERE key = ?1")
.bind(&key)
.fetch_optional(&self.pool)
Expand All @@ -71,10 +69,10 @@ impl UserManagement for UserManager {
}
}

#[derive(Clone, Deserialize, PartialEq, Eq, Serialize, Debug)]
#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)]
pub struct User {
pub name: AccountName,
pub key: ApiKey,
pub key: Key,
pub account_tier: AccountTier,
}

Expand All @@ -83,7 +81,7 @@ impl User {
self.account_tier == AccountTier::Admin
}

pub fn new(name: AccountName, key: ApiKey, account_tier: AccountTier) -> Self {
pub fn new(name: AccountName, key: Key, account_tier: AccountTier) -> Self {
Self {
name,
key,
Expand All @@ -106,7 +104,7 @@ where
let user_manager: UserManagerState = UserManagerState::from_ref(state);

let user = user_manager
.get_user_by_key(key.as_ref().clone())
.get_user_by_key(key)
.await
// Absorb any error into `Unauthorized`
.map_err(|_| Error::Unauthorized)?;
Expand All @@ -122,21 +120,16 @@ impl From<User> for shuttle_common::models::user::Response {
fn from(user: User) -> Self {
Self {
name: user.name.to_string(),
key: user.key.as_ref().to_string(),
key: user.key.to_string(),
account_tier: user.account_tier.to_string(),
}
}
}

/// A wrapper around [ApiKey] so we can implement [FromRequestParts]
/// for it.
pub struct Key(ApiKey);

impl AsRef<ApiKey> for Key {
fn as_ref(&self) -> &ApiKey {
&self.0
}
}
#[derive(Clone, sqlx::Type, PartialEq, Hash, Eq, Serialize, Deserialize, Debug)]
#[serde(transparent)]
#[sqlx(transparent)]
pub struct Key(String);

#[async_trait]
impl<S> FromRequestParts<S> for Key
Expand All @@ -149,14 +142,31 @@ where
let key = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
.await
.map_err(|_| Error::KeyMissing)
.and_then(|TypedHeader(Authorization(bearer))| {
let bearer = bearer.token().trim();
ApiKey::parse(bearer).map_err(|_| Self::Rejection::Unauthorized)
})?;
.and_then(|TypedHeader(Authorization(bearer))| bearer.token().trim().parse())?;

trace!("got bearer key");

Ok(Key(key))
Ok(key)
}
}

impl std::fmt::Display for Key {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

impl FromStr for Key {
type Err = Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(s.to_string()))
}
}

impl Key {
pub fn new_random() -> Self {
Self(Alphanumeric.sample_string(&mut rand::thread_rng(), 16))
}
}

Expand Down

0 comments on commit 7054e6a

Please sign in to comment.