From 197a344384e47a046e5013b34db8a9e54bac201b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oddbj=C3=B8rn=20Gr=C3=B8dem?= <29732646+oddgrd@users.noreply.github.com> Date: Tue, 2 Apr 2024 14:10:55 +0200 Subject: [PATCH] feat: merge runtime updates in main ecs branch (#1709) * feat: runtime healthcheck, start runtime on 0.0.0.0 running on unspecified ip was necessary for the runner to be able to reach the runtime when they are running in separate containers * feat(proto): update runtime::get_client to work with * misc(proto): get client takes u16 port * feat: add health toggle to runtime * feat: set runtime to unhealthy if it doesn't start within 60s * feat: change runtime::get_client to take address * feat: kill runtime if it doesn't become healthy in time * feat: increase provisioning timeout duration --- proto/runtime.proto | 5 ++ proto/src/generated/runtime.rs | 63 +++++++++++++++++++++++ proto/src/lib.rs | 9 ++-- runtime/src/alpha.rs | 92 +++++++++++++++++++++++++++++----- service/src/runner.rs | 6 +-- 5 files changed, 153 insertions(+), 22 deletions(-) diff --git a/proto/runtime.proto b/proto/runtime.proto index 3748e61299..fd0c974b09 100644 --- a/proto/runtime.proto +++ b/proto/runtime.proto @@ -13,6 +13,8 @@ service Runtime { // Channel to notify a service has been stopped rpc SubscribeStop(SubscribeStopRequest) returns (stream SubscribeStopResponse); + + rpc HealthCheck(Ping) returns (Pong); } message LoadRequest { @@ -78,3 +80,6 @@ enum StopReason { // Service crashed Crash = 2; } + +message Ping {} +message Pong {} diff --git a/proto/src/generated/runtime.rs b/proto/src/generated/runtime.rs index 09eb3cca5f..fa7097a29a 100644 --- a/proto/src/generated/runtime.rs +++ b/proto/src/generated/runtime.rs @@ -73,6 +73,12 @@ pub struct SubscribeStopResponse { #[prost(string, tag = "2")] pub message: ::prost::alloc::string::String, } +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Ping {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Pong {} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum StopReason { @@ -264,6 +270,23 @@ pub mod runtime_client { .insert(GrpcMethod::new("runtime.Runtime", "SubscribeStop")); self.inner.server_streaming(req, path, codec).await } + pub async fn health_check( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static("/runtime.Runtime/HealthCheck"); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("runtime.Runtime", "HealthCheck")); + self.inner.unary(req, path, codec).await + } } } /// Generated server implementations. @@ -298,6 +321,10 @@ pub mod runtime_server { &self, request: tonic::Request, ) -> std::result::Result, tonic::Status>; + async fn health_check( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; } #[derive(Debug)] pub struct RuntimeServer { @@ -534,6 +561,42 @@ pub mod runtime_server { }; Box::pin(fut) } + "/runtime.Runtime/HealthCheck" => { + #[allow(non_camel_case_types)] + struct HealthCheckSvc(pub Arc); + impl tonic::server::UnaryService for HealthCheckSvc { + type Response = super::Pong; + type Future = BoxFuture, tonic::Status>; + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = + async move { ::health_check(&inner, request).await }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = HealthCheckSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } _ => Box::pin(async move { Ok(http::Response::builder() .status(200) diff --git a/proto/src/lib.rs b/proto/src/lib.rs index b2676f35a2..0a6f221b60 100644 --- a/proto/src/lib.rs +++ b/proto/src/lib.rs @@ -146,16 +146,14 @@ mod _runtime_client { use tracing::{info, trace}; pub type Client = runtime_client::RuntimeClient< - shuttle_common::claims::ClaimService< - shuttle_common::claims::InjectPropagation, - >, + shuttle_common::claims::InjectPropagation, >; /// Get a runtime client that is correctly configured #[cfg(feature = "client")] - pub async fn get_client(port: &str) -> anyhow::Result { + pub async fn get_client(address: String) -> anyhow::Result { info!("connecting runtime client"); - let conn = Endpoint::new(format!("http://127.0.0.1:{port}")) + let conn = Endpoint::new(address) .context("creating runtime client endpoint")? .connect_timeout(Duration::from_secs(5)); @@ -177,7 +175,6 @@ mod _runtime_client { .context("runtime control port did not open in time")?; let runtime_service = tower::ServiceBuilder::new() - .layer(shuttle_common::claims::ClaimLayer) .layer(shuttle_common::claims::InjectPropagationLayer) .service(channel); diff --git a/runtime/src/alpha.rs b/runtime/src/alpha.rs index 0675b55f50..a3d5070bd3 100644 --- a/runtime/src/alpha.rs +++ b/runtime/src/alpha.rs @@ -2,9 +2,9 @@ use std::{ collections::BTreeMap, iter::FromIterator, net::{Ipv4Addr, SocketAddr}, - ops::DerefMut, + ops::{Deref, DerefMut}, str::FromStr, - sync::Mutex, + sync::{Arc, Mutex}, time::Duration, }; @@ -12,10 +12,13 @@ use anyhow::Context; use async_trait::async_trait; use core::future::Future; use shuttle_common::{extract_propagation::ExtractPropagationLayer, secrets::Secret}; -use shuttle_proto::runtime::{ - runtime_server::{Runtime, RuntimeServer}, - LoadRequest, LoadResponse, StartRequest, StartResponse, StopReason, StopRequest, StopResponse, - SubscribeStopRequest, SubscribeStopResponse, +use shuttle_proto::{ + runtime::{ + runtime_server::{Runtime, RuntimeServer}, + LoadRequest, LoadResponse, StartRequest, StartResponse, StopReason, StopRequest, + StopResponse, SubscribeStopRequest, SubscribeStopResponse, + }, + runtime::{Ping, Pong}, }; use shuttle_service::{ResourceFactory, Service}; use tokio::sync::{ @@ -84,23 +87,42 @@ pub async fn start(loader: impl Loader + Send + 'static, runner: impl Runner + S } // where to serve the gRPC control layer - let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), args.port); + let addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), args.port); let mut server_builder = Server::builder() .http2_keepalive_interval(Some(Duration::from_secs(60))) .layer(ExtractPropagationLayer); + // A channel we can use to kill the runtime if it does not become healthy in time. + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + let router = { - let alpha = Alpha::new(loader, runner); + let alpha = Alpha::new(loader, runner, tx); let svc = RuntimeServer::new(alpha); server_builder.add_service(svc) }; - match router.serve(addr).await { - Ok(_) => {} - Err(e) => panic!("Error while serving address {addr}: {e}"), - }; + tokio::select! { + res = router.serve(addr) => { + match res{ + Ok(_) => {} + Err(e) => panic!("Error while serving address {addr}: {e}") + } + } + res = rx => { + match res{ + Ok(_) => panic!("Received runtime kill signal"), + Err(e) => panic!("Receiver error: {e}") + } + } + } +} + +pub enum State { + Unhealthy, + Loading, + Running, } pub struct Alpha { @@ -109,10 +131,14 @@ pub struct Alpha { kill_tx: Mutex>>, loader: Mutex>, runner: Mutex>, + /// The current state of the runtime, which is used by the ECS task to determine if the runtime + /// is healthy. + state: Arc>, + runtime_kill_tx: Mutex>>, } impl Alpha { - pub fn new(loader: L, runner: R) -> Self { + pub fn new(loader: L, runner: R, runtime_kill_tx: tokio::sync::oneshot::Sender<()>) -> Self { let (stopped_tx, _stopped_rx) = broadcast::channel(10); Self { @@ -120,6 +146,8 @@ impl Alpha { kill_tx: Mutex::new(None), loader: Mutex::new(Some(loader)), runner: Mutex::new(Some(runner)), + state: Arc::new(Mutex::new(State::Unhealthy)), + runtime_kill_tx: Mutex::new(Some(runtime_kill_tx)), } } } @@ -223,6 +251,31 @@ where } }; + println!("setting current state to healthy"); + *self.state.lock().unwrap() = State::Loading; + + let state = self.state.clone(); + let runtime_kill_tx = self + .runtime_kill_tx + .lock() + .unwrap() + .deref_mut() + .take() + .unwrap(); + + // Ensure that the runtime is set to unhealthy if it doesn't reach the running state after + // it has sent a load response, so that the ECS task will fail. + tokio::spawn(async move { + // Note: The timeout is quite low as we are not actually provisioning resources after + // sending the load response. + tokio::time::sleep(Duration::from_secs(180)).await; + if !matches!(state.lock().unwrap().deref(), State::Running) { + println!("the runtime failed to enter the running state before timing out"); + + runtime_kill_tx.send(()).unwrap(); + } + }); + Ok(Response::new(LoadResponse { success: true, message: String::new(), @@ -355,6 +408,8 @@ where ..Default::default() }; + *self.state.lock().unwrap() = State::Running; + Ok(Response::new(message)) } @@ -398,4 +453,15 @@ where Ok(Response::new(ReceiverStream::new(rx))) } + + async fn health_check(&self, _request: Request) -> Result, Status> { + if matches!(self.state.lock().unwrap().deref(), State::Unhealthy) { + println!("runtime health check failed"); + return Err(Status::unavailable( + "runtime has not reached a healthy state", + )); + } + + Ok(Response::new(Pong {})) + } } diff --git a/service/src/runner.rs b/service/src/runner.rs index 5cb949558d..b772e573f5 100644 --- a/service/src/runner.rs +++ b/service/src/runner.rs @@ -13,8 +13,8 @@ pub async fn start( runtime_executable: PathBuf, project_path: &Path, ) -> anyhow::Result<(process::Child, runtime::Client)> { - let port = &port.to_string(); - let args = vec!["--port", port]; + let port_str = port.to_string(); + let args = vec!["--port", &port_str]; info!( args = %format!("{} {}", runtime_executable.display(), args.join(" ")), @@ -30,7 +30,7 @@ pub async fn start( .spawn() .context("spawning runtime process")?; - let runtime_client = runtime::get_client(port).await?; + let runtime_client = runtime::get_client(format!("http://0.0.0.0:{port}")).await?; Ok((runtime, runtime_client)) }