Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update to support hyper 1.0 #4

Merged
merged 5 commits into from
Jan 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,21 @@ keywords = ["hyper", "tls", "http", "https", "rustls"]
categories = ["network-programming"]
exclude = ["examples/certs"]

[features]
default = ["rustls_helpers"]
rustls_helpers = ["dep:rustls-pemfile"]

[dependencies]
futures-util = { version = "0.3.28", default-features = false, features = ["std"] }
hyper = { version = "0.14.28", features = ["server", "tcp"] }
rustls-pemfile = "2.0.0"
thiserror = "1.0.51"
# Will figure out how to handle http1 vs. http2 later
hyper = { version = "1.1.0", features = ["server", "http1"] }
hyper-util = { version = "0.1.2", features = ["tokio"] }
rustls-pemfile = { version = "2.0.0", optional = true }
thiserror = "1.0.52"
tls-listener = { version = "0.9.1", features = ["rustls"] }
tokio = { version = "1.35.1", features = ["net", "time"] }
tokio-rustls = "0.25.0"

[dev-dependencies]
hyper = { version = "0.14.28", features = ["http1", "http2"] }
http-body-util = "0.1.0"
hyper = { version = "1.1.0", features = ["http1"] }
tokio = { version = "1.35.1", features = ["rt", "macros"] }
48 changes: 23 additions & 25 deletions examples/hello-world-https.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
use flexible_hyper_server_tls::*;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server};
use http_body_util::Full;
use hyper::body::{Bytes, Incoming};
use hyper::service::service_fn;
use hyper::{Request, Response};
use std::convert::Infallible;
use std::time::Duration;
use tokio::net::TcpListener;

const CERT_DATA: &str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/examples/certs/cert.pem"));
const KEY_DATA: &str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/examples/certs/key.pem"));

async fn hello_world(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
Ok(Response::new("Hello, World".into()))
const CERT_DATA: &str = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/examples/certs/cert.pem"
));
const KEY_DATA: &str = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/examples/certs/key.pem"
));

async fn hello_world(_req: Request<Incoming>) -> Result<Response<Full<Bytes>>, Infallible> {
Ok(Response::new(Full::<Bytes>::from("Hello, World!")))
}

