Skip to content

Commit

Permalink
Merge pull request #6 from maxcountryman/feat/support-regenerate
Browse files Browse the repository at this point in the history
initial sketch for `session.regenerate()` support
  • Loading branch information
maxcountryman authored Aug 8, 2022
2 parents 0b9af40 + 848d364 commit 4948e30
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 33 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ features = ["headers"]
version = "0.3.4"
features = ["cookie-signed"]

[dependencies.tokio]
version = "1.20.1"
default-features = false
features = ["sync"]

[dev-dependencies]
http = "0.2.8"
hyper = "0.14.19"
Expand Down
4 changes: 4 additions & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[workspace]
members = ["*"]
exclude = ["target"]
resolver = "2"
17 changes: 17 additions & 0 deletions examples/regenerate/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "example-regenerate"
version = "0.1.0"
edition = "2021"
publish = false

[dependencies]
axum = "0.5.13"
axum-sessions = { path = "../../" }

[dependencies.rand]
version = "0.8.5"
features = ["min_const_gen"]

[dependencies.tokio]
version = "1.0"
features = ["full"]
44 changes: 44 additions & 0 deletions examples/regenerate/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use axum::{routing::get, Router};
use axum_sessions::{
async_session::MemoryStore,
extractors::{ReadableSession, WritableSession},
SessionLayer,
};
use rand::Rng;

#[tokio::main]
async fn main() {
let store = MemoryStore::new();
let secret = rand::thread_rng().gen::<[u8; 128]>();
let session_layer = SessionLayer::new(store, &secret);

async fn regenerate_handler(mut session: WritableSession) {
// NB: This DOES NOT update the store, meaning that both sessions will still be
// found.
session.regenerate();
}

async fn insert_handler(mut session: WritableSession) {
session
.insert("foo", 42)
.expect("Could not store the answer.");
}

async fn handler(session: ReadableSession) -> String {
session
.get::<usize>("foo")
.map(|answer| format!("{}", answer))
.unwrap_or_else(|| "Nothing in session yet; try /insert.".to_string())
}

let app = Router::new()
.route("/regenerate", get(regenerate_handler))
.route("/insert", get(insert_handler))
.route("/", get(handler))
.layer(session_layer);

axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
.serve(app.into_make_service())
.await
.unwrap();
}
78 changes: 78 additions & 0 deletions src/extractors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use std::ops::{Deref, DerefMut};

use axum::{
async_trait,
extract::{FromRequest, RequestParts},
http, Extension,
};
use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard};

use crate::SessionHandle;

/// An extractor which provides a readable session. Sessions may have many
/// readers.
pub struct ReadableSession {
session: OwnedRwLockReadGuard<async_session::Session>,
}

impl Deref for ReadableSession {
type Target = OwnedRwLockReadGuard<async_session::Session>;

fn deref(&self) -> &Self::Target {
&self.session
}
}

#[async_trait]
impl<B> FromRequest<B> for ReadableSession
where
B: Send,
{
type Rejection = std::convert::Infallible;

async fn from_request(request: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let Extension(session_handle): Extension<SessionHandle> = Extension::from_request(request)
.await
.expect("Session extension missing. Is the session layer installed?");
let session = session_handle.read_owned().await;

Ok(Self { session })
}
}

/// An extractor which provides a writable session. Sessions may have only one
/// writer.
pub struct WritableSession {
session: OwnedRwLockWriteGuard<async_session::Session>,
}

impl Deref for WritableSession {
type Target = OwnedRwLockWriteGuard<async_session::Session>;

fn deref(&self) -> &Self::Target {
&self.session
}
}

impl DerefMut for WritableSession {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.session
}
}

