Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of Concept: Implementation of Refresh Access Token in Rust #13

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions CustomIdentityComponent/lambda_rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[package]
name = "refresh-access-token"
version = "0.1.0"
edition = "2021"

[dependencies]
aws-config = "0"
aws-sdk-secretsmanager = "0"
cached = { version ="0", features = ["async"] }
http = "0.2"
jsonwebkey = { version ="0.3", features = ["jwt-convert"] }
jsonwebtoken = "8"
lambda_http = { version = "0.8", default-features = false, features = ["apigw_http"] }
lambda_runtime = "0.8"
metrics = "0.21.1"
metrics_cloudwatch_embedded = { version = "0.4.1", features = ["lambda"] }
reqwest = { version = "0.11", default-features = false, features = ["json", "rustls-tls"] }
reqwest-middleware = "0.2"
reqwest-retry = "0.3"
serde = {version = "1.0", features = ["derive"] }
serde_json = "1.0"
time = "0.3"
tokio = { version = "1", features = ["macros"] }
tracing = { version = "0.1", features = ["log"] }
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "env-filter", "json"] }

[profile.release]
opt-level = "z"
lto = true
codegen-units = 1
panic = "abort"
257 changes: 257 additions & 0 deletions CustomIdentityComponent/lambda_rust/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
use cached::proc_macro::{cached, once};
use lambda_http::{Body, Error, Request, RequestExt, Response};
use metrics_cloudwatch_embedded::lambda::handler::run_http;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info, info_span};

/// Input Jwt token claims
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct ClaimsIn {
sub: String,
iss: String,
kid: String,
aud: String,
scope: String,
access_token_scope: Option<String>,
iat: i64,
nbf: i64,
exp: i64,
}

/// Output Jwt token claims with references to save some allocations
#[derive(Debug, Serialize)]
struct ClaimsOut<'a> {
sub: &'a str,
iss: &'a str,
kid: &'a str,
aud: &'a str,
scope: &'a str,
access_token_scope: Option<&'a str>,
iat: i64,
nbf: i64,
exp: i64,
}

/// Json body of success responses
#[derive(Debug, Serialize)]
struct ResponsePayload<'a> {
user_id: &'a str,
auth_token: &'a str,
refresh_token: &'a str,
auth_token_expires_in: i64,
refresh_token_expires_in: i64,
}

fn generate_response(code: u16, body: &str) -> Response<Body> {
Response::builder()
.status(code)
.header("Access-Control-Allow-Origin", "*")
.header("Access-Control-Allow-Credentials", "true")
.body(body.into())
.expect("failed to generate response")
}

#[cached]
/// get our (cached) aws configuration
async fn get_aws_config() -> Arc<aws_config::SdkConfig> {
Arc::new(aws_config::load_from_env().await)
}

#[cached(time = 900)]
/// get our private kid and key from secrects manager, panic on failure
async fn get_private_key() -> (Arc<String>, Arc<jsonwebtoken::EncodingKey>) {
info!("refreshing private key from Secrets Manager");

let aws_config = get_aws_config().await;
let secrets_client = aws_sdk_secretsmanager::Client::new(&aws_config);

let jwk: jsonwebkey::JsonWebKey = secrets_client
.get_secret_value()
.secret_id(std::env::var("SECRET_KEY_ID").unwrap())
.send()
.await
.expect("failed to get SECRET_KEY_ID")
.secret_string()
.expect("SECRET_KEY_ID is blank")
.to_string()
.parse()
.expect("private key is not a valid jwk");

(
Arc::new(jwk.key_id.unwrap()),
Arc::new(jsonwebtoken::EncodingKey::from_rsa_pem(jwk.key.to_pem().as_bytes()).unwrap()),
)
}

#[once(time = 900)]
/// get the json web keyset for our issuer, panic on failure
async fn get_keyset(issuer: &str) -> Arc<HashMap<String, jsonwebtoken::DecodingKey>> {
info!("Refreshing json web keyset");

use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};

let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();

let jwks = client
.get(format!("{issuer}/.well-known/jwks.json"))
.send()
.await
.unwrap()
.json::<jsonwebtoken::jwk::JwkSet>()
.await
.unwrap();

let mut dict = HashMap::new();
for jwk in jwks.keys {
if let (Some(key_id), jsonwebtoken::jwk::AlgorithmParameters::RSA(rsa)) =
(jwk.common.key_id, &jwk.algorithm)
{
dict.insert(
key_id,
jsonwebtoken::DecodingKey::from_rsa_components(&rsa.n, &rsa.e).unwrap(),
);
}
}

if dict.is_empty() {
panic!("jwks has no valid keys");
}

Arc::new(dict)
}

async fn process_token(issuer: &str, refresh_token: &str) -> Result<Response<Body>, Error> {
let header = jsonwebtoken::decode_header(refresh_token)?;
let kid = header.kid.ok_or("kid missing from jwt header")?;

let jks = get_keyset(issuer).await;
let public_key = jks.get(&kid).ok_or("kid not in jks")?;

let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::RS256);
validation.set_audience(&["refresh"]);
validation.set_issuer(&[issuer]);

let jwt = jsonwebtoken::decode::<ClaimsIn>(refresh_token, public_key, &validation)?;
debug!("jwt = {jwt:?}");

let user_id = jwt.claims.sub.as_str();
let access_token_scope = &jwt
.claims
.access_token_scope
.ok_or("missing access_token_scope claim")?;
let access_token_duration_sec = 15 * 60;
let existing_exp_value = jwt.claims.exp;

let (private_kid, private_key) = get_private_key().await;

// Build a new header with the latest kid
let mut new_header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256);
new_header.kid = Some(private_kid.to_string());

let now = time::OffsetDateTime::now_utc().unix_timestamp();

