Skip to content

Commit

Permalink
Add URL cache
Browse files Browse the repository at this point in the history
Signed-off-by: Elizabeth Myers <elizabeth.jennifer.myers@gmail.com>
  • Loading branch information
Elizafox committed Mar 18, 2024
1 parent cded7db commit da08457
Show file tree
Hide file tree
Showing 12 changed files with 288 additions and 84 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ serde = { version = "1.0.197", features = ["derive"] }
subtle = { version = "2.5.0", features = ["core_hint_black_box", "const-generics"] }
thiserror = "1.0.58"
time = { version = "0.3.34", features = ["local-offset"] }
tokio = { version = "1.36.0", features = ["macros", "parking_lot", "rt-multi-thread", "signal", "time"] }
tokio = { version = "1.36.0", features = ["macros", "parking_lot", "rt-multi-thread", "signal", "sync", "time"] }
tracing = { version = "0.1.40", features = ["async-await", "log"] }
tracing-subscriber = { version = "0.3.18", features = ["local-time", "parking_lot", "time"] }
tower = { version = "0.4.13", features = ["timeout", "tokio"] }
Expand Down
28 changes: 14 additions & 14 deletions src/csrf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,23 @@ const MAX_DURATION: Duration = Duration::minutes(10); // FIXME - configurable?

// Actual session data, a random token and the time the session began.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CsrfSessionData {
pub struct SessionData {
pub(super) token: String,
pub(super) time: OffsetDateTime,
}

// This is the actual entry that gets put into the session.
// We serialise CsrfSessionData, then encrypt it for storage.
// We serialise SessionData, then encrypt it for storage.
// It's safe to store the nonce decrypted, and necessary. It is however important the nonce *never
// once be reused*.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CsrfSessionEntry {
pub struct SessionEntry {
encrypted: Vec<u8>,
nonce: Vec<u8>,
}

