diff --git a/api/src/infrastructure/repositories/alert_config_repo.rs b/api/src/infrastructure/repositories/alert_config_repo.rs index 2f103648..c52c1220 100644 --- a/api/src/infrastructure/repositories/alert_config_repo.rs +++ b/api/src/infrastructure/repositories/alert_config_repo.rs @@ -15,6 +15,7 @@ use crate::infrastructure::models::alert_config::NewSlackAlertConfigData; use crate::infrastructure::models::alert_config::{ AlertConfigData, MonitorAlertConfigData, NewAlertConfigData, }; +use crate::infrastructure::repositories::alert_configs::GetByMonitors; use crate::infrastructure::repositories::Repository; macro_rules! build_polymorphic_query { @@ -35,6 +36,7 @@ macro_rules! build_polymorphic_query { slack_alert_config::dsl::slack_channel.nullable(), slack_alert_config::dsl::slack_bot_oauth_token.nullable(), )) + .distinct() .into_boxed() }}; } @@ -124,6 +126,68 @@ impl<'a> AlertConfigRepository<'a> { Ok(()) } + + async fn fetch_alert_configs( + &mut self, + tenant: &str, + monitor_ids: Option<&[Uuid]>, + ) -> Result, Error> { + let mut connection = get_connection(self.pool).await?; + let (alert_config_datas, monitor_alert_config_datas) = connection + .transaction::<(Vec, Vec), DieselError, _>( + |conn| { + Box::pin(async move { + let query = + build_polymorphic_query!().filter(alert_config::tenant.eq(tenant)); + let alert_configs: Vec = + if let Some(monitor_ids) = monitor_ids { + query + .inner_join( + monitor_alert_config::table + .on(monitor_alert_config::alert_config_id + .eq(alert_config::alert_config_id)), + ) + .filter(monitor_alert_config::monitor_id.eq_any(monitor_ids)) + .load(conn) + .await? + } else { + query.load(conn).await? + }; + + let monitor_alert_configs = + MonitorAlertConfigData::belonging_to(&alert_configs) + .select(MonitorAlertConfigData::as_select()) + .load(conn) + .await?; + + Ok((alert_configs, monitor_alert_configs)) + }) + }, + ) + .await + .map_err(|err| Error::RepositoryError(err.to_string()))?; + + monitor_alert_config_datas + .grouped_by(&alert_config_datas) + .into_iter() + .zip(alert_config_datas) + .map(|(monitor_alert_config_datas, alert_config_datas)| { + self.db_to_model(&alert_config_datas, &monitor_alert_config_datas) + }) + .collect::, Error>>() + } +} + +#[async_trait] +#[allow(clippy::needless_lifetimes)] // This is needed for the lifetime of the pool +impl<'a> GetByMonitors for AlertConfigRepository<'a> { + async fn get_by_monitors( + &mut self, + monitor_ids: &[Uuid], + tenant: &str, + ) -> Result, Error> { + self.fetch_alert_configs(tenant, Some(monitor_ids)).await + } } #[async_trait] @@ -174,37 +238,7 @@ impl<'a> Repository for AlertConfigRepository<'a> { } async fn all(&mut self, tenant: &str) -> Result, Error> { - let mut connection = get_connection(self.pool).await?; - let (alert_config_datas, monitor_alert_config_datas) = connection - .transaction::<(Vec, Vec), DieselError, _>( - |conn| { - Box::pin(async move { - let alert_configs: Vec = build_polymorphic_query!() - .filter(alert_config::tenant.eq(tenant)) - .load(conn) - .await?; - - let monitor_alert_configs = - MonitorAlertConfigData::belonging_to(&alert_configs) - .select(MonitorAlertConfigData::as_select()) - .load(conn) - .await?; - - Ok((alert_configs, monitor_alert_configs)) - }) - }, - ) - .await - .map_err(|err| Error::RepositoryError(err.to_string()))?; - - Ok(monitor_alert_config_datas - .grouped_by(&alert_config_datas) - .into_iter() - .zip(alert_config_datas) - .map(|(monitor_alert_config_datas, alert_config_datas)| { - self.db_to_model(&alert_config_datas, &monitor_alert_config_datas) - }) - .collect::, Error>>()?) + self.fetch_alert_configs(tenant, None).await } async fn save(&mut self, alert_config: &AlertConfig) -> Result<(), Error> { diff --git a/api/src/infrastructure/repositories/alert_configs.rs b/api/src/infrastructure/repositories/alert_configs.rs new file mode 100644 index 00000000..78d24dac --- /dev/null +++ b/api/src/infrastructure/repositories/alert_configs.rs @@ -0,0 +1,18 @@ +use async_trait::async_trait; +use uuid::Uuid; + +#[cfg(test)] +use mockall::automock; + +use crate::domain::models::alert_config::AlertConfig; +use crate::errors::Error; + +#[cfg_attr(test, automock)] +#[async_trait] +pub trait GetByMonitors { + async fn get_by_monitors( + &mut self, + monitor_ids: &[Uuid], + tenant: &str, + ) -> Result, Error>; +} diff --git a/api/src/infrastructure/repositories/mod.rs b/api/src/infrastructure/repositories/mod.rs index 3066089a..6f526dd9 100644 --- a/api/src/infrastructure/repositories/mod.rs +++ b/api/src/infrastructure/repositories/mod.rs @@ -1,4 +1,5 @@ pub mod alert_config_repo; +pub mod alert_configs; pub mod api_key_repo; pub mod api_keys; pub mod monitor; diff --git a/api/tests/alert_config_repo_test.rs b/api/tests/alert_config_repo_test.rs index 20b9c8f2..58fd9a05 100644 --- a/api/tests/alert_config_repo_test.rs +++ b/api/tests/alert_config_repo_test.rs @@ -9,9 +9,50 @@ use cron_mon_api::domain::models::alert_config::{AlertConfig, AlertType, SlackAl use cron_mon_api::errors::Error; use cron_mon_api::infrastructure::models::alert_config::NewAlertConfigData; use cron_mon_api::infrastructure::repositories::alert_config_repo::AlertConfigRepository; +use cron_mon_api::infrastructure::repositories::alert_configs::GetByMonitors; use cron_mon_api::infrastructure::repositories::Repository; use common::{infrastructure, Infrastructure}; +use uuid::Uuid; + +#[rstest] +#[case( + vec![ + gen_uuid("c1bf0515-df39-448b-aa95-686360a33b36"), + gen_uuid("f0b291fe-bd41-4787-bc2d-1329903f7a6a") + ], + vec![ + "Test Slack alert (for errors)".to_owned(), + "Test Slack alert (for lates and errors)".to_owned(), + "Test Slack alert (for lates)".to_owned(), + ] +)] +#[case( + vec![ + gen_uuid("f0b291fe-bd41-4787-bc2d-1329903f7a6a") + ], + vec![ + "Test Slack alert (for errors)".to_owned() + ] +)] +#[tokio::test] +async fn test_get_by_monitors( + #[case] monitor_ids: Vec, + #[case] alert_config_names: Vec, + #[future] infrastructure: Infrastructure, +) { + let infra = infrastructure.await; + let mut repo = AlertConfigRepository::new(&infra.pool); + + let alert_configs = repo.get_by_monitors(&monitor_ids, "foo").await.unwrap(); + + let names: Vec = alert_configs + .iter() + .map(|alert_config| alert_config.name.clone()) + .collect(); + + assert_eq!(names, alert_config_names); +} #[rstest] #[tokio::test] @@ -28,9 +69,9 @@ async fn test_all(#[future] infrastructure: Infrastructure) { assert_eq!( names, vec![ - "Test Slack alert (for lates)".to_owned(), "Test Slack alert (for errors)".to_owned(), - "Test Slack alert (for lates and errors)".to_owned() + "Test Slack alert (for lates and errors)".to_owned(), + "Test Slack alert (for lates)".to_owned(), ] );