From d486a5977763e2ac024eb73dc78ccb947217c608 Mon Sep 17 00:00:00 2001 From: qjerome Date: Wed, 26 Jun 2024 18:12:30 +0200 Subject: [PATCH] chg: [kunai] use tokio tasks for all workers --- kunai/src/bin/main.rs | 81 +++++++++++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 29 deletions(-) diff --git a/kunai/src/bin/main.rs b/kunai/src/bin/main.rs index db3925c..45a37eb 100644 --- a/kunai/src/bin/main.rs +++ b/kunai/src/bin/main.rs @@ -27,6 +27,7 @@ use kunai_common::inspect_err; use kunai_common::version::KernelVersion; use log::LevelFilter; use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc::error::SendError; use std::borrow::Cow; use std::collections::{HashMap, HashSet, VecDeque}; @@ -38,12 +39,10 @@ use std::net::IpAddr; use std::path::{Path, PathBuf}; use std::str::FromStr; -use std::sync::mpsc::{channel, Receiver, SendError, Sender}; use std::sync::{Arc, RwLock}; -use std::thread::JoinHandle; +use std::process; use std::time::Duration; -use std::{process, thread}; use aya::{ include_bytes_aligned, @@ -57,7 +56,7 @@ use aya::{BpfLoader, VerifierLogLevel}; use log::{debug, error, info, warn}; -use tokio::sync::{Barrier, Mutex}; +use tokio::sync::{mpsc, Barrier, Mutex}; use tokio::{signal, task, time}; use kunai::cache::*; @@ -135,7 +134,6 @@ struct EventConsumer { tasks: HashMap, resolved: HashMap, output: std::fs::File, - handle: Option>>, } impl EventConsumer { @@ -165,7 +163,6 @@ impl EventConsumer { .append(true) .create(true) .open(output)?, - handle: None, }; // loading rules in the engine @@ -204,17 +201,15 @@ impl EventConsumer { /// Listen for events on the receiver pub fn consume( self, - receiver: Receiver, + mut receiver: mpsc::Receiver, ) -> anyhow::Result>> { let ep = Arc::new(RwLock::new(self)); let shared = Arc::clone(&ep); // we spawn thread only if there is a receiver - let h = thread::spawn(move || { - // the thread must drop CLONE_FS in order to be able to navigate in mnt namespaces - unshare(libc::CLONE_FS)?; - while let Ok(mut enc) = receiver.recv() { + tokio::spawn(async move { + while let Some(mut enc) = receiver.recv().await { // lock error is a symptom of implementation mistake so we panic let mut ep = shared.write().unwrap(); ep.handle_event(&mut enc); @@ -223,9 +218,6 @@ impl EventConsumer { Ok::<(), anyhow::Error>(()) }); - // lock error is a symptom of implementation mistake so we panic - ep.write().unwrap().handle = Some(h); - Ok(ep) } @@ -1261,7 +1253,7 @@ struct EventProducer { config: Config, batch: usize, pipe: VecDeque, - sender: Sender, + sender: mpsc::Sender, filter: Filter, stats: AyaHashMap, perf_array: AsyncPerfEventArray, @@ -1281,7 +1273,7 @@ impl EventProducer { pub fn with_params( bpf: &mut Bpf, config: Config, - sender: Sender, + sender: mpsc::Sender, ) -> anyhow::Result { let filter = (&config).try_into()?; let stats_map: AyaHashMap<_, Type, u64> = @@ -1370,15 +1362,15 @@ impl EventProducer { .expect("pop_front should never fail here"); // send event to event processor - self.sender.send(enc_evt).unwrap(); + self.sender.send(enc_evt).await.unwrap(); counter -= 1; } } #[inline] - fn send_event(&self, event: Event) -> Result<(), SendError> { - self.sender.send(EncodedEvent::from_event(event)) + async fn send_event(&self, event: Event) -> Result<(), SendError> { + self.sender.send(EncodedEvent::from_event(event)).await } /// function used to pre-process some targetted events where time is critical and for which @@ -1449,20 +1441,22 @@ impl EventProducer { /// this method pass through some events directly to the event processor /// only events that can be processed asynchronously should be passed through - fn pass_through_events(&self, e: &EncodedEvent) { + async fn pass_through_events(&self, e: &EncodedEvent) { let i = unsafe { e.info() }.unwrap(); match i.etype { Type::Execve | Type::ExecveScript => { let event = event!(e, bpf_events::ExecveEvent).unwrap(); for e in bpf_events::HashEvent::all_from_execve(event) { - self.send_event(e).unwrap() + self.send_event(e).await.unwrap(); } } Type::MmapExec => { let event = event!(e, bpf_events::MmapExecEvent).unwrap(); - self.send_event(bpf_events::HashEvent::from(event)).unwrap(); + self.send_event(bpf_events::HashEvent::from(event)) + .await + .unwrap(); } _ => {} @@ -1567,7 +1561,7 @@ impl EventProducer { } // passing through some events used for correlation - er.pass_through_events(&dec); + er.pass_through_events(&dec).await; // we must get the event type here because we eventually // changed it @@ -1673,6 +1667,12 @@ struct Cli { #[arg(long)] debug: bool, + /// Number of worker threads used by kunai. By default kunai runs + /// in a single threaded mode. If you want to use all available + /// threads, set this option to 0. + #[arg(short, long)] + jobs: Option, + /// Specify a configuration file to use. Command line options supersede the ones specified in the configuration file. #[arg(short, long, value_name = "FILE")] config: Option, @@ -1887,7 +1887,7 @@ impl Command { // we start event reader and event processor before loading the programs // if we load the programs first we might have some event lost errors - let (sender, receiver) = channel::(); + let (sender, receiver) = mpsc::channel(512); // we start consumer EventConsumer::with_config(conf.clone())?.consume(receiver)?; @@ -1931,9 +1931,7 @@ impl Command { } } -// todo: make single-threaded / multi-threaded available in configuration -#[tokio::main(flavor = "current_thread")] -async fn main() -> Result<(), anyhow::Error> { +fn main() -> Result<(), anyhow::Error> { let c = { let c: clap::Command = Cli::command(); let styles = styling::Styles::styled() @@ -2060,9 +2058,34 @@ async fn main() -> Result<(), anyhow::Error> { } } + // create the tokio runtime builder + let mut builder = { + match cli.jobs { + Some(workers) => { + let mut b = tokio::runtime::Builder::new_multi_thread(); + // if number of workers is positive we set desired + // number of workers. If not tokio default will be + // taken (i.e. number of available threads). + if workers > 0 { + b.worker_threads(workers); + } + b + } + None => tokio::runtime::Builder::new_current_thread(), + } + }; + + // creating tokio runtime + let runtime = builder + // the thread must drop CLONE_FS in order to be able to navigate in mnt namespaces + .on_thread_start(|| unshare(libc::CLONE_FS).unwrap()) + .enable_all() + .build() + .unwrap(); + // We finished preparing config match cli.command { - Some(Command::Replay(o)) => return Command::replay(conf, o), - _ => Command::run(conf, verifier_level).await, + Some(Command::Replay(o)) => Command::replay(conf, o), + _ => runtime.block_on(Command::run(conf, verifier_level)), } }