diff --git a/src/api/core/accounts.rs b/src/api/core/accounts.rs index f9822629b2..a4439d8254 100644 --- a/src/api/core/accounts.rs +++ b/src/api/core/accounts.rs @@ -1,7 +1,10 @@ +use std::sync::Arc; + use crate::db::DbPool; use chrono::{SecondsFormat, Utc}; use rocket::serde::json::Json; use serde_json::Value; +use tokio::sync::RwLock; use crate::{ api::{ @@ -1282,9 +1285,9 @@ async fn get_auth_requests(headers: Headers, mut conn: DbConn) -> JsonResult { }))) } -pub async fn purge_auth_requests(pool: DbPool) { +pub async fn purge_auth_requests(pool: Arc>) { debug!("Purging auth requests"); - if let Ok(mut conn) = pool.get().await { + if let Ok(mut conn) = pool.read().await.get().await { AuthRequest::purge_expired_auth_requests(&mut conn).await; } else { error!("Failed to get DB connection while purging trashed ciphers") diff --git a/src/api/core/ciphers.rs b/src/api/core/ciphers.rs index da7189428f..1d68433fbc 100644 --- a/src/api/core/ciphers.rs +++ b/src/api/core/ciphers.rs @@ -1,4 +1,5 @@ use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use chrono::{NaiveDateTime, Utc}; use num_traits::ToPrimitive; @@ -9,6 +10,7 @@ use rocket::{ Route, }; use serde_json::Value; +use tokio::sync::RwLock; use crate::util::NumberOrString; use crate::{ @@ -88,9 +90,9 @@ pub fn routes() -> Vec { ] } -pub async fn purge_trashed_ciphers(pool: DbPool) { +pub async fn purge_trashed_ciphers(pool: Arc>) { debug!("Purging trashed ciphers"); - if let Ok(mut conn) = pool.get().await { + if let Ok(mut conn) = pool.read().await.get().await { Cipher::purge_trash(&mut conn).await; } else { error!("Failed to get DB connection while purging trashed ciphers") diff --git a/src/api/core/emergency_access.rs b/src/api/core/emergency_access.rs index 1c29b7748a..dca04c83bd 100644 --- a/src/api/core/emergency_access.rs +++ b/src/api/core/emergency_access.rs @@ -1,6 +1,9 @@ +use std::sync::Arc; + use chrono::{TimeDelta, Utc}; use rocket::{serde::json::Json, Route}; use serde_json::Value; +use tokio::sync::RwLock; use crate::{ api::{ @@ -729,13 +732,13 @@ fn check_emergency_access_enabled() -> EmptyResult { Ok(()) } -pub async fn emergency_request_timeout_job(pool: DbPool) { +pub async fn emergency_request_timeout_job(pool: Arc>) { debug!("Start emergency_request_timeout_job"); if !CONFIG.emergency_access_allowed() { return; } - if let Ok(mut conn) = pool.get().await { + if let Ok(mut conn) = pool.read().await.get().await { let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await; if emergency_access_list.is_empty() { @@ -784,13 +787,13 @@ pub async fn emergency_request_timeout_job(pool: DbPool) { } } -pub async fn emergency_notification_reminder_job(pool: DbPool) { +pub async fn emergency_notification_reminder_job(pool: Arc>) { debug!("Start emergency_notification_reminder_job"); if !CONFIG.emergency_access_allowed() { return; } - if let Ok(mut conn) = pool.get().await { + if let Ok(mut conn) = pool.read().await.get().await { let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await; if emergency_access_list.is_empty() { diff --git a/src/api/core/events.rs b/src/api/core/events.rs index 484094f52e..303deb5277 100644 --- a/src/api/core/events.rs +++ b/src/api/core/events.rs @@ -1,8 +1,9 @@ -use std::net::IpAddr; +use std::{net::IpAddr, sync::Arc}; use chrono::NaiveDateTime; use rocket::{form::FromForm, serde::json::Json, Route}; use serde_json::Value; +use tokio::sync::RwLock; use crate::{ api::{EmptyResult, JsonResult}, @@ -320,14 +321,14 @@ async fn _log_event( event.save(conn).await.unwrap_or(()); } -pub async fn event_cleanup_job(pool: DbPool) { +pub async fn event_cleanup_job(pool: Arc>) { debug!("Start events cleanup job"); if CONFIG.events_days_retain().is_none() { debug!("events_days_retain is not configured, abort"); return; } - if let Ok(mut conn) = pool.get().await { + if let Ok(mut conn) = pool.read().await.get().await { Event::clean_events(&mut conn).await.ok(); } else { error!("Failed to get DB connection while trying to cleanup the events table") diff --git a/src/api/core/sends.rs b/src/api/core/sends.rs index a7e5bcf04b..77d4a9deb4 100644 --- a/src/api/core/sends.rs +++ b/src/api/core/sends.rs @@ -1,4 +1,5 @@ use std::path::Path; +use std::sync::Arc; use chrono::{DateTime, TimeDelta, Utc}; use num_traits::ToPrimitive; @@ -7,6 +8,7 @@ use rocket::fs::NamedFile; use rocket::fs::TempFile; use rocket::serde::json::Json; use serde_json::Value; +use tokio::sync::RwLock; use crate::{ api::{ApiResult, EmptyResult, JsonResult, Notify, UpdateType}, @@ -38,9 +40,9 @@ pub fn routes() -> Vec { ] } -pub async fn purge_sends(pool: DbPool) { +pub async fn purge_sends(pool: Arc>) { debug!("Purging sends"); - if let Ok(mut conn) = pool.get().await { + if let Ok(mut conn) = pool.read().await.get().await { Send::purge(&mut conn).await; } else { error!("Failed to get DB connection while purging sends") diff --git a/src/api/core/two_factor/duo_oidc.rs b/src/api/core/two_factor/duo_oidc.rs index d252df9196..8c317c9e7a 100644 --- a/src/api/core/two_factor/duo_oidc.rs +++ b/src/api/core/two_factor/duo_oidc.rs @@ -4,7 +4,8 @@ use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation}; use reqwest::{header, StatusCode}; use ring::digest::{digest, Digest, SHA512_256}; use serde::Serialize; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; +use tokio::sync::RwLock; use crate::{ api::{core::two_factor::duo::get_duo_keys_email, EmptyResult}, @@ -345,9 +346,9 @@ async fn extract_context(state: &str, conn: &mut DbConn) -> Option>) { debug!("Purging Duo authentication contexts"); - if let Ok(mut conn) = pool.get().await { + if let Ok(mut conn) = pool.read().await.get().await { TwoFactorDuoContext::purge_expired_duo_contexts(&mut conn).await; } else { error!("Failed to get DB connection while purging expired Duo authentications") diff --git a/src/api/core/two_factor/mod.rs b/src/api/core/two_factor/mod.rs index e3795eb8d3..77e845c6c2 100644 --- a/src/api/core/two_factor/mod.rs +++ b/src/api/core/two_factor/mod.rs @@ -1,8 +1,11 @@ +use std::sync::Arc; + use chrono::{TimeDelta, Utc}; use data_encoding::BASE32; use rocket::serde::json::Json; use rocket::Route; use serde_json::Value; +use tokio::sync::RwLock; use crate::{ api::{ @@ -244,14 +247,14 @@ pub async fn enforce_2fa_policy_for_org( Ok(()) } -pub async fn send_incomplete_2fa_notifications(pool: DbPool) { +pub async fn send_incomplete_2fa_notifications(pool: Arc>) { debug!("Sending notifications for incomplete 2FA logins"); if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() { return; } - let mut conn = match pool.get().await { + let mut conn = match pool.read().await.get().await { Ok(conn) => conn, _ => { error!("Failed to get DB connection in send_incomplete_2fa_notifications()"); diff --git a/src/db/mod.rs b/src/db/mod.rs index fe1ab79bae..09b58e9758 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -12,7 +12,7 @@ use rocket::{ }; use tokio::{ - sync::{Mutex, OwnedSemaphorePermit, Semaphore}, + sync::{Mutex, OwnedSemaphorePermit, RwLock, Semaphore}, time::timeout, }; @@ -417,8 +417,8 @@ impl<'r> FromRequest<'r> for DbConn { type Error = (); async fn from_request(request: &'r Request<'_>) -> Outcome { - match request.rocket().state::() { - Some(p) => match p.get().await { + match request.rocket().state::>>() { + Some(p) => match p.read().await.get().await { Ok(dbconn) => Outcome::Success(dbconn), _ => Outcome::Error((Status::ServiceUnavailable, ())), }, diff --git a/src/main.rs b/src/main.rs index 33c38027be..a46b75493f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -38,6 +38,7 @@ use std::{ use tokio::{ fs::File, io::{AsyncBufReadExt, BufReader}, + sync::RwLock, }; #[cfg(unix)] @@ -84,9 +85,9 @@ async fn main() -> Result<(), Error> { create_dir(&CONFIG.sends_folder(), "sends folder"); create_dir(&CONFIG.attachments_folder(), "attachments folder"); - let pool = create_db_pool().await; - schedule_jobs(pool.clone()); - db::models::TwoFactor::migrate_u2f_to_webauthn(&mut pool.get().await.unwrap()).await.unwrap(); + let pool = Arc::new(RwLock::new(create_db_pool().await)); + schedule_jobs(Arc::clone(&pool)); + db::models::TwoFactor::migrate_u2f_to_webauthn(&mut pool.read().await.get().await.unwrap()).await.unwrap(); let extra_debug = matches!(level, log::LevelFilter::Trace | log::LevelFilter::Debug); launch_rocket(pool, extra_debug).await // Blocks until program termination. @@ -560,7 +561,7 @@ async fn create_db_pool() -> db::DbPool { } } -async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> { +async fn launch_rocket(pool: Arc>, extra_debug: bool) -> Result<(), Error> { let basepath = &CONFIG.domain_path(); let mut config = rocket::Config::from(rocket::Config::figment()); @@ -584,7 +585,7 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> .register([basepath, "/"].concat(), api::web_catchers()) .register([basepath, "/api"].concat(), api::core_catchers()) .register([basepath, "/admin"].concat(), api::admin_catchers()) - .manage(pool) + .manage(Arc::clone(&pool)) .manage(Arc::clone(&WS_USERS)) .manage(Arc::clone(&WS_ANONYMOUS_SUBSCRIPTIONS)) .attach(util::AppHeaders()) @@ -623,7 +624,7 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> Ok(()) } -fn schedule_jobs(pool: db::DbPool) { +fn schedule_jobs(pool: Arc>) { if CONFIG.job_poll_interval_ms() == 0 { info!("Job scheduler disabled."); return; @@ -642,14 +643,14 @@ fn schedule_jobs(pool: db::DbPool) { // Purge sends that are past their deletion date. if !CONFIG.send_purge_schedule().is_empty() { sched.add(Job::new(CONFIG.send_purge_schedule().parse().unwrap(), || { - runtime.spawn(api::purge_sends(pool.clone())); + runtime.spawn(api::purge_sends(Arc::clone(&pool))); })); } // Purge trashed items that are old enough to be auto-deleted. if !CONFIG.trash_purge_schedule().is_empty() { sched.add(Job::new(CONFIG.trash_purge_schedule().parse().unwrap(), || { - runtime.spawn(api::purge_trashed_ciphers(pool.clone())); + runtime.spawn(api::purge_trashed_ciphers(Arc::clone(&pool))); })); } @@ -657,7 +658,7 @@ fn schedule_jobs(pool: db::DbPool) { // indicates that a user's master password has been compromised. if !CONFIG.incomplete_2fa_schedule().is_empty() { sched.add(Job::new(CONFIG.incomplete_2fa_schedule().parse().unwrap(), || { - runtime.spawn(api::send_incomplete_2fa_notifications(pool.clone())); + runtime.spawn(api::send_incomplete_2fa_notifications(Arc::clone(&pool))); })); } @@ -666,7 +667,7 @@ fn schedule_jobs(pool: db::DbPool) { // sending reminders for requests that are about to be granted anyway. if !CONFIG.emergency_request_timeout_schedule().is_empty() { sched.add(Job::new(CONFIG.emergency_request_timeout_schedule().parse().unwrap(), || { - runtime.spawn(api::emergency_request_timeout_job(pool.clone())); + runtime.spawn(api::emergency_request_timeout_job(Arc::clone(&pool))); })); } @@ -674,20 +675,20 @@ fn schedule_jobs(pool: db::DbPool) { // emergency access requests. if !CONFIG.emergency_notification_reminder_schedule().is_empty() { sched.add(Job::new(CONFIG.emergency_notification_reminder_schedule().parse().unwrap(), || { - runtime.spawn(api::emergency_notification_reminder_job(pool.clone())); + runtime.spawn(api::emergency_notification_reminder_job(Arc::clone(&pool))); })); } if !CONFIG.auth_request_purge_schedule().is_empty() { sched.add(Job::new(CONFIG.auth_request_purge_schedule().parse().unwrap(), || { - runtime.spawn(purge_auth_requests(pool.clone())); + runtime.spawn(purge_auth_requests(Arc::clone(&pool))); })); } // Clean unused, expired Duo authentication contexts. if !CONFIG.duo_context_purge_schedule().is_empty() && CONFIG._enable_duo() && !CONFIG.duo_use_iframe() { sched.add(Job::new(CONFIG.duo_context_purge_schedule().parse().unwrap(), || { - runtime.spawn(purge_duo_contexts(pool.clone())); + runtime.spawn(purge_duo_contexts(Arc::clone(&pool))); })); } @@ -697,7 +698,7 @@ fn schedule_jobs(pool: db::DbPool) { && CONFIG.events_days_retain().is_some() { sched.add(Job::new(CONFIG.event_cleanup_schedule().parse().unwrap(), || { - runtime.spawn(api::event_cleanup_job(pool.clone())); + runtime.spawn(api::event_cleanup_job(Arc::clone(&pool))); })); }