From f888d1a12a3640f4a57ee7336bf20ed76bc13939 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Mi=C5=82ek?= Date: Fri, 25 Oct 2024 00:20:15 +0200 Subject: [PATCH 1/8] 141: Add basic WebSocket service alongside the main thread --- backend/.env.example | 1 + backend/Cargo.toml | 9 ++-- backend/src/main.rs | 12 +++++ backend/src/startup.rs | 1 + backend/src/websocket/mod.rs | 94 ++++++++++++++++++++++++++++++++++++ 5 files changed, 113 insertions(+), 4 deletions(-) create mode 100644 backend/src/websocket/mod.rs diff --git a/backend/.env.example b/backend/.env.example index 13e52ee..db3f476 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -4,3 +4,4 @@ DATABASE_URL= JWT_SECRET= GOOGLE_CLIENT_ID= GOOGLE_CLIENT_SECRET= +WEBSOCKET_PORT= \ No newline at end of file diff --git a/backend/Cargo.toml b/backend/Cargo.toml index a24eafb..cb9e5f4 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -13,8 +13,8 @@ env_logger = "0.11.3" log = "0.4.21" serde = "1.0.203" serde_json = "1.0.117" -tower-http = {version = "0.5.2", features = ["full"]} -utoipa = {version = "4.2.3", features = ["axum_extras", "chrono"]} +tower-http = { version = "0.5.2", features = ["full"] } +utoipa = { version = "4.2.3", features = ["axum_extras", "chrono"] } utoipa-swagger-ui = { version = "7", features = ["axum"] } lazy_static = "1.5.0" anyhow = "1.0.86" @@ -25,8 +25,9 @@ thiserror = "1.0.62" serde_with = "3.8.3" log-derive = "0.4.1" dotenvy = "0.15.7" -sqlx = { version = "0.7", features = [ "runtime-tokio", "postgres", "macros" ] } +sqlx = { version = "0.7", features = ["runtime-tokio", "postgres", "macros"] } jsonwebtoken = "9.3.0" chrono = "0.4.38" http = "1.1.0" - +tokio-tungstenite = "0.24.0" +futures-util = "0.3.31" \ No newline at end of file diff --git a/backend/src/main.rs b/backend/src/main.rs index 948bf5d..3192871 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,6 +1,7 @@ use crate::open_api::ApiDoc; use crate::reddit_fetcher::fetcher::RMoodsFetcher; use crate::startup::{shutdown_signal, verify_environment}; +use crate::websocket::WebSocketMessage; use api::auth; use axum::Router; use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; @@ -19,6 +20,7 @@ mod app_error; mod open_api; mod reddit_fetcher; mod startup; +mod websocket; /// State to be shared between all routes. /// Contains common resources that shouldn't be created over and over again. @@ -27,6 +29,7 @@ pub struct AppState { pub fetcher: RMoodsFetcher, pub pool: Pool, pub http: Client, + pub websocket_service_tx: tokio::sync::mpsc::Sender, } /// Run the server, assuming the environment has been already validated. @@ -42,10 +45,18 @@ async fn run() -> anyhow::Result<()> { let fetcher = RMoodsFetcher::new(http.clone()).await?; info!("Connected to Reddit"); + info!("Starting the WebSocket service"); + // Drop the receiver, we don't need it + let (tx, rx) = tokio::sync::mpsc::channel::(100); + let port = std::env::var("WEBSOCKET_PORT").expect("WEBSOCKET_PORT is set"); + let port = port.parse::().expect("WEBSOCKET_PORT is a valid u16"); + tokio::spawn(websocket::start_service(port, rx)); + let state = AppState { fetcher, pool, http, + websocket_service_tx: tx, }; // Allow browsers to use GET and PUT from any origin @@ -60,6 +71,7 @@ async fn run() -> anyhow::Result<()> { let authorization = axum::middleware::from_fn(auth::middleware::authorization); // Routes after the layers won't have the layers applied + // Example: /auth routes won't have the authorization layer, but /api will let app = Router::::new() .nest("/api", api::router()) .layer(authorization) diff --git a/backend/src/startup.rs b/backend/src/startup.rs index 56c9ded..61f46b4 100644 --- a/backend/src/startup.rs +++ b/backend/src/startup.rs @@ -10,6 +10,7 @@ pub fn verify_environment() -> bool { "JWT_SECRET", "GOOGLE_CLIENT_ID", "GOOGLE_CLIENT_SECRET", + "WEBSOCKET_PORT", ]; let defined: Vec = std::env::vars().map(|(k, _)| k).collect(); diff --git a/backend/src/websocket/mod.rs b/backend/src/websocket/mod.rs new file mode 100644 index 0000000..c537b9a --- /dev/null +++ b/backend/src/websocket/mod.rs @@ -0,0 +1,94 @@ +use futures_util::stream::{SplitSink, SplitStream}; +use futures_util::StreamExt; +use log::{info, warn}; +use std::collections::HashMap; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::mpsc::Receiver; +use tokio_tungstenite::WebSocketStream; + +#[derive(Debug)] +pub enum WebSocketMessage { + NewRequestLimit(u16), + ReportDone(()), + Ping(String), +} + +async fn authenticate(mut ws: SplitStream>) -> anyhow::Result { + let msg = ws.next().await.unwrap()?.to_string(); + warn!("Authenticating: {}", msg); + Ok(msg) +} + +async fn accept_connection( + (stream, socket_addr): (TcpStream, std::net::SocketAddr), + peers: &mut HashMap< + String, + SplitSink, tokio_tungstenite::tungstenite::protocol::Message>, + >, +) { + info!("New connection from: {}", socket_addr.to_string()); + let ws_stream = tokio_tungstenite::accept_async(stream) + .await + .expect("Error during WebSocket handshake"); + let (write, read) = ws_stream.split(); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => { + info!("Ping timeout"); + }, + res = authenticate(read) => { + match res { + Ok(email) => { + info!("Authenticated email {email}"); + peers.insert(email, write); + info!("Peers: {}", peers.len()); + }, + Err(e) => { + warn!("Failed to authenticate: {}", e); + } + } + } + } +} + +async fn handle_system_message(msg: WebSocketMessage) { + match msg { + WebSocketMessage::NewRequestLimit(limit) => { + info!("New request limit: {}", limit); + } + WebSocketMessage::ReportDone(()) => { + info!("Report done"); + } + WebSocketMessage::Ping(msg) => { + info!("Ping: {}", msg); + } + } +} + +pub async fn start_service(port: u16, mut rx: Receiver) -> anyhow::Result<()> { + let addr = format!("127.0.0.1:{}", port); + let listener = TcpListener::bind(&addr).await?; + let mut peers: HashMap< + String, + SplitSink, tokio_tungstenite::tungstenite::protocol::Message>, + > = HashMap::new(); + + info!("WebSocket service listening on: {}", addr); + + // Some frontend client connects to this WebSocket service here + while let Ok((stream, socket_addr)) = listener.accept().await { + tokio::select! { + _ = accept_connection((stream, socket_addr), &mut peers) => { + info!("Connection accepted"); + }, + msg = rx.recv() => { + info!("Received message: {:?}", msg); + if let Some(msg) = msg { + handle_system_message(msg).await; + } + } + } + } + + Ok(()) +} From 6e990d9ae26b84c755574862309143e38ec849b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Mi=C5=82ek?= Date: Fri, 25 Oct 2024 00:47:37 +0200 Subject: [PATCH 2/8] 141: Graceful shutdown of both services with tokio_util CancellationToken --- backend/Cargo.toml | 3 ++- backend/src/main.rs | 9 +++++++-- backend/src/startup.rs | 17 ++++++++++++++--- backend/src/websocket/mod.rs | 37 +++++++++++++++++++++++------------- 4 files changed, 47 insertions(+), 19 deletions(-) diff --git a/backend/Cargo.toml b/backend/Cargo.toml index cb9e5f4..b445acb 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -30,4 +30,5 @@ jsonwebtoken = "9.3.0" chrono = "0.4.38" http = "1.1.0" tokio-tungstenite = "0.24.0" -futures-util = "0.3.31" \ No newline at end of file +futures-util = "0.3.31" +tokio-util = "0.7.12" \ No newline at end of file diff --git a/backend/src/main.rs b/backend/src/main.rs index 3192871..097afe8 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -46,11 +46,16 @@ async fn run() -> anyhow::Result<()> { info!("Connected to Reddit"); info!("Starting the WebSocket service"); + let cancellation_token = tokio_util::sync::CancellationToken::new(); // Drop the receiver, we don't need it let (tx, rx) = tokio::sync::mpsc::channel::(100); let port = std::env::var("WEBSOCKET_PORT").expect("WEBSOCKET_PORT is set"); let port = port.parse::().expect("WEBSOCKET_PORT is a valid u16"); - tokio::spawn(websocket::start_service(port, rx)); + tokio::spawn(websocket::start_service( + port, + rx, + cancellation_token.clone(), + )); let state = AppState { fetcher, @@ -89,7 +94,7 @@ async fn run() -> anyhow::Result<()> { info!("Started the RMoods server at {}", addr); axum::serve(listener, app) - .with_graceful_shutdown(shutdown_signal()) + .with_graceful_shutdown(shutdown_signal(cancellation_token)) .await?; Ok(()) diff --git a/backend/src/startup.rs b/backend/src/startup.rs index 61f46b4..b500171 100644 --- a/backend/src/startup.rs +++ b/backend/src/startup.rs @@ -1,4 +1,6 @@ +use log::info; use tokio::signal; +use tokio_util::sync::CancellationToken; /// Ensure that all necessary environment variables are available at server startup. /// It's important to keep this updated as our .env file grows. @@ -25,7 +27,7 @@ pub fn verify_environment() -> bool { is_ok } -pub async fn shutdown_signal() { +pub async fn shutdown_signal(cancellation_token: CancellationToken) { let ctrl_c = async { signal::ctrl_c() .await @@ -45,8 +47,17 @@ pub async fn shutdown_signal() { tokio::select! { _ = ctrl_c => { - log::info!("Received Ctrl+C signal, shutting down"); + log::info!("Received Ctrl+C signal, shutting down main HTTP server"); + cancellation_token.cancel(); + info!("Shutting down WebSocket service"); + }, + _ = terminate => { + log::info!("Received SIGTERM, shutting down"); + cancellation_token.cancel(); + info!("Shutting down WebSocket service"); + }, + _ = cancellation_token.cancelled() => { + info!("WebSocket service triggered shutdown, shutting down main HTTP server"); }, - _ = terminate => {}, } } diff --git a/backend/src/websocket/mod.rs b/backend/src/websocket/mod.rs index c537b9a..6dd6690 100644 --- a/backend/src/websocket/mod.rs +++ b/backend/src/websocket/mod.rs @@ -5,6 +5,7 @@ use std::collections::HashMap; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::Receiver; use tokio_tungstenite::WebSocketStream; +use tokio_util::sync::CancellationToken; #[derive(Debug)] pub enum WebSocketMessage { @@ -19,7 +20,7 @@ async fn authenticate(mut ws: SplitStream>) -> anyhow Ok(msg) } -async fn accept_connection( +async fn accept_websocket_connection( (stream, socket_addr): (TcpStream, std::net::SocketAddr), peers: &mut HashMap< String, @@ -65,7 +66,11 @@ async fn handle_system_message(msg: WebSocketMessage) { } } -pub async fn start_service(port: u16, mut rx: Receiver) -> anyhow::Result<()> { +pub async fn start_service( + port: u16, + mut rx: Receiver, + cancellation_token: CancellationToken, +) -> anyhow::Result<()> { let addr = format!("127.0.0.1:{}", port); let listener = TcpListener::bind(&addr).await?; let mut peers: HashMap< @@ -75,18 +80,24 @@ pub async fn start_service(port: u16, mut rx: Receiver) -> any info!("WebSocket service listening on: {}", addr); - // Some frontend client connects to this WebSocket service here - while let Ok((stream, socket_addr)) = listener.accept().await { + loop { tokio::select! { - _ = accept_connection((stream, socket_addr), &mut peers) => { - info!("Connection accepted"); - }, - msg = rx.recv() => { - info!("Received message: {:?}", msg); - if let Some(msg) = msg { - handle_system_message(msg).await; - } - } + res = listener.accept() => { + info!("Connection accepted"); + if let Ok((stream, socket_addr)) = res { + info!("Connection from: {}", socket_addr.to_string()); + accept_websocket_connection((stream, socket_addr), &mut peers).await; + } + }, + msg = rx.recv() => { + info!("Received message: {:?}", msg); + if let Some(msg) = msg { + handle_system_message(msg).await; + } + } + _ = cancellation_token.cancelled() => { + break; + }, } } From cf512dd74dc780a856a2d1f64a480f171d88d4fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Mi=C5=82ek?= Date: Fri, 25 Oct 2024 19:47:03 +0200 Subject: [PATCH 3/8] 141: Peer cleanup task, error handling, general improvements --- backend/src/main.rs | 6 +- backend/src/websocket/mod.rs | 127 ++++++++++++++++++++++++----------- 2 files changed, 91 insertions(+), 42 deletions(-) diff --git a/backend/src/main.rs b/backend/src/main.rs index 097afe8..ec13294 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,7 +1,7 @@ use crate::open_api::ApiDoc; use crate::reddit_fetcher::fetcher::RMoodsFetcher; use crate::startup::{shutdown_signal, verify_environment}; -use crate::websocket::WebSocketMessage; +use crate::websocket::SystemMessage; use api::auth; use axum::Router; use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; @@ -29,7 +29,7 @@ pub struct AppState { pub fetcher: RMoodsFetcher, pub pool: Pool, pub http: Client, - pub websocket_service_tx: tokio::sync::mpsc::Sender, + pub websocket_service_tx: tokio::sync::mpsc::Sender, } /// Run the server, assuming the environment has been already validated. @@ -48,7 +48,7 @@ async fn run() -> anyhow::Result<()> { info!("Starting the WebSocket service"); let cancellation_token = tokio_util::sync::CancellationToken::new(); // Drop the receiver, we don't need it - let (tx, rx) = tokio::sync::mpsc::channel::(100); + let (tx, rx) = tokio::sync::mpsc::channel::(100); let port = std::env::var("WEBSOCKET_PORT").expect("WEBSOCKET_PORT is set"); let port = port.parse::().expect("WEBSOCKET_PORT is a valid u16"); tokio::spawn(websocket::start_service( diff --git a/backend/src/websocket/mod.rs b/backend/src/websocket/mod.rs index 6dd6690..d1653d9 100644 --- a/backend/src/websocket/mod.rs +++ b/backend/src/websocket/mod.rs @@ -1,92 +1,141 @@ -use futures_util::stream::{SplitSink, SplitStream}; +use futures_util::stream::FusedStream; use futures_util::StreamExt; use log::{info, warn}; use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use thiserror::Error; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::Receiver; use tokio_tungstenite::WebSocketStream; use tokio_util::sync::CancellationToken; #[derive(Debug)] -pub enum WebSocketMessage { - NewRequestLimit(u16), +pub enum SystemMessage { + RemainingRequestsUpdate(u16), ReportDone(()), - Ping(String), } -async fn authenticate(mut ws: SplitStream>) -> anyhow::Result { - let msg = ws.next().await.unwrap()?.to_string(); +#[derive(Debug, Error)] +pub enum WebSocketServiceError { + #[error("Authentication timeout")] + AuthTimeout, + + #[error("Failed to authenticate: {0}")] + AuthError(String), + + #[error("WebSocket error: {0}")] + WebSocketError(#[from] tokio_tungstenite::tungstenite::Error), +} + +type PeersMap = Arc>>>; +const AUTH_TIMEOUT_SECS: u64 = 10; + +async fn authenticate_new_connection( + ws: &mut WebSocketStream, +) -> Result { + let msg = ws + .next() + .await + .ok_or_else(|| WebSocketServiceError::AuthError("No message in sink".to_string()))?? + .to_string(); warn!("Authenticating: {}", msg); + // TODO: Verify JWT Ok(msg) } async fn accept_websocket_connection( (stream, socket_addr): (TcpStream, std::net::SocketAddr), - peers: &mut HashMap< - String, - SplitSink, tokio_tungstenite::tungstenite::protocol::Message>, - >, -) { + peers: PeersMap, +) -> Result<(), WebSocketServiceError> { info!("New connection from: {}", socket_addr.to_string()); - let ws_stream = tokio_tungstenite::accept_async(stream) - .await - .expect("Error during WebSocket handshake"); - let (write, read) = ws_stream.split(); + let mut ws_stream = tokio_tungstenite::accept_async(stream).await?; + // Authenticate the new connection within a timeout of AUTH_TIMEOUT_SECS seconds tokio::select! { - _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => { - info!("Ping timeout"); + _ = tokio::time::sleep(std::time::Duration::from_secs(AUTH_TIMEOUT_SECS)) => { + warn!("Authentication timeout"); + Err(WebSocketServiceError::AuthTimeout) }, - res = authenticate(read) => { - match res { - Ok(email) => { - info!("Authenticated email {email}"); - peers.insert(email, write); - info!("Peers: {}", peers.len()); - }, - Err(e) => { - warn!("Failed to authenticate: {}", e); - } + auth_res = authenticate_new_connection(&mut ws_stream) => { + if let Ok(email) = auth_res { + info!("Authenticated email {email}"); + let mut peers = peers.lock().unwrap(); + peers.insert(email, ws_stream); + info!("Peers: {}", peers.len()); + } else { + warn!("Failed to authenticate: {}", auth_res.unwrap_err()); } + Ok(()) } } } -async fn handle_system_message(msg: WebSocketMessage) { +async fn handle_system_message(msg: SystemMessage) { match msg { - WebSocketMessage::NewRequestLimit(limit) => { + SystemMessage::RemainingRequestsUpdate(limit) => { info!("New request limit: {}", limit); } - WebSocketMessage::ReportDone(()) => { + SystemMessage::ReportDone(()) => { info!("Report done"); } - WebSocketMessage::Ping(msg) => { - info!("Ping: {}", msg); + } +} + +async fn peer_cleanup_task(peers: PeersMap, cancellation_token: CancellationToken) { + info!("Starting peer cleanup task"); + + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => { + info!("Peer cleanup started"); + let mut removed = 0; + let mut peers_lock = peers.lock().expect("Failed to lock peers map"); + + let to_remove: Vec = peers_lock + .iter() + .filter_map(|(email, ws)| match ws.is_terminated() { + true => Some(email.clone()), + _ => None, + }) + .collect(); + + for email in to_remove { + peers_lock.remove(&email); + removed += 1; + } + + info!("Removed {} peers", removed); + }, + _ = cancellation_token.cancelled() => { + info!("Peer cleanup task shutting down"); + return; + } } } } pub async fn start_service( port: u16, - mut rx: Receiver, + mut rx: Receiver, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { let addr = format!("127.0.0.1:{}", port); let listener = TcpListener::bind(&addr).await?; - let mut peers: HashMap< - String, - SplitSink, tokio_tungstenite::tungstenite::protocol::Message>, - > = HashMap::new(); + let peers: PeersMap = Arc::new(Mutex::new(HashMap::new())); + + info!("WebSocket service running on: {}", addr); - info!("WebSocket service listening on: {}", addr); + // spawn a task for peer cleanup + tokio::spawn(peer_cleanup_task(peers.clone(), cancellation_token.clone())); + // Until cancellation, accept new WS connections and handle system messages loop { tokio::select! { res = listener.accept() => { info!("Connection accepted"); if let Ok((stream, socket_addr)) = res { info!("Connection from: {}", socket_addr.to_string()); - accept_websocket_connection((stream, socket_addr), &mut peers).await; + let _ = accept_websocket_connection((stream, socket_addr), peers.clone()).await; } }, msg = rx.recv() => { From 8f437662f1ee70e40e6763d8ac5c33004b061a66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Mi=C5=82ek?= Date: Sat, 26 Oct 2024 20:45:52 +0200 Subject: [PATCH 4/8] 141: Create WS connections with Axum, improve architecture and dead client handling --- backend/Cargo.toml | 5 +- backend/src/api/auth/google.rs | 42 ++++++- backend/src/api/auth/mod.rs | 2 +- backend/src/main.rs | 19 +-- backend/src/websocket/mod.rs | 222 +++++++++++++++------------------ backend/src/websocket/peers.rs | 88 +++++++++++++ 6 files changed, 245 insertions(+), 133 deletions(-) create mode 100644 backend/src/websocket/peers.rs diff --git a/backend/Cargo.toml b/backend/Cargo.toml index b445acb..de4d833 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -8,7 +8,7 @@ opt-level = 1 [dependencies] tokio = { version = "1.37.0", features = ["full"] } -axum = "0.7.5" +axum = { version = "0.7.5", features = ["ws"] } env_logger = "0.11.3" log = "0.4.21" serde = "1.0.203" @@ -31,4 +31,5 @@ chrono = "0.4.38" http = "1.1.0" tokio-tungstenite = "0.24.0" futures-util = "0.3.31" -tokio-util = "0.7.12" \ No newline at end of file +tokio-util = "0.7.12" +futures = "0.3.31" \ No newline at end of file diff --git a/backend/src/api/auth/google.rs b/backend/src/api/auth/google.rs index b6faf94..ccdb45d 100644 --- a/backend/src/api/auth/google.rs +++ b/backend/src/api/auth/google.rs @@ -1,10 +1,14 @@ +use super::error::AuthError; +use crate::api::auth::jwt::decode_jwt; +use axum::async_trait; +use axum::extract::FromRequestParts; use derive_getters::Getters; +use http::request::Parts; +use http::StatusCode; use log_derive::logfn; use reqwest::{multipart::Form, Client}; use serde::{Deserialize, Serialize}; -use super::error::AuthError; - #[derive(Deserialize, Debug, Getters)] #[allow(unused)] pub struct GoogleTokenResponse { @@ -29,6 +33,40 @@ pub struct GoogleUserInfo { email_verified: bool, } +/// Implement `FromRequestParts` for `GoogleUserInfo` to extract the user info from the request. +/// +/// With this trait implemented, the user info can be extracted by any Axum HTTP handler without +/// jumping through extra hoops like obtaining an `Authorization` header and decoding the JWT. +#[async_trait] +impl FromRequestParts for GoogleUserInfo { + type Rejection = (StatusCode, String); + + async fn from_request_parts(parts: &mut Parts, _: &T) -> Result { + let auth_header = parts + .headers + .get("Authorization") + .ok_or(( + StatusCode::UNAUTHORIZED, + "No Authorization header".to_string(), + ))? + .to_str() + .map_err(|_| { + ( + StatusCode::UNAUTHORIZED, + "Invalid Authorization header".to_string(), + ) + })?; + + let token = auth_header.split_whitespace().nth(1).ok_or(( + StatusCode::UNAUTHORIZED, + "No token in Authorization header".to_string(), + ))?; + + let claims = decode_jwt(token).map_err(|e| (StatusCode::UNAUTHORIZED, e.to_string()))?; + Ok(claims.claims.user_info) + } +} + #[logfn(err = "ERROR", fmt = "Failed to fetch access token: {:?}")] pub async fn fetch_google_access_token( auth_code: String, diff --git a/backend/src/api/auth/mod.rs b/backend/src/api/auth/mod.rs index 0103cbc..d14919b 100644 --- a/backend/src/api/auth/mod.rs +++ b/backend/src/api/auth/mod.rs @@ -2,7 +2,7 @@ use crate::AppState; use axum::{routing::post, Router}; pub mod error; -mod google; +pub mod google; pub mod jwt; pub(crate) mod login; pub mod middleware; diff --git a/backend/src/main.rs b/backend/src/main.rs index ec13294..62ff3a9 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -3,11 +3,13 @@ use crate::reddit_fetcher::fetcher::RMoodsFetcher; use crate::startup::{shutdown_signal, verify_environment}; use crate::websocket::SystemMessage; use api::auth; -use axum::Router; +use axum::handler::HandlerWithoutStateExt; +use axum::{Router, ServiceExt}; use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; use log::{error, info, warn}; use reqwest::Client; use sqlx::{postgres::PgPoolOptions, Pool, Postgres}; +use std::net::SocketAddr; use tower_http::{ cors::{Any, CorsLayer}, trace::TraceLayer, @@ -29,7 +31,7 @@ pub struct AppState { pub fetcher: RMoodsFetcher, pub pool: Pool, pub http: Client, - pub websocket_service_tx: tokio::sync::mpsc::Sender, + pub system_tx: tokio::sync::mpsc::Sender, } /// Run the server, assuming the environment has been already validated. @@ -47,13 +49,12 @@ async fn run() -> anyhow::Result<()> { info!("Starting the WebSocket service"); let cancellation_token = tokio_util::sync::CancellationToken::new(); - // Drop the receiver, we don't need it - let (tx, rx) = tokio::sync::mpsc::channel::(100); + + let (system_tx, mut system_rx) = tokio::sync::mpsc::channel::(100); let port = std::env::var("WEBSOCKET_PORT").expect("WEBSOCKET_PORT is set"); let port = port.parse::().expect("WEBSOCKET_PORT is a valid u16"); tokio::spawn(websocket::start_service( - port, - rx, + system_rx, cancellation_token.clone(), )); @@ -61,7 +62,7 @@ async fn run() -> anyhow::Result<()> { fetcher, pool, http, - websocket_service_tx: tx, + system_tx, }; // Allow browsers to use GET and PUT from any origin @@ -79,12 +80,14 @@ async fn run() -> anyhow::Result<()> { // Example: /auth routes won't have the authorization layer, but /api will let app = Router::::new() .nest("/api", api::router()) + .nest("/ws", websocket::router()) .layer(authorization) .nest("/auth", auth::router()) .with_state(state) .layer(tracing) .layer(cors) - .merge(SwaggerUi::new("/doc/ui").url("/doc/api.json", ApiDoc::openapi())); + .merge(SwaggerUi::new("/doc/ui").url("/doc/api.json", ApiDoc::openapi())) + .into_make_service_with_connect_info::(); let port = std::env::var("PORT").unwrap_or_else(|_| "8001".to_string()); // Listen on all addresses diff --git a/backend/src/websocket/mod.rs b/backend/src/websocket/mod.rs index d1653d9..d67e117 100644 --- a/backend/src/websocket/mod.rs +++ b/backend/src/websocket/mod.rs @@ -1,154 +1,136 @@ -use futures_util::stream::FusedStream; +use crate::api::auth::google::GoogleUserInfo; +use crate::AppState; +use axum::extract::ws::WebSocket; +use axum::extract::{ConnectInfo, State, WebSocketUpgrade}; +use axum::response::IntoResponse; +use axum::routing::any; use futures_util::StreamExt; use log::{info, warn}; -use std::collections::HashMap; -use std::sync::{Arc, Mutex}; -use thiserror::Error; -use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::mpsc::Receiver; -use tokio_tungstenite::WebSocketStream; +use peers::PeersMap; +use std::net::SocketAddr; +use std::sync::atomic::AtomicUsize; +use tokio::sync::mpsc::{Receiver, Sender}; use tokio_util::sync::CancellationToken; +mod peers; + +/// Generates a unique user ID, thread safe. +fn generate_user_id() -> String { + static USER_ID_GEN: AtomicUsize = AtomicUsize::new(0); + USER_ID_GEN + .fetch_add(1, std::sync::atomic::Ordering::SeqCst) + .to_string() +} + +#[derive(Debug)] +pub struct ServiceToClientMessage; + +type ConnectionId = String; +type GoogleId = String; +/// A tuple of the user's Google ID and the WebSocket connection ID. +/// The Google ID is used to identify the user, while the WebSocket ID is used to identify the +/// connection. +/// +/// Thanks to the Google ID, the server can send messages to a specific user, even if they have +/// multiple connections. +type WsUserId = (GoogleId, ConnectionId); + #[derive(Debug)] pub enum SystemMessage { RemainingRequestsUpdate(u16), ReportDone(()), + AddPeer((WsUserId, Sender)), + RemovePeer(ConnectionId), } -#[derive(Debug, Error)] -pub enum WebSocketServiceError { - #[error("Authentication timeout")] - AuthTimeout, - - #[error("Failed to authenticate: {0}")] - AuthError(String), - - #[error("WebSocket error: {0}")] - WebSocketError(#[from] tokio_tungstenite::tungstenite::Error), -} - -type PeersMap = Arc>>>; -const AUTH_TIMEOUT_SECS: u64 = 10; - -async fn authenticate_new_connection( - ws: &mut WebSocketStream, -) -> Result { - let msg = ws - .next() - .await - .ok_or_else(|| WebSocketServiceError::AuthError("No message in sink".to_string()))?? - .to_string(); - warn!("Authenticating: {}", msg); - // TODO: Verify JWT - Ok(msg) +pub fn router() -> axum::Router { + axum::Router::new().route("/connect", any(websocket_handler)) } -async fn accept_websocket_connection( - (stream, socket_addr): (TcpStream, std::net::SocketAddr), - peers: PeersMap, -) -> Result<(), WebSocketServiceError> { - info!("New connection from: {}", socket_addr.to_string()); - let mut ws_stream = tokio_tungstenite::accept_async(stream).await?; - - // Authenticate the new connection within a timeout of AUTH_TIMEOUT_SECS seconds - tokio::select! { - _ = tokio::time::sleep(std::time::Duration::from_secs(AUTH_TIMEOUT_SECS)) => { - warn!("Authentication timeout"); - Err(WebSocketServiceError::AuthTimeout) - }, - auth_res = authenticate_new_connection(&mut ws_stream) => { - if let Ok(email) = auth_res { - info!("Authenticated email {email}"); - let mut peers = peers.lock().unwrap(); - peers.insert(email, ws_stream); - info!("Peers: {}", peers.len()); - } else { - warn!("Failed to authenticate: {}", auth_res.unwrap_err()); - } - Ok(()) - } - } +async fn websocket_handler( + ws: WebSocketUpgrade, + State(state): State, + ConnectInfo(socket_info): ConnectInfo, + google_user_info: GoogleUserInfo, +) -> impl IntoResponse { + ws.on_upgrade(move |ws| handle_socket(ws, state.system_tx, socket_info, google_user_info)) } -async fn handle_system_message(msg: SystemMessage) { - match msg { - SystemMessage::RemainingRequestsUpdate(limit) => { - info!("New request limit: {}", limit); - } - SystemMessage::ReportDone(()) => { - info!("Report done"); - } - } -} +async fn handle_socket( + mut socket: WebSocket, + system_tx: Sender, + socket_addr: SocketAddr, + user_info: GoogleUserInfo, +) { + info!("New WebSocket connection: {:?}", socket_addr); -async fn peer_cleanup_task(peers: PeersMap, cancellation_token: CancellationToken) { - info!("Starting peer cleanup task"); + dbg!(&user_info); - loop { - tokio::select! { - _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => { - info!("Peer cleanup started"); - let mut removed = 0; - let mut peers_lock = peers.lock().expect("Failed to lock peers map"); + let user_ws_id_pair = (user_info.sub().to_string(), generate_user_id()); - let to_remove: Vec = peers_lock - .iter() - .filter_map(|(email, ws)| match ws.is_terminated() { - true => Some(email.clone()), - _ => None, - }) - .collect(); + let (tx, mut rx) = tokio::sync::mpsc::channel::(100); - for email in to_remove { - peers_lock.remove(&email); - removed += 1; - } + system_tx + .send(SystemMessage::AddPeer((user_ws_id_pair, tx))) + .await + .unwrap(); - info!("Removed {} peers", removed); - }, - _ = cancellation_token.cancelled() => { - info!("Peer cleanup task shutting down"); - return; + tokio::select! { + ws_msg_res = socket.next() => { + if let Some(msg) = ws_msg_res { + let msg = msg.unwrap(); + type Message = axum::extract::ws::Message; + match msg { + Message::Close(_) => { + warn!("Closing connection"); + system_tx + .send(SystemMessage::RemovePeer(user_info.sub().to_string())) + .await + .unwrap(); + return; + } + _ => { + warn!("Received message: {:?}", msg); + socket.send(msg).await.unwrap(); + } + }; + } + }, + service_msg_res = rx.recv() => { + if let Some(msg) = service_msg_res { + warn!("Received message from the main service: {:?}", msg); } } } } pub async fn start_service( - port: u16, - mut rx: Receiver, + mut system_rx: Receiver, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { - let addr = format!("127.0.0.1:{}", port); - let listener = TcpListener::bind(&addr).await?; - let peers: PeersMap = Arc::new(Mutex::new(HashMap::new())); - - info!("WebSocket service running on: {}", addr); + let mut peers = PeersMap::new(); - // spawn a task for peer cleanup - tokio::spawn(peer_cleanup_task(peers.clone(), cancellation_token.clone())); - - // Until cancellation, accept new WS connections and handle system messages loop { tokio::select! { - res = listener.accept() => { - info!("Connection accepted"); - if let Ok((stream, socket_addr)) = res { - info!("Connection from: {}", socket_addr.to_string()); - let _ = accept_websocket_connection((stream, socket_addr), peers.clone()).await; - } - }, - msg = rx.recv() => { - info!("Received message: {:?}", msg); - if let Some(msg) = msg { - handle_system_message(msg).await; - } - } - _ = cancellation_token.cancelled() => { - break; - }, + msg = system_rx.recv() => { + if let Some(msg) = msg { + info!("Received system message: {:?}", msg); + match msg { + SystemMessage::AddPeer(ws_user_id) => { + peers.add_peer(ws_user_id); + } + SystemMessage::RemovePeer(connection_id) => { + peers.remove_peer(connection_id); + } + _ => {} + } + } + } + _ = cancellation_token.cancelled() => { + info!("WebSocket service shut down"); + break; + } } } - Ok(()) } diff --git a/backend/src/websocket/peers.rs b/backend/src/websocket/peers.rs new file mode 100644 index 0000000..33970e6 --- /dev/null +++ b/backend/src/websocket/peers.rs @@ -0,0 +1,88 @@ +use crate::websocket::{ConnectionId, ServiceToClientMessage, WsUserId}; +use std::collections::HashMap; +use tokio::sync::mpsc::Sender; + +#[derive(Debug)] +pub struct PeersMap { + peers: HashMap>, +} + +impl PeersMap { + pub fn new() -> Self { + Self { + peers: HashMap::new(), + } + } + + pub fn add_peer(&mut self, (user_id, sender): (WsUserId, Sender)) { + self.peers.insert(user_id, sender); + } + + pub fn remove_peer(&mut self, connection_id: ConnectionId) { + self.peers + .retain(|(_, conn_id), _| conn_id != &connection_id); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::sync::mpsc::channel; + + #[tokio::test] + async fn test_add_peer() { + let mut peers = PeersMap::new(); + let (sender, _) = channel(1); + let user_id = ("google_id".to_string(), "conn_id".to_string()); + + peers.add_peer((user_id.clone(), sender.clone())); + assert_eq!(peers.peers.len(), 1); + assert!(peers.peers.get(&user_id).is_some()); + } + + #[tokio::test] + async fn test_remove_peer() { + let mut peers = PeersMap::new(); + let (sender, _) = channel(1); + let user_id = ("google_id".to_string(), "conn_id".to_string()); + + peers.add_peer((user_id.clone(), sender)); + assert_eq!(peers.peers.len(), 1); + peers.remove_peer("conn_id".to_string()); + assert_eq!(peers.peers.len(), 0); + } + + #[tokio::test] + async fn test_remove_peer_no_match() { + let mut peers = PeersMap::new(); + let (sender, _) = channel(1); + let user_id = ("google_id".to_string(), "conn_id".to_string()); + + peers.add_peer((user_id.clone(), sender)); + assert_eq!(peers.peers.len(), 1); + peers.remove_peer("conn_id_2".to_string()); + assert_eq!(peers.peers.len(), 1); + } + + #[tokio::test] + async fn test_remove_peer_empty() { + let mut peers = PeersMap::new(); + peers.remove_peer("conn_id".to_string()); + assert_eq!(peers.peers.len(), 0); + } + + #[tokio::test] + async fn test_remove_peer_multiple() { + let mut peers = PeersMap::new(); + let (sender, _) = channel(1); + let user_id = ("google_id".to_string(), "conn_id".to_string()); + let user_id_2 = ("google_id_2".to_string(), "conn_id_2".to_string()); + + peers.add_peer((user_id.clone(), sender.clone())); + peers.add_peer((user_id_2.clone(), sender.clone())); + assert_eq!(peers.peers.len(), 2); + peers.remove_peer("conn_id".to_string()); + assert_eq!(peers.peers.len(), 1); + assert!(peers.peers.get(&user_id_2).is_some()); + } +} From 865448fdfa180454031c57c9dff0ed72681911d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Mi=C5=82ek?= Date: Sun, 27 Oct 2024 13:40:57 +0100 Subject: [PATCH 5/8] 141: Debug peer addition and removal --- backend/src/websocket/mod.rs | 52 ++++++++++++++++++---------------- backend/src/websocket/peers.rs | 8 ++++++ 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/backend/src/websocket/mod.rs b/backend/src/websocket/mod.rs index d67e117..3d04a7e 100644 --- a/backend/src/websocket/mod.rs +++ b/backend/src/websocket/mod.rs @@ -71,34 +71,36 @@ async fn handle_socket( let (tx, mut rx) = tokio::sync::mpsc::channel::(100); system_tx - .send(SystemMessage::AddPeer((user_ws_id_pair, tx))) + .send(SystemMessage::AddPeer((user_ws_id_pair.clone(), tx))) .await .unwrap(); - tokio::select! { - ws_msg_res = socket.next() => { - if let Some(msg) = ws_msg_res { - let msg = msg.unwrap(); - type Message = axum::extract::ws::Message; - match msg { - Message::Close(_) => { - warn!("Closing connection"); - system_tx - .send(SystemMessage::RemovePeer(user_info.sub().to_string())) - .await - .unwrap(); - return; - } - _ => { - warn!("Received message: {:?}", msg); - socket.send(msg).await.unwrap(); - } - }; - } - }, - service_msg_res = rx.recv() => { - if let Some(msg) = service_msg_res { - warn!("Received message from the main service: {:?}", msg); + loop { + tokio::select! { + ws_msg_res = socket.next() => { + if let Some(msg) = ws_msg_res { + let msg = msg.unwrap(); + type Message = axum::extract::ws::Message; + match msg { + Message::Close(_) => { + warn!("Closing connection"); + system_tx + .send(SystemMessage::RemovePeer(user_ws_id_pair.1)) + .await + .unwrap(); + return; + } + _ => { + warn!("Received message: {:?}", msg); + socket.send(msg).await.unwrap(); + } + }; + } + }, + service_msg_res = rx.recv() => { + if let Some(msg) = service_msg_res { + warn!("Received message from the main service: {:?}", msg); + } } } } diff --git a/backend/src/websocket/peers.rs b/backend/src/websocket/peers.rs index 33970e6..4d8c4b6 100644 --- a/backend/src/websocket/peers.rs +++ b/backend/src/websocket/peers.rs @@ -19,8 +19,16 @@ impl PeersMap { } pub fn remove_peer(&mut self, connection_id: ConnectionId) { + let len_before = self.peers.len(); self.peers .retain(|(_, conn_id), _| conn_id != &connection_id); + let len_after = self.peers.len(); + if len_before == len_after { + log::error!( + "No peer with connection ID {}. No peers removed. Each connection_id that is to be removed must come from the peer-specific handler task.", + connection_id + ); + } } } From c7a30836cd363ea4770a7bf40ef845b6092cc8e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Mi=C5=82ek?= Date: Sun, 27 Oct 2024 15:19:39 +0100 Subject: [PATCH 6/8] 141: Refactor, add documentation --- backend/src/websocket/mod.rs | 96 +++++++++++++++++++++++++--------- backend/src/websocket/peers.rs | 11 ++++ 2 files changed, 82 insertions(+), 25 deletions(-) diff --git a/backend/src/websocket/mod.rs b/backend/src/websocket/mod.rs index 3d04a7e..97eaed6 100644 --- a/backend/src/websocket/mod.rs +++ b/backend/src/websocket/mod.rs @@ -5,7 +5,7 @@ use axum::extract::{ConnectInfo, State, WebSocketUpgrade}; use axum::response::IntoResponse; use axum::routing::any; use futures_util::StreamExt; -use log::{info, warn}; +use log::{error, info, warn}; use peers::PeersMap; use std::net::SocketAddr; use std::sync::atomic::AtomicUsize; @@ -27,14 +27,16 @@ pub struct ServiceToClientMessage; type ConnectionId = String; type GoogleId = String; + /// A tuple of the user's Google ID and the WebSocket connection ID. -/// The Google ID is used to identify the user, while the WebSocket ID is used to identify the -/// connection. +/// * Google ID is used to identify the user +/// * WebSocket ID is used to identify the connection. /// /// Thanks to the Google ID, the server can send messages to a specific user, even if they have /// multiple connections. type WsUserId = (GoogleId, ConnectionId); +/// Messages that the WebSocket Service can receive from the main HTTP process. #[derive(Debug)] pub enum SystemMessage { RemainingRequestsUpdate(u16), @@ -43,10 +45,14 @@ pub enum SystemMessage { RemovePeer(ConnectionId), } +/// Defines the WebSocket routes. pub fn router() -> axum::Router { axum::Router::new().route("/connect", any(websocket_handler)) } +/// Handles the WebSocket upgrade request. +/// The handler is responsible for creating a new WebSocket connection and managing it. +/// Frontend calls this endpoint to establish a WebSocket connection. async fn websocket_handler( ws: WebSocketUpgrade, State(state): State, @@ -56,6 +62,13 @@ async fn websocket_handler( ws.on_upgrade(move |ws| handle_socket(ws, state.system_tx, socket_info, google_user_info)) } +/// Handles a new WebSocket connection. +/// One client (one Google account) can have many active connections at once. +/// The server uses the Google ID to identify the user, and the WebSocket ID to identify the connection. +/// +/// The handler can communicate with the WebSocket Service through `mpsc` channels. +/// 1. `system_tx` channel is the app-wide channel to send messages to the WebSocket Service. +/// 2. `service_to_client_tx` channel is the chanel where the Service sends messages to the client handlers. async fn handle_socket( mut socket: WebSocket, system_tx: Sender, @@ -63,49 +76,77 @@ async fn handle_socket( user_info: GoogleUserInfo, ) { info!("New WebSocket connection: {:?}", socket_addr); - dbg!(&user_info); let user_ws_id_pair = (user_info.sub().to_string(), generate_user_id()); - let (tx, mut rx) = tokio::sync::mpsc::channel::(100); + let (service_to_client_tx, mut service_to_client_rx) = + tokio::sync::mpsc::channel::(100); + // Register the connection. We give the Service our tx, so it can call the handler when needed. system_tx - .send(SystemMessage::AddPeer((user_ws_id_pair.clone(), tx))) + .send(SystemMessage::AddPeer(( + user_ws_id_pair.clone(), + service_to_client_tx, + ))) .await .unwrap(); + // Cleanup function to remove the peer from the system + let peer_cleanup = || async { + system_tx + .send(SystemMessage::RemovePeer(user_ws_id_pair.1)) + .await + .unwrap(); + }; + loop { tokio::select! { - ws_msg_res = socket.next() => { - if let Some(msg) = ws_msg_res { - let msg = msg.unwrap(); - type Message = axum::extract::ws::Message; - match msg { - Message::Close(_) => { - warn!("Closing connection"); - system_tx - .send(SystemMessage::RemovePeer(user_ws_id_pair.1)) - .await - .unwrap(); - return; - } - _ => { - warn!("Received message: {:?}", msg); - socket.send(msg).await.unwrap(); - } - }; + ws_msg_res = socket.next() => match ws_msg_res { + Some(Ok(msg)) => match msg { + axum::extract::ws::Message::Close(_) => { + info!("Closing connection"); + peer_cleanup().await; + break; + } + _ => { + info!("Received message: {:?}. Echoing", msg); + socket.send(msg).await.unwrap(); + } + }, + Some(Err(e)) => { + warn!("Error receiving message: {:?}", e); + peer_cleanup().await; + break; + } + None => { + warn!("Connection closed - WS stream ended"); + peer_cleanup().await; + break; } }, - service_msg_res = rx.recv() => { + service_msg_res = service_to_client_rx.recv() => { if let Some(msg) = service_msg_res { warn!("Received message from the main service: {:?}", msg); + todo!("Handle service message"); + } else { + error!("WS Service task has exited or closed the mpsc channel"); + break; } } } } } +/// Starts and maintains the WebSocket Service. +/// +/// This service is responsible for managing the WebSocket connections, and sending messages to the clients. +/// * When a new WebSocket connection is established, the service is asked to register it. +/// * When a connection is closed, the service is asked to remove it. +/// +/// The service is also responsible for sending messages to the clients. +/// * When a report request completes, the service sends the report to the client. +/// * When the number of remaining Reddit API requests changes, the service sends the new number to all clients. pub async fn start_service( mut system_rx: Receiver, cancellation_token: CancellationToken, @@ -120,12 +161,17 @@ pub async fn start_service( match msg { SystemMessage::AddPeer(ws_user_id) => { peers.add_peer(ws_user_id); + info!("Peers number: {}", peers.len()); } SystemMessage::RemovePeer(connection_id) => { peers.remove_peer(connection_id); + info!("Peers number: {}", peers.len()); } _ => {} } + } else { + error!("The main HTTP process has exited or closed the mpsc channel"); + break; } } _ = cancellation_token.cancelled() => { diff --git a/backend/src/websocket/peers.rs b/backend/src/websocket/peers.rs index 4d8c4b6..72e49f1 100644 --- a/backend/src/websocket/peers.rs +++ b/backend/src/websocket/peers.rs @@ -2,6 +2,13 @@ use crate::websocket::{ConnectionId, ServiceToClientMessage, WsUserId}; use std::collections::HashMap; use tokio::sync::mpsc::Sender; +/// A map of all connected peers. +/// +/// Key: A tuple of the user's Google ID and the WebSocket connection ID. +/// Value: The sender half of a channel that is used to send messages to that particular connection. +/// +/// Thanks to the key being a pair, the server can send messages to a specific user, even if they have +/// multiple connections. We can identify a user and all their connections. #[derive(Debug)] pub struct PeersMap { peers: HashMap>, @@ -30,6 +37,10 @@ impl PeersMap { ); } } + + pub fn len(&self) -> usize { + self.peers.len() + } } #[cfg(test)] From 0edcfac7f841586ce2d824675b2293d343352761 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Mi=C5=82ek?= Date: Sun, 27 Oct 2024 15:29:29 +0100 Subject: [PATCH 7/8] 141: Harden by using return instead of break --- backend/src/main.rs | 7 ++---- backend/src/websocket/mod.rs | 43 +++++++++++++++++++--------------- backend/src/websocket/peers.rs | 3 +++ 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/backend/src/main.rs b/backend/src/main.rs index 62ff3a9..a8dd495 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -3,8 +3,7 @@ use crate::reddit_fetcher::fetcher::RMoodsFetcher; use crate::startup::{shutdown_signal, verify_environment}; use crate::websocket::SystemMessage; use api::auth; -use axum::handler::HandlerWithoutStateExt; -use axum::{Router, ServiceExt}; +use axum::Router; use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; use log::{error, info, warn}; use reqwest::Client; @@ -50,9 +49,7 @@ async fn run() -> anyhow::Result<()> { info!("Starting the WebSocket service"); let cancellation_token = tokio_util::sync::CancellationToken::new(); - let (system_tx, mut system_rx) = tokio::sync::mpsc::channel::(100); - let port = std::env::var("WEBSOCKET_PORT").expect("WEBSOCKET_PORT is set"); - let port = port.parse::().expect("WEBSOCKET_PORT is a valid u16"); + let (system_tx, system_rx) = tokio::sync::mpsc::channel::(100); tokio::spawn(websocket::start_service( system_rx, cancellation_token.clone(), diff --git a/backend/src/websocket/mod.rs b/backend/src/websocket/mod.rs index 97eaed6..80424a7 100644 --- a/backend/src/websocket/mod.rs +++ b/backend/src/websocket/mod.rs @@ -84,20 +84,26 @@ async fn handle_socket( tokio::sync::mpsc::channel::(100); // Register the connection. We give the Service our tx, so it can call the handler when needed. - system_tx + let res = system_tx .send(SystemMessage::AddPeer(( user_ws_id_pair.clone(), service_to_client_tx, ))) - .await - .unwrap(); + .await; + + if let Err(e) = res { + error!("Failed to register the new peer: {:?}", e); + return; + } // Cleanup function to remove the peer from the system - let peer_cleanup = || async { - system_tx + let ask_to_remove_this_peer = || async { + let res = system_tx .send(SystemMessage::RemovePeer(user_ws_id_pair.1)) - .await - .unwrap(); + .await; + if let Err(e) = res { + error!("Failed to remove the peer: {:?}", e); + } }; loop { @@ -106,23 +112,23 @@ async fn handle_socket( Some(Ok(msg)) => match msg { axum::extract::ws::Message::Close(_) => { info!("Closing connection"); - peer_cleanup().await; - break; + ask_to_remove_this_peer().await; + return; } _ => { info!("Received message: {:?}. Echoing", msg); - socket.send(msg).await.unwrap(); + let _ = socket.send(msg).await; } }, Some(Err(e)) => { warn!("Error receiving message: {:?}", e); - peer_cleanup().await; - break; + ask_to_remove_this_peer().await; + return; } None => { warn!("Connection closed - WS stream ended"); - peer_cleanup().await; - break; + ask_to_remove_this_peer().await; + return; } }, service_msg_res = service_to_client_rx.recv() => { @@ -131,7 +137,7 @@ async fn handle_socket( todo!("Handle service message"); } else { error!("WS Service task has exited or closed the mpsc channel"); - break; + return; } } } @@ -150,7 +156,7 @@ async fn handle_socket( pub async fn start_service( mut system_rx: Receiver, cancellation_token: CancellationToken, -) -> anyhow::Result<()> { +) { let mut peers = PeersMap::new(); loop { @@ -171,14 +177,13 @@ pub async fn start_service( } } else { error!("The main HTTP process has exited or closed the mpsc channel"); - break; + return; } } _ = cancellation_token.cancelled() => { info!("WebSocket service shut down"); - break; + return; } } } - Ok(()) } diff --git a/backend/src/websocket/peers.rs b/backend/src/websocket/peers.rs index 72e49f1..485f1dd 100644 --- a/backend/src/websocket/peers.rs +++ b/backend/src/websocket/peers.rs @@ -25,6 +25,9 @@ impl PeersMap { self.peers.insert(user_id, sender); } + /// Remove a particular WS connection from the map. + /// This does not remove all connections for a user, only the one with the given connection ID. + /// Called when a connection is closed. pub fn remove_peer(&mut self, connection_id: ConnectionId) { let len_before = self.peers.len(); self.peers From 1acab320c0cc4ff5ef8cdbdeb9cf0ffdb3ff2806 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Mi=C5=82ek?= Date: Sun, 27 Oct 2024 18:04:10 +0100 Subject: [PATCH 8/8] 141: Remove WEBSOCKET_PORT env variable --- backend/.env.example | 3 +-- backend/src/main.rs | 1 + backend/src/startup.rs | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/backend/.env.example b/backend/.env.example index db3f476..f71d290 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -3,5 +3,4 @@ CLIENT_SECRET= DATABASE_URL= JWT_SECRET= GOOGLE_CLIENT_ID= -GOOGLE_CLIENT_SECRET= -WEBSOCKET_PORT= \ No newline at end of file +GOOGLE_CLIENT_SECRET= \ No newline at end of file diff --git a/backend/src/main.rs b/backend/src/main.rs index a8dd495..e187edd 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -24,6 +24,7 @@ mod startup; mod websocket; /// State to be shared between all routes. +/// /// Contains common resources that shouldn't be created over and over again. #[derive(Clone)] pub struct AppState { diff --git a/backend/src/startup.rs b/backend/src/startup.rs index b500171..32081f8 100644 --- a/backend/src/startup.rs +++ b/backend/src/startup.rs @@ -12,7 +12,6 @@ pub fn verify_environment() -> bool { "JWT_SECRET", "GOOGLE_CLIENT_ID", "GOOGLE_CLIENT_SECRET", - "WEBSOCKET_PORT", ]; let defined: Vec = std::env::vars().map(|(k, _)| k).collect();