Skip to content

Commit

Permalink
Make the ip2location argument immutable
Browse files Browse the repository at this point in the history
The APIs taking an ip2location database non longer requires a mutable
reference to the database. This change improves the ergonomics of the
API and removes the need for locking the database.
  • Loading branch information
msk committed Dec 30, 2024
1 parent b2c14a4 commit a00951e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 40 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ Versioning](https://semver.org/spec/v2.0.0.html).
- Added `Account::theme` field to represent user's selected screen color theme
on the user interface.

### Changed

- The APIs taking an ip2location database non longer requires a mutable
reference to the database. This change improves the ergonomics of the API and
removes the need for locking the database.

## [0.33.1] - 2024-12-20

### Fixed
Expand Down
55 changes: 24 additions & 31 deletions src/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ use std::{
fmt::{self},
net::IpAddr,
num::NonZeroU8,
sync::{Arc, Mutex, MutexGuard},
};

use aho_corasick::AhoCorasickBuilder;
Expand Down Expand Up @@ -473,7 +472,7 @@ impl Event {
/// not available.
pub fn matches(
&self,
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
filter: &EventFilter,
) -> Result<(bool, Option<Vec<TriageScore>>)> {
match self {
Expand Down Expand Up @@ -521,7 +520,7 @@ impl Event {

fn address_pair(
&self,
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
filter: &EventFilter,
) -> Result<(Option<IpAddr>, Option<IpAddr>)> {
let mut addr_pair = (None, None);
Expand Down Expand Up @@ -711,7 +710,7 @@ impl Event {

fn kind(
&self,
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
filter: &EventFilter,
) -> Result<Option<&'static str>> {
let mut kind = None;
Expand Down Expand Up @@ -963,25 +962,22 @@ impl Event {
/// # Errors
///
/// Returns an error if matching the event against the filter fails.
#[allow(clippy::needless_pass_by_value)] // function prototype must be the same as other `count_*` functions.
pub fn count_country(
&self,
counter: &mut HashMap<String, usize>,
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
filter: &EventFilter,
) -> Result<()> {
let addr_pair = self.address_pair(locator.clone(), filter)?;
let addr_pair = self.address_pair(locator, filter)?;

let mut src_country = "ZZ".to_string();
let mut dst_country = "ZZ".to_string();
if let Some(mutex) = &locator {
if let Ok(mut guarded_locator) = mutex.lock() {
if let Some(src_addr) = addr_pair.0 {
src_country = find_ip_country(&mut guarded_locator, src_addr);
}
if let Some(dst_addr) = addr_pair.1 {
dst_country = find_ip_country(&mut guarded_locator, dst_addr);
}
if let Some(locator) = locator {
if let Some(src_addr) = addr_pair.0 {
src_country = find_ip_country(locator, src_addr);
}
if let Some(dst_addr) = addr_pair.1 {
dst_country = find_ip_country(locator, dst_addr);
}
}
if src_country != dst_country && addr_pair.0.is_some() && addr_pair.1.is_some() {
Expand All @@ -1008,7 +1004,7 @@ impl Event {
pub fn count_category(
&self,
counter: &mut HashMap<EventCategory, usize>,
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
filter: &EventFilter,
) -> Result<()> {
let mut category = None;
Expand Down Expand Up @@ -1217,7 +1213,7 @@ impl Event {
pub fn count_ip_address(
&self,
counter: &mut HashMap<IpAddr, usize>,
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
filter: &EventFilter,
) -> Result<()> {
let addr_pair = self.address_pair(locator, filter)?;
Expand All @@ -1240,7 +1236,7 @@ impl Event {
pub fn count_ip_address_pair(
&self,
counter: &mut HashMap<(IpAddr, IpAddr), usize>,
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
filter: &EventFilter,
) -> Result<()> {
let addr_pair = self.address_pair(locator, filter)?;
Expand All @@ -1265,10 +1261,10 @@ impl Event {
pub fn count_ip_address_pair_and_kind(
&self,
counter: &mut HashMap<(IpAddr, IpAddr, &'static str), usize>,
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
filter: &EventFilter,
) -> Result<()> {
let addr_pair = self.address_pair(locator.clone(), filter)?;
let addr_pair = self.address_pair(locator, filter)?;
let kind = self.kind(locator, filter)?;

if let Some(src_addr) = addr_pair.0 {
Expand All @@ -1293,7 +1289,7 @@ impl Event {
pub fn count_src_ip_address(
&self,
counter: &mut HashMap<IpAddr, usize>,
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
filter: &EventFilter,
) -> Result<()> {
let addr_pair = self.address_pair(locator, filter)?;
Expand All @@ -1313,7 +1309,7 @@ impl Event {
pub fn count_dst_ip_address(
&self,
counter: &mut HashMap<IpAddr, usize>,
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
filter: &EventFilter,
) -> Result<()> {
let addr_pair = self.address_pair(locator, filter)?;
Expand All @@ -1333,7 +1329,7 @@ impl Event {
pub fn count_kind(
&self,
counter: &mut HashMap<String, usize>,
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
filter: &EventFilter,
) -> Result<()> {
let kind = if let Event::HttpThreat(event) = self {
Expand Down Expand Up @@ -1361,7 +1357,7 @@ impl Event {
pub fn count_level(
&self,
counter: &mut HashMap<NonZeroU8, usize>,
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
filter: &EventFilter,
) -> Result<()> {
let mut level = None;
Expand Down Expand Up @@ -1571,7 +1567,7 @@ impl Event {
&self,
counter: &mut HashMap<u32, usize>,
networks: &[Network],
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
filter: &EventFilter,
) -> Result<()> {
let addr_pair = self.address_pair(locator, filter)?;
Expand Down Expand Up @@ -2536,7 +2532,8 @@ pub enum TrafficDirection {
To,
}

pub fn find_ip_country(locator: &mut ip2location::DB, addr: IpAddr) -> String {
#[must_use]
pub fn find_ip_country(locator: &ip2location::DB, addr: IpAddr) -> String {
locator
.ip_lookup(addr)
.map(|r| get_record_country_short_name(&r))
Expand All @@ -2545,11 +2542,7 @@ pub fn find_ip_country(locator: &mut ip2location::DB, addr: IpAddr) -> String {
.unwrap_or_else(|| "XX".to_string())
}

fn eq_ip_country(
locator: &mut MutexGuard<ip2location::DB>,
addr: IpAddr,
country: [u8; 2],
) -> bool {
fn eq_ip_country(locator: &ip2location::DB, addr: IpAddr, country: [u8; 2]) -> bool {
locator
.ip_lookup(addr)
.ok()
Expand Down
14 changes: 5 additions & 9 deletions src/event/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ use std::{
fmt::{self, Formatter},
net::IpAddr,
num::NonZeroU8,
sync::{Arc, Mutex},
};

use anyhow::{bail, Result};
use anyhow::Result;
use num_traits::ToPrimitive;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -48,7 +47,7 @@ pub(super) trait Match {
/// not available.
fn matches(
&self,
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
filter: &EventFilter,
) -> Result<(bool, Option<Vec<TriageScore>>)> {
if !self.kind_matches(filter) {
Expand Down Expand Up @@ -78,7 +77,7 @@ pub(super) trait Match {
fn other_matches(
&self,
filter: &EventFilter,
locator: Option<Arc<Mutex<ip2location::DB>>>,
locator: Option<&ip2location::DB>,
) -> Result<(bool, Option<Vec<TriageScore>>)> {
if let Some(customers) = &filter.customers {
if customers.iter().all(|customer| {
Expand Down Expand Up @@ -138,12 +137,9 @@ pub(super) trait Match {

if let Some(countries) = &filter.countries {
if let Some(locator) = locator {
let Ok(mut locator) = locator.lock() else {
bail!("IP location database unavailable")
};
if countries.iter().all(|country| {
!eq_ip_country(&mut locator, self.src_addr(), *country)
&& !eq_ip_country(&mut locator, self.dst_addr(), *country)
!eq_ip_country(locator, self.src_addr(), *country)
&& !eq_ip_country(locator, self.dst_addr(), *country)
}) {
return Ok((false, None));
}
Expand Down

0 comments on commit a00951e

Please sign in to comment.