Skip to content

Commit

Permalink
Merge pull request #147 from RMoodsTeam/141-websocket-backend-service
Browse files Browse the repository at this point in the history
141 websocket backend service
  • Loading branch information
SebastianNowak01 authored Oct 28, 2024
2 parents f538be4 + 1acab32 commit cf74c3b
Show file tree
Hide file tree
Showing 8 changed files with 383 additions and 14 deletions.
2 changes: 1 addition & 1 deletion backend/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ CLIENT_SECRET=
DATABASE_URL=
JWT_SECRET=
GOOGLE_CLIENT_ID=
GOOGLE_CLIENT_SECRET=
GOOGLE_CLIENT_SECRET=
13 changes: 8 additions & 5 deletions backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ opt-level = 1

[dependencies]
tokio = { version = "1.37.0", features = ["full"] }
axum = "0.7.5"
axum = { version = "0.7.5", features = ["ws"] }
env_logger = "0.11.3"
log = "0.4.21"
serde = "1.0.203"
serde_json = "1.0.117"
tower-http = {version = "0.5.2", features = ["full"]}
utoipa = {version = "4.2.3", features = ["axum_extras", "chrono"]}
tower-http = { version = "0.5.2", features = ["full"] }
utoipa = { version = "4.2.3", features = ["axum_extras", "chrono"] }
utoipa-swagger-ui = { version = "7", features = ["axum"] }
lazy_static = "1.5.0"
anyhow = "1.0.86"
Expand All @@ -25,8 +25,11 @@ thiserror = "1.0.62"
serde_with = "3.8.3"
log-derive = "0.4.1"
dotenvy = "0.15.7"
sqlx = { version = "0.7", features = [ "runtime-tokio", "postgres", "macros" ] }
sqlx = { version = "0.7", features = ["runtime-tokio", "postgres", "macros"] }
jsonwebtoken = "9.3.0"
chrono = "0.4.38"
http = "1.1.0"

tokio-tungstenite = "0.24.0"
futures-util = "0.3.31"
tokio-util = "0.7.12"
futures = "0.3.31"
42 changes: 40 additions & 2 deletions backend/src/api/auth/google.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use super::error::AuthError;
use crate::api::auth::jwt::decode_jwt;
use axum::async_trait;
use axum::extract::FromRequestParts;
use derive_getters::Getters;
use http::request::Parts;
use http::StatusCode;
use log_derive::logfn;
use reqwest::{multipart::Form, Client};
use serde::{Deserialize, Serialize};

use super::error::AuthError;

