Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(proxy/jwks): reduce the rightward drift of jwks renewal #9853

Merged
merged 4 commits into from
Nov 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 96 additions & 81 deletions proxy/src/auth/backend/jwt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,93 @@ struct JwkSet<'a> {
keys: Vec<&'a RawValue>,
}

/// Given a jwks_url, fetch the JWKS and parse out all the signing JWKs.
/// Returns `None` and log a warning if there are any errors.
async fn fetch_jwks(
client: &reqwest_middleware::ClientWithMiddleware,
jwks_url: url::Url,
) -> Option<jose_jwk::JwkSet> {
let req = client.get(jwks_url.clone());
// TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
let resp = req.send().await.and_then(|r| {
r.error_for_status()
.map_err(reqwest_middleware::Error::Reqwest)
});

let resp = match resp {
Ok(r) => r,
// TODO: should we re-insert JWKs if we want to keep this JWKs URL?
// I expect these failures would be quite sparse.
Err(e) => {
tracing::warn!(url=?jwks_url, error=?e, "could not fetch JWKs");
return None;
}
};

let resp: http::Response<reqwest::Body> = resp.into();

let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE).await {
Ok(bytes) => bytes,
Err(e) => {
tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
return None;
}
};

let jwks = match serde_json::from_slice::<JwkSet>(&bytes) {
Ok(jwks) => jwks,
Err(e) => {
tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs");
return None;
}
};

// `jose_jwk::Jwk` is quite large (288 bytes). Let's not pre-allocate for what we don't need.
//
// Even though we limit our responses to 64KiB, we could still receive a payload like
// `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`. Parsing this as `RawValue` uses 468KiB.
// Pre-allocating the corresponding `Vec::<jose_jwk::Jwk>::with_capacity(30000)` uses 8.2MiB.
let mut keys = vec![];

let mut failed = 0;
for key in jwks.keys {
let key = match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
Ok(key) => key,
Err(e) => {
tracing::debug!(url=?jwks_url, failed=?e, "could not decode JWK");
failed += 1;
continue;
}
};

// if `use` (called `cls` in rust) is specified to be something other than signing,
// we can skip storing it.
if key
.prm
.cls
.as_ref()
.is_some_and(|c| *c != jose_jwk::Class::Signing)
{
continue;
}

keys.push(key);
}

keys.shrink_to_fit();

if failed > 0 {
tracing::warn!(url=?jwks_url, failed, "could not decode JWKs");
}

if keys.is_empty() {
tracing::warn!(url=?jwks_url, "no valid JWKs found inside the response body");
return None;
}

Some(jose_jwk::JwkSet { keys })
}

impl JwkCacheEntryLock {
async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
JwkRenewalPermit::acquire_permit(self).await
Expand Down Expand Up @@ -166,87 +253,15 @@ impl JwkCacheEntryLock {
// TODO(conrad): run concurrently
// TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
for rule in rules {
let req = client.get(rule.jwks_url.clone());
// TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
// TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
match req.send().await.and_then(|r| {
r.error_for_status()
.map_err(reqwest_middleware::Error::Reqwest)
}) {
// todo: should we re-insert JWKs if we want to keep this JWKs URL?
// I expect these failures would be quite sparse.
Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
Ok(r) => {
let resp: http::Response<reqwest::Body> = r.into();

let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE)
.await
{
Ok(bytes) => bytes,
Err(e) => {
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
continue;
}
};

match serde_json::from_slice::<JwkSet>(&bytes) {
Err(e) => {
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
}
Ok(jwks) => {
// size_of::<&RawValue>() == 16
// size_of::<jose_jwk::Jwk>() == 288
// better to not pre-allocate this as it might be pretty large - especially if it has many
// keys we don't want or need.
// trivial 'attack': `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`
// this would consume 8MiB just like that!
let mut keys = vec![];
let mut failed = 0;
for key in jwks.keys {
match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
Ok(key) => {
// if `use` (called `cls` in rust) is specified to be something other than signing,
// we can skip storing it.
if key
.prm
.cls
.as_ref()
.is_some_and(|c| *c != jose_jwk::Class::Signing)
{
continue;
}

keys.push(key);
}
Err(e) => {
tracing::debug!(url=?rule.jwks_url, failed=?e, "could not decode JWK");
failed += 1;
}
}
}
keys.shrink_to_fit();

if failed > 0 {
tracing::warn!(url=?rule.jwks_url, failed, "could not decode JWKs");
}

if keys.is_empty() {
tracing::warn!(url=?rule.jwks_url, "no valid JWKs found inside the response body");
continue;
}

let jwks = jose_jwk::JwkSet { keys };
key_sets.insert(
rule.id,
KeySet {
jwks,
audience: rule.audience,
role_names: rule.role_names,
},
);
}
};
}
if let Some(jwks) = fetch_jwks(client, rule.jwks_url).await {
key_sets.insert(
rule.id,
KeySet {
jwks,
audience: rule.audience,
role_names: rule.role_names,
},
);
}
}

Expand Down
Loading