#[tokio::main(flavor = "current_thread")]
Expand All @@ -18,27 +25,18 @@ async fn main() {

let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();

let make_svc = make_service_fn(|conn: &HttpOrHttpsConnection| {
println!("Remote address: {}", conn.remote_addr());
async { Ok::<_, Infallible>(service_fn(hello_world)) }
});
let builder = AcceptorBuilder::new(listener);

let acceptor = if use_tls {
let tls_acceptor = tlsconfig::get_tlsacceptor_from_pem_data(
CERT_DATA,
KEY_DATA,
tlsconfig::HttpProtocol::Both,
)
.unwrap();
HyperHttpOrHttpsAcceptor::new_https(listener, tls_acceptor, Duration::from_secs(10))
let mut acceptor = if use_tls {
let tls_acceptor =
rustls_helpers::get_tlsacceptor_from_pem_data(CERT_DATA, KEY_DATA).unwrap();
builder.https(tls_acceptor).build()
} else {
HyperHttpOrHttpsAcceptor::new_http(listener)
builder.build()
};

let mut server = Server::builder(acceptor).serve(make_svc);

loop {
let res = (&mut server).await;
eprintln!("Error: {:?}", res);
let peer_addr = acceptor.accept(service_fn(hello_world)).await.unwrap();
println!("Connected peer: {}", peer_addr)
}
}
188 changes: 76 additions & 112 deletions src/accept.rs
Original file line number Diff line number Diff line change
@@ -1,57 +1,89 @@
use futures_util::future::BoxFuture;
use futures_util::stream::FuturesUnordered;
use futures_util::{FutureExt, StreamExt};
use hyper::server::accept::Accept;
use std::pin::Pin;
use std::task::{Context, Poll};
use hyper::server::conn::http1;
use hyper_util::rt::TokioIo;
use std::net::SocketAddr;
use thiserror::Error;

use crate::conn::{ConnKind, HttpOrHttpsConnection};

/// Choose to accept either a HTTP or HTTPS connection
pub struct HyperHttpOrHttpsAcceptor {
listener: tokio::net::TcpListener,
kind: AcceptorKind,
}
///
/// Created by calling the `build` method on an `AcceptorBuilder`
// Use a struct instead of the enum directly to avoid users constructing/matching on enum variants
pub struct HttpOrHttpsAcceptor(pub(crate) AcceptorInner);

enum AcceptorKind {
Http,
Https {
tls_acceptor: tokio_rustls::TlsAcceptor,
timeout: std::time::Duration,
// Future has to be boxed because Rust doesn't allow writing out the full type
// Side benefit of allow us to use Timeout without needing pin projection
encryption_futures: FuturesUnordered<
tokio::time::Timeout<BoxFuture<'static, Result<HttpOrHttpsConnection, AcceptorError>>>,
>,
},
pub enum AcceptorInner {
Http(tokio::net::TcpListener),
Https(tls_listener::TlsListener<tokio::net::TcpListener, tokio_rustls::TlsAcceptor>),
}

impl HyperHttpOrHttpsAcceptor {
/// Create an acceptor that will accept HTTP connections
pub const fn new_http(listener: tokio::net::TcpListener) -> Self {
Self {
listener,
kind: AcceptorKind::Http,
impl HttpOrHttpsAcceptor {
/// Accepts every connection using the service provided, never completes.
/// Ignores any connection errors produced by the `accept` method.
pub async fn serve<S>(&mut self, service: S)
where
S: hyper::service::HttpService<hyper::body::Incoming> + Clone + Send + 'static,
S::Future: Send,
S::ResBody: Send + 'static,
<S::ResBody as hyper::body::Body>::Error: std::error::Error + Send + Sync + 'static,
<S::ResBody as hyper::body::Body>::Data: Send,
{
loop {
// Ignore result here
let _ = self.accept(service.clone()).await;
}
}

/// Create an acceptor that will accept HTTPS connections using the provided `TlsAcceptor`
/// Accepts a singular connection and spawns it onto the tokio runtime.
/// Returns the address of the connected client.
///
/// `handshake_timeout` is the length of time that should be allowed to finish a TLS handshake before we drop the connection.
/// Setting it to 0 will not disable the timeout, but will instead instantly drop every connection (you probably don't want this).
pub fn new_https(
listener: tokio::net::TcpListener,
tls_acceptor: tokio_rustls::TlsAcceptor,
handshake_timeout: std::time::Duration,
) -> Self {
Self {
listener,
kind: AcceptorKind::Https {
tls_acceptor,
timeout: handshake_timeout,
encryption_futures: FuturesUnordered::new(),
},
/// # Errors
/// If the TCP connection or TLS handshake (HTTPS only) fails
// Function won't panic, however tokio worker might
#[allow(clippy::missing_panics_doc)]
pub async fn accept<S>(&mut self, service: S) -> Result<SocketAddr, AcceptorError>
where
S: hyper::service::HttpService<hyper::body::Incoming> + Send + 'static,
S::Future: Send,
S::ResBody: Send + 'static,
<S::ResBody as hyper::body::Body>::Error: std::error::Error + Send + Sync + 'static,
<S::ResBody as hyper::body::Body>::Data: Send,
{
let conn_builder = http1::Builder::new();

match &mut self.0 {
AcceptorInner::Http(listener) => {
let (conn, peer_addr) =
listener.accept().await.map_err(AcceptorError::TcpConnect)?;

let conn = TokioIo::new(conn);

let conn = conn_builder.serve_connection(conn, service);

tokio::spawn(async move { conn.await.unwrap() });

Ok(peer_addr)
}
AcceptorInner::Https(listener) => {
let (conn, peer_addr) = loop {
match listener.accept().await {
Err(tls_listener::Error::ListenerError(e)) => {
return Err(AcceptorError::TcpConnect(e))
}
Err(tls_listener::Error::TlsAcceptError { error, .. }) => {
return Err(AcceptorError::TcpConnect(error))
}
// Ignore handshake timeout errors, just try to get another connection
Err(_) => continue,
Ok(conn_and_addr) => break conn_and_addr,
}
};

let conn = TokioIo::new(conn);

let conn = conn_builder.serve_connection(conn, service);

tokio::spawn(async move { conn.await.unwrap() });

Ok(peer_addr)
}
}
}
}
Expand All @@ -66,71 +98,3 @@ pub enum AcceptorError {
#[error("TLS handshake with client failed")]
TlsHandshake(#[source] std::io::Error),
}

impl Accept for HyperHttpOrHttpsAcceptor {
type Conn = HttpOrHttpsConnection;
type Error = AcceptorError;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
// Necessary to allow partial borrows
let this = self.get_mut();

match &mut this.kind {
// If just a normal HTTP connection, just poll to accept the new TCP connection
AcceptorKind::Http => match this.listener.poll_accept(cx) {
Poll::Ready(Ok(stream)) => Poll::Ready(Some(Ok(HttpOrHttpsConnection {
remote_addr: stream.1,
kind: ConnKind::Http(stream.0),
}))),
Poll::Ready(Err(err)) => Poll::Ready(Some(Err(AcceptorError::TcpConnect(err)))),
Poll::Pending => Poll::Pending,
},
// Otherwise, if it's an HTTPS connection, check if we're ready to encrypt the connection
AcceptorKind::Https {
tls_acceptor,
timeout,
encryption_futures,
} => {
// Accept all pending TCP connections at once (this future won't be woken up for TCP unless we get a pending here)
loop {
match this.listener.poll_accept(cx) {
Poll::Ready(Ok(stream)) => {
let tls_future = tls_acceptor
.accept(stream.0)
.map(move |f| {
// Map so that we can pass along the remote address
f.map(|conn| HttpOrHttpsConnection {
remote_addr: stream.1,
kind: ConnKind::Https(conn),
})
.map_err(AcceptorError::TlsHandshake)
})
.boxed();
let timed_tls_future = tokio::time::timeout(*timeout, tls_future);
encryption_futures.push(timed_tls_future);
}
Poll::Ready(Err(err)) => {
return Poll::Ready(Some(Err(AcceptorError::TcpConnect(err))))
}
// Break on pending here so we can check on the TLS queue
Poll::Pending => break,
}
}
// Check queue to see if any handshakes are done/timeouts hit
loop {
match encryption_futures.poll_next_unpin(cx) {
// Already `map`ed to a Result<HttpOrHttpsConnection>, so no need to differentiate
// between Some(Err) and Some(Ok)
Poll::Ready(Some(Ok(res))) => return Poll::Ready(Some(res)),
// An error here means that the timeout ran out, so just skip to the next one in the queue
Poll::Ready(Some(Err(_))) => continue,
_ => return Poll::Pending,
}
}
}
}
}
}
84 changes: 84 additions & 0 deletions src/builder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use std::time::Duration;

use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;

use crate::accept::{AcceptorInner, HttpOrHttpsAcceptor};

pub struct Http;
pub struct Https {
tls_acceptor: tokio_rustls::TlsAcceptor,
max_handshakes: usize,
timeout: Duration,
}

/// Build an `HttpOrHttpsAcceptor`
///
/// Defaults to accepting HTTP connections, call the `https` method to accept HTTPS connections instead
pub struct AcceptorBuilder<State> {
state: State,
listener: TcpListener,
}

impl AcceptorBuilder<Http> {
/// Create a new builder for an `HttpOrHttpsAcceptor`
///
/// Defaults to accepting HTTP
pub const fn new(listener: TcpListener) -> Self {
Self {
state: Http,
listener,
}
}

/// Converts the builder into accepting HTTPS using the provided `TlsAcceptor`
pub fn https(self, tls_acceptor: TlsAcceptor) -> AcceptorBuilder<Https> {
AcceptorBuilder {
state: Https {
tls_acceptor,
max_handshakes: tls_listener::DEFAULT_MAX_HANDSHAKES,
timeout: tls_listener::DEFAULT_HANDSHAKE_TIMEOUT,
},
listener: self.listener,
}
}

/// Builds an `HttpOrHttpsAcceptor` to accept HTTP connections
pub fn build(self) -> HttpOrHttpsAcceptor {
HttpOrHttpsAcceptor(AcceptorInner::Http(self.listener))
}
}

impl AcceptorBuilder<Https> {
/// Set the maximum number of handshakes that will be processed concurrently
///
/// Defaults to 64
#[must_use]
pub const fn max_handshakes(mut self, num: usize) -> Self {
self.state.max_handshakes = num;
self
}

/// Set the maximum amount of time that a handshake can take before being aborted.
/// Setting it to 0 will not disable the timeout, but will instead instantly drop every connection.
///
/// Defaults to 10 seconds
#[must_use]
pub const fn timeout(mut self, timeout: Duration) -> Self {
self.state.timeout = timeout;
self
}

/// Builds an `HttpOrHttpsAcceptor` to accept HTTPS connections
pub fn build(self) -> HttpOrHttpsAcceptor {
let mut tls_builder = tls_listener::builder(self.state.tls_acceptor);

tls_builder
.max_handshakes(self.state.max_handshakes)
.handshake_timeout(self.state.timeout);

let tls_listener = tls_builder.listen(self.listener);

HttpOrHttpsAcceptor(AcceptorInner::Https(tls_listener))
}
}
Loading