Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix down migrations #658

Merged
merged 3 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
893 changes: 456 additions & 437 deletions Cargo.lock

Large diffs are not rendered by default.

29 changes: 16 additions & 13 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
[package]
name = "defguard"
version = "0.11.0"
version = "0.11.1"
edition = "2021"
license = "Apache-2.0"
homepage = "https://defguard.net/"
repository = "https://github.com/DefGuard/defguard"
rust-version = "1.76"

[workspace]

Expand All @@ -18,7 +19,7 @@ axum-extra = { version = "0.9", features = [
"cookie-private",
"typed-header",
] }
base64 = "0.21"
base64 = "0.22"
chrono = { version = "0.4", default-features = false, features = [
"clock",
"serde",
Expand All @@ -30,25 +31,26 @@ ethers-core = "2.0"
humantime = "2.1"
# match ipnetwork version from sqlx
ipnetwork = { version = "0.20", features = ["serde"] }
jsonwebtoken = "9.2"
jsonwebtoken = "9.3"
ldap3 = { version = "0.11", default-features = false, features = ["tls"] }
lettre = { version = "0.11", features = ["tokio1", "tokio1-native-tls"] }
md4 = "0.10"
mime_guess = "2.0"
model_derive = { path = "model-derive" }
openidconnect = { version = "3.4", default-features = false, optional = true }
openidconnect = { version = "3.5", default-features = false, optional = true }
otpauth = "0.4"
prost = "0.12"
pulldown-cmark = "0.9"
pulldown-cmark = "0.11"
rand = "0.8"
rand_core = { version = "0.6", default-features = false, features = [
"getrandom",
] }
# TODO: update reqwest when openidconnect also depends on http >= 1.0.
reqwest = { version = "0.11", features = ["json"] }
rsa = { version = "0.9", features = ["pem"] }
rust-embed = { version = "8.4", features = ["include-exclude"] }
rust-ini = "0.20"
secp256k1 = { version = "0.28", features = [
secp256k1 = { version = "0.29", features = [
"recovery",
"rand-std",
"global-context",
Expand All @@ -69,7 +71,7 @@ sqlx = { version = "0.7", features = [
] }
ssh-key = "0.6"
struct-patch = "0.4"
tera = "1.19"
tera = "1.20"
thiserror = "1.0"
# match axum-extra -> cookies
time = { version = "0.3", default-features = false }
Expand All @@ -88,16 +90,16 @@ tower-http = { version = "0.5", features = ["fs", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uaparser = "0.6"
uuid = { version = "1.4", features = ["v4"] }
webauthn-authenticator-rs = { version = "0.4" }
webauthn-rs = { version = "0.4", features = [
uuid = { version = "1.9", features = ["v4"] }
webauthn-authenticator-rs = { version = "0.5" }
webauthn-rs = { version = "0.5", features = [
"danger-allow-state-serialisation",
] }
webauthn-rs-proto = "0.4"
webauthn-rs-proto = "0.5"
x25519-dalek = { version = "2.0", features = ["static_secrets"] }

[dev-dependencies]
bytes = "1.5"
bytes = "1.6"
claims = "0.7"
matches = "0.1"
regex = "1.10"
Expand All @@ -108,7 +110,8 @@ reqwest = { version = "0.11", features = [
"multipart",
"rustls-tls",
], default-features = false }
serde_qs = "0.12"
serde_qs = "0.13"
webauthn-authenticator-rs = { version = "0.5", features = ["softpasskey"] }

[build-dependencies]
prost-build = "0.12"
Expand Down
4 changes: 4 additions & 0 deletions migrations/20240216195802_authentication_key.down.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
DROP TABLE authentication_key;

DROP TABLE yubikey;

DROP TYPE authentication_key_type;

ALTER TABLE "user"
ADD pgp_key text NULL,
ADD pgp_cert_id text NULL,
Expand Down
2 changes: 0 additions & 2 deletions rust-toolchain.toml

This file was deleted.

3 changes: 1 addition & 2 deletions src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,9 @@ impl Claims {

fn get_secret(claims_type: ClaimsType) -> String {
let env_var = match claims_type {
ClaimsType::Auth => AUTH_SECRET_ENV,
ClaimsType::Auth | ClaimsType::DesktopClient => AUTH_SECRET_ENV,
ClaimsType::Gateway => GATEWAY_SECRET_ENV,
ClaimsType::YubiBridge => YUBIBRIDGE_SECRET_ENV,
ClaimsType::DesktopClient => AUTH_SECRET_ENV,
};
env::var(env_var).unwrap_or_default()
}
Expand Down
18 changes: 9 additions & 9 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,16 +243,16 @@ impl DefGuardConfig {

fn validate_secret_key(&self) {
let secret_key = self.secret_key.expose_secret();
if secret_key.trim().len() != secret_key.len() {
panic!("SECRET_KEY cannot have leading and trailing space",);
}
assert!(
secret_key.trim().len() == secret_key.len(),
"SECRET_KEY cannot have leading and trailing space",
);

if secret_key.len() < 64 {
panic!(
"SECRET_KEY must be at least 64 characters long, provided value has {} characters",
secret_key.len()
);
}
assert!(
secret_key.len() >= 64,
"SECRET_KEY must be at least 64 characters long, provided value has {} characters",
secret_key.len()
);
}

/// Try PKCS#1 and PKCS#8 PEM formats.
Expand Down
6 changes: 3 additions & 3 deletions src/db/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,15 @@ impl UserInfo {
transaction: &mut PgConnection,
user: &mut User,
) -> Result<bool, SqlxError> {
if self.is_active != user.is_active {
if self.is_active == user.is_active {
Ok(false)
} else {
if !self.is_active {
user.logout_all_sessions(&mut *transaction).await?;
}
user.is_active = self.is_active;
user.save(&mut *transaction).await?;
Ok(true)
} else {
Ok(false)
}
}

Expand Down
5 changes: 1 addition & 4 deletions src/db/models/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,7 @@ impl WireguardNetwork {
Ok(wireguard_network_device)
} else {
error!("Device {device} not allowed in network {self}");
Err(WireguardNetworkError::DeviceNotAllowed(format!(
"{}",
device
)))
Err(WireguardNetworkError::DeviceNotAllowed(format!("{device}")))
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/grpc/desktop_client_mfa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl ClientMfaServer {
}

/// Validate JWT and extract client pubkey
fn parse_token(&self, token: &str) -> Result<String, Status> {
fn parse_token(token: &str) -> Result<String, Status> {
let claims = Claims::from_jwt(ClaimsType::DesktopClient, token).map_err(|err| {
error!("Failed to parse JWT token: {err:?}");
Status::invalid_argument("invalid token")
Expand Down Expand Up @@ -185,7 +185,7 @@ impl ClientMfaServer {
) -> Result<ClientMfaFinishResponse, Status> {
debug!("Finishing desktop client login: {request:?}");
// get pubkey from token
let pubkey = self.parse_token(&request.token)?;
let pubkey = Self::parse_token(&request.token)?;

// fetch login session
let Some(session) = self.sessions.get(&pubkey) else {
Expand Down
2 changes: 1 addition & 1 deletion src/grpc/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ impl gateway_service_server::GatewayService for GatewayServer {
.ok_or_else(|| {
Status::new(
Code::Internal,
format!("Network with id {} not found", network_id),
format!("Network with id {network_id} not found"),
)
})?;

Expand Down
11 changes: 3 additions & 8 deletions src/grpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,8 @@ impl GatewayMap {
if let Some(state) = network_gateway_map.get_mut(&hostname) {
state.connected = false;
state.disconnected_at = Some(Utc::now().naive_utc());
state.send_disconnect_notification(pool)?;
debug!(
"Gateway {hostname} found in gateway map, current state: {:#?}",
state
);
state.send_disconnect_notification(pool);
debug!("Gateway {hostname} found in gateway map, current state: {state:#?}");
info!("Gateway {hostname} disconnected in network {network_id}");
return Ok(());
};
Expand Down Expand Up @@ -290,7 +287,7 @@ impl GatewayState {

/// Send gateway disconnected notification
/// Sends notification only if last notification time is bigger than specified in config
fn send_disconnect_notification(&mut self, pool: &DbPool) -> Result<(), GatewayMapError> {
fn send_disconnect_notification(&mut self, pool: &DbPool) {
debug!("Sending gateway disconnect email notification");
// Clone here because self doesn't live long enough
let name = self.name.clone();
Expand Down Expand Up @@ -327,8 +324,6 @@ impl GatewayState {
self.last_email_notification
);
};

Ok(())
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/handlers/forward_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ pub async fn forward_auth(
}
// If no session cookie provided redirect to login
info!("Valid session not found, redirecting to login page");
login_redirect(headers).await
login_redirect(headers)
}

async fn login_redirect(headers: ForwardAuthHeaders) -> Result<ForwardAuthResponse, WebError> {
fn login_redirect(headers: ForwardAuthHeaders) -> Result<ForwardAuthResponse, WebError> {
let server_url = &server_config().url; // prepare redirect URL for login page
let mut location = server_url.join("/auth/login").map_err(|err| {
error!("Failed to prepare redirect URL: {err}");
Expand Down
14 changes: 7 additions & 7 deletions src/handlers/openid_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,10 @@ fn redirect_to<T: AsRef<str>>(

/// Helper function to redirect unauthorized user to login page
/// and store information about OpenID authorize url in cookie to redirect later
async fn login_redirect(
fn login_redirect(
data: &AuthenticationRequest,
private_cookies: PrivateCookieJar,
) -> Result<(StatusCode, HeaderMap, PrivateCookieJar), WebError> {
) -> (StatusCode, HeaderMap, PrivateCookieJar) {
let config = server_config();
let base_url = config.url.join("api/v1/oauth/authorize").unwrap();
let cookie = Cookie::build((
Expand All @@ -358,7 +358,7 @@ async fn login_redirect(
.same_site(SameSite::Lax)
.http_only(true)
.max_age(Duration::minutes(10));
Ok(redirect_to("/login", private_cookies.add(cookie)))
redirect_to("/login", private_cookies.add(cookie))
}

/// Authorization Endpoint
Expand Down Expand Up @@ -400,7 +400,7 @@ pub async fn authorization(
if session.expired() {
info!("Session {} for user id {} has expired, redirecting to login", session.id, session.user_id);
let _result = session.delete(&appstate.pool).await;
login_redirect(&data, private_cookies).await
Ok(login_redirect(&data, private_cookies))
} else {
let user = User::find_by_id(&appstate.pool, session.user_id)
.await?
Expand All @@ -415,7 +415,7 @@ pub async fn authorization(
"MFA not verified for user id {}, redirecting to login",
session.user_id
);
return login_redirect(&data, private_cookies).await;
return Ok(login_redirect(&data, private_cookies));
}

// If session is present check if app is in user authorized apps.
Expand Down Expand Up @@ -462,13 +462,13 @@ pub async fn authorization(
"Session {} not found, redirecting to login page",
session_cookie.value()
);
login_redirect(&data, private_cookies).await
Ok(login_redirect(&data, private_cookies))
}

// If no session cookie provided redirect to login
} else {
info!("Session cookie not provided, redirecting to login page");
login_redirect(&data, private_cookies).await
Ok(login_redirect(&data, private_cookies))
};
}
}
Expand Down
52 changes: 22 additions & 30 deletions src/handlers/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ pub struct MappedDevices {
pub devices: Vec<MappedDevice>,
}

#[derive(Serialize)]
struct ConnectionInfo {
connected: bool,
}

#[derive(Deserialize)]
pub struct ImportNetworkData {
pub name: String,
Expand Down Expand Up @@ -425,32 +420,29 @@ pub async fn add_user_devices(
});
}

match WireguardNetwork::find_by_id(&appstate.pool, network_id).await? {
Some(network) => {
// wrap loop in transaction to abort if a device is invalid
let mut transaction = appstate.pool.begin().await?;
let events = network
.handle_mapped_devices(&mut transaction, mapped_devices)
.await?;
appstate.send_multiple_wireguard_events(events);
transaction.commit().await?;

info!(
"User {} mapped {device_count} devices for {network_id} network",
user.username,
);
if let Some(network) = WireguardNetwork::find_by_id(&appstate.pool, network_id).await? {
// wrap loop in transaction to abort if a device is invalid
let mut transaction = appstate.pool.begin().await?;
let events = network
.handle_mapped_devices(&mut transaction, mapped_devices)
.await?;
appstate.send_multiple_wireguard_events(events);
transaction.commit().await?;

Ok(ApiResponse {
json: json!({}),
status: StatusCode::CREATED,
})
}
None => {
error!("Failed to map devices, network {network_id} not found");
Err(WebError::ObjectNotFound(format!(
"Network {network_id} not found"
)))
}
info!(
"User {} mapped {device_count} devices for {network_id} network",
user.username,
);

Ok(ApiResponse {
json: json!({}),
status: StatusCode::CREATED,
})
} else {
error!("Failed to map devices, network {network_id} not found");
Err(WebError::ObjectNotFound(format!(
"Network {network_id} not found"
)))
}
}

Expand Down
Loading
Loading