Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Oauth2 template #155

Merged
merged 4 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions axum/oauth2/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
target/
Secrets*.toml
22 changes: 22 additions & 0 deletions axum/oauth2/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
name = "oauth-axum-shuttle-ex"
version = "0.1.0"
edition = "2021"

[dependencies]
anyhow = "1.0.72"
axum = { version = "0.7.2", features = ["multipart", "macros"] }
axum-extra = { version = "0.9.2", features = ["cookie-private"] }
chrono = { version = "0.4.35", features = ["clock"] }
oauth2 = "4.4.1"
reqwest = { version = "0.11.18", features = ["json"] }
serde = { version = "1.0.183", features = ["derive"] }
shuttle-axum = "0.41.0"
shuttle-runtime = "0.41.0"
shuttle-secrets = "0.41.0"
shuttle-shared-db = { version = "0.41.0", features = ["postgres", "sqlx"] }
sqlx = { version = "0.7.2", features = ["runtime-tokio-rustls", "macros", "chrono"] }
thiserror = "1.0.57"
time = "0.3.25"
tokio = "1.28.2"
tracing = "0.1.37"
18 changes: 18 additions & 0 deletions axum/oauth2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
## OAuth Axum Rust Example
This repo is an example of how you can quickly and easily implement OAuth using the Axum web framework in Rust. Hosted on Shuttle.

### How to Run
Make sure you set up your Google OAuth, which you can find a link to set up [here.](https://console.cloud.google.com/apis/dashboard)

Initialise your Shuttle project with `cargo shuttle init`:
```sh
cargo shuttle init --from shuttle-hq/examples --subfolder axum/oauth2
```

Set your secrets in the Secrets.toml file:
```toml
GOOGLE_OAUTH_CLIENT_ID = "your-client-id"
GOOGLE_OAUTH_CLIENT_SECRET = "your-client-secret"
```

Use `cargo shuttle run` and visit `http://localhost:8000` once the app is running, then try it out!
3 changes: 3 additions & 0 deletions axum/oauth2/migrations/20230815100114_schema.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-- Add down migration script here
DROP TABLE users;
DROP TABLE sessions;
15 changes: 15 additions & 0 deletions axum/oauth2/migrations/20230815100114_schema.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
-- Add up migration script here
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
email VARCHAR(255) NOT NULL UNIQUE,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
last_updated TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);

CREATE TABLE IF NOT EXISTS sessions (
id SERIAL PRIMARY KEY,
user_id INT NOT NULL UNIQUE,
session_id VARCHAR NOT NULL,
expires_at TIMESTAMP WITH TIME ZONE NOT NULL,
FOREIGN KEY (user_id) REFERENCES users(id)
);
93 changes: 93 additions & 0 deletions axum/oauth2/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use axum::{extract::FromRef, response::Html, routing::get, Extension, Router};
use axum_extra::extract::cookie::Key;
use oauth2::{basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl};
use reqwest::Client;
use routes::oauth;
use shuttle_secrets::SecretStore;
use sqlx::PgPool;
pub mod routes;

#[derive(Clone)]
pub struct AppState {
db: PgPool,
ctx: Client,
key: Key,
}

// this impl tells `SignedCookieJar` how to access the key from our state
impl FromRef<AppState> for Key {
fn from_ref(state: &AppState) -> Self {
state.key.clone()
}
}

#[shuttle_runtime::main]
async fn axum(
#[shuttle_shared_db::Postgres] db: PgPool,
#[shuttle_secrets::Secrets] secrets: SecretStore,
) -> shuttle_axum::ShuttleAxum {
sqlx::migrate!()
.run(&db)
.await
.expect("Failed migrations :(");

let oauth_id = secrets.get("GOOGLE_OAUTH_CLIENT_ID").unwrap();
let oauth_secret = secrets.get("GOOGLE_OAUTH_CLIENT_SECRET").unwrap();

let ctx = Client::new();

let state = AppState {
db,
ctx,
key: Key::generate(),
};

let oauth_client = build_oauth_client(oauth_id.clone(), oauth_secret);

let router = init_router(state, oauth_client, oauth_id);

Ok(router.into())
}

fn init_router(state: AppState, oauth_client: BasicClient, oauth_id: String) -> Router {
let auth_router = Router::new().route("/auth/google_callback", get(oauth::google_callback));

let protected_router = Router::new().route("/", get(oauth::protected));

let homepage_router = Router::new()
.route("/", get(homepage))
.layer(Extension(oauth_id));

Router::new()
.nest("/api", auth_router)
.nest("/protected", protected_router)
.nest("/", homepage_router)
.layer(Extension(oauth_client))
.with_state(state)
}

fn build_oauth_client(client_id: String, client_secret: String) -> BasicClient {
let redirect_url = "http://localhost:8000/api/auth/google_callback".to_string();

let auth_url = AuthUrl::new("https://accounts.google.com/o/oauth2/v2/auth".to_string())
.expect("Invalid authorization endpoint URL");
let token_url = TokenUrl::new("https://www.googleapis.com/oauth2/v3/token".to_string())
.expect("Invalid token endpoint URL");

BasicClient::new(
ClientId::new(client_id),
Some(ClientSecret::new(client_secret)),
auth_url,
Some(token_url),
)
.set_redirect_uri(RedirectUrl::new(redirect_url).unwrap())
}

