diff --git a/axum/oauth2/.gitignore b/axum/oauth2/.gitignore new file mode 100644 index 00000000..4fcf1cd7 --- /dev/null +++ b/axum/oauth2/.gitignore @@ -0,0 +1,2 @@ +target/ +Secrets*.toml diff --git a/axum/oauth2/Cargo.toml b/axum/oauth2/Cargo.toml new file mode 100644 index 00000000..881193fb --- /dev/null +++ b/axum/oauth2/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "oauth-axum-shuttle-ex" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1.0.72" +axum = { version = "0.7.2", features = ["multipart", "macros"] } +axum-extra = { version = "0.9.2", features = ["cookie-private"] } +chrono = { version = "0.4.35", features = ["clock"] } +oauth2 = "4.4.1" +reqwest = { version = "0.11.18", features = ["json"] } +serde = { version = "1.0.183", features = ["derive"] } +shuttle-axum = "0.41.0" +shuttle-runtime = "0.41.0" +shuttle-secrets = "0.41.0" +shuttle-shared-db = { version = "0.41.0", features = ["postgres", "sqlx"] } +sqlx = { version = "0.7.2", features = ["runtime-tokio-rustls", "macros", "chrono"] } +thiserror = "1.0.57" +time = "0.3.25" +tokio = "1.28.2" +tracing = "0.1.37" diff --git a/axum/oauth2/README.md b/axum/oauth2/README.md new file mode 100644 index 00000000..bc3237a2 --- /dev/null +++ b/axum/oauth2/README.md @@ -0,0 +1,18 @@ +## OAuth Axum Rust Example +This repo is an example of how you can quickly and easily implement OAuth using the Axum web framework in Rust. Hosted on Shuttle. + +### How to Run +Make sure you set up your Google OAuth, which you can find a link to set up [here.](https://console.cloud.google.com/apis/dashboard) + +Initialise your Shuttle project with `cargo shuttle init`: +```sh +cargo shuttle init --from shuttle-hq/examples --subfolder axum/oauth2 +``` + +Set your secrets in the Secrets.toml file: +```toml +GOOGLE_OAUTH_CLIENT_ID = "your-client-id" +GOOGLE_OAUTH_CLIENT_SECRET = "your-client-secret" +``` + +Use `cargo shuttle run` and visit `http://localhost:8000` once the app is running, then try it out! diff --git a/axum/oauth2/migrations/20230815100114_schema.down.sql b/axum/oauth2/migrations/20230815100114_schema.down.sql new file mode 100644 index 00000000..16f2ef4c --- /dev/null +++ b/axum/oauth2/migrations/20230815100114_schema.down.sql @@ -0,0 +1,3 @@ +-- Add down migration script here +DROP TABLE users; +DROP TABLE sessions; \ No newline at end of file diff --git a/axum/oauth2/migrations/20230815100114_schema.up.sql b/axum/oauth2/migrations/20230815100114_schema.up.sql new file mode 100644 index 00000000..5e5d8234 --- /dev/null +++ b/axum/oauth2/migrations/20230815100114_schema.up.sql @@ -0,0 +1,15 @@ +-- Add up migration script here +CREATE TABLE IF NOT EXISTS users ( + id SERIAL PRIMARY KEY, + email VARCHAR(255) NOT NULL UNIQUE, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + last_updated TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS sessions ( + id SERIAL PRIMARY KEY, + user_id INT NOT NULL UNIQUE, + session_id VARCHAR NOT NULL, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) +); \ No newline at end of file diff --git a/axum/oauth2/src/main.rs b/axum/oauth2/src/main.rs new file mode 100644 index 00000000..e6ff535f --- /dev/null +++ b/axum/oauth2/src/main.rs @@ -0,0 +1,93 @@ +use axum::{extract::FromRef, response::Html, routing::get, Extension, Router}; +use axum_extra::extract::cookie::Key; +use oauth2::{basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl}; +use reqwest::Client; +use routes::oauth; +use shuttle_secrets::SecretStore; +use sqlx::PgPool; +pub mod routes; + +#[derive(Clone)] +pub struct AppState { + db: PgPool, + ctx: Client, + key: Key, +} + +// this impl tells `SignedCookieJar` how to access the key from our state +impl FromRef for Key { + fn from_ref(state: &AppState) -> Self { + state.key.clone() + } +} + +#[shuttle_runtime::main] +async fn axum( + #[shuttle_shared_db::Postgres] db: PgPool, + #[shuttle_secrets::Secrets] secrets: SecretStore, +) -> shuttle_axum::ShuttleAxum { + sqlx::migrate!() + .run(&db) + .await + .expect("Failed migrations :("); + + let oauth_id = secrets.get("GOOGLE_OAUTH_CLIENT_ID").unwrap(); + let oauth_secret = secrets.get("GOOGLE_OAUTH_CLIENT_SECRET").unwrap(); + + let ctx = Client::new(); + + let state = AppState { + db, + ctx, + key: Key::generate(), + }; + + let oauth_client = build_oauth_client(oauth_id.clone(), oauth_secret); + + let router = init_router(state, oauth_client, oauth_id); + + Ok(router.into()) +} + +fn init_router(state: AppState, oauth_client: BasicClient, oauth_id: String) -> Router { + let auth_router = Router::new().route("/auth/google_callback", get(oauth::google_callback)); + + let protected_router = Router::new().route("/", get(oauth::protected)); + + let homepage_router = Router::new() + .route("/", get(homepage)) + .layer(Extension(oauth_id)); + + Router::new() + .nest("/api", auth_router) + .nest("/protected", protected_router) + .nest("/", homepage_router) + .layer(Extension(oauth_client)) + .with_state(state) +} + +fn build_oauth_client(client_id: String, client_secret: String) -> BasicClient { + let redirect_url = "http://localhost:8000/api/auth/google_callback".to_string(); + + let auth_url = AuthUrl::new("https://accounts.google.com/o/oauth2/v2/auth".to_string()) + .expect("Invalid authorization endpoint URL"); + let token_url = TokenUrl::new("https://www.googleapis.com/oauth2/v3/token".to_string()) + .expect("Invalid token endpoint URL"); + + BasicClient::new( + ClientId::new(client_id), + Some(ClientSecret::new(client_secret)), + auth_url, + Some(token_url), + ) + .set_redirect_uri(RedirectUrl::new(redirect_url).unwrap()) +} + +#[axum::debug_handler] +async fn homepage(Extension(oauth_id): Extension) -> Html { + Html(format!("

Welcome!

+ + + Click here to sign into Google! + ")) +} diff --git a/axum/oauth2/src/routes/errors.rs b/axum/oauth2/src/routes/errors.rs new file mode 100644 index 00000000..0a01fdf7 --- /dev/null +++ b/axum/oauth2/src/routes/errors.rs @@ -0,0 +1,46 @@ +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ApiError { + #[error("SQL error: {0}")] + SQL(#[from] sqlx::Error), + #[error("HTTP request error: {0}")] + Request(#[from] reqwest::Error), + #[error("OAuth token error: {0}")] + TokenError( + #[from] + oauth2::RequestTokenError< + oauth2::reqwest::Error, + oauth2::StandardErrorResponse, + >, + ), + #[error("You're not authorized!")] + Unauthorized, + #[error("Attempted to get a non-none value but found none")] + OptionError, + #[error("Attempted to parse a number to an integer but errored out: {0}")] + ParseIntError(#[from] std::num::TryFromIntError), + #[error("Encountered an error trying to convert an infallible value: {0}")] + FromRequestPartsError(#[from] std::convert::Infallible), +} + +impl IntoResponse for ApiError { + fn into_response(self) -> Response { + let response = match self { + Self::SQL(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), + Self::Request(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), + Self::TokenError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), + Self::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized!".to_string()), + Self::OptionError => ( + StatusCode::INTERNAL_SERVER_ERROR, + "Attempted to get a non-none value but found none".to_string(), + ), + Self::ParseIntError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), + Self::FromRequestPartsError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), + }; + + response.into_response() + } +} diff --git a/axum/oauth2/src/routes/mod.rs b/axum/oauth2/src/routes/mod.rs new file mode 100644 index 00000000..11e76bc2 --- /dev/null +++ b/axum/oauth2/src/routes/mod.rs @@ -0,0 +1,2 @@ +pub mod errors; +pub mod oauth; diff --git a/axum/oauth2/src/routes/oauth.rs b/axum/oauth2/src/routes/oauth.rs new file mode 100644 index 00000000..957e3b61 --- /dev/null +++ b/axum/oauth2/src/routes/oauth.rs @@ -0,0 +1,114 @@ +use crate::routes::errors::ApiError; + +use crate::AppState; +use axum::{ + extract::{FromRequest, FromRequestParts, Query, Request, State}, + http::StatusCode, + response::{IntoResponse, Redirect}, + Extension, +}; +use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; +use chrono::{Duration, Local}; +use oauth2::{basic::BasicClient, reqwest::async_http_client, AuthorizationCode, TokenResponse}; +use serde::Deserialize; +use time::Duration as TimeDuration; + +#[derive(Debug, Deserialize)] +pub struct AuthRequest { + code: String, +} + +pub async fn google_callback( + State(state): State, + jar: PrivateCookieJar, + Query(query): Query, + Extension(oauth_client): Extension, +) -> Result { + let token = oauth_client + .exchange_code(AuthorizationCode::new(query.code)) + .request_async(async_http_client) + .await?; + + let profile = state + .ctx + .get("https://openidconnect.googleapis.com/v1/userinfo") + .bearer_auth(token.access_token().secret().to_owned()) + .send() + .await?; + + let profile = profile.json::().await?; + + let Some(secs) = token.expires_in() else { + return Err(ApiError::OptionError); + }; + + let secs: i64 = secs.as_secs().try_into()?; + + let max_age = Local::now().naive_local() + Duration::try_seconds(secs).unwrap(); + + let cookie = Cookie::build(("sid", token.access_token().secret().to_owned())) + .domain(".app.localhost") + .path("/") + .secure(true) + .http_only(true) + .max_age(TimeDuration::seconds(secs)); + + sqlx::query("INSERT INTO users (email) VALUES ($1) ON CONFLICT (email) DO NOTHING") + .bind(profile.email.clone()) + .execute(&state.db) + .await?; + + sqlx::query( + "INSERT INTO sessions (user_id, session_id, expires_at) VALUES ( + (SELECT ID FROM USERS WHERE email = $1 LIMIT 1), + $2, $3) + ON CONFLICT (user_id) DO UPDATE SET + session_id = excluded.session_id, + expires_at = excluded.expires_at", + ) + .bind(profile.email) + .bind(token.access_token().secret().to_owned()) + .bind(max_age) + .execute(&state.db) + .await?; + + Ok((jar.add(cookie), Redirect::to("/protected"))) +} + +#[derive(Deserialize, sqlx::FromRow, Clone)] +pub struct UserProfile { + email: String, +} + +#[axum::async_trait] +impl FromRequest for UserProfile { + type Rejection = ApiError; + async fn from_request(req: Request, state: &AppState) -> Result { + let state = state.to_owned(); + let (mut parts, _body) = req.into_parts(); + let cookiejar: PrivateCookieJar = + PrivateCookieJar::from_request_parts(&mut parts, &state).await?; + + let Some(cookie) = cookiejar.get("sid").map(|cookie| cookie.value().to_owned()) else { + return Err(ApiError::Unauthorized); + }; + + let res = sqlx::query_as::<_, UserProfile>( + "SELECT + users.email + FROM sessions + LEFT JOIN USERS ON sessions.user_id = users.id + WHERE sessions.session_id = $1 + LIMIT 1", + ) + .bind(cookie) + .fetch_one(&state.db) + .await?; + + Ok(Self { email: res.email }) + } +} + +pub async fn protected(profile: UserProfile) -> impl IntoResponse { + (StatusCode::OK, profile.email) +} diff --git a/templates.toml b/templates.toml index 4596e18d..d4a07311 100644 --- a/templates.toml +++ b/templates.toml @@ -163,6 +163,13 @@ path = "axum/jwt-authentication" use_cases = ["Web app", "Authentication"] tags = ["axum", "jwt"] +[templates.axum-oauth2] +title = "OAuth authentication" +description = "Use Google OAuth to authenticate API endpoints" +path = "axum/oauth2" +use_cases = ["Web app", "Authentication"] +tags = ["axum", "oauth"] + [templates.axum-postgres] title = "Postgres" description = "Todo list with a Postgres database"