Skip to content

Commit

Permalink
Feat: cache public key (#655)
Browse files Browse the repository at this point in the history
* feat: move cache to common, cache public key

* fix: missing import

* refactor: public key cache only needs one key

* refactor: workspace deps

* feat: tracing calls in public_key cache flow

* refactor: init cache in AuthPublicKey::new

* fix: clippy
  • Loading branch information
oddgrd authored Feb 27, 2023
1 parent fb7c5ae commit 13d8bf0
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 75 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,5 @@ tower-http = { version = "0.3.4", features = ["trace"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.18.0"
tracing-subscriber = "0.3.16"
ttl_cache = "0.5.1"
uuid = "1.2.2"
6 changes: 4 additions & 2 deletions auth/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
[package]
name = "shuttle-auth"
version = "0.1.0"
edition = "2021"
version.workspace = true
edition.workspace = true
license.workspace = true
repository.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
Expand Down
3 changes: 2 additions & 1 deletion common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ tower = { workspace = true, optional = true }
tower-http = { workspace = true, optional = true }
tracing = { workspace = true }
tracing-opentelemetry = { workspace = true, optional = true }
ttl_cache = { workspace = true, optional = true }
uuid = { workspace = true, features = ["v4", "serde"] }

[features]
backend = ["async-trait", "axum", "bytes", "http", "http-body", "hyper/client", "jsonwebtoken", "opentelemetry", "opentelemetry-http", "thiserror", "tower", "tower-http", "tracing-opentelemetry"]
backend = ["async-trait", "axum", "bytes", "http", "http-body", "hyper/client", "jsonwebtoken", "opentelemetry", "opentelemetry-http", "thiserror", "tower", "tower-http", "tracing-opentelemetry", "ttl_cache"]
display = ["comfy-table", "crossterm"]
models = ["anyhow", "async-trait", "display", "http", "reqwest", "serde_json"]

Expand Down
53 changes: 40 additions & 13 deletions common/src/backends/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{
future::Future,
ops::Add,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

Expand All @@ -20,10 +21,14 @@ use thiserror::Error;
use tower::{Layer, Service};
use tracing::{error, trace};

use super::headers::XShuttleAdminSecret;
use super::{
cache::{CacheManagement, CacheManager},
headers::XShuttleAdminSecret,
};

const EXP_MINUTES: i64 = 5;
const ISS: &str = "shuttle";
const PUBLIC_KEY_CACHE_KEY: &str = "shuttle.public-key";

/// Layer to check the admin secret set by deployer is correct
#[derive(Clone)]
Expand Down Expand Up @@ -278,11 +283,16 @@ where
#[derive(Clone)]
pub struct AuthPublicKey {
auth_uri: Uri,
cache_manager: Arc<Box<dyn CacheManagement<Value = Vec<u8>>>>,
}

impl AuthPublicKey {
pub fn new(auth_uri: Uri) -> Self {
Self { auth_uri }
let public_key_cache_manager = CacheManager::new(1);
Self {
auth_uri,
cache_manager: Arc::new(Box::new(public_key_cache_manager)),
}
}
}

Expand All @@ -291,11 +301,25 @@ impl PublicKeyFn for AuthPublicKey {
type Error = PublicKeyFnError;

async fn public_key(&self) -> Result<Vec<u8>, Self::Error> {
let client = Client::new();
let uri = format!("{}public-key", self.auth_uri).parse()?;
let res = client.get(uri).await?;
let buf = body::to_bytes(res).await?;
Ok(buf.to_vec())
if let Some(public_key) = self.cache_manager.get(PUBLIC_KEY_CACHE_KEY) {
trace!("found public key in the cache, returning it");

Ok(public_key)
} else {
let client = Client::new();
let uri = format!("{}public-key", self.auth_uri).parse()?;
let res = client.get(uri).await?;
let buf = body::to_bytes(res).await?;

trace!("inserting public key from auth service into cache");
self.cache_manager.insert(
PUBLIC_KEY_CACHE_KEY,
buf.to_vec(),
std::time::Duration::from_secs(60),
);

Ok(buf.to_vec())
}
}
}

Expand Down Expand Up @@ -378,16 +402,20 @@ where

this.inner.call(req).await
}
Err(code) => Ok(Response::builder()
.status(code)
.body(Default::default())
.unwrap()),
Err(code) => {
error!(code = %code, "failed to decode JWT");

Ok(Response::builder()
.status(code)
.body(Default::default())
.unwrap())
}
}
}
Err(error) => {
error!(
error = &error as &dyn std::error::Error,
"failed to get public key"
"failed to get public key from auth service"
);

Ok(Response::builder()
Expand Down Expand Up @@ -653,7 +681,6 @@ mod tests {
ServiceBuilder::new()
.layer(JwtAuthenticationLayer::new(move || {
let public_key = public_key.clone();

async move { public_key.clone() }
}))
.layer(ScopedLayer::new(vec![Scope::Project])),
Expand Down
51 changes: 51 additions & 0 deletions common/src/backends/cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use std::{
sync::{Arc, RwLock},
time::Duration,
};
use ttl_cache::TtlCache;

pub trait CacheManagement: Send + Sync {
type Value;

fn get(&self, key: &str) -> Option<Self::Value>;
fn insert(&self, key: &str, value: Self::Value, ttl: Duration) -> Option<Self::Value>;
fn invalidate(&self, key: &str) -> Option<Self::Value>;
}

pub struct CacheManager<T> {
pub cache: Arc<RwLock<TtlCache<String, T>>>,
}

impl<T> CacheManager<T> {
pub fn new(capacity: usize) -> Self {
let cache = Arc::new(RwLock::new(TtlCache::new(capacity)));

Self { cache }
}
}

impl<T: Send + Sync + Clone> CacheManagement for CacheManager<T> {
type Value = T;

fn get(&self, key: &str) -> Option<Self::Value> {
self.cache
.read()
.expect("cache lock should not be poisoned")
.get(key)
.cloned()
}

fn insert(&self, key: &str, value: T, ttl: Duration) -> Option<Self::Value> {
self.cache
.write()
.expect("cache lock should not be poisoned")
.insert(key.to_string(), value, ttl)
}

fn invalidate(&self, key: &str) -> Option<Self::Value> {
self.cache
.write()
.expect("cache lock should not be poisoned")
.remove(key)
}
}
1 change: 1 addition & 0 deletions common/src/backends/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod auth;
pub mod cache;
pub mod headers;
pub mod metrics;
2 changes: 1 addition & 1 deletion gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ tower = { workspace = true, features = [ "steer" ] }
tracing = { workspace = true }
tracing-opentelemetry = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
ttl_cache = "0.5.1"
ttl_cache = { workspace = true }
uuid = { workspace = true, features = [ "v4" ] }

[dependencies.shuttle-common]
Expand Down
13 changes: 7 additions & 6 deletions gateway/src/api/auth_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ use hyper_reverse_proxy::ReverseProxy;
use once_cell::sync::Lazy;
use opentelemetry::global;
use opentelemetry_http::HeaderInjector;
use shuttle_common::backends::auth::ConvertResponse;
use shuttle_common::backends::{auth::ConvertResponse, cache::CacheManagement};
use tower::{Layer, Service};
use tracing::{error, trace, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;

use super::cache::CacheManagement;

static PROXY_CLIENT: Lazy<ReverseProxy<HttpConnector<GaiResolver>>> =
Lazy::new(|| ReverseProxy::new(Client::new()));

Expand All @@ -34,11 +32,14 @@ static PROXY_CLIENT: Lazy<ReverseProxy<HttpConnector<GaiResolver>>> =
#[derive(Clone)]
pub struct ShuttleAuthLayer {
auth_uri: Uri,
cache_manager: Arc<Box<dyn CacheManagement>>,
cache_manager: Arc<Box<dyn CacheManagement<Value = String>>>,
}

impl ShuttleAuthLayer {
pub fn new(auth_uri: Uri, cache_manager: Arc<Box<dyn CacheManagement>>) -> Self {
pub fn new(
auth_uri: Uri,
cache_manager: Arc<Box<dyn CacheManagement<Value = String>>>,
) -> Self {
Self {
auth_uri,
cache_manager,
Expand All @@ -62,7 +63,7 @@ impl<S> Layer<S> for ShuttleAuthLayer {
pub struct ShuttleAuthService<S> {
inner: S,
auth_uri: Uri,
cache_manager: Arc<Box<dyn CacheManagement>>,
cache_manager: Arc<Box<dyn CacheManagement<Value = String>>>,
}

impl<S> Service<Request<Body>> for ShuttleAuthService<S>
Expand Down
47 changes: 0 additions & 47 deletions gateway/src/api/cache.rs

This file was deleted.

6 changes: 3 additions & 3 deletions gateway/src/api/latest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use http::{StatusCode, Uri};
use instant_acme::{AccountCredentials, ChallengeType};
use serde::{Deserialize, Serialize};
use shuttle_common::backends::auth::{AuthPublicKey, JwtAuthenticationLayer, Scope, ScopedLayer};
use shuttle_common::backends::cache::CacheManager;
use shuttle_common::backends::metrics::{Metrics, TraceLayer};
use shuttle_common::models::error::ErrorKind;
use shuttle_common::models::{project, stats};
Expand All @@ -36,7 +37,6 @@ use crate::worker::WORKER_QUEUE_SIZE;
use crate::{Error, GatewayService, ProjectName};

use super::auth_layer::ShuttleAuthLayer;
use super::cache::CacheManager;

pub const SVC_DEGRADED_THRESHOLD: usize = 128;

Expand Down Expand Up @@ -493,14 +493,14 @@ impl ApiBuilder {
pub fn with_auth_service(mut self, auth_uri: Uri) -> Self {
let auth_public_key = AuthPublicKey::new(auth_uri.clone());

let cache_manager = CacheManager::new();
let jwt_cache_manager = CacheManager::new(1000);

self.router = self
.router
.layer(JwtAuthenticationLayer::new(auth_public_key))
.layer(ShuttleAuthLayer::new(
auth_uri,
Arc::new(Box::new(cache_manager)),
Arc::new(Box::new(jwt_cache_manager)),
));

self
Expand Down
1 change: 0 additions & 1 deletion gateway/src/api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
mod auth_layer;
mod cache;

pub mod latest;

0 comments on commit 13d8bf0

Please sign in to comment.