Skip to content

Commit

Permalink
feat: update to support hyper 1.0 (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
ravenclaw900 authored Jan 1, 2024
1 parent 55b84f0 commit 868c23e
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 258 deletions.
17 changes: 12 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,21 @@ keywords = ["hyper", "tls", "http", "https", "rustls"]
categories = ["network-programming"]
exclude = ["examples/certs"]

[features]
default = ["rustls_helpers"]
rustls_helpers = ["dep:rustls-pemfile"]

[dependencies]
futures-util = { version = "0.3.28", default-features = false, features = ["std"] }
hyper = { version = "0.14.28", features = ["server", "tcp"] }
rustls-pemfile = "2.0.0"
thiserror = "1.0.51"
# Will figure out how to handle http1 vs. http2 later
hyper = { version = "1.1.0", features = ["server", "http1"] }
hyper-util = { version = "0.1.2", features = ["tokio"] }
rustls-pemfile = { version = "2.0.0", optional = true }
thiserror = "1.0.52"
tls-listener = { version = "0.9.1", features = ["rustls"] }
tokio = { version = "1.35.1", features = ["net", "time"] }
tokio-rustls = "0.25.0"

[dev-dependencies]
hyper = { version = "0.14.28", features = ["http1", "http2"] }
http-body-util = "0.1.0"
hyper = { version = "1.1.0", features = ["http1"] }
tokio = { version = "1.35.1", features = ["rt", "macros"] }
48 changes: 23 additions & 25 deletions examples/hello-world-https.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
use flexible_hyper_server_tls::*;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server};
use http_body_util::Full;
use hyper::body::{Bytes, Incoming};
use hyper::service::service_fn;
use hyper::{Request, Response};
use std::convert::Infallible;
use std::time::Duration;
use tokio::net::TcpListener;

const CERT_DATA: &str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/examples/certs/cert.pem"));
const KEY_DATA: &str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/examples/certs/key.pem"));

async fn hello_world(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
Ok(Response::new("Hello, World".into()))
const CERT_DATA: &str = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/examples/certs/cert.pem"
));
const KEY_DATA: &str = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/examples/certs/key.pem"
));

async fn hello_world(_req: Request<Incoming>) -> Result<Response<Full<Bytes>>, Infallible> {
Ok(Response::new(Full::<Bytes>::from("Hello, World!")))
}

