Skip to content

Commit

Permalink
feat(client): Caching requests to servers for feed (#12)
Browse files Browse the repository at this point in the history
This PR implements a client that caches the get requests to the servers
when fetching feeds. It is used in places where `reqwest::Client` is
used. Currently there are two places:

- the client used to fetch the source (default ttl: 15 minutes)
- the full_text filter (default ttl: 12 hours)

This mechanism is expected to help in speeding up the fetching of feeds
especially for repeated full-text requests.

I'm also planning to employ this cache tests to insert custom fixtures
for certain web addresses for future feature tests.

---------

Co-authored-by: Shou Ya <shouya@users.noreply.github.com>
  • Loading branch information
shouya and shouya committed Feb 2, 2024
1 parent b163660 commit 20937b9
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 29 deletions.
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ either = "1.9.0" # used for returning sum types from the JS runtime
# Web client (blocking and async both used, blocking used in the JS runtime)
# TODO: upgrade reqwest after its hyper 1.0 upgrade
reqwest = { version = "0.11.23", default-features = false, features = ["blocking", "rustls-tls", "trust-dns"] }
encoding_rs = "0.8.33"
lru = "0.12.2"

# Used in sanitize filter to remove/replace text contents
regex = "1.10.2"
Expand Down
103 changes: 101 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
mod cache;

use std::time::Duration;

use reqwest::header::HeaderMap;
use serde::{Deserialize, Serialize};
use url::Url;

use crate::util::Result;

use self::cache::{Response, ResponseCache};

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ClientConfig {
user_agent: Option<String>,
accept: Option<String>,
set_cookie: Option<String>,
referer: Option<String>,
cache_size: Option<usize>,
#[serde(deserialize_with = "duration_str::deserialize_option_duration")]
cache_ttl: Option<Duration>,
#[serde(default = "default_timeout")]
#[serde(deserialize_with = "duration_str::deserialize_duration")]
timeout: Duration,
Expand All @@ -24,6 +32,8 @@ impl Default for ClientConfig {
set_cookie: None,
referer: None,
timeout: default_timeout(),
cache_size: None,
cache_ttl: None,
}
}
}
Expand Down Expand Up @@ -67,11 +77,100 @@ impl ClientConfig {
builder
}

pub fn build(&self) -> Result<reqwest::Client> {
Ok(self.to_builder().build()?)
pub fn build(&self, default_cache_ttl: Duration) -> Result<Client> {
let reqwest_client = self.to_builder().build()?;
Ok(Client::new(
self.cache_size.unwrap_or(0),
self.cache_ttl.unwrap_or(default_cache_ttl),
reqwest_client,
))
}
}

pub struct Client {
cache: ResponseCache,
client: reqwest::Client,
}

impl Client {
fn new(
cache_size: usize,
cache_ttl: Duration,
client: reqwest::Client,
) -> Self {
Self {
cache: ResponseCache::new(cache_size, cache_ttl),
client,
}
}

pub async fn get(&self, url: &Url) -> Result<Response> {
self.get_with(url, |req| req).await
}

pub async fn get_with(
&self,
url: &Url,
f: impl FnOnce(reqwest::RequestBuilder) -> reqwest::RequestBuilder,
) -> Result<Response> {
if let Some(resp) = self.cache.get_cached(url) {
return Ok(resp);
}

let resp = f(self.client.get(url.clone())).send().await?;
let resp = Response::from_reqwest_resp(resp).await?;
self.cache.insert(url.clone(), resp.clone());
Ok(resp)
}

#[cfg(test)]
pub fn insert(&self, url: Url, response: Response) {
self.cache.insert(url, response);
}
}