#[async_trait]
impl<B> FromRequest<B> for WritableSession
where
B: Send,
{
type Rejection = std::convert::Infallible;

async fn from_request(request: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let Extension(session_handle): Extension<SessionHandle> = Extension::from_request(request)
.await
.expect("Session extension missing. Is the session layer installed?");
let session = session_handle.write_owned().await;

Ok(Self { session })
}
}
40 changes: 21 additions & 19 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,31 @@
//!
//! [`SessionLayer`] provides client sessions via [`async_session`]. Sessions
//! are backed by cryptographically signed cookies. These cookies are generated
//! when they're not found or otherwise invalid. When a valid, known
//! cookie is received in a request, the session is hydrated from this cookie.
//! The middleware leverages [`http::Extensions`](axum::http::Extensions) to
//! attach an [`async_session::Session`] to the request. Request handlers can
//! then interact with the session.
//! when they're not found or are otherwise invalid. When a valid, known cookie
//! is received in a request, the session is hydrated from this cookie. The
//! middleware provides sessions via [`SessionHandle`]. Handlers use the
//! [`ReadableSession`](crate::extractors::ReadableSession) and
//! [`WritableSession`](crate::extractors::WritableSession) extractors to read
//! from and write to sessions respectively.
//!
//! # Example
//!
//! Using the middleware with axum is straightforward:
//!
//! ```rust,no_run
//! use axum::{routing::get, Extension, Router};
//! use axum_sessions::{
//! async_session::{MemoryStore, Session},
//! SessionLayer,
//! };
//! use axum::{routing::get, Router};
//! use axum_sessions::{async_session::MemoryStore, extractors::WritableSession, SessionLayer};
//!
//! #[tokio::main]
//! async fn main() {
//! let store = async_session::MemoryStore::new();
//! let secret = b"..."; // MUST be at least 64 bytes!
//! let session_layer = SessionLayer::new(store, secret);
//!
//! async fn handler(Extension(session): Extension<Session>) {
//! // Use the session in your handler...
//! async fn handler(mut session: WritableSession) {
//! session
//! .insert("foo", 42)
//! .expect("Could not store the answer.");
//! }
//!
//! let app = Router::new().route("/", get(handler)).layer(session_layer);
Expand All @@ -39,23 +39,24 @@
//! }
//! ```
//!
//! This middleware may also be used as a generic Tower middleware:
//! This middleware may also be used as a generic Tower middleware by making use
//! of the [`SessionHandle`] extension:
//!
//! ```rust
//! use std::convert::Infallible;
//!
//! use axum::http::header::SET_COOKIE;
//! use axum_sessions::SessionLayer;
//! use axum_sessions::{extractors::WritableSession, SessionHandle, SessionLayer};
//! use http::{Request, Response};
//! use hyper::Body;
//! use rand::Rng;
//! use tower::{Service, ServiceBuilder, ServiceExt};
//!
//! async fn handle(request: Request<Body>) -> Result<Response<Body>, Infallible> {
//! assert!(request
//! .extensions()
//! .get::<async_session::Session>()
//! .is_some());
//! let session_handle = request.extensions().get::<SessionHandle>().unwrap();
//! let session = session_handle.read().await;
//! // Use the session as you'd like.
//!
//! Ok(Response::new(Body::empty()))
//! }
//!
Expand Down Expand Up @@ -88,9 +89,10 @@
//! # Ok(())
//! # }
//! ```
pub mod extractors;
mod session;

pub use async_session;
pub use axum_extra::extract::cookie::SameSite;

pub use self::session::{Session, SessionLayer};
pub use self::session::{Session, SessionHandle, SessionLayer};
57 changes: 43 additions & 14 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// `tide::sessions::middleware::SessionMiddleware`. See: https://github.com/http-rs/tide/blob/20fe435a9544c10f64245e883847fc3cd1d50538/src/sessions/middleware.rs

use std::{
sync::Arc,
task::{Context, Poll},
time::Duration,
};
Expand All @@ -21,10 +22,22 @@ use axum::{
};
use axum_extra::extract::cookie::{Cookie, Key, SameSite};
use futures::future::BoxFuture;
use tokio::sync::RwLock;
use tower::{Layer, Service};

