Skip to content

Commit

Permalink
gateway/certs: automated gateway certs renewal
Browse files Browse the repository at this point in the history
Now starting the gateway will start a tokio task which will
loop indefinitely, checking if the certificate must be renewed
in case it's approaching 30 days before expiration.

Signed-off-by: Iulian Barbu <iulianbarbu2@gmail.com>
  • Loading branch information
iulianbarbu committed Feb 19, 2023
1 parent a22b814 commit 2384988
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 40 deletions.
2 changes: 1 addition & 1 deletion gateway/src/api/latest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ async fn request_acme_certificate(

let account = AccountWrapper::from(credentials).0;
let (certs, private_key) = service
.get_or_create_certificate(&fqdn, acme_client.clone(), &account, project_name.clone())
.get_or_create_certificate(&fqdn, acme_client.clone(), &account, &project_name)
.await?;

// Destroy and recreate the project with the new domain.
Expand Down
127 changes: 95 additions & 32 deletions gateway/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,21 @@ use shuttle_gateway::auth::Key;
use shuttle_gateway::proxy::UserServiceBuilder;
use shuttle_gateway::service::{GatewayService, MIGRATIONS};
use shuttle_gateway::task;
use shuttle_gateway::tls::{make_tls_acceptor, ChainAndPrivateKey};
use shuttle_gateway::tls::{make_tls_acceptor, ChainAndPrivateKey, GatewayCertResolver};
use shuttle_gateway::worker::{Worker, WORKER_QUEUE_SIZE};
use sqlx::migrate::MigrateDatabase;
use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqliteSynchronous};
use sqlx::{query, Sqlite, SqlitePool};
use std::io::{self, Cursor};
use std::ops::Sub;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, error, info, info_span, trace, warn, Instrument};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
use x509_parser::parse_x509_certificate;
use x509_parser::time::ASN1Time;

