Skip to content

Commit

Permalink
Refactor to have a plugins folder (#15)
Browse files Browse the repository at this point in the history
This refactors and demonstrates how a lot of plugins can be added to
this repository.
  • Loading branch information
sudarshan-reddy authored Aug 7, 2024
1 parent a5d1539 commit 964ac21
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 93 deletions.
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
mod live_packet_reader;
mod redis;
mod plugin;
mod tun;

use clap::Parser;
use env_logger;
use live_packet_reader::LivePacketReader;
use redis::RespHandler;
use plugin::redis::handler::RespHandler;
use std::io;
use std::sync::Arc;
use tokio::sync::Mutex;
Expand Down
1 change: 1 addition & 0 deletions src/plugin/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod redis;
69 changes: 69 additions & 0 deletions src/plugin/redis/handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use anyhow::Result;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex;

use crate::tun::{Metrics, Plugin};

use super::resp_parser::{parse_resp, RespValue};

pub struct RespHandler {
port: u16,
key_map: Arc<Mutex<HashMap<u32, RespValue>>>,
}

impl RespHandler {
pub fn new(port: u16) -> Self {
RespHandler {
port,
key_map: Arc::new(Mutex::new(HashMap::new())),
}
}
}

impl Plugin<RespValue> for RespHandler {
async fn port(&self) -> u16 {
self.port
}

async fn parse_packet(&self, buf: Vec<u8>) -> Result<RespValue> {
let resp = parse_resp(&buf).map_err(|_| anyhow::anyhow!("Failed to parse packet"))?;
Ok(resp.1)
}

async fn process(&self, input: RespValue, metrics: Option<Metrics>) -> Result<()> {
// Return if none and unpack the metrics
if metrics.is_none() {
return Ok(());
}
// We already know that metrics is not None
let metrics = metrics.unwrap();

let mut store = self.key_map.lock().await;
if !store.contains_key(&metrics.identifier) {
// Check if the identifier exists and save it in the store
store.insert(metrics.identifier, input.clone());
}

if let Some(latency) = metrics.latency {
let status = if input.to_string().contains("ERR") {
"ERR"
} else {
"OK"
};
// Print the latency and the key
let stored_value = store
.get(&metrics.identifier)
.ok_or_else(|| anyhow::anyhow!("Failed to get value from store"))?;
println!(
"Key: {}, Latency: {}ms, Status: {}",
stored_value.key.as_ref().unwrap(),
latency.as_millis(),
status,
);
// clean up the store
store.remove(&metrics.identifier);
}

Ok(())
}
}
2 changes: 2 additions & 0 deletions src/plugin/redis/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod handler;
mod resp_parser;
67 changes: 1 addition & 66 deletions src/redis.rs → src/plugin/redis/resp_parser.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
use anyhow::Result;
use nom::{
branch::alt,
bytes::complete::{tag, take, take_while},
character::complete::char,
IResult,
};
use std::{collections::HashMap, str, sync::Arc};
use tokio::sync::Mutex;

use crate::tun::{Handler, Metrics};
use std::str;

#[derive(Debug, Clone, PartialEq)]
pub struct RespValue {
Expand Down Expand Up @@ -232,65 +229,3 @@ mod tests {
// assert_eq!(parse_array(input).unwrap().1, expected);
//}
}

pub struct RespHandler {
port: u16,
key_map: Arc<Mutex<HashMap<u32, RespValue>>>,
}

impl RespHandler {
pub fn new(port: u16) -> Self {
RespHandler {
port,
key_map: Arc::new(Mutex::new(HashMap::new())),
}
}
}

impl Handler<RespValue> for RespHandler {
async fn port(&self) -> u16 {
self.port
}

async fn parse_packet(&self, buf: Vec<u8>) -> Result<RespValue> {
let resp = parse_resp(&buf).map_err(|_| anyhow::anyhow!("Failed to parse packet"))?;
Ok(resp.1)
}

async fn process(&self, input: RespValue, metrics: Option<Metrics>) -> Result<()> {
// Return if none and unpack the metrics
if metrics.is_none() {
return Ok(());
}
// We already know that metrics is not None
let metrics = metrics.unwrap();

let mut store = self.key_map.lock().await;
if !store.contains_key(&metrics.identifier) {
// Check if the identifier exists and save it in the store
store.insert(metrics.identifier, input.clone());
}

if let Some(latency) = metrics.latency {
let status = if input.to_string().contains("ERR") {
"ERR"
} else {
"OK"
};
// Print the latency and the key
let stored_value = store
.get(&metrics.identifier)
.ok_or_else(|| anyhow::anyhow!("Failed to get value from store"))?;
println!(
"Key: {}, Latency: {}ms, Status: {}",
stored_value.key.as_ref().unwrap(),
latency.as_millis(),
status,
);
// clean up the store
store.remove(&metrics.identifier);
}

Ok(())
}
}
48 changes: 23 additions & 25 deletions src/tun.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use std::time::Instant;
use tokio::sync::{watch, Mutex};
use tokio::time::Duration;

pub trait Handler<T>: Send + Sync {
/// Plugin trait that defines the interface for a plugin.
/// A plugin is a module that can parse a packet, process it and send the result to a handler.
/// The plugin can be used to implement different types of handlers like a Redis handler, a HTTP handler etc.
pub trait Plugin<T>: Send + Sync {
async fn port(&self) -> u16;
async fn parse_packet(&self, buf: Vec<u8>) -> Result<T>;
async fn process(&self, input: T, metrics: Option<Metrics>) -> Result<()>;
Expand All @@ -21,7 +24,6 @@ pub trait PacketReader {
}

pub struct Observer {
// TODO (for later): This should also be an LRU perhaps so we dont grow indiscriminately.
syn_packets: Arc<Mutex<HashMap<u32, Instant>>>,
ttl: Duration,
stop_tx: watch::Sender<bool>,
Expand Down Expand Up @@ -75,7 +77,7 @@ impl Observer {
pub async fn capture_packets<T>(
&self,
mut reader: impl PacketReader,
handler: Arc<Mutex<impl Handler<T>>>,
plugin: Arc<Mutex<impl Plugin<T>>>,
) -> Result<()>
where
T: Send + 'static,
Expand All @@ -89,7 +91,7 @@ impl Observer {
}
}
Some(packet) = async { reader.read_packet() } => {
self.handle_packet(&handler, packet).await?;
self.handle_packet(&plugin, packet).await?;
}
}
}
Expand All @@ -98,7 +100,7 @@ impl Observer {

async fn handle_packet<T>(
&self,
handler: &Arc<Mutex<impl Handler<T>>>,
plugin: &Arc<Mutex<impl Plugin<T>>>,
packet: Vec<u8>,
) -> Result<()>
where
Expand All @@ -115,7 +117,7 @@ impl Observer {
match ethernet_packet.get_ethertype() {
EtherTypes::Ipv4 => {
if let Some(ipv4_packet) = Ipv4Packet::new(ethernet_packet.payload()) {
self.handle_ipv4_packet(handler, ipv4_packet, timestamp)
self.handle_ipv4_packet(plugin, ipv4_packet, timestamp)
.await?;
}
}
Expand All @@ -127,7 +129,7 @@ impl Observer {

async fn handle_ipv4_packet<T>(
&self,
handler: &Arc<Mutex<impl Handler<T>>>,
plugin: &Arc<Mutex<impl Plugin<T>>>,
ipv4_packet: Ipv4Packet<'_>,
timestamp: Instant,
) -> Result<()>
Expand All @@ -136,16 +138,15 @@ impl Observer {
{
match ipv4_packet.get_next_level_protocol() {
IpNextHeaderProtocols::Tcp => {
self.handle_tcp_packet(handler, ipv4_packet, timestamp)
.await
self.handle_tcp_packet(plugin, ipv4_packet, timestamp).await
}
_ => Ok(()),
}
}

async fn handle_tcp_packet<T>(
&self,
handler: &Arc<Mutex<impl Handler<T>>>,
plugin: &Arc<Mutex<impl Plugin<T>>>,
ipv4_packet: Ipv4Packet<'_>,
timestamp: Instant,
) -> Result<()>
Expand All @@ -154,7 +155,7 @@ impl Observer {
{
let tcp_packet = TcpPacket::new(ipv4_packet.payload())
.ok_or_else(|| anyhow::anyhow!("Failed to parse TCP packet from IPv4 payload"))?;
let port = handler.lock().await.port().await;
let port = plugin.lock().await.port().await;
let dst_port = tcp_packet.get_destination();
let src_port = tcp_packet.get_source();
if dst_port != port && src_port != port {
Expand All @@ -168,8 +169,8 @@ impl Observer {
return Ok(()); // Skip if payload is empty
}

let parsed_packet = handler.lock().await.parse_packet(payload.to_vec()).await?;
handler.lock().await.process(parsed_packet, metrics).await
let parsed_packet = plugin.lock().await.parse_packet(payload.to_vec()).await?;
plugin.lock().await.process(parsed_packet, metrics).await
}

async fn get_metrics(
Expand Down Expand Up @@ -244,15 +245,15 @@ mod tests {
assert!(metrics.is_none());
}

struct MockHandler;
struct MockPlugin;

impl MockHandler {
impl MockPlugin {
fn new() -> Self {
MockHandler
MockPlugin
}
}

impl Handler<Vec<u8>> for MockHandler {
impl Plugin<Vec<u8>> for MockPlugin {
async fn port(&self) -> u16 {
1234
}
Expand All @@ -275,21 +276,18 @@ mod tests {
0x00, 0x01, 0x7f, 0x00, 0x00, 0x01,
]],
};
let handler = Arc::new(Mutex::new(MockHandler::new()));
let plugin = Arc::new(Mutex::new(MockPlugin::new()));
let obs = Arc::new(Mutex::new(Observer::new(ObsConfig::default())));

let stop_tx = obs.lock().await.stop_tx.clone();
// Clone the Arc and receiver to pass into the spawned task
let obs_clone = Arc::clone(&obs);

// Start the packet capture in a separate task
let capture_task = tokio::spawn(async move {
obs_clone
.lock()
.await
.capture_packets(reader, handler)
.await
});
let capture_task =
tokio::spawn(
async move { obs_clone.lock().await.capture_packets(reader, plugin).await },
);

// Run the capture for a short duration and then signal stop
tokio::time::sleep(Duration::from_secs(1)).await;
Expand Down

0 comments on commit 964ac21

Please sign in to comment.