// Build a new refresh token
let refresh_claims = ClaimsOut {
sub: user_id,
iss: issuer,
kid: &private_kid,
aud: "refresh",
scope: "refresh",
access_token_scope: Some(access_token_scope),
iat: now,
nbf: now,
exp: existing_exp_value,
};
let refresh_token = jsonwebtoken::encode(&new_header, &refresh_claims, &private_key)?;

// Build a new access token
let access_claims = ClaimsOut {
sub: user_id,
iss: issuer,
kid: &private_kid,
aud: "gamebackend",
scope: access_token_scope,
access_token_scope: None,
iat: now,
nbf: now,
exp: now + access_token_duration_sec,
};
let access_token = jsonwebtoken::encode(&new_header, &access_claims, &private_key)?;

let response_payload = ResponsePayload {
user_id,
auth_token: &access_token,
auth_token_expires_in: access_token_duration_sec,
refresh_token: &refresh_token,
refresh_token_expires_in: existing_exp_value - now,
};

Ok(generate_response(
200,
&serde_json::to_string(&response_payload)?,
))
}

async fn function_handler(issuer: &str, request: Request) -> Result<Response<Body>, Error> {
// Get the refresh_token from the query string
let query = request.query_string_parameters();
let refresh_token = query.first("refresh_token");

match refresh_token {
None => {
metrics::increment_counter!("deny", "reason" => "No refresh token provided");
Ok(generate_response(401, "Error: No refresh token provided"))
}
Some(refresh_token) => match process_token(issuer, refresh_token).await {
Ok(response) => {
metrics::increment_counter!("allow");
Ok(response)
}
Err(e) => {
// Record the details but don't give the remote client specifics
metrics::increment_counter!("deny", "reason" => e.to_string());
Ok(generate_response(
401,
"Error: Failed to validate refresh token",
))
}
},
}
}

#[tokio::main]
async fn main() -> Result<(), Error> {
tracing_subscriber::fmt()
.json()
.with_env_filter(tracing_subscriber::filter::EnvFilter::from_default_env())
.with_target(false)
.with_current_span(false)
.without_time()
.init();

let issuer = std::env::var("ISSUER_URL").unwrap();

let metrics = metrics_cloudwatch_embedded::Builder::new()
.cloudwatch_namespace(std::env::var("POWERTOOLS_METRICS_NAMESPACE").unwrap())
.with_dimension("service", std::env::var("POWERTOOLS_SERVICE_NAME").unwrap())
.with_dimension(
"function",
std::env::var("AWS_LAMBDA_FUNCTION_NAME").unwrap(),
)
.lambda_cold_start_span(info_span!("cold start").entered())
.lambda_cold_start_metric("ColdStart")
.with_lambda_request_id("requestId")
.init()
.unwrap();

run_http(metrics, |request: Request| {
function_handler(&issuer, request)
})
.await
}
35 changes: 33 additions & 2 deletions CustomIdentityComponent/lib/custom_identity_component-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import { Stack, StackProps, CfnOutput, Duration } from 'aws-cdk-lib';
import { Construct } from 'constructs';
import { RustFunction } from 'cargo-lambda-cdk';
import * as lambda from 'aws-cdk-lib/aws-lambda';
import * as cloudfront from 'aws-cdk-lib/aws-cloudfront';
import * as s3 from 'aws-cdk-lib/aws-s3';
Expand Down Expand Up @@ -362,8 +363,38 @@ export class CustomIdentityComponentStack extends Stack {
{ id: 'AwsSolutions-IAM5', reason: 'Using the standard Lambda execution role, all custom access resource restricted.' }
], true);

// Map login_as_guest_function to the api_gateway GET requeste login_as_guest
api_gateway.root.addResource('refresh-access-token').addMethod('GET', new apigw.LambdaIntegration(refresh_access_token_function),{
// Map refresh_access_token_function to the api_gateway GET requeste refresh_access_token_function
api_gateway.root.addResource('refresh-access-token').addMethod('GET', new apigw.LambdaIntegration(refresh_access_token_function), {
requestParameters: {
'method.request.querystring.refresh_token': true
},
requestValidator: requestValidator
});

const refresh_access_token_rust_function = new RustFunction(this, 'RefreshAccessToken_Rust', {
role: refresh_access_token_function_role,
manifestPath: 'lambda_rust/Cargo.toml',
architecture: lambda.Architecture.ARM_64,
bundling: {
forcedDockerBundling: true,
},
timeout: Duration.seconds(5),
tracing: lambda.Tracing.ACTIVE,
memorySize: 256,
environment: {
"ISSUER_URL": "https://" + distribution.domainName,
"POWERTOOLS_METRICS_NAMESPACE": "AWS for Games",
"POWERTOOLS_SERVICE_NAME": "CustomIdentityComponent",
"RUST_LOG": "info",
"SECRET_KEY_ID": secret.secretName,
"USER_TABLE": user_table.tableName
}
});
secret.grantRead(refresh_access_token_rust_function);
user_table.grantReadWriteData(refresh_access_token_rust_function);

// Map refresh_access_token_function to the api_gateway GET requeste refresh_access_token_function
api_gateway.root.addResource('refresh-access-token-rust').addMethod('GET', new apigw.LambdaIntegration(refresh_access_token_rust_function), {
requestParameters: {
'method.request.querystring.refresh_token': true
},
Expand Down
1 change: 1 addition & 0 deletions CustomIdentityComponent/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
},
"dependencies": {
"aws-cdk-lib": "^2.97.0",
"cargo-lambda-cdk": "^0.0.16",
"cdk": "^2.81.0-alpha.0",
"cdk-nag": "^2.27.24",
"constructs": "^10.0.0",
Expand Down