From 4dfd65c4bc27eeb69e246c3eaf557e1f8596e47a Mon Sep 17 00:00:00 2001 From: Pieter Date: Wed, 2 Aug 2023 17:24:54 +0200 Subject: [PATCH] feat: gateway to start last deploy from idle project (#1121) * refactor: stop cycling resources * feat: gateway start idle deploys * refactor: comment deploy scopes * refactor: better name * bug: use current time * bug: skip newest not oldest * refactor: reverse order from query --------- Co-authored-by: Iulian Barbu <14218860+iulianbarbu@users.noreply.github.com> --- auth/src/user.rs | 5 +- common/src/claims.rs | 8 ++- deployer/src/deployment/deploy_layer.rs | 6 +- deployer/src/deployment/queue.rs | 2 +- deployer/src/deployment/run.rs | 31 ++++------ deployer/src/handlers/mod.rs | 4 +- deployer/src/lib.rs | 21 +++---- deployer/src/persistence/mod.rs | 18 +++++- gateway/src/api/latest.rs | 8 +-- gateway/src/lib.rs | 6 +- gateway/src/project.rs | 78 ++++++++++++++++++++++--- gateway/src/service.rs | 71 ++++++++++++++++++---- gateway/src/task.rs | 27 +++++---- 13 files changed, 206 insertions(+), 79 deletions(-) diff --git a/auth/src/user.rs b/auth/src/user.rs index d9fec5269..b6e64a2e3 100644 --- a/auth/src/user.rs +++ b/auth/src/user.rs @@ -376,7 +376,10 @@ mod tests { fn deployer_machine() { let scopes: Vec = AccountTier::Deployer.into(); - assert_eq!(scopes, vec![Scope::DeploymentPush, Scope::Resources]); + assert_eq!( + scopes, + vec![Scope::DeploymentPush, Scope::Resources, Scope::Service] + ); } } } diff --git a/common/src/claims.rs b/common/src/claims.rs index 0f4c0e0fc..efa4d0aa4 100644 --- a/common/src/claims.rs +++ b/common/src/claims.rs @@ -131,7 +131,11 @@ impl ScopeBuilder { /// Extend the current scopes with those needed by a deployer machine / user. pub fn with_deploy_rights(mut self) -> Self { - self.0.extend(vec![Scope::DeploymentPush, Scope::Resources]); + self.0.extend(vec![ + Scope::DeploymentPush, // To start an idle deploy + Scope::Resources, // To get past resources for an idle deploy + Scope::Service, // To get the running deploy for a service + ]); self } @@ -146,7 +150,7 @@ impl Default for ScopeBuilder { } } -#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] +#[derive(Clone, Debug, Default, Deserialize, Serialize, Eq, PartialEq)] pub struct Claim { /// Expiration time (as UTC timestamp). pub exp: usize, diff --git a/deployer/src/deployment/deploy_layer.rs b/deployer/src/deployment/deploy_layer.rs index 64a5fda04..4b72fe11c 100644 --- a/deployer/src/deployment/deploy_layer.rs +++ b/deployer/src/deployment/deploy_layer.rs @@ -828,7 +828,7 @@ mod tests { service_id: Uuid::new_v4(), tracing_context: Default::default(), is_next: false, - claim: None, + claim: Default::default(), }) .await; @@ -872,7 +872,7 @@ mod tests { data: Bytes::from("violets are red").to_vec(), will_run_tests: false, tracing_context: Default::default(), - claim: None, + claim: Default::default(), }) .await; @@ -931,7 +931,7 @@ mod tests { data: bytes, will_run_tests: false, tracing_context: Default::default(), - claim: None, + claim: Default::default(), } } } diff --git a/deployer/src/deployment/queue.rs b/deployer/src/deployment/queue.rs index a59a20a72..55f0d6aa5 100644 --- a/deployer/src/deployment/queue.rs +++ b/deployer/src/deployment/queue.rs @@ -165,7 +165,7 @@ pub struct Queued { pub data: Vec, pub will_run_tests: bool, pub tracing_context: HashMap, - pub claim: Option, + pub claim: Claim, } impl Queued { diff --git a/deployer/src/deployment/run.rs b/deployer/src/deployment/run.rs index ebcb3b501..19123cbaa 100644 --- a/deployer/src/deployment/run.rs +++ b/deployer/src/deployment/run.rs @@ -206,7 +206,7 @@ pub struct Built { pub service_id: Uuid, pub tracing_context: HashMap, pub is_next: bool, - pub claim: Option, + pub claim: Claim, } impl Built { @@ -286,7 +286,7 @@ async fn load( secret_getter: impl SecretGetter, resource_manager: impl ResourceManager, mut runtime_client: RuntimeClient>>, - claim: Option, + claim: Claim, ) -> Result<()> { info!( "loading project from: {}", @@ -297,19 +297,14 @@ async fn load( .unwrap_or_default() ); - // Get resources from cache when a claim is not set (ie an idl project is started) - let resources = if claim.is_none() { - resource_manager - .get_resources(&service_id) - .await - .unwrap() - .into_iter() - .map(resource::Response::from) - .map(resource::Response::into_bytes) - .collect() - } else { - Default::default() - }; + let resources = resource_manager + .get_resources(&service_id) + .await + .unwrap() + .into_iter() + .map(resource::Response::from) + .map(resource::Response::into_bytes) + .collect(); let secrets = secret_getter .get_secrets(&service_id) @@ -329,9 +324,7 @@ async fn load( secrets, }); - if let Some(claim) = claim { - load_request.extensions_mut().insert(claim); - } + load_request.extensions_mut().insert(claim); debug!(service_name = %service_name, "loading service"); let response = runtime_client.load(load_request).await; @@ -749,7 +742,7 @@ mod tests { service_id: Uuid::new_v4(), tracing_context: Default::default(), is_next: false, - claim: None, + claim: Default::default(), }, storage_manager, ) diff --git a/deployer/src/handlers/mod.rs b/deployer/src/handlers/mod.rs index 63f7b3170..7ccc61043 100644 --- a/deployer/src/handlers/mod.rs +++ b/deployer/src/handlers/mod.rs @@ -356,7 +356,7 @@ pub async fn create_service( data: deployment_req.data, will_run_tests: !deployment_req.no_test, tracing_context: Default::default(), - claim: Some(claim), + claim, }; deployment_manager.queue_push(queued).await; @@ -521,7 +521,7 @@ pub async fn start_deployment( service_id: deployment.service_id, tracing_context: Default::default(), is_next: deployment.is_next, - claim: Some(claim), + claim, }; deployment_manager.run_push(built).await; diff --git a/deployer/src/lib.rs b/deployer/src/lib.rs index 0448f6577..272abbc01 100644 --- a/deployer/src/lib.rs +++ b/deployer/src/lib.rs @@ -2,7 +2,7 @@ use std::{convert::Infallible, net::SocketAddr, sync::Arc}; pub use args::Args; pub use deployment::deploy_layer::DeployLayer; -use deployment::{Built, DeploymentManager}; +use deployment::DeploymentManager; use fqdn::FQDN; use hyper::{ server::conn::AddrStream, @@ -45,17 +45,14 @@ pub async fn start( persistence.cleanup_invalid_states().await.unwrap(); let runnable_deployments = persistence.get_all_runnable_deployments().await.unwrap(); - info!(count = %runnable_deployments.len(), "enqueuing runnable deployments"); - for existing_deployment in runnable_deployments { - let built = Built { - id: existing_deployment.id, - service_name: existing_deployment.service_name, - service_id: existing_deployment.service_id, - tracing_context: Default::default(), - is_next: existing_deployment.is_next, - claim: None, // This will cause us to read the resource info from past provisions - }; - deployment_manager.run_push(built).await; + info!(count = %runnable_deployments.len(), "stopping all but last running deploy"); + + // Make sure we don't stop the last running deploy. This works because they are returned in descending order. + for existing_deployment in runnable_deployments.into_iter().skip(1) { + persistence + .stop_running_deployment(existing_deployment) + .await + .unwrap(); } let mut builder = handlers::RouterBuilder::new( diff --git a/deployer/src/persistence/mod.rs b/deployer/src/persistence/mod.rs index 97f273411..ccf6e8f04 100644 --- a/deployer/src/persistence/mod.rs +++ b/deployer/src/persistence/mod.rs @@ -288,7 +288,7 @@ impl Persistence { FROM deployments AS d JOIN services AS s ON s.id = d.service_id WHERE state = ? - ORDER BY last_update"#, + ORDER BY last_update DESC"#, ) .bind(State::Running) .fetch_all(&self.pool) @@ -328,6 +328,18 @@ impl Persistence { pub fn get_log_sender(&self) -> crossbeam_channel::Sender { self.log_send.clone() } + + pub async fn stop_running_deployment(&self, deployable: DeploymentRunnable) -> Result<()> { + update_deployment( + &self.pool, + DeploymentState { + id: deployable.id, + last_update: Utc::now(), + state: State::Stopped, + }, + ) + .await + } } async fn update_deployment(pool: &SqlitePool, state: impl Into) -> Result<()> { @@ -935,7 +947,7 @@ mod tests { runnable, [ DeploymentRunnable { - id: id_1, + id: id_3, service_name: "foo".to_string(), service_id: foo_id, is_next: false, @@ -947,7 +959,7 @@ mod tests { is_next: true, }, DeploymentRunnable { - id: id_3, + id: id_1, service_name: "foo".to_string(), service_id: foo_id, is_next: false, diff --git a/gateway/src/api/latest.rs b/gateway/src/api/latest.rs index b3e773c2b..99f4c7d6f 100644 --- a/gateway/src/api/latest.rs +++ b/gateway/src/api/latest.rs @@ -182,6 +182,8 @@ async fn create_project( service .new_task() .project(project.clone()) + .and_then(task::run_until_done()) + .and_then(task::start_idle_deploys()) .send(&sender) .await?; @@ -291,11 +293,7 @@ async fn get_status( // Compute auth status. let auth_status = { - let response = AUTH_CLIENT - .get_or_init(reqwest::Client::new) - .get(service.auth_uri().to_string()) - .send() - .await; + let response = AUTH_CLIENT.get(service.auth_uri().clone()).await; match response { Ok(response) if response.status() == 200 => StatusResponse::healthy(), Ok(_) | Err(_) => StatusResponse::unhealthy(), diff --git a/gateway/src/lib.rs b/gateway/src/lib.rs index 9e192a636..3848ad7b6 100644 --- a/gateway/src/lib.rs +++ b/gateway/src/lib.rs @@ -7,13 +7,15 @@ use std::fmt::Formatter; use std::io; use std::pin::Pin; use std::str::FromStr; -use std::sync::OnceLock; use acme::AcmeClientError; use axum::response::{IntoResponse, Response}; use axum::Json; use bollard::Docker; use futures::prelude::*; +use hyper::client::HttpConnector; +use hyper::Client; +use once_cell::sync::Lazy; use serde::{Deserialize, Deserializer, Serialize}; use service::ContainerSettings; use shuttle_common::models::error::{ApiError, ErrorKind}; @@ -31,7 +33,7 @@ pub mod task; pub mod tls; pub mod worker; -static AUTH_CLIENT: OnceLock = OnceLock::new(); +static AUTH_CLIENT: Lazy> = Lazy::new(Client::new); /// Server-side errors that do not have to do with the user runtime /// should be [`Error`]s. diff --git a/gateway/src/project.rs b/gateway/src/project.rs index 53b85fec2..4910f42ee 100644 --- a/gateway/src/project.rs +++ b/gateway/src/project.rs @@ -13,16 +13,20 @@ use bollard::network::{ConnectNetworkOptions, DisconnectNetworkOptions}; use bollard::system::EventsOptions; use fqdn::FQDN; use futures::prelude::*; +use http::header::AUTHORIZATION; use http::uri::InvalidUri; -use http::Uri; +use http::{Method, Request, Uri}; use hyper::client::HttpConnector; -use hyper::Client; +use hyper::{Body, Client}; use once_cell::sync::Lazy; use rand::distributions::{Alphanumeric, DistString}; use serde::{Deserialize, Serialize}; +use shuttle_common::backends::headers::{X_SHUTTLE_ACCOUNT_NAME, X_SHUTTLE_ADMIN_SECRET}; use shuttle_common::models::project::{idle_minutes, IDLE_MINUTES}; +use shuttle_common::models::service; use tokio::time::{sleep, timeout}; -use tracing::{debug, error, info, instrument}; +use tracing::{debug, error, info, instrument, trace}; +use uuid::Uuid; use crate::service::ContainerSettings; use crate::{ @@ -72,7 +76,7 @@ const MAX_REBOOTS: usize = 3; // Client used for health checks static CLIENT: Lazy> = Lazy::new(Client::new); // Health check must succeed within 10 seconds -static IS_HEALTHY_TIMEOUT: Duration = Duration::from_secs(10); +pub static IS_HEALTHY_TIMEOUT: Duration = Duration::from_secs(10); #[async_trait] impl Refresh for ContainerInspectResponse @@ -1178,8 +1182,10 @@ impl ProjectReady { self.service.is_healthy().await } - pub async fn start_last_deploy(&mut self, api_key: String) { - self.service.start_last_deploy(api_key).await + pub async fn start_last_deploy(&mut self, jwt: String, admin_secret: String) { + if let Err(error) = self.service.start_last_deploy(jwt, admin_secret).await { + error!(error, "failed to start last running deploy"); + }; } } @@ -1239,8 +1245,64 @@ impl Service { is_healthy } - pub async fn start_last_deploy(&mut self, _api_key: String) { - // TODO: convert the key to a JWT, get last deployment and start it (ENG-816) + pub async fn start_last_deploy( + &mut self, + jwt: String, + admin_secret: String, + ) -> Result<(), Box> { + trace!(jwt, "getting last deploy"); + + let running_id = self.get_running_deploy(&jwt, &admin_secret).await?; + + trace!(?running_id, "starting deploy"); + + if let Some(running_id) = running_id { + // Start this deployment + let uri = self.uri(format!( + "/projects/{}/deployments/{}", + self.name, running_id + ))?; + + let req = Request::builder() + .method(Method::PUT) + .uri(uri) + .header(AUTHORIZATION, format!("Bearer {}", jwt)) + .header(X_SHUTTLE_ACCOUNT_NAME.clone(), "gateway") + .header(X_SHUTTLE_ADMIN_SECRET.clone(), admin_secret) + .body(Body::empty())?; + + let _ = timeout(IS_HEALTHY_TIMEOUT, CLIENT.request(req)).await; + } + + Ok(()) + } + + /// Get the last running deployment + async fn get_running_deploy( + &self, + jwt: &str, + admin_secret: &str, + ) -> Result, Box> { + let uri = self.uri(format!("/projects/{}/services/{}", self.name, self.name))?; + + let req = Request::builder() + .uri(uri) + .header(AUTHORIZATION, format!("Bearer {}", jwt)) + .header(X_SHUTTLE_ACCOUNT_NAME.clone(), "gateway") + .header(X_SHUTTLE_ADMIN_SECRET.clone(), admin_secret) + .body(Body::empty())?; + + let resp = timeout(IS_HEALTHY_TIMEOUT, CLIENT.request(req)).await??; + + let body = hyper::body::to_bytes(resp.into_body()).await?; + + let service: service::Summary = serde_json::from_slice(&body)?; + + if let Some(deployment) = service.deployment { + Ok(Some(deployment.id)) + } else { + Ok(None) + } } } diff --git a/gateway/src/service.rs b/gateway/src/service.rs index d57b11030..a1431eb03 100644 --- a/gateway/src/service.rs +++ b/gateway/src/service.rs @@ -10,6 +10,7 @@ use axum::http::Request; use axum::response::Response; use bollard::{Docker, API_DEFAULT_VERSION}; use fqdn::{Fqdn, FQDN}; +use http::header::AUTHORIZATION; use http::Uri; use hyper::client::connect::dns::GaiResolver; use hyper::client::HttpConnector; @@ -19,6 +20,7 @@ use instant_acme::{AccountCredentials, ChallengeType}; use once_cell::sync::Lazy; use opentelemetry::global; use opentelemetry_http::HeaderInjector; +use serde_json::Value; use shuttle_common::backends::headers::{XShuttleAccountName, XShuttleAdminSecret}; use sqlx::error::DatabaseError; use sqlx::migrate::Migrator; @@ -26,8 +28,9 @@ use sqlx::sqlite::SqlitePool; use sqlx::types::Json as SqlxJson; use sqlx::{query, Error as SqlxError, QueryBuilder, Row}; use tokio::sync::mpsc::Sender; +use tokio::time::timeout; use tonic::transport::Endpoint; -use tracing::{debug, trace, warn, Span}; +use tracing::{debug, instrument, trace, warn, Span}; use tracing_opentelemetry::OpenTelemetrySpanExt; use x509_parser::nom::AsBytes; use x509_parser::parse_x509_certificate; @@ -36,11 +39,13 @@ use x509_parser::time::ASN1Time; use crate::acme::{AccountWrapper, AcmeClient, CustomDomain}; use crate::args::ContextArgs; -use crate::project::{Project, ProjectCreating}; +use crate::project::{Project, ProjectCreating, IS_HEALTHY_TIMEOUT}; use crate::task::{self, BoxedTask, TaskBuilder}; use crate::tls::{ChainAndPrivateKey, GatewayCertResolver, RENEWAL_VALIDITY_THRESHOLD_IN_DAYS}; use crate::worker::TaskRouter; -use crate::{AccountName, DockerContext, Error, ErrorKind, ProjectDetails, ProjectName}; +use crate::{ + AccountName, DockerContext, Error, ErrorKind, ProjectDetails, ProjectName, AUTH_CLIENT, +}; pub static MIGRATIONS: Migrator = sqlx::migrate!("./migrations"); static PROXY_CLIENT: Lazy>> = @@ -169,17 +174,31 @@ impl ContainerSettings { pub struct GatewayContextProvider { docker: Docker, settings: ContainerSettings, + api_key: String, + auth_key_uri: Uri, } impl GatewayContextProvider { - pub fn new(docker: Docker, settings: ContainerSettings) -> Self { - Self { docker, settings } + pub fn new( + docker: Docker, + settings: ContainerSettings, + api_key: String, + auth_key_uri: Uri, + ) -> Self { + Self { + docker, + settings, + api_key, + auth_key_uri, + } } pub fn context(&self) -> GatewayContext { GatewayContext { docker: self.docker.clone(), settings: self.settings.clone(), + api_key: self.api_key.clone(), + auth_key_uri: self.auth_key_uri.clone(), } } } @@ -193,7 +212,6 @@ pub struct GatewayService { // We store these because we'll need them for the health checks provisioner_host: Endpoint, auth_host: Uri, - api_key: String, } impl GatewayService { @@ -206,7 +224,12 @@ impl GatewayService { let container_settings = ContainerSettings::builder().from_args(&args).await; - let provider = GatewayContextProvider::new(docker, container_settings); + let provider = GatewayContextProvider::new( + docker, + container_settings, + args.deploys_api_key, + format!("{}auth/key", args.auth_uri).parse().unwrap(), + ); let task_router = TaskRouter::new(); Self { @@ -217,7 +240,6 @@ impl GatewayService { provisioner_host: Endpoint::new(format!("http://{}:8000", args.provisioner_host)) .expect("to have a valid provisioner endpoint"), auth_host: args.auth_uri, - api_key: args.deploys_api_key, } } @@ -759,15 +781,14 @@ impl GatewayService { pub fn auth_uri(&self) -> &Uri { &self.auth_host } - pub fn api_key(&self) -> String { - self.api_key.clone() - } } #[derive(Clone)] pub struct GatewayContext { docker: Docker, settings: ContainerSettings, + api_key: String, + auth_key_uri: Uri, } impl DockerContext for GatewayContext { @@ -780,6 +801,34 @@ impl DockerContext for GatewayContext { } } +impl GatewayContext { + #[instrument(skip(self), fields(auth_key_uri = %self.auth_key_uri, api_key = self.api_key))] + pub async fn get_jwt(&self) -> String { + let req = Request::builder() + .uri(self.auth_key_uri.clone()) + .header(AUTHORIZATION, format!("Bearer {}", self.api_key)) + .body(Body::empty()) + .unwrap(); + + trace!("getting jwt"); + + let resp = timeout(IS_HEALTHY_TIMEOUT, AUTH_CLIENT.request(req)).await; + + if let Ok(Ok(resp)) = resp { + let body = hyper::body::to_bytes(resp.into_body()) + .await + .unwrap_or_default(); + let convert: Value = serde_json::from_slice(&body).unwrap_or_default(); + + trace!(?convert, "got jwt response"); + + convert["token"].as_str().unwrap_or_default().to_string() + } else { + Default::default() + } + } +} + #[cfg(test)] pub mod tests { use fqdn::FQDN; diff --git a/gateway/src/task.rs b/gateway/src/task.rs index 54c1dc28f..a5c6e4374 100644 --- a/gateway/src/task.rs +++ b/gateway/src/task.rs @@ -137,15 +137,14 @@ pub fn start() -> impl Task { pub fn start_idle_deploys() -> impl Task { run(|ctx| async move { - match ctx.state.refresh(&ctx.gateway).await { - Ok(Project::Ready(mut ready)) => { - let api_key = ctx.api_key; - - ready.start_last_deploy(api_key).await; + match ctx.state { + Project::Ready(mut ready) => { + ready + .start_last_deploy(ctx.gateway.get_jwt().await, ctx.admin_secret.clone()) + .await; TaskResult::Done(Project::Ready(ready)) } - Ok(update) => TaskResult::Done(update), - Err(err) => TaskResult::Err(err), + other => TaskResult::Done(other), } }) } @@ -438,8 +437,8 @@ pub struct ProjectContext { pub gateway: GatewayContext, /// The last known state of the project pub state: Project, - /// The api key for this machine - pub api_key: String, + /// The secret needed to communicate with the project + pub admin_secret: String, } pub type BoxedTask = Box>; @@ -473,13 +472,21 @@ where Ok(account_name) => account_name, Err(err) => return TaskResult::Err(err), }; + let admin_secret = match self + .service + .control_key_from_project_name(&self.project_name) + .await + { + Ok(account_name) => account_name, + Err(err) => return TaskResult::Err(err), + }; let project_ctx = ProjectContext { project_name: self.project_name.clone(), account_name: account_name.clone(), gateway: ctx, state: project, - api_key: self.service.api_key(), + admin_secret, }; let span = info_span!(