#[axum::debug_handler]
async fn homepage(Extension(oauth_id): Extension<String>) -> Html<String> {
Html(format!("<p>Welcome!</p>

<a href=\"https://accounts.google.com/o/oauth2/v2/auth?scope=openid%20profile%20email&client_id={oauth_id}&response_type=code&redirect_uri=http://localhost:8000/api/auth/google_callback\">
Click here to sign into Google!
</a>"))
}
46 changes: 46 additions & 0 deletions axum/oauth2/src/routes/errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use thiserror::Error;

#[derive(Debug, Error)]
pub enum ApiError {
#[error("SQL error: {0}")]
SQL(#[from] sqlx::Error),
#[error("HTTP request error: {0}")]
Request(#[from] reqwest::Error),
#[error("OAuth token error: {0}")]
TokenError(
#[from]
oauth2::RequestTokenError<
oauth2::reqwest::Error<reqwest::Error>,
oauth2::StandardErrorResponse<oauth2::basic::BasicErrorResponseType>,
>,
),
#[error("You're not authorized!")]
Unauthorized,
#[error("Attempted to get a non-none value but found none")]
OptionError,
#[error("Attempted to parse a number to an integer but errored out: {0}")]
ParseIntError(#[from] std::num::TryFromIntError),
#[error("Encountered an error trying to convert an infallible value: {0}")]
FromRequestPartsError(#[from] std::convert::Infallible),
}

impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let response = match self {
Self::SQL(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
Self::Request(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
Self::TokenError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
Self::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized!".to_string()),
Self::OptionError => (
StatusCode::INTERNAL_SERVER_ERROR,
"Attempted to get a non-none value but found none".to_string(),
),
Self::ParseIntError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
Self::FromRequestPartsError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
};

response.into_response()
}
}
2 changes: 2 additions & 0 deletions axum/oauth2/src/routes/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod errors;
pub mod oauth;
114 changes: 114 additions & 0 deletions axum/oauth2/src/routes/oauth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use crate::routes::errors::ApiError;

use crate::AppState;
use axum::{
extract::{FromRequest, FromRequestParts, Query, Request, State},
http::StatusCode,
response::{IntoResponse, Redirect},
Extension,
};
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
use chrono::{Duration, Local};
use oauth2::{basic::BasicClient, reqwest::async_http_client, AuthorizationCode, TokenResponse};
use serde::Deserialize;
use time::Duration as TimeDuration;

#[derive(Debug, Deserialize)]
pub struct AuthRequest {
code: String,
}

pub async fn google_callback(
State(state): State<AppState>,
jar: PrivateCookieJar,
Query(query): Query<AuthRequest>,
Extension(oauth_client): Extension<BasicClient>,
) -> Result<impl IntoResponse, ApiError> {
let token = oauth_client
.exchange_code(AuthorizationCode::new(query.code))
.request_async(async_http_client)
.await?;

let profile = state
.ctx
.get("https://openidconnect.googleapis.com/v1/userinfo")
.bearer_auth(token.access_token().secret().to_owned())
.send()
.await?;

let profile = profile.json::<UserProfile>().await?;

let Some(secs) = token.expires_in() else {
return Err(ApiError::OptionError);
};

let secs: i64 = secs.as_secs().try_into()?;

let max_age = Local::now().naive_local() + Duration::try_seconds(secs).unwrap();

let cookie = Cookie::build(("sid", token.access_token().secret().to_owned()))
.domain(".app.localhost")
.path("/")
.secure(true)
.http_only(true)
.max_age(TimeDuration::seconds(secs));

sqlx::query("INSERT INTO users (email) VALUES ($1) ON CONFLICT (email) DO NOTHING")
.bind(profile.email.clone())
.execute(&state.db)
.await?;

sqlx::query(
"INSERT INTO sessions (user_id, session_id, expires_at) VALUES (
(SELECT ID FROM USERS WHERE email = $1 LIMIT 1),
$2, $3)
ON CONFLICT (user_id) DO UPDATE SET
session_id = excluded.session_id,
expires_at = excluded.expires_at",
)
.bind(profile.email)
.bind(token.access_token().secret().to_owned())
.bind(max_age)
.execute(&state.db)
.await?;

Ok((jar.add(cookie), Redirect::to("/protected")))
}

#[derive(Deserialize, sqlx::FromRow, Clone)]
pub struct UserProfile {
email: String,
}

#[axum::async_trait]
impl FromRequest<AppState> for UserProfile {
type Rejection = ApiError;
async fn from_request(req: Request, state: &AppState) -> Result<Self, Self::Rejection> {
let state = state.to_owned();
let (mut parts, _body) = req.into_parts();
let cookiejar: PrivateCookieJar =
PrivateCookieJar::from_request_parts(&mut parts, &state).await?;

let Some(cookie) = cookiejar.get("sid").map(|cookie| cookie.value().to_owned()) else {
return Err(ApiError::Unauthorized);
};

let res = sqlx::query_as::<_, UserProfile>(
"SELECT
users.email
FROM sessions
LEFT JOIN USERS ON sessions.user_id = users.id
WHERE sessions.session_id = $1
LIMIT 1",
)
.bind(cookie)
.fetch_one(&state.db)
.await?;

Ok(Self { email: res.email })
}
}

pub async fn protected(profile: UserProfile) -> impl IntoResponse {
(StatusCode::OK, profile.email)
}
7 changes: 7 additions & 0 deletions templates.toml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ path = "axum/jwt-authentication"
use_cases = ["Web app", "Authentication"]
tags = ["axum", "jwt"]

[templates.axum-oauth2]
title = "OAuth authentication"
description = "Use Google OAuth to authenticate API endpoints"
path = "axum/oauth2"
use_cases = ["Web app", "Authentication"]
tags = ["axum", "oauth"]

[templates.axum-postgres]
title = "Postgres"
description = "Todo list with a Postgres database"
Expand Down