Skip to content

Commit

Permalink
feat(s3s/ops): forward region & service to S3Request
Browse files Browse the repository at this point in the history
  • Loading branch information
Nugine committed Dec 10, 2024
1 parent 5c64b47 commit 09fae30
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 46 deletions.
2 changes: 2 additions & 0 deletions crates/s3s/src/http/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ pub(crate) struct S3Extensions {
pub vec_stream: Option<VecByteStream>,

pub credentials: Option<Credentials>,
pub region: Option<String>,
pub service: Option<String>,
}

impl From<hyper::Request<Body>> for Request {
Expand Down
36 changes: 27 additions & 9 deletions crates/s3s/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ mod get_object;
mod tests;

use crate::access::{S3Access, S3AccessContext};
use crate::auth::S3Auth;
use crate::auth::{Credentials, S3Auth};
use crate::error::*;
use crate::header;
use crate::host::S3Host;
Expand Down Expand Up @@ -55,18 +55,23 @@ pub struct CallContext<'a> {
}

fn build_s3_request<T>(input: T, req: &mut Request) -> S3Request<T> {
let credentials = req.s3ext.credentials.take();
let extensions = mem::take(&mut req.extensions);
let headers = mem::take(&mut req.headers);
let method = req.method.clone();
let uri = mem::take(&mut req.uri);
let headers = mem::take(&mut req.headers);
let extensions = mem::take(&mut req.extensions);
let credentials = req.s3ext.credentials.take();
let region = req.s3ext.region.take();
let service = req.s3ext.service.take();

S3Request {
method,
uri,
headers,
input,
credentials,
extensions,
headers,
uri,
method: req.method.clone(),
credentials,
region,
service,
}
}

Expand Down Expand Up @@ -297,7 +302,20 @@ async fn prepare(req: &mut Request, ccx: &CallContext<'_>) -> S3Result<Prepare>
transformed_body = scx.transformed_body;

req.s3ext.multipart = scx.multipart;
req.s3ext.credentials = credentials;

match credentials {
Some(cred) => {
req.s3ext.credentials = Some(Credentials {
access_key: cred.access_key,
secret_key: cred.secret_key,
});
req.s3ext.region = cred.region;
req.s3ext.service = cred.service;
}
None => {
req.s3ext.credentials = None;
}
}
}

if body_changed {
Expand Down
62 changes: 44 additions & 18 deletions crates/s3s/src/ops/signature.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::auth::Credentials;
use crate::auth::S3Auth;
use crate::auth::SecretKey;
use crate::error::*;
use crate::http;
use crate::http::{AwsChunkedStream, Body, Multipart};
Expand Down Expand Up @@ -72,12 +72,19 @@ pub struct SignatureContext<'a> {
pub multipart: Option<Multipart>,
}

pub struct CredentialsExt {
pub access_key: String,
pub secret_key: SecretKey,
pub region: Option<String>,
pub service: Option<String>,
}

fn require_auth(auth: Option<&dyn S3Auth>) -> S3Result<&dyn S3Auth> {
auth.ok_or_else(|| s3_error!(NotImplemented, "This service has no authentication provider"))
}

