Skip to content

Commit

Permalink
chg: [kunai] use tokio tasks for all workers
Browse files Browse the repository at this point in the history
  • Loading branch information
qjerome committed Jun 26, 2024
1 parent 7556dfe commit d486a59
Showing 1 changed file with 52 additions and 29 deletions.
81 changes: 52 additions & 29 deletions kunai/src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use kunai_common::inspect_err;
use kunai_common::version::KernelVersion;
use log::LevelFilter;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::error::SendError;

use std::borrow::Cow;
use std::collections::{HashMap, HashSet, VecDeque};
Expand All @@ -38,12 +39,10 @@ use std::net::IpAddr;
use std::path::{Path, PathBuf};
use std::str::FromStr;

use std::sync::mpsc::{channel, Receiver, SendError, Sender};
use std::sync::{Arc, RwLock};

use std::thread::JoinHandle;
use std::process;
use std::time::Duration;
use std::{process, thread};

use aya::{
include_bytes_aligned,
Expand All @@ -57,7 +56,7 @@ use aya::{BpfLoader, VerifierLogLevel};

use log::{debug, error, info, warn};

use tokio::sync::{Barrier, Mutex};
use tokio::sync::{mpsc, Barrier, Mutex};
use tokio::{signal, task, time};

use kunai::cache::*;
Expand Down Expand Up @@ -135,7 +134,6 @@ struct EventConsumer {
tasks: HashMap<TaskKey, Task>,
resolved: HashMap<IpAddr, String>,
output: std::fs::File,
handle: Option<JoinHandle<Result<(), anyhow::Error>>>,
}

impl EventConsumer {
Expand Down Expand Up @@ -165,7 +163,6 @@ impl EventConsumer {
.append(true)
.create(true)
.open(output)?,
handle: None,
};

// loading rules in the engine
Expand Down Expand Up @@ -204,17 +201,15 @@ impl EventConsumer {
/// Listen for events on the receiver
pub fn consume(
self,
receiver: Receiver<EncodedEvent>,
mut receiver: mpsc::Receiver<EncodedEvent>,
) -> anyhow::Result<Arc<RwLock<EventConsumer>>> {
let ep = Arc::new(RwLock::new(self));

let shared = Arc::clone(&ep);

// we spawn thread only if there is a receiver
let h = thread::spawn(move || {
// the thread must drop CLONE_FS in order to be able to navigate in mnt namespaces
unshare(libc::CLONE_FS)?;
while let Ok(mut enc) = receiver.recv() {
tokio::spawn(async move {
while let Some(mut enc) = receiver.recv().await {
// lock error is a symptom of implementation mistake so we panic
let mut ep = shared.write().unwrap();
ep.handle_event(&mut enc);
Expand All @@ -223,9 +218,6 @@ impl EventConsumer {
Ok::<(), anyhow::Error>(())
});

// lock error is a symptom of implementation mistake so we panic
ep.write().unwrap().handle = Some(h);

Ok(ep)
}

Expand Down Expand Up @@ -1261,7 +1253,7 @@ struct EventProducer {
config: Config,
batch: usize,
pipe: VecDeque<EncodedEvent>,
sender: Sender<EncodedEvent>,
sender: mpsc::Sender<EncodedEvent>,
filter: Filter,
stats: AyaHashMap<MapData, Type, u64>,
perf_array: AsyncPerfEventArray<MapData>,
Expand All @@ -1281,7 +1273,7 @@ impl EventProducer {
pub fn with_params(
bpf: &mut Bpf,
config: Config,
sender: Sender<EncodedEvent>,
sender: mpsc::Sender<EncodedEvent>,
) -> anyhow::Result<Self> {
let filter = (&config).try_into()?;
let stats_map: AyaHashMap<_, Type, u64> =
Expand Down Expand Up @@ -1370,15 +1362,15 @@ impl EventProducer {
.expect("pop_front should never fail here");

// send event to event processor
self.sender.send(enc_evt).unwrap();
self.sender.send(enc_evt).await.unwrap();

counter -= 1;
}
}

#[inline]
fn send_event<T>(&self, event: Event<T>) -> Result<(), SendError<EncodedEvent>> {
self.sender.send(EncodedEvent::from_event(event))
async fn send_event<T>(&self, event: Event<T>) -> Result<(), SendError<EncodedEvent>> {
self.sender.send(EncodedEvent::from_event(event)).await
}

/// function used to pre-process some targetted events where time is critical and for which
Expand Down Expand Up @@ -1449,20 +1441,22 @@ impl EventProducer {

/// this method pass through some events directly to the event processor
/// only events that can be processed asynchronously should be passed through
fn pass_through_events(&self, e: &EncodedEvent) {
async fn pass_through_events(&self, e: &EncodedEvent) {
let i = unsafe { e.info() }.unwrap();

match i.etype {
Type::Execve | Type::ExecveScript => {
let event = event!(e, bpf_events::ExecveEvent).unwrap();
for e in bpf_events::HashEvent::all_from_execve(event) {
self.send_event(e).unwrap()
self.send_event(e).await.unwrap();
}
}

Type::MmapExec => {
let event = event!(e, bpf_events::MmapExecEvent).unwrap();
self.send_event(bpf_events::HashEvent::from(event)).unwrap();
self.send_event(bpf_events::HashEvent::from(event))
.await
.unwrap();
}

_ => {}
Expand Down Expand Up @@ -1567,7 +1561,7 @@ impl EventProducer {
}

// passing through some events used for correlation
er.pass_through_events(&dec);
er.pass_through_events(&dec).await;

// we must get the event type here because we eventually
// changed it
Expand Down Expand Up @@ -1673,6 +1667,12 @@ struct Cli {
#[arg(long)]
debug: bool,

/// Number of worker threads used by kunai. By default kunai runs
/// in a single threaded mode. If you want to use all available
/// threads, set this option to 0.
#[arg(short, long)]
jobs: Option<usize>,

/// Specify a configuration file to use. Command line options supersede the ones specified in the configuration file.
#[arg(short, long, value_name = "FILE")]
config: Option<PathBuf>,
Expand Down Expand Up @@ -1887,7 +1887,7 @@ impl Command {

// we start event reader and event processor before loading the programs
// if we load the programs first we might have some event lost errors
let (sender, receiver) = channel::<EncodedEvent>();
let (sender, receiver) = mpsc::channel(512);

// we start consumer
EventConsumer::with_config(conf.clone())?.consume(receiver)?;
Expand Down Expand Up @@ -1931,9 +1931,7 @@ impl Command {
}
}

// todo: make single-threaded / multi-threaded available in configuration
#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), anyhow::Error> {
fn main() -> Result<(), anyhow::Error> {
let c = {
let c: clap::Command = Cli::command();
let styles = styling::Styles::styled()
Expand Down Expand Up @@ -2060,9 +2058,34 @@ async fn main() -> Result<(), anyhow::Error> {
}
}

// create the tokio runtime builder
let mut builder = {
match cli.jobs {
Some(workers) => {
let mut b = tokio::runtime::Builder::new_multi_thread();
// if number of workers is positive we set desired
// number of workers. If not tokio default will be
// taken (i.e. number of available threads).
if workers > 0 {
b.worker_threads(workers);
}
b
}
None => tokio::runtime::Builder::new_current_thread(),
}
};

// creating tokio runtime
let runtime = builder
// the thread must drop CLONE_FS in order to be able to navigate in mnt namespaces
.on_thread_start(|| unshare(libc::CLONE_FS).unwrap())
.enable_all()
.build()
.unwrap();

// We finished preparing config
match cli.command {
Some(Command::Replay(o)) => return Command::replay(conf, o),
_ => Command::run(conf, verifier_level).await,
Some(Command::Replay(o)) => Command::replay(conf, o),
_ => runtime.block_on(Command::run(conf, verifier_level)),
}
}

0 comments on commit d486a59

Please sign in to comment.