diff --git a/Cargo.lock b/Cargo.lock index d7c3cf1..272e286 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -840,6 +840,41 @@ dependencies = [ "typenum", ] +[[package]] +name = "darling" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83b2eb4d90d12bdda5ed17de686c2acb4c57914f8f921b8da7e112b5a36f3fe1" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622687fe0bac72a04e5599029151f5796111b90f1baaa9b544d807a5e31cd120" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.11.1", + "syn 2.0.68", +] + +[[package]] +name = "darling_macro" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.68", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -883,6 +918,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ "powerfmt", + "serde", ] [[package]] @@ -1582,6 +1618,12 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "hmac" version = "0.12.1" @@ -1765,6 +1807,12 @@ version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.5.0" @@ -1798,6 +1846,7 @@ dependencies = [ "autocfg", "hashbrown 0.12.3", "rustc-rayon 0.5.0", + "serde", ] [[package]] @@ -2678,9 +2727,11 @@ dependencies = [ "kclvm-sema", "logcraft-runtime", "rayon", + "regex", "reqwest", "serde", "serde_json", + "serde_with", "serde_yaml_ng", "similar", "tempfile", @@ -4358,6 +4409,36 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_with" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20" +dependencies = [ + "base64 0.22.1", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.2.6", + "serde", + "serde_derive", + "serde_json", + "serde_with_macros", + "time 0.3.36", +] + +[[package]] +name = "serde_with_macros" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65569b702f41443e8bc8bbb1c5779bd0450bbe723b56198980e80ec45780bce2" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.68", +] + [[package]] name = "serde_yaml" version = "0.9.34+deprecated" diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index cdfaab6..2426770 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -21,7 +21,7 @@ serde_yaml_ng.workspace = true url.workspace = true tokio.workspace = true tokio-util.workspace = true -reqwest = { workspace = true, features = ["stream"] } +reqwest.workspace = true wasmtime.workspace = true serde_json.workspace = true inquire.workspace = true @@ -40,6 +40,8 @@ async-trait = "0.1" tempfile = "3.10" uuid = "1.8" similar = "2.5" +regex = "1.10" +serde_with = "3.8" # Local dependencies logcraft-runtime = { path = "../runtime" } diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 43df3c3..0b0f686 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -28,6 +28,7 @@ pub const LGC_CONFIG_PATH: &str = "lgc.yaml"; pub const LGC_RULES_DIR: &str = "rules"; use crate::plugins::Plugin; +use crate::state::backends::StateBackend; use crate::utils::ensure_kebab_case; /// ProjectConfiguration definition @@ -36,6 +37,8 @@ use crate::utils::ensure_kebab_case; /// Hash is calculated for the name field to provide unique objects. #[derive(Serialize, Deserialize, Default, Clone)] pub struct ProjectConfiguration { + #[serde(default)] + pub state: StateBackend, pub plugins: BTreeMap, pub environments: BTreeSet, pub services: BTreeSet, diff --git a/crates/common/src/state.rs b/crates/common/src/state.rs deleted file mode 100644 index 3923b1b..0000000 --- a/crates/common/src/state.rs +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) 2023 LogCraft, SAS. -// SPDX-License-Identifier: MPL-2.0 - -use crate::detections::{DetectionState, ServiceDetections}; -use anyhow::{anyhow, bail, Result}; -use console::style; -use dashmap::DashMap; -use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; -use serde::{Deserialize, Serialize}; -use std::{ - collections::{HashMap, HashSet}, - fs, - io::{BufReader, BufWriter}, - path::PathBuf, -}; -use uuid::Uuid; - -const LGC_STATE_PATH: &str = ".logcraft/state.json"; -const LGC_STATE_VERSION: usize = 1; - -#[derive(Debug, Serialize, Deserialize)] -pub struct State { - /// State unique ID - lineage: Uuid, - /// Serial number of the state file. - /// Increments every time the state file is written. - serial: usize, - /// Version of the state schema - version: usize, - /// Version of LogCraft CLI - lgc_version: String, - /// List of rules to track service_name => (rule_name, rule_settings) - pub services: ServiceDetections, -} - -impl State { - pub fn clean(&mut self) -> Result<()> { - self.services.clear(); - self.write() - } - - pub fn read() -> Result { - let path = PathBuf::from(LGC_STATE_PATH); - if !path.is_file() { - return Ok(Self { - lineage: Uuid::new_v4(), - serial: 0, - version: LGC_STATE_VERSION, - lgc_version: env!("CARGO_PKG_VERSION").to_string(), - services: HashMap::new(), - }); - } - - let f = fs::File::open(path)?; - let reader = BufReader::new(f); - - match serde_json::from_reader(reader) { - Ok(state) => Ok(state), - Err(e) => { - bail!("unable to load state file: {}", e) - } - } - } - - pub fn write(&mut self) -> Result<()> { - let f = fs::File::create(LGC_STATE_PATH)?; - - self.serial += 1; - self.lgc_version = env!("CARGO_PKG_VERSION").to_string(); - - let writer = BufWriter::new(f); - serde_json::to_writer_pretty(writer, self) - .map_err(|e| anyhow!("unable to write state file: {}", e)) - } - - pub fn missing_rules(&self, detections: &ServiceDetections) -> ServiceDetections { - let to_remove: DashMap> = DashMap::new(); - - detections.par_iter().for_each(|(service_id, rules)| { - if let Some(state_rules) = self.services.get(service_id) { - state_rules.difference(rules).for_each(|rule| { - to_remove - .entry(service_id.clone()) - .and_modify(|s| { - s.insert(rule.clone()); - }) - .or_insert(HashSet::from([rule.clone()])); - println!( - "[-] rule: `{}` will be deleted from `{}`", - style(&rule.name).red(), - &service_id - ); - }); - } - }); - - to_remove.into_iter().collect() - } -} diff --git a/crates/common/src/state/backends/http.rs b/crates/common/src/state/backends/http.rs new file mode 100644 index 0000000..84686cb --- /dev/null +++ b/crates/common/src/state/backends/http.rs @@ -0,0 +1,231 @@ +// Copyright (c) 2023 LogCraft, SAS. +// SPDX-License-Identifier: MPL-2.0 + +use std::collections::HashMap; +use std::str::FromStr; +use std::time::Duration; + +use super::State; +use anyhow::{anyhow, bail, Result}; +use async_trait::async_trait; +use regex::Regex; +use reqwest::{ + header, + header::{HeaderMap, HeaderName, HeaderValue}, + Certificate, Client, ClientBuilder, Method, RequestBuilder, Response, StatusCode, +}; +use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; +use url::Url; +use uuid::Uuid; + +use super::BackendActions; + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Clone)] +pub struct HttpBackend { + address: String, + update_method: Option, + lock_address: Option, + unlock_address: Option, + lock_method: Option, + unlock_method: Option, + username: Option, + password: Option, + skip_cert_verification: Option, + timeout: Option, + client_ca_certificate_pem: Option, + client_certificate_pem: Option, + client_private_key_pem: Option, + headers: Option>, +} + +impl HttpBackend { + fn check_headers(&self) -> Result { + let mut headermap = HeaderMap::new(); + if let Some(headers) = &self.headers { + for (key, value) in headers { + if key.is_empty() || value.is_empty() { + bail!("remote http state header key or value cannot be empty") + } + + if !value.is_ascii() { + bail!("remote http state header value must only contain ascii characters") + } + + if !Regex::new("[^a-zA-Z0-9-_]").unwrap().is_match(key) { + bail!("remote http state header key value must only contain A-Za-z0-9-_ characters") + } + + if ["content-type", "content-md5"].contains(&key.to_lowercase().as_str()) { + bail!("remote http state header key {} is reserved", key) + } + + headermap.insert(HeaderName::from_str(key)?, HeaderValue::from_str(value)?); + } + } + Ok(headermap) + } + + fn client(&self) -> Result { + let headermap = self.check_headers()?; + if headermap.get(header::AUTHORIZATION).is_some() && self.username.is_some() { + bail!( + "http remote state request headers {} cannot be set when providing username", + header::AUTHORIZATION + ) + } + + let client = ClientBuilder::new() + .default_headers(headermap) + .timeout(Duration::from_secs(self.timeout.unwrap_or(60))) + .danger_accept_invalid_certs(self.skip_cert_verification.unwrap_or_default()); + + // Set certificates + let client = match ( + &self.client_ca_certificate_pem, + &self.client_certificate_pem, + &self.client_private_key_pem, + ) { + (Some(ca), Some(cert), Some(key)) => { + let bundle = format!("{}\n{}\n{}", ca, cert, key); + client.add_root_certificate(Certificate::from_pem(bundle.as_bytes())?) + } + (None, Some(cert), Some(key)) => { + let bundle = format!("{}\n{}", cert, key); + client.add_root_certificate(Certificate::from_pem(bundle.as_bytes())?) + } + (Some(ca), None, None) => { + client.add_root_certificate(Certificate::from_pem(ca.as_bytes())?) + } + _ => client, + }; + + client + .build() + .map_err(|e| anyhow::anyhow!("unable to retrieve state: {}", e)) + } + + async fn send_auth(&self, req: RequestBuilder) -> Result { + if let Some(usr) = &self.username { + req.basic_auth(usr, self.password.clone()).send().await + } else { + req.send().await + } + .map_err(|e| anyhow::anyhow!("unable to retrieve state: {}", e)) + } + + async fn lock(&self, client: &Client, lock_address: &str) -> Result { + let lock_method = self.lock_method.clone().unwrap_or("LOCK".to_string()); + + let lock_id = Uuid::new_v4(); + + let req = client + .request(Method::from_str(&lock_method)?, Url::parse(lock_address)?) + .query(&[("ID", &lock_id)]); + + match self.send_auth(req).await { + Ok(resp) => match resp.status() { + StatusCode::OK => Ok(lock_id), + // StatusCode::CONFLICT => bail!("unable to lock state: already locked"), + _ => bail!( + "unable to lock state: {} {}", + resp.status(), + resp.text().await? + ), + }, + Err(e) => bail!("unable to lock state: {}", e), + } + } + + async fn unlock(&self, client: &Client, lock_id: &str) -> Result<()> { + let unlock_address = if let Some(address) = &self.unlock_address { + address + } else { + return Ok(()); + }; + let unlock_method = self.unlock_method.clone().unwrap_or("UNLOCK".to_string()); + let req = client + .request( + Method::from_str(&unlock_method)?, + Url::parse(unlock_address)?, + ) + .query(&[("ID", lock_id)]); + + match self.send_auth(req).await { + Ok(resp) => match resp.status() { + StatusCode::OK => Ok(()), + _ => bail!("unable to unlock state: {}", resp.status()), + }, + Err(e) => bail!("unable to unlock state: {}", e), + } + } +} + +#[async_trait] +impl BackendActions for HttpBackend { + async fn load(&self) -> Result { + let client = self.client()?; + + let req = client.request(Method::GET, Url::from_str(&self.address)?); + + let resp = self.send_auth(req).await?; + match resp.status() { + StatusCode::OK => resp + .json() + .await + .map_err(|e| anyhow::anyhow!("unable to decode state: {}", e)), + StatusCode::NOT_FOUND => Ok(State::default()), + _ => bail!("unable to retrieve state: {}", resp.status()), + } + } + + async fn save(&self, state: &mut State) -> anyhow::Result<()> { + let client = self.client()?; + + state.serial += 1; + state.lgc_version = env!("CARGO_PKG_VERSION").to_string(); + + // Lock state - If lock address is not set ignore state locking + let (req, lock_id) = if let Some(address) = &self.lock_address { + let lock_id = self.lock(&client, address).await?; + ( + client + .request( + Method::from_str( + self.update_method.as_ref().unwrap_or(&"POST".to_string()), + )?, + Url::from_str(&self.address)?, + ) + .query(&[("ID", lock_id)]) + .json(state), + &self.lock_address, + ) + } else { + ( + client + .request( + Method::from_str( + self.update_method.as_ref().unwrap_or(&"POST".to_string()), + )?, + Url::from_str(&self.address)?, + ) + .json(state), + &None, + ) + }; + + match lock_id { + Some(lock_id) => { + self.send_auth(req) + .await + .map_err(|e| anyhow!("unable to save state: {}", e))?; + self.unlock(&client, lock_id).await + } + None => { + self.send_auth(req).await?; + Ok(()) + } + } + } +} diff --git a/crates/common/src/state/backends/local.rs b/crates/common/src/state/backends/local.rs new file mode 100644 index 0000000..a1ce015 --- /dev/null +++ b/crates/common/src/state/backends/local.rs @@ -0,0 +1,50 @@ +// Copyright (c) 2023 LogCraft, SAS. +// SPDX-License-Identifier: MPL-2.0 + +use crate::state::LGC_DEFAULT_STATE_PATH; +use anyhow::{anyhow, Ok, Result}; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::{fs, io, path}; + +use super::State; + +use super::BackendActions; + +#[derive(Serialize, Deserialize, Clone)] +pub struct LocalBackend { + path: path::PathBuf, +} + +impl Default for LocalBackend { + fn default() -> Self { + Self { + path: path::PathBuf::from(LGC_DEFAULT_STATE_PATH), + } + } +} + +#[async_trait] +impl BackendActions for LocalBackend { + async fn load(&self) -> Result { + if !self.path.is_file() { + return Ok(State::default()); + } + + let f = fs::File::open(&self.path)?; + let reader = io::BufReader::new(f); + + serde_json::from_reader(reader).map_err(|e| anyhow!("unable to load state file: {}", e)) + } + + async fn save(&self, state: &mut State) -> anyhow::Result<()> { + let f = fs::File::create(&self.path)?; + + state.serial += 1; + state.lgc_version = env!("CARGO_PKG_VERSION").to_string(); + + let writer = io::BufWriter::new(f); + serde_json::to_writer_pretty(writer, state) + .map_err(|e| anyhow!("unable to write state file: {}", e)) + } +} diff --git a/crates/common/src/state/backends/mod.rs b/crates/common/src/state/backends/mod.rs new file mode 100644 index 0000000..f497e2d --- /dev/null +++ b/crates/common/src/state/backends/mod.rs @@ -0,0 +1,44 @@ +// Copyright (c) 2023 LogCraft, SAS. +// SPDX-License-Identifier: MPL-2.0 + +use super::State; +use anyhow::Result; +use async_trait::async_trait; +use local::LocalBackend; +use serde::{Deserialize, Serialize}; + +// Backends +mod http; +mod local; + +use http::HttpBackend; + +#[derive(Serialize, Deserialize, Clone)] +#[serde(tag = "type")] +pub enum StateBackend { + /// Local state backend + Local(LocalBackend), + /// Http state backend + Http(Box), +} + +impl StateBackend { + pub async fn load(&self) -> Result { + match self { + Self::Local(path) => path.load().await, + Self::Http(backend) => backend.load().await, + } + } +} + +impl Default for StateBackend { + fn default() -> Self { + Self::Local(LocalBackend::default()) + } +} + +#[async_trait] +pub trait BackendActions { + async fn load(&self) -> Result; + async fn save(&self, state: &mut State) -> Result<()>; +} diff --git a/crates/common/src/state/mod.rs b/crates/common/src/state/mod.rs new file mode 100644 index 0000000..792e582 --- /dev/null +++ b/crates/common/src/state/mod.rs @@ -0,0 +1,79 @@ +// Copyright (c) 2023 LogCraft, SAS. +// SPDX-License-Identifier: MPL-2.0 + +use crate::detections::{DetectionState, ServiceDetections}; +use anyhow::Result; +use console::style; +use dashmap::DashMap; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, HashSet}; +use uuid::Uuid; + +const LGC_DEFAULT_STATE_PATH: &str = ".logcraft/state.json"; +const LGC_STATE_VERSION: usize = 1; + +pub mod backends; +use backends::{BackendActions, StateBackend}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct State { + /// State unique ID + lineage: Uuid, + /// Serial number of the state file. + /// Increments every time the state file is written. + serial: usize, + /// Version of the state schema + version: usize, + /// Version of LogCraft CLI + lgc_version: String, + /// List of rules to track service_name => (rule_name, rule_settings) + pub services: ServiceDetections, +} + +impl Default for State { + fn default() -> Self { + Self { + lineage: Uuid::new_v4(), + serial: 0, + version: LGC_STATE_VERSION, + lgc_version: env!("CARGO_PKG_VERSION").to_string(), + services: HashMap::new(), + } + } +} + +impl State { + pub async fn save(&mut self, backend: &StateBackend) -> Result<()> { + match backend { + StateBackend::Local(path) => path.save(self).await, + StateBackend::Http(backend) => backend.save(self).await, + } + } + + pub fn missing_rules(&self, detections: &ServiceDetections, silent: bool) -> ServiceDetections { + let to_remove: DashMap> = DashMap::new(); + + detections.par_iter().for_each(|(service_id, rules)| { + if let Some(state_rules) = self.services.get(service_id) { + state_rules.difference(rules).for_each(|rule| { + to_remove + .entry(service_id.clone()) + .and_modify(|s| { + s.insert(rule.clone()); + }) + .or_insert(HashSet::from([rule.clone()])); + if !silent { + println!( + "[-] rule: `{}` will be deleted from `{}`", + style(&rule.name).red(), + &service_id + ); + } + }); + } + }); + + to_remove.into_iter().collect() + } +} diff --git a/src/commands/deploy.rs b/src/commands/deploy.rs index aa4a128..93ea875 100644 --- a/src/commands/deploy.rs +++ b/src/commands/deploy.rs @@ -11,7 +11,6 @@ use logcraft_common::{ configuration::{Environment, ProjectConfiguration, Service}, detections::{compare_detections, map_plugin_detections, DetectionState, ServiceDetections}, plugins::manager::{PluginActions, PluginManager}, - state::State, }; use serde_json::Value; use tokio::task::JoinSet; @@ -149,8 +148,8 @@ impl DeployCommand { } } - let mut state = State::read()?; - let to_remove = state.missing_rules(&returned_rules); + let mut state = config.state.load().await?; + let to_remove = state.missing_rules(&returned_rules, self.auto_approve); let changed = compare_detections(&detections, &returned_rules, &services, !self.auto_approve); @@ -186,7 +185,7 @@ impl DeployCommand { ) } Err(e) => { - state.write()?; + state.save(&config.state).await?; bail!( "on update for `{}` in `{}`: {}", style(&rule.name).red(), @@ -220,7 +219,7 @@ impl DeployCommand { ) } Err(e) => { - state.write()?; + state.save(&config.state).await?; bail!( "on update for `{}` in `{}`: {}", style(&rule.name).red(), @@ -254,7 +253,7 @@ impl DeployCommand { ); } Err(e) => { - state.write()?; + state.save(&config.state).await?; bail!( "on deletion for `{}` in `{}`: {}", style(&rule.name).red(), @@ -266,11 +265,21 @@ impl DeployCommand { } } } - state.write()?; + state.save(&config.state).await?; } else { bail!("action aborted") } } else { + // Update state to include any missing rules detected + if returned_rules + .iter() + .any(|(k, v)| state.services.get(k) != Some(v)) + { + tracing::info!("including unchanged remote detection rules that are not currently referenced in state"); + state.services.extend(returned_rules); + state.save(&config.state).await?; + } + tracing::info!("no differences found"); } } diff --git a/src/commands/destroy.rs b/src/commands/destroy.rs index b2bea4e..3951056 100644 --- a/src/commands/destroy.rs +++ b/src/commands/destroy.rs @@ -8,7 +8,6 @@ use dialoguer::{theme::ColorfulTheme, Confirm, Select}; use logcraft_common::{ configuration::{Environment, ProjectConfiguration, Service}, plugins::manager::{PluginActions, PluginManager}, - state::State, }; use std::collections::HashMap; use tokio::task::JoinSet; @@ -34,7 +33,7 @@ pub struct DestroyCommand { impl DestroyCommand { pub async fn run(self, config: &ProjectConfiguration) -> Result<()> { // Load all detections - let mut state = State::read()?; + let mut state = config.state.load().await?; // Prompt theme let prompt_theme = ColorfulTheme::default(); @@ -171,7 +170,7 @@ impl DestroyCommand { service.remove(&rule_state); } Err(e) => { - state.write()?; + state.save(&config.state).await?; bail!( "on deletion for `{}` in `{}`: {}", style(&rule_state.name).red(), @@ -193,6 +192,6 @@ impl DestroyCommand { } } - state.write() + state.save(&config.state).await } } diff --git a/src/commands/diff.rs b/src/commands/diff.rs index cf8a72d..1c59129 100644 --- a/src/commands/diff.rs +++ b/src/commands/diff.rs @@ -12,7 +12,6 @@ use logcraft_common::{ ServiceDetections, }, plugins::manager::{PluginActions, PluginManager}, - state::State, }; use serde_json::Value; use std::collections::{HashMap, HashSet}; @@ -148,7 +147,15 @@ impl DiffCommand { let changes = compare_detections(&detections, &returned_rules, &services, true).is_empty(); - if State::read()?.missing_rules(&returned_rules).is_empty() && changes && !has_diff { + if config + .state + .load() + .await? + .missing_rules(&returned_rules, false) + .is_empty() + && changes + && !has_diff + { tracing::info!("no differences found"); }