impl SignatureContext<'_> {
pub async fn check(&mut self) -> S3Result<Option<Credentials>> {
pub async fn check(&mut self) -> S3Result<Option<CredentialsExt>> {
if let Some(result) = self.v2_check().await {
debug!("checked signature v2");
return Ok(Some(result?));
Expand All @@ -92,7 +99,7 @@ impl SignatureContext<'_> {
}

#[tracing::instrument(skip(self))]
pub async fn v4_check(&mut self) -> Option<S3Result<Credentials>> {
pub async fn v4_check(&mut self) -> Option<S3Result<CredentialsExt>> {
// POST auth
if self.req_method == Method::POST {
if let Some(ref mime) = self.mime {
Expand Down Expand Up @@ -120,7 +127,7 @@ impl SignatureContext<'_> {
None
}

pub async fn v4_check_post_signature(&mut self) -> S3Result<Credentials> {
pub async fn v4_check_post_signature(&mut self) -> S3Result<CredentialsExt> {
let auth = require_auth(self.auth)?;

let multipart = {
Expand Down Expand Up @@ -157,21 +164,31 @@ impl SignatureContext<'_> {
let access_key = credential.access_key_id.to_owned();
let secret_key = auth.get_secret_key(&access_key).await?;

let region = credential.aws_region;
let service = credential.aws_service;

let string_to_sign = info.policy;
let signature =
sig_v4::calculate_signature(string_to_sign, &secret_key, &amz_date, credential.aws_region, credential.aws_service);
let signature = sig_v4::calculate_signature(string_to_sign, &secret_key, &amz_date, region, service);

let expected_signature = info.x_amz_signature;
if signature != expected_signature {
debug!(?signature, expected=?expected_signature, "signature mismatch");
return Err(s3_error!(SignatureDoesNotMatch));
}

let region = region.to_owned();
let service = service.to_owned();

self.multipart = Some(multipart);
Ok(Credentials { access_key, secret_key })
Ok(CredentialsExt {
access_key,
secret_key,
region: Some(region),
service: Some(service),
})
}

pub async fn v4_check_presigned_url(&mut self) -> S3Result<Credentials> {
pub async fn v4_check_presigned_url(&mut self) -> S3Result<CredentialsExt> {
let qs = self.qs.unwrap(); // assume: qs has "X-Amz-Signature"

let presigned_url = PresignedUrlV4::parse(qs).map_err(|err| invalid_request!(err, "missing presigned url v4 fields"))?;
Expand Down Expand Up @@ -217,15 +234,16 @@ impl SignatureContext<'_> {
let access_key = presigned_url.credential.access_key_id;
let secret_key = auth.get_secret_key(access_key).await?;

let region = presigned_url.credential.aws_region;
let service = presigned_url.credential.aws_service;

let signature = {
let headers = self.hs.find_multiple(&presigned_url.signed_headers);
let method = &self.req_method;
let uri_path = &self.decoded_uri_path;

let canonical_request = sig_v4::create_presigned_canonical_request(method, uri_path, qs.as_ref(), &headers);

let region = presigned_url.credential.aws_region;
let service = presigned_url.credential.aws_service;
let amz_date = &presigned_url.amz_date;
let string_to_sign = sig_v4::create_string_to_sign(&canonical_request, amz_date, region, service);

Expand All @@ -238,14 +256,16 @@ impl SignatureContext<'_> {
return Err(s3_error!(SignatureDoesNotMatch));
}

Ok(Credentials {
Ok(CredentialsExt {
access_key: access_key.into(),
secret_key,
region: Some(region.into()),
service: Some(service.into()),
})
}

#[tracing::instrument(skip(self))]
pub async fn v4_check_header_auth(&mut self) -> S3Result<Credentials> {
pub async fn v4_check_header_auth(&mut self) -> S3Result<CredentialsExt> {
let authorization: AuthorizationV4<'_> = {
// assume: headers has "authorization"
let mut a = extract_authorization_v4(&self.hs)?.unwrap();
Expand Down Expand Up @@ -344,14 +364,16 @@ impl SignatureContext<'_> {
self.transformed_body = Some(Body::from(stream.into_byte_stream()));
}

Ok(Credentials {
Ok(CredentialsExt {
access_key: access_key.into(),
secret_key,
region: Some(region.into()),
service: Some(service.into()),
})
}

#[tracing::instrument(skip(self))]
pub async fn v2_check(&mut self) -> Option<S3Result<Credentials>> {
pub async fn v2_check(&mut self) -> Option<S3Result<CredentialsExt>> {
if let Some(qs) = self.qs {
if qs.has("Signature") {
debug!("checking presigned url");
Expand All @@ -369,7 +391,7 @@ impl SignatureContext<'_> {
None
}

pub async fn v2_check_header_auth(&mut self, auth_v2: AuthorizationV2<'_>) -> S3Result<Credentials> {
pub async fn v2_check_header_auth(&mut self, auth_v2: AuthorizationV2<'_>) -> S3Result<CredentialsExt> {
let method = &self.req_method;

let date = self.hs.get_unique("date").or_else(|| self.hs.get_unique("x-amz-date"));
Expand Down Expand Up @@ -399,13 +421,15 @@ impl SignatureContext<'_> {
return Err(s3_error!(SignatureDoesNotMatch));
}

Ok(Credentials {
Ok(CredentialsExt {
access_key: access_key.into(),
secret_key,
region: None,
service: Some("s3".into()),
})
}

pub async fn v2_check_presigned_url(&mut self) -> S3Result<Credentials> {
pub async fn v2_check_presigned_url(&mut self) -> S3Result<CredentialsExt> {
let qs = self.qs.unwrap(); // assume: qs has "Signature"
let presigned_url = PresignedUrlV2::parse(qs).map_err(|err| invalid_request!(err, "missing presigned url v2 fields"))?;

Expand Down Expand Up @@ -433,9 +457,11 @@ impl SignatureContext<'_> {
return Err(s3_error!(SignatureDoesNotMatch));
}

Ok(Credentials {
Ok(CredentialsExt {
access_key: access_key.into(),
secret_key,
region: None,
service: Some("s3".into()),
})
}
}
48 changes: 29 additions & 19 deletions crates/s3s/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,59 @@ use stdx::default::default;
#[derive(Debug)]
#[non_exhaustive]
pub struct S3Request<T> {
/// HTTP method
pub method: Method,

// Raw URI
pub uri: Uri,

// Headers
pub headers: HeaderMap<HeaderValue>,

/// Operation input
pub input: T,

/// Identity information.
///
/// `None` means anonymous request.
pub credentials: Option<Credentials>,

/// Request extensions
///
/// It is used to pass custom data between middlewares.
pub extensions: Extensions,

// Headers
pub headers: HeaderMap<HeaderValue>,
/// Identity information.
///
/// `None` means anonymous request.
pub credentials: Option<Credentials>,

// Raw URI
pub uri: Uri,
/// The requested region.
pub region: Option<String>,

/// HTTP method
pub method: Method,
/// The requested service.
pub service: Option<String>,
}

impl<T> S3Request<T> {
pub fn new(input: T) -> Self {
Self {
method: default(),
uri: default(),
headers: default(),
input,
credentials: default(),
extensions: default(),
headers: default(),
uri: default(),
method: default(),
credentials: default(),
region: default(),
service: default(),
}
}

pub fn map_input<U>(self, f: impl FnOnce(T) -> U) -> S3Request<U> {
S3Request {
method: self.method,
uri: self.uri,
headers: self.headers,
input: f(self.input),
credentials: self.credentials,
extensions: self.extensions,
headers: self.headers,
uri: self.uri,
method: self.method,
credentials: self.credentials,
region: self.region,
service: self.service,
}
}
}

0 comments on commit 09fae30

Please sign in to comment.