Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: revert addition of apikey to auth #886

Merged
merged 2 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion auth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ tracing-subscriber = { workspace = true }

[dependencies.shuttle-common]
workspace = true
features = ["backend", "models", "persist"]
features = ["backend", "models"]

[dev-dependencies]
axum-extra = { version = "0.7.1", features = ["cookie"] }
Expand Down
2 changes: 1 addition & 1 deletion auth/src/api/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub(crate) async fn convert_key(
let User {
name, account_tier, ..
} = user_manager
.get_user_by_key(key.as_ref().clone())
.get_user_by_key(key)
.await
.map_err(|_| StatusCode::UNAUTHORIZED)?;

Expand Down
17 changes: 8 additions & 9 deletions auth/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ mod user;
use std::{io, str::FromStr, time::Duration};

use args::StartArgs;
use shuttle_common::ApiKey;
use sqlx::{
migrate::Migrator,
query,
Expand All @@ -16,7 +15,10 @@ use sqlx::{
};
use tracing::info;

use crate::{api::serve, user::AccountTier};
use crate::{
api::serve,
user::{AccountTier, Key},
};
pub use api::ApiBuilder;
pub use args::{Args, Commands, InitArgs};

Expand All @@ -39,8 +41,8 @@ pub async fn start(pool: SqlitePool, args: StartArgs) -> io::Result<()> {

pub async fn init(pool: SqlitePool, args: InitArgs) -> io::Result<()> {
let key = match args.key {
Some(ref key) => ApiKey::parse(key).unwrap(),
None => ApiKey::generate(),
Some(ref key) => Key::from_str(key).unwrap(),
None => Key::new_random(),
};

query("INSERT INTO users (account_name, key, account_tier) VALUES (?1, ?2, ?3)")
Expand All @@ -51,11 +53,8 @@ pub async fn init(pool: SqlitePool, args: InitArgs) -> io::Result<()> {
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;

println!(
"`{}` created as super user with key: {}",
args.name,
key.as_ref()
);
println!("`{}` created as super user with key: {key}", args.name,);

Ok(())
}

Expand Down
62 changes: 36 additions & 26 deletions auth/src/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@ use axum::{
http::request::Parts,
TypedHeader,
};
use rand::distributions::{Alphanumeric, DistString};
use serde::{Deserialize, Deserializer, Serialize};
use shuttle_common::{
claims::{Scope, ScopeBuilder},
ApiKey,
};
use shuttle_common::claims::{Scope, ScopeBuilder};
use sqlx::{query, Row, SqlitePool};
use tracing::{trace, Span};

Expand All @@ -21,7 +19,7 @@ use crate::{api::UserManagerState, error::Error};
pub trait UserManagement: Send + Sync {
async fn create_user(&self, name: AccountName, tier: AccountTier) -> Result<User, Error>;
async fn get_user(&self, name: AccountName) -> Result<User, Error>;
async fn get_user_by_key(&self, key: ApiKey) -> Result<User, Error>;
async fn get_user_by_key(&self, key: Key) -> Result<User, Error>;
}

#[derive(Clone)]
Expand All @@ -32,7 +30,7 @@ pub struct UserManager {
#[async_trait]
impl UserManagement for UserManager {
async fn create_user(&self, name: AccountName, tier: AccountTier) -> Result<User, Error> {
let key = ApiKey::generate();
let key = Key::new_random();

query("INSERT INTO users (account_name, key, account_tier) VALUES (?1, ?2, ?3)")
.bind(&name)
Expand All @@ -57,7 +55,7 @@ impl UserManagement for UserManager {
.ok_or(Error::UserNotFound)
}

async fn get_user_by_key(&self, key: ApiKey) -> Result<User, Error> {
async fn get_user_by_key(&self, key: Key) -> Result<User, Error> {
query("SELECT account_name, key, account_tier FROM users WHERE key = ?1")
.bind(&key)
.fetch_optional(&self.pool)
Expand All @@ -71,10 +69,10 @@ impl UserManagement for UserManager {
}
}

#[derive(Clone, Deserialize, PartialEq, Eq, Serialize, Debug)]
#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)]
pub struct User {
pub name: AccountName,
pub key: ApiKey,
pub key: Key,
pub account_tier: AccountTier,
}

Expand All @@ -83,7 +81,7 @@ impl User {
self.account_tier == AccountTier::Admin
}

pub fn new(name: AccountName, key: ApiKey, account_tier: AccountTier) -> Self {
pub fn new(name: AccountName, key: Key, account_tier: AccountTier) -> Self {
Self {
name,
key,
Expand All @@ -106,7 +104,7 @@ where
let user_manager: UserManagerState = UserManagerState::from_ref(state);

let user = user_manager
.get_user_by_key(key.as_ref().clone())
.get_user_by_key(key)
.await
// Absorb any error into `Unauthorized`
.map_err(|_| Error::Unauthorized)?;
Expand All @@ -122,21 +120,16 @@ impl From<User> for shuttle_common::models::user::Response {
fn from(user: User) -> Self {
Self {
name: user.name.to_string(),
key: user.key.as_ref().to_string(),
key: user.key.to_string(),
account_tier: user.account_tier.to_string(),
}
}
}

/// A wrapper around [ApiKey] so we can implement [FromRequestParts]
/// for it.
pub struct Key(ApiKey);

impl AsRef<ApiKey> for Key {
fn as_ref(&self) -> &ApiKey {
&self.0
}
}
#[derive(Clone, sqlx::Type, PartialEq, Hash, Eq, Serialize, Deserialize, Debug)]
#[serde(transparent)]
#[sqlx(transparent)]
pub struct Key(String);

#[async_trait]
impl<S> FromRequestParts<S> for Key
Expand All @@ -149,14 +142,31 @@ where
let key = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
.await
.map_err(|_| Error::KeyMissing)
.and_then(|TypedHeader(Authorization(bearer))| {
let bearer = bearer.token().trim();
ApiKey::parse(bearer).map_err(|_| Self::Rejection::Unauthorized)
})?;
.and_then(|TypedHeader(Authorization(bearer))| bearer.token().trim().parse())?;

trace!("got bearer key");

Ok(Key(key))
Ok(key)
}
}

impl std::fmt::Display for Key {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

impl FromStr for Key {
type Err = Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(s.to_string()))
}
}

impl Key {
pub fn new_random() -> Self {
Self(Alphanumeric.sample_string(&mut rand::thread_rng(), 16))
}
}

Expand Down