Skip to content

Commit

Permalink
axum: allow body types other than axum::body::Body in Services pa…
Browse files Browse the repository at this point in the history
…ssed to `serve`
  • Loading branch information
mladedav committed Feb 7, 2025
1 parent 0e6e96f commit 2aef075
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 19 deletions.
5 changes: 5 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

# Unreleased

- **changed:** `serve` has an additional generic argument and can now work with any response body
type, not just `axum::body::Body` ([3205])

# 0.8.2

- **added:** Implement `OptionalFromRequest` for `Json` ([#3142])
Expand Down
80 changes: 61 additions & 19 deletions axum/src/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use std::{
convert::Infallible,
error::Error as StdError,
fmt::Debug,
future::{poll_fn, Future, IntoFuture},
io,
Expand All @@ -11,6 +12,7 @@ use std::{

use axum_core::{body::Body, extract::Request, response::Response};
use futures_util::{pin_mut, FutureExt};
use http_body::Body as HttpBody;
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
#[cfg(any(feature = "http1", feature = "http2"))]
Expand Down Expand Up @@ -94,12 +96,15 @@ pub use self::listener::{Listener, ListenerExt, TapIo};
/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
pub fn serve<L, M, S, B>(listener: L, make_service: M) -> Serve<L, M, S, B>
where
L: Listener,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
Serve {
listener,
Expand All @@ -111,14 +116,14 @@ where
/// Future returned by [`serve`].
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"]
pub struct Serve<L, M, S> {
pub struct Serve<L, M, S, B> {
listener: L,
make_service: M,
_marker: PhantomData<S>,
_marker: PhantomData<(S, B)>,
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S> Serve<L, M, S>
impl<L, M, S, B> Serve<L, M, S, B>
where
L: Listener,
{
Expand Down Expand Up @@ -148,7 +153,7 @@ where
///
/// Similarly to [`serve`], although this future resolves to `io::Result<()>`, it will never
/// error. It returns `Ok(())` only after the `signal` future completes.
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F, B>
where
F: Future<Output = ()> + Send + 'static,
{
Expand All @@ -167,7 +172,7 @@ where
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S> Debug for Serve<L, M, S>
impl<L, M, S, B> Debug for Serve<L, M, S, B>
where
L: Debug + 'static,
M: Debug,
Expand All @@ -188,14 +193,17 @@ where
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S> IntoFuture for Serve<L, M, S>
impl<L, M, S, B> IntoFuture for Serve<L, M, S, B>
where
L: Listener,
L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
type Output = io::Result<()>;
type IntoFuture = private::ServeFuture;
Expand All @@ -209,15 +217,15 @@ where
/// Serve future with graceful shutdown enabled.
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"]
pub struct WithGracefulShutdown<L, M, S, F> {
pub struct WithGracefulShutdown<L, M, S, F, B> {
listener: L,
make_service: M,
signal: F,
_marker: PhantomData<S>,
_marker: PhantomData<(S, B)>,
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
impl<L, M, S, F, B> WithGracefulShutdown<L, M, S, F, B>
where
L: Listener,
{
Expand All @@ -228,7 +236,7 @@ where
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
impl<L, M, S, F, B> Debug for WithGracefulShutdown<L, M, S, F, B>
where
L: Debug + 'static,
M: Debug,
Expand All @@ -252,15 +260,18 @@ where
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
impl<L, M, S, F, B> IntoFuture for WithGracefulShutdown<L, M, S, F, B>
where
L: Listener,
L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
F: Future<Output = ()> + Send + 'static,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
type Output = io::Result<()>;
type IntoFuture = private::ServeFuture;
Expand All @@ -274,15 +285,18 @@ where
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
impl<L, M, S, F, B> WithGracefulShutdown<L, M, S, F, B>
where
L: Listener,
L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
F: Future<Output = ()> + Send + 'static,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
async fn run(self) {
let Self {
Expand Down Expand Up @@ -439,14 +453,15 @@ mod tests {
};

use axum_core::{body::Body, extract::Request};
use http::StatusCode;
use http::{Response, StatusCode};
use hyper_util::rt::TokioIo;
#[cfg(unix)]
use tokio::net::UnixListener;
use tokio::{
io::{self, AsyncRead, AsyncWrite},
net::TcpListener,
};
use tower::ServiceBuilder;

#[cfg(unix)]
use super::IncomingStream;
Expand All @@ -458,7 +473,7 @@ mod tests {
handler::{Handler, HandlerWithoutStateExt},
routing::get,
serve::ListenerExt,
Router,
Router, ServiceExt,
};

#[allow(dead_code, unused_must_use)]
Expand Down Expand Up @@ -686,4 +701,31 @@ mod tests {
let body = String::from_utf8(body.to_vec()).unwrap();
assert_eq!(body, "Hello, World!");
}

#[crate::test]
async fn serving_with_custom_body_type() {
struct CustomBody;
impl http_body::Body for CustomBody {
type Data = bytes::Bytes;
type Error = std::convert::Infallible;
fn poll_frame(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>>
{
#![allow(clippy::unreachable)] // The implementation is not used, we just need to provide one.
unreachable!();
}
}

let app = ServiceBuilder::new()
.layer_fn(|_| tower::service_fn(|_| std::future::ready(Ok(Response::new(CustomBody)))))
.service(Router::<()>::new().route("/hello", get(|| async {})));
let addr = "0.0.0.0:0";

_ = serve(
TcpListener::bind(addr).await.unwrap(),
app.into_make_service(),
);
}
}

0 comments on commit 2aef075

Please sign in to comment.