#[tokio::main(flavor = "multi_thread")]
async fn main() -> io::Result<()> {
Expand Down Expand Up @@ -192,10 +195,24 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> {
.unwrap();
}

let fs_clone = fs.clone();
let resolver_clone = resolver.clone();
tokio::spawn(async move {
// make sure we have a certificate for ourselves
let certs = init_certs(fs, args.context.proxy_fqdn.clone(), acme_client.clone()).await;
resolver.serve_default_der(certs).await.unwrap();
request_gateway_certs(
fs,
args.context.proxy_fqdn.clone(),
acme_client.clone(),
resolver_clone.clone(),
)
.await;
automate_renewal(
fs_clone,
args.context.proxy_fqdn.clone(),
acme_client.clone(),
resolver_clone,
)
.await
});
} else {
warn!("TLS is disabled in the proxy service. This is only acceptable in testing, and should *never* be used in deployments.");
Expand Down Expand Up @@ -237,47 +254,93 @@ async fn init(db: SqlitePool, args: InitArgs) -> io::Result<()> {
Ok(())
}

async fn init_certs<P: AsRef<Path>>(fs: P, public: FQDN, acme: AcmeClient) -> ChainAndPrivateKey {
async fn request_gateway_certs<P: AsRef<Path>>(
fs: P,
public: FQDN,
acme: AcmeClient,
resolver: Arc<GatewayCertResolver>,
) -> ChainAndPrivateKey {
let tls_path = fs.as_ref().join("ssl.pem");

match ChainAndPrivateKey::load_pem(&tls_path) {
Ok(valid) => valid,
Err(_) => {
let creds_path = fs.as_ref().join("acme.json");
warn!(
"no valid certificate found at {}, creating one...",
tls_path.display()
);
let certs = create_gateway_certificate(fs, public, acme, resolver).await;
certs.clone().save_pem(&tls_path).unwrap();
certs
}
}
}

if !creds_path.exists() {
panic!(
"no ACME credentials found at {}, cannot continue with certificate creation",
creds_path.display()
);
}

let creds = std::fs::File::open(creds_path).unwrap();
let creds: AccountCredentials = serde_json::from_reader(&creds).unwrap();
let account = AccountWrapper::from(creds).0;

let identifier = format!("*.{public}");

// Use ::Dns01 challenge because that's the only supported
// challenge type for wildcard domains
let (chain, private_key) = acme
.create_certificate(&identifier, ChallengeType::Dns01, &account)
.await
.unwrap();

let mut buf = Vec::new();
buf.extend(chain.as_bytes());
buf.extend(private_key.as_bytes());
async fn create_gateway_certificate<P: AsRef<Path>>(
fs: P,
public: FQDN,
acme: AcmeClient,
resolver: Arc<GatewayCertResolver>,
) -> ChainAndPrivateKey {
let creds_path = fs.as_ref().join("acme.json");
let identifier = format!("*.{public}");
if !creds_path.exists() {
panic!(
"no ACME credentials found at {}, cannot continue with certificate creation",
creds_path.display()
);
}
let creds = std::fs::File::open(creds_path).expect("Invalid credentials path");
let creds: AccountCredentials =
serde_json::from_reader(&creds).expect("Can not parse admin credentials from path");
let account = Account::from_credentials(creds)
.map_err(|error| {
error!(
error = &error as &dyn std::error::Error,
"failed to convert acme credentials into account"
);
AcmeClientError::AccountCreation
})
.unwrap();

let certs = ChainAndPrivateKey::parse_pem(Cursor::new(buf)).unwrap();
// Use ::Dns01 challenge because that's the only supported
// challenge type for wildcard domains
let (chain, private_key) = acme
.create_certificate(&identifier, ChallengeType::Dns01, &account)
.await
.unwrap();

certs.clone().save_pem(&tls_path).unwrap();
let mut buf = Vec::new();
buf.extend(chain.as_bytes());
buf.extend(private_key.as_bytes());
let certs = ChainAndPrivateKey::parse_pem(Cursor::new(buf)).unwrap();
resolver.serve_default_der(certs.clone()).await.unwrap();
certs
}

certs
async fn automate_renewal<P: AsRef<Path> + Clone>(
fs: P,
public: FQDN,
acme: AcmeClient,
resolver: Arc<GatewayCertResolver>,
) {
loop {
tokio::time::sleep(Duration::from_secs(240)).await;
let certs =
request_gateway_certs(fs.clone(), public.clone(), acme.clone(), resolver.clone()).await;
// Safe because a 'ChainAndPrivateKey' is built from a PEM.
let chain_and_pk = certs.into_pem().unwrap();
let (_, x509_cert) = parse_x509_certificate(chain_and_pk.as_bytes()).unwrap();
let diff = x509_cert.validity().not_after.sub(ASN1Time::now()).unwrap();
if diff.whole_days() <= 30 {
let tls_path = fs.as_ref().join("ssl.pem");
let certs = create_gateway_certificate(
fs.clone(),
public.clone(),
acme.clone(),
resolver.clone(),
)
.await;
certs.save_pem(&tls_path).unwrap();
}
}
}
14 changes: 7 additions & 7 deletions gateway/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,14 +485,14 @@ impl GatewayService {

pub async fn create_custom_domain(
&self,
project_name: ProjectName,
project_name: &ProjectName,
fqdn: &Fqdn,
certs: &str,
private_key: &str,
) -> Result<(), Error> {
query("INSERT OR REPLACE INTO custom_domains (fqdn, project_name, certificate, private_key) VALUES (?1, ?2, ?3, ?4)")
.bind(fqdn.to_string())
.bind(&project_name)
.bind(project_name)
.bind(certs)
.bind(private_key)
.execute(&self.db)
Expand Down Expand Up @@ -558,7 +558,7 @@ impl GatewayService {
fqdn: &Fqdn,
acme_client: AcmeClient,
account: &Account,
project_name: ProjectName,
project_name: &ProjectName,
) -> Result<(String, String), Error> {
match self.project_details_for_custom_domain(fqdn).await {
Ok(CustomDomain {
Expand All @@ -568,9 +568,9 @@ impl GatewayService {
}) => Ok((certificate, private_key)),
Err(err) if err.kind() == ErrorKind::CustomDomainNotFound => {
let (certs, private_key) = acme_client
.create_certificate(&fqdn.to_string(), ChallengeType::Http01, &account)
.create_certificate(&fqdn.to_string(), ChallengeType::Http01, account)
.await?;
self.create_custom_domain(project_name.clone(), &fqdn, &certs, &private_key)
self.create_custom_domain(&project_name, &fqdn, &certs, &private_key)
.await?;
Ok((certs, private_key))
}
Expand Down Expand Up @@ -832,7 +832,7 @@ pub mod tests {
.await
.unwrap();

svc.create_custom_domain(project_name.clone(), &domain, certificate, private_key)
svc.create_custom_domain(&project_name, &domain, certificate, private_key)
.await
.unwrap();

Expand All @@ -849,7 +849,7 @@ pub mod tests {
let certificate = "dummy certificate update";
let private_key = "dummy private key update";

svc.create_custom_domain(project_name.clone(), &domain, certificate, private_key)
svc.create_custom_domain(&project_name, &domain, certificate, private_key)
.await
.unwrap();

Expand Down

0 comments on commit 2384988

Please sign in to comment.