Skip to content

Commit

Permalink
feat: add {http1,http2}_only for auto conn
Browse files Browse the repository at this point in the history
  • Loading branch information
dswij committed Mar 17, 2024
1 parent 16daef6 commit b29fc46
Showing 1 changed file with 143 additions and 4 deletions.
147 changes: 143 additions & 4 deletions src/server/conn/auto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub struct Builder<E> {
http1: http1::Builder,
#[cfg(feature = "http2")]
http2: http2::Builder<E>,
#[cfg(any(feature = "http1", feature = "http2"))]
version: Option<Version>,
#[cfg(not(feature = "http2"))]
_executor: E,
}
Expand All @@ -84,6 +86,8 @@ impl<E> Builder<E> {
http1: http1::Builder::new(),
#[cfg(feature = "http2")]
http2: http2::Builder::new(executor),
#[cfg(any(feature = "http1", feature = "http2"))]
version: None,
#[cfg(not(feature = "http2"))]
_executor: executor,
}
Expand All @@ -101,6 +105,26 @@ impl<E> Builder<E> {
Http2Builder { inner: self }
}

/// Only accepts HTTP/2
///
/// Does not do anything if used with [`serve_connection_with_upgrades`]
#[cfg(feature = "http2")]
pub fn http2_only(mut self) -> Self {
assert!(self.version.is_none());
self.version = Some(Version::H2);
self
}

/// Only accepts HTTP/1
///
/// Does not do anything if used with [`serve_connection_with_upgrades`]
#[cfg(feature = "http1")]
pub fn http1_only(mut self) -> Self {
assert!(self.version.is_none());
self.version = Some(Version::H1);
self
}

/// Bind a connection together with a [`Service`].
pub fn serve_connection<I, S, B>(&self, io: I, service: S) -> Connection<'_, I, S, E>
where
Expand All @@ -112,13 +136,28 @@ impl<E> Builder<E> {
I: Read + Write + Unpin + 'static,
E: HttpServerConnExec<S::Future, B>,
{
Connection {
state: ConnState::ReadVersion {
let state = match self.version {
#[cfg(feature = "http1")]
Some(Version::H1) => {
let io = Rewind::new_buffered(io, Bytes::new());
let conn = self.http1.serve_connection(io, service);
ConnState::H1 { conn }
}
#[cfg(feature = "http2")]
Some(Version::H2) => {
let io = Rewind::new_buffered(io, Bytes::new());
let conn = self.http2.serve_connection(io, service);
ConnState::H2 { conn }
}
#[cfg(any(feature = "http1", feature = "http2"))]
_ => ConnState::ReadVersion {
read_version: read_version(io),
builder: self,
service: Some(service),
},
}
};

Connection { state }
}

/// Bind a connection together with a [`Service`], with the ability to
Expand Down Expand Up @@ -148,7 +187,7 @@ impl<E> Builder<E> {
}
}

#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
enum Version {
H1,
H2,
Expand Down Expand Up @@ -894,6 +933,62 @@ mod tests {
assert_eq!(body, BODY);
}

#[cfg(not(miri))]
#[tokio::test]
async fn http2_only() {
let addr = start_server_h2_only().await;
let mut sender = connect_h2(addr).await;

let response = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.unwrap();

let body = response.into_body().collect().await.unwrap().to_bytes();

assert_eq!(body, BODY);
}

#[cfg(not(miri))]
#[tokio::test]
async fn http2_only_fail_if_client_is_http1() {
let addr = start_server_h2_only().await;
let mut sender = connect_h1(addr).await;

let _ = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.expect_err("should fail");
}

#[cfg(not(miri))]
#[tokio::test]
async fn http1_only() {
let addr = start_server_h1_only().await;
let mut sender = connect_h1(addr).await;

let response = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.unwrap();

let body = response.into_body().collect().await.unwrap().to_bytes();

assert_eq!(body, BODY);
}

#[cfg(not(miri))]
#[tokio::test]
async fn http1_only_fail_if_client_is_http2() {
let addr = start_server_h1_only().await;
let mut sender = connect_h2(addr).await;

let _ = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.expect_err("should fail");
}

#[cfg(not(miri))]
#[tokio::test]
async fn graceful_shutdown() {
Expand Down Expand Up @@ -980,6 +1075,50 @@ mod tests {
local_addr
}

async fn start_server_h2_only() -> SocketAddr {
let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
let listener = TcpListener::bind(addr).await.unwrap();

let local_addr = listener.local_addr().unwrap();

tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let stream = TokioIo::new(stream);
tokio::task::spawn(async move {
let _ = auto::Builder::new(TokioExecutor::new())
.http2_only()
.serve_connection(stream, service_fn(hello))
.await;
});
}
});

local_addr
}

async fn start_server_h1_only() -> SocketAddr {
let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
let listener = TcpListener::bind(addr).await.unwrap();

let local_addr = listener.local_addr().unwrap();

tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let stream = TokioIo::new(stream);
tokio::task::spawn(async move {
let _ = auto::Builder::new(TokioExecutor::new())
.http1_only()
.serve_connection(stream, service_fn(hello))
.await;
});
}
});

local_addr
}

async fn hello(_req: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, Infallible> {
Ok(Response::new(Full::new(Bytes::from(BODY))))
}
Expand Down

0 comments on commit b29fc46

Please sign in to comment.