diff --git a/Cargo.lock b/Cargo.lock index a3200896c5..e8aab9b0bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -362,6 +362,45 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" +[[package]] +name = "asn1-rs" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf6690c370453db30743b373a60ba498fc0d6d83b11f4abfd87a84a075db5dd4" +dependencies = [ + "asn1-rs-derive", + "asn1-rs-impl", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror", + "time 0.3.11", +] + +[[package]] +name = "asn1-rs-derive" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "726535892e8eae7e70657b4c8ea93d26b8553afb1ce617caee529ef96d7dee6c" +dependencies = [ + "proc-macro2 1.0.47", + "quote 1.0.21", + "syn 1.0.104", + "synstructure", +] + +[[package]] +name = "asn1-rs-impl" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2777730b2039ac0f95f093556e61b6d26cebed5393ca6f152717777cec3a42ed" +dependencies = [ + "proc-macro2 1.0.47", + "quote 1.0.21", + "syn 1.0.104", +] + [[package]] name = "assert_cmd" version = "2.0.6" @@ -2285,6 +2324,20 @@ version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57" +[[package]] +name = "der-parser" +version = "8.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42d4bc9b0db0a0df9ae64634ac5bdefb7afcb534e182275ca0beadbe486701c1" +dependencies = [ + "asn1-rs", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + [[package]] name = "derivative" version = "2.2.0" @@ -2433,6 +2486,17 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "212d0f5754cb6769937f4501cc0e67f4f4483c8d2c3e1e922ee9edbe4ab4c7c0" +[[package]] +name = "displaydoc" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bf95dc3f046b9da4f2d51833c0d3547d8564ef6910f5c1ed130306a75b92886" +dependencies = [ + "proc-macro2 1.0.47", + "quote 1.0.21", + "syn 1.0.104", +] + [[package]] name = "doc-comment" version = "0.3.3" @@ -4111,6 +4175,15 @@ dependencies = [ "malloc_buf", ] +[[package]] +name = "oid-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bedf36ffb6ba96c2eb7144ef6270557b52e54b20c0a8e1eb2ff99a6c6959bff" +dependencies = [ + "asn1-rs", +] + [[package]] name = "once_cell" version = "1.16.0" @@ -5432,6 +5505,15 @@ dependencies = [ "serde_json", ] +[[package]] +name = "rusticata-macros" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" +dependencies = [ + "nom", +] + [[package]] name = "rustix" version = "0.36.3" @@ -6107,6 +6189,7 @@ dependencies = [ "tracing-subscriber", "ttl_cache", "uuid 1.2.2", + "x509-parser", ] [[package]] @@ -6623,6 +6706,18 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8" +[[package]] +name = "synstructure" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" +dependencies = [ + "proc-macro2 1.0.47", + "quote 1.0.21", + "syn 1.0.104", + "unicode-xid 0.2.3", +] + [[package]] name = "take_mut" version = "0.2.2" @@ -8062,6 +8157,24 @@ dependencies = [ "winapi", ] +[[package]] +name = "x509-parser" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0ecbeb7b67ce215e40e3cc7f2ff902f94a223acf44995934763467e7b1febc8" +dependencies = [ + "asn1-rs", + "base64 0.13.1", + "data-encoding", + "der-parser", + "lazy_static", + "nom", + "oid-registry", + "rusticata-macros", + "thiserror", + "time 0.3.11", +] + [[package]] name = "xattr" version = "0.2.3" @@ -8126,3 +8239,19 @@ dependencies = [ "cc", "libc", ] + +[[patch.unused]] +name = "shuttle-aws-rds" +version = "0.10.0" + +[[patch.unused]] +name = "shuttle-persist" +version = "0.10.0" + +[[patch.unused]] +name = "shuttle-shared-db" +version = "0.10.0" + +[[patch.unused]] +name = "shuttle-static-folder" +version = "0.10.0" diff --git a/Containerfile b/Containerfile index c2a9218259..5881b2d0c8 100644 --- a/Containerfile +++ b/Containerfile @@ -3,11 +3,14 @@ ARG RUSTUP_TOOLCHAIN FROM rust:${RUSTUP_TOOLCHAIN}-buster as shuttle-build RUN apt-get update &&\ apt-get install -y curl + # download protoc binary and unzip it in usr/bin -RUN curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.9/protoc-21.9-linux-x86_64.zip &&\ - unzip -o protoc-21.9-linux-x86_64.zip -d /usr bin/protoc &&\ - unzip -o protoc-21.9-linux-x86_64.zip -d /usr/ 'include/*' &&\ - rm -f protoc-21.9-linux-x86_64.zip +ARG PROTOC_ARCH +RUN curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.9/protoc-21.9-linux-${PROTOC_ARCH}.zip &&\ + unzip -o protoc-21.9-linux-${PROTOC_ARCH}.zip -d /usr bin/protoc &&\ + unzip -o protoc-21.9-linux-${PROTOC_ARCH}.zip -d /usr/ 'include/*' &&\ + rm -f protoc-21.9-linux-${PROTOC_ARCH}.zip + RUN cargo install cargo-chef WORKDIR /build diff --git a/Makefile b/Makefile index 5e3706ac0a..abea9a2d8c 100644 --- a/Makefile +++ b/Makefile @@ -61,6 +61,12 @@ USE_TLS?=disable CARGO_PROFILE=debug endif +ARCH=$(shell uname -m) +PROTOC_ARCH=$(ARCH) +ifeq ($(ARCH), arm64) +PROTOC_ARCH=aarch_64 +endif + POSTGRES_EXTRA_PATH?=./extras/postgres POSTGRES_TAG?=14 @@ -81,6 +87,7 @@ images: shuttle-provisioner shuttle-deployer shuttle-gateway shuttle-auth postgr postgres: docker buildx build \ + --build-arg PROTOC_ARCH=$(PROTOC_ARCH) \ --build-arg POSTGRES_TAG=$(POSTGRES_TAG) \ --tag $(CONTAINER_REGISTRY)/postgres:$(POSTGRES_TAG) \ $(BUILDX_FLAGS) \ @@ -89,6 +96,8 @@ postgres: panamax: docker buildx build \ + --build-arg PROTOC_ARCH=$(PROTOC_ARCH) \ + --platform linux/$(ARCH) \ --build-arg PANAMAX_TAG=$(PANAMAX_TAG) \ --tag $(CONTAINER_REGISTRY)/panamax:$(PANAMAX_TAG) \ $(BUILDX_FLAGS) \ @@ -112,6 +121,7 @@ down: docker-compose.rendered.yml shuttle-%: ${SRC} Cargo.lock docker buildx build \ + --build-arg PROTOC_ARCH=$(PROTOC_ARCH) \ --build-arg folder=$(*) \ --build-arg RUSTUP_TOOLCHAIN=$(RUSTUP_TOOLCHAIN) \ --build-arg CARGO_PROFILE=$(CARGO_PROFILE) \ diff --git a/admin/src/args.rs b/admin/src/args.rs index e7ae681225..b1e02f9937 100644 --- a/admin/src/args.rs +++ b/admin/src/args.rs @@ -18,7 +18,7 @@ pub enum Command { /// Try to revive projects in the crashed state Revive, - /// Manage custom domains + /// Manage domains #[command(subcommand)] Acme(AcmeCommand), @@ -58,6 +58,30 @@ pub enum AcmeCommand { #[arg(long)] credentials: PathBuf, }, + + /// Renew the certificate for a FQDN + RenewCustomDomainCertificate { + /// Fqdn to renew the certificate for + #[arg(long)] + fqdn: String, + + /// Project to renew the certificate for + #[arg(long)] + project: ProjectName, + + /// Path to acme credentials file + /// This should have been created with `acme create-account` + #[arg(long)] + credentials: PathBuf, + }, + + /// Renew certificate for the shuttle gateway + RenewGatewayCertificate { + /// Path to acme credentials file + /// This should have been created with `acme create-account` + #[arg(long)] + credentials: PathBuf, + }, } #[derive(Subcommand, Debug)] diff --git a/admin/src/client.rs b/admin/src/client.rs index 43e756a789..bb55c2b92d 100644 --- a/admin/src/client.rs +++ b/admin/src/client.rs @@ -39,6 +39,24 @@ impl Client { self.post(&path, Some(credentials)).await } + pub async fn acme_renew_custom_domain_certificate( + &self, + fqdn: &str, + project_name: &ProjectName, + credentials: &serde_json::Value, + ) -> Result { + let path = format!("/admin/acme/renew/{project_name}/{fqdn}"); + self.post(&path, Some(credentials)).await + } + + pub async fn acme_renew_gateway_certificate( + &self, + credentials: &serde_json::Value, + ) -> Result { + let path = "/admin//acme/gateway/renew".to_string(); + self.post(&path, Some(credentials)).await + } + pub async fn get_projects(&self) -> Result> { self.get("/admin/projects").await } diff --git a/admin/src/main.rs b/admin/src/main.rs index 55cb40ee54..81921fde10 100644 --- a/admin/src/main.rs +++ b/admin/src/main.rs @@ -50,6 +50,30 @@ async fn main() { .await .expect("to get a certificate challenge response") } + Command::Acme(AcmeCommand::RenewCustomDomainCertificate { + fqdn, + project, + credentials, + }) => { + let credentials = fs::read_to_string(credentials).expect("to read credentials file"); + let credentials = + serde_json::from_str(&credentials).expect("to parse content of credentials file"); + + client + .acme_renew_custom_domain_certificate(&fqdn, &project, &credentials) + .await + .expect("to get a certificate challenge response") + } + Command::Acme(AcmeCommand::RenewGatewayCertificate { credentials }) => { + let credentials = fs::read_to_string(credentials).expect("to read credentials file"); + let credentials = + serde_json::from_str(&credentials).expect("to parse content of credentials file"); + + client + .acme_renew_gateway_certificate(&credentials) + .await + .expect("to get a certificate challenge response") + } Command::ProjectNames => { let projects = client .get_projects() diff --git a/common/src/backends/auth.rs b/common/src/backends/auth.rs index e8521f2381..fff8fb3904 100644 --- a/common/src/backends/auth.rs +++ b/common/src/backends/auth.rs @@ -136,10 +136,18 @@ pub enum Scope { /// Create an ACME account AcmeCreate, - /// Create a custom domain, + /// Create a custom domain. CustomDomainCreate, - /// Admin level scope to internals + /// Renew the certificate of a custom domain. + CustomDomainCertificateRenew, + + /// Request renewal of the gateway certificate. + /// Note: this step should be completed manually in terms + /// of DNS-01 challenge completion. + GatewayCertificateRenew, + + /// Admin level scope to internals. Admin, } diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index 31a0ee25f1..115571a447 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -43,6 +43,7 @@ tracing-opentelemetry = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } ttl_cache = "0.5.1" uuid = { workspace = true, features = [ "v4" ] } +x509-parser = "0.14.0" [dependencies.shuttle-common] workspace = true diff --git a/gateway/src/acme.rs b/gateway/src/acme.rs index 4f448df1ed..ff122cb6f9 100644 --- a/gateway/src/acme.rs +++ b/gateway/src/acme.rs @@ -32,6 +32,11 @@ pub struct CustomDomain { pub private_key: String, } +pub enum AcmeCredentials<'a> { + InMemory(AccountCredentials<'a>), + GatewayState, +} + /// An ACME client implementation that completes Http01 challenges /// It is safe to clone this type as it functions as a singleton #[derive(Clone, Default)] @@ -60,7 +65,8 @@ impl AcmeClient { self.0.lock().await.remove(token); } - /// Create a new ACME account that can be restored using by deserializing the returned JSON into a [instant_acme::AccountCredentials] + /// Create a new ACME account that can be restored by using the deserialization + /// of the returned JSON into a [instant_acme::AccountCredentials] pub async fn create_account( &self, email: &str, @@ -101,15 +107,8 @@ impl AcmeClient { ) -> Result<(String, String), AcmeClientError> { trace!(identifier, "requesting acme certificate"); - let account = Account::from_credentials(credentials).map_err(|error| { - error!( - error = &error as &dyn std::error::Error, - "failed to convert acme credentials into account" - ); - AcmeClientError::AccountCreation - })?; - - let (mut order, state) = account + let (mut order, state) = AccountWrapper::from(credentials) + .0 .new_order(&NewOrder { identifiers: &[Identifier::Dns(identifier.to_string())], }) @@ -288,6 +287,24 @@ impl AcmeClient { } } +#[derive(Clone)] +pub struct AccountWrapper(pub Account); + +impl<'a> From> for AccountWrapper { + fn from(value: AccountCredentials<'a>) -> Self { + AccountWrapper( + Account::from_credentials(value) + .map_err(|error| { + error!( + error = &error as &dyn std::error::Error, + "failed to convert acme credentials into account" + ); + }) + .expect("Malformed account credentials."), + ) + } +} + #[derive(Debug, strum::Display)] pub enum AcmeClientError { AccountCreation, diff --git a/gateway/src/api/latest.rs b/gateway/src/api/latest.rs index 4e64cfc903..e8e0fb2a27 100644 --- a/gateway/src/api/latest.rs +++ b/gateway/src/api/latest.rs @@ -1,5 +1,6 @@ use std::io::Cursor; use std::net::SocketAddr; +use std::ops::Sub; use std::sync::Arc; use std::time::Duration; @@ -26,14 +27,17 @@ use tokio::sync::{Mutex, MutexGuard}; use tracing::{field, instrument}; use ttl_cache::TtlCache; use uuid::Uuid; +use x509_parser::parse_x509_certificate; +use x509_parser::time::ASN1Time; use crate::acme::{AcmeClient, CustomDomain}; use crate::auth::{ScopedUser, User}; use crate::project::{Project, ProjectCreating}; +use crate::service::GatewayService; use crate::task::{self, BoxedTask, TaskResult}; use crate::tls::GatewayCertResolver; use crate::worker::WORKER_QUEUE_SIZE; -use crate::{Error, GatewayService, ProjectName}; +use crate::{Error, ProjectName}; use super::auth_layer::ShuttleAuthLayer; @@ -103,23 +107,23 @@ async fn get_projects_list( Ok(AxumJson(projects)) } -async fn get_projects_list_with_filter( - State(RouterState { service, .. }): State, - User { name, .. }: User, - Path(project_status): Path, -) -> Result>, Error> { - let projects = service - .iter_user_projects_detailed_filtered(name.clone(), project_status) - .await? - .into_iter() - .map(|project| project::Response { - name: project.0.to_string(), - state: project.1.into(), - }) - .collect(); - - Ok(AxumJson(projects)) -} +// async fn get_projects_list_with_filter( +// State(RouterState { service, .. }): State, +// User { name, .. }: User, +// Path(project_status): Path, +// ) -> Result>, Error> { +// let projects = service +// .iter_user_projects_detailed_filtered(name.clone(), project_status) +// .await? +// .into_iter() +// .map(|project| project::Response { +// name: project.0.to_string(), +// state: project.1.into(), +// }) +// .collect(); + +// Ok(AxumJson(projects)) +// } #[instrument(skip_all, fields(%project))] async fn post_project( @@ -296,7 +300,7 @@ async fn create_acme_account( } #[instrument(skip_all, fields(%project_name, %fqdn))] -async fn request_acme_certificate( +async fn request_custom_domain_acme_certificate( State(RouterState { service, sender, .. }): State, @@ -309,28 +313,14 @@ async fn request_acme_certificate( .parse() .map_err(|_err| Error::from(ErrorKind::InvalidCustomDomain))?; - let (certs, private_key) = match service.project_details_for_custom_domain(&fqdn).await { - Ok(CustomDomain { - certificate, - private_key, - .. - }) => (certificate, private_key), - Err(err) if err.kind() == ErrorKind::CustomDomainNotFound => { - let (certs, private_key) = acme_client - .create_certificate(&fqdn.to_string(), ChallengeType::Http01, credentials) - .await?; - service - .create_custom_domain(project_name.clone(), &fqdn, &certs, &private_key) - .await?; - (certs, private_key) - } - Err(err) => return Err(err), - }; + let (certs, private_key) = service + .create_custom_domain_certificate(&fqdn, &acme_client, &project_name, credentials) + .await?; - // destroy and recreate the project with the new domain + // Destroy and recreate the project with the new domain. service .new_task() - .project(project_name) + .project(project_name.clone()) .and_then(task::destroy()) .and_then(task::run_until_done()) .and_then(task::run({ @@ -354,7 +344,79 @@ async fn request_acme_certificate( .serve_pem(&fqdn.to_string(), Cursor::new(buf)) .await?; - Ok("certificate created".to_string()) + Ok(format!( + "New certificate created for {} project.", + project_name + )) +} + +#[instrument(skip_all, fields(%project_name, %fqdn))] +async fn renew_custom_domain_acme_certificate( + State(RouterState { service, .. }): State, + Extension(acme_client): Extension, + Extension(resolver): Extension>, + Path((project_name, fqdn)): Path<(ProjectName, String)>, + AxumJson(credentials): AxumJson>, +) -> Result { + let fqdn: FQDN = fqdn + .parse() + .map_err(|_err| Error::from(ErrorKind::InvalidCustomDomain))?; + // Try retrieve the current certificate if any. + match service.project_details_for_custom_domain(&fqdn).await { + Ok(CustomDomain { certificate, .. }) => { + let (_, x509_cert_chain) = parse_x509_certificate(certificate.as_bytes()) + .unwrap_or_else(|_| { + panic!( + "Malformed existing X509 certificate for {} project.", + project_name + ) + }); + let diff = x509_cert_chain + .validity() + .not_after + .sub(ASN1Time::now()) + .unwrap(); + // If current certificate validity less_or_eq than 30 days, attempt renewal. + if diff.whole_days() <= 30 { + return match acme_client + .create_certificate(&fqdn.to_string(), ChallengeType::Http01, credentials) + .await + { + // If successfuly created, save the certificate in memory to be + // served in the future. + Ok((certs, private_key)) => { + let mut buf = Vec::new(); + buf.extend(certs.as_bytes()); + buf.extend(private_key.as_bytes()); + resolver + .serve_pem(&fqdn.to_string(), Cursor::new(buf)) + .await?; + Ok(format!("Certificate renewed for {} project.", project_name)) + } + Err(err) => Err(err.into()), + }; + } else { + Ok(format!( + "Certificate renewal skipped, {} project certificate still valid for {} days.", + project_name, diff + )) + } + } + Err(err) => Err(err), + } +} + +#[instrument(skip_all)] +async fn renew_gateway_acme_certificate( + State(RouterState { service, .. }): State, + Extension(acme_client): Extension, + Extension(resolver): Extension>, + AxumJson(credentials): AxumJson>, +) -> Result { + service + .renew_certificate(&acme_client, resolver, credentials) + .await; + Ok("Renewed the gate certificate.".to_string()) } async fn get_projects( @@ -410,10 +472,24 @@ impl ApiBuilder { .route( "/admin/acme/request/:project_name/:fqdn", post( - request_acme_certificate + request_custom_domain_acme_certificate .layer(ScopedLayer::new(vec![Scope::CustomDomainCreate])), ), ) + .route( + "/admin/acme/renew/:project_name/:fqdn", + post( + renew_custom_domain_acme_certificate + .layer(ScopedLayer::new(vec![Scope::CustomDomainCertificateRenew])), + ), + ) + .route( + "/admin/acme/gateway/renew", + post( + renew_gateway_acme_certificate + .layer(ScopedLayer::new(vec![Scope::GatewayCertificateRenew])), + ), + ) .layer(Extension(acme)) .layer(Extension(resolver)); self @@ -543,7 +619,7 @@ pub mod tests { #[tokio::test] async fn api_create_get_delete_projects() -> anyhow::Result<()> { let world = World::new().await; - let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); + let service = Arc::new(GatewayService::init(world.args(), world.pool(), "".into()).await); let (sender, mut receiver) = channel::(256); tokio::spawn(async move { @@ -553,7 +629,7 @@ pub mod tests { }); let mut router = ApiBuilder::new() - .with_service(Arc::clone(&service)) + .with_service(Arc::::clone(&service)) .with_sender(sender) .with_default_routes() .with_auth_service(world.context().auth_uri) @@ -690,7 +766,7 @@ pub mod tests { #[tokio::test(flavor = "multi_thread")] async fn status() { let world = World::new().await; - let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); + let service = Arc::new(GatewayService::init(world.args(), world.pool(), "".into()).await); let (sender, mut receiver) = channel::(1); let (ctl_send, ctl_recv) = oneshot::channel(); @@ -707,7 +783,7 @@ pub mod tests { }); let mut router = ApiBuilder::new() - .with_service(Arc::clone(&service)) + .with_service(Arc::::clone(&service)) .with_sender(sender) .with_default_routes() .with_auth_service(world.context().auth_uri) diff --git a/gateway/src/lib.rs b/gateway/src/lib.rs index 86529e4d0c..f160562786 100644 --- a/gateway/src/lib.rs +++ b/gateway/src/lib.rs @@ -14,6 +14,7 @@ use axum::Json; use bollard::Docker; use futures::prelude::*; use serde::{Deserialize, Deserializer, Serialize}; +use service::ContainerSettings; use shuttle_common::models::error::{ApiError, ErrorKind}; use tokio::sync::mpsc::error::SendError; use tracing::error; @@ -29,8 +30,6 @@ pub mod task; pub mod tls; pub mod worker; -use crate::service::{ContainerSettings, GatewayService}; - /// Server-side errors that do not have to do with the user runtime /// should be [`Error`]s. /// @@ -750,7 +749,7 @@ pub mod tests { #[tokio::test] async fn end_to_end() { let world = World::new().await; - let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); + let service = Arc::new(GatewayService::init(world.args(), world.pool(), "".into()).await); let worker = Worker::new(); let (log_out, mut log_in) = channel(256); @@ -769,7 +768,7 @@ pub mod tests { let api_client = world.client(world.args.control); let api = ApiBuilder::new() - .with_service(Arc::clone(&service)) + .with_service(Arc::::clone(&service)) .with_sender(log_out) .with_default_routes() .with_auth_service(world.context().auth_uri) diff --git a/gateway/src/main.rs b/gateway/src/main.rs index d990918351..c7e61a1073 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -1,22 +1,21 @@ use clap::Parser; -use fqdn::FQDN; use futures::prelude::*; -use instant_acme::{AccountCredentials, ChallengeType}; use opentelemetry::global; -use shuttle_gateway::acme::{AcmeClient, CustomDomain}; +use shuttle_gateway::acme::{AcmeClient, AcmeCredentials, CustomDomain}; use shuttle_gateway::api::latest::{ApiBuilder, SVC_DEGRADED_THRESHOLD}; use shuttle_gateway::args::StartArgs; use shuttle_gateway::args::{Args, Commands, UseTls}; use shuttle_gateway::proxy::UserServiceBuilder; use shuttle_gateway::service::{GatewayService, MIGRATIONS}; use shuttle_gateway::task; -use shuttle_gateway::tls::{make_tls_acceptor, ChainAndPrivateKey}; +use shuttle_gateway::tls::make_tls_acceptor; use shuttle_gateway::worker::{Worker, WORKER_QUEUE_SIZE}; use sqlx::migrate::MigrateDatabase; use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqliteSynchronous}; use sqlx::{Sqlite, SqlitePool}; use std::io::{self, Cursor}; -use std::path::{Path, PathBuf}; + +use std::path::PathBuf; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; @@ -76,7 +75,7 @@ async fn main() -> io::Result<()> { } async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> { - let gateway = Arc::new(GatewayService::init(args.context.clone(), db).await); + let gateway = Arc::new(GatewayService::init(args.context.clone(), db, fs).await); let worker = Worker::new(); @@ -105,8 +104,7 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> { .map_err(|err| error!("worker error: {}", err)), ); - // Every 60secs go over all `::Ready` projects and check their - // health + // Every 60 secs go over all `::Ready` projects and check their health. let ambulance_handle = tokio::spawn({ let gateway = Arc::clone(&gateway); let sender = sender.clone(); @@ -114,7 +112,7 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> { loop { tokio::time::sleep(Duration::from_secs(60)).await; if sender.capacity() < WORKER_QUEUE_SIZE - SVC_DEGRADED_THRESHOLD { - // if degraded, don't stack more health checks + // If degraded, don't stack more health checks. warn!( sender.capacity = sender.capacity(), "skipping health checks" @@ -139,8 +137,8 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> { .send(&sender) .await { - // we wait for the check to be done before - // queuing up the next one + // We wait for the check to be done before + // queuing up the next one. handle.await } } @@ -155,7 +153,7 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> { let acme_client = AcmeClient::new(); let mut api_builder = ApiBuilder::new() - .with_service(Arc::clone(&gateway)) + .with_service(Arc::::clone(&gateway)) .with_sender(sender) .binding_to(args.control); @@ -191,9 +189,14 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> { } tokio::spawn(async move { - // make sure we have a certificate for ourselves - let certs = init_certs(fs, args.context.proxy_fqdn.clone(), acme_client.clone()).await; - resolver.serve_default_der(certs).await.unwrap(); + // Make sure we have a certificate for ourselves. + gateway + .fetch_certificate( + &acme_client, + resolver.clone(), + AcmeCredentials::GatewayState, + ) + .await; }); } else { warn!("TLS is disabled in the proxy service. This is only acceptable in testing, and should *never* be used in deployments."); @@ -218,47 +221,3 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> { Ok(()) } - -async fn init_certs>(fs: P, public: FQDN, acme: AcmeClient) -> ChainAndPrivateKey { - let tls_path = fs.as_ref().join("ssl.pem"); - - match ChainAndPrivateKey::load_pem(&tls_path) { - Ok(valid) => valid, - Err(_) => { - let creds_path = fs.as_ref().join("acme.json"); - warn!( - "no valid certificate found at {}, creating one...", - tls_path.display() - ); - - if !creds_path.exists() { - panic!( - "no ACME credentials found at {}, cannot continue with certificate creation", - creds_path.display() - ); - } - - let creds = std::fs::File::open(creds_path).unwrap(); - let creds: AccountCredentials = serde_json::from_reader(&creds).unwrap(); - - let identifier = format!("*.{public}"); - - // Use ::Dns01 challenge because that's the only supported - // challenge type for wildcard domains - let (chain, private_key) = acme - .create_certificate(&identifier, ChallengeType::Dns01, creds) - .await - .unwrap(); - - let mut buf = Vec::new(); - buf.extend(chain.as_bytes()); - buf.extend(private_key.as_bytes()); - - let certs = ChainAndPrivateKey::parse_pem(Cursor::new(buf)).unwrap(); - - certs.clone().save_pem(&tls_path).unwrap(); - - certs - } - } -} diff --git a/gateway/src/project.rs b/gateway/src/project.rs index 7af842ba69..c8fc45b064 100644 --- a/gateway/src/project.rs +++ b/gateway/src/project.rs @@ -22,9 +22,9 @@ use serde::{Deserialize, Serialize}; use tokio::time::{sleep, timeout}; use tracing::{debug, error, info, instrument}; +use crate::service::ContainerSettings; use crate::{ - ContainerSettings, DockerContext, EndState, Error, ErrorKind, IntoTryState, ProjectName, - Refresh, State, TryState, + DockerContext, EndState, Error, ErrorKind, IntoTryState, ProjectName, Refresh, State, TryState, }; macro_rules! safe_unwrap { diff --git a/gateway/src/service.rs b/gateway/src/service.rs index 60242fcd31..455eab84b3 100644 --- a/gateway/src/service.rs +++ b/gateway/src/service.rs @@ -1,4 +1,7 @@ +use std::io::Cursor; use std::net::Ipv4Addr; +use std::ops::Sub; +use std::path::PathBuf; use std::sync::Arc; use axum::body::Body; @@ -6,11 +9,12 @@ use axum::headers::HeaderMapExt; use axum::http::Request; use axum::response::Response; use bollard::{Docker, API_DEFAULT_VERSION}; -use fqdn::Fqdn; +use fqdn::{Fqdn, FQDN}; use hyper::client::connect::dns::GaiResolver; use hyper::client::HttpConnector; use hyper::Client; use hyper_reverse_proxy::ReverseProxy; +use instant_acme::{AccountCredentials, ChallengeType}; use once_cell::sync::Lazy; use opentelemetry::global; use opentelemetry_http::HeaderInjector; @@ -20,14 +24,18 @@ use sqlx::migrate::Migrator; use sqlx::sqlite::SqlitePool; use sqlx::types::Json as SqlxJson; use sqlx::{query, Error as SqlxError, Row}; +use tracing::log::warn; use tracing::{debug, Span}; use tracing_opentelemetry::OpenTelemetrySpanExt; +use x509_parser::parse_x509_certificate; +use x509_parser::time::ASN1Time; -use crate::acme::CustomDomain; +use crate::acme::{AccountWrapper, AcmeClient, AcmeCredentials, CustomDomain}; use crate::args::ContextArgs; use crate::auth::ScopedUser; use crate::project::{Project, ProjectCreating}; use crate::task::{BoxedTask, TaskBuilder}; +use crate::tls::{ChainAndPrivateKey, GatewayCertResolver}; use crate::worker::TaskRouter; use crate::{AccountName, DockerContext, Error, ErrorKind, ProjectDetails, ProjectName}; @@ -177,6 +185,7 @@ pub struct GatewayService { provider: GatewayContextProvider, db: SqlitePool, task_router: TaskRouter, + state_location: PathBuf, } impl GatewayService { @@ -184,7 +193,7 @@ impl GatewayService { /// /// * `args` - The [`Args`] with which the service was /// started. Will be passed as [`Context`] to workers and state. - pub async fn init(args: ContextArgs, db: SqlitePool) -> Self { + pub async fn init(args: ContextArgs, db: SqlitePool, state_location: PathBuf) -> Self { let docker = Docker::connect_with_unix(&args.docker_host, 60, API_DEFAULT_VERSION).unwrap(); let container_settings = ContainerSettings::builder().from_args(&args).await; @@ -197,6 +206,7 @@ impl GatewayService { provider, db, task_router, + state_location, } } @@ -441,14 +451,14 @@ impl GatewayService { pub async fn create_custom_domain( &self, - project_name: ProjectName, + project_name: &ProjectName, fqdn: &Fqdn, certs: &str, private_key: &str, ) -> Result<(), Error> { query("INSERT OR REPLACE INTO custom_domains (fqdn, project_name, certificate, private_key) VALUES (?1, ?2, ?3, ?4)") .bind(fqdn.to_string()) - .bind(&project_name) + .bind(project_name) .bind(certs) .bind(private_key) .execute(&self.db) @@ -526,6 +536,136 @@ impl GatewayService { Ok(iter) } + /// Returns the current certificate as a pair of the chain and private key. + /// If the pair doesn't exist for a specific project, create both the certificate + /// and the custom domain it will represent. + pub async fn create_custom_domain_certificate( + &self, + fqdn: &Fqdn, + acme_client: &AcmeClient, + project_name: &ProjectName, + creds: AccountCredentials<'_>, + ) -> Result<(String, String), Error> { + match self.project_details_for_custom_domain(fqdn).await { + Ok(CustomDomain { + certificate, + private_key, + .. + }) => Ok((certificate, private_key)), + Err(err) if err.kind() == ErrorKind::CustomDomainNotFound => { + let (certs, private_key) = acme_client + .create_certificate(&fqdn.to_string(), ChallengeType::Http01, creds) + .await?; + self.create_custom_domain(project_name, fqdn, &certs, &private_key) + .await?; + Ok((certs, private_key)) + } + Err(err) => Err(err), + } + } + + async fn create_certificate<'a>( + &self, + acme: &AcmeClient, + resolver: Arc, + creds: AccountCredentials<'a>, + ) -> ChainAndPrivateKey { + let public: FQDN = self.context().settings.fqdn.parse().unwrap(); + let identifier = format!("*.{public}"); + + // Use ::Dns01 challenge because that's the only supported + // challenge type for wildcard domains. + let (chain, private_key) = acme + .create_certificate(&identifier, ChallengeType::Dns01, creds) + .await + .unwrap(); + + let mut buf = Vec::new(); + buf.extend(chain.as_bytes()); + buf.extend(private_key.as_bytes()); + let certs = ChainAndPrivateKey::parse_pem(Cursor::new(buf)).expect("Malformed PEM buffer."); + resolver + .serve_default_der(certs.clone()) + .await + .expect("Failed to serve the default certs"); + + certs + } + + /// Fetch the gateway certificate from the state location. + /// If not existent, create the gateway certificate and save it to the + /// gateway state. + pub async fn fetch_certificate( + &self, + acme: &AcmeClient, + resolver: Arc, + creds: AcmeCredentials<'_>, + ) -> ChainAndPrivateKey { + let tls_path = self.state_location.join("ssl.pem"); + match ChainAndPrivateKey::load_pem(&tls_path) { + Ok(valid) => valid, + Err(_) => { + warn!( + "no valid certificate found at {}, creating one...", + tls_path.display() + ); + + let creds = match creds { + AcmeCredentials::InMemory(creds) => creds, + AcmeCredentials::GatewayState => { + let creds_path = self.state_location.join("acme.json"); + if !creds_path.exists() { + panic!( + "no ACME credentials found at {}, cannot continue with certificate creation", + creds_path.display() + ); + } + + let creds = + std::fs::File::open(creds_path).expect("Invalid credentials path"); + serde_json::from_reader(&creds) + .expect("Can not parse admin credentials from path") + } + }; + + let certs = self.create_certificate(acme, resolver, creds).await; + certs.clone().save_pem(&tls_path).unwrap(); + certs + } + } + } + + /// Renew the gateway certificate if there less than 30 days until the current + /// certificate expiration. + pub(crate) async fn renew_certificate( + &self, + acme: &AcmeClient, + resolver: Arc, + creds: AccountCredentials<'_>, + ) { + let account = AccountWrapper::from(creds).0; + let certs = self + .fetch_certificate( + acme, + resolver.clone(), + AcmeCredentials::InMemory(account.credentials()), + ) + .await; + // Safe to unwrap because a 'ChainAndPrivateKey' is built from a PEM. + let chain_and_pk = certs.into_pem().unwrap(); + + let (_, x509_cert) = parse_x509_certificate(chain_and_pk.as_bytes()).unwrap(); + let diff = x509_cert.validity().not_after.sub(ASN1Time::now()).unwrap(); + + if diff.whole_days() <= 30 { + let tls_path = self.state_location.join("ssl.pem"); + let certs = self + .create_certificate(acme, resolver.clone(), account.credentials()) + .await; + certs.save_pem(&tls_path).unwrap(); + } + } + pub fn context(&self) -> GatewayContext { self.provider.context() } @@ -561,6 +701,7 @@ pub mod tests { use fqdn::FQDN; use super::*; + use crate::task::{self, TaskResult}; use crate::tests::{assert_err_kind, World}; use crate::{Error, ErrorKind}; @@ -568,7 +709,7 @@ pub mod tests { #[tokio::test] async fn service_create_find_delete_project() -> anyhow::Result<()> { let world = World::new().await; - let svc = Arc::new(GatewayService::init(world.args(), world.pool()).await); + let svc = Arc::new(GatewayService::init(world.args(), world.pool(), "".into()).await); let neo: AccountName = "neo".parse().unwrap(); let trinity: AccountName = "trinity".parse().unwrap(); @@ -662,7 +803,7 @@ pub mod tests { #[tokio::test] async fn service_create_ready_kill_restart_docker() -> anyhow::Result<()> { let world = World::new().await; - let svc = Arc::new(GatewayService::init(world.args(), world.pool()).await); + let svc = Arc::new(GatewayService::init(world.args(), world.pool(), "".into()).await); let neo: AccountName = "neo".parse().unwrap(); let matrix: ProjectName = "matrix".parse().unwrap(); @@ -718,7 +859,7 @@ pub mod tests { #[tokio::test] async fn service_create_find_custom_domain() -> anyhow::Result<()> { let world = World::new().await; - let svc = Arc::new(GatewayService::init(world.args(), world.pool()).await); + let svc = Arc::new(GatewayService::init(world.args(), world.pool(), "".into()).await); let account: AccountName = "neo".parse().unwrap(); let project_name: ProjectName = "matrix".parse().unwrap(); @@ -736,7 +877,7 @@ pub mod tests { .await .unwrap(); - svc.create_custom_domain(project_name.clone(), &domain, certificate, private_key) + svc.create_custom_domain(&project_name, &domain, certificate, private_key) .await .unwrap(); @@ -753,7 +894,7 @@ pub mod tests { let certificate = "dummy certificate update"; let private_key = "dummy private key update"; - svc.create_custom_domain(project_name.clone(), &domain, certificate, private_key) + svc.create_custom_domain(&project_name, &domain, certificate, private_key) .await .unwrap(); @@ -772,7 +913,7 @@ pub mod tests { #[tokio::test] async fn service_create_custom_domain_destroy_recreate_project() -> anyhow::Result<()> { let world = World::new().await; - let svc = Arc::new(GatewayService::init(world.args(), world.pool()).await); + let svc = Arc::new(GatewayService::init(world.args(), world.pool(), "".into()).await); let account: AccountName = "neo".parse().unwrap(); let project_name: ProjectName = "matrix".parse().unwrap(); @@ -790,7 +931,7 @@ pub mod tests { .await .unwrap(); - svc.create_custom_domain(project_name.clone(), &domain, certificate, private_key) + svc.create_custom_domain(&project_name, &domain, certificate, private_key) .await .unwrap();