#[tokio::main(flavor = "current_thread")]
Expand All @@ -18,27 +25,18 @@ async fn main() {

let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();

let make_svc = make_service_fn(|conn: &HttpOrHttpsConnection| {
println!("Remote address: {}", conn.remote_addr());
async { Ok::<_, Infallible>(service_fn(hello_world)) }
});
let builder = AcceptorBuilder::new(listener);

let acceptor = if use_tls {
let tls_acceptor = tlsconfig::get_tlsacceptor_from_pem_data(
CERT_DATA,
KEY_DATA,
tlsconfig::HttpProtocol::Both,
)
.unwrap();
HyperHttpOrHttpsAcceptor::new_https(listener, tls_acceptor, Duration::from_secs(10))
let mut acceptor = if use_tls {
let tls_acceptor =
rustls_helpers::get_tlsacceptor_from_pem_data(CERT_DATA, KEY_DATA).unwrap();
builder.https(tls_acceptor).build()
} else {
HyperHttpOrHttpsAcceptor::new_http(listener)
builder.build()
};

let mut server = Server::builder(acceptor).serve(make_svc);

loop {
let res = (&mut server).await;
eprintln!("Error: {:?}", res);
let peer_addr = acceptor.accept(service_fn(hello_world)).await.unwrap();
println!("Connected peer: {}", peer_addr)
}
}
188 changes: 76 additions & 112 deletions src/accept.rs
Original file line number Diff line number Diff line change
@@ -1,57 +1,89 @@
use futures_util::future::BoxFuture;
use futures_util::stream::FuturesUnordered;
use futures_util::{FutureExt, StreamExt};
use hyper::server::accept::Accept;
use std::pin::Pin;
use std::task::{Context, Poll};
use hyper::server::conn::http1;
use hyper_util::rt::TokioIo;
use std::net::SocketAddr;
use thiserror::Error;

use crate::conn::{ConnKind, HttpOrHttpsConnection};

/// Choose to accept either a HTTP or HTTPS connection
pub struct HyperHttpOrHttpsAcceptor {
listener: tokio::net::TcpListener,
kind: AcceptorKind,
}
///
/// Created by calling the `build` method on an `AcceptorBuilder`
// Use a struct instead of the enum directly to avoid users constructing/matching on enum variants
pub struct HttpOrHttpsAcceptor(pub(crate) AcceptorInner);

enum AcceptorKind {
Http,
Https {
tls_acceptor: tokio_rustls::TlsAcceptor,
timeout: std::time::Duration,
// Future has to be boxed because Rust doesn't allow writing out the full type
// Side benefit of allow us to use Timeout without needing pin projection
encryption_futures: FuturesUnordered<
tokio::time::Timeout<BoxFuture<'static, Result<HttpOrHttpsConnection, AcceptorError>>>,
>,
},
pub enum AcceptorInner {
Http(tokio::net::TcpListener),
Https(tls_listener::TlsListener<tokio::net::TcpListener, tokio_rustls::TlsAcceptor>),
}

impl HyperHttpOrHttpsAcceptor {
/// Create an acceptor that will accept HTTP connections
pub const fn new_http(listener: tokio::net::TcpListener) -> Self {
Self {
listener,
kind: AcceptorKind::Http,
impl HttpOrHttpsAcceptor {
/// Accepts every connection using the service provided, never completes.
/// Ignores any connection errors produced by the `accept` method.
pub async fn serve<S>(&mut self, service: S)
where
S: hyper::service::HttpService<hyper::body::Incoming> + Clone + Send + 'static,
S::Future: Send,
S::ResBody: Send + 'static,
<S::ResBody as hyper::body::Body>::Error: std::error::Error + Send + Sync + 'static,
<S::ResBody as hyper::body::Body>::Data: Send,
{
loop {
// Ignore result here
let _ = self.accept(service.clone()).await;
}
}

/// Create an acceptor that will accept HTTPS connections using the provided `TlsAcceptor`
/// Accepts a singular connection and spawns it onto the tokio runtime.
/// Returns the address of the connected client.
///
/// `handshake_timeout` is the length of time that should be allowed to finish a TLS handshake before we drop the connection.
/// Setting it to 0 will not disable the timeout, but will instead instantly drop every connection (you probably don't want this).
pub fn new_https(
listener: tokio::net::TcpListener,
tls_acceptor: tokio_rustls::TlsAcceptor,
handshake_timeout: std::time::Duration,
) -> Self {
Self {
listener,
kind: AcceptorKind::Https {
tls_acceptor,
timeout: handshake_timeout,
encryption_futures: FuturesUnordered::new(),
},
/// # Errors
/// If the TCP connection or TLS handshake (HTTPS only) fails
// Function won't panic, however tokio worker might
#[allow(clippy::missing_panics_doc)]
pub async fn accept<S>(&mut self, service: S) -> Result<SocketAddr, AcceptorError>
where
S: hyper::service::HttpService<hyper::body::Incoming> + Send + 'static,
S::Future: Send,
S::ResBody: Send + 'static,
<S::ResBody as hyper::body::Body>::Error: std::error::Error + Send + Sync + 'static,
<S::ResBody as hyper::body::Body>::Data: Send,
{
let conn_builder = http1::Builder::new();

match &mut self.0 {
AcceptorInner::Http(listener) => {
let (conn, peer_addr) =
listener.accept().await.map_err(AcceptorError::TcpConnect)?;

let conn = TokioIo::new(conn);

let conn = conn_builder.serve_connection(conn, service);

tokio::spawn(async move { conn.await.unwrap() });

Ok(peer_addr)
}
AcceptorInner::Https(listener) => {
let (conn, peer_addr) = loop {
match listener.accept().await {
Err(tls_listener::Error::ListenerError(e)) => {
return Err(AcceptorError::TcpConnect(e))
}
Err(tls_listener::Error::TlsAcceptError { error, .. }) => {
return Err(AcceptorError::TcpConnect(error))
}
// Ignore handshake timeout errors, just try to get another connection
Err(_) => continue,
Ok(conn_and_addr) => break conn_and_addr,
}
};

let conn = TokioIo::new(conn);

let conn = conn_builder.serve_connection(conn, service);

tokio::spawn(async move { conn.await.unwrap() });

Ok(peer_addr)
}
}
}
}
Expand All @@ -66,71 +98,3 @@ pub enum AcceptorError {
#[error("TLS handshake with client failed")]
TlsHandshake(#[source] std::io::Error),
}

impl Accept for HyperHttpOrHttpsAcceptor {
type Conn = HttpOrHttpsConnection;
type Error = AcceptorError;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
// Necessary to allow partial borrows
let this = self.get_mut();

match &mut this.kind {
// If just a normal HTTP connection, just poll to accept the new TCP connection
AcceptorKind::Http => match this.listener.poll_accept(cx) {
Poll::Ready(Ok(stream)) => Poll::Ready(Some(Ok(HttpOrHttpsConnection {
remote_addr: stream.1,
kind: ConnKind::Http(stream.0),
}))),
Poll::Ready(Err(err)) => Poll::Ready(Some(Err(AcceptorError::TcpConnect(err)))),
Poll::Pending => Poll::Pending,
},
// Otherwise, if it's an HTTPS connection, check if we're ready to encrypt the connection
AcceptorKind::Https {
tls_acceptor,
timeout,
encryption_futures,
} => {
// Accept all pending TCP connections at once (this future won't be woken up for TCP unless we get a pending here)
loop {
match this.listener.poll_accept(cx) {
Poll::Ready(Ok(stream)) => {
let tls_future = tls_acceptor
.accept(stream.0)
.map(move |f| {
// Map so that we can pass along the remote address
f.map(|conn| HttpOrHttpsConnection {
remote_addr: stream.1,
kind: ConnKind::Https(conn),
})
.map_err(AcceptorError::TlsHandshake)
})
.boxed();
let timed_tls_future = tokio::time::timeout(*timeout, tls_future);
encryption_futures.push(timed_tls_future);
}
Poll::Ready(Err(err)) => {
return Poll::Ready(Some(Err(AcceptorError::TcpConnect(err))))
}
// Break on pending here so we can check on the TLS queue
Poll::Pending => break,
}
}
// Check queue to see if any handshakes are done/timeouts hit
loop {
match encryption_futures.poll_next_unpin(cx) {
// Already `map`ed to a Result<HttpOrHttpsConnection>, so no need to differentiate
// between Some(Err) and Some(Ok)
Poll::Ready(Some(Ok(res))) => return Poll::Ready(Some(res)),
// An error here means that the timeout ran out, so just skip to the next one in the queue
Poll::Ready(Some(Err(_))) => continue,
_ => return Poll::Pending,
}
}
}
}
}
}
84 changes: 84 additions & 0 deletions src/builder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use std::time::Duration;

use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;

use crate::accept::{AcceptorInner, HttpOrHttpsAcceptor};

pub struct Http;
pub struct Https {
tls_acceptor: tokio_rustls::TlsAcceptor,
max_handshakes: usize,
timeout: Duration,
}

/// Build an `HttpOrHttpsAcceptor`
///
/// Defaults to accepting HTTP connections, call the `https` method to accept HTTPS connections instead
pub struct AcceptorBuilder<State> {
state: State,
listener: TcpListener,
}

impl AcceptorBuilder<Http> {
/// Create a new builder for an `HttpOrHttpsAcceptor`
///
/// Defaults to accepting HTTP
pub const fn new(listener: TcpListener) -> Self {
Self {
state: Http,
listener,
}
}

/// Converts the builder into accepting HTTPS using the provided `TlsAcceptor`
pub fn https(self, tls_acceptor: TlsAcceptor) -> AcceptorBuilder<Https> {
AcceptorBuilder {
state: Https {
tls_acceptor,
max_handshakes: tls_listener::DEFAULT_MAX_HANDSHAKES,
timeout: tls_listener::DEFAULT_HANDSHAKE_TIMEOUT,
},
listener: self.listener,
}
}

/// Builds an `HttpOrHttpsAcceptor` to accept HTTP connections
pub fn build(self) -> HttpOrHttpsAcceptor {
HttpOrHttpsAcceptor(AcceptorInner::Http(self.listener))
}
}

impl AcceptorBuilder<Https> {
/// Set the maximum number of handshakes that will be processed concurrently
///
/// Defaults to 64
#[must_use]
pub const fn max_handshakes(mut self, num: usize) -> Self {
self.state.max_handshakes = num;
self
}

/// Set the maximum amount of time that a handshake can take before being aborted.
/// Setting it to 0 will not disable the timeout, but will instead instantly drop every connection.
///
/// Defaults to 10 seconds
#[must_use]
pub const fn timeout(mut self, timeout: Duration) -> Self {
self.state.timeout = timeout;
self
}

/// Builds an `HttpOrHttpsAcceptor` to accept HTTPS connections
pub fn build(self) -> HttpOrHttpsAcceptor {
let mut tls_builder = tls_listener::builder(self.state.tls_acceptor);

tls_builder
.max_handshakes(self.state.max_handshakes)
.handshake_timeout(self.state.timeout);

let tls_listener = tls_builder.listen(self.listener);

HttpOrHttpsAcceptor(AcceptorInner::Https(tls_listener))
}
}
Loading

0 comments on commit 868c23e

Please sign in to comment.