const BASE64_DIGEST_LEN: usize = 44;

/// A type alias which provides a handle to the underlying session.
///
/// This is provided via [`http::Extensions`](axum::http::Extensions). Most
/// applications will use the
/// [`ReadableSession`](crate::extractors::ReadableSession) and
/// [`WritableSession`](crate::extractors::WritableSession) extractors rather
/// than using the handle directly. A notable exception is when using this
/// library as a generic Tower middleware: such use cases will consume the
/// handle directly.
pub type SessionHandle = Arc<RwLock<async_session::Session>>;

#[derive(Clone)]
pub struct SessionLayer<Store> {
store: Store,
Expand All @@ -39,11 +52,11 @@ pub struct SessionLayer<Store> {
}

impl<Store: SessionStore> SessionLayer<Store> {
/// Creates a layer which will attach an [`async_session::Session`] to
/// requests via an extension. This session is derived from a
/// cryptographically signed cookie. When the client sends a valid,
/// known cookie then the session is hydrated from this. Otherwise a new
/// cookie is created and returned in the response.
/// Creates a layer which will attach a [`SessionHandle`] to requests via an
/// extension. This session is derived from a cryptographically signed
/// cookie. When the client sends a valid, known cookie then the session is
/// hydrated from this. Otherwise a new cookie is created and returned in
/// the response.
///
/// # Panics
///
Expand Down Expand Up @@ -134,15 +147,17 @@ impl<Store: SessionStore> SessionLayer<Store> {
self
}

async fn load_or_create(&self, cookie_value: Option<String>) -> async_session::Session {
async fn load_or_create(&self, cookie_value: Option<String>) -> SessionHandle {
let session = match cookie_value {
Some(cookie_value) => self.store.load_session(cookie_value).await.ok().flatten(),
None => None,
};

session
.and_then(|session| session.validate())
.unwrap_or_default()
Arc::new(RwLock::new(
session
.and_then(async_session::Session::validate)
.unwrap_or_default(),
))
}

fn build_cookie(&self, secure: bool, cookie_value: String) -> Cookie<'static> {
Expand Down Expand Up @@ -270,16 +285,28 @@ where

let mut inner = self.inner.clone();
Box::pin(async move {
let mut session = session_layer.load_or_create(cookie_value).await;
let session_handle = session_layer.load_or_create(cookie_value).await;

let mut session = session_handle.write().await;
if let Some(ttl) = session_layer.session_ttl {
session.expire_in(ttl);
(*session).expire_in(ttl);
}
drop(session);

request.extensions_mut().insert(session.clone());
request.extensions_mut().insert(session_handle.clone());
let mut response = inner.call(request).await?;

if session.is_destroyed() {
let session = session_handle.read().await;
let (session_is_destroyed, session_data_changed) =
(session.is_destroyed(), session.data_changed());
drop(session);

// Pull out the session so we can pass it to the store without `Clone` blowing
// away the `cookie_value`.
let session = RwLock::into_inner(
Arc::try_unwrap(session_handle).expect("Session handle still has owners."),
);
if session_is_destroyed {
if let Err(e) = session_layer.store.destroy_session(session).await {
tracing::error!("Failed to destroy session: {:?}", e);
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Expand All @@ -291,7 +318,7 @@ where
SET_COOKIE,
HeaderValue::from_str(&removal_cookie.to_string()).unwrap(),
);
} else if session_layer.save_unchanged || session.data_changed() {
} else if session_layer.save_unchanged || session_data_changed {
match session_layer.store.store_session(session).await {
Ok(Some(cookie_value)) => {
let cookie = session_layer.build_cookie(session_layer.secure, cookie_value);
Expand All @@ -300,7 +327,9 @@ where
HeaderValue::from_str(&cookie.to_string()).unwrap(),
);
}

Ok(None) => {}

Err(e) => {
tracing::error!("Failed to reach session storage: {:?}", e);
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Expand Down

0 comments on commit 4948e30

Please sign in to comment.