fn default_timeout() -> Duration {
Duration::from_secs(10)
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_client_cache() {
let client = Client::new(1, Duration::from_secs(1), reqwest::Client::new());
let url = Url::parse("http://example.com").unwrap();
let body: Box<str> = "foo".into();
let response = Response::new(
url.clone(),
reqwest::StatusCode::OK,
HeaderMap::new(),
body.into(),
);

client.insert(url.clone(), response.clone());
let actual = client.get(&url).await.unwrap();
let expected = response;

assert_eq!(actual.url(), expected.url());
assert_eq!(actual.status(), expected.status());
assert_eq!(actual.headers(), expected.headers());
assert_eq!(actual.body(), expected.body());
}

const YT_SCISHOW_FEED_URL: &str = "https://www.youtube.com/feeds/videos.xml?channel_id=UCZYTClx2T1of7BRZ86-8fow";

#[tokio::test]
async fn test_client() {
let client = Client::new(0, Duration::from_secs(1), reqwest::Client::new());
let url = Url::parse(YT_SCISHOW_FEED_URL).unwrap();
let resp = client.get(&url).await.unwrap();
assert_eq!(resp.status(), reqwest::StatusCode::OK);
assert_eq!(
resp.content_type().unwrap().to_string(),
"text/xml; charset=utf-8"
);
assert!(resp.text().unwrap().contains("<title>SciShow</title>"));
}
}
151 changes: 151 additions & 0 deletions src/client/cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use std::{
num::NonZeroUsize,
sync::{Arc, RwLock},
time::{Duration, Instant},
};

use lru::LruCache;
use mime::Mime;
use reqwest::header::HeaderMap;
use url::Url;

use crate::util::{Error, Result};

struct Timed<T> {
value: T,
created: Instant,
}

pub struct ResponseCache {
map: RwLock<LruCache<Url, Timed<Response>>>,
timeout: Duration,
}

impl ResponseCache {
pub fn new(max_entries: usize, timeout: Duration) -> Self {
let max_entries = max_entries.try_into().unwrap_or(NonZeroUsize::MIN);
Self {
map: RwLock::new(LruCache::new(max_entries)),
timeout,
}
}

pub fn get_cached(&self, url: &Url) -> Option<Response> {
let mut map = self.map.write().ok()?;
let Some(entry) = map.get(url) else {
return None;
};
if entry.created.elapsed() > self.timeout {
map.pop(url);
return None;
}
Some(entry.value.clone())
}

pub fn insert(&self, url: Url, response: Response) -> Option<()> {
let timed = Timed {
value: response,
created: Instant::now(),
};
self.map.write().ok()?.push(url, timed);
Some(())
}
}

#[derive(Clone)]
pub struct Response {
inner: Arc<InnerResponse>,
}

struct InnerResponse {
url: Url,
status: reqwest::StatusCode,
headers: HeaderMap,
body: Box<[u8]>,
}

impl Response {
pub async fn from_reqwest_resp(resp: reqwest::Response) -> Result<Self> {
let status = resp.status();
let headers = resp.headers().clone();
let url = resp.url().clone();
let body = resp.bytes().await?.to_vec().into_boxed_slice();
let resp = InnerResponse {
url,
status,
headers,
body,
};

Ok(Self {
inner: Arc::new(resp),
})
}

#[cfg(test)]
pub fn new(
url: Url,
status: reqwest::StatusCode,
headers: HeaderMap,
body: Box<[u8]>,
) -> Self {
Self {
inner: Arc::new(InnerResponse {
url,
status,
headers,
body,
}),
}
}

pub fn error_for_status(self) -> Result<Self> {
let status = self.inner.status;
if status.is_client_error() || status.is_server_error() {
return Err(Error::HttpStatus(status, self.inner.url.clone()));
}

Ok(self)
}

pub fn header(&self, name: &str) -> Option<&str> {
self.inner.headers.get(name).and_then(|v| v.to_str().ok())
}

pub fn text_with_charset(&self, default_encoding: &str) -> Result<String> {
let content_type = self.content_type();
let encoding_name = content_type
.as_ref()
.and_then(|mime| {
mime.get_param("charset").map(|charset| charset.as_str())
})
.unwrap_or(default_encoding);
let encoding = encoding_rs::Encoding::for_label(encoding_name.as_bytes())
.unwrap_or(encoding_rs::UTF_8);

let full = &self.inner.body;
let (text, _, _) = encoding.decode(full);
Ok(text.into_owned())
}

pub fn text(&self) -> Result<String> {
self.text_with_charset("utf-8")
}

pub fn content_type(&self) -> Option<Mime> {
self.header("content-type").and_then(|v| v.parse().ok())
}

pub fn url(&self) -> &Url {
&self.inner.url
}
pub fn status(&self) -> reqwest::StatusCode {
self.inner.status
}
pub fn headers(&self) -> &HeaderMap {
&self.inner.headers
}
pub fn body(&self) -> &[u8] {
&self.inner.body
}
}
Loading

0 comments on commit 20937b9

Please sign in to comment.