Skip to content

Commit

Permalink
feat(serde): Use serde and bitcode to store structs as blobs (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike Rivnak authored May 18, 2024
1 parent eff47b6 commit c4ed45f
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 39 deletions.
8 changes: 6 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@ repository = "https://github.com/mrivnak/pond"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
bitcode = { version = "0.6.0", default-features = false, features = ["serde"] }
chrono = { version = "0.4.38", features = ["serde"] }
rusqlite = { version = "0.31.0", features = ["bundled"] }
rusqlite = { version = "0.31.0", features = ["blob", "bundled"] }
serde = "1.0.202"

[dev-dependencies]
bitcode = { version = "0.6.0", features = ["serde"] }
rand = "0.8.5"
uuid = { version = "1.8.0", features = ["v4"] }
serde = { version = "1.0.202", features = ["derive"] }
uuid = { version = "1.8.0", features = ["v4", "serde"] }

108 changes: 71 additions & 37 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
use std::hash::{DefaultHasher, Hash, Hasher};
use std::path::PathBuf;
use std::time::Instant;

use chrono::{DateTime, Duration, Utc};
use rusqlite::Connection;
use serde::de::DeserializeOwned;
use serde::Serialize;

pub use rusqlite::types::{FromSql, ToSql};
pub use rusqlite::Error;

pub struct Cache {
pub struct Cache<T> {
path: PathBuf,
ttl: Duration,
data: std::marker::PhantomData<T>,
}

#[derive(Debug)]
pub struct CacheEntry<T>
struct CacheEntry<T>
where
T: ToSql + FromSql,
T: Serialize + DeserializeOwned + Clone,
{
key: u32,
value: T,
expiration: DateTime<Utc>,
}

impl Cache {
impl<T: Serialize + DeserializeOwned + Clone> Cache<T> {
pub fn new(path: PathBuf) -> Result<Self, Error> {
Self::with_time_to_live(path, Duration::minutes(10))
}
Expand All @@ -35,20 +36,23 @@ impl Cache {
"CREATE TABLE IF NOT EXISTS items (
id TEXT PRIMARY KEY,
expires TEXT NOT NULL,
data TEXT NOT NULL
data BLOB NOT NULL
)",
(), // empty list of parameters.
(),
)?;

db.close().expect("Failed to close database connection");

Ok(Self { path, ttl })
Ok(Self {
path,
ttl,
data: std::marker::PhantomData,
})
}

pub fn get<K, T>(&self, key: K) -> Result<Option<T>, Error>
pub fn get<K>(&self, key: K) -> Result<Option<T>, Error>
where
K: Hash,
T: ToSql + FromSql,
{
let db = Connection::open(self.path.as_path())?;

Expand Down Expand Up @@ -77,37 +81,31 @@ impl Cache {
.with_timezone(&Utc)
})
.unwrap();
let data: T = row.get(2).unwrap();
let data: Vec<u8> = row.get(2).unwrap();

drop(rows);
drop(stmt);
db.close().expect("Failed to close database connection");

let data: T = bitcode::deserialize(&data).unwrap();

if expires < Utc::now() {
Ok(None)
} else {
Ok(Some(data))
}
}

pub fn store<K, T>(&self, key: K, value: T) -> Result<(), Error>
where
K: Hash,
T: ToSql + FromSql,
{
pub fn store<K: Hash>(&self, key: K, value: T) -> Result<(), Error> {
self.store_with_expiration(key, value, Utc::now() + self.ttl)
}

pub fn store_with_expiration<K, T>(
pub fn store_with_expiration<K: Hash>(
&self,
key: K,
value: T,
expiration: DateTime<Utc>,
) -> Result<(), Error>
where
K: Hash,
T: ToSql + FromSql,
{
) -> Result<(), Error> {
let mut hasher = DefaultHasher::new();
let hash = {
key.hash(&mut hasher);
Expand All @@ -127,7 +125,7 @@ impl Cache {
(
&value.key.to_string(),
&value.expiration.to_rfc3339(),
&value.value,
&bitcode::serialize(&value.value).unwrap(),
),
)?;

Expand All @@ -152,14 +150,22 @@ impl Cache {

#[cfg(test)]
mod tests {
use serde::Deserialize;
use serde::Serialize;
use uuid::Uuid;

use super::*;

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
struct User {
id: Uuid,
name: String,
}

fn store_manual(
path: PathBuf,
key: String,
value: String,
value: Vec<u8>,
expires: DateTime<Utc>,
) -> Result<(), Error> {
let mut hasher = DefaultHasher::new();
Expand All @@ -180,7 +186,7 @@ mod tests {
Ok(())
}

fn get_manual<T: ToSql + FromSql>(
fn get_manual<T: Serialize + DeserializeOwned + Clone>(
path: PathBuf,
key: String,
) -> Result<Option<CacheEntry<T>>, Error> {
Expand Down Expand Up @@ -212,12 +218,14 @@ mod tests {
.with_timezone(&Utc)
})
.unwrap();
let data: T = row.get(2).unwrap();
let data: Vec<u8> = row.get(2).unwrap();

drop(rows);
drop(stmt);
db.close().expect("Failed to close database connection");

let data: T = bitcode::deserialize(&data).unwrap();

Ok(Some(CacheEntry {
key: hash,
value: data,
Expand All @@ -232,7 +240,7 @@ mod tests {
Uuid::new_v4(),
rand::random::<u8>()
));
let cache = Cache::new(filename.clone()).unwrap();
let cache: Cache<String> = Cache::new(filename.clone()).unwrap();
assert_eq!(cache.path, filename);
assert_eq!(cache.ttl, Duration::minutes(10));
}
Expand All @@ -244,8 +252,8 @@ mod tests {
Uuid::new_v4(),
rand::random::<u8>()
));
let _ = Cache::new(filename.clone()).unwrap();
let _ = Cache::new(filename).unwrap();
let _: Cache<String> = Cache::new(filename.clone()).unwrap();
let _: Cache<String> = Cache::new(filename).unwrap();
}

#[test]
Expand All @@ -255,7 +263,8 @@ mod tests {
Uuid::new_v4(),
rand::random::<u8>()
));
let cache = Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap();
let cache: Cache<String> =
Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap();
assert_eq!(cache.path, filename);
assert_eq!(cache.ttl, Duration::minutes(5));
}
Expand All @@ -279,6 +288,28 @@ mod tests {
assert_eq!(result, Some(value));
}

#[test]
fn test_store_get_struct() {
let filename = std::env::temp_dir().join(format!(
"pond-test-{}-{}.sqlite",
Uuid::new_v4(),
rand::random::<u8>()
));

let cache = Cache::new(filename).unwrap();

let key = Uuid::new_v4();
let value = User {
id: Uuid::new_v4(),
name: String::from("Alice"),
};

cache.store(key, value.clone()).unwrap();
let result: Option<_> = cache.get(key).unwrap();

assert_eq!(result, Some(value));
}

#[test]
fn test_store_existing() {
let filename = std::env::temp_dir().join(format!(
Expand Down Expand Up @@ -317,7 +348,7 @@ mod tests {
store_manual(
filename,
key.to_string(),
value,
bitcode::serialize(&value).unwrap(),
Utc::now() - Duration::minutes(5),
)
.unwrap();
Expand Down Expand Up @@ -345,7 +376,8 @@ mod tests {

#[test]
fn test_invalid_path() {
let cache = Cache::new(PathBuf::from("invalid/path/db.sqlite"));
let cache: Result<Cache<String>, Error> =
Cache::new(PathBuf::from("invalid/path/db.sqlite"));

assert!(cache.is_err());
}
Expand All @@ -358,15 +390,16 @@ mod tests {
rand::random::<u8>()
));

let cache = Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap();
let cache: Cache<String> =
Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap();

let key = Uuid::new_v4().to_string();
let value = String::from("Hello, world!");

store_manual(
filename.clone(),
key.clone(),
value.clone(),
bitcode::serialize(&value).unwrap(),
Utc::now() - Duration::minutes(5),
)
.unwrap();
Expand All @@ -391,15 +424,16 @@ mod tests {
rand::random::<u8>()
));

let cache = Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap();
let cache: Cache<String> =
Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap();

let key = Uuid::new_v4().to_string();
let value = String::from("Hello, world!");

store_manual(
filename.clone(),
key.clone(),
value.clone(),
bitcode::serialize(&value).unwrap(),
Utc::now() + Duration::minutes(15),
)
.unwrap();
Expand Down

0 comments on commit c4ed45f

Please sign in to comment.