Skip to content

Commit

Permalink
feat(rust/driver/snowflake)!: return a Result from `Builder::from_e…
Browse files Browse the repository at this point in the history
…nv` when parsing fails (#2334)

As suggested in
#2207 (comment)
the `Builder::from_env` methods should return a result when parsing
fails.
  • Loading branch information
mbrobbel authored Nov 25, 2024
1 parent 8de65c8 commit 5da710b
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 113 deletions.
4 changes: 2 additions & 2 deletions rust/driver/snowflake/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ use arrow_array::{cast::AsArray, types::Decimal128Type};
let mut driver = Driver::try_load()?;
// Construct a database using environment variables
let mut database = database::Builder::from_env().build(&mut driver)?;
let mut database = database::Builder::from_env()?.build(&mut driver)?;
// Create a connection to the database
let mut connection = connection::Builder::from_env().build(&mut database)?;
let mut connection = connection::Builder::from_env()?.build(&mut database)?;
// Construct a statement to execute a query
let mut statement = connection.new_statement()?;
Expand Down
36 changes: 36 additions & 0 deletions rust/driver/snowflake/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
//!
use std::iter::{Chain, Flatten};
#[cfg(feature = "env")]
use std::{env, error::Error as StdError};

#[cfg(feature = "env")]
use adbc_core::error::{Error, Status};
use adbc_core::options::OptionValue;

/// An iterator over the builder options.
Expand Down Expand Up @@ -48,3 +52,35 @@ impl<T, const COUNT: usize> Iterator for BuilderIter<T, COUNT> {
self.0.next()
}
}

#[cfg(feature = "env")]
/// Attempt to read the environment variable with the given `key`, parsing it
/// using the provided `parse` function.
///
/// Returns
///
/// - `Ok(None)` when the env variable is not set.
/// - `Ok(Some(T))` when the env variable is set and the parser succeeds.
/// - `Err(Error)` when the env variable is set and the parse fails.
pub(crate) fn env_parse<T>(
key: &str,
parse: impl FnOnce(&str) -> Result<T, Error>,
) -> Result<Option<T>, Error> {
env::var(key).ok().as_deref().map(parse).transpose()
}

#[cfg(feature = "env")]
/// Attempt to read the environment variable with the given `key`, parsing it
/// using the provided `parse` function, mapping the parse result to an
/// [`Error`] with [`Status::InvalidArguments`].
pub(crate) fn env_parse_map_err<T, E: StdError>(
key: &str,
parse: impl FnOnce(&str) -> Result<T, E>,
) -> Result<Option<T>, Error> {
env::var(key)
.ok()
.as_deref()
.map(parse)
.transpose()
.map_err(|err| Error::with_message_and_status(err.to_string(), Status::InvalidArguments))
}
21 changes: 11 additions & 10 deletions rust/driver/snowflake/src/connection/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
//!
//!
#[cfg(feature = "env")]
use std::env;
use std::fmt;

