diff --git a/coordinator/src/handlers/block_streams.rs b/coordinator/src/handlers/block_streams.rs index b3e9fb5c..c3dc21e4 100644 --- a/coordinator/src/handlers/block_streams.rs +++ b/coordinator/src/handlers/block_streams.rs @@ -3,31 +3,76 @@ use std::time::{Duration, SystemTime}; pub use block_streamer::StreamInfo; +use block_streamer::{StartStreamResponse, StopStreamResponse}; use anyhow::Context; use block_streamer::block_streamer_client::BlockStreamerClient; use block_streamer::{ start_stream_request::Rule, ActionAnyRule, ActionFunctionCallRule, GetStreamRequest, - ListStreamsRequest, ProcessingState, StartStreamRequest, Status, StopStreamRequest, + ProcessingState, StartStreamRequest, Status, StopStreamRequest, }; use near_primitives::types::AccountId; use registry_types::StartBlock; use tonic::transport::channel::Channel; -use tonic::Request; use crate::indexer_config::IndexerConfig; use crate::redis::{KeyProvider, RedisClient}; -use crate::utils::exponential_retry; const RESTART_TIMEOUT_SECONDS: u64 = 600; +#[cfg(not(test))] +use BlockStreamsClientWrapperImpl as BlockStreamsClientWrapper; +#[cfg(test)] +use MockBlockStreamsClientWrapperImpl as BlockStreamsClientWrapper; + +#[derive(Clone)] +struct BlockStreamsClientWrapperImpl { + inner: BlockStreamerClient, +} + +#[cfg_attr(test, mockall::automock)] +impl BlockStreamsClientWrapperImpl { + pub fn new(inner: BlockStreamerClient) -> Self { + Self { inner } + } + + pub async fn stop_stream( + &self, + request: R, + ) -> std::result::Result, tonic::Status> + where + R: tonic::IntoRequest + 'static, + { + self.inner.clone().stop_stream(request).await + } + + pub async fn get_stream( + &self, + request: R, + ) -> std::result::Result, tonic::Status> + where + R: tonic::IntoRequest + 'static, + { + self.inner.clone().get_stream(request).await + } + + pub async fn start_stream( + &self, + request: R, + ) -> std::result::Result, tonic::Status> + where + R: tonic::IntoRequest + 'static, + { + self.inner.clone().start_stream(request).await + } +} + #[derive(Clone)] pub struct BlockStreamsHandler { - client: BlockStreamerClient, + client: BlockStreamsClientWrapper, redis_client: RedisClient, } -#[cfg_attr(test, mockall::automock)] impl BlockStreamsHandler { pub fn connect(block_streamer_url: &str, redis_client: RedisClient) -> anyhow::Result { let channel = Channel::from_shared(block_streamer_url.to_string()) @@ -36,38 +81,17 @@ impl BlockStreamsHandler { let client = BlockStreamerClient::new(channel); Ok(Self { - client, + client: BlockStreamsClientWrapper::new(client), redis_client, }) } - pub async fn list(&self) -> anyhow::Result> { - exponential_retry(|| async { - let response = self - .client - .clone() - .list_streams(Request::new(ListStreamsRequest {})) - .await - .context("Failed to list streams")?; - - let streams = response.into_inner().streams; - - tracing::debug!("List streams response: {:#?}", streams); - - Ok(streams) - }) - .await - } - pub async fn stop(&self, stream_id: String) -> anyhow::Result<()> { - let request = StopStreamRequest { - stream_id: stream_id.clone(), - }; - let response = self .client - .clone() - .stop_stream(Request::new(request.clone())) + .stop_stream(StopStreamRequest { + stream_id: stream_id.clone(), + }) .await .context(format!("Failed to stop stream: {stream_id}"))?; @@ -95,7 +119,7 @@ impl BlockStreamsHandler { function_name: function_name.clone(), }; - match self.client.clone().get_stream(Request::new(request)).await { + match self.client.get_stream(request).await { Ok(response) => Ok(Some(response.into_inner())), Err(status) if status.code() == tonic::Code::NotFound => Ok(None), Err(err) => Err(err).context(format!( @@ -145,15 +169,10 @@ impl BlockStreamsHandler { rule: Some(rule), }; - let response = self - .client - .clone() - .start_stream(Request::new(request.clone())) - .await - .context(format!( - "Failed to start stream: {}", - indexer_config.get_full_name() - ))?; + let response = self.client.start_stream(request).await.context(format!( + "Failed to start stream: {}", + indexer_config.get_full_name() + ))?; tracing::debug!( account_id = indexer_config.account_id.as_str(), @@ -166,7 +185,7 @@ impl BlockStreamsHandler { Ok(()) } - async fn reconfigure_block_stream(&self, config: &IndexerConfig) -> anyhow::Result<()> { + async fn reconfigure(&self, config: &IndexerConfig) -> anyhow::Result<()> { if matches!( config.start_block, StartBlock::Latest | StartBlock::Height(..) @@ -229,7 +248,7 @@ impl BlockStreamsHandler { Ok(height) } - async fn resume_block_stream(&self, config: &IndexerConfig) -> anyhow::Result<()> { + async fn resume(&self, config: &IndexerConfig) -> anyhow::Result<()> { let height = self.get_continuation_block_height(config).await?; tracing::info!(height, "Resuming block stream"); @@ -277,7 +296,7 @@ impl BlockStreamsHandler { Ok(()) } - pub async fn synchronise_block_stream( + pub async fn synchronise( &self, config: &IndexerConfig, previous_sync_version: Option, @@ -299,7 +318,7 @@ impl BlockStreamsHandler { self.stop(block_stream.stream_id.clone()).await?; - self.reconfigure_block_stream(config).await?; + self.reconfigure(config).await?; return Ok(()); } @@ -311,12 +330,12 @@ impl BlockStreamsHandler { } if previous_sync_version.unwrap() != config.get_registry_version() { - self.reconfigure_block_stream(config).await?; + self.reconfigure(config).await?; return Ok(()); } - self.resume_block_stream(config).await?; + self.resume(config).await?; Ok(()) } @@ -335,3 +354,382 @@ impl BlockStreamsHandler { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + use mockall::predicate::*; + use tonic::Response; + + impl Clone for MockBlockStreamsClientWrapperImpl { + fn clone(&self) -> Self { + Self::default() + } + } + + #[tokio::test] + async fn resumes_stopped_streams() { + let config = IndexerConfig::default(); + let last_published_block = 10; + + let mut mock_client = BlockStreamsClientWrapper::default(); + mock_client + .expect_get_stream::() + .with(eq(GetStreamRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + })) + .returning(|_| Err(tonic::Status::not_found("not found"))); + mock_client + .expect_start_stream::() + .with(eq(StartStreamRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + redis_stream: config.get_redis_stream_key(), + rule: Some(Rule::ActionAnyRule(ActionAnyRule { + affected_account_id: "queryapi.dataplatform.near".to_string(), + status: Status::Any.into(), + })), + start_block_height: last_published_block + 1, + version: config.get_registry_version(), + })) + .returning(|_| Ok(Response::new(StartStreamResponse::default()))); + + let mut mock_redis = RedisClient::default(); + mock_redis + .expect_get_last_published_block::() + .returning(move |_| Ok(Some(last_published_block))); + + let handler = BlockStreamsHandler { + client: mock_client, + redis_client: mock_redis, + }; + + handler + .synchronise(&config, Some(config.get_registry_version())) + .await + .unwrap(); + } + + #[tokio::test] + async fn reconfigures_outdated_streams() { + let config = IndexerConfig::default(); + + let existing_stream = StreamInfo { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + stream_id: "stream-id".to_string(), + version: config.get_registry_version() - 1, + health: None, + }; + + let mut mock_client = BlockStreamsClientWrapper::default(); + mock_client + .expect_stop_stream::() + .with(eq(StopStreamRequest { + stream_id: existing_stream.stream_id.clone(), + })) + .returning(|_| Ok(Response::new(StopStreamResponse::default()))); + mock_client + .expect_get_stream::() + .with(eq(GetStreamRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + })) + .returning(move |_| Ok(Response::new(existing_stream.clone()))); + mock_client + .expect_start_stream::() + .with(eq(StartStreamRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + redis_stream: config.get_redis_stream_key(), + rule: Some(Rule::ActionAnyRule(ActionAnyRule { + affected_account_id: "queryapi.dataplatform.near".to_string(), + status: Status::Any.into(), + })), + start_block_height: if let StartBlock::Height(height) = config.start_block { + height + } else { + unreachable!() + }, + version: config.get_registry_version(), + })) + .returning(|_| Ok(Response::new(StartStreamResponse::default()))); + + let mut mock_redis = RedisClient::default(); + mock_redis + .expect_clear_block_stream::() + .returning(|_| Ok(())) + .once(); + + let handler = BlockStreamsHandler { + client: mock_client, + redis_client: mock_redis, + }; + + handler + .synchronise(&config, Some(config.get_registry_version())) + .await + .unwrap(); + } + + #[tokio::test] + async fn starts_new_streams() { + let config = IndexerConfig::default(); + + let mut mock_client = BlockStreamsClientWrapper::default(); + mock_client + .expect_get_stream::() + .with(eq(GetStreamRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + })) + .returning(|_| Err(tonic::Status::not_found("not found"))); + mock_client + .expect_start_stream::() + .with(eq(StartStreamRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + redis_stream: config.get_redis_stream_key(), + rule: Some(Rule::ActionAnyRule(ActionAnyRule { + affected_account_id: "queryapi.dataplatform.near".to_string(), + status: Status::Any.into(), + })), + start_block_height: if let StartBlock::Height(height) = config.start_block { + height + } else { + unreachable!() + }, + version: config.get_registry_version(), + })) + .returning(|_| Ok(Response::new(StartStreamResponse::default()))); + + let mock_redis = RedisClient::default(); + + let handler = BlockStreamsHandler { + client: mock_client, + redis_client: mock_redis, + }; + + handler.synchronise(&config, None).await.unwrap(); + } + + #[tokio::test] + async fn reconfigures_outdated_and_stopped_streams() { + let config = IndexerConfig { + start_block: StartBlock::Latest, + ..Default::default() + }; + + let mut mock_client = BlockStreamsClientWrapper::default(); + mock_client + .expect_get_stream::() + .with(eq(GetStreamRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + })) + .returning(|_| Err(tonic::Status::not_found("not found"))); + mock_client + .expect_start_stream::() + .with(eq(StartStreamRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + redis_stream: config.get_redis_stream_key(), + rule: Some(Rule::ActionAnyRule(ActionAnyRule { + affected_account_id: "queryapi.dataplatform.near".to_string(), + status: Status::Any.into(), + })), + start_block_height: config.get_registry_version(), + version: config.get_registry_version(), + })) + .returning(|_| Ok(Response::new(StartStreamResponse::default()))); + + let mut mock_redis = RedisClient::default(); + mock_redis + .expect_clear_block_stream::() + .returning(|_| Ok(())) + .once(); + + let handler = BlockStreamsHandler { + client: mock_client, + redis_client: mock_redis, + }; + + handler + .synchronise(&config, Some(config.get_registry_version() - 1)) + .await + .unwrap(); + } + + #[tokio::test] + async fn restarts_unhealthy_streams() { + tokio::time::pause(); + + let config = IndexerConfig::default(); + let last_published_block = 10; + + let existing_stream = StreamInfo { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + stream_id: "stream-id".to_string(), + version: config.get_registry_version(), + health: Some(block_streamer::Health { + updated_at_timestamp_secs: SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(), + processing_state: ProcessingState::Stalled.into(), + }), + }; + + let mut mock_client = BlockStreamsClientWrapper::default(); + mock_client + .expect_stop_stream::() + .with(eq(StopStreamRequest { + stream_id: existing_stream.stream_id.clone(), + })) + .returning(|_| Ok(Response::new(StopStreamResponse::default()))); + mock_client + .expect_get_stream::() + .with(eq(GetStreamRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + })) + .returning(move |_| Ok(Response::new(existing_stream.clone()))); + mock_client + .expect_start_stream::() + .with(eq(StartStreamRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + redis_stream: config.get_redis_stream_key(), + rule: Some(Rule::ActionAnyRule(ActionAnyRule { + affected_account_id: "queryapi.dataplatform.near".to_string(), + status: Status::Any.into(), + })), + start_block_height: last_published_block + 1, + version: config.get_registry_version(), + })) + .returning(|_| Ok(Response::new(StartStreamResponse::default()))); + + let mut mock_redis = RedisClient::default(); + mock_redis + .expect_get_last_published_block::() + .returning(move |_| Ok(Some(last_published_block))); + + let handler = BlockStreamsHandler { + client: mock_client, + redis_client: mock_redis, + }; + + handler + .synchronise(&config, Some(config.get_registry_version() - 1)) + .await + .unwrap(); + } + + #[tokio::test] + async fn ignores_healthy_streams() { + tokio::time::pause(); + + let config = IndexerConfig::default(); + + let healthy_states = vec![ + ProcessingState::Running, + ProcessingState::Idle, + ProcessingState::Waiting, + ]; + + for healthy_state in healthy_states { + let existing_stream = StreamInfo { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + stream_id: "stream-id".to_string(), + version: config.get_registry_version(), + health: Some(block_streamer::Health { + updated_at_timestamp_secs: SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(), + processing_state: healthy_state.into(), + }), + }; + + let mut mock_client = BlockStreamsClientWrapper::default(); + mock_client + .expect_get_stream::() + .with(eq(GetStreamRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + })) + .returning(move |_| Ok(Response::new(existing_stream.clone()))); + mock_client + .expect_stop_stream::() + .never(); + mock_client + .expect_start_stream::() + .never(); + + let mock_redis = RedisClient::default(); + + let handler = BlockStreamsHandler { + client: mock_client, + redis_client: mock_redis, + }; + + handler + .synchronise(&config, Some(config.get_registry_version())) + .await + .unwrap(); + } + } + + #[tokio::test] + async fn clears_redis_stream() { + let config_with_height = IndexerConfig::default(); + let config_with_latest = IndexerConfig { + start_block: StartBlock::Latest, + ..Default::default() + }; + let config_with_continue = IndexerConfig { + start_block: StartBlock::Continue, + ..Default::default() + }; + + let mut mock_client = BlockStreamsClientWrapper::default(); + mock_client + .expect_start_stream::() + .with(always()) + .returning(|_| Ok(Response::new(StartStreamResponse::default()))) + .times(3); + + let mut mock_redis = RedisClient::default(); + mock_redis + .expect_clear_block_stream::() + .with(eq(config_with_height.clone())) + .returning(|_| Ok(())) + .once(); + mock_redis + .expect_clear_block_stream::() + .with(eq(config_with_latest.clone())) + .returning(|_| Ok(())) + .once(); + mock_redis + .expect_clear_block_stream::() + .with(eq(config_with_continue.clone())) + .never(); + mock_redis + .expect_get_last_published_block::() + .returning(|_| Ok(None)) + .once(); + + let handler = BlockStreamsHandler { + client: mock_client, + redis_client: mock_redis, + }; + + handler.reconfigure(&config_with_latest).await.unwrap(); + handler.reconfigure(&config_with_continue).await.unwrap(); + handler.reconfigure(&config_with_height).await.unwrap(); + } +} diff --git a/coordinator/src/handlers/data_layer.rs b/coordinator/src/handlers/data_layer.rs index ff2e2657..e0ea273c 100644 --- a/coordinator/src/handlers/data_layer.rs +++ b/coordinator/src/handlers/data_layer.rs @@ -6,9 +6,12 @@ pub use runner::data_layer::TaskStatus; use anyhow::Context; use runner::data_layer::data_layer_client::DataLayerClient; -use runner::data_layer::{DeprovisionRequest, GetTaskStatusRequest, ProvisionRequest}; +use runner::data_layer::{ + DeprovisionRequest, GetTaskStatusRequest, GetTaskStatusResponse, ProvisionRequest, + StartTaskResponse, +}; use tonic::transport::channel::Channel; -use tonic::{Request, Status}; +use tonic::Status; use crate::indexer_config::IndexerConfig; @@ -16,9 +19,63 @@ type TaskId = String; const TASK_TIMEOUT_SECONDS: u64 = 600; // 10 minutes +#[cfg(not(test))] +use DataLayerClientWrapperImpl as DataLayerClientWrapper; +#[cfg(test)] +use MockDataLayerClientWrapperImpl as DataLayerClientWrapper; + +#[derive(Clone)] +struct DataLayerClientWrapperImpl { + inner: DataLayerClient, +} + +#[cfg(test)] +impl Clone for MockDataLayerClientWrapperImpl { + fn clone(&self) -> Self { + Self::default() + } +} + +#[cfg_attr(test, mockall::automock)] +impl DataLayerClientWrapperImpl { + pub fn new(inner: DataLayerClient) -> Self { + Self { inner } + } + + pub async fn start_provisioning_task( + &self, + request: R, + ) -> std::result::Result, tonic::Status> + where + R: tonic::IntoRequest + 'static, + { + self.inner.clone().start_provisioning_task(request).await + } + + pub async fn start_deprovisioning_task( + &self, + request: R, + ) -> std::result::Result, tonic::Status> + where + R: tonic::IntoRequest + 'static, + { + self.inner.clone().start_deprovisioning_task(request).await + } + + pub async fn get_task_status( + &self, + request: R, + ) -> std::result::Result, tonic::Status> + where + R: tonic::IntoRequest + 'static, + { + self.inner.clone().get_task_status(request).await + } +} + #[derive(Clone)] pub struct DataLayerHandler { - client: DataLayerClient, + client: DataLayerClientWrapper, } impl DataLayerHandler { @@ -28,7 +85,9 @@ impl DataLayerHandler { .connect_lazy(); let client = DataLayerClient::new(channel); - Ok(Self { client }) + Ok(Self { + client: DataLayerClientWrapper::new(client), + }) } pub async fn start_provisioning_task( @@ -41,11 +100,7 @@ impl DataLayerHandler { schema: indexer_config.schema.clone(), }; - let response = self - .client - .clone() - .start_provisioning_task(Request::new(request)) - .await?; + let response = self.client.start_provisioning_task(request).await?; Ok(response.into_inner().task_id) } @@ -60,11 +115,7 @@ impl DataLayerHandler { function_name, }; - let response = self - .client - .clone() - .start_deprovisioning_task(Request::new(request)) - .await?; + let response = self.client.start_deprovisioning_task(request).await?; Ok(response.into_inner().task_id) } @@ -72,11 +123,7 @@ impl DataLayerHandler { pub async fn get_task_status(&self, task_id: TaskId) -> anyhow::Result { let request = GetTaskStatusRequest { task_id }; - let response = self - .client - .clone() - .get_task_status(Request::new(request)) - .await; + let response = self.client.get_task_status(request).await; if let Err(error) = response { if error.code() == tonic::Code::NotFound { @@ -193,3 +240,284 @@ impl DataLayerHandler { Ok(()) } } + +#[cfg(test)] +mod tests { + use crate::redis::KeyProvider; + + use super::*; + + use mockall::predicate::*; + + #[tokio::test] + async fn provisions_data_layer() { + let config = IndexerConfig::default(); + + let mut mock_client = DataLayerClientWrapper::default(); + mock_client + .expect_start_provisioning_task::() + .with(eq(ProvisionRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + schema: config.schema.clone(), + })) + .returning(|_| { + Ok(tonic::Response::new(StartTaskResponse { + task_id: "task_id".to_string(), + })) + }) + .once(); + mock_client + .expect_get_task_status::() + .with(eq(GetTaskStatusRequest { + task_id: "task_id".to_string(), + })) + .returning(|_| { + Ok(tonic::Response::new(GetTaskStatusResponse { + status: TaskStatus::Pending.into(), + })) + }) + .once(); + mock_client + .expect_get_task_status::() + .with(eq(GetTaskStatusRequest { + task_id: "task_id".to_string(), + })) + .returning(|_| { + Ok(tonic::Response::new(GetTaskStatusResponse { + status: TaskStatus::Complete.into(), + })) + }) + .once(); + + let handler = DataLayerHandler { + client: mock_client, + }; + + handler.ensure_provisioned(&config).await.unwrap(); + } + + #[tokio::test] + async fn timesout_provisioning_task() { + tokio::time::pause(); + + let config = IndexerConfig::default(); + + let mut mock_client = DataLayerClientWrapper::default(); + mock_client + .expect_start_provisioning_task::() + .with(eq(ProvisionRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + schema: config.schema.clone(), + })) + .returning(|_| { + Ok(tonic::Response::new(StartTaskResponse { + task_id: "task_id".to_string(), + })) + }) + .once(); + mock_client + .expect_get_task_status::() + .with(eq(GetTaskStatusRequest { + task_id: "task_id".to_string(), + })) + .returning(|_| { + Ok(tonic::Response::new(GetTaskStatusResponse { + status: TaskStatus::Pending.into(), + })) + }) + .times(610); + + let handler = DataLayerHandler { + client: mock_client, + }; + + let result = handler.ensure_provisioned(&config).await; + + assert_eq!( + result.err().unwrap().to_string(), + "Provisioning task timed out" + ); + } + + #[tokio::test] + async fn propagates_provisioning_failures() { + let config = IndexerConfig::default(); + + let mut mock_client = DataLayerClientWrapper::default(); + mock_client + .expect_start_provisioning_task::() + .with(eq(ProvisionRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + schema: config.schema.clone(), + })) + .returning(|_| { + Ok(tonic::Response::new(StartTaskResponse { + task_id: "task_id".to_string(), + })) + }) + .once(); + mock_client + .expect_get_task_status::() + .with(eq(GetTaskStatusRequest { + task_id: "task_id".to_string(), + })) + .returning(|_| { + Ok(tonic::Response::new(GetTaskStatusResponse { + status: TaskStatus::Failed.into(), + })) + }) + .once(); + + let handler = DataLayerHandler { + client: mock_client, + }; + + let result = handler.ensure_provisioned(&config).await; + + assert_eq!( + result.err().unwrap().to_string(), + "Provisioning task failed" + ); + } + + #[tokio::test] + async fn deprovisions_data_layer() { + let config = IndexerConfig::default(); + + let mut mock_client = DataLayerClientWrapper::default(); + mock_client + .expect_start_deprovisioning_task::() + .with(eq(DeprovisionRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + })) + .returning(|_| { + Ok(tonic::Response::new(StartTaskResponse { + task_id: "task_id".to_string(), + })) + }) + .once(); + mock_client + .expect_get_task_status::() + .with(eq(GetTaskStatusRequest { + task_id: "task_id".to_string(), + })) + .returning(|_| { + Ok(tonic::Response::new(GetTaskStatusResponse { + status: TaskStatus::Pending.into(), + })) + }) + .once(); + mock_client + .expect_get_task_status::() + .with(eq(GetTaskStatusRequest { + task_id: "task_id".to_string(), + })) + .returning(|_| { + Ok(tonic::Response::new(GetTaskStatusResponse { + status: TaskStatus::Complete.into(), + })) + }) + .once(); + + let handler = DataLayerHandler { + client: mock_client, + }; + + handler + .ensure_deprovisioned(config.account_id, config.function_name) + .await + .unwrap(); + } + + #[tokio::test] + async fn timesout_deprovisioning_task() { + tokio::time::pause(); + + let config = IndexerConfig::default(); + + let mut mock_client = DataLayerClientWrapper::default(); + mock_client + .expect_start_deprovisioning_task::() + .with(eq(DeprovisionRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + })) + .returning(|_| { + Ok(tonic::Response::new(StartTaskResponse { + task_id: "task_id".to_string(), + })) + }) + .once(); + mock_client + .expect_get_task_status::() + .with(eq(GetTaskStatusRequest { + task_id: "task_id".to_string(), + })) + .returning(|_| { + Ok(tonic::Response::new(GetTaskStatusResponse { + status: TaskStatus::Pending.into(), + })) + }) + .times(610); + + let handler = DataLayerHandler { + client: mock_client, + }; + + let result = handler + .ensure_deprovisioned(config.account_id, config.function_name) + .await; + + assert_eq!( + result.err().unwrap().to_string(), + "Deprovisioning task timed out" + ); + } + + #[tokio::test] + async fn propagates_deprovisioning_failures() { + let config = IndexerConfig::default(); + + let mut mock_client = DataLayerClientWrapper::default(); + mock_client + .expect_start_deprovisioning_task::() + .with(eq(DeprovisionRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + })) + .returning(|_| { + Ok(tonic::Response::new(StartTaskResponse { + task_id: "task_id".to_string(), + })) + }) + .once(); + mock_client + .expect_get_task_status::() + .with(eq(GetTaskStatusRequest { + task_id: "task_id".to_string(), + })) + .returning(|_| { + Ok(tonic::Response::new(GetTaskStatusResponse { + status: TaskStatus::Failed.into(), + })) + }) + .once(); + + let handler = DataLayerHandler { + client: mock_client, + }; + + let result = handler + .ensure_deprovisioned(config.account_id, config.function_name) + .await; + + assert_eq!( + result.err().unwrap().to_string(), + "Deprovisioning task failed" + ); + } +} diff --git a/coordinator/src/handlers/executors.rs b/coordinator/src/handlers/executors.rs index 184718d0..4fc67a67 100644 --- a/coordinator/src/handlers/executors.rs +++ b/coordinator/src/handlers/executors.rs @@ -6,21 +6,66 @@ pub use runner::ExecutorInfo; use anyhow::Context; use runner::runner_client::RunnerClient; use runner::{ - ExecutionState, GetExecutorRequest, ListExecutorsRequest, StartExecutorRequest, - StopExecutorRequest, + ExecutionState, GetExecutorRequest, StartExecutorRequest, StartExecutorResponse, + StopExecutorRequest, StopExecutorResponse, }; use tonic::transport::channel::Channel; -use tonic::Request; use crate::indexer_config::IndexerConfig; use crate::redis::KeyProvider; -use crate::utils::exponential_retry; const RESTART_TIMEOUT_SECONDS: u64 = 600; +#[cfg(not(test))] +use ExecutorsClientWrapperImpl as ExecutorsClientWrapper; +#[cfg(test)] +use MockExecutorsClientWrapperImpl as ExecutorsClientWrapper; + +#[derive(Clone)] +struct ExecutorsClientWrapperImpl { + inner: RunnerClient, +} + +#[cfg_attr(test, mockall::automock)] +impl ExecutorsClientWrapperImpl { + pub fn new(inner: RunnerClient) -> Self { + Self { inner } + } + + pub async fn get_executor( + &self, + request: R, + ) -> std::result::Result, tonic::Status> + where + R: tonic::IntoRequest + 'static, + { + self.inner.clone().get_executor(request).await + } + + pub async fn start_executor( + &self, + request: R, + ) -> std::result::Result, tonic::Status> + where + R: tonic::IntoRequest + 'static, + { + self.inner.clone().start_executor(request).await + } + + pub async fn stop_executor( + &self, + request: R, + ) -> std::result::Result, tonic::Status> + where + R: tonic::IntoRequest + 'static, + { + self.inner.clone().stop_executor(request).await + } +} + #[derive(Clone)] pub struct ExecutorsHandler { - client: RunnerClient, + client: ExecutorsClientWrapper, } impl ExecutorsHandler { @@ -30,25 +75,9 @@ impl ExecutorsHandler { .connect_lazy(); let client = RunnerClient::new(channel); - Ok(Self { client }) - } - - pub async fn list(&self) -> anyhow::Result> { - exponential_retry(|| async { - let response = self - .client - .clone() - .list_executors(Request::new(ListExecutorsRequest {})) - .await - .context("Failed to list executors")?; - - let executors = response.into_inner().executors; - - tracing::debug!("List executors response: {:#?}", executors); - - Ok(executors) + Ok(Self { + client: ExecutorsClientWrapper::new(client), }) - .await } pub async fn get( @@ -61,12 +90,7 @@ impl ExecutorsHandler { function_name: function_name.clone(), }; - match self - .client - .clone() - .get_executor(Request::new(request)) - .await - { + match self.client.get_executor(request).await { Ok(response) => Ok(Some(response.into_inner())), Err(status) if status.code() == tonic::Code::NotFound => Ok(None), Err(err) => Err(err).context(format!( @@ -86,15 +110,10 @@ impl ExecutorsHandler { function_name: indexer_config.function_name.clone(), }; - let response = self - .client - .clone() - .start_executor(Request::new(request.clone())) - .await - .context(format!( - "Failed to start executor: {}", - indexer_config.get_full_name() - ))?; + let response = self.client.start_executor(request).await.context(format!( + "Failed to start executor: {}", + indexer_config.get_full_name() + ))?; tracing::debug!( account_id = indexer_config.account_id.as_str(), @@ -114,8 +133,7 @@ impl ExecutorsHandler { let response = self .client - .clone() - .stop_executor(Request::new(request.clone())) + .stop_executor(request) .await .context(format!("Failed to stop executor: {executor_id}"))?; @@ -147,7 +165,7 @@ impl ExecutorsHandler { Ok(()) } - pub async fn synchronise_executor(&self, config: &IndexerConfig) -> anyhow::Result<()> { + pub async fn synchronise(&self, config: &IndexerConfig) -> anyhow::Result<()> { let executor = self .get(config.account_id.clone(), config.function_name.clone()) .await?; @@ -193,3 +211,208 @@ impl ExecutorsHandler { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + use mockall::predicate::*; + use tonic::Response; + + impl Clone for MockExecutorsClientWrapperImpl { + fn clone(&self) -> Self { + Self::default() + } + } + + #[tokio::test] + async fn resumes_stopped_executors() { + let config = IndexerConfig::default(); + + let mut mock_client = ExecutorsClientWrapper::default(); + mock_client + .expect_get_executor::() + .with(eq(GetExecutorRequest { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + })) + .returning(|_| Err(tonic::Status::not_found("not found"))) + .once(); + mock_client + .expect_start_executor::() + .with(eq(StartExecutorRequest { + code: config.code.clone(), + schema: config.schema.clone(), + redis_stream: config.get_redis_stream_key(), + version: config.get_registry_version(), + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + })) + .returning(|_| { + Ok(tonic::Response::new(StartExecutorResponse { + executor_id: "executor_id".to_string(), + })) + }) + .once(); + + let handler = ExecutorsHandler { + client: mock_client, + }; + + handler.synchronise(&config).await.unwrap() + } + + #[tokio::test] + async fn reconfigures_outdated_executors() { + let config = IndexerConfig::default(); + + let executor = ExecutorInfo { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + executor_id: "executor_id".to_string(), + version: config.get_registry_version() - 1, + health: None, + }; + + let mut mock_client = ExecutorsClientWrapper::default(); + mock_client + .expect_stop_executor::() + .with(eq(StopExecutorRequest { + executor_id: executor.executor_id.clone(), + })) + .returning(|_| { + Ok(Response::new(StopExecutorResponse { + executor_id: "executor_id".to_string(), + })) + }) + .once(); + mock_client + .expect_start_executor::() + .with(eq(StartExecutorRequest { + code: config.code.clone(), + schema: config.schema.clone(), + redis_stream: config.get_redis_stream_key(), + version: config.get_registry_version(), + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + })) + .returning(|_| { + Ok(tonic::Response::new(StartExecutorResponse { + executor_id: "executor_id".to_string(), + })) + }) + .once(); + mock_client + .expect_get_executor::() + .with(always()) + .returning(move |_| Ok(Response::new(executor.clone()))) + .once(); + + let handler = ExecutorsHandler { + client: mock_client, + }; + + handler.synchronise(&config).await.unwrap() + } + + #[tokio::test] + async fn restarts_unhealthy_executors() { + tokio::time::pause(); + + let config = IndexerConfig::default(); + + let executor = ExecutorInfo { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + executor_id: "executor_id".to_string(), + version: config.get_registry_version(), + health: Some(runner::Health { + execution_state: runner::ExecutionState::Stalled.into(), + }), + }; + + let mut mock_client = ExecutorsClientWrapper::default(); + mock_client + .expect_stop_executor::() + .with(eq(StopExecutorRequest { + executor_id: executor.executor_id.clone(), + })) + .returning(|_| { + Ok(Response::new(StopExecutorResponse { + executor_id: "executor_id".to_string(), + })) + }) + .once(); + mock_client + .expect_start_executor::() + .with(eq(StartExecutorRequest { + code: config.code.clone(), + schema: config.schema.clone(), + redis_stream: config.get_redis_stream_key(), + version: config.get_registry_version(), + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + })) + .returning(|_| { + Ok(tonic::Response::new(StartExecutorResponse { + executor_id: "executor_id".to_string(), + })) + }) + .once(); + mock_client + .expect_get_executor::() + .with(always()) + .returning(move |_| Ok(Response::new(executor.clone()))) + .once(); + + let handler = ExecutorsHandler { + client: mock_client, + }; + + handler.synchronise(&config).await.unwrap() + } + + #[tokio::test] + async fn ignores_healthy_executors() { + tokio::time::pause(); + + let config = IndexerConfig::default(); + + let healthy_states = vec![ + runner::ExecutionState::Running, + runner::ExecutionState::Failing, + runner::ExecutionState::Waiting, + runner::ExecutionState::Stopped, + ]; + + for healthy_state in healthy_states { + let executor = ExecutorInfo { + account_id: config.account_id.to_string(), + function_name: config.function_name.clone(), + executor_id: "executor_id".to_string(), + version: config.get_registry_version(), + health: Some(runner::Health { + execution_state: healthy_state.into(), + }), + }; + + let mut mock_client = ExecutorsClientWrapper::default(); + mock_client + .expect_stop_executor::() + .never(); + mock_client + .expect_start_executor::() + .never(); + mock_client + .expect_get_executor::() + .with(always()) + .returning(move |_| Ok(Response::new(executor.clone()))); + + let handler = ExecutorsHandler { + client: mock_client, + }; + + handler.synchronise(&config).await.unwrap() + } + } +} diff --git a/coordinator/src/lifecycle.rs b/coordinator/src/lifecycle.rs index 80ebeddc..32fb2f9f 100644 --- a/coordinator/src/lifecycle.rs +++ b/coordinator/src/lifecycle.rs @@ -125,7 +125,7 @@ impl<'a> LifecycleManager<'a> { if let Err(error) = self .block_streams_handler - .synchronise_block_stream(config, state.block_stream_synced_at) + .synchronise(config, state.block_stream_synced_at) .await { warn!(?error, "Failed to synchronise block stream, retrying..."); @@ -135,7 +135,7 @@ impl<'a> LifecycleManager<'a> { state.block_stream_synced_at = Some(config.get_registry_version()); - if let Err(error) = self.executors_handler.synchronise_executor(config).await { + if let Err(error) = self.executors_handler.synchronise(config).await { warn!(?error, "Failed to synchronise executor, retrying..."); return LifecycleState::Running;