Skip to content

Commit

Permalink
Refactor Observer (#14)
Browse files Browse the repository at this point in the history
This PR does three things:

1. Removes the nested if loops and makes that part readable by adding
   more functions.

2. Adds some tests (this needs more work to add an exact tcp packet like
   unit test. Also introduces a stop mechanism to the observer for the
   tests primarily.

3. Make the hashmap ttl-like so abandoned SYN packets are not stored
   forever.
  • Loading branch information
sudarshan-reddy authored Aug 6, 2024
1 parent 05bc0d6 commit a5d1539
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 34 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Aragorn

[![Rust](https://github.com/sudarshan-reddy/aragorn/actions/workflows/rust.yml/badge.svg)](https://github.com/sudarshan-reddy/aragorn/actions/workflows/rust.yml)

Proof of Concept of a a watcher tool that runs on user-space
and monitors tcpdump for predefine-able patterns and has a
configurable module to act upon these observed metrics.
Expand Down
8 changes: 7 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,18 @@ async fn main() -> io::Result<()> {
let redis_handler = Arc::new(Mutex::new(RespHandler::new(args.redis_port)));
let active_packet_reader =
LivePacketReader::new(&args.interface).expect("Failed to create packet reader");
let observer = Observer::new();
let observer = Observer::new(tun::ObsConfig {
..Default::default()
});

observer.start_cleanup();

observer
.capture_packets(active_packet_reader, redis_handler)
.await
.unwrap();

observer.stop();

Ok(())
}
228 changes: 195 additions & 33 deletions src/tun.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use pnet::packet::Packet;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Mutex;
use tokio::sync::{watch, Mutex};
use tokio::time::Duration;

pub trait Handler<T>: Send + Sync {
async fn port(&self) -> u16;
Expand All @@ -20,21 +21,55 @@ pub trait PacketReader {
}

pub struct Observer {
// IMMEDIATE TODO: Need to find a way to set a TTL here since we dont want to store all SYN packets that we never receive an ACK for.
// This isn't very simple because we have a few edge cases:
// 1. We might have a rogue SYN packet that we dont want to track anyway.
// 2. A legitimate SYN packet might timeout and we might not receive an ACK for it. We need to
// record/observe this.
//
// TODO (for later): This should also be an LRU perhaps so we dont grow indiscriminately.
syn_packets: Arc<Mutex<HashMap<u32, Instant>>>,
ttl: Duration,
stop_tx: watch::Sender<bool>,
stop_rx: watch::Receiver<bool>,
}

pub struct ObsConfig {
pub ttl: Duration,
pub cleanup_interval: Duration,
}

impl Default for ObsConfig {
fn default() -> Self {
ObsConfig {
ttl: Duration::from_secs(5),
cleanup_interval: Duration::from_secs(1),
}
}
}

impl Observer {
pub fn new() -> Self {
Observer {
/// Create a new Observer instance.
/// Default TTL is 5 seconds.
/// Default cleanup interval is 1 second.
pub fn new(cfg: ObsConfig) -> Self {
let (stop_tx, stop_rx) = watch::channel(false);
let obs = Observer {
syn_packets: Arc::new(Mutex::new(HashMap::new())),
}
ttl: cfg.ttl,
stop_tx,
stop_rx,
};

obs
}

pub fn start_cleanup(&self) {
let syn_packets = self.syn_packets.clone();
let ttl = self.ttl;
let cleanup_interval = async move {
loop {
tokio::time::sleep(ttl).await;
let mut syn_packets = syn_packets.lock().await;
let now = Instant::now();
syn_packets.retain(|_, v| now.duration_since(*v) < ttl);
}
};
tokio::spawn(cleanup_interval);
}

pub async fn capture_packets<T>(
Expand All @@ -45,33 +80,66 @@ impl Observer {
where
T: Send + 'static,
{
let mut stop_rx = self.stop_rx.clone();
loop {
if let Some(packet) = reader.read_packet() {
// TODO: This isnt the most reliable way to measure time.
// Ideally we should be using the timestamp from the packet header/kernel.
// But this isnt easy enough. One way to do this is to set SO_TIMESTAMP on the socket
// and then read the timestamp from the packet header. For the purpose of the
// POC and simplicity, we are using this method temporarily. Moreover, this also
// doesn't work if we are playing back a pcap file.
let timestamp = Instant::now();
if let Some(ethernet_packet) = EthernetPacket::new(&packet) {
if ethernet_packet.get_ethertype() == EtherTypes::Ipv4 {
if let Some(ipv4_packet) = Ipv4Packet::new(ethernet_packet.payload()) {
match ipv4_packet.get_next_level_protocol() {
IpNextHeaderProtocols::Tcp => {
let res = self
.handle_tcp_packet(&handler, ipv4_packet, timestamp)
.await;
if res.is_err() {
eprintln!("Failed to handle TCP packet: {:?}", res);
}
}
_ => {}
}
}
tokio::select! {
_ = stop_rx.changed() => {
if *stop_rx.borrow() {
break;
}
}
Some(packet) = async { reader.read_packet() } => {
self.handle_packet(&handler, packet).await?;
}
}
}
Ok(())
}

async fn handle_packet<T>(
&self,
handler: &Arc<Mutex<impl Handler<T>>>,
packet: Vec<u8>,
) -> Result<()>
where
T: Send + 'static,
{
// TODO: This isnt the most reliable way to measure time.
// Ideally we should be using the timestamp from the packet header/kernel.
// But this isnt easy enough. One way to do this is to set SO_TIMESTAMP on the socket
// and then read the timestamp from the packet header. For the purpose of the
// POC and simplicity, we are using this method temporarily. Moreover, this also
// doesn't work if we are playing back a pcap file.
let timestamp = Instant::now();
if let Some(ethernet_packet) = EthernetPacket::new(&packet) {
match ethernet_packet.get_ethertype() {
EtherTypes::Ipv4 => {
if let Some(ipv4_packet) = Ipv4Packet::new(ethernet_packet.payload()) {
self.handle_ipv4_packet(handler, ipv4_packet, timestamp)
.await?;
}
}
_ => {}
}
}
Ok(())
}

async fn handle_ipv4_packet<T>(
&self,
handler: &Arc<Mutex<impl Handler<T>>>,
ipv4_packet: Ipv4Packet<'_>,
timestamp: Instant,
) -> Result<()>
where
T: Send + 'static,
{
match ipv4_packet.get_next_level_protocol() {
IpNextHeaderProtocols::Tcp => {
self.handle_tcp_packet(handler, ipv4_packet, timestamp)
.await
}
_ => Ok(()),
}
}

Expand Down Expand Up @@ -139,9 +207,103 @@ impl Observer {
}
None
}

pub fn stop(&self) {
let _ = self.stop_tx.send(true).unwrap();
}
}

#[derive(Debug)]
pub struct Metrics {
pub identifier: u32,
pub latency: Option<std::time::Duration>,
}

#[cfg(test)]
mod tests {
use super::*;

// Mock the PacketReader trait
struct MockPacketReader {
packets: Vec<Vec<u8>>,
}

impl PacketReader for MockPacketReader {
fn read_packet(&mut self) -> Option<Vec<u8>> {
self.packets.pop()
}
}

#[tokio::test]
async fn test_get_metrics() {
let obs = Observer::new(ObsConfig::default());
let tcp_packet = TcpPacket::new(&[0; 20]).unwrap();
let timestamp = Instant::now();
let port = 1234;
let metrics = obs.get_metrics(&tcp_packet, timestamp, port).await;
assert!(metrics.is_none());
}

struct MockHandler;

impl MockHandler {
fn new() -> Self {
MockHandler
}
}

impl Handler<Vec<u8>> for MockHandler {
async fn port(&self) -> u16 {
1234
}

async fn parse_packet(&self, buf: Vec<u8>) -> Result<Vec<u8>> {
Ok(buf)
}

async fn process(&self, _input: Vec<u8>, _metrics: Option<Metrics>) -> Result<()> {
Ok(())
}
}

#[tokio::test]
async fn test_capture_packets() {
let reader = MockPacketReader {
// TODO: send a fake tcp packet
packets: vec![vec![
0x45, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0x7f, 0x00,
0x00, 0x01, 0x7f, 0x00, 0x00, 0x01,
]],
};
let handler = Arc::new(Mutex::new(MockHandler::new()));
let obs = Arc::new(Mutex::new(Observer::new(ObsConfig::default())));

let stop_tx = obs.lock().await.stop_tx.clone();
// Clone the Arc and receiver to pass into the spawned task
let obs_clone = Arc::clone(&obs);

// Start the packet capture in a separate task
let capture_task = tokio::spawn(async move {
obs_clone
.lock()
.await
.capture_packets(reader, handler)
.await
});

// Run the capture for a short duration and then signal stop
tokio::time::sleep(Duration::from_secs(1)).await;
let _ = stop_tx.send(true);

// Wait for the capture task to complete
let res = capture_task.await;

// Assert that the result is Ok
assert!(res.is_ok());

// Look at whats in the syn_packets hashmap
let obs = obs.lock().await;
let syn_packets = obs.syn_packets.lock().await;
assert_eq!(syn_packets.len(), 0);
}
}

0 comments on commit a5d1539

Please sign in to comment.