use adbc_core::{
Expand All @@ -30,7 +28,7 @@ use adbc_core::{
};

#[cfg(feature = "env")]
use crate::database;
use crate::{builder::env_parse_map_err, database};
use crate::{builder::BuilderIter, Connection, Database};

/// A builder for [`Connection`].
Expand Down Expand Up @@ -61,18 +59,21 @@ impl Builder {

/// Construct a builder, setting values based on values of the
/// configuration environment variables.
pub fn from_env() -> Self {
///
/// # Error
///
/// Returns an error when environment variables are set but their values
/// fail to parse.
pub fn from_env() -> Result<Self> {
#[cfg(feature = "dotenv")]
let _ = dotenvy::dotenv();

let use_high_precision = env::var(Self::USE_HIGH_PRECISION_ENV)
.ok()
.as_deref()
.and_then(|value| value.parse().ok());
Self {
let use_high_precision = env_parse_map_err(Self::USE_HIGH_PRECISION_ENV, str::parse)?;

Ok(Self {
use_high_precision,
..Default::default()
}
})
}
}

Expand Down
182 changes: 96 additions & 86 deletions rust/driver/snowflake/src/database/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@ use adbc_core::{
};
use url::{Host, Url};

#[cfg(feature = "env")]
use crate::duration::parse_duration;
use crate::{builder::BuilderIter, Database, Driver};
#[cfg(feature = "env")]
use crate::{
builder::{env_parse, env_parse_map_err},
duration::parse_duration,
};

/// Authentication types.
#[derive(Copy, Clone, Debug, Default)]
Expand Down Expand Up @@ -83,7 +86,15 @@ impl str::FromStr for AuthType {
"auth_jwt" => Ok(Self::Jwt),
"auth_mfa" => Ok(Self::UsernamePasswordMFA),
_ => Err(Error::with_message_and_status(
format!("invalid auth type: {s}"),
format!(
"invalid auth type: {s} (possible values: {}, {}, {}, {}, {}, {})",
Self::Snowflake,
Self::OAuth,
Self::ExternalBrowser,
Self::Okta,
Self::Jwt,
Self::UsernamePasswordMFA
),
Status::InvalidArguments,
)),
}
Expand Down Expand Up @@ -122,7 +133,11 @@ impl str::FromStr for Protocol {
"https" | "HTTPS" => Ok(Self::Https),
"http" | "HTTP" => Ok(Self::Http),
_ => Err(Error::with_message_and_status(
format!("invalid protocol type: {s}"),
format!(
"invalid protocol type: {s} (possible values: {}, {})",
Self::Https,
Self::Http
),
Status::InvalidArguments,
)),
}
Expand Down Expand Up @@ -181,7 +196,16 @@ impl str::FromStr for LogLevel {
"fatal" => Ok(Self::Fatal),
"off" => Ok(Self::Off),
_ => Err(Error::with_message_and_status(
format!("invalid log level: {s}"),
format!(
"invalid log level: {s} (possible values: {}, {}, {}, {}, {}, {}, {})",
Self::Trace,
Self::Debug,
Self::Info,
Self::Warn,
Self::Error,
Self::Fatal,
Self::Off
),
Status::InvalidArguments,
)),
}
Expand Down Expand Up @@ -453,14 +477,16 @@ impl Builder {

/// Construct a builder, setting values based on values of the
/// configuration environment variables.
pub fn from_env() -> Self {
///
/// # Error
///
/// Returns an error when environment variables are set but their values
/// fail to parse.
pub fn from_env() -> Result<Self> {
#[cfg(feature = "dotenv")]
let _ = dotenvy::dotenv();

let uri = env::var(Self::URI_ENV)
.ok()
.as_deref()
.and_then(|value| Url::parse(value).ok());
let uri = env_parse_map_err(Self::URI_ENV, Url::parse)?;
let username = env::var(Self::USERNAME_ENV).ok();
let password = env::var(Self::PASSWORD_ENV).ok();
let database = env::var(Self::DATABASE_ENV).ok();
Expand All @@ -469,86 +495,34 @@ impl Builder {
let role = env::var(Self::ROLE_ENV).ok();
let region = env::var(Self::REGION_ENV).ok();
let account = env::var(Self::ACCOUNT_ENV).ok();
let protocol = env::var(Self::PROTOCOL_ENV)
.ok()
.as_deref()
.and_then(|value| value.parse().ok());
let port = env::var(Self::PORT_ENV)
.ok()
.and_then(|value| value.parse().ok());
let host = env::var(Self::HOST_ENV)
.ok()
.as_deref()
.and_then(|value| Host::parse(value).ok());
let auth_type = env::var(Self::AUTH_TYPE_ENV)
.ok()
.and_then(|value| value.parse().ok());
let login_timeout = env::var(Self::LOGIN_TIMEOUT_ENV)
.ok()
.as_deref()
.and_then(|value| parse_duration(value).ok());
let request_timeout = env::var(Self::REQUEST_TIMEOUT_ENV)
.ok()
.as_deref()
.and_then(|value| parse_duration(value).ok());
let jwt_expire_timeout = env::var(Self::JWT_EXPIRE_TIMEOUT_ENV)
.ok()
.as_deref()
.and_then(|value| parse_duration(value).ok());
let client_timeout = env::var(Self::CLIENT_TIMEOUT_ENV)
.ok()
.as_deref()
.and_then(|value| parse_duration(value).ok());
let use_high_precision = env::var(Self::USE_HIGH_PRECISION_ENV)
.ok()
.as_deref()
.and_then(|value| value.parse().ok());
let protocol = env_parse(Self::PROTOCOL_ENV, str::parse)?;
let port = env_parse_map_err(Self::PORT_ENV, str::parse)?;
let host = env_parse_map_err(Self::HOST_ENV, Host::parse)?;
let auth_type = env_parse(Self::AUTH_TYPE_ENV, str::parse)?;
let login_timeout = env_parse(Self::LOGIN_TIMEOUT_ENV, parse_duration)?;
let request_timeout = env_parse(Self::REQUEST_TIMEOUT_ENV, parse_duration)?;
let jwt_expire_timeout = env_parse(Self::JWT_EXPIRE_TIMEOUT_ENV, parse_duration)?;
let client_timeout = env_parse(Self::CLIENT_TIMEOUT_ENV, parse_duration)?;
let use_high_precision = env_parse_map_err(Self::USE_HIGH_PRECISION_ENV, str::parse)?;
let application_name = env::var(Self::APPLICATION_NAME_ENV).ok();
let ssl_skip_verify = env::var(Self::SSL_SKIP_VERIFY_ENV)
.ok()
.as_deref()
.and_then(|value| value.parse().ok());
let ocsp_fail_open_mode = env::var(Self::OCSP_FAIL_OPEN_MODE_ENV)
.ok()
.as_deref()
.and_then(|value| value.parse().ok());
let ssl_skip_verify = env_parse_map_err(Self::SSL_SKIP_VERIFY_ENV, str::parse)?;
let ocsp_fail_open_mode = env_parse_map_err(Self::OCSP_FAIL_OPEN_MODE_ENV, str::parse)?;
let auth_token = env::var(Self::AUTH_TOKEN_ENV).ok();
let auth_okta_url = env::var(Self::AUTH_OKTA_URL_ENV)
.ok()
.as_deref()
.and_then(|value| Url::parse(value).ok());
let keep_session_alive = env::var(Self::OCSP_FAIL_OPEN_MODE_ENV)
.ok()
.as_deref()
.and_then(|value| value.parse().ok());
let jwt_private_key = env::var(Self::JWT_PRIVATE_KEY_ENV)
.ok()
.as_deref()
.and_then(|value| value.parse().ok());
let auth_okta_url = env_parse_map_err(Self::AUTH_OKTA_URL_ENV, Url::parse)?;
let keep_session_alive = env_parse_map_err(Self::OCSP_FAIL_OPEN_MODE_ENV, str::parse)?;
let jwt_private_key = env_parse_map_err(Self::JWT_PRIVATE_KEY_ENV, str::parse)?;
let jwt_private_key_pkcs8_value = env::var(Self::JWT_PRIVATE_KEY_PKCS8_VALUE_ENV).ok();
let jwt_private_key_pkcs8_password =
env::var(Self::JWT_PRIVATE_KEY_PKCS8_PASSWORD_ENV).ok();
let disable_telemetry = env::var(Self::DISABLE_TELEMETRY_ENV)
.ok()
.as_deref()
.and_then(|value| value.parse().ok());
let log_tracing = env::var(Self::LOG_TRACING_ENV)
.ok()
.as_deref()
.and_then(|value| value.parse().ok());
let client_config_file = env::var(Self::CLIENT_CONFIG_FILE_ENV)
.ok()
.as_deref()
.and_then(|value| value.parse().ok());
let client_cache_mfa_token = env::var(Self::CLIENT_CACHE_MFA_TOKEN_ENV)
.ok()
.as_deref()
.and_then(|value| value.parse().ok());
let client_store_temp_creds = env::var(Self::CLIENT_STORE_TEMP_CREDS_ENV)
.ok()
.as_deref()
.and_then(|value| value.parse().ok());
Self {
let disable_telemetry = env_parse_map_err(Self::DISABLE_TELEMETRY_ENV, str::parse)?;
let log_tracing = env_parse(Self::LOG_TRACING_ENV, str::parse)?;
let client_config_file = env_parse_map_err(Self::CLIENT_CONFIG_FILE_ENV, str::parse)?;
let client_cache_mfa_token =
env_parse_map_err(Self::CLIENT_CACHE_MFA_TOKEN_ENV, str::parse)?;
let client_store_temp_creds =
env_parse_map_err(Self::CLIENT_STORE_TEMP_CREDS_ENV, str::parse)?;

Ok(Self {
uri,
username,
password,
Expand Down Expand Up @@ -582,7 +556,7 @@ impl Builder {
client_cache_mfa_token,
client_store_temp_creds,
..Default::default()
}
})
}
}

Expand Down Expand Up @@ -1076,3 +1050,39 @@ impl IntoIterator for Builder {
)
}
}

#[cfg(test)]
#[cfg(feature = "env")]
mod tests {
use std::env;

use adbc_core::error::Status;

use super::*;

#[test]
fn from_env_parse_error() {
// Set a value that fails to parse to a LogLevel
env::set_var(Builder::LOG_TRACING_ENV, "warning");
let result = Builder::from_env();
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
Error::with_message_and_status(
"invalid log level: warning (possible values: trace, debug, info, warn, error, fatal, off)",
Status::InvalidArguments
)
);
// Fix it to move on
env::set_var(Builder::LOG_TRACING_ENV, "warn");

// Set a value that fails to parse to a duration
env::set_var(Builder::LOGIN_TIMEOUT_ENV, "forever");
let result = Builder::from_env();
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
Error::with_message_and_status("invalid duration (valid durations are a sequence of decimal numbers, each with optional fraction and a unit suffix, such as 300ms, 1.5h, 2h45m, valid time units are ns, us, ms, s, m, h)", Status::InvalidArguments)
);
}
}
Loading

0 comments on commit 5da710b

Please sign in to comment.