Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enchancement: use Arc + Mutex for dbpool instance instead clone() #5037

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/api/core/accounts.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::sync::Arc;

use crate::db::DbPool;
use chrono::{SecondsFormat, Utc};
use rocket::serde::json::Json;
use serde_json::Value;
use tokio::sync::RwLock;

use crate::{
api::{
Expand Down Expand Up @@ -1282,9 +1285,9 @@ async fn get_auth_requests(headers: Headers, mut conn: DbConn) -> JsonResult {
})))
}

pub async fn purge_auth_requests(pool: DbPool) {
pub async fn purge_auth_requests(pool: Arc<RwLock<DbPool>>) {
debug!("Purging auth requests");
if let Ok(mut conn) = pool.get().await {
if let Ok(mut conn) = pool.read().await.get().await {
AuthRequest::purge_expired_auth_requests(&mut conn).await;
} else {
error!("Failed to get DB connection while purging trashed ciphers")
Expand Down
6 changes: 4 additions & 2 deletions src/api/core/ciphers.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use chrono::{NaiveDateTime, Utc};
use num_traits::ToPrimitive;
Expand All @@ -9,6 +10,7 @@ use rocket::{
Route,
};
use serde_json::Value;
use tokio::sync::RwLock;

use crate::util::NumberOrString;
use crate::{
Expand Down Expand Up @@ -88,9 +90,9 @@ pub fn routes() -> Vec<Route> {
]
}

pub async fn purge_trashed_ciphers(pool: DbPool) {
pub async fn purge_trashed_ciphers(pool: Arc<RwLock<DbPool>>) {
debug!("Purging trashed ciphers");
if let Ok(mut conn) = pool.get().await {
if let Ok(mut conn) = pool.read().await.get().await {
Cipher::purge_trash(&mut conn).await;
} else {
error!("Failed to get DB connection while purging trashed ciphers")
Expand Down
11 changes: 7 additions & 4 deletions src/api/core/emergency_access.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::sync::Arc;

use chrono::{TimeDelta, Utc};
use rocket::{serde::json::Json, Route};
use serde_json::Value;
use tokio::sync::RwLock;

use crate::{
api::{
Expand Down Expand Up @@ -729,13 +732,13 @@ fn check_emergency_access_enabled() -> EmptyResult {
Ok(())
}

pub async fn emergency_request_timeout_job(pool: DbPool) {
pub async fn emergency_request_timeout_job(pool: Arc<RwLock<DbPool>>) {
debug!("Start emergency_request_timeout_job");
if !CONFIG.emergency_access_allowed() {
return;
}

if let Ok(mut conn) = pool.get().await {
if let Ok(mut conn) = pool.read().await.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await;

if emergency_access_list.is_empty() {
Expand Down Expand Up @@ -784,13 +787,13 @@ pub async fn emergency_request_timeout_job(pool: DbPool) {
}
}

pub async fn emergency_notification_reminder_job(pool: DbPool) {
pub async fn emergency_notification_reminder_job(pool: Arc<RwLock<DbPool>>) {
debug!("Start emergency_notification_reminder_job");
if !CONFIG.emergency_access_allowed() {
return;
}

if let Ok(mut conn) = pool.get().await {
if let Ok(mut conn) = pool.read().await.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await;

if emergency_access_list.is_empty() {
Expand Down
7 changes: 4 additions & 3 deletions src/api/core/events.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::net::IpAddr;
use std::{net::IpAddr, sync::Arc};

use chrono::NaiveDateTime;
use rocket::{form::FromForm, serde::json::Json, Route};
use serde_json::Value;
use tokio::sync::RwLock;

use crate::{
api::{EmptyResult, JsonResult},
Expand Down Expand Up @@ -320,14 +321,14 @@ async fn _log_event(
event.save(conn).await.unwrap_or(());
}

pub async fn event_cleanup_job(pool: DbPool) {
pub async fn event_cleanup_job(pool: Arc<RwLock<DbPool>>) {
debug!("Start events cleanup job");
if CONFIG.events_days_retain().is_none() {
debug!("events_days_retain is not configured, abort");
return;
}

if let Ok(mut conn) = pool.get().await {
if let Ok(mut conn) = pool.read().await.get().await {
Event::clean_events(&mut conn).await.ok();
} else {
error!("Failed to get DB connection while trying to cleanup the events table")
Expand Down
6 changes: 4 additions & 2 deletions src/api/core/sends.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::path::Path;
use std::sync::Arc;

use chrono::{DateTime, TimeDelta, Utc};
use num_traits::ToPrimitive;
Expand All @@ -7,6 +8,7 @@ use rocket::fs::NamedFile;
use rocket::fs::TempFile;
use rocket::serde::json::Json;
use serde_json::Value;
use tokio::sync::RwLock;

use crate::{
api::{ApiResult, EmptyResult, JsonResult, Notify, UpdateType},
Expand Down Expand Up @@ -38,9 +40,9 @@ pub fn routes() -> Vec<rocket::Route> {
]
}

pub async fn purge_sends(pool: DbPool) {
pub async fn purge_sends(pool: Arc<RwLock<DbPool>>) {
debug!("Purging sends");
if let Ok(mut conn) = pool.get().await {
if let Ok(mut conn) = pool.read().await.get().await {
Send::purge(&mut conn).await;
} else {
error!("Failed to get DB connection while purging sends")
Expand Down
7 changes: 4 additions & 3 deletions src/api/core/two_factor/duo_oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
use reqwest::{header, StatusCode};
use ring::digest::{digest, Digest, SHA512_256};
use serde::Serialize;
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;

use crate::{
api::{core::two_factor::duo::get_duo_keys_email, EmptyResult},
Expand Down Expand Up @@ -345,9 +346,9 @@ async fn extract_context(state: &str, conn: &mut DbConn) -> Option<DuoAuthContex
}

// Task to clean up expired Duo authentication contexts that may have accumulated in the database.
pub async fn purge_duo_contexts(pool: DbPool) {
pub async fn purge_duo_contexts(pool: Arc<RwLock<DbPool>>) {
debug!("Purging Duo authentication contexts");
if let Ok(mut conn) = pool.get().await {
if let Ok(mut conn) = pool.read().await.get().await {
TwoFactorDuoContext::purge_expired_duo_contexts(&mut conn).await;
} else {
error!("Failed to get DB connection while purging expired Duo authentications")
Expand Down
7 changes: 5 additions & 2 deletions src/api/core/two_factor/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::sync::Arc;

use chrono::{TimeDelta, Utc};
use data_encoding::BASE32;
use rocket::serde::json::Json;
use rocket::Route;
use serde_json::Value;
use tokio::sync::RwLock;

use crate::{
api::{
Expand Down Expand Up @@ -244,14 +247,14 @@ pub async fn enforce_2fa_policy_for_org(
Ok(())
}

pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
pub async fn send_incomplete_2fa_notifications(pool: Arc<RwLock<DbPool>>) {
debug!("Sending notifications for incomplete 2FA logins");

if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
return;
}

let mut conn = match pool.get().await {
let mut conn = match pool.read().await.get().await {
Ok(conn) => conn,
_ => {
error!("Failed to get DB connection in send_incomplete_2fa_notifications()");
Expand Down
6 changes: 3 additions & 3 deletions src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use rocket::{
};

use tokio::{
sync::{Mutex, OwnedSemaphorePermit, Semaphore},
sync::{Mutex, OwnedSemaphorePermit, RwLock, Semaphore},
time::timeout,
};

Expand Down Expand Up @@ -417,8 +417,8 @@ impl<'r> FromRequest<'r> for DbConn {
type Error = ();

async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.rocket().state::<DbPool>() {
Some(p) => match p.get().await {
match request.rocket().state::<Arc<RwLock<DbPool>>>() {
Some(p) => match p.read().await.get().await {
Ok(dbconn) => Outcome::Success(dbconn),
_ => Outcome::Error((Status::ServiceUnavailable, ())),
},
Expand Down
29 changes: 15 additions & 14 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use std::{
use tokio::{
fs::File,
io::{AsyncBufReadExt, BufReader},
sync::RwLock,
};

#[cfg(unix)]
Expand Down Expand Up @@ -84,9 +85,9 @@ async fn main() -> Result<(), Error> {
create_dir(&CONFIG.sends_folder(), "sends folder");
create_dir(&CONFIG.attachments_folder(), "attachments folder");

let pool = create_db_pool().await;
schedule_jobs(pool.clone());
db::models::TwoFactor::migrate_u2f_to_webauthn(&mut pool.get().await.unwrap()).await.unwrap();
let pool = Arc::new(RwLock::new(create_db_pool().await));
schedule_jobs(Arc::clone(&pool));
db::models::TwoFactor::migrate_u2f_to_webauthn(&mut pool.read().await.get().await.unwrap()).await.unwrap();

let extra_debug = matches!(level, log::LevelFilter::Trace | log::LevelFilter::Debug);
launch_rocket(pool, extra_debug).await // Blocks until program termination.
Expand Down Expand Up @@ -560,7 +561,7 @@ async fn create_db_pool() -> db::DbPool {
}
}

async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> {
async fn launch_rocket(pool: Arc<RwLock<db::DbPool>>, extra_debug: bool) -> Result<(), Error> {
let basepath = &CONFIG.domain_path();

let mut config = rocket::Config::from(rocket::Config::figment());
Expand All @@ -584,7 +585,7 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error>
.register([basepath, "/"].concat(), api::web_catchers())
.register([basepath, "/api"].concat(), api::core_catchers())
.register([basepath, "/admin"].concat(), api::admin_catchers())
.manage(pool)
.manage(Arc::clone(&pool))
.manage(Arc::clone(&WS_USERS))
.manage(Arc::clone(&WS_ANONYMOUS_SUBSCRIPTIONS))
.attach(util::AppHeaders())
Expand Down Expand Up @@ -623,7 +624,7 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error>
Ok(())
}

fn schedule_jobs(pool: db::DbPool) {
fn schedule_jobs(pool: Arc<RwLock<db::DbPool>>) {
if CONFIG.job_poll_interval_ms() == 0 {
info!("Job scheduler disabled.");
return;
Expand All @@ -642,22 +643,22 @@ fn schedule_jobs(pool: db::DbPool) {
// Purge sends that are past their deletion date.
if !CONFIG.send_purge_schedule().is_empty() {
sched.add(Job::new(CONFIG.send_purge_schedule().parse().unwrap(), || {
runtime.spawn(api::purge_sends(pool.clone()));
runtime.spawn(api::purge_sends(Arc::clone(&pool)));
}));
}

// Purge trashed items that are old enough to be auto-deleted.
if !CONFIG.trash_purge_schedule().is_empty() {
sched.add(Job::new(CONFIG.trash_purge_schedule().parse().unwrap(), || {
runtime.spawn(api::purge_trashed_ciphers(pool.clone()));
runtime.spawn(api::purge_trashed_ciphers(Arc::clone(&pool)));
}));
}

// Send email notifications about incomplete 2FA logins, which potentially
// indicates that a user's master password has been compromised.
if !CONFIG.incomplete_2fa_schedule().is_empty() {
sched.add(Job::new(CONFIG.incomplete_2fa_schedule().parse().unwrap(), || {
runtime.spawn(api::send_incomplete_2fa_notifications(pool.clone()));
runtime.spawn(api::send_incomplete_2fa_notifications(Arc::clone(&pool)));
}));
}

Expand All @@ -666,28 +667,28 @@ fn schedule_jobs(pool: db::DbPool) {
// sending reminders for requests that are about to be granted anyway.
if !CONFIG.emergency_request_timeout_schedule().is_empty() {
sched.add(Job::new(CONFIG.emergency_request_timeout_schedule().parse().unwrap(), || {
runtime.spawn(api::emergency_request_timeout_job(pool.clone()));
runtime.spawn(api::emergency_request_timeout_job(Arc::clone(&pool)));
}));
}

// Send reminders to emergency access grantors that there are pending
// emergency access requests.
if !CONFIG.emergency_notification_reminder_schedule().is_empty() {
sched.add(Job::new(CONFIG.emergency_notification_reminder_schedule().parse().unwrap(), || {
runtime.spawn(api::emergency_notification_reminder_job(pool.clone()));
runtime.spawn(api::emergency_notification_reminder_job(Arc::clone(&pool)));
}));
}

if !CONFIG.auth_request_purge_schedule().is_empty() {
sched.add(Job::new(CONFIG.auth_request_purge_schedule().parse().unwrap(), || {
runtime.spawn(purge_auth_requests(pool.clone()));
runtime.spawn(purge_auth_requests(Arc::clone(&pool)));
}));
}

// Clean unused, expired Duo authentication contexts.
if !CONFIG.duo_context_purge_schedule().is_empty() && CONFIG._enable_duo() && !CONFIG.duo_use_iframe() {
sched.add(Job::new(CONFIG.duo_context_purge_schedule().parse().unwrap(), || {
runtime.spawn(purge_duo_contexts(pool.clone()));
runtime.spawn(purge_duo_contexts(Arc::clone(&pool)));
}));
}

Expand All @@ -697,7 +698,7 @@ fn schedule_jobs(pool: db::DbPool) {
&& CONFIG.events_days_retain().is_some()
{
sched.add(Job::new(CONFIG.event_cleanup_schedule().parse().unwrap(), || {
runtime.spawn(api::event_cleanup_job(pool.clone()));
runtime.spawn(api::event_cleanup_job(Arc::clone(&pool)));
}));
}

Expand Down
Loading