diff --git a/.sqlx/query-eb982489a09c45fcaec74346f499c657d3018d01be7e095683a40160d533f410.json b/.sqlx/query-00454ac37de808986d66b6abd808fb648b288f49586113cea21d889dca9655b9.json similarity index 78% rename from .sqlx/query-eb982489a09c45fcaec74346f499c657d3018d01be7e095683a40160d533f410.json rename to .sqlx/query-00454ac37de808986d66b6abd808fb648b288f49586113cea21d889dca9655b9.json index 5e191c1e1..e91d531de 100644 --- a/.sqlx/query-eb982489a09c45fcaec74346f499c657d3018d01be7e095683a40160d533f410.json +++ b/.sqlx/query-00454ac37de808986d66b6abd808fb648b288f49586113cea21d889dca9655b9.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT instance_name, main_logo_url, nav_logo_url, wireguard_enabled, webhooks_enabled, worker_enabled, openid_enabled FROM settings WHERE id = 1;\n ", + "query": "SELECT instance_name, main_logo_url, nav_logo_url, wireguard_enabled, webhooks_enabled, worker_enabled, openid_enabled FROM settings WHERE id = 1", "describe": { "columns": [ { @@ -52,5 +52,5 @@ false ] }, - "hash": "eb982489a09c45fcaec74346f499c657d3018d01be7e095683a40160d533f410" + "hash": "00454ac37de808986d66b6abd808fb648b288f49586113cea21d889dca9655b9" } diff --git a/.sqlx/query-cdda0d8e9b34aef0728fc390bf77a3211b708f23ecdb3df5cada3d628280a025.json b/.sqlx/query-d8b3cbc7317bfdee111b80accd5d87781e7fa62a39998d54bf79d736ed7a8827.json similarity index 50% rename from .sqlx/query-cdda0d8e9b34aef0728fc390bf77a3211b708f23ecdb3df5cada3d628280a025.json rename to .sqlx/query-d8b3cbc7317bfdee111b80accd5d87781e7fa62a39998d54bf79d736ed7a8827.json index c17124309..2c1f0f69e 100644 --- a/.sqlx/query-cdda0d8e9b34aef0728fc390bf77a3211b708f23ecdb3df5cada3d628280a025.json +++ b/.sqlx/query-d8b3cbc7317bfdee111b80accd5d87781e7fa62a39998d54bf79d736ed7a8827.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT d.wireguard_pubkey as pubkey, array[host(wnd.wireguard_ip)] as \"allowed_ips!: Vec\" FROM wireguard_network_device wnd\n JOIN device d\n ON wnd.device_id = d.id\n WHERE wireguard_network_id = $1\n ORDER BY d.id ASC\n ", + "query": "SELECT d.wireguard_pubkey as pubkey, array[host(wnd.wireguard_ip)] as \"allowed_ips!: Vec\" FROM wireguard_network_device wnd JOIN device d ON wnd.device_id = d.id WHERE wireguard_network_id = $1 ORDER BY d.id ASC", "describe": { "columns": [ { @@ -24,5 +24,5 @@ null ] }, - "hash": "cdda0d8e9b34aef0728fc390bf77a3211b708f23ecdb3df5cada3d628280a025" + "hash": "d8b3cbc7317bfdee111b80accd5d87781e7fa62a39998d54bf79d736ed7a8827" } diff --git a/Cargo.lock b/Cargo.lock index e52b4edee..9f38afd68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -714,9 +714,9 @@ dependencies = [ [[package]] name = "const-oid" -version = "0.9.5" +version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28c122c3980598d243d63d9a704629a2d748d101f278052ff068be5a4423ab6f" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" [[package]] name = "const-random" @@ -1780,11 +1780,11 @@ dependencies = [ [[package]] name = "home" -version = "0.5.5" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5444c27eef6923071f7ebcc33e3444508466a76f7a2b93da00ed6e19f30c1ddb" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -4478,18 +4478,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.50" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" +checksum = "f11c217e1416d6f036b870f14e0413d480dbf28edbee1f877abaf0206af43bb7" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.50" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" +checksum = "01742297787513b79cf8e29d1056ede1313e2420b7b3b15d0a768b4921f549df" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 60ef1d613..2678ab831 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ humantime = "2.1" # match ipnetwork version from sqlx ipnetwork = { version = "0.20", features = ["serde"] } jsonwebtoken = "9.2" -ldap3 = "0.11" +ldap3 = { version = "0.11", default-features = false, features = ["tls"] } lettre = { version = "0.11", features = ["tokio1", "tokio1-native-tls"] } md4 = "0.10" otpauth = "0.4" diff --git a/src/appstate.rs b/src/appstate.rs index 7a662b612..feba4d36d 100644 --- a/src/appstate.rs +++ b/src/appstate.rs @@ -36,13 +36,14 @@ pub struct AppState { } impl AppState { - pub fn trigger_action(&self, event: AppEvent) { + pub(crate) fn trigger_action(&self, event: AppEvent) { let event_name = event.name().to_owned(); match self.tx.send(event) { Ok(()) => info!("Sent trigger {event_name}"), Err(err) => error!("Error sending trigger {event_name}: {err}"), } } + /// Handle webhook events async fn handle_triggers(pool: DbPool, mut rx: UnboundedReceiver) { let reqwest_client = Client::builder().user_agent("reqwest").build().unwrap(); diff --git a/src/bin/defguard.rs b/src/bin/defguard.rs index 60f25c616..851d1831a 100644 --- a/src/bin/defguard.rs +++ b/src/bin/defguard.rs @@ -1,3 +1,12 @@ +use std::{ + fs::read_to_string, + sync::{Arc, Mutex}, +}; + +use secrecy::ExposeSecret; +use tokio::sync::{broadcast, mpsc::unbounded_channel}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + use defguard::{ auth::failed_login::FailedLoginMap, config::{Command, DefGuardConfig}, @@ -10,13 +19,6 @@ use defguard::{ wireguard_stats_purge::run_periodic_stats_purge, SERVER_CONFIG, }; -use secrecy::ExposeSecret; -use std::{ - fs::read_to_string, - sync::{Arc, Mutex}, -}; -use tokio::sync::{broadcast, mpsc::unbounded_channel}; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[macro_use] extern crate tracing; diff --git a/src/db/models/device.rs b/src/db/models/device.rs index 98692d614..f1c9fc9a3 100644 --- a/src/db/models/device.rs +++ b/src/db/models/device.rs @@ -655,11 +655,11 @@ mod test { network.save(&pool).await.unwrap(); let mut user = User::new( - "testuser".to_string(), + "testuser", Some("hunter2"), - "Tester".to_string(), - "Test".to_string(), - "test@test.com".to_string(), + "Tester", + "Test", + "test@test.com", None, ); user.save(&pool).await.unwrap(); diff --git a/src/db/models/group.rs b/src/db/models/group.rs index b8a49e914..2ee5f91ff 100644 --- a/src/db/models/group.rs +++ b/src/db/models/group.rs @@ -11,7 +11,7 @@ pub struct Group { impl Group { #[must_use] - pub fn new(name: &str) -> Self { + pub fn new>(name: S) -> Self { Self { id: None, name: name.into(), @@ -48,7 +48,7 @@ impl Group { } } - pub async fn fetch_all_members<'e, E>(&self, executor: E) -> Result, SqlxError> + pub async fn members<'e, E>(&self, executor: E) -> Result, SqlxError> where E: PgExecutor<'e>, { @@ -233,11 +233,11 @@ mod test { group.save(&pool).await.unwrap(); let mut user = User::new( - "hpotter".into(), + "hpotter", Some("pass123"), - "Potter".into(), - "Harry".into(), - "h.potter@hogwart.edu.uk".into(), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", None, ); user.save(&pool).await.unwrap(); diff --git a/src/db/models/mod.rs b/src/db/models/mod.rs index fc29365d1..6e46502d9 100644 --- a/src/db/models/mod.rs +++ b/src/db/models/mod.rs @@ -268,11 +268,11 @@ mod test { #[sqlx::test] async fn test_user_info(pool: DbPool) { let mut user = User::new( - "hpotter".into(), + "hpotter", Some("pass123"), - "Potter".into(), - "Harry".into(), - "h.potter@hogwart.edu.uk".into(), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", None, ); user.save(&pool).await.unwrap(); diff --git a/src/db/models/settings.rs b/src/db/models/settings.rs index b9eb7f489..ed8257ba5 100644 --- a/src/db/models/settings.rs +++ b/src/db/models/settings.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use model_derive::Model; -use sqlx::{query, Error as SqlxError, PgExecutor, Type}; +use sqlx::{query, query_as, Error as SqlxError, PgExecutor, Type}; use struct_patch::Patch; use super::DbPool; @@ -46,7 +46,7 @@ pub struct Settings { pub enrollment_welcome_email: Option, pub enrollment_welcome_email_subject: Option, pub enrollment_use_welcome_message_as_email: bool, - // Instance uuid needed for desktop client + // Instance UUID needed for desktop client #[serde(skip)] pub uuid: uuid::Uuid, // LDAP @@ -110,14 +110,7 @@ impl Settings { } } -#[derive(Debug, Serialize, Clone)] -pub struct SettingsBranding { - pub instance_name: String, - pub main_logo_url: String, - pub nav_logo_url: String, -} - -#[derive(Debug, Serialize, Clone)] +#[derive(Serialize)] pub struct SettingsEssentials { pub instance_name: String, pub main_logo_url: String, @@ -129,11 +122,18 @@ pub struct SettingsEssentials { } impl SettingsEssentials { - pub async fn get_settings_essentials(pool: &DbPool) -> Result { - let res = sqlx::query_as!(SettingsEssentials, r#" - SELECT instance_name, main_logo_url, nav_logo_url, wireguard_enabled, webhooks_enabled, worker_enabled, openid_enabled FROM settings WHERE id = 1; - "#).fetch_one(pool).await?; - Ok(res) + pub(crate) async fn get_settings_essentials<'e, E>(executor: E) -> Result + where + E: PgExecutor<'e>, + { + query_as!( + SettingsEssentials, + "SELECT instance_name, main_logo_url, nav_logo_url, wireguard_enabled, \ + webhooks_enabled, worker_enabled, openid_enabled \ + FROM settings WHERE id = 1" + ) + .fetch_one(executor) + .await } } @@ -152,7 +152,7 @@ impl From for SettingsEssentials { } mod defaults { - pub const WELCOME_MESSAGE: &str = "Dear {{ first_name }} {{ last_name }}, + pub static WELCOME_MESSAGE: &str = "Dear {{ first_name }} {{ last_name }}, By completing the enrollment process, you now have now access to all company systems. @@ -187,5 +187,5 @@ Sent by defguard {{ defguard_version }} Star us on GitHub! https://github.com/defguard/defguard\ "; - pub const WELCOME_EMAIL_SUBJECT: &str = "[defguard] Welcome message after enrollment"; + pub static WELCOME_EMAIL_SUBJECT: &str = "[defguard] Welcome message after enrollment"; } diff --git a/src/db/models/user.rs b/src/db/models/user.rs index 40d294989..aa3902886 100644 --- a/src/db/models/user.rs +++ b/src/db/models/user.rs @@ -93,23 +93,23 @@ impl User { } #[must_use] - pub fn new( - username: String, + pub fn new>( + username: S, password: Option<&str>, - last_name: String, - first_name: String, - email: String, + last_name: S, + first_name: S, + email: S, phone: Option, ) -> Self { let password_hash = password.and_then(|password_hash| Self::hash_password(password_hash).ok()); Self { id: None, - username, + username: username.into(), password_hash, - last_name, - first_name, - email, + last_name: last_name.into(), + first_name: first_name.into(), + email: email.into(), phone, ssh_key: None, pgp_key: None, @@ -152,7 +152,10 @@ impl User { } /// Generate new TOTP secret, save it, then return it as RFC 4648 base32-encoded string. - pub async fn new_totp_secret(&mut self, pool: &DbPool) -> Result { + pub async fn new_totp_secret<'e, E>(&mut self, executor: E) -> Result + where + E: PgExecutor<'e>, + { let secret = gen_totp_secret(); if let Some(id) = self.id { query!( @@ -160,7 +163,7 @@ impl User { secret, id ) - .execute(pool) + .execute(executor) .await?; } let secret_base32 = TOTP::from_bytes(&secret).base32_secret(); @@ -169,7 +172,10 @@ impl User { } /// Generate new email secret, similar to TOTP secret above, but don't return generated value. - pub async fn new_email_secret(&mut self, pool: &DbPool) -> Result<(), SqlxError> { + pub async fn new_email_secret<'e, E>(&mut self, executor: E) -> Result<(), SqlxError> + where + E: PgExecutor<'e>, + { let email_secret = gen_totp_secret(); if let Some(id) = self.id { query!( @@ -177,7 +183,7 @@ impl User { email_secret, id ) - .execute(pool) + .execute(executor) .await?; } self.email_mfa_secret = Some(email_secret); @@ -783,7 +789,7 @@ impl User { // if new user was created add them to admin group (ID 1) if let Some(new_user_id) = result { - info!("New admin user was created, adding to Admin group..."); + info!("New admin user has been created, adding to Admin group..."); query("INSERT INTO group_user (group_id, user_id) VALUES (1, $1)") .bind(new_user_id) .execute(pool) @@ -801,11 +807,11 @@ mod test { #[sqlx::test] async fn test_user(pool: DbPool) { let mut user = User::new( - "hpotter".into(), + "hpotter", Some("pass123"), - "Potter".into(), - "Harry".into(), - "h.potter@hogwart.edu.uk".into(), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", None, ); user.save(&pool).await.unwrap(); @@ -830,21 +836,21 @@ mod test { #[sqlx::test] async fn test_all_users(pool: DbPool) { let mut harry = User::new( - "hpotter".into(), + "hpotter", Some("pass123"), - "Potter".into(), - "Harry".into(), - "h.potter@hogwart.edu.uk".into(), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", None, ); harry.save(&pool).await.unwrap(); let mut albus = User::new( - "adumbledore".into(), + "adumbledore", Some("magic!"), - "Dumbledore".into(), - "Albus".into(), - "a.dumbledore@hogwart.edu.uk".into(), + "Dumbledore", + "Albus", + "a.dumbledore@hogwart.edu.uk", None, ); albus.save(&pool).await.unwrap(); @@ -861,11 +867,11 @@ mod test { #[sqlx::test] async fn test_recovery_codes(pool: DbPool) { let mut harry = User::new( - "hpotter".into(), + "hpotter", Some("pass123"), - "Potter".into(), - "Harry".into(), - "h.potter@hogwart.edu.uk".into(), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", None, ); harry.get_recovery_codes(&pool).await.unwrap(); diff --git a/src/db/models/wallet.rs b/src/db/models/wallet.rs index 6983ef24b..e5de96e6f 100644 --- a/src/db/models/wallet.rs +++ b/src/db/models/wallet.rs @@ -77,20 +77,20 @@ pub struct Wallet { impl Wallet { #[must_use] - pub fn new_for_user( + pub fn new_for_user>( user_id: i64, - address: String, - name: String, + address: S, + name: S, chain_id: i64, - challenge_message: String, + challenge_message: S, ) -> Self { Self { id: None, user_id, - address, - name, + address: address.into(), + name: name.into(), chain_id, - challenge_message, + challenge_message: challenge_message.into(), challenge_signature: None, creation_timestamp: Utc::now().naive_utc(), validation_timestamp: None, @@ -184,11 +184,14 @@ impl Wallet { .collect() } - pub async fn find_by_user_and_address( - pool: &DbPool, + pub async fn find_by_user_and_address<'e, E>( + executor: E, user_id: i64, address: &str, - ) -> Result, SqlxError> { + ) -> Result, SqlxError> + where + E: PgExecutor<'e>, + { query_as!( Self, "SELECT id \"id?\", user_id, address, name, chain_id, challenge_message, challenge_signature, \ @@ -197,7 +200,7 @@ impl Wallet { user_id, address ) - .fetch_optional(pool) + .fetch_optional(executor) .await } diff --git a/src/db/models/wireguard.rs b/src/db/models/wireguard.rs index 084654583..2961b4b98 100644 --- a/src/db/models/wireguard.rs +++ b/src/db/models/wireguard.rs @@ -438,7 +438,7 @@ impl WireguardNetwork { })); } else { let msg = format!("Device {} does not exist", device_network_config.device_id); - error!("{msg}"); + error!(msg); return Err(WireguardNetworkError::Unexpected(msg)); } } @@ -1013,11 +1013,11 @@ mod test { async fn add_devices(pool: &DbPool, network: &WireguardNetwork, count: usize) { let mut user = User::new( - "testuser".to_string(), + "testuser", Some("hunter2"), - "Tester".to_string(), - "Test".to_string(), - "test@test.com".to_string(), + "Tester", + "Test", + "test@test.com", None, ); user.save(pool).await.unwrap(); @@ -1094,11 +1094,11 @@ mod test { network.save(&pool).await.unwrap(); let mut user = User::new( - "testuser".to_string(), + "testuser", Some("hunter2"), - "Tester".to_string(), - "Test".to_string(), - "test@test.com".to_string(), + "Tester", + "Test", + "test@test.com", None, ); user.save(&pool).await.unwrap(); @@ -1144,11 +1144,11 @@ mod test { network.save(&pool).await.unwrap(); let mut user = User::new( - "testuser".to_string(), + "testuser", Some("hunter2"), - "Tester".to_string(), - "Test".to_string(), - "test@test.com".to_string(), + "Tester", + "Test", + "test@test.com", None, ); user.save(&pool).await.unwrap(); diff --git a/src/error.rs b/src/error.rs index dfc3a9ac6..3c003b0f9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,7 +9,7 @@ use crate::{ wireguard::WireguardNetworkError, }, grpc::GatewayMapError, - ldap::error::OriLDAPError, + ldap::error::LdapError, templates::TemplateError, }; @@ -64,12 +64,12 @@ impl From for WebError { } } -impl From for WebError { - fn from(error: OriLDAPError) -> Self { +impl From for WebError { + fn from(error: LdapError) -> Self { match error { - OriLDAPError::ObjectNotFound(msg) => Self::ObjectNotFound(msg), - OriLDAPError::Ldap(msg) => Self::Ldap(msg), - OriLDAPError::MissingSettings => Self::Ldap("LDAP settings are missing".to_string()), + LdapError::ObjectNotFound(msg) => Self::ObjectNotFound(msg), + LdapError::Ldap(msg) => Self::Ldap(msg), + LdapError::MissingSettings => Self::Ldap("LDAP settings are missing".to_string()), } } } diff --git a/src/grpc/gateway.rs b/src/grpc/gateway.rs index 6c9370939..6b35d5387 100644 --- a/src/grpc/gateway.rs +++ b/src/grpc/gateway.rs @@ -43,13 +43,11 @@ impl WireguardNetwork { debug!("Fetching all peers for network {}", self.id.unwrap()); let result = query_as!( Peer, - r#" - SELECT d.wireguard_pubkey as pubkey, array[host(wnd.wireguard_ip)] as "allowed_ips!: Vec" FROM wireguard_network_device wnd - JOIN device d - ON wnd.device_id = d.id - WHERE wireguard_network_id = $1 - ORDER BY d.id ASC - "#, + "SELECT d.wireguard_pubkey as pubkey, array[host(wnd.wireguard_ip)] as \"allowed_ips!: Vec\" \ + FROM wireguard_network_device wnd \ + JOIN device d ON wnd.device_id = d.id \ + WHERE wireguard_network_id = $1 \ + ORDER BY d.id ASC", self.id ) .fetch_all(executor) @@ -185,7 +183,7 @@ impl GatewayUpdatesHandler { self.gateway_hostname, self.network ); while let Ok(update) = self.events_rx.recv().await { - debug!("Received wireguard update: {:?}", update); + debug!("Received wireguard update: {update:?}"); let result = match update { GatewayEvent::NetworkCreated(network_id, network) => { if network_id == self.network_id { @@ -277,7 +275,7 @@ impl GatewayUpdatesHandler { peers: Vec, update_type: i32, ) -> Result<(), Status> { - debug!("Sending network update for network {}", network); + debug!("Sending network update for network {network}"); if let Err(err) = self .tx .send(Ok(Update { @@ -295,8 +293,8 @@ impl GatewayUpdatesHandler { let msg = format!( "Failed to send network update, network {network}, update type: {update_type}, error: {err}", ); - error!("{msg}"); - return Err(Status::new(tonic::Code::Internal, msg)); + error!(msg); + return Err(Status::new(Code::Internal, msg)); } Ok(()) } @@ -322,11 +320,11 @@ impl GatewayUpdatesHandler { .await { let msg = format!( - "Failed to send network update, network {}, update type: {}, error: {}", - self.network, 2, err, + "Failed to send network update, network {}, update type: 2, error: {err}", + self.network, ); - error!("{}", msg); - return Err(Status::new(tonic::Code::Internal, msg)); + error!(msg); + return Err(Status::new(Code::Internal, msg)); } Ok(()) } @@ -343,11 +341,11 @@ impl GatewayUpdatesHandler { .await { let msg = format!( - "Failed to send peer update for network {}, update type: {}, error: {}", - self.network, update_type, err, + "Failed to send peer update for network {}, update type: {update_type}, error: {err}", + self.network ); - error!("{}", msg); - return Err(Status::new(tonic::Code::Internal, msg)); + error!(msg); + return Err(Status::new(Code::Internal, msg)); } Ok(()) } @@ -367,11 +365,11 @@ impl GatewayUpdatesHandler { .await { let msg = format!( - "Failed to send peer update for network {}, peer {}, update type: 2, error: {}", - self.network, peer_pubkey, err, + "Failed to send peer update for network {}, peer {peer_pubkey}, update type: 2, error: {err}", + self.network, ); - error!("{}", msg); - return Err(Status::new(tonic::Code::Internal, msg)); + error!(msg); + return Err(Status::new(Code::Internal, msg)); } Ok(()) } @@ -383,6 +381,7 @@ pub struct GatewayUpdatesStream { network_id: i64, gateway_hostname: String, gateway_state: Arc>, + pool: DbPool, } impl GatewayUpdatesStream { @@ -393,6 +392,7 @@ impl GatewayUpdatesStream { network_id: i64, gateway_hostname: String, gateway_state: Arc>, + pool: DbPool, ) -> Self { Self { task_handle, @@ -400,6 +400,7 @@ impl GatewayUpdatesStream { network_id, gateway_hostname, gateway_state, + pool, } } } @@ -421,7 +422,7 @@ impl Drop for GatewayUpdatesStream { self.gateway_state .lock() .unwrap() - .disconnect_gateway(self.network_id, self.gateway_hostname.clone()) + .disconnect_gateway(self.network_id, self.gateway_hostname.clone(), &self.pool) .expect("Unable to disconnect gateway."); } } @@ -445,25 +446,19 @@ impl gateway_service_server::GatewayService for GatewayServer { stats.device_id = match Device::find_by_pubkey(&self.pool, &public_key).await { Ok(Some(device)) => device .id - .ok_or_else(|| Status::new(tonic::Code::Internal, "Device has no id"))?, + .ok_or_else(|| Status::new(Code::Internal, "Device has no ID"))?, Ok(None) => { - error!("Device with public key {} not found", &public_key); + error!("Device with public key {public_key} not found"); return Err(Status::new( - tonic::Code::Internal, - format!("Device with public key {} not found", &public_key), + Code::Internal, + format!("Device with public key {public_key} not found"), )); } Err(err) => { - error!( - "Failed to retrieve device with public key {}: {err}", - &public_key - ); + error!("Failed to retrieve device with public key {public_key}: {err}",); return Err(Status::new( - tonic::Code::Internal, - format!( - "Failed to retrieve device with public key {}: {err}", - &public_key - ), + Code::Internal, + format!("Failed to retrieve device with public key {public_key}: {err}",), )); } }; @@ -471,7 +466,7 @@ impl gateway_service_server::GatewayService for GatewayServer { if let Err(err) = stats.save(&self.pool).await { error!("Saving WireGuard peer stats to db failed: {err}"); return Err(Status::new( - tonic::Code::Internal, + Code::Internal, format!("Saving WireGuard peer stats to db failed: {err}"), )); } @@ -488,17 +483,13 @@ impl gateway_service_server::GatewayService for GatewayServer { let network_id = Self::get_network_id(request.metadata())?; let hostname = Self::get_gateway_hostname(request.metadata())?; - let pool = self.pool.clone(); - let mut network = WireguardNetwork::find_by_id(&pool, network_id) + let mut network = WireguardNetwork::find_by_id(&self.pool, network_id) .await .map_err(|e| { - error!("Network {} not found", network_id); - Status::new( - tonic::Code::Internal, - format!("Failed to retrieve network: {e}"), - ) + error!("Network {network_id} not found"); + Status::new(Code::Internal, format!("Failed to retrieve network: {e}")) })? - .ok_or_else(|| Status::new(tonic::Code::Internal, "Network not found"))?; + .ok_or_else(|| Status::new(Code::Internal, "Network not found"))?; info!("Sending configuration to gateway client, network {network}."); @@ -506,23 +497,22 @@ impl gateway_service_server::GatewayService for GatewayServer { let mut state = self.state.lock().unwrap(); state.add_gateway( network_id, - network.name.clone(), + &network.name, hostname, request.into_inner().name, - self.pool.clone(), self.mail_tx.clone(), ); } network.connected_at = Some(Utc::now().naive_utc()); - if let Err(err) = network.save(&pool).await { + if let Err(err) = network.save(&self.pool).await { error!("Failed to update network {network_id} status: {err}"); } - let peers = network.get_peers(&pool).await.map_err(|error| { + let peers = network.get_peers(&self.pool).await.map_err(|error| { error!("Failed to fetch peers for network {network_id}: {error}",); Status::new( - tonic::Code::Internal, + Code::Internal, format!("Failed to retrieve peers for network: {network_id}"), ) })?; @@ -539,7 +529,7 @@ impl gateway_service_server::GatewayService for GatewayServer { .map_err(|_| { error!("Failed to fetch network {gateway_network_id}"); Status::new( - tonic::Code::Internal, + Code::Internal, format!("Failed to retrieve network {gateway_network_id}"), ) })? @@ -556,7 +546,7 @@ impl gateway_service_server::GatewayService for GatewayServer { .connect_gateway(gateway_network_id, &hostname) .map_err(|err| { error!("Failed to connect gateway: {err}"); - Status::new(tonic::Code::Internal, "Failed to connect gateway ") + Status::new(Code::Internal, "Failed to connect gateway") })?; // clone here before moving into a closure @@ -578,6 +568,7 @@ impl gateway_service_server::GatewayService for GatewayServer { gateway_network_id, hostname, Arc::clone(&self.state), + self.pool.clone(), ))) } } diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index e40801bbc..2b8114de1 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -1,4 +1,3 @@ -use chrono::Duration as ChronoDuration; use std::{ collections::hash_map::HashMap, time::{Duration, Instant}, @@ -9,7 +8,7 @@ use std::{ sync::{Arc, Mutex}, }; -use chrono::{NaiveDateTime, Utc}; +use chrono::{Duration as ChronoDuration, NaiveDateTime, Utc}; use serde::Serialize; use thiserror::Error; use tokio::sync::{broadcast::Sender, mpsc::UnboundedSender}; @@ -88,31 +87,20 @@ impl GatewayMap { pub fn add_gateway( &mut self, network_id: i64, - network_name: String, + network_name: &str, hostname: String, name: Option, - pool: DbPool, mail_tx: UnboundedSender, ) { info!("Adding gateway {hostname} with to gateway map for network {network_id}",); + let gateway_state = GatewayState::new(network_id, network_name, &hostname, name, mail_tx); + if let Some(network_gateway_map) = self.0.get_mut(&network_id) { - network_gateway_map - .entry(hostname.clone()) - .or_insert(GatewayState::new( - network_id, - network_name, - hostname, - name, - pool, - mail_tx, - )); + network_gateway_map.entry(hostname).or_insert(gateway_state); } else { // no map for a given network exists yet let mut network_gateway_map = HashMap::new(); - network_gateway_map.insert( - hostname.clone(), - GatewayState::new(network_id, network_name, hostname, name, pool, mail_tx), - ); + network_gateway_map.insert(hostname, gateway_state); self.0.insert(network_id, network_gateway_map); } } @@ -177,13 +165,14 @@ impl GatewayMap { &mut self, network_id: i64, hostname: String, + pool: &DbPool, ) -> Result<(), GatewayMapError> { info!("Disconnecting gateway {hostname} in network {network_id}"); if let Some(network_gateway_map) = self.0.get_mut(&network_id) { if let Some(state) = network_gateway_map.get_mut(&hostname) { state.connected = false; state.disconnected_at = Some(Utc::now().naive_utc()); - state.send_disconnect_notification()?; + state.send_disconnect_notification(pool)?; return Ok(()); }; }; @@ -246,42 +235,39 @@ pub struct GatewayState { #[serde(skip)] pub mail_tx: UnboundedSender, #[serde(skip)] - pub pool: DbPool, - #[serde(skip)] pub last_email_notification: Option, } impl GatewayState { #[must_use] - pub fn new( + pub fn new>( network_id: i64, - network_name: String, - hostname: String, + network_name: S, + hostname: S, name: Option, - pool: DbPool, mail_tx: UnboundedSender, ) -> Self { Self { uid: Uuid::new_v4(), connected: false, network_id, - network_name, + network_name: network_name.into(), name, - hostname, + hostname: hostname.into(), connected_at: None, disconnected_at: None, mail_tx, - pool, last_email_notification: None, } } + /// Send gateway disconnected notification /// Sends notification only if last notification time is bigger than specified in config - fn send_disconnect_notification(&mut self) -> Result<(), GatewayMapError> { + fn send_disconnect_notification(&mut self, pool: &DbPool) -> Result<(), GatewayMapError> { // Clone here because self doesn't live long enough let name = self.name.clone(); let mail_tx = self.mail_tx.clone(); - let pool = self.pool.clone(); + let pool = pool.clone(); let hostname = self.hostname.clone(); let network_name = self.network_name.clone(); let send_email = if let Some(last_notification_time) = self.last_email_notification { @@ -305,12 +291,12 @@ impl GatewayState { send_gateway_disconnected_email(name, network_name, &hostname, &mail_tx, &pool) .await { - error!("Sending gateway disconnected notification failed: {e}"); + error!("Failed to send gateway disconnect notification: {e}"); } }); } else { debug!( - "Gateway {hostname} disconnected not sending email. Last notification time was at {:?}", + "Gateway {hostname} disconnected. Email notification not sent. Last notification was at {:?}", self.last_email_notification ); }; diff --git a/src/handlers/auth.rs b/src/handlers/auth.rs index fa01c4b99..be0643ea5 100644 --- a/src/handlers/auth.rs +++ b/src/handlers/auth.rs @@ -106,7 +106,7 @@ pub async fn authenticate( }; let server_config = SERVER_CONFIG.get().ok_or(WebError::ServerConfigMissing)?; - let auth_cookie = Cookie::build((SESSION_COOKIE_NAME, session.clone().id)) + let auth_cookie = Cookie::build((SESSION_COOKIE_NAME, session.id.clone())) .domain( server_config .cookie_domain diff --git a/src/handlers/group.rs b/src/handlers/group.rs index ec899c971..afa4d96eb 100644 --- a/src/handlers/group.rs +++ b/src/handlers/group.rs @@ -4,17 +4,17 @@ use axum::{ }; use serde_json::json; -use super::{ApiResponse, ApiResult, Username}; +use super::{ApiResponse, GroupInfo, Username}; use crate::{ appstate::AppState, auth::{SessionInfo, UserAdminRole}, db::{Group, User}, error::WebError, - ldap::utils::{ldap_add_user_to_group, ldap_remove_user_from_group}, + // ldap::utils::{ldap_add_user_to_group, ldap_modify_group, ldap_remove_user_from_group}, }; #[derive(Serialize)] -pub struct Groups { +pub(crate) struct Groups { groups: Vec, } @@ -25,20 +25,11 @@ impl Groups { } } -#[derive(Serialize)] -pub struct GroupInfo { - name: String, - members: Vec, -} - -impl GroupInfo { - #[must_use] - pub fn new(name: String, members: Vec) -> Self { - Self { name, members } - } -} - -pub async fn list_groups(_session: SessionInfo, State(appstate): State) -> ApiResult { +/// GET: Retrieve all groups. +pub(crate) async fn list_groups( + _session: SessionInfo, + State(appstate): State, +) -> Result { debug!("Listing groups"); let groups = Group::all(&appstate.pool) .await? @@ -52,36 +43,165 @@ pub async fn list_groups(_session: SessionInfo, State(appstate): State }) } -pub async fn get_group( +/// GET: Retrieve group with `name`. +pub(crate) async fn get_group( _session: SessionInfo, State(appstate): State, Path(name): Path, -) -> ApiResult { +) -> Result { debug!("Retrieving group {name}"); if let Some(group) = Group::find_by_name(&appstate.pool, &name).await? { let members = group.member_usernames(&appstate.pool).await?; info!("Retrieved group {name}"); Ok(ApiResponse { - json: json!(GroupInfo::new(name, members)), + json: json!(GroupInfo::new(name, Some(members))), status: StatusCode::OK, }) } else { - error!("Group {name} not found"); - Err(WebError::ObjectNotFound(format!("Group {name} not found",))) + let msg = format!("Group {name} not found"); + error!(msg); + Err(WebError::ObjectNotFound(msg)) + } +} + +/// POST: Create group with a given name and member list. +pub(crate) async fn create_group( + _role: UserAdminRole, + State(appstate): State, + Json(group_info): Json, +) -> Result { + debug!("Creating group {}", group_info.name); + + // FIXME: LDAP operations are not reverted. + let mut transaction = appstate.pool.begin().await?; + + let mut group = Group::new(&group_info.name); + // FIXME: conflicts must not return interal server error (500). + group.save(&appstate.pool).await?; + // TODO: create group in LDAP + + if let Some(ref members) = group_info.members { + for username in members { + let Some(user) = User::find_by_username(&mut *transaction, username).await? else { + let msg = format!("Failed to find user {username}"); + error!(msg); + return Err(WebError::ObjectNotFound(msg)); + }; + user.add_to_group(&mut *transaction, &group).await?; + // let _result = ldap_add_user_to_group(&mut *transaction, username, &group.name).await; + } } + + transaction.commit().await?; + + info!("Created group {}", group_info.name); + Ok(ApiResponse { + json: json!(group_info), + status: StatusCode::CREATED, + }) } -pub async fn add_group_member( +/// PUT: Rename group and/or change group members. +pub(crate) async fn modify_group( + _role: UserAdminRole, + State(appstate): State, + Path(name): Path, + Json(group_info): Json, +) -> Result { + debug!("Modifying group {}", group_info.name); + let Some(mut group) = Group::find_by_name(&appstate.pool, &name).await? else { + let msg = format!("Group {name} not found"); + error!(msg); + return Err(WebError::ObjectNotFound(msg)); + }; + + // FIXME: LDAP operations are not reverted. + let mut transaction = appstate.pool.begin().await?; + + // Rename only when needed. + if group.name != group_info.name { + group.name = group_info.name; + group.save(&mut *transaction).await?; + // let _result = ldap_modify_group(&mut *transaction, &group.name, &group).await; + } + + // Modify group members. + if let Some(ref members) = group_info.members { + let mut current_members = group.members(&mut *transaction).await?; + for username in members { + if let Some(index) = current_members + .iter() + .position(|gm| &gm.username == username) + { + // This member is already in the group. + current_members.remove(index); + continue; + } + + // Add new members to the group. + if let Some(user) = User::find_by_username(&mut *transaction, username).await? { + user.add_to_group(&mut *transaction, &group).await?; + // let _result = + // ldap_add_user_to_group(&mut *transaction, username, &group.name).await; + } + } + + // Remove outstanding members. + for user in current_members { + user.remove_from_group(&mut *transaction, &group).await?; + // let _result = + // ldap_remove_user_from_group(&mut *transaction, &user.username, &group.name).await; + } + } + + transaction.commit().await?; + + info!("Modified group {}", group.name); + Ok(ApiResponse::default()) +} + +/// DELETE: Remove group with `name`. +pub(crate) async fn delete_group( + _session: SessionInfo, + + State(appstate): State, + Path(name): Path, +) -> Result { + debug!("Deleting group {name}"); + // Administrative group must not be removed. + // Note: Group names are unique, so this condition should be sufficient. + if name == appstate.config.admin_groupname { + return Ok(ApiResponse { + json: json!({}), + status: StatusCode::BAD_REQUEST, + }); + } + + if let Some(group) = Group::find_by_name(&appstate.pool, &name).await? { + group.delete(&appstate.pool).await?; + // TODO: delete group from LDAP + + info!("Deleted group {name}"); + Ok(ApiResponse::default()) + } else { + let msg = format!("Failed to find group {name}"); + error!(msg); + Err(WebError::ObjectNotFound(msg)) + } +} + +/// POST: Find a group with `name` and add `username` as a member. +pub(crate) async fn add_group_member( _role: UserAdminRole, State(appstate): State, Path(name): Path, Json(data): Json, -) -> ApiResult { +) -> Result { if let Some(group) = Group::find_by_name(&appstate.pool, &name).await? { if let Some(user) = User::find_by_username(&appstate.pool, &data.username).await? { debug!("Adding user: {} to group: {}", user.username, group.name); user.add_to_group(&appstate.pool, &group).await?; - let _result = ldap_add_user_to_group(&appstate.pool, &user.username, &group.name).await; + // let _result = ldap_add_user_to_group(&appstate.pool, &user.username, &group.name).await; info!("Added user: {} to group: {}", user.username, group.name); Ok(ApiResponse::default()) } else { @@ -92,16 +212,18 @@ pub async fn add_group_member( ))) } } else { - error!("Group {name} not found"); - Err(WebError::ObjectNotFound(format!("Group {name} not found"))) + let msg = format!("Group {name} not found"); + error!(msg); + Err(WebError::ObjectNotFound(msg)) } } -pub async fn remove_group_member( +/// DELETE: Remove `username` from group with `name`. +pub(crate) async fn remove_group_member( _role: UserAdminRole, State(appstate): State, Path((name, username)): Path<(String, String)>, -) -> ApiResult { +) -> Result { if let Some(group) = Group::find_by_name(&appstate.pool, &name).await? { if let Some(user) = User::find_by_username(&appstate.pool, &username).await? { debug!( @@ -109,18 +231,17 @@ pub async fn remove_group_member( user.username, group.name ); user.remove_from_group(&appstate.pool, &group).await?; - let _result = - ldap_remove_user_from_group(&appstate.pool, &user.username, &group.name).await; + // let _result = + // ldap_remove_user_from_group(&appstate.pool, &user.username, &group.name).await; info!("Removed user: {} from group: {}", user.username, group.name); Ok(ApiResponse { json: json!({}), status: StatusCode::OK, }) } else { - error!("User not found {username}"); - Err(WebError::ObjectNotFound(format!( - "User {username} not found" - ))) + let msg = format!("User {username} not found"); + error!(msg); + Err(WebError::ObjectNotFound(msg)) } } else { error!("Group {name} not found"); diff --git a/src/handlers/mail.rs b/src/handlers/mail.rs index 61307237d..0a273a309 100644 --- a/src/handlers/mail.rs +++ b/src/handlers/mail.rs @@ -430,12 +430,7 @@ pub fn send_password_reset_email( let mail = Mail { to: user.email.clone(), subject: EMAIL_PASSOWRD_RESET_START_SUBJECT.into(), - content: templates::email_password_reset_mail( - service_url.clone(), - token, - ip_address, - device_info, - )?, + content: templates::email_password_reset_mail(service_url, token, ip_address, device_info)?, attachments: Vec::new(), result_tx: None, }; diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index 63328dbf2..f836b0a5b 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -57,11 +57,11 @@ impl From for ApiResponse { ApiResponse::new(json!({ "msg": msg }), StatusCode::NOT_FOUND) } WebError::Authorization(msg) => { - error!("{msg}"); + error!(msg); ApiResponse::new(json!({ "msg": msg }), StatusCode::UNAUTHORIZED) } WebError::Forbidden(msg) => { - error!("{msg}"); + error!(msg); ApiResponse::new(json!({ "msg": msg }), StatusCode::FORBIDDEN) } WebError::DbError(_) @@ -92,7 +92,7 @@ impl From for ApiResponse { WebError::IncorrectUsername(msg) | WebError::PubkeyValidation(msg) | WebError::BadRequest(msg) => { - error!("{msg}"); + error!(msg); ApiResponse::new(json!({ "msg": msg }), StatusCode::BAD_REQUEST) } WebError::TemplateError(err) => { @@ -135,8 +135,11 @@ pub struct Auth { impl Auth { #[must_use] - pub fn new(username: String, password: String) -> Self { - Self { username, password } + pub fn new>(username: S, password: S) -> Self { + Self { + username: username.into(), + password: password.into(), + } } } @@ -147,8 +150,10 @@ pub struct AuthTotp { impl AuthTotp { #[must_use] - pub fn new(secret: String) -> Self { - Self { secret } + pub fn new>(secret: S) -> Self { + Self { + secret: secret.into(), + } } } @@ -164,6 +169,22 @@ impl AuthCode { } } +#[derive(Deserialize, Serialize)] +pub struct GroupInfo { + pub name: String, + pub members: Option>, +} + +impl GroupInfo { + #[must_use] + pub fn new>(name: S, members: Option>) -> Self { + Self { + name: name.into(), + members, + } + } +} + #[derive(Deserialize, Serialize)] pub struct Username { pub username: String, diff --git a/src/handlers/ssh_authorized_keys.rs b/src/handlers/ssh_authorized_keys.rs index 067fa39de..9fecb8c57 100644 --- a/src/handlers/ssh_authorized_keys.rs +++ b/src/handlers/ssh_authorized_keys.rs @@ -68,7 +68,7 @@ pub async fn get_authorized_keys( None => { debug!("Fetching SSH keys for all users in group {group_name}"); // fetch all users in group - let users = group.fetch_all_members(&appstate.pool).await?; + let users = group.members(&appstate.pool).await?; for user in users { add_user_keys_to_list(user); } diff --git a/src/handlers/user.rs b/src/handlers/user.rs index 1f6a9ae1a..c3b650632 100644 --- a/src/handlers/user.rs +++ b/src/handlers/user.rs @@ -113,7 +113,7 @@ pub async fn add_user( // check username if let Err(err) = check_username(&username) { - debug!("{}", err); + debug!("{err}"); return Ok(ApiResponse { json: json!({}), status: StatusCode::BAD_REQUEST, @@ -281,7 +281,7 @@ pub async fn modify_user( debug!("User {} updating user {username}", session.user.username); let mut user = user_for_admin_or_self(&appstate.pool, &session, &username).await?; if let Err(err) = check_username(&user_info.username) { - debug!("{}", err); + debug!("Failed to check username {} {err}", user_info.username); return Ok(ApiResponse { json: json!({}), status: StatusCode::BAD_REQUEST, diff --git a/src/hex.rs b/src/hex.rs index 44fb14693..3f6cd5438 100644 --- a/src/hex.rs +++ b/src/hex.rs @@ -73,7 +73,7 @@ pub fn to_lower_hex(bytes: &[u8]) -> String { mod tests { use super::*; - #[std::prelude::v1::test] + #[test] fn test_hex_decode() { assert_eq!(hex_decode("deadf00d"), Ok(vec![0xde, 0xad, 0xf0, 0x0d])); assert_eq!(hex_decode("0Xdeadf00d"), Ok(vec![0xde, 0xad, 0xf0, 0x0d])); diff --git a/src/ldap/error.rs b/src/ldap/error.rs index 185fd33ec..ac790cd33 100644 --- a/src/ldap/error.rs +++ b/src/ldap/error.rs @@ -1,29 +1,28 @@ -use ldap3::LdapError; use std::{error::Error, fmt}; #[derive(Debug)] -pub enum OriLDAPError { +pub enum LdapError { Ldap(String), ObjectNotFound(String), MissingSettings, } -impl fmt::Display for OriLDAPError { +impl fmt::Display for LdapError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - OriLDAPError::Ldap(msg) => write!(f, "LDAP error: {msg}"), - OriLDAPError::ObjectNotFound(msg) => write!(f, "Object not found: {msg}"), - OriLDAPError::MissingSettings => { + LdapError::Ldap(msg) => write!(f, "LDAP error: {msg}"), + LdapError::ObjectNotFound(msg) => write!(f, "Object not found: {msg}"), + LdapError::MissingSettings => { write!(f, "LDAP settings are missing.") } } } } -impl Error for OriLDAPError {} +impl Error for LdapError {} -impl From for OriLDAPError { - fn from(error: LdapError) -> Self { +impl From for LdapError { + fn from(error: ldap3::LdapError) -> Self { Self::Ldap(error.to_string()) } } diff --git a/src/ldap/hash.rs b/src/ldap/hash.rs index 1fd5e624b..33cc69725 100644 --- a/src/ldap/hash.rs +++ b/src/ldap/hash.rs @@ -6,6 +6,8 @@ use sha1::{ Digest, Sha1, }; +use crate::hex::to_lower_hex; + /// Calculate salted SHA1 hash from given password in SSHA password storage scheme. #[must_use] pub fn salted_sha1_hash(password: &str) -> String { @@ -32,14 +34,14 @@ pub fn nthash(password: &str) -> String { .encode_utf16() .flat_map(|c| IntoIterator::into_iter(c.to_le_bytes())) .collect(); - format!("{:x}", Md4::digest(password_utf16_le)) + to_lower_hex(&Md4::digest(password_utf16_le)) } #[cfg(test)] mod tests { use super::*; - #[std::prelude::v1::test] + #[test] fn test_hash() { assert_eq!(nthash("password"), "8846f7eaee8fb117ad06bdd830b7586c"); assert_eq!( diff --git a/src/ldap/mod.rs b/src/ldap/mod.rs index 62459a938..770a0a1a3 100644 --- a/src/ldap/mod.rs +++ b/src/ldap/mod.rs @@ -1,8 +1,11 @@ -use self::{error::OriLDAPError, model::Group}; -use crate::db::{DbPool, Settings, User}; -use ldap3::{drive, Ldap, LdapConnAsync, Mod, Scope, SearchEntry}; use std::collections::HashSet; +use ldap3::{drive, Ldap, LdapConnAsync, Mod, Scope, SearchEntry}; +use sqlx::PgExecutor; + +use self::error::LdapError; +use crate::db::{self, Settings, User}; + pub mod error; pub mod hash; pub mod model; @@ -21,7 +24,7 @@ macro_rules! hashset { }; } -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct LDAPConfig { pub ldap_bind_username: String, pub ldap_group_search_base: String, @@ -39,8 +42,8 @@ impl LDAPConfig { #[must_use] pub fn user_dn(&self, username: &str) -> String { format!( - "{}={},{}", - &self.ldap_username_attr, username, &self.ldap_user_search_base, + "{}={username},{}", + self.ldap_username_attr, self.ldap_user_search_base, ) } @@ -48,44 +51,44 @@ impl LDAPConfig { #[must_use] pub fn group_dn(&self, groupname: &str) -> String { format!( - "{}={},{}", - &self.ldap_groupname_attr, groupname, &self.ldap_group_search_base, + "{}={groupname},{}", + self.ldap_groupname_attr, self.ldap_group_search_base, ) } } impl TryFrom for LDAPConfig { - type Error = OriLDAPError; + type Error = LdapError; - fn try_from(settings: Settings) -> Result { + fn try_from(settings: Settings) -> Result { Ok(Self { ldap_member_attr: settings .ldap_member_attr - .ok_or(OriLDAPError::MissingSettings)?, + .ok_or(LdapError::MissingSettings)?, ldap_group_member_attr: settings .ldap_group_member_attr - .ok_or(OriLDAPError::MissingSettings)?, + .ok_or(LdapError::MissingSettings)?, ldap_groupname_attr: settings .ldap_groupname_attr - .ok_or(OriLDAPError::MissingSettings)?, + .ok_or(LdapError::MissingSettings)?, ldap_username_attr: settings .ldap_username_attr - .ok_or(OriLDAPError::MissingSettings)?, + .ok_or(LdapError::MissingSettings)?, ldap_group_obj_class: settings .ldap_group_obj_class - .ok_or(OriLDAPError::MissingSettings)?, + .ok_or(LdapError::MissingSettings)?, ldap_user_obj_class: settings .ldap_user_obj_class - .ok_or(OriLDAPError::MissingSettings)?, + .ok_or(LdapError::MissingSettings)?, ldap_user_search_base: settings .ldap_user_search_base - .ok_or(OriLDAPError::MissingSettings)?, + .ok_or(LdapError::MissingSettings)?, ldap_bind_username: settings .ldap_bind_username - .ok_or(OriLDAPError::MissingSettings)?, + .ok_or(LdapError::MissingSettings)?, ldap_group_search_base: settings .ldap_group_search_base - .ok_or(OriLDAPError::MissingSettings)?, + .ok_or(LdapError::MissingSettings)?, }) } } @@ -96,18 +99,21 @@ pub struct LDAPConnection { } impl LDAPConnection { - pub async fn create(pool: &DbPool) -> Result { - let settings = Settings::get_settings(pool) + pub async fn create<'e, E>(executor: E) -> Result + where + E: PgExecutor<'e>, + { + let settings = Settings::get_settings(executor) .await - .map_err(|_| OriLDAPError::MissingSettings)?; + .map_err(|_| LdapError::MissingSettings)?; let config = LDAPConfig::try_from(settings.clone())?; - let url = settings.ldap_url.ok_or(OriLDAPError::MissingSettings)?; + let url = settings.ldap_url.ok_or(LdapError::MissingSettings)?; let password = settings .ldap_bind_password - .ok_or(OriLDAPError::MissingSettings)?; + .ok_or(LdapError::MissingSettings)?; let (conn, mut ldap) = LdapConnAsync::new(&url).await?; drive!(conn); - info!("Connected to LDAP: {}", &url); + info!("Connected to LDAP: {url}"); ldap.simple_bind(&config.ldap_bind_username, password.expose_secret()) .await? .success()?; @@ -115,7 +121,7 @@ impl LDAPConnection { } /// Searches LDAP for users. - async fn search_users(&mut self, filter: &str) -> Result, OriLDAPError> { + async fn search_users(&mut self, filter: &str) -> Result, LdapError> { let (rs, _res) = self .ldap .search( @@ -126,38 +132,34 @@ impl LDAPConnection { ) .await? .success()?; - info!("Performed LDAP user search with filter = {}", filter); + info!("Performed LDAP user search with filter = {filter}"); Ok(rs.into_iter().map(SearchEntry::construct).collect()) } /// Searches LDAP for groups. - async fn search_groups(&mut self, filter: &str) -> Result, OriLDAPError> { - let (rs, _res) = self - .ldap - .search( - &self.config.ldap_group_search_base, - Scope::Subtree, - filter, - vec![ - &self.config.ldap_username_attr, - &self.config.ldap_group_member_attr, - ], - ) - .await? - .success()?; - info!("Performed LDAP group search with filter = {}", filter); - Ok(rs.into_iter().map(SearchEntry::construct).collect()) - } + // async fn search_groups(&mut self, filter: &str) -> Result, LdapError> { + // let (rs, _res) = self + // .ldap + // .search( + // &self.config.ldap_group_search_base, + // Scope::Subtree, + // filter, + // vec![ + // &self.config.ldap_username_attr, + // &self.config.ldap_group_member_attr, + // ], + // ) + // .await? + // .success()?; + // info!("Performed LDAP group search with filter = {filter}"); + // Ok(rs.into_iter().map(SearchEntry::construct).collect()) + // } /// Creates LDAP object with specified distinguished name and attributes. - async fn add( - &mut self, - dn: &str, - attrs: Vec<(&str, HashSet<&str>)>, - ) -> Result<(), OriLDAPError> { - debug!("Adding object {}", dn); + async fn add(&mut self, dn: &str, attrs: Vec<(&str, HashSet<&str>)>) -> Result<(), LdapError> { + debug!("Adding object {dn}"); self.ldap.add(dn, attrs).await?.success()?; - info!("Added object {}", dn); + info!("Added object {dn}"); Ok(()) } @@ -167,23 +169,23 @@ impl LDAPConnection { old_dn: &str, new_dn: &str, mods: Vec>, - ) -> Result<(), OriLDAPError> { - debug!("Modifying object {}", old_dn); + ) -> Result<(), LdapError> { + debug!("Modifying LDAP object {old_dn}"); self.ldap.modify(old_dn, mods).await?; if old_dn != new_dn { if let Some((new_rdn, _rest)) = new_dn.split_once(',') { self.ldap.modifydn(old_dn, new_rdn, true, None).await?; } } - info!("Modified object {}", old_dn); + info!("Modified LDAP object {old_dn}"); Ok(()) } /// Deletes LDAP object with specified distinguished name. - pub async fn delete(&mut self, dn: &str) -> Result<(), OriLDAPError> { - debug!("Deleting object {}", dn); + pub async fn delete(&mut self, dn: &str) -> Result<(), LdapError> { + debug!("Deleting LDAP object {dn}"); self.ldap.delete(dn).await?; - info!("Deleted object {}", dn); + info!("Deleted LDAP object {dn}"); Ok(()) } @@ -191,8 +193,8 @@ impl LDAPConnection { pub async fn is_username_available(&mut self, username: &str) -> bool { let users = self .search_users(&format!( - "(&({}={})(|(objectClass={})))", - self.config.ldap_username_attr, username, self.config.ldap_user_obj_class + "(&({}={username})(|(objectClass={})))", + self.config.ldap_username_attr, self.config.ldap_user_obj_class )) .await; match users { @@ -203,26 +205,26 @@ impl LDAPConnection { /// Retrieves user with given username from LDAP. /// TODO: Password must agree with the password stored in LDAP. - pub async fn get_user(&mut self, username: &str, password: &str) -> Result { + pub async fn get_user(&mut self, username: &str, password: &str) -> Result { debug!("Performing LDAP user search: {username}"); let mut entries = self .search_users(&format!( - "(&({}={})(objectClass={}))", - self.config.ldap_username_attr, username, self.config.ldap_user_obj_class + "(&({}={username})(objectClass={}))", + self.config.ldap_username_attr, self.config.ldap_user_obj_class )) .await?; if let Some(entry) = entries.pop() { info!("Performed LDAP user search: {username}"); Ok(User::from_searchentry(&entry, username, password)) } else { - Err(OriLDAPError::ObjectNotFound(format!( + Err(LdapError::ObjectNotFound(format!( "User {username} not found", ))) } } /// Adds user to LDAP. - pub async fn add_user(&mut self, user: &User, password: &str) -> Result<(), OriLDAPError> { + pub async fn add_user(&mut self, user: &User, password: &str) -> Result<(), LdapError> { debug!("Adding LDAP user {}", user.username); let dn = self.config.user_dn(&user.username); let ssha_password = hash::salted_sha1_hash(password); @@ -234,7 +236,7 @@ impl LDAPConnection { } /// Modifies LDAP user. - pub async fn modify_user(&mut self, username: &str, user: &User) -> Result<(), OriLDAPError> { + pub async fn modify_user(&mut self, username: &str, user: &User) -> Result<(), LdapError> { debug!("Modifying user {username}"); let old_dn = self.config.user_dn(username); let new_dn = self.config.user_dn(&user.username); @@ -245,7 +247,7 @@ impl LDAPConnection { } /// Deletes user from LDAP. - pub async fn delete_user(&mut self, username: &str) -> Result<(), OriLDAPError> { + pub async fn delete_user(&mut self, username: &str) -> Result<(), LdapError> { debug!("Deleting user {username}"); let dn = self.config.user_dn(username); self.delete(&dn).await?; @@ -254,11 +256,7 @@ impl LDAPConnection { } /// Changes user password. - pub async fn set_password( - &mut self, - username: &str, - password: &str, - ) -> Result<(), OriLDAPError> { + pub async fn set_password(&mut self, username: &str, password: &str) -> Result<(), LdapError> { debug!("Setting password for user {username}"); let user_dn = self.config.user_dn(username); let ssha_password = hash::salted_sha1_hash(password); @@ -277,47 +275,66 @@ impl LDAPConnection { } /// Retrieves group with given groupname from LDAP. - pub async fn get_group(&mut self, groupname: &str) -> Result { - debug!("Performing LDAP group search: {groupname}"); - let mut enties = self - .search_groups(&format!( - "(&({}={})(objectClass={}))", - self.config.ldap_groupname_attr, groupname, self.config.ldap_group_obj_class - )) - .await?; - if let Some(entry) = enties.pop() { - info!("Performed LDAP user search: {groupname}"); - Ok(Group::from_searchentry(&entry, &self.config)) - } else { - Err(OriLDAPError::ObjectNotFound(format!( - "Group {groupname} not found" - ))) - } - } + // pub async fn get_group(&mut self, groupname: &str) -> Result { + // debug!("Performing LDAP group search: {groupname}"); + // let mut enties = self + // .search_groups(&format!( + // "(&({}={})(objectClass={}))", + // self.config.ldap_groupname_attr, groupname, self.config.ldap_group_obj_class + // )) + // .await?; + // if let Some(entry) = enties.pop() { + // info!("Performed LDAP user search: {groupname}"); + // Ok(Group::from_searchentry(&entry, &self.config)) + // } else { + // Err(LdapError::ObjectNotFound(format!( + // "Group {groupname} not found" + // ))) + // } + // } - /// Lists users satisfying specified criteria - pub async fn get_groups(&mut self) -> Result, OriLDAPError> { - debug!("Performing LDAP group search"); - let mut entries = self - .search_groups(&format!( - "(objectClass={})", - self.config.ldap_group_obj_class - )) - .await?; - let users = entries - .drain(..) - .map(|entry| Group::from_searchentry(&entry, &self.config)) - .collect(); - info!("Performed LDAP group search"); - Ok(users) + /// Modifies LDAP group. + pub async fn modify_group( + &mut self, + groupname: &str, + group: &db::Group, + ) -> Result<(), LdapError> { + debug!("Modifying LDAP group {groupname}"); + let old_dn = self.config.group_dn(groupname); + let new_dn = self.config.group_dn(&group.name); + self.modify( + &old_dn, + &new_dn, + vec![Mod::Replace("cn", hashset![group.name.as_str()])], + ) + .await?; + info!("Modified LDAP group {groupname}"); + Ok(()) } + /// Lists groups satisfying specified criteria + // pub async fn get_groups(&mut self) -> Result, LdapError> { + // debug!("Performing LDAP group search"); + // let mut entries = self + // .search_groups(&format!( + // "(objectClass={})", + // self.config.ldap_group_obj_class + // )) + // .await?; + // let users = entries + // .drain(..) + // .map(|entry| Group::from_searchentry(&entry, &self.config)) + // .collect(); + // info!("Performed LDAP group search"); + // Ok(users) + // } + /// Add user to a group. pub async fn add_user_to_group( &mut self, username: &str, groupname: &str, - ) -> Result<(), OriLDAPError> { + ) -> Result<(), LdapError> { let user_dn = self.config.user_dn(username); let group_dn = self.config.group_dn(groupname); self.modify( @@ -337,7 +354,7 @@ impl LDAPConnection { &mut self, username: &str, groupname: &str, - ) -> Result<(), OriLDAPError> { + ) -> Result<(), LdapError> { let user_dn = self.config.user_dn(username); let group_dn = self.config.group_dn(groupname); self.modify( diff --git a/src/ldap/model.rs b/src/ldap/model.rs index 40494da6d..e16d0ca90 100644 --- a/src/ldap/model.rs +++ b/src/ldap/model.rs @@ -1,8 +1,9 @@ -use crate::{db::User, hashset}; -use ldap3::{Mod, SearchEntry}; use std::collections::HashSet; +use ldap3::{Mod, SearchEntry}; + use super::LDAPConfig; +use crate::{db::User, hashset}; impl User { #[must_use] @@ -67,26 +68,27 @@ impl User { } } -pub struct Group { - pub name: String, - pub members: Vec, -} +// TODO: This struct is similar to `GroupInfo`, so maybe use one? +// pub(crate) struct Group { +// pub name: String, +// pub members: Vec, +// } -impl Group { - #[must_use] - pub fn from_searchentry(entry: &SearchEntry, config: &LDAPConfig) -> Self { - Self { - name: get_value_or_default(entry, &config.ldap_groupname_attr), - members: match entry.attrs.get(&config.ldap_group_member_attr) { - Some(members) => members - .iter() - .filter_map(|member| extract_dn_value(member)) - .collect(), - None => Vec::new(), - }, - } - } -} +// impl Group { +// #[must_use] +// pub(crate) fn from_searchentry(entry: &SearchEntry, config: &LDAPConfig) -> Self { +// Self { +// name: get_value_or_default(entry, &config.ldap_groupname_attr), +// members: match entry.attrs.get(&config.ldap_group_member_attr) { +// Some(members) => members +// .iter() +// .filter_map(|member| extract_dn_value(member)) +// .collect(), +// None => Vec::new(), +// }, +// } +// } +// } fn get_value_or_default(entry: &SearchEntry, key: &str) -> String { match entry.attrs.get(key) { diff --git a/src/ldap/utils.rs b/src/ldap/utils.rs index deebde04e..1fa8e07e5 100644 --- a/src/ldap/utils.rs +++ b/src/ldap/utils.rs @@ -1,19 +1,24 @@ -use super::{error::OriLDAPError, LDAPConnection}; -use crate::db::{DbPool, User}; +use sqlx::PgExecutor; + +use super::{error::LdapError, LDAPConnection}; +use crate::db::{DbPool, Group, User}; pub async fn user_from_ldap( pool: &DbPool, username: &str, password: &str, -) -> Result { +) -> Result { let mut ldap_connection = LDAPConnection::create(pool).await?; let mut user = ldap_connection.get_user(username, password).await?; let _result = user.save(pool).await; // FIXME: do not ignore errors Ok(user) } -pub async fn ldap_add_user(pool: &DbPool, user: &User, password: &str) -> Result<(), OriLDAPError> { - let mut ldap_connection = LDAPConnection::create(pool).await?; +pub async fn ldap_add_user<'e, E>(executor: E, user: &User, password: &str) -> Result<(), LdapError> +where + E: PgExecutor<'e>, +{ + let mut ldap_connection = LDAPConnection::create(executor).await?; match ldap_connection.add_user(user, password).await { Ok(()) => Ok(()), // this user might exist in LDAP, just try to set the password @@ -21,45 +26,72 @@ pub async fn ldap_add_user(pool: &DbPool, user: &User, password: &str) -> Result } } -pub async fn ldap_modify_user( - pool: &DbPool, +pub async fn ldap_modify_user<'e, E>( + executor: E, username: &str, user: &User, -) -> Result<(), OriLDAPError> { - let mut ldap_connection = LDAPConnection::create(pool).await?; +) -> Result<(), LdapError> +where + E: PgExecutor<'e>, +{ + let mut ldap_connection = LDAPConnection::create(executor).await?; ldap_connection.modify_user(username, user).await } -pub async fn ldap_delete_user(pool: &DbPool, username: &str) -> Result<(), OriLDAPError> { - let mut ldap_connection = LDAPConnection::create(pool).await?; +pub async fn ldap_delete_user<'e, E>(executor: E, username: &str) -> Result<(), LdapError> +where + E: PgExecutor<'e>, +{ + let mut ldap_connection = LDAPConnection::create(executor).await?; ldap_connection.delete_user(username).await } -pub async fn ldap_add_user_to_group( - pool: &DbPool, +pub async fn ldap_add_user_to_group<'e, E>( + executor: E, username: &str, groupname: &str, -) -> Result<(), OriLDAPError> { - let mut ldap_connection = LDAPConnection::create(pool).await?; +) -> Result<(), LdapError> +where + E: PgExecutor<'e>, +{ + let mut ldap_connection = LDAPConnection::create(executor).await?; ldap_connection.add_user_to_group(username, groupname).await } -pub async fn ldap_remove_user_from_group( - pool: &DbPool, +pub async fn ldap_remove_user_from_group<'e, E>( + executor: E, username: &str, groupname: &str, -) -> Result<(), OriLDAPError> { - let mut ldap_connection = LDAPConnection::create(pool).await?; +) -> Result<(), LdapError> +where + E: PgExecutor<'e>, +{ + let mut ldap_connection = LDAPConnection::create(executor).await?; ldap_connection .remove_user_from_group(username, groupname) .await } -pub async fn ldap_change_password( - pool: &DbPool, +pub async fn ldap_change_password<'e, E>( + executor: E, username: &str, password: &str, -) -> Result<(), OriLDAPError> { - let mut ldap_connection = LDAPConnection::create(pool).await?; +) -> Result<(), LdapError> +where + E: PgExecutor<'e>, +{ + let mut ldap_connection = LDAPConnection::create(executor).await?; ldap_connection.set_password(username, password).await } + +pub async fn ldap_modify_group<'e, E>( + executor: E, + groupname: &str, + group: &Group, +) -> Result<(), LdapError> +where + E: PgExecutor<'e>, +{ + let mut ldap_connection = LDAPConnection::create(executor).await?; + ldap_connection.modify_group(groupname, group).await +} diff --git a/src/lib.rs b/src/lib.rs index 82fe3bef8..f84ba0bde 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ use axum::{ serve, Extension, Router, }; use handlers::{ + group::{create_group, delete_group, modify_group}, settings::{get_settings_essentials, patch_settings, test_ldap_settings}, user::reset_password, }; @@ -205,6 +206,9 @@ pub fn build_webapp( // group .route("/group", get(list_groups)) .route("/group/:name", get(get_group)) + .route("/group", post(create_group)) + .route("/group/:name", put(modify_group)) + .route("/group/:name", delete(delete_group)) .route("/group/:name", post(add_group_member)) .route("/group/:name/user/:username", delete(remove_group_member)) // mail diff --git a/src/support.rs b/src/support.rs index 5655b7f54..19836167c 100644 --- a/src/support.rs +++ b/src/support.rs @@ -2,16 +2,15 @@ use std::{collections::HashMap, fmt::Display}; use serde::Serialize; use serde_json::{json, value::to_value, Value}; -use sqlx::{Pool, Postgres}; use crate::{ config::DefGuardConfig, - db::{models::device::WireguardNetworkDevice, Settings, User, WireguardNetwork}, + db::{models::device::WireguardNetworkDevice, DbPool, Settings, User, WireguardNetwork}, VERSION, }; /// Unwraps the result returning a JSON representation of value or error -fn unwrap_json(result: Result) -> Value { +fn unwrap_json(result: Result) -> Value { match result { Ok(value) => to_value(value).expect("conversion to JSON failed"), Err(err) => json!({"error": err.to_string()}), @@ -19,7 +18,7 @@ fn unwrap_json(result: Result) -> Value { } /// Dumps all data that could be used for debugging. -pub async fn dump_config(db: &Pool, config: &DefGuardConfig) -> Value { +pub async fn dump_config(db: &DbPool, config: &DefGuardConfig) -> Value { // App settings DB records let settings = match Settings::find_by_id(db, 1).await { Ok(Some(mut settings)) => { @@ -33,7 +32,7 @@ pub async fn dump_config(db: &Pool, config: &DefGuardConfig) -> Value let (networks, devices) = match WireguardNetwork::all(db).await { Ok(networks) => { // Devices for each network - let mut devices = HashMap::::default(); + let mut devices = HashMap::::new(); for network in &networks { let Some(network_id) = network.id else { continue; diff --git a/src/templates.rs b/src/templates.rs index 5c19636aa..0f9cbb1d3 100644 --- a/src/templates.rs +++ b/src/templates.rs @@ -410,11 +410,11 @@ mod test { #[test] fn test_enrollment_admin_notification() { let test_user: User = User::new( - "test".into(), - "1234".into(), - "test_last".into(), - "test_first".into(), - "test@example.com".into(), + "test", + Some("1234"), + "test_last", + "test_first", + "test@example.com", Some("99999".into()), ); assert_ok!(enrollment_admin_notification( diff --git a/tests/auth.rs b/tests/auth.rs index 2162bf81a..05088fe62 100644 --- a/tests/auth.rs +++ b/tests/auth.rs @@ -35,10 +35,10 @@ async fn make_client() -> TestClient { let mut wallet = Wallet::new_for_user( client_state.test_user.id.unwrap(), - "0x4aF8803CBAD86BA65ED347a3fbB3fb50e96eDD3e".into(), - "test".into(), + "0x4aF8803CBAD86BA65ED347a3fbB3fb50e96eDD3e", + "test", 5, - String::new(), + "", ); wallet.save(&client_state.pool).await.unwrap(); @@ -50,10 +50,10 @@ async fn make_client_with_db() -> (TestClient, DbPool) { let mut wallet = Wallet::new_for_user( client_state.test_user.id.unwrap(), - "0x4aF8803CBAD86BA65ED347a3fbB3fb50e96eDD3e".into(), - "test".into(), + "0x4aF8803CBAD86BA65ED347a3fbB3fb50e96eDD3e", + "test", 5, - String::new(), + "", ); wallet.save(&client_state.pool).await.unwrap(); @@ -65,26 +65,21 @@ async fn make_client_with_state() -> (TestClient, ClientState) { let mut wallet = Wallet::new_for_user( client_state.test_user.id.unwrap(), - "0x4aF8803CBAD86BA65ED347a3fbB3fb50e96eDD3e".into(), - "test".into(), + "0x4aF8803CBAD86BA65ED347a3fbB3fb50e96eDD3e", + "test", 5, - String::new(), + "", ); wallet.save(&client_state.pool).await.unwrap(); (client, client_state) } -async fn make_client_with_wallet(address: String) -> TestClient { +async fn make_client_with_wallet(address: &str) -> TestClient { let (client, client_state) = make_test_client().await; - let mut wallet = Wallet::new_for_user( - client_state.test_user.id.unwrap(), - address, - "test".into(), - 5, - String::new(), - ); + let mut wallet = + Wallet::new_for_user(client_state.test_user.id.unwrap(), address, "test", 5, ""); wallet.save(&client_state.pool).await.unwrap(); client @@ -94,7 +89,7 @@ async fn make_client_with_wallet(address: String) -> TestClient { async fn test_logout() { let mut client = make_client().await; - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -123,7 +118,7 @@ async fn test_logout() { async fn test_login_bruteforce() { let client = make_client().await; - let invalid_auth = Auth::new("hpotter".into(), "invalid".into()); + let invalid_auth = Auth::new("hpotter", "invalid"); // fail login 5 times in a row for i in 0..6 { @@ -140,7 +135,7 @@ async fn test_login_bruteforce() { async fn test_cannot_enable_mfa() { let client = make_client().await; - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -163,7 +158,7 @@ async fn test_totp() { let client = make_client().await; // login - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -262,7 +257,7 @@ async fn test_totp() { assert_eq!(response.status(), StatusCode::OK); // login again - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); } @@ -284,7 +279,7 @@ async fn test_email_mfa() { assert_eq!(response.status(), StatusCode::UNAUTHORIZED); // login - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -410,7 +405,7 @@ async fn test_email_mfa() { assert_eq!(response.status(), StatusCode::OK); // login again - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); } @@ -423,7 +418,7 @@ async fn test_webauthn() { let origin = Url::parse("http://localhost:8000").unwrap(); // login - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -451,7 +446,7 @@ async fn test_webauthn() { assert_eq!(response.status(), StatusCode::OK); // login again - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::CREATED); @@ -480,7 +475,7 @@ async fn test_webauthn() { assert_eq!(response.status(), StatusCode::OK); // login again - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -500,7 +495,7 @@ async fn test_cannot_skip_otp_by_adding_yubikey() { let client = make_client().await; // login - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -535,7 +530,7 @@ async fn test_cannot_skip_security_key_by_adding_yubikey() { let origin = Url::parse("http://localhost:8000").unwrap(); // login - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -559,7 +554,7 @@ async fn test_cannot_skip_security_key_by_adding_yubikey() { assert_eq!(response.status(), StatusCode::OK); // login again - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::CREATED); @@ -573,7 +568,7 @@ async fn test_mfa_method_is_updated_when_removing_last_webauthn_passkey() { let client = make_client().await; // login - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -651,7 +646,7 @@ async fn test_mfa_method_is_updated_when_removing_last_webauthn_passkey() { assert_eq!(response.status(), StatusCode::OK); // login again - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::CREATED); @@ -780,10 +775,10 @@ async fn test_web3() { let wallet_address = to_lower_hex(addr); // create client - let client = make_client_with_wallet(wallet_address.clone()).await; + let client = make_client_with_wallet(&wallet_address).await; // login - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -806,7 +801,7 @@ async fn test_web3() { assert_eq!(response.status(), StatusCode::OK); // login with wallet - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::CREATED); wallet_login(&client, wallet_address, &secp, secret_key).await; @@ -816,7 +811,7 @@ async fn test_web3() { assert_eq!(response.status(), StatusCode::OK); // login again - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); } @@ -836,7 +831,7 @@ async fn test_re_adding_wallet() { let client = make_client().await; // login - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -881,7 +876,7 @@ async fn test_re_adding_wallet() { assert_eq!(response.status(), StatusCode::OK); // login with wallet - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::CREATED); wallet_login(&client, wallet_address.clone(), &secp, secret_key).await; @@ -898,7 +893,7 @@ async fn test_re_adding_wallet() { assert_eq!(response.status(), StatusCode::OK); // login without MFA - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -941,7 +936,7 @@ async fn test_re_adding_wallet() { assert_eq!(response.status(), StatusCode::OK); // login with wallet - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::CREATED); wallet_login(&client, wallet_address.clone(), &secp, secret_key).await; @@ -954,7 +949,7 @@ async fn test_mfa_method_totp_enabled_mail() { let user_agent_header = "Mozilla/5.0 (iPhone; CPU iPhone OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1"; // login - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client .post("/api/v1/auth") .header(USER_AGENT, user_agent_header) @@ -994,7 +989,7 @@ async fn test_new_device_login() { let user_agent_header_android = "Mozilla/5.0 (Linux; Android 7.0; SM-G930VC Build/NRD90M; wv) AppleWebKit/537.36 (KHTML, like Gecko) Version/4.0 Chrome/58.0.3029.83 Mobile Safari/537.36"; // login - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client .post("/api/v1/auth") .header(USER_AGENT, user_agent_header_iphone) @@ -1018,7 +1013,7 @@ async fn test_new_device_login() { assert_eq!(response.status(), StatusCode::OK); // login using the same device - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client .post("/api/v1/auth") .header(USER_AGENT, user_agent_header_iphone) @@ -1030,7 +1025,7 @@ async fn test_new_device_login() { assert_err!(mail_rx.try_recv()); // login using a different device - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client .post("/api/v1/auth") .header(USER_AGENT, user_agent_header_android) @@ -1057,7 +1052,7 @@ async fn test_login_ip_headers() { let user_agent_header_iphone = "Mozilla/5.0 (iPhone; CPU iPhone OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1"; // Works with X-Forwarded-For header - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client .post("/api/v1/auth") .header(USER_AGENT, user_agent_header_iphone) @@ -1080,7 +1075,7 @@ async fn test_login_ip_headers() { async fn test_session_cookie() { let (client, pool) = make_client_with_db().await; - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/common/client.rs b/tests/common/client.rs index e6a8fe07b..af1374d09 100644 --- a/tests/common/client.rs +++ b/tests/common/client.rs @@ -6,7 +6,7 @@ use reqwest::{ cookie::{Cookie, Jar}, header::{HeaderMap, HeaderName}, redirect::Policy, - Client, StatusCode, Url, + Body, Client, StatusCode, Url, }; use tokio::net::TcpListener; @@ -121,7 +121,7 @@ impl RequestBuilder { } } - pub fn body(mut self, body: impl Into) -> Self { + pub fn body>(mut self, body: B) -> Self { self.builder = self.builder.body(body); self } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index b46517f0b..7fd848b31 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -66,11 +66,11 @@ async fn initialize_users(pool: &DbPool, config: DefGuardConfig) { .unwrap(); let mut test_user = User::new( - "hpotter".into(), + "hpotter", Some("pass123"), - "Potter".into(), - "Harry".into(), - "h.potter@hogwart.edu.uk".into(), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", None, ); test_user.save(pool).await.unwrap(); diff --git a/tests/enrollment.rs b/tests/enrollment.rs index 5bb40fb66..1561e990b 100644 --- a/tests/enrollment.rs +++ b/tests/enrollment.rs @@ -19,7 +19,7 @@ async fn make_client() -> (TestClient, DbPool) { async fn test_initialize_enrollment() { let (client, pool) = make_client().await; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/forward_auth.rs b/tests/forward_auth.rs index 6c7f9debc..9b939b354 100644 --- a/tests/forward_auth.rs +++ b/tests/forward_auth.rs @@ -10,10 +10,10 @@ async fn make_client() -> TestClient { let mut wallet = Wallet::new_for_user( client_state.test_user.id.unwrap(), - "0x4aF8803CBAD86BA65ED347a3fbB3fb50e96eDD3e".into(), - "test".into(), + "0x4aF8803CBAD86BA65ED347a3fbB3fb50e96eDD3e", + "test", 5, - String::new(), + "", ); wallet.save(&client_state.pool).await.unwrap(); @@ -43,7 +43,7 @@ async fn test_forward_auth() { ); // login - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/group.rs b/tests/group.rs new file mode 100644 index 000000000..dc62ff9d2 --- /dev/null +++ b/tests/group.rs @@ -0,0 +1,103 @@ +mod common; + +use defguard::handlers::{Auth, GroupInfo}; +use reqwest::StatusCode; + +use self::common::make_test_client; + +#[tokio::test] +async fn test_create_group() { + let (client, _) = make_test_client().await; + + // Authorize as an administrator. + let auth = Auth::new("admin", "pass123"); + let response = client.post("/api/v1/auth").json(&auth).send().await; + assert_eq!(response.status(), StatusCode::OK); + + // Create new group. + let data = GroupInfo::new("hogwards", Some(vec!["hpotter".into()])); + let response = client.post("/api/v1/group").json(&data).send().await; + assert_eq!(response.status(), StatusCode::CREATED); + + // Try to create the same group again. + let response = client.post("/api/v1/group").json(&data).send().await; + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + + // Delete the group. + let response = client.delete("/api/v1/group/hogwards").send().await; + assert_eq!(response.status(), StatusCode::OK); + + // Try to delete again. + let response = client.delete("/api/v1/group/hogwards").send().await; + assert_eq!(response.status(), StatusCode::NOT_FOUND); +} + +#[tokio::test] +async fn test_modify_group() { + let (client, _) = make_test_client().await; + + // Authorize as an administrator. + let auth = Auth::new("admin", "pass123"); + let response = client.post("/api/v1/auth").json(&auth).send().await; + assert_eq!(response.status(), StatusCode::OK); + + // Create new group. + let data = GroupInfo::new("hogwards", Some(vec!["hpotter".into()])); + let response = client.post("/api/v1/group").json(&data).send().await; + assert_eq!(response.status(), StatusCode::CREATED); + + // Rename group. + let data = GroupInfo::new("gryffindor", None); + let response = client + .put("/api/v1/group/hogwards") + .json(&data) + .send() + .await; + assert_eq!(response.status(), StatusCode::OK); + + // Try to get the group by its old name. + let response = client.get("/api/v1/group/hogwards").send().await; + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + // Get group info. + let response = client.get("/api/v1/group/gryffindor").send().await; + assert_eq!(response.status(), StatusCode::OK); + let group_info: GroupInfo = response.json().await; + assert_eq!(group_info.name, "gryffindor"); +} + +#[tokio::test] +async fn test_modify_group_members() { + let (client, _) = make_test_client().await; + + // Authorize as an administrator. + let auth = Auth::new("admin", "pass123"); + let response = client.post("/api/v1/auth").json(&auth).send().await; + assert_eq!(response.status(), StatusCode::OK); + + // Create new group. + let data = GroupInfo::new("hogwards", Some(vec!["hpotter".into()])); + let response = client.post("/api/v1/group").json(&data).send().await; + assert_eq!(response.status(), StatusCode::CREATED); + + // Get group info. + let response = client.get("/api/v1/group/hogwards").send().await; + assert_eq!(response.status(), StatusCode::OK); + let group_info: GroupInfo = response.json().await; + assert_eq!(group_info.members.unwrap(), vec!["hpotter".to_string()]); + + // Change group members. + let data = GroupInfo::new("hogwards", Some(Vec::new())); + let response = client + .put("/api/v1/group/hogwards") + .json(&data) + .send() + .await; + assert_eq!(response.status(), StatusCode::OK); + + // Get group info. + let response = client.get("/api/v1/group/hogwards").send().await; + assert_eq!(response.status(), StatusCode::OK); + let group_info: GroupInfo = response.json().await; + assert!(group_info.members.unwrap().is_empty()); +} diff --git a/tests/oauth.rs b/tests/oauth.rs index 7186206da..aebe081d1 100644 --- a/tests/oauth.rs +++ b/tests/oauth.rs @@ -26,7 +26,7 @@ async fn make_client() -> (TestClient, DbPool) { async fn test_authorize() { let (client, pool) = make_client().await; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -165,7 +165,7 @@ async fn test_openid_app_management_access() { let (client, _) = make_client().await; // login as admin - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -267,7 +267,7 @@ async fn test_openid_app_management_access() { let test_app = &apps[0]; // // login as standard user - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/openid.rs b/tests/openid.rs index bd58b222e..260576622 100644 --- a/tests/openid.rs +++ b/tests/openid.rs @@ -49,7 +49,7 @@ pub struct AuthenticationResponse<'r> { async fn test_openid_client() { let client = make_client().await; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -107,7 +107,7 @@ async fn test_openid_client() { #[tokio::test] async fn test_openid_flow() { let client = make_client().await; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); let openid_client = NewOpenIDClient { @@ -252,7 +252,7 @@ async fn test_openid_flow() { assert_eq!(response.status(), StatusCode::UNAUTHORIZED); // log back in - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -414,7 +414,7 @@ async fn test_openid_authorization_code() { .unwrap(); // create OAuth2 client - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); let oauth2client = NewOpenIDClient { @@ -519,7 +519,7 @@ async fn test_openid_authorization_code_with_pkce() { .unwrap(); // create OAuth2 client/application - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); let oauth2client = NewOpenIDClient { @@ -622,7 +622,7 @@ async fn test_openid_flow_new_login_mail() { let mut mail_rx = state.mail_rx; let user_agent_header = "Mozilla/5.0 (iPhone; CPU iPhone OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1"; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client .post("/api/v1/auth") .header(USER_AGENT, user_agent_header) diff --git a/tests/settings.rs b/tests/settings.rs index 2c38e4a2c..9d50e95c0 100644 --- a/tests/settings.rs +++ b/tests/settings.rs @@ -17,7 +17,7 @@ async fn make_client() -> (TestClient, ClientState) { #[tokio::test] async fn test_settings() { let (client, _client_state) = make_client().await; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); // get settings diff --git a/tests/user.rs b/tests/user.rs index c28af9d4f..dfcaf57a5 100644 --- a/tests/user.rs +++ b/tests/user.rs @@ -25,15 +25,15 @@ async fn make_client() -> TestClient { async fn test_authenticate() { let client = make_client().await; - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); - let auth = Auth::new("hpotter".into(), "-wrong-".into()); + let auth = Auth::new("hpotter", "-wrong-"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); - let auth = Auth::new("adumbledore".into(), "pass123".into()); + let auth = Auth::new("adumbledore", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } @@ -42,7 +42,7 @@ async fn test_authenticate() { async fn test_me() { let client = make_client().await; - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -57,7 +57,7 @@ async fn test_me() { async fn test_change_self_password() { let client = make_client().await; - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -106,7 +106,7 @@ async fn test_change_self_password() { let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); - let new_auth = Auth::new("hpotter".into(), new_password.into()); + let new_auth = Auth::new("hpotter", new_password); let response = client.post("/api/v1/auth").json(&new_auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -116,7 +116,7 @@ async fn test_change_self_password() { async fn test_change_password() { let client = make_client().await; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -146,7 +146,7 @@ async fn test_change_password() { assert_eq!(response.status(), StatusCode::OK); - let auth = Auth::new("hpotter".into(), new_password.to_string()); + let auth = Auth::new("hpotter", new_password); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -167,7 +167,7 @@ async fn test_list_users() { assert_eq!(response.status(), StatusCode::UNAUTHORIZED); // normal user cannot list users - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -175,7 +175,7 @@ async fn test_list_users() { assert_eq!(response.status(), StatusCode::FORBIDDEN); // admin can list users - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -190,7 +190,7 @@ async fn test_get_user() { let response = client.get("/api/v1/user/hpotter").send().await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -204,7 +204,7 @@ async fn test_username_available() { let client = make_client().await; // standard user cannot check username availability - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -219,7 +219,7 @@ async fn test_username_available() { assert_eq!(response.status(), StatusCode::FORBIDDEN); // log in as admin - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -258,7 +258,7 @@ async fn test_username_available() { async fn test_crud_user() { let client = make_client().await; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -296,7 +296,7 @@ async fn test_crud_user() { async fn test_admin_group() { let client = make_client().await; - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -313,7 +313,7 @@ async fn test_admin_group() { async fn test_wallet() { let client = make_client().await; - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -427,7 +427,7 @@ This request will not trigger a blockchain transaction or cost any gas fees."; async fn test_check_username() { let client = make_client().await; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -466,7 +466,7 @@ async fn test_check_password_strength() { let client = make_client().await; // auth session with admin - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -513,7 +513,7 @@ async fn test_check_password_strength() { #[tokio::test] async fn test_user_unregister_authorized_app() { let client = make_client().await; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); let openid_client = NewOpenIDClient { @@ -579,7 +579,7 @@ async fn test_user_add_device() { let user_agent_header = "Mozilla/5.0 (iPhone; CPU iPhone OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1"; // log in as admin - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client .post("/api/v1/auth") .header(USER_AGENT, user_agent_header) @@ -649,7 +649,7 @@ async fn test_user_add_device() { .contains("Device type: iPhone, OS: iOS 17.1, Mobile Safari")); // log in as normal user - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client .post("/api/v1/auth") .header(USER_AGENT, user_agent_header) diff --git a/tests/webhook.rs b/tests/webhook.rs index eab72420c..fe986d582 100644 --- a/tests/webhook.rs +++ b/tests/webhook.rs @@ -14,7 +14,7 @@ async fn make_client() -> TestClient { async fn test_webhooks() { let client = make_client().await; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/wireguard.rs b/tests/wireguard.rs index 1ec6bc110..5d4951d84 100644 --- a/tests/wireguard.rs +++ b/tests/wireguard.rs @@ -28,7 +28,7 @@ async fn test_network() { let mut wg_rx = client_state.wireguard_rx; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -96,7 +96,7 @@ async fn test_device() { let mut wg_rx = client_state.wireguard_rx; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -261,7 +261,7 @@ async fn test_device() { async fn test_device_permissions() { let (client, _) = make_test_client().await; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -323,7 +323,7 @@ async fn test_device_permissions() { assert_eq!(response.status(), StatusCode::CREATED); // normal user cannot add devices for other users or import multiple devices - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -385,7 +385,7 @@ async fn test_device_permissions() { assert_eq!(user_devices.len(), 3); // admin can list devices of other users - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -406,7 +406,7 @@ async fn test_device_pubkey() { let mut wg_rx = client_state.wireguard_rx; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/wireguard_network_allowed_groups.rs b/tests/wireguard_network_allowed_groups.rs index 8c34295aa..bf0a48569 100644 --- a/tests/wireguard_network_allowed_groups.rs +++ b/tests/wireguard_network_allowed_groups.rs @@ -53,11 +53,11 @@ async fn setup_test_users(pool: &DbPool) -> (Vec, Vec) { // standard user in other, non-allowed group let mut other_user = User::new( - "ssnape".into(), + "ssnape", Some("pass123"), - "Snape".into(), - "Severus".into(), - "s.snape@hogwart.edu.uk".into(), + "Snape", + "Severus", + "s.snape@hogwart.edu.uk", None, ); other_user.save(pool).await.unwrap(); @@ -76,11 +76,11 @@ async fn setup_test_users(pool: &DbPool) -> (Vec, Vec) { // standard user in no groups let mut non_group_user = User::new( - "dobby".into(), + "dobby", Some("pass123"), - "Elf".into(), - "Dobby".into(), - "dobby@hogwart.edu.uk".into(), + "Elf", + "Dobby", + "dobby@hogwart.edu.uk", None, ); non_group_user.save(pool).await.unwrap(); @@ -103,7 +103,7 @@ async fn test_create_new_network() { let mut wg_rx = client_state.wireguard_rx; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -142,7 +142,7 @@ async fn test_modify_network() { let mut wg_rx = client_state.wireguard_rx; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -276,7 +276,7 @@ async fn test_import_network_existing_devices() { let mut wg_rx = client_state.wireguard_rx; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -364,7 +364,7 @@ async fn test_import_mapping_devices() { let mut wg_rx = client_state.wireguard_rx; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -472,7 +472,7 @@ async fn test_modify_user() { let mut wg_rx = client_state.wireguard_rx; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/wireguard_network_import.rs b/tests/wireguard_network_import.rs index cedaa0587..ac80028c5 100644 --- a/tests/wireguard_network_import.rs +++ b/tests/wireguard_network_import.rs @@ -79,7 +79,7 @@ async fn test_config_import() { let mut wg_rx = client_state.wireguard_rx; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -228,7 +228,7 @@ async fn test_config_import_missing_interface() { "; let (client, _) = make_test_client().await; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -262,7 +262,7 @@ async fn test_config_import_invalid_key() { "; let (client, _) = make_test_client().await; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -321,7 +321,7 @@ async fn test_config_import_invalid_ip() { "; let (client, _) = make_test_client().await; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -354,7 +354,7 @@ async fn test_config_import_nonadmin() { PersistentKeepalive = 300 "; let (client, _) = make_test_client().await; - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/wireguard_network_stats.rs b/tests/wireguard_network_stats.rs index 4df73325b..2d909bd66 100644 --- a/tests/wireguard_network_stats.rs +++ b/tests/wireguard_network_stats.rs @@ -32,7 +32,7 @@ async fn test_stats() { let (client, client_state) = make_test_client().await; let pool = client_state.pool; - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); diff --git a/tests/worker.rs b/tests/worker.rs index 844bd38c0..702e3bb32 100644 --- a/tests/worker.rs +++ b/tests/worker.rs @@ -29,7 +29,7 @@ async fn test_scheduling_worker_jobs() { }; // normal user can only provision keys for themselves - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -57,7 +57,7 @@ async fn test_scheduling_worker_jobs() { assert_eq!(response.status(), StatusCode::FORBIDDEN); // admin user can provision keys for other users - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -146,7 +146,7 @@ async fn test_scheduling_worker_jobs() { assert_eq!(response.status(), StatusCode::OK); // // normal user can only fetch status of their own jobs - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -182,7 +182,7 @@ async fn test_worker_management_permissions() { } // admin can create worker tokens - let auth = Auth::new("admin".into(), "pass123".into()); + let auth = Auth::new("admin", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -203,7 +203,7 @@ async fn test_worker_management_permissions() { assert_eq!(workers.len(), 2); // normal user cannot create worker tokens - let auth = Auth::new("hpotter".into(), "pass123".into()); + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK);