diff --git a/devolutions-gateway/Cargo.toml b/devolutions-gateway/Cargo.toml index d6423f64f..c2b2ec7f6 100644 --- a/devolutions-gateway/Cargo.toml +++ b/devolutions-gateway/Cargo.toml @@ -68,7 +68,7 @@ tokio-rustls = { version = "0.24", features = ["dangerous_configuration", "tls12 reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-native-roots", "json"] } # TODO: directly use hyper in subscriber module futures = "0.3" async-trait = "0.1" -tower = "0.4" +tower = { version = "0.4", features = ["timeout"] } ngrok = "0.13" # HTTP diff --git a/devolutions-gateway/src/lib.rs b/devolutions-gateway/src/lib.rs index 6e67ded7b..72287aad2 100644 --- a/devolutions-gateway/src/lib.rs +++ b/devolutions-gateway/src/lib.rs @@ -84,6 +84,9 @@ impl DgwState { } pub fn make_http_service(state: DgwState) -> axum::Router<()> { + use axum::error_handling::HandleErrorLayer; + use std::time::Duration; + use tower::timeout::TimeoutLayer; use tower::ServiceBuilder; trace!("Make http service"); @@ -101,6 +104,11 @@ pub fn make_http_service(state: DgwState) -> axum::Router<()> { .layer(axum::middleware::from_fn_with_state( state, middleware::auth::auth_middleware, - )), + )) + // This middleware goes above `TimeoutLayer` because it will receive errors returned by `TimeoutLayer`. + .layer(HandleErrorLayer::new(|_: axum::BoxError| async { + hyper::StatusCode::REQUEST_TIMEOUT + })) + .layer(TimeoutLayer::new(Duration::from_secs(15))), ) } diff --git a/devolutions-gateway/src/listener.rs b/devolutions-gateway/src/listener.rs index 2f675aa15..c02aba37d 100644 --- a/devolutions-gateway/src/listener.rs +++ b/devolutions-gateway/src/listener.rs @@ -13,7 +13,7 @@ use crate::generic_client::GenericClient; use crate::utils::url_to_socket_addr; use crate::DgwState; -const HTTP_REQUEST_TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(15); +const HTTP_CONNECTION_MAX_DURATION: tokio::time::Duration = tokio::time::Duration::from_secs(10 * 60); #[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))] #[derive(Debug, Clone, Serialize)] @@ -170,12 +170,12 @@ async fn run_http_listener(listener: TcpListener, state: DgwState) -> anyhow::Re Ok((stream, peer_addr)) => { let state = state.clone(); - let fut = tokio::time::timeout(HTTP_REQUEST_TIMEOUT, async move { + let fut = tokio::time::timeout(HTTP_CONNECTION_MAX_DURATION, async move { if let Err(e) = handle_http_peer(stream, state, peer_addr).await { error!(error = format!("{e:#}"), "handle_http_peer failed"); } }) - .inspect_err(|error| warn!(%error, "Request timed out")) + .inspect_err(|error| debug!(%error, "Drop long-lived HTTP connection")) .instrument(info_span!("http", client = %peer_addr)); ChildTask::spawn(fut).detach(); @@ -198,12 +198,12 @@ async fn run_https_listener(listener: TcpListener, state: DgwState) -> anyhow::R let tls_acceptor = tls_conf.acceptor.clone(); let state = state.clone(); - let fut = tokio::time::timeout(HTTP_REQUEST_TIMEOUT, async move { + let fut = tokio::time::timeout(HTTP_CONNECTION_MAX_DURATION, async move { if let Err(e) = handle_https_peer(stream, tls_acceptor, state, peer_addr).await { error!(error = format!("{e:#}"), "handle_https_peer failed"); } }) - .inspect_err(|error| warn!(%error, "Request timed out")) + .inspect_err(|error| debug!(%error, "Drop long-lived HTTP connection")) .instrument(info_span!("https", client = %peer_addr)); ChildTask::spawn(fut).detach();