diff --git a/admin/src/args.rs b/admin/src/args.rs index 3fb3312b1..8b1e878fc 100644 --- a/admin/src/args.rs +++ b/admin/src/args.rs @@ -18,7 +18,7 @@ pub enum Command { /// Try to revive projects in the crashed state Revive, - /// Manage custom domains + /// Manage domains #[command(subcommand)] Acme(AcmeCommand), @@ -59,13 +59,13 @@ pub enum AcmeCommand { credentials: PathBuf, }, - /// Automate certificate renewal for a FQDN - AutomateCertificateRenewal { - /// Fqdn to automate certificate renewal for + /// Renew the certificate for a FQDN + RenewProjectCertificate { + /// Fqdn to renew the certificate for #[arg(long)] fqdn: String, - /// Project to automate certificate renewal for + /// Project to renew the certificate for #[arg(long)] project: ProjectName, @@ -74,6 +74,14 @@ pub enum AcmeCommand { #[arg(long)] credentials: PathBuf, }, + + /// Renew certificate for the shuttle gateway + RenewGatewayCertificate { + /// Path to acme credentials file + /// This should have been created with `acme create-account` + #[arg(long)] + credentials: PathBuf, + }, } #[derive(Subcommand, Debug)] diff --git a/admin/src/client.rs b/admin/src/client.rs index 3f42dc080..bb55c2b92 100644 --- a/admin/src/client.rs +++ b/admin/src/client.rs @@ -39,7 +39,7 @@ impl Client { self.post(&path, Some(credentials)).await } - pub async fn acme_renew_certificate( + pub async fn acme_renew_custom_domain_certificate( &self, fqdn: &str, project_name: &ProjectName, @@ -49,6 +49,14 @@ impl Client { self.post(&path, Some(credentials)).await } + pub async fn acme_renew_gateway_certificate( + &self, + credentials: &serde_json::Value, + ) -> Result { + let path = "/admin//acme/gateway/renew".to_string(); + self.post(&path, Some(credentials)).await + } + pub async fn get_projects(&self) -> Result> { self.get("/admin/projects").await } diff --git a/admin/src/main.rs b/admin/src/main.rs index 16a444272..ac3c8815e 100644 --- a/admin/src/main.rs +++ b/admin/src/main.rs @@ -50,7 +50,7 @@ async fn main() { .await .expect("to get a certificate challenge response") } - Command::Acme(AcmeCommand::AutomateCertificateRenewal { + Command::Acme(AcmeCommand::RenewProjectCertificate { fqdn, project, credentials, @@ -60,7 +60,17 @@ async fn main() { serde_json::from_str(&credentials).expect("to parse content of credentials file"); client - .acme_renew_certificate(&fqdn, &project, &credentials) + .acme_renew_custom_domain_certificate(&fqdn, &project, &credentials) + .await + .expect("to get a certificate challenge response") + } + Command::Acme(AcmeCommand::RenewGatewayCertificate { credentials }) => { + let credentials = fs::read_to_string(credentials).expect("to read credentials file"); + let credentials = + serde_json::from_str(&credentials).expect("to parse content of credentials file"); + + client + .acme_renew_gateway_certificate(&credentials) .await .expect("to get a certificate challenge response") } diff --git a/gateway/src/acme.rs b/gateway/src/acme.rs index b89d97d8f..502402c60 100644 --- a/gateway/src/acme.rs +++ b/gateway/src/acme.rs @@ -32,6 +32,11 @@ pub struct CustomDomain { pub private_key: String, } +pub enum AcmeCredentials<'a> { + InMemory(AccountCredentials<'a>), + GatewayState, +} + /// An ACME client implementation that completes Http01 challenges /// It is safe to clone this type as it functions as a singleton #[derive(Clone, Default)] @@ -98,10 +103,10 @@ impl AcmeClient { &self, identifier: &str, challenge_type: ChallengeType, - account: &Account, + creds: AccountCredentials<'_>, ) -> Result<(String, String), AcmeClientError> { trace!(identifier, "requesting acme certificate"); - + let account = AccountWrapper::from(creds).0; let (mut order, state) = account .new_order(&NewOrder { identifiers: &[Identifier::Dns(identifier.to_string())], @@ -129,7 +134,6 @@ impl AcmeClient { self.complete_challenge(challenge_type, authorization, &mut order) .await?; - let certificate = { let mut params = CertificateParams::new(vec![identifier.to_owned()]); params.distinguished_name = DistinguishedName::new(); @@ -138,11 +142,11 @@ impl AcmeClient { AcmeClientError::CertificateCreation })? }; + let signing_request = certificate.serialize_request_der().map_err(|error| { error!(%error, "failed to create certificate signing request"); AcmeClientError::CertificateSigning })?; - let certificate_chain = order .finalize(&signing_request, &state.finalize) .await @@ -294,7 +298,7 @@ impl<'a> From> for AccountWrapper { "failed to convert acme credentials into account" ); }) - .expect("Account credentials malformed"), + .expect("Malformed account credentials."), ) } } diff --git a/gateway/src/api/latest.rs b/gateway/src/api/latest.rs index f3ea3b28f..a9f6207c8 100644 --- a/gateway/src/api/latest.rs +++ b/gateway/src/api/latest.rs @@ -28,7 +28,7 @@ use uuid::Uuid; use x509_parser::parse_x509_certificate; use x509_parser::time::ASN1Time; -use crate::acme::{AccountWrapper, AcmeClient, CustomDomain}; +use crate::acme::{AcmeClient, CustomDomain}; use crate::auth::{Admin, ScopedUser, User}; use crate::project::{Project, ProjectCreating}; use crate::task::{self, BoxedTask, TaskResult}; @@ -303,7 +303,7 @@ async fn create_acme_account( } #[instrument(skip_all, fields(%project_name, %fqdn))] -async fn request_acme_certificate( +async fn request_custom_domain_acme_certificate( _: Admin, State(RouterState { service, sender, .. @@ -317,15 +317,14 @@ async fn request_acme_certificate( .parse() .map_err(|_err| Error::from(ErrorKind::InvalidCustomDomain))?; - let account = AccountWrapper::from(credentials).0; let (certs, private_key) = service - .get_or_create_certificate(&fqdn, acme_client.clone(), &account, &project_name) + .create_custom_domain_certificate(&fqdn, &acme_client, &project_name, credentials) .await?; // Destroy and recreate the project with the new domain. service .new_task() - .project(project_name) + .project(project_name.clone()) .and_then(task::destroy()) .and_then(task::run_until_done()) .and_then(task::run({ @@ -349,15 +348,16 @@ async fn request_acme_certificate( .serve_pem(&fqdn.to_string(), Cursor::new(buf)) .await?; - Ok("certificate created".to_string()) + Ok(format!( + "New certificate created for {} project.", + project_name + )) } #[instrument(skip_all, fields(%project_name, %fqdn))] -async fn renew_acme_certificate( +async fn renew_custom_domain_acme_certificate( _: Admin, - State(RouterState { - service, sender, .. - }): State, + State(RouterState { service, .. }): State, Extension(acme_client): Extension, Extension(resolver): Extension>, Path((project_name, fqdn)): Path<(ProjectName, String)>, @@ -366,75 +366,64 @@ async fn renew_acme_certificate( let fqdn: FQDN = fqdn .parse() .map_err(|_err| Error::from(ErrorKind::InvalidCustomDomain))?; - - let account = AccountWrapper::from(credentials).0; - let fqdn_clone = fqdn.clone(); - service - .new_task() - .project(project_name) - .and_then(task::run(move |ctx| { - let service_clone = service.clone(); - let fqdn_clone_clone = fqdn_clone.clone(); - let acme_client_clone = acme_client.clone(); - let account_clone = account.clone(); - let resolve_clone = resolver.clone(); - async move { - // If project not ready yet, don't attept certificare renewal. - if !ctx.state.is_ready() { - return TaskResult::Pending(ctx.state); - } - - // Try retrieve the current certificate if any. - match service_clone - .project_details_for_custom_domain(&fqdn_clone_clone) + // Try retrieve the current certificate if any. + match service.project_details_for_custom_domain(&fqdn).await { + Ok(CustomDomain { certificate, .. }) => { + let (_, x509_cert_chain) = parse_x509_certificate(certificate.as_bytes()) + .unwrap_or_else(|_| { + panic!( + "Malformed existing X509 certificate for {} project.", + project_name + ) + }); + let diff = x509_cert_chain + .validity() + .not_after + .sub(ASN1Time::now()) + .unwrap(); + // If current certificate validity less_or_eq than 30 days, attempt + // renewal. + if diff.whole_days() <= 30 { + return match acme_client + .create_certificate(&fqdn.to_string(), ChallengeType::Http01, credentials) .await { - Ok(CustomDomain { certificate, .. }) => { - let (_, x509_cert_chain) = - parse_x509_certificate(certificate.as_bytes()).unwrap(); - let diff = x509_cert_chain - .validity() - .not_after - .sub(ASN1Time::now()) - .unwrap(); - // If current certificate validity less_or_eq than 30 days, attempt - // renewal. - if diff.whole_days() <= 30 { - return match acme_client_clone - .create_certificate( - &fqdn_clone_clone.to_string(), - ChallengeType::Http01, - &account_clone, - ) - .await - { - // If successfuly created, save the certificate in memory to be - // served in the future. - Ok((certs, private_key)) => { - let mut buf = Vec::new(); - buf.extend(certs.as_bytes()); - buf.extend(private_key.as_bytes()); - match resolve_clone - .serve_pem(&fqdn_clone_clone.to_string(), Cursor::new(buf)) - .await - { - Ok(_) => TaskResult::Pending(ctx.state), - Err(err) => TaskResult::Err(err), - } - } - Err(err) => TaskResult::Err(err.into()), - }; - }; - TaskResult::Pending(ctx.state) + // If successfuly created, save the certificate in memory to be + // served in the future. + Ok((certs, private_key)) => { + let mut buf = Vec::new(); + buf.extend(certs.as_bytes()); + buf.extend(private_key.as_bytes()); + resolver + .serve_pem(&fqdn.to_string(), Cursor::new(buf)) + .await?; + Ok(format!("Certificate renewed for {} project.", project_name)) } - Err(err) => TaskResult::Err(err), - } + Err(err) => Err(err.into()), + }; + } else { + Ok(format!( + "Certificate renewal skipped, {} project certificate still valid for {} days.", + project_name, diff + )) } - })) - .send(&sender) - .await?; + } + Err(err) => Err(err), + } +} - Ok("automated certificate renewal started".to_string()) +#[instrument(skip_all)] +async fn renew_gateway_acme_certificate( + _: Admin, + State(RouterState { service, .. }): State, + Extension(acme_client): Extension, + Extension(resolver): Extension>, + AxumJson(credentials): AxumJson>, +) -> Result { + service + .renew_certificate(&acme_client, resolver, credentials) + .await; + Ok("Renewed the gate certificate.".to_string()) } async fn get_projects( @@ -487,11 +476,15 @@ impl ApiBuilder { .route("/admin/acme/:email", post(create_acme_account)) .route( "/admin/acme/request/:project_name/:fqdn", - post(request_acme_certificate), + post(request_custom_domain_acme_certificate), ) .route( "/admin/acme/renew/:project_name/:fqdn", - post(renew_acme_certificate), + post(renew_custom_domain_acme_certificate), + ) + .route( + "/admin/acme/gateway/renew", + post(renew_gateway_acme_certificate), ) .layer(Extension(acme)) .layer(Extension(resolver)); @@ -596,7 +589,7 @@ pub mod tests { #[tokio::test] async fn api_create_get_delete_projects() -> anyhow::Result<()> { let world = World::new().await; - let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); + let service = Arc::new(GatewayService::init(world.args(), world.pool(), "".into()).await); let (sender, mut receiver) = channel::(256); tokio::spawn(async move { @@ -744,7 +737,7 @@ pub mod tests { #[tokio::test] async fn api_create_get_users() -> anyhow::Result<()> { let world = World::new().await; - let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); + let service = Arc::new(GatewayService::init(world.args(), world.pool(), "".into()).await); let (sender, mut receiver) = channel::(256); tokio::spawn(async move { @@ -837,7 +830,7 @@ pub mod tests { #[tokio::test(flavor = "multi_thread")] async fn status() { let world = World::new().await; - let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); + let service = Arc::new(GatewayService::init(world.args(), world.pool(), "".into()).await); let (sender, mut receiver) = channel::(1); let (ctl_send, ctl_recv) = oneshot::channel(); diff --git a/gateway/src/lib.rs b/gateway/src/lib.rs index 25e0895c0..6f80a4923 100644 --- a/gateway/src/lib.rs +++ b/gateway/src/lib.rs @@ -665,7 +665,7 @@ pub mod tests { #[tokio::test] async fn end_to_end() { let world = World::new().await; - let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); + let service = Arc::new(GatewayService::init(world.args(), world.pool(), "".into()).await); let worker = Worker::new(); let (log_out, mut log_in) = channel(256); diff --git a/gateway/src/main.rs b/gateway/src/main.rs index 5cda56b9f..86c1ef976 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -1,9 +1,7 @@ use clap::Parser; -use fqdn::FQDN; use futures::prelude::*; -use instant_acme::{Account, AccountCredentials, ChallengeType}; use opentelemetry::global; -use shuttle_gateway::acme::{AcmeClient, AcmeClientError, CustomDomain}; +use shuttle_gateway::acme::{AcmeClient, AcmeCredentials, CustomDomain}; use shuttle_gateway::api::latest::{ApiBuilder, SVC_DEGRADED_THRESHOLD}; use shuttle_gateway::args::StartArgs; use shuttle_gateway::args::{Args, Commands, InitArgs, UseTls}; @@ -11,21 +9,19 @@ use shuttle_gateway::auth::Key; use shuttle_gateway::proxy::UserServiceBuilder; use shuttle_gateway::service::{GatewayService, MIGRATIONS}; use shuttle_gateway::task; -use shuttle_gateway::tls::{make_tls_acceptor, ChainAndPrivateKey, GatewayCertResolver}; +use shuttle_gateway::tls::make_tls_acceptor; use shuttle_gateway::worker::{Worker, WORKER_QUEUE_SIZE}; use sqlx::migrate::MigrateDatabase; use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqliteSynchronous}; use sqlx::{query, Sqlite, SqlitePool}; use std::io::{self, Cursor}; -use std::ops::Sub; -use std::path::{Path, PathBuf}; + +use std::path::PathBuf; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; use tracing::{debug, error, info, info_span, trace, warn, Instrument}; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; -use x509_parser::parse_x509_certificate; -use x509_parser::time::ASN1Time; #[tokio::main(flavor = "multi_thread")] async fn main() -> io::Result<()> { @@ -81,7 +77,7 @@ async fn main() -> io::Result<()> { } async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> { - let gateway = Arc::new(GatewayService::init(args.context.clone(), db).await); + let gateway = Arc::new(GatewayService::init(args.context.clone(), db, fs).await); let worker = Worker::new(); @@ -110,8 +106,7 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> { .map_err(|err| error!("worker error: {}", err)), ); - // Every 60secs go over all `::Ready` projects and check their - // health + // Every 60 secs go over all `::Ready` projects and check their health. let ambulance_handle = tokio::spawn({ let gateway = Arc::clone(&gateway); let sender = sender.clone(); @@ -119,7 +114,7 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> { loop { tokio::time::sleep(Duration::from_secs(60)).await; if sender.capacity() < WORKER_QUEUE_SIZE - SVC_DEGRADED_THRESHOLD { - // if degraded, don't stack more health checks + // If degraded, don't stack more health checks. warn!( sender.capacity = sender.capacity(), "skipping health checks" @@ -144,8 +139,8 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> { .send(&sender) .await { - // we wait for the check to be done before - // queuing up the next one + // We wait for the check to be done before + // queuing up the next one. handle.await } } @@ -195,24 +190,16 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> { .unwrap(); } - let fs_clone = fs.clone(); let resolver_clone = resolver.clone(); tokio::spawn(async move { - // make sure we have a certificate for ourselves - request_gateway_certs( - fs, - args.context.proxy_fqdn.clone(), - acme_client.clone(), - resolver_clone.clone(), - ) - .await; - automate_renewal( - fs_clone, - args.context.proxy_fqdn.clone(), - acme_client.clone(), - resolver_clone, - ) - .await + // Make sure we have a certificate for ourselves. + gateway + .fetch_certificate( + &acme_client, + resolver_clone.clone(), + AcmeCredentials::GatewayState, + ) + .await; }); } else { warn!("TLS is disabled in the proxy service. This is only acceptable in testing, and should *never* be used in deployments."); @@ -253,94 +240,3 @@ async fn init(db: SqlitePool, args: InitArgs) -> io::Result<()> { println!("`{}` created as super user with key: {key}", args.name); Ok(()) } - -async fn request_gateway_certs>( - fs: P, - public: FQDN, - acme: AcmeClient, - resolver: Arc, -) -> ChainAndPrivateKey { - let tls_path = fs.as_ref().join("ssl.pem"); - match ChainAndPrivateKey::load_pem(&tls_path) { - Ok(valid) => valid, - Err(_) => { - warn!( - "no valid certificate found at {}, creating one...", - tls_path.display() - ); - let certs = create_gateway_certificate(fs, public, acme, resolver).await; - certs.clone().save_pem(&tls_path).unwrap(); - certs - } - } -} - -async fn create_gateway_certificate>( - fs: P, - public: FQDN, - acme: AcmeClient, - resolver: Arc, -) -> ChainAndPrivateKey { - let creds_path = fs.as_ref().join("acme.json"); - let identifier = format!("*.{public}"); - if !creds_path.exists() { - panic!( - "no ACME credentials found at {}, cannot continue with certificate creation", - creds_path.display() - ); - } - let creds = std::fs::File::open(creds_path).expect("Invalid credentials path"); - let creds: AccountCredentials = - serde_json::from_reader(&creds).expect("Can not parse admin credentials from path"); - let account = Account::from_credentials(creds) - .map_err(|error| { - error!( - error = &error as &dyn std::error::Error, - "failed to convert acme credentials into account" - ); - AcmeClientError::AccountCreation - }) - .unwrap(); - - // Use ::Dns01 challenge because that's the only supported - // challenge type for wildcard domains - let (chain, private_key) = acme - .create_certificate(&identifier, ChallengeType::Dns01, &account) - .await - .unwrap(); - - let mut buf = Vec::new(); - buf.extend(chain.as_bytes()); - buf.extend(private_key.as_bytes()); - let certs = ChainAndPrivateKey::parse_pem(Cursor::new(buf)).unwrap(); - resolver.serve_default_der(certs.clone()).await.unwrap(); - certs -} - -async fn automate_renewal + Clone>( - fs: P, - public: FQDN, - acme: AcmeClient, - resolver: Arc, -) { - loop { - tokio::time::sleep(Duration::from_secs(240)).await; - let certs = - request_gateway_certs(fs.clone(), public.clone(), acme.clone(), resolver.clone()).await; - // Safe because a 'ChainAndPrivateKey' is built from a PEM. - let chain_and_pk = certs.into_pem().unwrap(); - let (_, x509_cert) = parse_x509_certificate(chain_and_pk.as_bytes()).unwrap(); - let diff = x509_cert.validity().not_after.sub(ASN1Time::now()).unwrap(); - if diff.whole_days() <= 30 { - let tls_path = fs.as_ref().join("ssl.pem"); - let certs = create_gateway_certificate( - fs.clone(), - public.clone(), - acme.clone(), - resolver.clone(), - ) - .await; - certs.save_pem(&tls_path).unwrap(); - } - } -} diff --git a/gateway/src/service.rs b/gateway/src/service.rs index 039056cb2..ab919c45f 100644 --- a/gateway/src/service.rs +++ b/gateway/src/service.rs @@ -1,3 +1,6 @@ +use std::io::Cursor; +use std::ops::Sub; +use std::path::PathBuf; use std::sync::Arc; use axum::body::Body; @@ -5,13 +8,13 @@ use axum::headers::{Authorization, HeaderMapExt}; use axum::http::Request; use axum::response::Response; use bollard::{Docker, API_DEFAULT_VERSION}; -use fqdn::Fqdn; +use fqdn::{Fqdn, FQDN}; use http::HeaderValue; use hyper::client::connect::dns::GaiResolver; use hyper::client::HttpConnector; use hyper::Client; use hyper_reverse_proxy::ReverseProxy; -use instant_acme::{Account, ChallengeType}; +use instant_acme::{AccountCredentials, ChallengeType}; use once_cell::sync::Lazy; use opentelemetry::global; use opentelemetry_http::HeaderInjector; @@ -20,14 +23,18 @@ use sqlx::migrate::Migrator; use sqlx::sqlite::SqlitePool; use sqlx::types::Json as SqlxJson; use sqlx::{query, Error as SqlxError, Row}; +use tracing::log::warn; use tracing::{debug, Span}; use tracing_opentelemetry::OpenTelemetrySpanExt; +use x509_parser::parse_x509_certificate; +use x509_parser::time::ASN1Time; -use crate::acme::{AcmeClient, CustomDomain}; +use crate::acme::{AccountWrapper, AcmeClient, AcmeCredentials, CustomDomain}; use crate::args::ContextArgs; use crate::auth::{Key, Permissions, ScopedUser, User}; use crate::project::Project; use crate::task::{BoxedTask, TaskBuilder}; +use crate::tls::{ChainAndPrivateKey, GatewayCertResolver}; use crate::worker::TaskRouter; use crate::{AccountName, DockerContext, Error, ErrorKind, ProjectDetails, ProjectName}; @@ -165,6 +172,7 @@ pub struct GatewayService { provider: GatewayContextProvider, db: SqlitePool, task_router: TaskRouter, + state_location: PathBuf, } impl GatewayService { @@ -172,7 +180,7 @@ impl GatewayService { /// /// * `args` - The [`Args`] with which the service was /// started. Will be passed as [`Context`] to workers and state. - pub async fn init(args: ContextArgs, db: SqlitePool) -> Self { + pub async fn init(args: ContextArgs, db: SqlitePool, state_location: PathBuf) -> Self { let docker = Docker::connect_with_unix(&args.docker_host, 60, API_DEFAULT_VERSION).unwrap(); let container_settings = ContainerSettings::builder().from_args(&args).await; @@ -185,6 +193,7 @@ impl GatewayService { provider, db, task_router, + state_location, } } @@ -553,12 +562,12 @@ impl GatewayService { /// Returns the current certificate as a pair of the chain and private key. /// If the pair doesn't exist for a specific project, create both the certificate /// and the custom domain it will represent. - pub async fn get_or_create_certificate( + pub async fn create_custom_domain_certificate( &self, fqdn: &Fqdn, - acme_client: AcmeClient, - account: &Account, + acme_client: &AcmeClient, project_name: &ProjectName, + creds: AccountCredentials<'_>, ) -> Result<(String, String), Error> { match self.project_details_for_custom_domain(fqdn).await { Ok(CustomDomain { @@ -568,7 +577,7 @@ impl GatewayService { }) => Ok((certificate, private_key)), Err(err) if err.kind() == ErrorKind::CustomDomainNotFound => { let (certs, private_key) = acme_client - .create_certificate(&fqdn.to_string(), ChallengeType::Http01, account) + .create_certificate(&fqdn.to_string(), ChallengeType::Http01, creds) .await?; self.create_custom_domain(project_name, fqdn, &certs, &private_key) .await?; @@ -578,6 +587,108 @@ impl GatewayService { } } + async fn create_certificate<'a>( + &self, + acme: &AcmeClient, + resolver: Arc, + creds: AccountCredentials<'a>, + ) -> ChainAndPrivateKey { + let public: FQDN = self.context().settings.fqdn.parse().unwrap(); + let identifier = format!("*.{public}"); + + // Use ::Dns01 challenge because that's the only supported + // challenge type for wildcard domains. + let (chain, private_key) = acme + .create_certificate(&identifier, ChallengeType::Dns01, creds) + .await + .unwrap(); + + let mut buf = Vec::new(); + buf.extend(chain.as_bytes()); + buf.extend(private_key.as_bytes()); + let certs = ChainAndPrivateKey::parse_pem(Cursor::new(buf)).expect("Malformed PEM buffer."); + resolver + .serve_default_der(certs.clone()) + .await + .expect("Failed to serve the default certs"); + + certs + } + + /// Fetch the gateway certificate from the state location. + /// If not existent, create the gateway certificate and save it to the + /// gateway state. + pub async fn fetch_certificate( + &self, + acme: &AcmeClient, + resolver: Arc, + creds: AcmeCredentials<'_>, + ) -> ChainAndPrivateKey { + let tls_path = self.state_location.join("ssl.pem"); + match ChainAndPrivateKey::load_pem(&tls_path) { + Ok(valid) => valid, + Err(_) => { + warn!( + "no valid certificate found at {}, creating one...", + tls_path.display() + ); + + let creds = match creds { + AcmeCredentials::InMemory(creds) => creds, + AcmeCredentials::GatewayState => { + let creds_path = self.state_location.join("acme.json"); + if !creds_path.exists() { + panic!( + "no ACME credentials found at {}, cannot continue with certificate creation", + creds_path.display() + ); + } + + let creds = + std::fs::File::open(creds_path).expect("Invalid credentials path"); + serde_json::from_reader(&creds) + .expect("Can not parse admin credentials from path") + } + }; + + let certs = self.create_certificate(acme, resolver, creds).await; + certs.clone().save_pem(&tls_path).unwrap(); + certs + } + } + } + + /// Renew the gateway certificate if there less than 30 days until the current + /// certificate expiration. + pub(crate) async fn renew_certificate( + &self, + acme: &AcmeClient, + resolver: Arc, + creds: AccountCredentials<'_>, + ) { + let account = AccountWrapper::from(creds).0; + let certs = self + .fetch_certificate( + acme, + resolver.clone(), + AcmeCredentials::InMemory(account.credentials()), + ) + .await; + // Safe to unwrap because a 'ChainAndPrivateKey' is built from a PEM. + let chain_and_pk = certs.into_pem().unwrap(); + + let (_, x509_cert) = parse_x509_certificate(chain_and_pk.as_bytes()).unwrap(); + let diff = x509_cert.validity().not_after.sub(ASN1Time::now()).unwrap(); + + if diff.whole_days() <= 30 { + let tls_path = self.state_location.join("ssl.pem"); + let certs = self + .create_certificate(acme, resolver.clone(), account.credentials()) + .await; + certs.save_pem(&tls_path).unwrap(); + } + } + pub fn context(&self) -> GatewayContext { self.provider.context() } @@ -624,7 +735,7 @@ pub mod tests { #[tokio::test] async fn service_create_find_user() -> anyhow::Result<()> { let world = World::new().await; - let svc = GatewayService::init(world.args(), world.pool()).await; + let svc = GatewayService::init(world.args(), world.pool(), "".into()).await; let account_name: AccountName = "test_user_123".parse()?; @@ -675,7 +786,7 @@ pub mod tests { #[tokio::test] async fn service_create_find_delete_project() -> anyhow::Result<()> { let world = World::new().await; - let svc = Arc::new(GatewayService::init(world.args(), world.pool()).await); + let svc = Arc::new(GatewayService::init(world.args(), world.pool(), "".into()).await); let neo: AccountName = "neo".parse().unwrap(); let trinity: AccountName = "trinity".parse().unwrap(); @@ -755,7 +866,7 @@ pub mod tests { #[tokio::test] async fn service_create_ready_kill_restart_docker() -> anyhow::Result<()> { let world = World::new().await; - let svc = Arc::new(GatewayService::init(world.args(), world.pool()).await); + let svc = Arc::new(GatewayService::init(world.args(), world.pool(), "".into()).await); let neo: AccountName = "neo".parse().unwrap(); let matrix: ProjectName = "matrix".parse().unwrap(); @@ -812,7 +923,7 @@ pub mod tests { #[tokio::test] async fn service_create_find_custom_domain() -> anyhow::Result<()> { let world = World::new().await; - let svc = Arc::new(GatewayService::init(world.args(), world.pool()).await); + let svc = Arc::new(GatewayService::init(world.args(), world.pool(), "".into()).await); let account: AccountName = "neo".parse().unwrap(); let project_name: ProjectName = "matrix".parse().unwrap();