#[derive(Deserialize, Debug, Getters)]
#[allow(unused)]
pub struct GoogleTokenResponse {
Expand All @@ -29,6 +33,40 @@ pub struct GoogleUserInfo {
email_verified: bool,
}

/// Implement `FromRequestParts` for `GoogleUserInfo` to extract the user info from the request.
///
/// With this trait implemented, the user info can be extracted by any Axum HTTP handler without
/// jumping through extra hoops like obtaining an `Authorization` header and decoding the JWT.
#[async_trait]
impl<T> FromRequestParts<T> for GoogleUserInfo {
type Rejection = (StatusCode, String);

async fn from_request_parts(parts: &mut Parts, _: &T) -> Result<Self, Self::Rejection> {
let auth_header = parts
.headers
.get("Authorization")
.ok_or((
StatusCode::UNAUTHORIZED,
"No Authorization header".to_string(),
))?
.to_str()
.map_err(|_| {
(
StatusCode::UNAUTHORIZED,
"Invalid Authorization header".to_string(),
)
})?;

let token = auth_header.split_whitespace().nth(1).ok_or((
StatusCode::UNAUTHORIZED,
"No token in Authorization header".to_string(),
))?;

let claims = decode_jwt(token).map_err(|e| (StatusCode::UNAUTHORIZED, e.to_string()))?;
Ok(claims.claims.user_info)
}
}

#[logfn(err = "ERROR", fmt = "Failed to fetch access token: {:?}")]
pub async fn fetch_google_access_token(
auth_code: String,
Expand Down
2 changes: 1 addition & 1 deletion backend/src/api/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::AppState;
use axum::{routing::post, Router};

pub mod error;
mod google;
pub mod google;
pub mod jwt;
pub(crate) mod login;
pub mod middleware;
Expand Down
22 changes: 20 additions & 2 deletions backend/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use crate::open_api::ApiDoc;
use crate::reddit_fetcher::fetcher::RMoodsFetcher;
use crate::startup::{shutdown_signal, verify_environment};
use crate::websocket::SystemMessage;
use api::auth;
use axum::Router;
use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE};
use log::{error, info, warn};
use reqwest::Client;
use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
use std::net::SocketAddr;
use tower_http::{
cors::{Any, CorsLayer},
trace::TraceLayer,
Expand All @@ -19,14 +21,17 @@ mod app_error;
mod open_api;
mod reddit_fetcher;
mod startup;
mod websocket;

/// State to be shared between all routes.
///
/// Contains common resources that shouldn't be created over and over again.
#[derive(Clone)]
pub struct AppState {
pub fetcher: RMoodsFetcher,
pub pool: Pool<Postgres>,
pub http: Client,
pub system_tx: tokio::sync::mpsc::Sender<SystemMessage>,
}

/// Run the server, assuming the environment has been already validated.
Expand All @@ -42,10 +47,20 @@ async fn run() -> anyhow::Result<()> {
let fetcher = RMoodsFetcher::new(http.clone()).await?;
info!("Connected to Reddit");

info!("Starting the WebSocket service");
let cancellation_token = tokio_util::sync::CancellationToken::new();

let (system_tx, system_rx) = tokio::sync::mpsc::channel::<SystemMessage>(100);
tokio::spawn(websocket::start_service(
system_rx,
cancellation_token.clone(),
));

let state = AppState {
fetcher,
pool,
http,
system_tx,
};

// Allow browsers to use GET and PUT from any origin
Expand All @@ -60,14 +75,17 @@ async fn run() -> anyhow::Result<()> {
let authorization = axum::middleware::from_fn(auth::middleware::authorization);

// Routes after the layers won't have the layers applied
// Example: /auth routes won't have the authorization layer, but /api will
let app = Router::<AppState>::new()
.nest("/api", api::router())
.nest("/ws", websocket::router())
.layer(authorization)
.nest("/auth", auth::router())
.with_state(state)
.layer(tracing)
.layer(cors)
.merge(SwaggerUi::new("/doc/ui").url("/doc/api.json", ApiDoc::openapi()));
.merge(SwaggerUi::new("/doc/ui").url("/doc/api.json", ApiDoc::openapi()))
.into_make_service_with_connect_info::<SocketAddr>();

let port = std::env::var("PORT").unwrap_or_else(|_| "8001".to_string());
// Listen on all addresses
Expand All @@ -77,7 +95,7 @@ async fn run() -> anyhow::Result<()> {

info!("Started the RMoods server at {}", addr);
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.with_graceful_shutdown(shutdown_signal(cancellation_token))
.await?;

Ok(())
Expand Down
17 changes: 14 additions & 3 deletions backend/src/startup.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use log::info;
use tokio::signal;
use tokio_util::sync::CancellationToken;

/// Ensure that all necessary environment variables are available at server startup.
/// It's important to keep this updated as our .env file grows.
Expand All @@ -24,7 +26,7 @@ pub fn verify_environment() -> bool {
is_ok
}

pub async fn shutdown_signal() {
pub async fn shutdown_signal(cancellation_token: CancellationToken) {
let ctrl_c = async {
signal::ctrl_c()
.await
Expand All @@ -44,8 +46,17 @@ pub async fn shutdown_signal() {

tokio::select! {
_ = ctrl_c => {
log::info!("Received Ctrl+C signal, shutting down");
log::info!("Received Ctrl+C signal, shutting down main HTTP server");
cancellation_token.cancel();
info!("Shutting down WebSocket service");
},
_ = terminate => {
log::info!("Received SIGTERM, shutting down");
cancellation_token.cancel();
info!("Shutting down WebSocket service");
},
_ = cancellation_token.cancelled() => {
info!("WebSocket service triggered shutdown, shutting down main HTTP server");
},
_ = terminate => {},
}
}
Loading

0 comments on commit cf74c3b

Please sign in to comment.