#[derive(Debug, thiserror::Error)]
pub enum CsrfSessionError {
pub enum SessionError {
#[error(transparent)]
Session(#[from] tower_sessions::session::Error),

Expand All @@ -72,7 +72,7 @@ pub enum CsrfSessionError {
Expired,
}

impl CsrfSessionData {
impl SessionData {
fn new() -> Self {
let len_distr = Lazy::new(|| Uniform::new(24usize, 48usize));
let mut rng = thread_rng();
Expand All @@ -83,10 +83,10 @@ impl CsrfSessionData {
}
}

pub fn cmp(&self, token: &str) -> Result<(), CsrfSessionError> {
pub fn cmp(&self, token: &str) -> Result<(), SessionError> {
if token.as_bytes().ct_ne(self.token.as_bytes()).into() {
debug!("CSRF tokens mismatched: {token} != {}", self.token);
return Err(CsrfSessionError::Mismatch);
return Err(SessionError::Mismatch);
}

// This isn't sensitive info, so it's okay not to compare in constant time
Expand All @@ -95,19 +95,19 @@ impl CsrfSessionData {
"CSRF token expired: {}",
OffsetDateTime::now_utc() - self.time
);
return Err(CsrfSessionError::Expired);
return Err(SessionError::Expired);
}

Ok(())
}
}

impl CsrfSessionEntry {
impl SessionEntry {
pub async fn insert_session(
engine: &CryptoEngine,
session: &Session,
) -> Result<String, CsrfSessionError> {
let data = CsrfSessionData::new();
) -> Result<String, SessionError> {
let data = SessionData::new();

// Serialise and encrypt the data
let buf = bincode::serialize(&data)?;
Expand All @@ -128,16 +128,16 @@ impl CsrfSessionEntry {
engine: &CryptoEngine,
session: &Session,
token: &str,
) -> Result<(), CsrfSessionError> {
) -> Result<(), SessionError> {
// Smokey the Bear sez: Only YOU can prevent forest fi... err, session reuse
let entry: Self = session
.remove(SESSION_KEY)
.await?
.ok_or(CsrfSessionError::NoToken)?;
.ok_or(SessionError::NoToken)?;

// Decrypt and deserialise the data
let decrypted = engine.decrypt(entry.nonce.as_slice().into(), entry.encrypted.as_ref())?;
let data: CsrfSessionData = bincode::deserialize(&decrypted)?;
let data: SessionData = bincode::deserialize(&decrypted)?;

// Verify the token
data.cmp(token)
Expand Down
8 changes: 6 additions & 2 deletions src/err.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@ use tracing::{error, warn};
use crate::{
auth::{AuthError, Backend},
bancache::BanCacheError,
csrf::CsrfSessionError,
csrf::SessionError,
urlcache::UrlCacheError,
util::net::{AddressError, NetworkPrefixError},
};

// Anything that can go wrong in a handler should go here.
#[derive(Debug, thiserror::Error)]
pub enum AppError {
#[error(transparent)]
VerifyCsrf(#[from] CsrfSessionError),
VerifyCsrf(#[from] SessionError),

#[error(transparent)]
Auth(#[from] AuthError),
Expand All @@ -50,6 +51,9 @@ pub enum AppError {
#[error(transparent)]
BanCache(#[from] BanCacheError),

#[error(transparent)]
UrlCache(#[from] UrlCacheError),

#[error(transparent)]
Address(#[from] AddressError),

Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod env;
mod err;
mod generate;
mod state;
mod urlcache;
mod util;
mod validators;
mod web;
Expand Down
3 changes: 2 additions & 1 deletion src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::sync::Arc;

use sea_orm::DbConn;

use crate::{bancache::BanCache, csrf::CryptoEngine, env::Vars};
use crate::{bancache::BanCache, csrf::CryptoEngine, env::Vars, urlcache::UrlCache};

// This is the struct that holds state for handlers
#[allow(clippy::module_name_repetitions)]
Expand All @@ -29,5 +29,6 @@ pub struct AppState {
pub(crate) db: Arc<DbConn>,
pub(crate) env: Vars,
pub(crate) bancache: BanCache,
pub(crate) urlcache: UrlCache,
pub(crate) csrf_crypto_engine: CryptoEngine,
}
159 changes: 159 additions & 0 deletions src/urlcache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/* SPDX-License-Identifier: CC0-1.0
*
* src/urlcache.rs
*
* This file is a component of ShadyURL by Elizabeth Myers.
*
* To the extent possible under law, the person who associated CC0 with
* ShadyURL has waived all copyright and related or neighboring rights
* to ShadyURL.
*
* You should have received a copy of the CC0 legalcode along with this
* work. If not, see <http://creativecommons.org/publicdomain/zero/1.0/>.
*/

// Compiled regex caching

use std::sync::Arc;

use moka::future::Cache;
use regex::Regex;
use sea_orm::{DbConn, DbErr};
use time::Duration;
use tokio::sync::RwLock;
use tracing::trace;

use service::Query;

// TODO: configurable
const CACHE_ENTRIES: u64 = 1000;

#[derive(Debug, thiserror::Error)]
pub enum UrlCacheError {
#[error(transparent)]
Db(#[from] DbErr),

#[error(transparent)]
Regex(#[from] regex::Error),

#[error("Regex not found: {}", .0)]
RegexNotFound(String),

#[error("Regex is duplicate: {}", .0)]
RegexDuplicated(String),
}

#[derive(Clone, Debug)]
pub struct UrlCache {
// This makes sure we can still clone, yet we still point to the same Vec
regex: Arc<RwLock<Vec<Regex>>>,
cache: Cache<String, bool>, // Caches addresses
db: Arc<DbConn>,
}

impl UrlCache {
// Create a new UrlCache instance and initalise regexes from the database.
pub(crate) async fn new(db: Arc<DbConn>) -> Result<Self, UrlCacheError> {
let mut regex = Vec::new();
for url_filter in Query::fetch_all_url_filters(&db).await? {
trace!("Adding URL filter {}", url_filter.0.filter);
let cmpreg = Regex::new(&url_filter.0.filter)?;
regex.push(cmpreg);
}

Ok(Self {
// XXX - should these cache parameters be configurable?
regex: Arc::new(RwLock::new(regex)),
cache: Cache::builder()
.max_capacity(CACHE_ENTRIES)
.time_to_live(Duration::days(7).unsigned_abs())
.time_to_idle(Duration::days(1).unsigned_abs())
.support_invalidation_closures()
.build(),
db,
})
}

// Sync the regex cache with the database, flushing all old entries
// This also flushes the URL cache.
pub(crate) async fn sync_regex_cache(&mut self) -> Result<(), UrlCacheError> {
let mut regex_vec = self.regex.write().await;
let mut new_regex = Vec::with_capacity(regex_vec.len());

for url_filter in Query::fetch_all_url_filters(&self.db).await? {
let cmpreg = Regex::new(&url_filter.0.filter)?;
new_regex.push(cmpreg);
}

regex_vec.clear();
regex_vec.extend(new_regex);
drop(regex_vec);
self.cache.invalidate_all();
Ok(())
}

// Add one regex without flushing the entire cache.
// This operation also removes any matching URL's from the cache.
// NOTE: this does not update the database
pub(crate) async fn add_regex_cache(&mut self, cmpreg: Regex) -> Result<(), UrlCacheError> {
let mut regex_vec = self.regex.write().await;
if regex_vec
.iter()
.any(|ocmpreg| ocmpreg.as_str() == cmpreg.as_str())
{
trace!("URL is duplicated: {}", cmpreg.as_str());
return Err(UrlCacheError::RegexDuplicated(cmpreg.as_str().to_string()));
}

trace!("Adding URL filter to regex cache {}", cmpreg.as_str());

regex_vec.push(cmpreg.clone());
drop(regex_vec);
self.cache
.invalidate_entries_if(move |k, _| cmpreg.is_match(k))
.expect("Could not invalidate cache");
Ok(())
}

// Remove one regex without flushing the entire cache.
// This operation also removes any matching URL's from the cache.
// NOTE: this does not update the database
pub(crate) async fn remove_regex_cache(&mut self, regstr: &str) -> Result<(), UrlCacheError> {
// Invariant: regexes are not duplicated
let mut regex_vec = self.regex.write().await;
let pos = regex_vec
.iter()
.position(|cmpreg| cmpreg.as_str() == regstr)
.ok_or_else(|| UrlCacheError::RegexNotFound(regstr.to_string()))?;
let cmpreg = regex_vec[pos].clone();
regex_vec.swap_remove(pos);
drop(regex_vec);
self.cache
.invalidate_entries_if(move |k, _| cmpreg.is_match(k))
.expect("Could not invalidate cache");

trace!("Removing URL filter from regex cache {}", regstr);
Ok(())
}

// Check URL against cache.
// If not found, it will check the regexes, and cache the result.
// Returns true if found in the cache, false otherwise.
pub(crate) async fn check_url_banned(&self, url: &str) -> Result<bool, UrlCacheError> {
if let Some(is_match) = self.cache.get(url).await {
trace!("Cached URL ban result for \"{}\": {:?}", url, is_match);
return Ok(is_match);
}

let regex_vec = self.regex.read().await;
let is_match = regex_vec
.iter()
.map(|cmpreg| cmpreg.is_match(url))
.any(|x| x);
drop(regex_vec);

trace!("Uncached URL ban result for \"{}\": {:?}", url, is_match);
self.cache.insert(url.to_string(), is_match).await;
Ok(is_match)
}
}
14 changes: 7 additions & 7 deletions src/web/admin/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use tracing::{info, warn};

use crate::{
auth::{AuthSession, Credentials},
csrf::CsrfSessionEntry,
csrf::SessionEntry,
err::AppError,
state::AppState,
};
Expand All @@ -50,8 +50,8 @@ pub fn router() -> Router<AppState> {

mod post {
use super::{
info, warn, AppError, AppState, AuthSession, Credentials, CsrfSessionEntry, Form,
IntoResponse, Messages, Redirect, Response, Session, State,
info, warn, AppError, AppState, AuthSession, Credentials, Form, IntoResponse, Messages,
Redirect, Response, Session, SessionEntry, State,
};

pub(super) async fn login(
Expand All @@ -61,7 +61,7 @@ mod post {
State(state): State<AppState>,
Form(creds): Form<Credentials>,
) -> Result<Response, AppError> {
CsrfSessionEntry::check_session(
SessionEntry::check_session(
&state.csrf_crypto_engine,
&session,
&creds.authenticity_token,
Expand All @@ -87,8 +87,8 @@ mod post {
mod get {
use super::{
info, AppError, AppState, AuthSession, CsrfSessionEntry, IntoResponse, LoginTemplate,
Messages, Redirect, Response, Session, State,
info, AppError, AppState, AuthSession, IntoResponse, LoginTemplate, Messages, Redirect,
Response, Session, SessionEntry, State,
};
pub(super) async fn login(
Expand All @@ -103,7 +103,7 @@ mod get {
}

let authenticity_token =
CsrfSessionEntry::insert_session(&state.csrf_crypto_engine, &session).await?;
SessionEntry::insert_session(&state.csrf_crypto_engine, &session).await?;

Ok(LoginTemplate {
authenticity_token,
Expand Down
Loading

0 comments on commit da08457

Please sign in to comment.