diff --git a/tembo-operator/Cargo.lock b/tembo-operator/Cargo.lock index 10d4943b5..320a4cf99 100644 --- a/tembo-operator/Cargo.lock +++ b/tembo-operator/Cargo.lock @@ -494,7 +494,7 @@ dependencies = [ [[package]] name = "controller" -version = "0.29.5" +version = "0.30.0" dependencies = [ "actix-web", "anyhow", diff --git a/tembo-operator/Cargo.toml b/tembo-operator/Cargo.toml index a43feb081..376d1d502 100644 --- a/tembo-operator/Cargo.toml +++ b/tembo-operator/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "controller" description = "Tembo Operator for Postgres" -version = "0.29.5" +version = "0.30.0" edition = "2021" default-run = "controller" license = "Apache-2.0" diff --git a/tembo-operator/src/trunk.rs b/tembo-operator/src/trunk.rs index 3608c7f5f..de0434fa5 100644 --- a/tembo-operator/src/trunk.rs +++ b/tembo-operator/src/trunk.rs @@ -1,10 +1,14 @@ use k8s_openapi::api::core::v1::ConfigMap; use kube::{runtime::controller::Action, Api, Client}; use lazy_static::lazy_static; +use schemars::JsonSchema; +use serde::de::Error; +use serde::{Deserialize, Serialize}; use std::{collections::BTreeMap, env, time::Duration}; use crate::configmap::apply_configmap; use tracing::log::error; +use utoipa::ToSchema; const DEFAULT_TRUNK_REGISTRY_DOMAIN: &str = "registry.pgtrunk.io"; @@ -17,6 +21,51 @@ pub struct ExtensionRequiresLoad { pub library_name: String, } +// TODO(ianstanton) We can publish this as a crate library and use it in other projects, such as Trunk CLI and Registry +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema, JsonSchema)] +pub struct TrunkProjectMetadata { + pub name: String, + pub description: Option, + pub documentation_link: Option, + pub repository_link: Option, + pub version: String, + pub postgres_versions: Option>, + pub extensions: Vec, + pub downloads: Option>, +} + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema, JsonSchema)] +pub struct TrunkExtensionMetadata { + pub extension_name: String, + pub version: String, + pub trunk_project_name: String, + pub dependencies_extension_names: Option>, + pub loadable_libraries: Option>, + pub configurations: Option>, + pub control_file: TrunkControlFileMetadata, +} + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema, JsonSchema)] +pub struct TrunkDownloadMetadata { + pub link: String, + pub pg_version: i32, + pub platform: String, + pub sha256: String, +} + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema, JsonSchema)] +pub struct TrunkLoadableLibrariesMetadata { + pub library_name: String, + pub requires_restart: bool, + pub priority: i32, +} + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema, JsonSchema)] +pub struct TrunkControlFileMetadata { + pub absent: bool, + pub content: Option, +} + // This is a place to configure specific exceptions before // Trunk handles everything. // In terms of extensions that require load, we need to know @@ -117,6 +166,8 @@ pub async fn reconcile_trunk_configmap(client: Client, namespace: &str) -> Resul } } +// TODO(ianstanton) This information is now available in the trunk project metadata. We should fetch it from there +// instead async fn requires_load_list_from_trunk() -> Result, TrunkError> { let domain = env::var("TRUNK_REGISTRY_DOMAIN") .unwrap_or_else(|_| DEFAULT_TRUNK_REGISTRY_DOMAIN.to_string()); @@ -137,13 +188,329 @@ async fn requires_load_list_from_trunk() -> Result, TrunkError> { } } +// Get all trunk projects +pub async fn get_trunk_projects() -> Result, TrunkError> { + let domain = env::var("TRUNK_REGISTRY_DOMAIN") + .unwrap_or_else(|_| DEFAULT_TRUNK_REGISTRY_DOMAIN.to_string()); + let url = format!("https://{}/api/v1/trunk-projects", domain); + + let response = reqwest::get(&url).await?; + + if response.status().is_success() { + let response_body = response.text().await?; + let project_metadata: Vec = serde_json::from_str(&response_body)?; + Ok(project_metadata.clone()) + } else { + error!("Failed to fetch all trunk projects: {}", response.status()); + Err(TrunkError::NetworkFailure( + response.error_for_status().unwrap_err(), + )) + } +} + +// Get all trunk project names +pub async fn get_trunk_project_names() -> Result, TrunkError> { + let domain = env::var("TRUNK_REGISTRY_DOMAIN") + .unwrap_or_else(|_| DEFAULT_TRUNK_REGISTRY_DOMAIN.to_string()); + let url = format!("https://{}/api/v1/trunk-projects", domain); + + let response = reqwest::get(&url).await?; + + if response.status().is_success() { + let response_body = response.text().await?; + let project_metadata: Vec = serde_json::from_str(&response_body)?; + let project_names: Vec = project_metadata + .iter() + .map(|project_metadata| project_metadata.name.clone()) + .collect(); + Ok(project_names) + } else { + error!("Failed to fetch all trunk projects: {}", response.status()); + Err(TrunkError::NetworkFailure( + response.error_for_status().unwrap_err(), + )) + } +} + +// Get all metadata entries for a given trunk project +async fn get_trunk_project_metadata( + trunk_project: String, +) -> Result, TrunkError> { + let domain = env::var("TRUNK_REGISTRY_DOMAIN") + .unwrap_or_else(|_| DEFAULT_TRUNK_REGISTRY_DOMAIN.to_string()); + let url = format!("https://{}/api/v1/trunk-projects/{}", domain, trunk_project); + + let response = reqwest::get(&url).await?; + + if response.status().is_success() { + let response_body = response.text().await?; + let project_metadata: Vec = serde_json::from_str(&response_body)?; + Ok(project_metadata) + } else { + error!( + "Failed to fetch metadata for trunk project {}: {}", + trunk_project, + response.status() + ); + Err(TrunkError::NetworkFailure( + response.error_for_status().unwrap_err(), + )) + } +} + +// Get trunk project metadata for a specific version +async fn get_trunk_project_metadata_for_version( + trunk_project: String, + version: String, +) -> Result { + let domain = env::var("TRUNK_REGISTRY_DOMAIN") + .unwrap_or_else(|_| DEFAULT_TRUNK_REGISTRY_DOMAIN.to_string()); + let url = format!( + "https://{}/api/v1/trunk-projects/{}/version/{}", + domain, trunk_project, version + ); + + let response = reqwest::get(&url).await?; + + if response.status().is_success() { + let response_body = response.text().await?; + let project_metadata: Vec = serde_json::from_str(&response_body)?; + // There will only be one index here, so we can safely assume index 0 + let project_metadata = match project_metadata.get(0) { + Some(project_metadata) => project_metadata, + None => { + error!( + "Failed to fetch metadata for trunk project {} version {}", + trunk_project, version + ); + return Err(TrunkError::ParsingIssue(serde_json::Error::custom( + "No metadata found", + ))); + } + }; + + Ok(project_metadata.clone()) + } else { + error!( + "Failed to fetch metadata for trunk project {} version {}: {}", + trunk_project, + version, + response.status() + ); + Err(TrunkError::NetworkFailure( + response.error_for_status().unwrap_err(), + )) + } +} + +// Check if extension name is in list of trunk project names +pub async fn extension_name_matches_trunk_project( + extension_name: String, +) -> Result { + let trunk_project_names = match get_trunk_project_names().await { + Ok(trunk_project_names) => trunk_project_names, + Err(e) => { + error!("Failed to get trunk project names: {:?}", e); + return Err(TrunkError::ConfigMapApplyError); + } + }; + Ok(trunk_project_names.contains(&extension_name)) +} + +// Find the trunk project name associated with a given extension +pub async fn get_trunk_project_for_extension( + extension_name: String, +) -> Result, TrunkError> { + let trunk_projects = match get_trunk_projects().await { + Ok(trunk_projects) => trunk_projects, + Err(e) => { + error!("Failed to get trunk projects: {:?}", e); + return Err(TrunkError::ConfigMapApplyError); + } + }; + for trunk_project in trunk_projects { + for extension in trunk_project.extensions { + if extension.extension_name == extension_name { + return Ok(Some(trunk_project.name)); + } + } + } + Ok(None) +} + +// Check if control file is absent for a given trunk project version +pub async fn is_control_file_absent( + trunk_project: String, + version: String, +) -> Result { + let project_metadata: TrunkProjectMetadata = + match get_trunk_project_metadata_for_version(trunk_project, version.clone()).await { + Ok(project_metadata) => project_metadata, + Err(e) => { + error!( + "Failed to get trunk project metadata for version {}: {:?}", + version, e + ); + return Err(Action::requeue(Duration::from_secs(300))); + } + }; + // TODO(ianstanton) This assumes that there is only one extension in the project, but we need to handle the case + // where there are multiple extensions + let control_file_absent = project_metadata + .extensions + .get(0) + .unwrap() + .control_file + .absent; + Ok(control_file_absent) +} + +// Check if extension has loadable_library metadata for a given trunk project version and return the library name +pub async fn get_loadable_library_name( + trunk_project: String, + version: String, + extension_name: String, +) -> Result, Action> { + let project_metadata: TrunkProjectMetadata = match get_trunk_project_metadata_for_version( + trunk_project.clone(), + version.clone(), + ) + .await + { + Ok(project_metadata) => project_metadata, + Err(e) => { + error!( + "Failed to get trunk project metadata for version {}: {:?}", + version, e + ); + return Err(Action::requeue(Duration::from_secs(300))); + } + }; + // Find the extension in the project metadata + let extension_metadata = match project_metadata + .extensions + .iter() + .find(|e| e.extension_name == extension_name) + { + Some(extension_metadata) => extension_metadata, + None => { + error!( + "Failed to find extension {} in trunk project {} version {}", + extension_name, trunk_project, version + ); + return Err(Action::requeue(Duration::from_secs(300))); + } + }; + // Find the loadable library in the extension metadata + let loadable_library_name = match extension_metadata.loadable_libraries { + Some(ref loadable_libraries) => { + let loadable_library = loadable_libraries + .iter() + .find(|l| l.requires_restart == true); + match loadable_library { + Some(loadable_library) => Some(loadable_library.library_name.clone()), + None => None, + } + } + None => None, + }; + Ok(loadable_library_name) +} + // Define error type #[derive(Debug, thiserror::Error)] pub enum TrunkError { - #[error("Failed to update extensions libraries list from trunk: {0}")] + #[error("Failed to fetch metadata from trunk: {0}")] NetworkFailure(#[from] reqwest::Error), #[error("Failed to parse extensions libraries list from trunk: {0}")] ParsingIssue(#[from] serde_json::Error), #[error("Failed to apply trunk configmap")] ConfigMapApplyError, } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_get_trunk_projects() { + let result = get_trunk_projects().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_get_trunk_project_metadata() { + let trunk_project = "auto_explain".to_string(); + let result = get_trunk_project_metadata(trunk_project).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_get_trunk_project_names() { + let result = get_trunk_project_names().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_get_trunk_project_metadata_for_version() { + let trunk_project = "auto_explain".to_string(); + let version = "15.3.0".to_string(); + let result = get_trunk_project_metadata_for_version(trunk_project, version).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_extension_name_matches_trunk_project() { + let extension_name = "auto_explain".to_string(); + let result = extension_name_matches_trunk_project(extension_name).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), true); + + let extension_name = "pgml".to_string(); + let result = extension_name_matches_trunk_project(extension_name).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), false); + + let extension_name = "vector".to_string(); + let result = extension_name_matches_trunk_project(extension_name).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), false); + } + + #[tokio::test] + async fn test_get_trunk_project_for_extension() { + let extension_name = "auto_explain".to_string(); + let result = get_trunk_project_for_extension(extension_name).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some("auto_explain".to_string())); + + let extension_name = "pgml".to_string(); + let result = get_trunk_project_for_extension(extension_name).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some("postgresml".to_string())); + + let extension_name = "vector".to_string(); + let result = get_trunk_project_for_extension(extension_name).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some("pgvector".to_string())); + } + + #[tokio::test] + async fn test_is_control_file_absent() { + let trunk_project = "auto_explain".to_string(); + let version = "15.3.0".to_string(); + let result = is_control_file_absent(trunk_project, version).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), true); + } + + #[tokio::test] + async fn test_get_loadable_library_name() { + let trunk_project = "auto_explain".to_string(); + let version = "15.3.0".to_string(); + let extension_name = "auto_explain".to_string(); + let result = get_loadable_library_name(trunk_project, version, extension_name).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some("auto_explain".to_string())); + } +}