diff --git a/Cargo.toml b/Cargo.toml index d37ed5ec6..843939fa4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["crates/*"] resolver = "2" [workspace.package] -version = "0.61.0" +version = "0.62.0" authors = ["Chrislearn Young "] edition = "2021" rust-version = "1.70" @@ -16,14 +16,27 @@ documentation = "https://docs.rs/salvo/" readme = "./README.md" keywords = ["http", "async", "web", "framework", "server"] license = "MIT OR Apache-2.0" -categories = [ - "web-programming::http-server", - "web-programming::websocket", - "network-programming", - "asynchronous", -] +categories = ["web-programming::http-server", "web-programming::websocket", "network-programming", "asynchronous"] [workspace.dependencies] +salvo_macros = { version = "0.62.0", path = "crates/macros", default-features = false } +salvo_core = { version = "0.62.0", path = "crates/core", default-features = false } +salvo_extra = { version = "0.62.0", path = "crates/extra", default-features = false } +salvo-compression = { version = "0.62.0", path = "crates/compression", default-features = false } +salvo-cache = { version = "0.62.0", path = "crates/cache", default-features = false } +salvo-cors = { version = "0.62.0", path = "crates/cors", default-features = false } +salvo-csrf = { version = "0.62.0", path = "crates/csrf", default-features = false } +salvo-flash = { version = "0.62.0", path = "crates/flash", default-features = false } +salvo-http3 = { version = "0.0.5", default-features = false } +salvo-jwt-auth = { version = "0.62.0", path = "crates/jwt-auth", default-features = false } +salvo-oapi = { version = "0.62.0", path = "./crates/oapi", default-features = false } +salvo-oapi-macros = { version = "0.62.0", path = "crates/oapi-macros", default-features = false } +salvo-otel = { version = "0.62.0", path = "crates/otel", default-features = false } +salvo-proxy = { version = "0.62.0", path = "crates/proxy", default-features = false } +salvo-rate-limiter = { version = "0.62.0", path = "crates/rate-limiter", default-features = false } +salvo-serve-static = { version = "0.62.0", path = "crates/serve-static", default-features = false } +salvo-session = { version = "0.62.0", path = "crates/session", default-features = false } + aead = "0.5" aes-gcm = "0.10" anyhow = "1" @@ -47,13 +60,15 @@ fastrand = "2" form_urlencoded = "1" futures-channel = "0.3" futures-util = { version = "0.3", default-features = false } -headers = "0.3" -http = "0.2" -http-body-util = "0.1.0-rc.3" +headers = "0.4" +http = "1" +http-body-util = "0.1" hmac = "0.12" hex = "0.4" hostname-validator = "1" -hyper = "=1.0.0-rc.4" +hyper = { version = "1", features = ["full"] } +hyper-util = { version = "0.1.1", default-features = true } +hyper-tls = "0.6" indexmap = "2" inventory = "0.3" jsonwebtoken = "9.1" @@ -82,30 +97,11 @@ quote = "1" rand = "0.8" rcgen = "0.11" regex = "1" -reqwest = { version = "0.11", default-features = false } ring = "0.17" rust_decimal = "1" -rustls = "0.21" -rustls-pemfile = "1" -rust-embed = { version = ">= 6, <= 9" } -salvo-utils = { version = "0.0.6", default-features = true } -salvo_macros = { version = "0.61.0", path = "crates/macros", default-features = false } -salvo_core = { version = "0.61.0", path = "crates/core", default-features = false } -salvo_extra = { version = "0.61.0", path = "crates/extra", default-features = false } -salvo-compression = { version = "0.61.0", path = "crates/compression", default-features = false } -salvo-cache = { version = "0.61.0", path = "crates/cache", default-features = false } -salvo-cors = { version = "0.61.0", path = "crates/cors", default-features = false } -salvo-csrf = { version = "0.61.0", path = "crates/csrf", default-features = false } -salvo-flash = { version = "0.61.0", path = "crates/flash", default-features = false } -salvo-http3 = { version = "0.0.4", default-features = false } -salvo-jwt-auth = { version = "0.61.0", path = "crates/jwt-auth", default-features = false } -salvo-oapi = { version = "0.61.0", path = "./crates/oapi", default-features = false } -salvo-oapi-macros = { version = "0.61.0", path = "crates/oapi-macros", default-features = false } -salvo-otel = { version = "0.61.0", path = "crates/otel", default-features = false } -salvo-proxy = { version = "0.61.0", path = "crates/proxy", default-features = false } -salvo-rate-limiter = { version = "0.61.0", path = "crates/rate-limiter", default-features = false } -salvo-serve-static = { version = "0.61.0", path = "crates/serve-static", default-features = false } -salvo-session = { version = "0.61.0", path = "crates/session", default-features = false } +rustls = "0.22" +rustls-pemfile = "2" +rust-embed = { version = ">= 6, <= 8" } serde = "1" serde_json = "1" serde-xml-rs = "0.6" @@ -120,7 +116,7 @@ thiserror = "1" time = "0.3" tokio = "1" tokio-native-tls = "0.3" -tokio-rustls = "0.24" +tokio-rustls = "0.25" tokio-openssl = "0.6" tokio-stream = { version = "0.1", default-features = false } tokio-tungstenite = { version = "0.21", default-features = false } @@ -137,4 +133,4 @@ x509-parser = "0.15" # Compress brotli = { version = "3.3", default-features = false } flate2 = { version = "1.0", default-features = false } -zstd = { version = "0.13", default-features = false } +zstd = { version = "0.13", default-features = false } \ No newline at end of file diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 2d6288507..2d59cd374 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -22,16 +22,19 @@ cookie = ["dep:cookie"] http1 = [] fix-http1-request-uri = ["http1"] http2 = ["hyper/http2"] -quinn = ["dep:salvo-http3", "dep:quinn", "rustls"] +quinn = ["dep:salvo-http3", "dep:quinn", "dep:tokio-rustls-old", "dep:rustls-pemfile-old", "rustls"] rustls = ["http1", "http2", "dep:tokio-rustls", "dep:rustls-pemfile"] native-tls = ["http1", "http2", "dep:tokio-native-tls", "dep:native-tls"] openssl = ["http2", "dep:openssl", "dep:tokio-openssl"] unix = ["http1"] test = ["dep:brotli", "dep:flate2", "dep:zstd", "dep:encoding_rs", "dep:serde_urlencoded", "dep:url", "tokio/macros"] -acme = ["http1", "http2", "hyper/client", "dep:reqwest", "dep:rcgen", "dep:ring", "dep:x509-parser", "dep:tokio-rustls", "dep:rustls-pemfile"] +acme = ["http1", "http2", "hyper-util/http1", "hyper-util/http2","hyper-util/client-legacy", "dep:hyper-tls", "dep:rcgen", "dep:ring", "dep:x509-parser", "dep:tokio-rustls", "dep:rustls-pemfile"] tower-compat = ["dep:tower"] [dependencies] +rustls-pemfile-old = { version = "1", package = "rustls-pemfile", optional = true } +tokio-rustls-old = { version = "0.24", package = "tokio-rustls", optional = true } + anyhow = { workspace = true, optional = true } async-trait = { workspace = true } base64 = { workspace = true } @@ -63,12 +66,12 @@ quinn = { workspace = true, optional = true, features = ["runtime-tokio", "ring" rand = { workspace = true } rcgen = { workspace = true, optional = true } regex = { workspace = true } -reqwest = { workspace = true, optional = true, features = ["rustls-tls", "json"] } ring = { workspace = true, optional = true } rustls-pemfile = { workspace = true, optional = true } salvo-http3 = { workspace = true, optional = true, features = ["quinn"] } salvo_macros = { workspace = true } -salvo-utils = { workspace = true, features = ["runtime"] } +hyper-tls = { workspace = true, optional = true } +hyper-util = { workspace = true, features = ["tokio"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, features = ["raw_value"] } serde-xml-rs = { workspace = true } diff --git a/crates/core/src/conn/acme/client.rs b/crates/core/src/conn/acme/client.rs index ca17e1df1..cb6994821 100644 --- a/crates/core/src/conn/acme/client.rs +++ b/crates/core/src/conn/acme/client.rs @@ -1,20 +1,22 @@ -use std::{ - io::{Error as IoError, ErrorKind, Result as IoResult}, - sync::Arc, - time::Duration, -}; +use std::sync::Arc; use base64::engine::{general_purpose::URL_SAFE_NO_PAD, Engine}; use bytes::Bytes; -use http::header; -use reqwest::Client; +use http_body_util::{BodyExt, Full}; +use hyper::Uri; +use hyper_tls::HttpsConnector; +use hyper_util::client::legacy::{connect::HttpConnector, Client}; +use hyper_util::rt::TokioExecutor; use serde::{Deserialize, Serialize}; -use super::{Challenge, Problem}; - use super::{jose, key_pair::KeyPair, ChallengeType}; +use super::{Challenge, Problem}; use super::{Directory, Identifier}; +use crate::Error; + +pub(super) type HyperClient = Client, Full>; + #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub(crate) struct NewOrderResponse { @@ -34,7 +36,7 @@ pub(crate) struct FetchAuthorizationResponse { } pub(crate) struct AcmeClient { - pub(crate) client: Client, + pub(crate) client: HyperClient, pub(crate) directory: Directory, pub(crate) key_pair: Arc, pub(crate) contacts: Vec, @@ -43,8 +45,8 @@ pub(crate) struct AcmeClient { impl AcmeClient { #[inline] - pub(crate) async fn new(directory_url: &str, key_pair: Arc, contacts: Vec) -> IoResult { - let client = Client::builder().timeout(Duration::from_secs(30)).build().unwrap(); + pub(crate) async fn new(directory_url: &str, key_pair: Arc, contacts: Vec) -> crate::Result { + let client = Client::builder(TokioExecutor::new()).build(HttpsConnector::new()); let directory = get_directory(&client, directory_url).await?; Ok(Self { client, @@ -55,7 +57,7 @@ impl AcmeClient { }) } - pub(crate) async fn new_order(&mut self, domains: &[String]) -> IoResult { + pub(crate) async fn new_order(&mut self, domains: &[String]) -> crate::Result { #[derive(Serialize)] #[serde(rename_all = "camelCase")] struct NewOrderRequest { @@ -63,11 +65,11 @@ impl AcmeClient { } impl FetchAuthorizationResponse { - pub(crate) fn find_challenge(&self, ctype: ChallengeType) -> IoResult<&Challenge> { + pub(crate) fn find_challenge(&self, ctype: ChallengeType) -> crate::Result<&Challenge> { self.challenges .iter() .find(|c| c.kind == ctype.to_string()) - .ok_or_else(|| IoError::new(ErrorKind::Other, format!("unable to find `{}` challenge", ctype))) + .ok_or_else(|| Error::other(format!("unable to find `{}` challenge", ctype))) } } @@ -107,7 +109,7 @@ impl AcmeClient { } #[inline] - pub(crate) async fn fetch_authorization(&self, auth_url: &str) -> IoResult { + pub(crate) async fn fetch_authorization(&self, auth_url: &str) -> crate::Result { tracing::debug!(auth_uri = %auth_url, "fetch authorization"); let nonce = get_nonce(&self.client, &self.directory.new_nonce).await?; @@ -136,7 +138,7 @@ impl AcmeClient { domain: &str, challenge_type: ChallengeType, url: &str, - ) -> IoResult<()> { + ) -> crate::Result<()> { tracing::debug!( auth_uri = %url, domain = domain, @@ -159,7 +161,7 @@ impl AcmeClient { } #[inline] - pub(crate) async fn send_csr(&self, url: &str, csr: &[u8]) -> IoResult { + pub(crate) async fn send_csr(&self, url: &str, csr: &[u8]) -> crate::Result { tracing::debug!(url = %url, "send certificate request"); #[derive(Debug, Serialize)] @@ -183,7 +185,7 @@ impl AcmeClient { } #[inline] - pub(crate) async fn obtain_certificate(&self, url: &str) -> IoResult { + pub(crate) async fn obtain_certificate(&self, url: &str) -> crate::Result { tracing::debug!(url = %url, "send certificate request"); let nonce = get_nonce(&self.client, &self.directory.new_nonce).await?; @@ -196,34 +198,35 @@ impl AcmeClient { None::<()>, ) .await?; - res.bytes() - .await - .map_err(|e| IoError::new(ErrorKind::Other, format!("failed to download certificate: {}", e))) + Ok(res.into_body().collect().await?.to_bytes()) } } -async fn get_directory(client: &Client, directory_url: &str) -> IoResult { +async fn get_directory(client: &HyperClient, directory_url: &str) -> crate::Result { tracing::debug!("loading directory"); - + let directory_url = directory_url + .parse::() + .map_err(|e| Error::other(format!("failed to parse directory dir: {}", e)))?; let res = client .get(directory_url) - .send() .await - .map_err(|e| IoError::new(ErrorKind::Other, format!("failed to load directory: {}", e)))?; + .map_err(|e| Error::other(format!("failed to load directory: {}", e)))?; if !res.status().is_success() { - return Err(IoError::new( - ErrorKind::Other, - format!("failed to load directory: status = {}", res.status()), - )); + return Err(Error::other(format!( + "failed to load directory: status = {}", + res.status() + ))); } let data = res - .bytes() + .into_body() + .collect() .await - .map_err(|e| IoError::new(ErrorKind::Other, format!("failed to read response: {}", e)))?; + .map_err(|e| Error::other(format!("failed to load body: {}", e)))? + .to_bytes(); let directory = serde_json::from_slice::(&data) - .map_err(|e| IoError::new(ErrorKind::Other, format!("failed to load directory: {}", e)))?; + .map_err(|e| Error::other(format!("failed to load directory: {}", e)))?; tracing::debug!( new_nonce = ?directory.new_nonce, @@ -234,20 +237,19 @@ async fn get_directory(client: &Client, directory_url: &str) -> IoResult IoResult { +async fn get_nonce(client: &HyperClient, nonce_url: &str) -> crate::Result { tracing::debug!("creating nonce"); let res = client - .get(nonce_url) - .send() + .get(nonce_url.parse::()?) .await - .map_err(|e| IoError::new(ErrorKind::Other, format!("failed to get nonce: {}", e)))?; + .map_err(|e| Error::other(format!("failed to get nonce: {}", e)))?; if !res.status().is_success() { - return Err(IoError::new( - ErrorKind::Other, - format!("failed to load directory: status = {}", res.status()), - )); + return Err(Error::other(format!( + "failed to load directory: status = {}", + res.status() + ))); } let nonce = res @@ -262,11 +264,11 @@ async fn get_nonce(client: &Client, nonce_url: &str) -> IoResult { } async fn create_acme_account( - client: &Client, + client: &HyperClient, directory: &Directory, key_pair: &KeyPair, contacts: Vec, -) -> IoResult { +) -> crate::Result { tracing::debug!("creating acme account"); #[derive(Serialize)] @@ -293,11 +295,11 @@ async fn create_acme_account( .await?; let kid = res .headers() - .get(header::LOCATION) - .ok_or_else(|| IoError::new(ErrorKind::Other, "unable to get account id"))? + .get("location") + .ok_or_else(|| Error::other("unable to get account id"))? .to_str() .map(|s| s.to_owned()) - .map_err(|_| IoError::new(ErrorKind::Other, "unable to get account id")); + .map_err(|_| Error::other("unable to get account id")); tracing::debug!(kid = ?kid, "account created"); kid diff --git a/crates/core/src/conn/acme/issuer.rs b/crates/core/src/conn/acme/issuer.rs index 437260596..e060ba1af 100644 --- a/crates/core/src/conn/acme/issuer.rs +++ b/crates/core/src/conn/acme/issuer.rs @@ -1,10 +1,10 @@ -use std::io::{Error as IoError, ErrorKind, Result as IoResult}; +use std::io::Result as IoResult; use std::sync::Arc; use std::time::Duration; use rcgen::{Certificate, CertificateParams, CustomExtension, DistinguishedName, PKCS_ECDSA_P256_SHA256}; -use tokio_rustls::rustls::sign::{any_ecdsa_type, CertifiedKey}; -use tokio_rustls::rustls::PrivateKey; +use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; +use tokio_rustls::rustls::{crypto::ring::sign::any_ecdsa_type, sign::CertifiedKey}; use super::cache::AcmeCache; use super::client::AcmeClient; @@ -12,11 +12,13 @@ use super::config::AcmeConfig; use super::resolver::ResolveServerCert; use super::{jose, ChallengeType}; +use crate::Error; + pub(crate) async fn issue_cert( client: &mut AcmeClient, config: &AcmeConfig, resolver: &ResolveServerCert, -) -> IoResult<()> { +) -> crate::Result<()> { tracing::debug!("issue certificate"); let order_res = client.new_order(&config.domains).await?; // trigger challenge @@ -54,14 +56,11 @@ pub(crate) async fn issue_cert( .await?; } else if res.status == "invalid" { tracing::error!(res = ?res, "unable to authorize"); - return Err(IoError::new( - ErrorKind::Other, - format!( - "unable to authorize `{}`: {}", - res.identifier.value, - res.error.as_ref().map(|problem| &*problem.detail).unwrap_or("unknown") - ), - )); + return Err(Error::other(format!( + "unable to authorize `{}`: {}", + res.identifier.value, + res.error.as_ref().map(|problem| &*problem.detail).unwrap_or("unknown") + ))); } } if all_valid { @@ -71,40 +70,37 @@ pub(crate) async fn issue_cert( tokio::time::sleep(Duration::from_secs(i * 10)).await; } if !valid { - return Err(IoError::new(ErrorKind::Other, "authorization failed too many times")); + return Err(Error::other("authorization failed too many times")); } // send csr let mut params = CertificateParams::new(config.domains.clone()); params.distinguished_name = DistinguishedName::new(); params.alg = &PKCS_ECDSA_P256_SHA256; let cert = Certificate::from_params(params) - .map_err(|e| IoError::new(ErrorKind::Other, format!("failed create certificate request: {}", e)))?; - let pk = any_ecdsa_type(&PrivateKey(cert.serialize_private_key_der())).unwrap(); + .map_err(|e| Error::other(format!("failed create certificate request: {}", e)))?; + let pk = any_ecdsa_type(&PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from( + cert.serialize_private_key_der(), + ))) + .unwrap(); let csr = cert .serialize_request_der() - .map_err(|e| IoError::new(ErrorKind::Other, format!("failed to serialize request der {}", e)))?; + .map_err(|e| Error::other(format!("failed to serialize request der {}", e)))?; let order_res = client.send_csr(&order_res.finalize, &csr).await?; if order_res.status == "invalid" { - return Err(IoError::new( - ErrorKind::Other, - format!( - "failed to request certificate: {}", - order_res - .error - .as_ref() - .map(|problem| &*problem.detail) - .unwrap_or("unknown") - ), - )); + return Err(Error::other(format!( + "failed to request certificate: {}", + order_res + .error + .as_ref() + .map(|problem| &*problem.detail) + .unwrap_or("unknown") + ))); } if order_res.status != "valid" { - return Err(IoError::new( - ErrorKind::Other, - format!( - "failed to request certificate: unexpected status `{}`", - order_res.status - ), - )); + return Err(Error::other(format!( + "failed to request certificate: unexpected status `{}`", + order_res.status + ))); } // download certificate let cert_pem = client @@ -112,17 +108,13 @@ pub(crate) async fn issue_cert( order_res .certificate .as_ref() - .ok_or_else(|| IoError::new(ErrorKind::Other, "invalid response: missing `certificate` url"))?, + .ok_or_else(|| Error::other("invalid response: missing `certificate` url"))?, ) .await? .as_ref() .to_vec(); let key_pem = cert.serialize_private_key_pem(); - let cert_chain = rustls_pemfile::certs(&mut cert_pem.as_slice()) - .map_err(|e| IoError::new(ErrorKind::Other, format!("invalid pem: {}", e)))? - .into_iter() - .map(tokio_rustls::rustls::Certificate) - .collect(); + let cert_chain = rustls_pemfile::certs(&mut cert_pem.as_slice()).collect::>>()?; let cert_key = CertifiedKey::new(cert_chain, pk); *resolver.cert.write() = Some(Arc::new(cert_key)); tracing::debug!("certificate obtained"); @@ -138,17 +130,20 @@ pub(crate) async fn issue_cert( } #[inline] -fn gen_acme_cert(domain: &str, acme_hash: &[u8]) -> IoResult { +fn gen_acme_cert(domain: &str, acme_hash: &[u8]) -> crate::Result { let mut params = CertificateParams::new(vec![domain.to_string()]); params.alg = &PKCS_ECDSA_P256_SHA256; params.custom_extensions = vec![CustomExtension::new_acme_identifier(acme_hash)]; - let cert = Certificate::from_params(params) - .map_err(|_| IoError::new(ErrorKind::Other, "failed to generate acme certificate"))?; - let key = any_ecdsa_type(&PrivateKey(cert.serialize_private_key_der())).unwrap(); + let cert = Certificate::from_params(params).map_err(|_| Error::other("failed to generate acme certificate"))?; + let key = any_ecdsa_type(&PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from( + cert.serialize_private_key_der(), + ))) + .unwrap(); Ok(CertifiedKey::new( - vec![tokio_rustls::rustls::Certificate(cert.serialize_der().map_err( - |_| IoError::new(ErrorKind::Other, "failed to serialize acme certificate"), - )?)], + vec![CertificateDer::from( + cert.serialize_der() + .map_err(|_| Error::other("failed to serialize acme certificate"))?, + )], key, )) } diff --git a/crates/core/src/conn/acme/jose.rs b/crates/core/src/conn/acme/jose.rs index 26255d726..17b510e0e 100644 --- a/crates/core/src/conn/acme/jose.rs +++ b/crates/core/src/conn/acme/jose.rs @@ -1,12 +1,15 @@ use std::io::{Error as IoError, ErrorKind, Result as IoResult}; +use super::client::HyperClient; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; -use reqwest::Client; +use http_body_util::{BodyExt, Full}; +use hyper::{body::Incoming as HyperBody, Method}; use ring::digest::{digest, Digest, SHA256}; use serde::{de::DeserializeOwned, Serialize}; use crate::conn::acme::key_pair::KeyPair; +use crate::Error; #[derive(Serialize)] struct Protected<'a> { @@ -95,13 +98,13 @@ struct Body { } pub(crate) async fn request( - client: &Client, + client: &HyperClient, key_pair: &KeyPair, kid: Option<&str>, nonce: &str, uri: &str, payload: Option, -) -> IoResult { +) -> IoResult> { let jwk = match kid { None => Some(Jwk::new(key_pair)), Some(_) => None, @@ -122,11 +125,15 @@ pub(crate) async fn request( }) .unwrap(); - let res = client - .post(uri) + let req = hyper::Request::builder() .header("content-type", "application/jose+json") - .body(body) - .send() + .method(Method::POST) + .uri(uri) + .body(Full::from(body)) + .unwrap(); + + let res = client + .request(req) .await .map_err(|e| IoError::new(ErrorKind::Other, format!("failed to send http request: {}", e)))?; if !res.status().is_success() { @@ -139,13 +146,13 @@ pub(crate) async fn request( } #[inline] pub(crate) async fn request_json( - cli: &Client, + cli: &HyperClient, key_pair: &KeyPair, kid: Option<&str>, nonce: &str, url: &str, payload: Option, -) -> IoResult +) -> crate::Result where T: Serialize + Send, R: DeserializeOwned, @@ -153,11 +160,12 @@ where let res = request(cli, key_pair, kid, nonce, url, payload).await?; let data = res - .bytes() - .await - .map_err(|e| IoError::new(ErrorKind::Other, format!("failed to read response: {}", e)))?; + .into_body() + .collect() + .await? + .to_bytes(); serde_json::from_slice(&data) - .map_err(|e| IoError::new(ErrorKind::Other, format!("response is not a valid json: {}", e))) + .map_err(|e| Error::other(format!("response is not a valid json: {}", e))) } #[inline] diff --git a/crates/core/src/conn/acme/listener.rs b/crates/core/src/conn/acme/listener.rs index 2afb845ec..42af180ad 100644 --- a/crates/core/src/conn/acme/listener.rs +++ b/crates/core/src/conn/acme/listener.rs @@ -3,10 +3,11 @@ use std::path::PathBuf; use std::sync::{Arc, Weak}; use std::time::Duration; +use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_rustls::rustls::crypto::ring::sign::any_ecdsa_type; use tokio_rustls::rustls::server::ServerConfig; -use tokio_rustls::rustls::sign::{any_ecdsa_type, CertifiedKey}; -use tokio_rustls::rustls::PrivateKey; +use tokio_rustls::rustls::sign::CertifiedKey; use tokio_rustls::server::TlsStream; use tokio_rustls::TlsAcceptor; @@ -20,13 +21,14 @@ use super::config::{AcmeConfig, AcmeConfigBuilder}; use super::resolver::{ResolveServerCert, ACME_TLS_ALPN_NAME}; use super::{AcmeCache, AcmeClient, ChallengeType, Http01Handler, WELL_KNOWN_PATH}; -cfg_feature! { - #![feature = "quinn"] - use crate::conn::quinn::QuinnAcceptor; - use crate::conn::joined::JoinedAcceptor; - use crate::conn::quinn::QuinnListener; - use futures_util::stream::BoxStream; -} +// TODO: waiting quinn update +// cfg_feature! { +// #![feature = "quinn"] +// use crate::conn::quinn::QuinnAcceptor; +// use crate::conn::joined::JoinedAcceptor; +// use crate::conn::quinn::QuinnListener; +// use futures_util::stream::BoxStream; +// } /// A wrapper around an underlying listener which implements the ACME. pub struct AcmeListener { inner: T, @@ -98,9 +100,10 @@ impl AcmeListener { let handler = Http01Handler { keys: keys_for_http01.clone(), }; - router - .routers - .insert(0, Router::with_path(format!("{}/", WELL_KNOWN_PATH)).goal(handler)); + router.routers.insert( + 0, + Router::with_path(format!("{}/", WELL_KNOWN_PATH)).goal(handler), + ); } else { panic!("`HTTP-01` challenge's key should not be none"); } @@ -128,16 +131,17 @@ impl AcmeListener { } } - cfg_feature! { - #![feature = "quinn"] - /// Enable Http3 using quinn. - pub fn quinn(self, local_addr: A) -> AcmeQuinnListener - where - A: std::net::ToSocketAddrs + Send, - { - AcmeQuinnListener::new(self, local_addr) - } - } + // TODO: waiting quinn update + // cfg_feature! { + // #![feature = "quinn"] + // /// Enable Http3 using quinn. + // pub fn quinn(self, local_addr: A) -> AcmeQuinnListener + // where + // A: std::net::ToSocketAddrs + Send, + // { + // AcmeQuinnListener::new(self, local_addr) + // } + // } } #[async_trait] @@ -148,7 +152,7 @@ where { type Acceptor = AcmeAcceptor; - async fn try_bind(mut self) -> IoResult { + async fn try_bind(mut self) -> crate::Result { let Self { inner, config_builder, @@ -157,49 +161,53 @@ where } = self; let acme_config = config_builder.build()?; let mut cached_key = None; - let mut cached_cert = None; + let mut cached_certs = None; if let Some(cache_path) = &acme_config.cache_path { let key_data = cache_path .read_key(&acme_config.directory_name, &acme_config.domains) .await?; if let Some(key_data) = key_data { tracing::debug!("load private key from cache"); - match rustls_pemfile::pkcs8_private_keys(&mut key_data.as_slice()) { - Ok(key) => cached_key = key.into_iter().next(), - Err(e) => { - tracing::warn!(error = ?e, "parse cached private key failed") + if let Some(key) = rustls_pemfile::pkcs8_private_keys(&mut key_data.as_slice()).next() { + match key { + Ok(key) => { + cached_key = Some(key); + } + Err(e) => { + tracing::warn!(error = ?e, "parse cached private key failed"); + } } - }; + } else { + tracing::warn!("parse cached private key failed"); + } } let cert_data = cache_path .read_cert(&acme_config.directory_name, &acme_config.domains) .await?; if let Some(cert_data) = cert_data { tracing::debug!("load certificate from cache"); - match rustls_pemfile::certs(&mut cert_data.as_slice()) { - Ok(cert) => cached_cert = Some(cert), - Err(e) => { - tracing::warn!(error = ?e, "parse cached tls certificates failed") - } + let certs = rustls_pemfile::certs(&mut cert_data.as_slice()) + .filter_map(|i| i.ok()) + .collect::>(); + if !certs.is_empty() { + cached_certs = Some(certs); + } else { + tracing::warn!("parse cached tls certificates failed") }; } }; let cert_resolver = Arc::new(ResolveServerCert::default()); - if let (Some(cached_cert), Some(cached_key)) = (cached_cert, cached_key) { - let certs = cached_cert - .into_iter() - .map(tokio_rustls::rustls::Certificate) - .collect::>(); + if let (Some(cached_certs), Some(cached_key)) = (cached_certs, cached_key) { + let certs = cached_certs.into_iter().collect::>>(); tracing::debug!("using cached tls certificates"); *cert_resolver.cert.write() = Some(Arc::new(CertifiedKey::new( certs, - any_ecdsa_type(&PrivateKey(cached_key)).unwrap(), + any_ecdsa_type(&PrivateKeyDer::Pkcs8(cached_key)).unwrap(), ))); } let mut server_config = ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() .with_cert_resolver(cert_resolver.clone()); @@ -212,50 +220,58 @@ where let tls_acceptor = TlsAcceptor::from(server_config.clone()); let inner = inner.try_bind().await?; - let acceptor = AcmeAcceptor::new(acme_config, server_config, cert_resolver, inner, tls_acceptor, check_duration).await?; + let acceptor = AcmeAcceptor::new( + acme_config, + server_config, + cert_resolver, + inner, + tls_acceptor, + check_duration, + ) + .await?; Ok(acceptor) } } +// TODO: waiting quinn update +// cfg_feature! { +// #![feature = "quinn"] +// /// A wrapper around an underlying listener which implements the ACME and Quinn. +// pub struct AcmeQuinnListener { +// acme: AcmeListener, +// local_addr: A, +// } -cfg_feature! { - #![feature = "quinn"] - /// A wrapper around an underlying listener which implements the ACME and Quinn. - pub struct AcmeQuinnListener { - acme: AcmeListener, - local_addr: A, - } - - impl AcmeQuinnListener - where - A: std::net::ToSocketAddrs + Send, - { - pub(crate) fn new(acme: AcmeListener, local_addr: A) -> Self { - Self { acme, local_addr } - } - } +// impl AcmeQuinnListener +// where +// A: std::net::ToSocketAddrs + Send, +// { +// pub(crate) fn new(acme: AcmeListener, local_addr: A) -> Self { +// Self { acme, local_addr } +// } +// } - #[async_trait] - impl Listener for AcmeQuinnListener - where - T: Listener + Send, - T::Acceptor: Send + Unpin + 'static, - A: std::net::ToSocketAddrs + Send, - { - type Acceptor = JoinedAcceptor, QuinnAcceptor, crate::conn::quinn::ServerConfig, std::convert::Infallible>>; +// #[async_trait] +// impl Listener for AcmeQuinnListener +// where +// T: Listener + Send, +// T::Acceptor: Send + Unpin + 'static, +// A: std::net::ToSocketAddrs + Send, +// { +// type Acceptor = JoinedAcceptor, QuinnAcceptor, crate::conn::quinn::ServerConfig, std::convert::Infallible>>; - async fn try_bind(self) -> IoResult { - let Self { acme, local_addr } = self; - let a = acme.try_bind().await?; +// async fn try_bind(self) -> crate::Result { +// let Self { acme, local_addr } = self; +// let a = acme.try_bind().await?; - let mut crypto = a.server_config.as_ref().clone(); - crypto.alpn_protocols = vec![b"h3-29".to_vec(), b"h3-28".to_vec(), b"h3-27".to_vec(), b"h3".to_vec()]; - let config = crate::conn::quinn::ServerConfig::with_crypto(Arc::new(crypto)); - let b = QuinnListener::new(futures_util::stream::once(async {config}), local_addr).try_bind().await?; - let holdings = a.holdings().iter().chain(b.holdings().iter()).cloned().collect(); - Ok(JoinedAcceptor::new(a, b, holdings)) - } - } -} +// let mut crypto = a.server_config.as_ref().clone(); +// crypto.alpn_protocols = vec![b"h3-29".to_vec(), b"h3-28".to_vec(), b"h3-27".to_vec(), b"h3".to_vec()]; +// let config = crate::conn::quinn::ServerConfig::with_crypto(Arc::new(crypto)); +// let b = QuinnListener::new(futures_util::stream::once(async {config}), local_addr).try_bind().await?; +// let holdings = a.holdings().iter().chain(b.holdings().iter()).cloned().collect(); +// Ok(JoinedAcceptor::new(a, b, holdings)) +// } +// } +// } /// AcmeAcceptor pub struct AcmeAcceptor { @@ -277,7 +293,7 @@ where inner: T, tls_acceptor: TlsAcceptor, check_duration: Duration, - ) -> IoResult> + ) -> crate::Result> where T: Send, { diff --git a/crates/core/src/conn/acme/mod.rs b/crates/core/src/conn/acme/mod.rs index f627b3ee6..d8ce117e9 100644 --- a/crates/core/src/conn/acme/mod.rs +++ b/crates/core/src/conn/acme/mod.rs @@ -76,10 +76,11 @@ use crate::{async_trait, Depot, FlowCtrl, Handler, Request, Response}; use cache::AcmeCache; pub use config::{AcmeConfig, AcmeConfigBuilder}; pub use listener::AcmeListener; -cfg_feature! { - #![feature = "quinn"] - pub use listener::AcmeQuinnListener; -} +// TODO: waiting quinn update +// cfg_feature! { +// #![feature = "quinn"] +// pub use listener::AcmeQuinnListener; +// } /// Letsencrypt production directory url pub const LETS_ENCRYPT_PRODUCTION: &str = "https://acme-v02.api.letsencrypt.org/directory"; diff --git a/crates/core/src/conn/acme/resolver.rs b/crates/core/src/conn/acme/resolver.rs index 31c7f6b60..2e5f74afd 100644 --- a/crates/core/src/conn/acme/resolver.rs +++ b/crates/core/src/conn/acme/resolver.rs @@ -9,7 +9,7 @@ use x509_parser::prelude::{FromDer, X509Certificate}; pub(crate) const ACME_TLS_ALPN_NAME: &[u8] = b"acme-tls/1"; -#[derive(Default)] +#[derive(Default, Debug)] pub(crate) struct ResolveServerCert { pub(crate) cert: RwLock>>, pub(crate) acme_keys: RwLock>>, diff --git a/crates/core/src/conn/joined.rs b/crates/core/src/conn/joined.rs index 8a96a525b..98f3be935 100644 --- a/crates/core/src/conn/joined.rs +++ b/crates/core/src/conn/joined.rs @@ -93,7 +93,7 @@ where { type Acceptor = JoinedAcceptor; - async fn try_bind(self) -> IoResult { + async fn try_bind(self) -> crate::Result { let a = self.a.try_bind().await?; let b = self.b.try_bind().await?; let holdings = a.holdings().iter().chain(b.holdings().iter()).cloned().collect(); diff --git a/crates/core/src/conn/mod.rs b/crates/core/src/conn/mod.rs index d2ce56705..d425a85b2 100644 --- a/crates/core/src/conn/mod.rs +++ b/crates/core/src/conn/mod.rs @@ -46,6 +46,7 @@ cfg_feature! { cfg_feature! { #![feature = "quinn"] pub mod quinn; + pub mod rustls_old; pub use self::quinn::{QuinnListener, H3Connection}; } cfg_feature! { @@ -201,7 +202,7 @@ pub trait Listener { } /// Bind and returns acceptor. - async fn try_bind(self) -> IoResult; + async fn try_bind(self) -> crate::Result; /// Join current Listener with the other. #[inline] diff --git a/crates/core/src/conn/native_tls/listener.rs b/crates/core/src/conn/native_tls/listener.rs index 2f5bd4875..4c6dac571 100644 --- a/crates/core/src/conn/native_tls/listener.rs +++ b/crates/core/src/conn/native_tls/listener.rs @@ -54,7 +54,7 @@ where { type Acceptor = NativeTlsAcceptor, C, T::Acceptor, E>; - async fn try_bind(self) -> IoResult { + async fn try_bind(self) -> crate::Result { Ok(NativeTlsAcceptor::new( self.config_stream.into_stream().boxed(), self.inner.try_bind().await?, diff --git a/crates/core/src/conn/openssl/listener.rs b/crates/core/src/conn/openssl/listener.rs index 0d8575f51..03eef63fa 100644 --- a/crates/core/src/conn/openssl/listener.rs +++ b/crates/core/src/conn/openssl/listener.rs @@ -57,7 +57,7 @@ where { type Acceptor = OpensslAcceptor, C, T::Acceptor, E>; - async fn try_bind(self) -> IoResult { + async fn try_bind(self) -> crate::Result { Ok(OpensslAcceptor::new( self.config_stream.into_stream().boxed(), self.inner.try_bind().await?, diff --git a/crates/core/src/conn/quinn/builder.rs b/crates/core/src/conn/quinn/builder.rs index 7b568c73e..13738a8dd 100644 --- a/crates/core/src/conn/quinn/builder.rs +++ b/crates/core/src/conn/quinn/builder.rs @@ -150,8 +150,8 @@ async fn process_web_transport( ) -> IoResult>> { let (parts, _body) = request.into_parts(); let mut request = hyper::Request::from_parts(parts, ReqBody::None); - request.extensions_mut().insert(Mutex::new(conn)); - request.extensions_mut().insert(stream); + request.extensions_mut().insert(Arc::new(Mutex::new(conn))); + request.extensions_mut().insert(Arc::new(stream)); let mut response = hyper::service::Service::call(&hyper_handler, request) .await diff --git a/crates/core/src/conn/quinn/listener.rs b/crates/core/src/conn/quinn/listener.rs index 48253be15..d58413973 100644 --- a/crates/core/src/conn/quinn/listener.rs +++ b/crates/core/src/conn/quinn/listener.rs @@ -54,7 +54,7 @@ where { type Acceptor = QuinnAcceptor, C, C::Error>; - async fn try_bind(self) -> IoResult { + async fn try_bind(self) -> crate::Result { let Self { config_stream, local_addr, diff --git a/crates/core/src/conn/quinn/mod.rs b/crates/core/src/conn/quinn/mod.rs index b0d8b7cf9..ad71cefab 100644 --- a/crates/core/src/conn/quinn/mod.rs +++ b/crates/core/src/conn/quinn/mod.rs @@ -13,7 +13,7 @@ pub use salvo_http3::http3_quinn::ServerConfig; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::async_trait; -use crate::conn::rustls::RustlsConfig; +use crate::conn::rustls_old::RustlsConfig; use crate::conn::{HttpBuilder, IntoConfigStream}; use crate::http::HttpConnection; use crate::service::HyperHandler; diff --git a/crates/core/src/conn/rustls/config.rs b/crates/core/src/conn/rustls/config.rs index 400503ab4..d94a04975 100644 --- a/crates/core/src/conn/rustls/config.rs +++ b/crates/core/src/conn/rustls/config.rs @@ -7,12 +7,12 @@ use std::path::Path; use std::sync::Arc; use futures_util::stream::{once, Once, Stream}; +use tokio_rustls::rustls::crypto::ring::sign::any_supported_type; +use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use tokio_rustls::rustls::server::{ClientHello, ResolvesServerCert, WebPkiClientVerifier}; +use tokio_rustls::rustls::sign::CertifiedKey; + pub use tokio_rustls::rustls::server::ServerConfig; -use tokio_rustls::rustls::server::{ - AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, ClientHello, NoClientAuth, ResolvesServerCert, -}; -use tokio_rustls::rustls::sign::{self, CertifiedKey}; -use tokio_rustls::rustls::{Certificate, PrivateKey}; use crate::conn::IntoConfigStream; @@ -84,27 +84,30 @@ impl Keycert { fn build_certified_key(&mut self) -> IoResult { let cert = rustls_pemfile::certs(&mut self.cert.as_ref()) - .map(|certs| certs.into_iter().map(Certificate).collect()) - .map_err(|_| IoError::new(ErrorKind::Other, "failed to parse tls certificates"))?; + .map(|certs| certs.into_iter().collect::>>()) + .next() + .ok_or_else(|| IoError::new(ErrorKind::Other, "failed to parse tls certificates"))?; let key = { let mut pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut self.key.as_ref()) + .collect::, _>>() .map_err(|_| IoError::new(ErrorKind::Other, "failed to parse tls private keys"))?; if !pkcs8.is_empty() { - PrivateKey(pkcs8.remove(0)) + PrivateKeyDer::Pkcs8(pkcs8.remove(0)) } else { let mut rsa = rustls_pemfile::rsa_private_keys(&mut self.key.as_ref()) + .collect::, _>>() .map_err(|_| IoError::new(ErrorKind::Other, "failed to parse tls private keys"))?; if !rsa.is_empty() { - PrivateKey(rsa.remove(0)) + PrivateKeyDer::Pkcs1(rsa.remove(0)) } else { return Err(IoError::new(ErrorKind::Other, "failed to parse tls private keys")); } } }; - let key = sign::any_supported_type(&key).map_err(|_| IoError::new(ErrorKind::Other, "invalid private key"))?; + let key = any_supported_type(&key).map_err(|_| IoError::new(ErrorKind::Other, "invalid private key"))?; Ok(CertifiedKey { cert, @@ -114,7 +117,6 @@ impl Keycert { } else { None }, - sct_list: None, }) } } @@ -225,17 +227,21 @@ impl RustlsConfig { } let client_auth = match &self.client_auth { - TlsClientAuth::Off => NoClientAuth::boxed(), + TlsClientAuth::Off => WebPkiClientVerifier::no_client_auth(), TlsClientAuth::Optional(trust_anchor) => { - AllowAnyAnonymousOrAuthenticatedClient::new(read_trust_anchor(trust_anchor)?).boxed() + WebPkiClientVerifier::builder(read_trust_anchor(trust_anchor)?.into()) + .allow_unauthenticated() + .build() + .map_err(|e| IoError::new(ErrorKind::Other, format!("failed to build server config: {}", e)))? } TlsClientAuth::Required(trust_anchor) => { - AllowAnyAuthenticatedClient::new(read_trust_anchor(trust_anchor)?).boxed() + WebPkiClientVerifier::builder(read_trust_anchor(trust_anchor)?.into()) + .build() + .map_err(|e| IoError::new(ErrorKind::Other, format!("failed to build server config: {}", e)))? } }; let mut config = ServerConfig::builder() - .with_safe_defaults() .with_client_cert_verifier(client_auth) .with_cert_resolver(Arc::new(CertResolver { certified_keys, @@ -254,6 +260,7 @@ impl TryInto for RustlsConfig { } } +#[derive(Debug)] pub(crate) struct CertResolver { fallback: Option>, certified_keys: HashMap>, diff --git a/crates/core/src/conn/rustls/listener.rs b/crates/core/src/conn/rustls/listener.rs index 64c25c1d7..009ae2b36 100644 --- a/crates/core/src/conn/rustls/listener.rs +++ b/crates/core/src/conn/rustls/listener.rs @@ -56,7 +56,7 @@ where { type Acceptor = RustlsAcceptor, C, T::Acceptor, E>; - async fn try_bind(self) -> IoResult { + async fn try_bind(self) -> crate::Result { Ok(RustlsAcceptor::new( self.config_stream.into_stream().boxed(), self.inner.try_bind().await?, diff --git a/crates/core/src/conn/rustls/mod.rs b/crates/core/src/conn/rustls/mod.rs index 6dfd8a189..9bd56994d 100644 --- a/crates/core/src/conn/rustls/mod.rs +++ b/crates/core/src/conn/rustls/mod.rs @@ -1,7 +1,7 @@ //! `RustlsListener` and utils. use std::io::{Error as IoError, ErrorKind, Result as IoResult}; -use tokio_rustls::rustls::{Certificate, RootCertStore}; +use tokio_rustls::rustls::{pki_types::CertificateDer, RootCertStore}; pub(crate) mod config; pub use config::{Keycert, RustlsConfig, ServerConfig}; @@ -9,13 +9,12 @@ pub use config::{Keycert, RustlsConfig, ServerConfig}; mod listener; pub use listener::{RustlsAcceptor, RustlsListener}; -#[inline] pub(crate) fn read_trust_anchor(mut trust_anchor: &[u8]) -> IoResult { - let certs = rustls_pemfile::certs(&mut trust_anchor)?; + let certs = rustls_pemfile::certs(&mut trust_anchor).collect::>>()?; let mut store = RootCertStore::empty(); for cert in certs { store - .add(&Certificate(cert)) + .add(CertificateDer::from(cert)) .map_err(|err| IoError::new(ErrorKind::Other, err.to_string()))?; } Ok(store) @@ -27,7 +26,7 @@ mod tests { use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; - use tokio_rustls::rustls::{ClientConfig, ServerName}; + use tokio_rustls::rustls::{ClientConfig, pki_types::ServerName}; use tokio_rustls::TlsConnector; use super::*; @@ -51,7 +50,6 @@ mod tests { let stream = TcpStream::connect(addr).await.unwrap(); let trust_anchor = include_bytes!("../../../certs/chain.pem"); let client_config = ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(read_trust_anchor(trust_anchor.as_slice()).unwrap()) .with_no_client_auth(); let connector = TlsConnector::from(Arc::new(client_config)); diff --git a/crates/core/src/conn/rustls_old/config.rs b/crates/core/src/conn/rustls_old/config.rs new file mode 100644 index 000000000..599bac0f8 --- /dev/null +++ b/crates/core/src/conn/rustls_old/config.rs @@ -0,0 +1,287 @@ +//! rustls 0.24 module +use std::collections::HashMap; +use std::fs::File; +use std::future::{ready, Ready}; +use std::io::{Error as IoError, ErrorKind, Read, Result as IoResult}; +use std::path::Path; +use std::sync::Arc; + +use futures_util::stream::{once, Once, Stream}; +pub use tokio_rustls_old::rustls::server::ServerConfig; +use tokio_rustls_old::rustls::server::{ + AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, ClientHello, NoClientAuth, ResolvesServerCert, +}; +use tokio_rustls_old::rustls::sign::{self, CertifiedKey}; +use tokio_rustls_old::rustls::{Certificate, PrivateKey}; + +use crate::conn::IntoConfigStream; + +use super::read_trust_anchor; + +/// Private key and certificate +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct Keycert { + /// Private key. + pub key: Vec, + /// Certificate. + pub cert: Vec, + /// OCSP response. + pub ocsp_resp: Vec, +} + +impl Default for Keycert { + fn default() -> Self { + Self::new() + } +} + +impl Keycert { + /// Create a new keycert. + #[inline] + pub fn new() -> Self { + Self { + key: vec![], + cert: vec![], + ocsp_resp: vec![], + } + } + /// Sets the Tls private key via File Path, returns [`IoError`] if the file cannot be open. + #[inline] + pub fn key_from_path(mut self, path: impl AsRef) -> IoResult { + let mut file = File::open(path.as_ref())?; + file.read_to_end(&mut self.key)?; + Ok(self) + } + + /// Sets the Tls private key via bytes slice. + #[inline] + pub fn key(mut self, key: impl Into>) -> Self { + self.key = key.into(); + self + } + + /// Specify the file path for the TLS certificate to use. + #[inline] + pub fn cert_from_path(mut self, path: impl AsRef) -> IoResult { + let mut file = File::open(path)?; + file.read_to_end(&mut self.cert)?; + Ok(self) + } + + /// Sets the Tls certificate via bytes slice + #[inline] + pub fn cert(mut self, cert: impl Into>) -> Self { + self.cert = cert.into(); + self + } + + /// Get ocsp_resp. + #[inline] + pub fn ocsp_resp(&self) -> &[u8] { + &self.ocsp_resp + } + + fn build_certified_key(&mut self) -> IoResult { + let cert = rustls_pemfile_old::certs(&mut self.cert.as_ref()) + .map(|certs| certs.into_iter().map(Certificate).collect()) + .map_err(|_| IoError::new(ErrorKind::Other, "failed to parse tls certificates"))?; + + let key = { + let mut pkcs8 = rustls_pemfile_old::pkcs8_private_keys(&mut self.key.as_ref()) + .map_err(|_| IoError::new(ErrorKind::Other, "failed to parse tls private keys"))?; + if !pkcs8.is_empty() { + PrivateKey(pkcs8.remove(0)) + } else { + let mut rsa = rustls_pemfile_old::rsa_private_keys(&mut self.key.as_ref()) + .map_err(|_| IoError::new(ErrorKind::Other, "failed to parse tls private keys"))?; + + if !rsa.is_empty() { + PrivateKey(rsa.remove(0)) + } else { + return Err(IoError::new(ErrorKind::Other, "failed to parse tls private keys")); + } + } + }; + + let key = sign::any_supported_type(&key).map_err(|_| IoError::new(ErrorKind::Other, "invalid private key"))?; + + Ok(CertifiedKey { + cert, + key, + ocsp: if !self.ocsp_resp.is_empty() { + Some(self.ocsp_resp.clone()) + } else { + None + }, + sct_list: None, + }) + } +} + +/// Tls client authentication configuration. +#[derive(Clone, Debug)] +pub(crate) enum TlsClientAuth { + /// No client auth. + Off, + /// Allow any anonymous or authenticated client. + Optional(Vec), + /// Allow any authenticated client. + Required(Vec), +} + +/// Builder to set the configuration for the Tls server. +#[derive(Clone, Debug)] +pub struct RustlsConfig { + fallback: Option, + keycerts: HashMap, + client_auth: TlsClientAuth, + alpn_protocols: Vec>, +} + +impl RustlsConfig { + /// Create new `RustlsConfig` + #[inline] + pub fn new(fallback: impl Into>) -> Self { + RustlsConfig { + fallback: fallback.into(), + keycerts: HashMap::new(), + client_auth: TlsClientAuth::Off, + alpn_protocols: vec![b"h2".to_vec(), b"http/1.1".to_vec()], + } + } + + /// Sets the trust anchor for optional Tls client authentication via file path. + /// + /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any + /// of the `client_auth_` methods, then client authentication is disabled by default. + #[inline] + pub fn client_auth_optional_path(mut self, path: impl AsRef) -> IoResult { + let mut data = vec![]; + let mut file = File::open(path)?; + file.read_to_end(&mut data)?; + self.client_auth = TlsClientAuth::Optional(data); + Ok(self) + } + + /// Sets the trust anchor for optional Tls client authentication via bytes slice. + /// + /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any + /// of the `client_auth_` methods, then client authentication is disabled by default. + pub fn client_auth_optional(mut self, trust_anchor: impl Into>) -> Self { + self.client_auth = TlsClientAuth::Optional(trust_anchor.into()); + self + } + + /// Sets the trust anchor for required Tls client authentication via file path. + /// + /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the + /// `client_auth_` methods, then client authentication is disabled by default. + #[inline] + pub fn client_auth_required_path(mut self, path: impl AsRef) -> IoResult { + let mut data = vec![]; + let mut file = File::open(path)?; + file.read_to_end(&mut data)?; + self.client_auth = TlsClientAuth::Required(data); + Ok(self) + } + + /// Sets the trust anchor for required Tls client authentication via bytes slice. + /// + /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the + /// `client_auth_` methods, then client authentication is disabled by default. + #[inline] + pub fn client_auth_required(mut self, trust_anchor: impl Into>) -> Self { + self.client_auth = TlsClientAuth::Required(trust_anchor.into()); + self + } + + /// Add a new keycert to be used for the given SNI `name`. + #[inline] + pub fn keycert(mut self, name: impl Into, keycert: Keycert) -> Self { + self.keycerts.insert(name.into(), keycert); + self + } + + /// Add a new keycert to be used for the given SNI `name`. + #[inline] + pub fn alpn_protocols(mut self, alpn_protocols: impl Into>>) -> Self { + self.alpn_protocols = alpn_protocols.into(); + self + } + + /// ServerConfig + pub(crate) fn build_server_config(mut self) -> IoResult { + let fallback = self + .fallback + .as_mut() + .map(|fallback| fallback.build_certified_key()) + .transpose()? + .map(Arc::new); + let mut certified_keys = HashMap::new(); + + for (name, keycert) in &mut self.keycerts { + certified_keys.insert(name.clone(), Arc::new(keycert.build_certified_key()?)); + } + + let client_auth = match &self.client_auth { + TlsClientAuth::Off => NoClientAuth::boxed(), + TlsClientAuth::Optional(trust_anchor) => { + AllowAnyAnonymousOrAuthenticatedClient::new(read_trust_anchor(trust_anchor)?).boxed() + } + TlsClientAuth::Required(trust_anchor) => { + AllowAnyAuthenticatedClient::new(read_trust_anchor(trust_anchor)?).boxed() + } + }; + + let mut config = ServerConfig::builder() + .with_safe_defaults() + .with_client_cert_verifier(client_auth) + .with_cert_resolver(Arc::new(CertResolver { + certified_keys, + fallback, + })); + config.alpn_protocols = self.alpn_protocols; + Ok(config) + } +} + +impl TryInto for RustlsConfig { + type Error = IoError; + + fn try_into(self) -> IoResult { + self.build_server_config() + } +} + +pub(crate) struct CertResolver { + fallback: Option>, + certified_keys: HashMap>, +} + +impl ResolvesServerCert for CertResolver { + fn resolve(&self, client_hello: ClientHello) -> Option> { + client_hello + .server_name() + .and_then(|name| self.certified_keys.get(name).map(Arc::clone)) + .or_else(|| self.fallback.clone()) + } +} + +impl IntoConfigStream for RustlsConfig { + type Stream = Once>; + + fn into_stream(self) -> Self::Stream { + once(ready(self)) + } +} +impl IntoConfigStream for T +where + T: Stream + Send + 'static, +{ + type Stream = T; + + fn into_stream(self) -> Self { + self + } +} \ No newline at end of file diff --git a/crates/core/src/conn/rustls_old/mod.rs b/crates/core/src/conn/rustls_old/mod.rs new file mode 100644 index 000000000..9b551ca52 --- /dev/null +++ b/crates/core/src/conn/rustls_old/mod.rs @@ -0,0 +1,19 @@ +//! `RustlsListener` and utils. +use std::io::{Error as IoError, ErrorKind, Result as IoResult}; + +use tokio_rustls_old::rustls::{Certificate, RootCertStore}; + +pub(crate) mod config; +pub use config::{Keycert, RustlsConfig, ServerConfig}; + +#[inline] +pub(crate) fn read_trust_anchor(mut trust_anchor: &[u8]) -> IoResult { + let certs = rustls_pemfile_old::certs(&mut trust_anchor)?; + let mut store = RootCertStore::empty(); + for cert in certs { + store + .add(&Certificate(cert)) + .map_err(|err| IoError::new(ErrorKind::Other, err.to_string()))?; + } + Ok(store) +} \ No newline at end of file diff --git a/crates/core/src/conn/tcp.rs b/crates/core/src/conn/tcp.rs index f94451713..1074316a0 100644 --- a/crates/core/src/conn/tcp.rs +++ b/crates/core/src/conn/tcp.rs @@ -102,8 +102,8 @@ where { type Acceptor = TcpAcceptor; - async fn try_bind(self) -> IoResult { - TokioTcpListener::bind(self.local_addr).await?.try_into() + async fn try_bind(self) -> crate::Result { + Ok(TokioTcpListener::bind(self.local_addr).await?.try_into()?) } } /// `TcpAcceptor` is used to accept a TCP connection. diff --git a/crates/core/src/conn/unix.rs b/crates/core/src/conn/unix.rs index bb438b5d5..722bffba4 100644 --- a/crates/core/src/conn/unix.rs +++ b/crates/core/src/conn/unix.rs @@ -9,7 +9,7 @@ use http::uri::Scheme; use tokio::net::{UnixListener as TokioUnixListener, UnixStream}; use nix::unistd::{Gid, chown, Uid}; -use crate::async_trait; +use crate::{Error, async_trait}; use crate::conn::{Holding, HttpBuilder}; use crate::http::{HttpConnection, Version}; use crate::service::HyperHandler; @@ -55,12 +55,12 @@ where { type Acceptor = UnixAcceptor; - async fn try_bind(self) -> IoResult { + async fn try_bind(self) -> crate::Result { let inner = match (self.permissions, self.owner) { (Some(permissions), Some((uid, gid))) => { let inner = TokioUnixListener::bind(self.path.clone())?; set_permissions(self.path.clone(), permissions)?; - chown(self.path.as_ref().as_os_str().into(), uid, gid)?; + chown(self.path.as_ref().as_os_str().into(), uid, gid).map_err(Error::other)?; inner } (Some(permissions), None) => { @@ -70,7 +70,7 @@ where } (None, Some((uid, gid))) => { let inner = TokioUnixListener::bind(self.path.clone())?; - chown(self.path.as_ref().as_os_str().into(), uid, gid)?; + chown(self.path.as_ref().as_os_str().into(), uid, gid).map_err(Error::other)?; inner } (None, None) => TokioUnixListener::bind(self.path)?, diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index d7c50a0a8..c9c908f92 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -23,6 +23,8 @@ pub enum Error { Io(IoError), /// SerdeJson error. SerdeJson(serde_json::Error), + /// Invalid URI error. + InvalidUri(http::uri::InvalidUri), #[cfg(feature = "quinn")] #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))] /// H3 error. @@ -55,6 +57,7 @@ impl Display for Error { Self::HttpStatus(e) => Display::fmt(e, f), Self::Io(e) => Display::fmt(e, f), Self::SerdeJson(e) => Display::fmt(e, f), + Self::InvalidUri(e) => Display::fmt(e, f), #[cfg(feature = "quinn")] Self::H3(e) => Display::fmt(e, f), #[cfg(feature = "anyhow")] @@ -98,6 +101,12 @@ impl From for Error { Error::Io(e) } } +impl From for Error { + #[inline] + fn from(e: http::uri::InvalidUri) -> Error { + Error::InvalidUri(e) + } +} impl From for Error { #[inline] fn from(e: serde_json::Error) -> Error { diff --git a/crates/core/src/http/form.rs b/crates/core/src/http/form.rs index 091a9d2ab..68ee18eea 100644 --- a/crates/core/src/http/form.rs +++ b/crates/core/src/http/form.rs @@ -17,7 +17,7 @@ use tokio::fs::File; use tokio::io::AsyncWriteExt; use crate::http::body::ReqBody; -use crate::http::header::{HeaderMap, CONTENT_TYPE}; +use crate::http::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE}; use crate::http::ParseError; /// The extracted text fields and uploaded files from a `multipart/form-data` request. @@ -66,7 +66,8 @@ impl FormData { let mut multipart = Multipart::new(body, boundary); while let Some(mut field) = multipart.next_field().await? { if let Some(name) = field.name().map(|s| s.to_owned()) { - if field.headers().get(CONTENT_TYPE).is_some() { + if field.headers().get("content-type").is_some() { + //TODO: use CONTENT_TYPE after multer updated form_data.files.insert(name, FilePart::create(&mut field).await?); } else { form_data.fields.insert(name, field.text().await?); @@ -162,9 +163,17 @@ impl FilePart { size += chunk.len() as u64; file.write_all(&chunk).await?; } + //TODO: Will remove after multer updated + let mut headers = HeaderMap::with_capacity(field.headers().len()); + headers.extend(field.headers().into_iter().map(|(name, value)| { + let name = HeaderName::from_bytes(name.as_ref()).unwrap(); + let value = HeaderValue::from_bytes(value.as_ref()).unwrap(); + (name, value) + })); Ok(FilePart { name, - headers: field.headers().to_owned(), + //TODO: use field.headers().to_owned() after multer updated + headers, path, size, temp_dir, diff --git a/crates/core/src/http/request.rs b/crates/core/src/http/request.rs index d7513f47a..a92358de4 100644 --- a/crates/core/src/http/request.rs +++ b/crates/core/src/http/request.rs @@ -2,6 +2,8 @@ use std::error::Error as StdError; use std::fmt::{self, Formatter}; +#[cfg(feature = "quinn")] +use std::sync::Arc; use bytes::Bytes; #[cfg(feature = "cookie")] @@ -455,14 +457,14 @@ impl Request { let stream = self.extensions.remove::, Bytes>>(); if conn.is_some() && stream.is_some() { let session = crate::proto::WebTransportSession::accept(stream.unwrap(), conn.unwrap().into_inner().unwrap()).await?; - self.extensions.insert(session); + self.extensions.insert(Arc::new(session)); Ok(self.extensions.get_mut::>().unwrap()) } else { if let Some(conn) = conn { - self.extensions_mut().insert(conn); + self.extensions_mut().insert(Arc::new(conn)); } if let Some(stream) = stream { - self.extensions_mut().insert(stream); + self.extensions_mut().insert(Arc::new(stream)); } Err(crate::Error::Other("invalid web transport".into())) } diff --git a/crates/core/src/rt.rs b/crates/core/src/rt.rs index 1b33c57ab..df00d2da6 100644 --- a/crates/core/src/rt.rs +++ b/crates/core/src/rt.rs @@ -7,5 +7,5 @@ pub use hyper::rt::*; /// Tokio runtimes pub mod tokio { - pub use salvo_utils::rt::{TokioExecutor, TokioIo}; + pub use hyper_util::rt::{TokioExecutor, TokioIo}; } diff --git a/crates/core/src/service.rs b/crates/core/src/service.rs index f43a78a41..6373b509e 100644 --- a/crates/core/src/service.rs +++ b/crates/core/src/service.rs @@ -235,19 +235,19 @@ impl HyperHandler { .extensions .remove::>() { - res.extensions.insert(session); + res.extensions.insert(Arc::new(session)); } if let Some(conn) = req .extensions .remove::>>() { - res.extensions.insert(conn); + res.extensions.insert(Arc::new(conn)); } if let Some(stream) = req .extensions .remove::, Bytes>>() { - res.extensions.insert(stream); + res.extensions.insert(Arc::new(stream)); } } res diff --git a/crates/core/src/tower_compat.rs b/crates/core/src/tower_compat.rs index 569ae8782..edecf4810 100644 --- a/crates/core/src/tower_compat.rs +++ b/crates/core/src/tower_compat.rs @@ -4,11 +4,12 @@ use std::fmt; use std::future::Future; use std::io::{Error as IoError, ErrorKind}; use std::marker::PhantomData; +use std::sync::Arc; use std::task::{Context, Poll}; use futures_util::future::{BoxFuture, FutureExt}; use http_body_util::BodyExt; -use hyper::body::{Body, Bytes, Frame}; +use hyper::body::{Body, Bytes}; use tower::buffer::Buffer; use tower::{Layer, Service, ServiceExt}; @@ -51,6 +52,7 @@ where SB::Error: StdError + Send + Sync + 'static, E: StdError + Send + Sync + 'static, Svc: Service, Response = hyper::Response, Future = Fut> + Send + Sync + Clone + 'static, + Svc::Error: StdError + Send + Sync + 'static, Fut: Future, E>> + Send + 'static, { async fn handle(&self, req: &mut Request, _depot: &mut Depot, res: &mut Response, _ctrl: &mut FlowCtrl) { @@ -71,22 +73,13 @@ where let hyper_res = match svc.call(hyper_req).await { Ok(hyper_res) => hyper_res, - Err(_) => { - tracing::error!("call tower service failed."); - res.render(StatusError::internal_server_error().cause("call tower service failed.")); + Err(e) => { + tracing::error!(error = ?e, "call tower service failed: {}", e); + res.render(StatusError::internal_server_error().cause(format!("call tower service failed: {}", e))); return; } } - .map(|res| { - ResBody::Boxed(Box::pin( - res.map_frame(|f| match f.into_data() { - //TODO: should use Frame::map_data after new version of hyper is released. - Ok(data) => Frame::data(data.into()), - Err(frame) => Frame::trailers(frame.into_trailers().expect("frame must be trailers")), - }) - .map_err(|e| e.into()), - )) - }); + .map(|res| ResBody::Boxed(Box::pin(res.map_frame(|f| f.map_data(|data| data.into())).map_err(|e|e.into())))); res.merge_hyper(hyper_res); } @@ -132,12 +125,13 @@ impl Service> for FlowCtrlService { } fn call(&mut self, mut hyper_req: hyper::Request) -> Self::Future { + let ctx = hyper_req.extensions_mut().remove::>().and_then(Arc::into_inner); let Some(FlowCtrlInContext { mut ctrl, mut request, mut depot, mut response, - }) = hyper_req.extensions_mut().remove::() + }) = ctx else { return futures_util::future::ready(Err(IoError::new( ErrorKind::Other, @@ -150,7 +144,7 @@ impl Service> for FlowCtrlService { ctrl.call_next(&mut request, &mut depot, &mut response).await; response .extensions - .insert(FlowCtrlOutContext::new(ctrl, request, depot)); + .insert(Arc::new(FlowCtrlOutContext::new(ctrl, request, depot))); Ok(response.strip_to_hyper()) }) } @@ -212,36 +206,30 @@ where std::mem::take(depot), std::mem::take(res), ); - hyper_req.extensions_mut().insert(ctx); + hyper_req.extensions_mut().insert(Arc::new(ctx)); let mut hyper_res = match svc.call(hyper_req).await { Ok(hyper_res) => hyper_res, - Err(_) => { - tracing::error!("call tower service failed."); - res.render(StatusError::internal_server_error().cause("call tower service failed.")); + Err(e) => { + tracing::error!(error = ?e, "call tower service failed: {}", e); + res.render(StatusError::internal_server_error().cause(format!("call tower service failed: {}", e))); return; } } - .map(|res| { - ResBody::Boxed(Box::pin( - res.map_frame(|f| match f.into_data() { - //TODO: should use Frame::map_data after new version of hyper is released. - Ok(data) => Frame::data(data.into()), - Err(frame) => Frame::trailers(frame.into_trailers().expect("frame must be trailers")), - }) - .map_err(|e| e.into()), - )) - }); + .map(|res| ResBody::Boxed(Box::pin(res.map_frame(|f| f.map_data(|data| data.into())).map_err(|e|e.into())))); let origin_depot = depot; let origin_ctrl = ctrl; - if let Some(FlowCtrlOutContext { ctrl, request, depot }) = - hyper_res.extensions_mut().remove::() + + let ctx = hyper_res.extensions_mut().remove::>().and_then(Arc::into_inner); + if let Some(FlowCtrlOutContext { ctrl, request, depot }) = ctx { *origin_depot = depot; *origin_ctrl = ctrl; *req = request; } else { - tracing::debug!("`FlowCtrlOutContext` does not exists in response extensions, `FlowCtrlService` may not be used."); + tracing::debug!( + "`FlowCtrlOutContext` does not exists in response extensions, `FlowCtrlService` may not be used." + ); } res.merge_hyper(hyper_res); diff --git a/crates/jwt-auth/Cargo.toml b/crates/jwt-auth/Cargo.toml index 456177ede..37ff52bc3 100644 --- a/crates/jwt-auth/Cargo.toml +++ b/crates/jwt-auth/Cargo.toml @@ -20,12 +20,15 @@ rustdoc-args = ["--cfg", "docsrs"] [features] default = [] full = ["oidc"] -oidc = ["dep:reqwest"] +oidc = ["dep:bytes", "dep:hyper-tls", "dep:hyper-util", "dep:http-body-util"] [dependencies] base64 = { workspace = true } +bytes = { workspace = true, optional = true } jsonwebtoken = { workspace = true } -reqwest = { workspace = true, optional = true, features = ["rustls-tls", "json"] } +http-body-util = { workspace = true, optional = true } +hyper-tls = { workspace = true, optional = true } +hyper-util = { workspace = true, optional = true, features = ["client-legacy", "http1", "http2", "tokio"] } salvo_core = { workspace = true, features = ["cookie"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } diff --git a/crates/jwt-auth/src/lib.rs b/crates/jwt-auth/src/lib.rs index 0af5455f5..92a8b8e86 100644 --- a/crates/jwt-auth/src/lib.rs +++ b/crates/jwt-auth/src/lib.rs @@ -59,11 +59,18 @@ const ALL_METHODS: [Method; 9] = [ /// JwtAuthError #[derive(Debug, Error)] pub enum JwtAuthError { - /// HTTP request failed + /// HTTP client error #[cfg(feature = "oidc")] #[cfg_attr(docsrs, doc(cfg(feature = "oidc")))] - #[error("HTTP request failed")] - ReqwestError(#[from] reqwest::Error), + #[error("ClientError")] + ClientError(#[from] hyper_util::client::legacy::Error), + + /// Error happened in hyper. + #[cfg(feature = "oidc")] + #[cfg_attr(docsrs, doc(cfg(feature = "oidc")))] + #[error("HyperError")] + Hyper(#[from] salvo_core::hyper::Error), + /// InvalidUri #[error("InvalidUri")] InvalidUri(#[from] salvo_core::http::uri::InvalidUri), diff --git a/crates/jwt-auth/src/oidc/mod.rs b/crates/jwt-auth/src/oidc/mod.rs index 8c6e67b81..e2511e1d9 100644 --- a/crates/jwt-auth/src/oidc/mod.rs +++ b/crates/jwt-auth/src/oidc/mod.rs @@ -1,20 +1,22 @@ -//! Oidc(OpenID Connect) support module +//! Oidc(OpenID Connect) supports. use std::future::Future; use std::str::FromStr; use std::sync::Arc; -use std::time::Duration; use std::time::SystemTime; use std::time::UNIX_EPOCH; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; +use bytes::Bytes; +use http_body_util::{BodyExt, Full}; +use hyper_tls::HttpsConnector; +use hyper_util::client::legacy::{connect::HttpConnector, Client}; +use hyper_util::rt::TokioExecutor; use jsonwebtoken::jwk::{Jwk, JwkSet}; use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation}; -use reqwest::Client; -use salvo_core::async_trait; -use salvo_core::http::header::CACHE_CONTROL; -use salvo_core::Depot; +use salvo_core::http::{header::CACHE_CONTROL, uri::Uri}; +use salvo_core::{async_trait, Depot}; use serde::de::DeserializeOwned; use serde::Deserialize; use tokio::sync::{Notify, RwLock}; @@ -25,11 +27,13 @@ mod cache; pub use cache::{CachePolicy, CacheState, JwkSetStore, UpdateAction}; +pub(super) type HyperClient = Client, Full>; + /// ConstDecoder will decode token with a static secret. #[derive(Clone)] pub struct OidcDecoder { issuer: String, - http_client: reqwest::Client, + http_client: HyperClient, cache: Arc>, cache_state: Arc, notifier: Arc, @@ -62,7 +66,7 @@ where /// The issuer URL of the token. eg: `https://xx-xx.clerk.accounts.dev` pub issuer: T, /// The http client for the decoder. - pub http_client: Option, + pub http_client: Option, /// The validation options for the decoder. pub validation: Option, } @@ -79,7 +83,7 @@ where } } /// Set the http client for the decoder. - pub fn http_client(mut self, client: reqwest::Client) -> Self { + pub fn http_client(mut self, client: HyperClient) -> Self { self.http_client = Some(client); self } @@ -106,7 +110,7 @@ where let cache_state = Arc::new(CacheState::new()); let http_client = - http_client.unwrap_or_else(|| Client::builder().timeout(Duration::from_secs(30)).build().unwrap()); + http_client.unwrap_or_else(|| Client::builder(TokioExecutor::new()).build(HttpsConnector::new())); let decoder = OidcDecoder { issuer, http_client, @@ -139,8 +143,9 @@ impl OidcDecoder { format!("{}/.well-known/openid-configuration", &self.issuer) } async fn get_config(&self) -> Result { - let res = self.http_client.get(self.config_url()).send().await?; - let config = res.json().await?; + let res = self.http_client.get(self.config_url().parse::()?).await?; + let body = res.into_body().collect().await?.to_bytes(); + let config = serde_json::from_slice(&body)?; Ok(config) } async fn jwks_uri(&self) -> Result { @@ -149,10 +154,10 @@ impl OidcDecoder { /// Triggers an HTTP Request to get a fresh `JwkSet` async fn get_jwks(&self) -> Result { - let uri = &self.jwks_uri().await?; + let uri = self.jwks_uri().await?.parse::()?; // Get the jwks endpoint tracing::debug!("Requesting JWKS From Uri: {uri}"); - let res = self.http_client.get(uri).send().await?; + let res = self.http_client.get(uri).await?; let cache_policy = { // Determine it from the cache_control header @@ -160,10 +165,11 @@ impl OidcDecoder { let cache_policy = CachePolicy::from_header_val(cache_control); Some(cache_policy) }; + let jwks = res.into_body().collect().await?.to_bytes(); let fetched_at = current_time(); Ok(JwkSetFetch { - jwks: res.json().await?, + jwks: serde_json::from_slice(&jwks)?, cache_policy, fetched_at, }) diff --git a/crates/oapi/Cargo.toml b/crates/oapi/Cargo.toml index 592bec20b..829ab6cdc 100644 --- a/crates/oapi/Cargo.toml +++ b/crates/oapi/Cargo.toml @@ -3,7 +3,7 @@ name = "salvo-oapi" version = { workspace = true } edition = "2021" description = "OpenApi support for Salvo web framework" -readme = { workspace = true } +readme = "./README.md" license = { workspace = true } documentation = "https://docs.rs/salvo-oapi/" homepage = { workspace = true } diff --git a/crates/oapi/README.md b/crates/oapi/README.md new file mode 100644 index 000000000..5aa29be76 --- /dev/null +++ b/crates/oapi/README.md @@ -0,0 +1,11 @@ +# salvo-oapi + +## Library to Provide a OpenAPI supports for Salvo. + +This is offical crate, so you can enable it in `Cargo.toml` like this: + +```toml +salvo = { version = "*", features=["oapi"] } +``` + +[![Docs](https://docs.rs/salvo-oapi/badge.svg)](https://docs.rs/salvo-oapi) diff --git a/crates/otel/Cargo.toml b/crates/otel/Cargo.toml index 345d29252..05f41cc05 100644 --- a/crates/otel/Cargo.toml +++ b/crates/otel/Cargo.toml @@ -25,6 +25,7 @@ opentelemetry-http = { workspace = true } opentelemetry-semantic-conventions = { workspace = true } opentelemetry = { workspace = true, features = ["metrics"] } salvo_core = { workspace = true, default-features = false } +headers03 = {version = "0.3", package = "headers"} [dev-dependencies] salvo_core = { workspace = true, features = ["test"] } diff --git a/crates/otel/src/tracing.rs b/crates/otel/src/tracing.rs index 4c5848af5..cac5f2618 100644 --- a/crates/otel/src/tracing.rs +++ b/crates/otel/src/tracing.rs @@ -1,3 +1,4 @@ +use headers03::{HeaderMap, HeaderName, HeaderValue}; use opentelemetry::trace::{FutureExt, Span, SpanKind, TraceContextExt, Tracer}; use opentelemetry::{global, Context}; use opentelemetry_http::HeaderExtractor; @@ -26,8 +27,15 @@ where async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) { let remote_addr = req.remote_addr().to_string(); - let parent_cx = - global::get_text_map_propagator(|propagator| propagator.extract(&HeaderExtractor(req.headers()))); + //TODO: Will remove after opentelemetry_http updated + let mut headers = HeaderMap::with_capacity(req.headers().len()); + headers.extend(req.headers().into_iter().map(|(name, value)| { + let name = HeaderName::from_bytes(name.as_ref()).unwrap(); + let value = HeaderValue::from_bytes(value.as_ref()).unwrap(); + (name, value) + })); + + let parent_cx = global::get_text_map_propagator(|propagator| propagator.extract(&HeaderExtractor(&headers))); let mut attributes = Vec::new(); attributes.push(resource::TELEMETRY_SDK_NAME.string(env!("CARGO_CRATE_NAME"))); diff --git a/crates/proxy/Cargo.toml b/crates/proxy/Cargo.toml index f0c98d210..cfe7395e8 100644 --- a/crates/proxy/Cargo.toml +++ b/crates/proxy/Cargo.toml @@ -24,8 +24,8 @@ tracing = { workspace = true } tokio = { workspace = true } fastrand = { workspace = true } hyper = { workspace = true, features = ["server", "http1", "http2"] } -reqwest = { workspace = true, features = ["rustls-tls", "stream"] } -salvo-utils = { workspace = true, features = ["runtime"] } +hyper-util = { workspace = true, features = ["tokio", "http1", "http2", "client-legacy"] } +hyper-tls = { workspace = true } percent-encoding = { workspace = true } [dev-dependencies] diff --git a/crates/proxy/src/clients.rs b/crates/proxy/src/clients.rs new file mode 100644 index 000000000..a3e4caea8 --- /dev/null +++ b/crates/proxy/src/clients.rs @@ -0,0 +1,74 @@ +use hyper::upgrade::OnUpgrade; +use hyper_tls::HttpsConnector; +use hyper_util::client::legacy::{connect::HttpConnector, Client as HyperUtilClient}; +use hyper_util::rt::TokioExecutor; +use salvo_core::http::{ReqBody, ResBody, StatusCode}; +use salvo_core::rt::tokio::TokioIo; +use salvo_core::{async_trait, Error}; +use tokio::io::copy_bidirectional; + +use super::{HyperRequest, HyperResponse}; + +/// A [`Client`] implementation based on [`hyper_util::client::legacy::Client`]. +pub struct HyperClient { + inner: HyperUtilClient, ReqBody>, +} +impl Default for HyperClient { + fn default() -> Self { + Self { + inner: HyperUtilClient::builder(TokioExecutor::new()).build(HttpsConnector::new()), + } + } +} +impl HyperClient { + /// Create a new `HyperClient` with the given `HyperClient`. + pub fn new(inner: HyperUtilClient, ReqBody>) -> Self { + Self { inner } + } +} + +#[async_trait] +impl super::Client for HyperClient { + type Error = salvo_core::Error; + + async fn execute( + &self, + proxied_request: HyperRequest, + request_upgraded: Option, + ) -> Result { + let request_upgrade_type = crate::get_upgrade_type(proxied_request.headers()).map(|s| s.to_owned()); + + let mut response = self.inner.request(proxied_request).await.map_err(Error::other)?; + + if response.status() == StatusCode::SWITCHING_PROTOCOLS { + let response_upgrade_type = crate::get_upgrade_type(response.headers()); + if request_upgrade_type.as_deref() == response_upgrade_type { + let response_upgraded = hyper::upgrade::on(&mut response).await?; + if let Some(request_upgraded) = request_upgraded { + tokio::spawn(async move { + match request_upgraded.await { + Ok(request_upgraded) => { + let mut request_upgraded = TokioIo::new(request_upgraded); + let mut response_upgraded = TokioIo::new(response_upgraded); + if let Err(e) = copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await + { + tracing::error!(error = ?e, "coping between upgraded connections failed."); + } + } + Err(e) => { + tracing::error!(error = ?e, "upgrade request failed."); + } + } + }); + } else { + return Err(Error::other("request does not have an upgrade extension.")); + } + } else { + return Err(Error::other("upgrade type mismatch")); + } + } + Ok(response.map(ResBody::Hyper)) + } +} + +//TODO: ReqwestClient diff --git a/crates/proxy/src/lib.rs b/crates/proxy/src/lib.rs index 327615026..5e01af3f4 100644 --- a/crates/proxy/src/lib.rs +++ b/crates/proxy/src/lib.rs @@ -11,17 +11,17 @@ #![warn(rustdoc::broken_intra_doc_links)] use std::convert::{Infallible, TryFrom}; +use std::error::Error as StdError; -use futures_util::TryStreamExt; use hyper::upgrade::OnUpgrade; use percent_encoding::{utf8_percent_encode, CONTROLS}; -use reqwest::Client; use salvo_core::http::header::{HeaderMap, HeaderName, HeaderValue, CONNECTION, HOST, UPGRADE}; use salvo_core::http::uri::Uri; use salvo_core::http::{ReqBody, ResBody, StatusCode}; -use salvo_core::rt::tokio::TokioIo; use salvo_core::{async_trait, BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response}; -use tokio::io::copy_bidirectional; + +mod clients; +pub use clients::*; type HyperRequest = hyper::Request; type HyperResponse = hyper::Response; @@ -35,10 +35,19 @@ pub(crate) fn encode_url_path(path: &str) -> String { .join("/") } +/// Client trait. +#[async_trait] +pub trait Client: Send + Sync + 'static { + /// Error type. + type Error: StdError + Send + Sync + 'static; + /// Elect a upstream to process current request. + async fn execute(&self, req: HyperRequest, upgraded: Option) -> Result; +} + /// Upstreams trait. pub trait Upstreams: Send + Sync + 'static { /// Error type. - type Error; + type Error: StdError + Send + Sync + 'static; /// Elect a upstream to process current request. fn elect(&self) -> Result<&str, Self::Error>; } @@ -100,34 +109,39 @@ pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option /// Handler that can proxy request to other server. #[non_exhaustive] - -pub struct Proxy { +pub struct Proxy +where + U: Upstreams, + C: Client, +{ /// Upstreams list. pub upstreams: U, /// [`Client`] for proxy. - pub client: Client, + pub client: C, /// Url path getter. pub url_path_getter: UrlPartGetter, /// Url query getter. pub url_query_getter: UrlPartGetter, } +impl Proxy +where + U: Upstreams, + U::Error: Into, +{ + /// Create new `Proxy` which use default hyper util client. + pub fn default_hyper_client(upstreams: U) -> Self { + Proxy::new(upstreams, HyperClient::default()) + } +} -impl Proxy +impl Proxy where U: Upstreams, U::Error: Into, + C: Client, { /// Create new `Proxy` with upstreams list. - pub fn new(upstreams: U) -> Self { - Proxy { - upstreams, - client: Client::new(), - url_path_getter: Box::new(default_url_path_getter), - url_query_getter: Box::new(default_url_query_getter), - } - } - /// Create new `Proxy` with upstreams list and [`Client`]. - pub fn with_client(upstreams: U, client: Client) -> Self { + pub fn new(upstreams: U, client: C) -> Self { Proxy { upstreams, client, @@ -169,12 +183,12 @@ where /// Get client reference. #[inline] - pub fn client(&self) -> &Client { + pub fn client(&self) -> &C { &self.client } /// Get client mutable reference. #[inline] - pub fn client_mut(&mut self) -> &mut Client { + pub fn client_mut(&mut self) -> &mut C { &mut self.client } @@ -234,80 +248,22 @@ where // } build.body(req.take_body()).map_err(Error::other) } - - #[inline] - async fn call_proxied_server( - &self, - proxied_request: HyperRequest, - request_upgraded: Option, - ) -> Result { - let request_upgrade_type = get_upgrade_type(proxied_request.headers()).map(|s| s.to_owned()); - - let proxied_request = - proxied_request.map(|s| reqwest::Body::wrap_stream(s.map_ok(|s| s.into_data().unwrap_or_default()))); - let response = self - .client - .execute(proxied_request.try_into().map_err(Error::other)?) - .await - .map_err(Error::other)?; - - let res_headers = response.headers().clone(); - let hyper_response = hyper::Response::builder() - .status(response.status()) - .version(response.version()); - - let mut hyper_response = if response.status() == StatusCode::SWITCHING_PROTOCOLS { - let response_upgrade_type = get_upgrade_type(response.headers()); - - if request_upgrade_type.as_deref() == response_upgrade_type { - let mut response_upgraded = response - .upgrade() - .await - .map_err(|e| Error::other(format!("response does not have an upgrade extension. {}", e)))?; - if let Some(request_upgraded) = request_upgraded { - tokio::spawn(async move { - match request_upgraded.await { - Ok(request_upgraded) => { - let mut request_upgraded = TokioIo::new(request_upgraded); - if let Err(e) = copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await - { - tracing::error!(error = ?e, "coping between upgraded connections failed"); - } - } - Err(e) => { - tracing::error!(error = ?e, "upgrade request failed"); - } - } - }); - } else { - return Err(Error::other("request does not have an upgrade extension")); - } - } else { - return Err(Error::other("upgrade type mismatch")); - } - hyper_response.body(ResBody::None).map_err(Error::other)? - } else { - hyper_response - .body(ResBody::stream(response.bytes_stream())) - .map_err(Error::other)? - }; - *hyper_response.headers_mut() = res_headers; - Ok(hyper_response) - } } #[async_trait] -impl Handler for Proxy +impl Handler for Proxy where U: Upstreams, U::Error: Into, + C: Client, { #[inline] async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) { match self.build_proxied_request(req, depot) { Ok(proxied_request) => { match self - .call_proxied_server(proxied_request, req.extensions_mut().remove()) + .client + .execute(proxied_request, req.extensions_mut().remove()) .await { Ok(response) => { @@ -326,10 +282,10 @@ where res.body(body); } Err(e) => { - tracing::error!(error = ?e, uri = ?req.uri(), "get response data failed"); + tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e); res.status_code(StatusCode::INTERNAL_SERVER_ERROR); } - }; + } } Err(e) => { tracing::error!(error = ?e, "build proxied request failed"); @@ -375,7 +331,7 @@ mod tests { #[test] fn test_upstreams_elect() { let upstreams = vec!["https://www.example.com", "https://www.example2.com"]; - let proxy = Proxy::new(upstreams.clone()); + let proxy = Proxy::default_hyper_client(upstreams.clone()); let elected_upstream = proxy.upstreams().elect().unwrap(); assert!(upstreams.contains(&elected_upstream)); } @@ -391,8 +347,9 @@ mod tests { #[tokio::test] async fn test_proxy() { - let router = - Router::new().push(Router::with_path("rust/<**rest>").goal(Proxy::new(vec!["https://www.rust-lang.org"]))); + let router = Router::new().push( + Router::with_path("rust/<**rest>").goal(Proxy::default_hyper_client(vec!["https://www.rust-lang.org"])), + ); let content = TestClient::get("http://127.0.0.1:5801/rust/tools/install") .send(router) @@ -404,7 +361,7 @@ mod tests { } #[test] fn test_others() { - let mut handler = Proxy::new(["https://www.bing.com"]); + let mut handler = Proxy::default_hyper_client(["https://www.bing.com"]); assert_eq!(handler.upstreams().len(), 1); assert_eq!(handler.upstreams_mut().len(), 1); } diff --git a/crates/rate-limiter/src/lib.rs b/crates/rate-limiter/src/lib.rs index 6bad8e1c9..5ec017474 100644 --- a/crates/rate-limiter/src/lib.rs +++ b/crates/rate-limiter/src/lib.rs @@ -169,7 +169,7 @@ where let quota = match self.quota_getter.get(&key).await { Ok(quota) => quota, Err(e) => { - tracing::error!(error = ?e, "RateLimiter error"); + tracing::error!(error = ?e, "RateLimiter error: {}", e); res.status_code(StatusCode::INTERNAL_SERVER_ERROR); ctrl.skip_rest(); return; @@ -178,7 +178,7 @@ where let mut guard = match self.store.load_guard(&key, &self.guard).await { Ok(guard) => guard, Err(e) => { - tracing::error!(error = ?e, "RateLimiter error"); + tracing::error!(error = ?e, "RateLimiter error: {}", e); res.status_code(StatusCode::INTERNAL_SERVER_ERROR); ctrl.skip_rest(); return; diff --git a/crates/salvo/src/lib.rs b/crates/salvo/src/lib.rs index 06b331676..e64fb6f43 100644 --- a/crates/salvo/src/lib.rs +++ b/crates/salvo/src/lib.rs @@ -1,4 +1,6 @@ -//! Salvo is a powerful and simple Rust web server framework. Read more: +//! Salvo is a powerful and simple Rust web server framework. +//! +//! Read more: #![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")] #![doc(html_logo_url = "https://salvo.rs/images/logo.svg")] diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 9a8723fc4..46e4edcea 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -1,6 +1,6 @@ [workspace] members = ["*"] -exclude = ["db-prisma-orm", "target", "temp"] +exclude = ["db-prisma-orm", "target", "temp", "acme-http01-quinn", "webtransport-acme-http01"] resolver = "2" [workspace.package] @@ -28,7 +28,7 @@ rbdc-mysql = "4.3" rbs = "4.3" async-std = "1.12" async-trait = "0.1" -hyper = "=1.0.0-rc.4" +hyper = "1" sea-orm = "0.12" sea-orm-migration = "0.12" eyre = "0.6" @@ -39,8 +39,6 @@ opentelemetry-http = "0.10.0" opentelemetry-jaeger = "0.20.0" opentelemetry-prometheus = "0.14.0" opentelemetry_sdk = "0.21" -utoipa = "3.3.0" -utoipa-swagger-ui = "*" tokio-stream = "0.1.14" async-stream = "0.3.5" futures-util = { version = "0.3", default-features = true } @@ -59,4 +57,4 @@ url = "2" chrono = "0.4" sqlx = "0.7" rust-embed = "8" -time = "0.3" \ No newline at end of file +time = "0.3" diff --git a/examples/acme-http01-quinn/src/main.rs b/examples/acme-http01-quinn/src/main.rs index a35996add..54ee35737 100644 --- a/examples/acme-http01-quinn/src/main.rs +++ b/examples/acme-http01-quinn/src/main.rs @@ -1,21 +1,23 @@ -use salvo::prelude::*; -#[handler] -async fn hello() -> &'static str { - "Hello World" -} +// TODO: waiting quinn update +// use salvo::prelude::*; -#[tokio::main] -async fn main() { - tracing_subscriber::fmt().init(); +// #[handler] +// async fn hello() -> &'static str { +// "Hello World" +// } - let mut router = Router::new().get(hello); - let listener = TcpListener::new("0.0.0.0:443") - .acme() - .cache_path("temp/letsencrypt") - .add_domain("test.salvo.rs") - .http01_challege(&mut router) - .quinn("0.0.0.0:443"); - let acceptor = listener.join(TcpListener::new("0.0.0.0:80")).bind().await; - Server::new(acceptor).serve(router).await; -} +// #[tokio::main] +// async fn main() { +// tracing_subscriber::fmt().init(); + +// let mut router = Router::new().get(hello); +// let listener = TcpListener::new("0.0.0.0:443") +// .acme() +// .cache_path("temp/letsencrypt") +// .add_domain("test.salvo.rs") +// .http01_challege(&mut router) +// .quinn("0.0.0.0:443"); +// let acceptor = listener.join(TcpListener::new("0.0.0.0:80")).bind().await; +// Server::new(acceptor).serve(router).await; +// } diff --git a/examples/hello-h3/src/main.rs b/examples/hello-h3/src/main.rs index e6114c168..137832af3 100644 --- a/examples/hello-h3/src/main.rs +++ b/examples/hello-h3/src/main.rs @@ -1,4 +1,5 @@ use salvo::conn::rustls::{Keycert, RustlsConfig}; +use salvo::conn::rustls_old; use salvo::prelude::*; #[handler] @@ -14,9 +15,11 @@ async fn main() { let router = Router::new().get(hello); let config = RustlsConfig::new(Keycert::new().cert(cert.as_slice()).key(key.as_slice())); - let listener = TcpListener::new(("0.0.0.0", 5800)).rustls(config.clone()); + let config_old = + rustls_old::RustlsConfig::new(rustls_old::Keycert::new().cert(cert.as_slice()).key(key.as_slice())); + let listener = TcpListener::new(("0.0.0.0", 5800)).rustls(config); - let acceptor = QuinnListener::new(config, ("0.0.0.0", 5800)) + let acceptor = QuinnListener::new(config_old, ("0.0.0.0", 5800)) .join(listener) .bind() .await; diff --git a/examples/jwt-clerk/src/main.rs b/examples/jwt-clerk/src/main.rs index 913e4b80f..70465c8e5 100644 --- a/examples/jwt-clerk/src/main.rs +++ b/examples/jwt-clerk/src/main.rs @@ -24,7 +24,7 @@ async fn main() { let acceptor = TcpListener::new("0.0.0.0:5800").bind().await; let router = Router::new() .push(Router::with_hoop(auth_handler).path("welcome").get(welcome)) - .push(Router::with_path("<**rest>").goal(Proxy::new(vec!["http://localhost:5801"]))); + .push(Router::with_path("<**rest>").goal(Proxy::default_hyper_client(vec!["http://localhost:5801"]))); Server::new(acceptor).serve(router).await; } #[handler] diff --git a/examples/jwt-oidc-clerk/src/main.rs b/examples/jwt-oidc-clerk/src/main.rs index 320106b23..6da8fe3e5 100644 --- a/examples/jwt-oidc-clerk/src/main.rs +++ b/examples/jwt-oidc-clerk/src/main.rs @@ -24,7 +24,7 @@ async fn main() { let acceptor = TcpListener::new("0.0.0.0:5800").bind().await; let router = Router::new() .push(Router::with_hoop(auth_handler).path("welcome").get(welcome)) - .push(Router::with_path("<**rest>").goal(Proxy::new(vec!["http://localhost:5801"]))); + .push(Router::with_path("<**rest>").goal(Proxy::default_hyper_client(vec!["http://localhost:5801"]))); Server::new(acceptor).serve(router).await; } #[handler] diff --git a/examples/otel-jaeger/src/server1.rs b/examples/otel-jaeger/src/server1.rs index b2719dad2..66866b92a 100644 --- a/examples/otel-jaeger/src/server1.rs +++ b/examples/otel-jaeger/src/server1.rs @@ -5,9 +5,7 @@ use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::Tracer}; use opentelemetry::trace::{FutureExt, SpanKind, TraceContextExt, Tracer as _}; use opentelemetry::{global, KeyValue}; use opentelemetry_http::HeaderInjector; -use reqwest::Client; -use reqwest::Url; -use salvo::http::Method; +use reqwest::{Url, Method, Client}; use salvo::otel::{Metrics, Tracing}; use salvo::prelude::*; diff --git a/examples/proxy-react-app/src/main.rs b/examples/proxy-react-app/src/main.rs index 7c1cdf154..b3543ebf1 100644 --- a/examples/proxy-react-app/src/main.rs +++ b/examples/proxy-react-app/src/main.rs @@ -5,7 +5,7 @@ use salvo::proxy::Proxy; async fn main() { tracing_subscriber::fmt().init(); - let router = Router::with_path("<**rest>").goal(Proxy::new(vec!["http://localhost:3000"])); + let router = Router::with_path("<**rest>").goal(Proxy::default_hyper_client(vec!["http://localhost:3000"])); println!("{:?}", router); let acceptor = TcpListener::new("0.0.0.0:5800").bind().await; diff --git a/examples/proxy-simple/src/main.rs b/examples/proxy-simple/src/main.rs index ac07251d0..4f1a6da99 100644 --- a/examples/proxy-simple/src/main.rs +++ b/examples/proxy-simple/src/main.rs @@ -8,15 +8,15 @@ async fn main() { let router = Router::new() .push( Router::new() - .host("0.0.0.0") + .host("127.0.0.1") .path("<**rest>") - .goal(Proxy::new("https://www.rust-lang.org")), + .goal(Proxy::default_hyper_client("https://www.rust-lang.org")), ) .push( Router::new() .host("localhost") .path("<**rest>") - .goal(Proxy::new("https://crates.io")), + .goal(Proxy::default_hyper_client("https://crates.io")), ); let acceptor = TcpListener::new("0.0.0.0:5800").bind().await; diff --git a/examples/proxy-websocket/src/main.rs b/examples/proxy-websocket/src/main.rs index ea5de7c7d..fd36a3d4c 100644 --- a/examples/proxy-websocket/src/main.rs +++ b/examples/proxy-websocket/src/main.rs @@ -5,9 +5,9 @@ use salvo::proxy::Proxy; async fn main() { tracing_subscriber::fmt().init(); - let router = Router::with_path("<**rest>").goal(Proxy::new(vec!["http://localhost:5800"])); + let router = Router::with_path("<**rest>").goal(Proxy::default_hyper_client(vec!["http://localhost:5800"])); println!("{:?}", router); - tracing::info!("Run `cargo run --bin example-ws-chat` to start websocket chat server"); + tracing::info!("Run `cargo run --bin example-websocket-chat` to start websocket chat server"); let acceptor = TcpListener::new("0.0.0.0:8888").bind().await; Server::new(acceptor).serve(router).await; } diff --git a/examples/todos-utoipa/Cargo.toml b/examples/todos-utoipa/Cargo.toml index aa7189216..b2a89650f 100644 --- a/examples/todos-utoipa/Cargo.toml +++ b/examples/todos-utoipa/Cargo.toml @@ -4,7 +4,6 @@ version.workspace = true edition.workspace = true publish.workspace = true - [dependencies] once_cell = "1" salvo = { workspace = true, features = ["affix", "size-limiter"] } @@ -13,5 +12,5 @@ serde_json = "1" tokio = { workspace = true, features = ["macros"] } tracing.workspace = true tracing-subscriber.workspace = true -utoipa .workspace = true -utoipa-swagger-ui .workspace = true +utoipa = "4" +utoipa-swagger-ui = "*" diff --git a/examples/websocket-chat/src/main.rs b/examples/websocket-chat/src/main.rs index 239cc1cda..45b070a00 100644 --- a/examples/websocket-chat/src/main.rs +++ b/examples/websocket-chat/src/main.rs @@ -25,7 +25,6 @@ async fn main() { let router = Router::new() .goal(index) .push(Router::with_path("chat").goal(user_connected)); - let acceptor = TcpListener::new("0.0.0.0:5800").bind().await; Server::new(acceptor).serve(router).await; } diff --git a/examples/webtransport-acme-http01/src/main.rs b/examples/webtransport-acme-http01/src/main.rs index 58e1e15bd..060650d94 100644 --- a/examples/webtransport-acme-http01/src/main.rs +++ b/examples/webtransport-acme-http01/src/main.rs @@ -1,131 +1,132 @@ -use std::time::Duration; - -use anyhow::{Context, Result}; -use bytes::{BufMut, Bytes, BytesMut}; -use salvo::prelude::*; -use salvo::proto::webtransport; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::pin; - -macro_rules! log_result { - ($expr:expr) => { - if let Err(err) = $expr { - tracing::error!("{err:?}"); - } - }; -} -async fn echo_stream(send: T, recv: R) -> anyhow::Result<()> -where - T: AsyncWrite, - R: AsyncRead, -{ - pin!(send); - pin!(recv); - - tracing::info!("Got stream"); - let mut buf = Vec::new(); - recv.read_to_end(&mut buf).await?; - - let message = Bytes::from(buf); - send_chunked(send, message).await?; - - Ok(()) -} -// Used to test that all chunks arrive properly as it is easy to write an impl which only reads and -// writes the first chunk. -async fn send_chunked(mut send: impl AsyncWrite + Unpin, data: Bytes) -> anyhow::Result<()> { - for chunk in data.chunks(4) { - tokio::time::sleep(Duration::from_millis(100)).await; - tracing::info!("Sending {chunk:?}"); - send.write_all(chunk).await?; - } - - Ok(()) -} - -#[handler] -async fn connect(req: &mut Request) -> Result<(), salvo::Error> { - let session = req.web_transport_mut().await.unwrap(); - let session_id = session.session_id(); - - // This will open a bidirectional stream and send a message to the client right after connecting! - let stream = session.open_bi(session_id).await?; - - tokio::spawn(async move { - log_result!(open_bidi_test(stream).await); - }); - loop { - tokio::select! { - datagram = session.accept_datagram() => { - let datagram = datagram?; - if let Some((_, datagram)) = datagram { - tracing::info!("Responding with {datagram:?}"); - // Put something before to make sure encoding and decoding works and don't just - // pass through - let mut resp = BytesMut::from(&b"Response: "[..]); - resp.put(datagram); - - session.send_datagram(resp.freeze())?; - tracing::info!("Finished sending datagram"); - } - } - uni_stream = session.accept_uni() => { - let (id, stream) = uni_stream?.unwrap(); - - let send = session.open_uni(id).await?; - tokio::spawn( async move { log_result!(echo_stream(send, stream).await); }); - } - stream = session.accept_bi() => { - if let Some(webtransport::server::AcceptedBi::BidiStream(_, stream)) = stream? { - let (send, recv) = salvo::proto::quic::BidiStream::split(stream); - tokio::spawn( async move { log_result!(echo_stream(send, recv).await); }); - } - } - else => { - break - } - } - } - - tracing::info!("Finished handling session"); - - Ok(()) -} - -async fn open_bidi_test(mut stream: S) -> anyhow::Result<()> -where - S: Unpin + AsyncRead + AsyncWrite, -{ - tracing::info!("Opening bidirectional stream"); - - stream - .write_all(b"Hello from a server initiated bidi stream") - .await - .context("Failed to respond")?; - - let mut resp = Vec::new(); - stream.shutdown().await?; - stream.read_to_end(&mut resp).await?; - - tracing::info!("Got response from client: {resp:?}"); - - Ok(()) -} - -#[tokio::main] -async fn main() { - tracing_subscriber::fmt().init(); - - let mut router = Router::new().push(Router::with_path("counter").goal(connect)).push( - Router::with_path("<*path>").get(StaticDir::new(["webtransport/static", "./static"]).defaults("client.html")), - ); - - let listener = TcpListener::new("0.0.0.0:443") - .acme() - .cache_path("temp/letsencrypt") - .add_domain("test.salvo.rs") - .http01_challege(&mut router) - .quinn("0.0.0.0:443"); - let acceptor = listener.join(TcpListener::new("0.0.0.0:80")).bind().await; - Server::new(acceptor).serve(router).await; -} +// TODO: waiting quinn update +// use std::time::Duration; + +// use anyhow::{Context, Result}; +// use bytes::{BufMut, Bytes, BytesMut}; +// use salvo::prelude::*; +// use salvo::proto::webtransport; +// use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +// use tokio::pin; + +// macro_rules! log_result { +// ($expr:expr) => { +// if let Err(err) = $expr { +// tracing::error!("{err:?}"); +// } +// }; +// } +// async fn echo_stream(send: T, recv: R) -> anyhow::Result<()> +// where +// T: AsyncWrite, +// R: AsyncRead, +// { +// pin!(send); +// pin!(recv); + +// tracing::info!("Got stream"); +// let mut buf = Vec::new(); +// recv.read_to_end(&mut buf).await?; + +// let message = Bytes::from(buf); +// send_chunked(send, message).await?; + +// Ok(()) +// } +// // Used to test that all chunks arrive properly as it is easy to write an impl which only reads and +// // writes the first chunk. +// async fn send_chunked(mut send: impl AsyncWrite + Unpin, data: Bytes) -> anyhow::Result<()> { +// for chunk in data.chunks(4) { +// tokio::time::sleep(Duration::from_millis(100)).await; +// tracing::info!("Sending {chunk:?}"); +// send.write_all(chunk).await?; +// } + +// Ok(()) +// } + +// #[handler] +// async fn connect(req: &mut Request) -> Result<(), salvo::Error> { +// let session = req.web_transport_mut().await.unwrap(); +// let session_id = session.session_id(); + +// // This will open a bidirectional stream and send a message to the client right after connecting! +// let stream = session.open_bi(session_id).await?; + +// tokio::spawn(async move { +// log_result!(open_bidi_test(stream).await); +// }); +// loop { +// tokio::select! { +// datagram = session.accept_datagram() => { +// let datagram = datagram?; +// if let Some((_, datagram)) = datagram { +// tracing::info!("Responding with {datagram:?}"); +// // Put something before to make sure encoding and decoding works and don't just +// // pass through +// let mut resp = BytesMut::from(&b"Response: "[..]); +// resp.put(datagram); + +// session.send_datagram(resp.freeze())?; +// tracing::info!("Finished sending datagram"); +// } +// } +// uni_stream = session.accept_uni() => { +// let (id, stream) = uni_stream?.unwrap(); + +// let send = session.open_uni(id).await?; +// tokio::spawn( async move { log_result!(echo_stream(send, stream).await); }); +// } +// stream = session.accept_bi() => { +// if let Some(webtransport::server::AcceptedBi::BidiStream(_, stream)) = stream? { +// let (send, recv) = salvo::proto::quic::BidiStream::split(stream); +// tokio::spawn( async move { log_result!(echo_stream(send, recv).await); }); +// } +// } +// else => { +// break +// } +// } +// } + +// tracing::info!("Finished handling session"); + +// Ok(()) +// } + +// async fn open_bidi_test(mut stream: S) -> anyhow::Result<()> +// where +// S: Unpin + AsyncRead + AsyncWrite, +// { +// tracing::info!("Opening bidirectional stream"); + +// stream +// .write_all(b"Hello from a server initiated bidi stream") +// .await +// .context("Failed to respond")?; + +// let mut resp = Vec::new(); +// stream.shutdown().await?; +// stream.read_to_end(&mut resp).await?; + +// tracing::info!("Got response from client: {resp:?}"); + +// Ok(()) +// } + +// #[tokio::main] +// async fn main() { +// tracing_subscriber::fmt().init(); + +// let mut router = Router::new().push(Router::with_path("counter").goal(connect)).push( +// Router::with_path("<*path>").get(StaticDir::new(["webtransport/static", "./static"]).defaults("client.html")), +// ); + +// let listener = TcpListener::new("0.0.0.0:443") +// .acme() +// .cache_path("temp/letsencrypt") +// .add_domain("test.salvo.rs") +// .http01_challege(&mut router) +// .quinn("0.0.0.0:443"); +// let acceptor = listener.join(TcpListener::new("0.0.0.0:80")).bind().await; +// Server::new(acceptor).serve(router).await; +// } diff --git a/examples/webtransport/src/main.rs b/examples/webtransport/src/main.rs index 1ac49baf5..4ee6411cd 100644 --- a/examples/webtransport/src/main.rs +++ b/examples/webtransport/src/main.rs @@ -3,6 +3,7 @@ use std::time::Duration; use anyhow::{Context, Result}; use bytes::{BufMut, Bytes, BytesMut}; use salvo::conn::rustls::{Keycert, RustlsConfig}; +use salvo::conn::rustls_old; use salvo::prelude::*; use salvo::proto::webtransport; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -125,9 +126,11 @@ async fn main() { ); let config = RustlsConfig::new(Keycert::new().cert(cert.as_slice()).key(key.as_slice())); - let listener = TcpListener::new(("0.0.0.0", 5800)).rustls(config.clone()); + let config_old = + rustls_old::RustlsConfig::new(rustls_old::Keycert::new().cert(cert.as_slice()).key(key.as_slice())); + let listener = TcpListener::new(("0.0.0.0", 5800)).rustls(config); - let acceptor = QuinnListener::new(config, ("0.0.0.0", 5800)) + let acceptor = QuinnListener::new(config_old, ("0.0.0.0", 5800)) .join(listener) .bind() .await;