Skip to content

Commit

Permalink
Reimplemented JwtAuthentication with struct-based Future. (#868)
Browse files Browse the repository at this point in the history
* Reimplemented JwtAuthentication with struct-based Future.

* More effective encoding

* Remove explicit lifetime

* Code cleanup

* Code cleanup

---------

Co-authored-by: root <root@razor.localdomain>
  • Loading branch information
arturaz and root authored May 11, 2023
1 parent 3ca63c6 commit d4322be
Showing 1 changed file with 115 additions and 45 deletions.
160 changes: 115 additions & 45 deletions common/src/backends/auth.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{convert::Infallible, future::Future, pin::Pin, sync::Arc};
use std::{convert::Infallible, future::Future, pin::Pin, sync::Arc, task::Poll};

use async_trait::async_trait;
use bytes::Bytes;
Expand All @@ -8,6 +8,7 @@ use http_body::combinators::UnsyncBoxBody;
use hyper::{body, Body, Client};
use opentelemetry::global;
use opentelemetry_http::HeaderInjector;
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tower::{Layer, Service};
Expand Down Expand Up @@ -211,6 +212,106 @@ pub struct JwtAuthentication<S, F> {
public_key_fn: F,
}

type AsyncTraitFuture<A> = Pin<Box<dyn Future<Output = A> + Send>>;

#[pin_project(project = JwtAuthenticationFutureProj, project_replace = JwtAuthenticationFutureProjOwn)]
pub enum JwtAuthenticationFuture<
PubKeyFn: PublicKeyFn,
TService: Service<Request<Body>, Response = Response<UnsyncBoxBody<Bytes, ResponseError>>>,
ResponseError,
> {
// If there was an error return a BAD_REQUEST.
Error,

WaitForFuture {
#[pin]
future: TService::Future,
},

// We have a token and need to run our logic.
HasTokenWaitingForPublicKey {
bearer: Authorization<Bearer>,
request: Request<Body>,
#[pin]
public_key_future: AsyncTraitFuture<Result<Vec<u8>, PubKeyFn::Error>>,
service: TService,
},
}

impl<PubKeyFn, TService, ResponseError> Future
for JwtAuthenticationFuture<PubKeyFn, TService, ResponseError>
where
PubKeyFn: PublicKeyFn + 'static,
TService: Service<Request<Body>, Response = Response<UnsyncBoxBody<Bytes, ResponseError>>>,
{
type Output = Result<TService::Response, TService::Error>;

fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project() {
JwtAuthenticationFutureProj::Error => {
let response = Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Default::default())
.unwrap();
Poll::Ready(Ok(response))
}
JwtAuthenticationFutureProj::WaitForFuture { future } => future.poll(cx),
JwtAuthenticationFutureProj::HasTokenWaitingForPublicKey {
bearer,
public_key_future,
..
} => {
match public_key_future.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(error)) => {
error!(
error = &error as &dyn std::error::Error,
"failed to get public key from auth service"
);
let response = Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.body(Default::default())
.unwrap();

Poll::Ready(Ok(response))
}
Poll::Ready(Ok(public_key)) => {
let claim_result = Claim::from_token(bearer.token().trim(), &public_key);
match claim_result {
Err(code) => {
error!(code = %code, "failed to decode JWT");

let response = Response::builder()
.status(code)
.body(Default::default())
.unwrap();

Poll::Ready(Ok(response))
}
Ok(claim) => {
let owned = self
.as_mut()
.project_replace(JwtAuthenticationFuture::Error);
match owned {
JwtAuthenticationFutureProjOwn::HasTokenWaitingForPublicKey {
mut request, mut service, ..
} => {
request.extensions_mut().insert(claim);
let future = service.call(request);
self.as_mut().set(JwtAuthenticationFuture::WaitForFuture { future });
self.poll(cx)
},
_ => unreachable!("We know that we're in the 'HasTokenWaitingForPublicKey' state"),
}
}
}
}
}
}
}
}
}

impl<S, F, ResponseError> Service<Request<Body>> for JwtAuthentication<S, F>
where
S: Service<Request<Body>, Response = Response<UnsyncBoxBody<Bytes, ResponseError>>>
Expand All @@ -219,11 +320,11 @@ where
+ 'static,
S::Future: Send + 'static,
F: PublicKeyFn + 'static,
<F as PublicKeyFn>::Error: 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
type Future = JwtAuthenticationFuture<F, S, ResponseError>;

fn poll_ready(
&mut self,
Expand All @@ -232,55 +333,24 @@ where
self.inner.poll_ready(cx)
}

fn call(&mut self, mut req: Request<Body>) -> Self::Future {
fn call(&mut self, req: Request<Body>) -> Self::Future {
match req.headers().typed_try_get::<Authorization<Bearer>>() {
Ok(Some(bearer)) => {
let mut this = self.clone();

Box::pin(async move {
match this.public_key_fn.public_key().await {
Ok(public_key) => {
match Claim::from_token(bearer.token().trim(), &public_key) {
Ok(claim) => {
req.extensions_mut().insert(claim);

this.inner.call(req).await
}
Err(code) => {
error!(code = %code, "failed to decode JWT");

Ok(Response::builder()
.status(code)
.body(Default::default())
.unwrap())
}
}
}
Err(error) => {
error!(
error = &error as &dyn std::error::Error,
"failed to get public key from auth service"
);

Ok(Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.body(Default::default())
.unwrap())
}
}
})
let public_key_fn = self.public_key_fn.clone();
let public_key_future = Box::pin(async move { public_key_fn.public_key().await });
Self::Future::HasTokenWaitingForPublicKey {
bearer,
request: req,
public_key_future,
service: self.inner.clone(),
}
}
Ok(None) => {
let future = self.inner.call(req);

Box::pin(async move { future.await })
Self::Future::WaitForFuture { future }
}
Err(_) => Box::pin(async move {
Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Default::default())
.unwrap())
}),
Err(_) => Self::Future::Error,
}
}
}
Expand Down

0 comments on commit d4322be

Please sign in to comment.