From 6baeb6dc0122d4d5c0fbe8af1442930bbc7320c1 Mon Sep 17 00:00:00 2001 From: Vladislav Mamon Date: Mon, 26 Feb 2024 02:41:30 +0300 Subject: [PATCH] feat: mvp --- Cargo.toml | 10 ++++ src/cli.rs | 81 +++++++++++++++++++++++++++ src/coord.rs | 73 ++++++++++++++++++++++++ src/filters.rs | 65 ++++++++++++++++++++++ src/lib.rs | 5 ++ src/main.rs | 71 +++++++++++++++++++++++- src/pinger.rs | 98 +++++++++++++++++++++++++++++++++ src/relays.rs | 147 +++++++++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 548 insertions(+), 2 deletions(-) create mode 100644 src/cli.rs create mode 100644 src/coord.rs create mode 100644 src/filters.rs create mode 100644 src/lib.rs create mode 100644 src/pinger.rs create mode 100644 src/relays.rs diff --git a/Cargo.toml b/Cargo.toml index 9297584..8bb4fb4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,16 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[dependencies] +anyhow = "1.0.80" +clap = { version = "4.5.1", features = ["derive"] } +indicatif = "0.17.8" +reqwest = { version = "0.11.24", features = ["json"] } +serde = { version = "1.0.197", features = ["derive"] } +serde_json = "1.0.114" +thiserror = "1.0.57" +tokio = { version = "1.36.0", features = ["full"] } + [profile.release] lto = "thin" panic = "abort" diff --git a/src/cli.rs b/src/cli.rs new file mode 100644 index 0000000..2362801 --- /dev/null +++ b/src/cli.rs @@ -0,0 +1,81 @@ +use std::time::Duration; + +use clap::builder::PossibleValue; +use clap::{Parser, ValueEnum}; +use indicatif::{ProgressBar, ProgressStyle}; + +use crate::relays::Protocol; + +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +pub struct Cli { + /// Filter servers by used protocol. + #[arg(short, long, value_parser = clap::value_parser!(Protocol))] + pub protocol: Option, + + /// Filter servers by maximum physical distance (in km). + #[arg(short, long, default_value_t = 500)] + pub distance: usize, + + /// Filter servers by maximum rtt (in ms). + #[arg(short, long)] + pub rtt: Option, + + /// How many pings to send for each relay. + #[arg(short, long, default_value_t = 4)] + pub count: usize, + + /// Specify ping timeout (in ms). + #[arg(long, default_value_t = 750)] + pub timeout: u64, + + /// Specify the latitude. + #[arg(long, requires = "longitude")] + pub latitude: Option, + + /// Specify the longitude. + #[arg(long, requires = "latitude")] + pub longitude: Option, +} + +impl ValueEnum for Protocol { + fn value_variants<'a>() -> &'a [Self] { + &[Self::OpenVPN, Self::WireGuard] + } + + fn to_possible_value(&self) -> Option { + Some(match self { + | Protocol::OpenVPN => PossibleValue::new("openvpn"), + | Protocol::WireGuard => PossibleValue::new("wireguard"), + }) + } +} + +/// Small wrapper around the `indicatif` spinner. +pub struct Spinner { + spinner: ProgressBar, +} + +impl Spinner { + pub fn new() -> Self { + let style = ProgressStyle::default_spinner() + .tick_strings(&[" ", "· ", "·· ", "···", " ··", " ·", " "]); + + let spinner = ProgressBar::new_spinner(); + + spinner.set_style(style); + spinner.enable_steady_tick(Duration::from_millis(150)); + + Self { spinner } + } + + /// Sets the message of the spinner. + pub fn set_message(&self, message: &'static str) { + self.spinner.set_message(message); + } + + /// Stops the spinner and clears the message. + pub fn stop(&self) { + self.spinner.finish_and_clear(); + } +} diff --git a/src/coord.rs b/src/coord.rs new file mode 100644 index 0000000..f623efe --- /dev/null +++ b/src/coord.rs @@ -0,0 +1,73 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum CoordError { + #[error("Failed to fetch coordinates")] + FetchFailed(reqwest::Error), + #[error("Failed to parse response")] + ParseResponseFailed(reqwest::Error), + #[error("Failed to get latitude and longitude from the response")] + GetCoordsFailed, +} + +/// Represents a point on Earth. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Coord { + latitude: f64, + longitude: f64, +} + +impl Coord { + /// Constructs a new `Coord`. + pub fn new(latitude: f64, longitude: f64) -> Self { + Self { + latitude, + longitude, + } + } + + /// Fetches the current coordinates using the Mullvad API. + pub async fn fetch() -> Result { + let response = reqwest::get("https://am.i.mullvad.net/json") + .await + .map_err(CoordError::FetchFailed)?; + + let data = response + .json::() + .await + .map_err(CoordError::ParseResponseFailed)?; + + let lat = data["latitude"].as_f64(); + let lon = data["longitude"].as_f64(); + + lat + .zip(lon) + .map(|(latitude, longitude)| Self::new(latitude, longitude)) + .ok_or_else(|| CoordError::GetCoordsFailed) + } + + /// Finds the distance (in meters) between two coordinates using the haversine formula. + pub fn distance_to(&self, other: &Self) -> f64 { + // Earth radius in meters. This is *average*, since Earth is not a sphere, but a spheroid. + const R: f64 = 6_371_000f64; + + // Turn latitudes and longitudes into radians. + let phi1 = self.latitude.to_radians(); + let phi2 = other.latitude.to_radians(); + let lam1 = self.longitude.to_radians(); + let lam2 = other.longitude.to_radians(); + + // The haversine function. Computes half a versine of the given angle `theta`. + let haversine = |theta: f64| (1.0 - theta.cos()) / 2.0; + + let hav_delta_phi = haversine(phi2 - phi1); + let hav_delta_lam = phi1.cos() * phi2.cos() * haversine(lam2 - lam1); + let hav_delta = hav_delta_phi + hav_delta_lam; + + let distance = (2.0 * R * hav_delta.sqrt().asin() * 1_000.0).round() / 1_000.0; + + distance + } +} diff --git a/src/filters.rs b/src/filters.rs new file mode 100644 index 0000000..6eac381 --- /dev/null +++ b/src/filters.rs @@ -0,0 +1,65 @@ +use std::fmt::Debug; + +use crate::coord::Coord; +use crate::relays::{Protocol, Relay}; + +#[derive(PartialEq)] +pub enum FilterStage { + /// Such filters apply when loading them from the relays file. + Load, + /// Such filters apply after pinging relays. + Ping, +} + +/// Filter trait to dynamically dispatch filters. +pub trait Filter: Debug { + /// Returns the stage of the filter. + fn stage(&self) -> FilterStage; + + /// Filter predicate. + fn matches(&self, relay: &Relay) -> bool; +} + +#[derive(Debug)] +pub struct FilterByDistance { + user: Coord, + distance: f64, +} + +impl FilterByDistance { + pub fn new(user: Coord, distance: f64) -> Self { + Self { user, distance } + } +} + +impl Filter for FilterByDistance { + fn stage(&self) -> FilterStage { + FilterStage::Load + } + + fn matches(&self, relay: &Relay) -> bool { + (relay.coord.distance_to(&self.user) / 1_000.0) < self.distance + } +} + +#[derive(Debug)] +pub struct FilterByProtocol(Option); + +impl FilterByProtocol { + pub fn new(protocol: Option) -> Self { + Self(protocol) + } +} + +impl Filter for FilterByProtocol { + fn stage(&self) -> FilterStage { + FilterStage::Load + } + + fn matches(&self, relay: &Relay) -> bool { + self + .0 + .as_ref() + .map_or(true, |use_protocol| relay.protocol == *use_protocol) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..1e3c8de --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,5 @@ +pub mod cli; +pub mod coord; +pub mod filters; +pub mod pinger; +pub mod relays; diff --git a/src/main.rs b/src/main.rs index 80a1832..99ac4e0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,70 @@ -fn main() { - println!("Hello, world!"); +use std::thread; +use std::time::Duration; + +use clap::Parser; +use pingmole::cli::{Cli, Spinner}; +use pingmole::coord::Coord; +use pingmole::filters::{FilterByDistance, FilterByProtocol}; +use pingmole::pinger::RelayPinger; +use pingmole::relays::RelaysLoader; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let cli = Cli::parse(); + let spinner = Spinner::new(); + + // ----------------------------------------------------------------------------------------------- + // 1. Get current location, either via arguments or via Mullvad API. + spinner.set_message("Getting current location"); + + let user = match cli.latitude.zip(cli.longitude) { + | Some((latitude, longitude)) => Coord::new(latitude, longitude), + | None => Coord::fetch().await?, + }; + + thread::sleep(std::time::Duration::from_secs(1)); + + // ----------------------------------------------------------------------------------------------- + // 2. Load relays from file and filter them. + spinner.set_message("Loading relays"); + + let loader = RelaysLoader::new( + RelaysLoader::resolve_path()?, + vec![ + Box::new(FilterByDistance::new(user, cli.distance as f64)), + Box::new(FilterByProtocol::new(cli.protocol)), + ], + ); + + let relays = loader.load()?; + + thread::sleep(std::time::Duration::from_secs(1)); + + // ----------------------------------------------------------------------------------------------- + // 3. Pinging relays. + spinner.set_message("Pinging relays"); + + let mut tasks = Vec::new(); + let mut timings = Vec::new(); + + for relay in relays { + let mut pinger = RelayPinger::new(relay); + + pinger.set_count(cli.count); + pinger.set_timeout(Duration::from_millis(cli.timeout)); + + tasks.push(tokio::spawn(pinger.execute())); + } + + for task in tasks { + timings.push(task.await?); + } + + // ----------------------------------------------------------------------------------------------- + // 4. Print results. + spinner.stop(); + + dbg!(timings); + + Ok(()) } diff --git a/src/pinger.rs b/src/pinger.rs new file mode 100644 index 0000000..825a59e --- /dev/null +++ b/src/pinger.rs @@ -0,0 +1,98 @@ +use tokio::net::TcpStream; +use tokio::time::{self, Duration, Instant, MissedTickBehavior}; + +use crate::relays::Relay; + +#[derive(Debug)] +pub struct RelayTimings { + /// Relay. + relay: Relay, + /// Relay timings. + timings: Vec, +} + +impl RelayTimings { + pub fn new(relay: Relay, timings: Vec) -> Self { + Self { relay, timings } + } + + pub fn relay(&self) -> &Relay { + &self.relay + } + + pub fn rtt(&self) -> Option { + match self.timings.len() { + | 0 => None, + | len => Some(self.timings.iter().sum::() / len as u32), + } + } +} + +#[derive(Debug)] +pub struct RelayPinger { + /// Relay to ping. + relay: Relay, + /// How many times to ping the relay. Defaults to 4. + count: usize, + /// How long to wait before timing out a ping. Defaults to 750 ms. + timeout: Duration, + /// How long to wait between pings. Defaults to 1 second. + interval: Duration, +} + +impl RelayPinger { + pub fn new(relay: Relay) -> Self { + Self { + relay, + count: 4, + timeout: Duration::from_millis(750), + interval: Duration::from_millis(1_000), + } + } + + pub fn set_count(&mut self, count: usize) { + self.count = count; + } + + pub fn set_timeout(&mut self, timeout: Duration) { + self.timeout = timeout; + } + + pub fn set_interval(&mut self, interval: Duration) { + self.interval = interval; + } + + pub async fn execute(self) -> RelayTimings { + // I'm not entirely sure about hardcoding port 80, but it seems to be open on servers I checked. + let ping_addr = format!("{}:80", self.relay.ip); + + // Set up the interval... + let mut interval = time::interval(self.interval); + + // ...and use a different behavior for missed ticks. I'm not really sure why, but this works + // better than the default one. + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + + let mut timings = Vec::new(); + + for _ in 1..=self.count { + interval.tick().await; + + let start = Instant::now(); + let stream = TcpStream::connect(&ping_addr); + + match time::timeout(self.timeout, stream).await { + | Ok(Ok(..)) => { + let end = Instant::now(); + let elapsed = end.duration_since(start); + + timings.push(elapsed); + }, + | Ok(Err(..)) => continue, + | Err(..) => continue, + } + } + + RelayTimings::new(self.relay, timings) + } +} diff --git a/src/relays.rs b/src/relays.rs new file mode 100644 index 0000000..090e1d4 --- /dev/null +++ b/src/relays.rs @@ -0,0 +1,147 @@ +use std::env::consts; +use std::fmt::Debug; +use std::fs; +use std::path::PathBuf; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use thiserror::Error; + +use crate::coord::Coord; +use crate::filters::{Filter, FilterStage}; + +#[derive(Debug, Error)] +pub enum RelaysError { + #[error("Failed to read the relay file: {path}")] + ReadFileFailed { + path: PathBuf, + source: std::io::Error, + }, + #[error("Failed to parse the relay file")] + ParseFileFailed(serde_json::Error), + #[error("Failed to parse the field {0}: it's either missing or malformed")] + ParseFieldFailed(String), + #[error("Unsupported system: {0}")] + UnsupportedSystem(String), +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum Protocol { + OpenVPN, + WireGuard, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Relay { + pub ip: String, + pub city: String, + pub country: String, + pub coord: Coord, + pub protocol: Protocol, + pub is_active: bool, + pub is_mullvad_owned: bool, +} + +#[derive(Debug)] +pub struct RelaysLoader { + path: PathBuf, + filters: Vec>, +} + +impl RelaysLoader { + pub fn new(path: PathBuf, filters: Vec>) -> Self { + Self { path, filters } + } + + pub fn resolve_path() -> Result { + let path = match consts::OS { + | "linux" => "/var/cache/mullvad-vpn/relays.json", + | "macos" => "/Library/Caches/mullvad-vpn/relays.json", + | "windows" => "C:/ProgramData/Mullvad VPN/cache/relays.json", + | system => return Err(RelaysError::UnsupportedSystem(system.to_string())), + }; + + Ok(PathBuf::from(path)) + } + + /// Parses a protocol stored in the `endpoint_data` field of a relay, which can be either of the + /// following: + /// + /// ```json + /// "endpoint_data": "openvpn", + /// "endpoint_data": "bridge", + /// "endpoint_data": { + /// "wireguard": { + /// "public_key": "..." + /// } + /// } + /// ``` + /// + /// We actually not interested in those with "bridge", so skip them with other ones. + pub fn resolve_protocol(relay: &Value) -> Option { + match &relay["endpoint_data"] { + | Value::String(ref s) => s.eq("openvpn").then_some(Protocol::OpenVPN), + | Value::Object(o) => o.get("wireguard").map(|_| Protocol::WireGuard), + | _ => None, + } + } + + pub fn load(&self) -> Result, RelaysError> { + /// Simple macro helper to simplify accessing JSON fields and casting them. + macro_rules! get { + ($data:expr, $field:expr, $method:ident) => { + $data[$field] + .$method() + .ok_or_else(|| RelaysError::ParseFieldFailed(stringify!($field).into()))? + }; + } + + let mut locations = Vec::new(); + + // Read into a string. + let data = fs::read_to_string(&self.path).map_err(|source| { + RelaysError::ReadFileFailed { + path: self.path.clone(), + source, + } + })?; + + // Parse the string as arbitrary JSON. + let data = serde_json::from_str::(&data).map_err(RelaysError::ParseFileFailed)?; + + for country in get!(data, "countries", as_array) { + for city in get!(country, "cities", as_array) { + for relay in get!(city, "relays", as_array) { + // We only need relays that have either "openvpn" or "wireguard" protocols. + if let Some(protocol) = Self::resolve_protocol(&relay) { + let coord = Coord::new( + get!(city, "latitude", as_f64), + get!(city, "longitude", as_f64), + ); + + let relay = Relay { + coord, + protocol, + ip: get!(relay, "ipv4_addr_in", as_str).to_string(), + city: get!(city, "name", as_str).to_string(), + country: get!(country, "name", as_str).to_string(), + is_active: get!(relay, "active", as_bool), + is_mullvad_owned: get!(relay, "owned", as_bool), + }; + + if self + .filters + .iter() + .filter(|filter| filter.stage() == FilterStage::Load) + .all(|filter| filter.matches(&relay)) + { + locations.push(relay); + } + } + } + } + } + + Ok(locations) + } +}