diff --git a/crates/aws/src/lib.rs b/crates/aws/src/lib.rs index ee7f222701..c062e47334 100644 --- a/crates/aws/src/lib.rs +++ b/crates/aws/src/lib.rs @@ -771,7 +771,9 @@ mod tests { let factory = S3LogStoreFactory::default(); let store = InMemory::new(); let url = Url::parse("s3://test-bucket").unwrap(); - std::env::remove_var(crate::constants::AWS_S3_LOCKING_PROVIDER); + unsafe { + std::env::remove_var(crate::constants::AWS_S3_LOCKING_PROVIDER); + } let logstore = factory .with_options(Arc::new(store), &url, &StorageOptions::from(HashMap::new())) .unwrap(); diff --git a/crates/aws/src/storage.rs b/crates/aws/src/storage.rs index b2ad64d0c1..019071a60f 100644 --- a/crates/aws/src/storage.rs +++ b/crates/aws/src/storage.rs @@ -586,6 +586,7 @@ mod tests { ScopedEnv::run(|| { clear_env_of_aws_keys(); std::env::remove_var(constants::AWS_ENDPOINT_URL); + let options = S3StorageOptions::from_map(&hashmap! { constants::AWS_REGION.to_string() => "eu-west-1".to_string(), constants::AWS_ACCESS_KEY_ID.to_string() => "test".to_string(), @@ -767,12 +768,10 @@ mod tests { ScopedEnv::run(|| { clear_env_of_aws_keys(); let raw_options = hashmap! {}; - std::env::set_var(constants::AWS_ACCESS_KEY_ID, "env_key"); std::env::set_var(constants::AWS_ENDPOINT_URL, "env_key"); std::env::set_var(constants::AWS_SECRET_ACCESS_KEY, "env_key"); std::env::set_var(constants::AWS_REGION, "env_key"); - let combined_options = S3ObjectStoreFactory {}.with_env_s3(&StorageOptions(raw_options)); @@ -795,7 +794,6 @@ mod tests { "AWS_SECRET_ACCESS_KEY".to_string() => "options_key".to_string(), "AWS_REGION".to_string() => "options_key".to_string() }; - std::env::set_var("aws_access_key_id", "env_key"); std::env::set_var("aws_endpoint", "env_key"); std::env::set_var("aws_secret_access_key", "env_key"); diff --git a/crates/aws/tests/common.rs b/crates/aws/tests/common.rs index dfa2a9cd51..e32522e2d3 100644 --- a/crates/aws/tests/common.rs +++ b/crates/aws/tests/common.rs @@ -3,7 +3,7 @@ use deltalake_aws::constants; use deltalake_aws::register_handlers; use deltalake_aws::storage::*; use deltalake_test::utils::*; -use rand::Rng; +use rand::random; use std::process::{Command, ExitStatus, Stdio}; #[derive(Clone, Debug)] @@ -43,14 +43,16 @@ impl StorageIntegration for S3Integration { fn prepare_env(&self) { set_env_if_not_set( constants::LOCK_TABLE_KEY_NAME, - format!("delta_log_it_{}", rand::thread_rng().gen::()), + format!("delta_log_it_{}", random::()), ); match std::env::var(s3_constants::AWS_ENDPOINT_URL).ok() { - Some(endpoint_url) if endpoint_url.to_lowercase() == "none" => { + Some(endpoint_url) if endpoint_url.to_lowercase() == "none" => unsafe { std::env::remove_var(s3_constants::AWS_ENDPOINT_URL) - } + }, Some(_) => (), - None => std::env::set_var(s3_constants::AWS_ENDPOINT_URL, "http://localhost:4566"), + None => unsafe { + std::env::set_var(s3_constants::AWS_ENDPOINT_URL, "http://localhost:4566") + }, } set_env_if_not_set(s3_constants::AWS_ACCESS_KEY_ID, "deltalake"); set_env_if_not_set(s3_constants::AWS_SECRET_ACCESS_KEY, "weloverust"); diff --git a/crates/catalog-glue/src/lib.rs b/crates/catalog-glue/src/lib.rs index e9ef449be2..089ce56ce2 100644 --- a/crates/catalog-glue/src/lib.rs +++ b/crates/catalog-glue/src/lib.rs @@ -60,6 +60,8 @@ const PLACEHOLDER_SUFFIX: &str = "-__PLACEHOLDER__"; #[async_trait::async_trait] impl DataCatalog for GlueDataCatalog { + type Error = DataCatalogError; + /// Get the table storage location from the Glue Data Catalog async fn get_table_storage_location( &self, diff --git a/crates/catalog-unity/Cargo.toml b/crates/catalog-unity/Cargo.toml new file mode 100644 index 0000000000..8a0827386b --- /dev/null +++ b/crates/catalog-unity/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "deltalake-catalog-unity" +version = "0.6.0" +authors.workspace = true +keywords.workspace = true +readme.workspace = true +edition.workspace = true +homepage.workspace = true +description.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +async-trait.workspace = true +tokio.workspace = true +serde.workspace = true +serde_json.workspace = true +thiserror.workspace = true +deltalake-core = { version = "0.22", path = "../core" } +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json", "http2"] } +reqwest-retry = "0.7" +reqwest-middleware = "0.4.0" +rand = "0.8" +futures = "0.3" +chrono = "0.4" +dashmap = "6" +tracing = "0.1" +datafusion = { version = "43", optional = true } +datafusion-common = { version = "43", optional = true } + +[dev-dependencies] +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +tempfile = "3" +httpmock = { version = "0.8.0-alpha.1" } + +[features] +default = [] +datafusion = ["dep:datafusion", "datafusion-common"] + diff --git a/crates/core/src/data_catalog/client/backoff.rs b/crates/catalog-unity/src/client/backoff.rs similarity index 100% rename from crates/core/src/data_catalog/client/backoff.rs rename to crates/catalog-unity/src/client/backoff.rs diff --git a/crates/catalog-unity/src/client/mock_server.rs b/crates/catalog-unity/src/client/mock_server.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/crates/core/src/data_catalog/client/mod.rs b/crates/catalog-unity/src/client/mod.rs similarity index 87% rename from crates/core/src/data_catalog/client/mod.rs rename to crates/catalog-unity/src/client/mod.rs index c6cd838076..e88d0fa040 100644 --- a/crates/core/src/data_catalog/client/mod.rs +++ b/crates/catalog-unity/src/client/mod.rs @@ -1,15 +1,18 @@ //! Generic utilities reqwest based Catalog implementations pub mod backoff; -#[cfg(test)] -pub mod mock_server; #[allow(unused)] pub mod pagination; pub mod retry; pub mod token; +use crate::client::retry::RetryConfig; +use crate::UnityCatalogError; +use deltalake_core::data_catalog::DataCatalogResult; use reqwest::header::{HeaderMap, HeaderValue}; -use reqwest::{Client, ClientBuilder, Proxy}; +use reqwest::{ClientBuilder, Proxy}; +use reqwest_middleware::ClientWithMiddleware; +use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use std::time::Duration; fn map_client_error(e: reqwest::Error) -> super::DataCatalogError { @@ -38,6 +41,7 @@ pub struct ClientOptions { http2_keep_alive_while_idle: bool, http1_only: bool, http2_only: bool, + retry_config: Option, } impl ClientOptions { @@ -164,7 +168,12 @@ impl ClientOptions { self } - pub(crate) fn client(&self) -> super::DataCatalogResult { + pub fn with_retry_config(mut self, cfg: RetryConfig) -> Self { + self.retry_config = Some(cfg); + self + } + + pub(crate) fn client(&self) -> DataCatalogResult { let mut builder = ClientBuilder::new(); match &self.user_agent { @@ -221,9 +230,19 @@ impl ClientOptions { builder = builder.danger_accept_invalid_certs(self.allow_insecure) } - builder + let inner_client = builder .https_only(!self.allow_http) .build() - .map_err(map_client_error) + .map_err(UnityCatalogError::from)?; + let retry_policy = self + .retry_config + .as_ref() + .map(|retry| retry.into()) + .unwrap_or(ExponentialBackoff::builder().build_with_max_retries(3)); + + let middleware = RetryTransientMiddleware::new_with_policy(retry_policy); + Ok(reqwest_middleware::ClientBuilder::new(inner_client) + .with(middleware) + .build()) } } diff --git a/crates/core/src/data_catalog/client/pagination.rs b/crates/catalog-unity/src/client/pagination.rs similarity index 97% rename from crates/core/src/data_catalog/client/pagination.rs rename to crates/catalog-unity/src/client/pagination.rs index a5225237b4..630ef2aace 100644 --- a/crates/core/src/data_catalog/client/pagination.rs +++ b/crates/catalog-unity/src/client/pagination.rs @@ -3,7 +3,7 @@ use std::future::Future; use futures::Stream; -use crate::data_catalog::DataCatalogResult; +use deltalake_core::data_catalog::DataCatalogResult; /// Takes a paginated operation `op` that when called with: /// diff --git a/crates/catalog-unity/src/client/retry.rs b/crates/catalog-unity/src/client/retry.rs new file mode 100644 index 0000000000..9b3828274e --- /dev/null +++ b/crates/catalog-unity/src/client/retry.rs @@ -0,0 +1,118 @@ +//! A shared HTTP client implementation incorporating retries + +use super::backoff::BackoffConfig; +use deltalake_core::DataCatalogError; +use reqwest::StatusCode; +use reqwest_retry::policies::ExponentialBackoff; +use std::time::Duration; + +/// Retry request error +#[derive(Debug)] +pub struct RetryError { + retries: usize, + message: String, + source: Option, +} + +impl std::fmt::Display for RetryError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "response error \"{}\", after {} retries", + self.message, self.retries + )?; + if let Some(source) = &self.source { + write!(f, ": {source}")?; + } + Ok(()) + } +} + +impl std::error::Error for RetryError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source.as_ref().map(|e| e as _) + } +} + +impl RetryError { + /// Returns the status code associated with this error if any + pub fn status(&self) -> Option { + self.source.as_ref().and_then(|e| e.status()) + } +} + +impl From for std::io::Error { + fn from(err: RetryError) -> Self { + use std::io::ErrorKind; + match (&err.source, err.status()) { + (Some(source), _) if source.is_builder() || source.is_request() => { + Self::new(ErrorKind::InvalidInput, err) + } + (_, Some(StatusCode::NOT_FOUND)) => Self::new(ErrorKind::NotFound, err), + (_, Some(StatusCode::BAD_REQUEST)) => Self::new(ErrorKind::InvalidInput, err), + (Some(source), None) if source.is_timeout() => Self::new(ErrorKind::TimedOut, err), + (Some(source), None) if source.is_connect() => Self::new(ErrorKind::NotConnected, err), + _ => Self::new(ErrorKind::Other, err), + } + } +} + +impl From for DataCatalogError { + fn from(value: RetryError) -> Self { + DataCatalogError::Generic { + catalog: "", + source: Box::new(value), + } + } +} + +/// Error retrying http requests +pub type Result = std::result::Result; + +/// Contains the configuration for how to respond to server errors +/// +/// By default, they will be retried up to some limit, using exponential +/// backoff with jitter. See [`BackoffConfig`] for more information +/// +#[derive(Debug, Clone)] +pub struct RetryConfig { + /// The backoff configuration + pub backoff: BackoffConfig, + + /// The maximum number of times to retry a request + /// + /// Set to 0 to disable retries + pub max_retries: usize, + + /// The maximum length of time from the initial request + /// after which no further retries will be attempted + /// + /// This not only bounds the length of time before a server + /// error will be surfaced to the application, but also bounds + /// the length of time a request's credentials must remain valid. + /// + /// As requests are retried without renewing credentials or + /// regenerating request payloads, this number should be kept + /// below 5 minutes to avoid errors due to expired credentials + /// and/or request payloads + pub retry_timeout: Duration, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + backoff: Default::default(), + max_retries: 10, + retry_timeout: Duration::from_secs(3 * 60), + } + } +} + +impl From<&RetryConfig> for ExponentialBackoff { + fn from(val: &RetryConfig) -> ExponentialBackoff { + ExponentialBackoff::builder() + .retry_bounds(val.backoff.init_backoff, val.backoff.max_backoff) + .base(val.backoff.base as u32) + .build_with_max_retries(val.max_retries as u32) + } +} diff --git a/crates/core/src/data_catalog/client/token.rs b/crates/catalog-unity/src/client/token.rs similarity index 100% rename from crates/core/src/data_catalog/client/token.rs rename to crates/catalog-unity/src/client/token.rs diff --git a/crates/core/src/data_catalog/unity/credential.rs b/crates/catalog-unity/src/credential.rs similarity index 83% rename from crates/core/src/data_catalog/unity/credential.rs rename to crates/catalog-unity/src/credential.rs index e0a833f182..b6b21b47eb 100644 --- a/crates/core/src/data_catalog/unity/credential.rs +++ b/crates/catalog-unity/src/credential.rs @@ -4,13 +4,12 @@ use std::process::Command; use std::time::{Duration, Instant}; use reqwest::header::{HeaderValue, ACCEPT}; -use reqwest::{Client, Method}; +use reqwest::Method; +use reqwest_middleware::ClientWithMiddleware; use serde::Deserialize; use super::UnityCatalogError; -use crate::data_catalog::client::retry::{RetryConfig, RetryExt}; -use crate::data_catalog::client::token::{TemporaryToken, TokenCache}; -use crate::data_catalog::DataCatalogResult; +use crate::client::token::{TemporaryToken, TokenCache}; // https://learn.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/authentication @@ -37,9 +36,8 @@ pub trait TokenCredential: std::fmt::Debug + Send + Sync + 'static { /// get the token async fn fetch_token( &self, - client: &Client, - retry: &RetryConfig, - ) -> DataCatalogResult>; + client: &ClientWithMiddleware, + ) -> Result, UnityCatalogError>; } /// Provides credentials for use when signing requests @@ -95,9 +93,8 @@ impl TokenCredential for ClientSecretOAuthProvider { /// Fetch a token async fn fetch_token( &self, - client: &Client, - retry: &RetryConfig, - ) -> DataCatalogResult> { + client: &ClientWithMiddleware, + ) -> Result, UnityCatalogError> { let response: TokenResponse = client .request(Method::POST, &self.token_url) .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON)) @@ -107,10 +104,12 @@ impl TokenCredential for ClientSecretOAuthProvider { ("scope", &format!("{}/.default", DATABRICKS_RESOURCE_SCOPE)), ("grant_type", "client_credentials"), ]) - .send_retry(retry) - .await? + .send() + .await + .map_err(UnityCatalogError::from)? .json() - .await?; + .await + .map_err(UnityCatalogError::from)?; Ok(TemporaryToken { token: response.access_token, @@ -167,9 +166,8 @@ impl TokenCredential for AzureCliCredential { /// Fetch a token async fn fetch_token( &self, - _client: &Client, - _retry: &RetryConfig, - ) -> DataCatalogResult> { + _client: &ClientWithMiddleware, + ) -> Result, UnityCatalogError> { // on window az is a cmd and it should be called like this // see https://doc.rust-lang.org/nightly/std/process/struct.Command.html let program = if cfg!(target_os = "windows") { @@ -281,9 +279,8 @@ impl TokenCredential for WorkloadIdentityOAuthProvider { /// Fetch a token async fn fetch_token( &self, - client: &Client, - retry: &RetryConfig, - ) -> DataCatalogResult> { + client: &ClientWithMiddleware, + ) -> Result, UnityCatalogError> { let token_str = std::fs::read_to_string(&self.federated_token_file) .map_err(|_| UnityCatalogError::FederatedTokenFile)?; @@ -301,10 +298,12 @@ impl TokenCredential for WorkloadIdentityOAuthProvider { ("scope", &format!("{}/.default", DATABRICKS_RESOURCE_SCOPE)), ("grant_type", "client_credentials"), ]) - .send_retry(retry) - .await? + .send() + .await + .map_err(UnityCatalogError::from)? .json() - .await?; + .await + .map_err(UnityCatalogError::from)?; Ok(TemporaryToken { token: response.access_token, @@ -340,7 +339,7 @@ pub struct ImdsManagedIdentityOAuthProvider { client_id: Option, object_id: Option, msi_res_id: Option, - client: Client, + client: ClientWithMiddleware, } impl ImdsManagedIdentityOAuthProvider { @@ -350,7 +349,7 @@ impl ImdsManagedIdentityOAuthProvider { object_id: Option, msi_res_id: Option, msi_endpoint: Option, - client: Client, + client: ClientWithMiddleware, ) -> Self { let msi_endpoint = msi_endpoint .unwrap_or_else(|| "http://169.254.169.254/metadata/identity/oauth2/token".to_owned()); @@ -370,9 +369,8 @@ impl TokenCredential for ImdsManagedIdentityOAuthProvider { /// Fetch a token async fn fetch_token( &self, - _client: &Client, - retry: &RetryConfig, - ) -> DataCatalogResult> { + _client: &ClientWithMiddleware, + ) -> Result, UnityCatalogError> { let resource_scope = format!("{}/.default", DATABRICKS_RESOURCE_SCOPE); let mut query_items = vec![ ("api-version", MSI_API_VERSION), @@ -403,7 +401,13 @@ impl TokenCredential for ImdsManagedIdentityOAuthProvider { builder = builder.header("x-identity-header", val); }; - let response: MsiTokenResponse = builder.send_retry(retry).await?.json().await?; + let response: MsiTokenResponse = builder + .send() + .await + .map_err(UnityCatalogError::from)? + .json() + .await + .map_err(UnityCatalogError::from)?; Ok(TemporaryToken { token: response.access_token, @@ -415,39 +419,27 @@ impl TokenCredential for ImdsManagedIdentityOAuthProvider { #[cfg(test)] mod tests { use super::*; - use crate::data_catalog::client::mock_server::MockServer; - use futures::executor::block_on; - use hyper::body::to_bytes; - use hyper::{Body, Response}; - use reqwest::{Client, Method}; + use httpmock::prelude::*; + use reqwest::Client; use tempfile::NamedTempFile; #[tokio::test] async fn test_managed_identity() { - let server = MockServer::new(); + let server = MockServer::start_async().await; std::env::set_var(MSI_SECRET_ENV_KEY, "env-secret"); - let endpoint = server.url(); - let client = Client::new(); - let retry_config = RetryConfig::default(); - - // Test IMDS - server.push_fn(|req| { - assert_eq!(req.uri().path(), "/metadata/identity/oauth2/token"); - assert!(req.uri().query().unwrap().contains("client_id=client_id")); - assert_eq!(req.method(), &Method::GET); - let t = req - .headers() - .get("x-identity-header") - .unwrap() - .to_str() - .unwrap(); - assert_eq!(t, "env-secret"); - let t = req.headers().get("metadata").unwrap().to_str().unwrap(); - assert_eq!(t, "true"); - Response::new(Body::from( - r#" + let client = reqwest_middleware::ClientBuilder::new(Client::new()).build(); + + server + .mock_async(|when, then| { + when.path("/metadata/identity/oauth2/token") + .query_param("client_id", "client_id") + .method("GET") + .header("x-identity-header", "env-secret") + .header("metadata", "true"); + then.body( + r#" { "access_token": "TOKEN", "refresh_token": "", @@ -458,45 +450,40 @@ mod tests { "token_type": "Bearer" } "#, - )) - }); + ); + }) + .await; let credential = ImdsManagedIdentityOAuthProvider::new( Some("client_id".into()), None, None, - Some(format!("{endpoint}/metadata/identity/oauth2/token")), + Some(server.url("/metadata/identity/oauth2/token")), client.clone(), ); - let token = credential - .fetch_token(&client, &retry_config) - .await - .unwrap(); + let token = credential.fetch_token(&client).await.unwrap(); assert_eq!(&token.token, "TOKEN"); } #[tokio::test] async fn test_workload_identity() { - let server = MockServer::new(); + let server = MockServer::start_async().await; let tokenfile = NamedTempFile::new().unwrap(); let tenant = "tenant"; std::fs::write(tokenfile.path(), "federated-token").unwrap(); - let endpoint = server.url(); - let client = Client::new(); - let retry_config = RetryConfig::default(); - - // Test IMDS - server.push_fn(move |req| { - assert_eq!(req.uri().path(), format!("/{tenant}/oauth2/v2.0/token")); - assert_eq!(req.method(), &Method::POST); - let body = block_on(to_bytes(req.into_body())).unwrap(); - let body = String::from_utf8(body.to_vec()).unwrap(); - assert!(body.contains("federated-token")); - Response::new(Body::from( - r#" + let client = reqwest_middleware::ClientBuilder::new(Client::new()).build(); + + server + .mock_async(|when, then| { + when.path_includes(format!("/{tenant}/oauth2/v2.0/token")) + .method("POST") + .body_includes("federated-token"); + + then.body( + r#" { "access_token": "TOKEN", "refresh_token": "", @@ -507,20 +494,18 @@ mod tests { "token_type": "Bearer" } "#, - )) - }); + ); + }) + .await; let credential = WorkloadIdentityOAuthProvider::new( "client_id", tokenfile.path().to_str().unwrap(), tenant, - Some(endpoint.to_string()), + Some(server.url(format!("/{tenant}/oauth2/v2.0/token"))), ); - let token = credential - .fetch_token(&client, &retry_config) - .await - .unwrap(); + let token = credential.fetch_token(&client).await.unwrap(); assert_eq!(&token.token, "TOKEN"); } diff --git a/crates/core/src/data_catalog/unity/datafusion.rs b/crates/catalog-unity/src/datafusion.rs similarity index 97% rename from crates/core/src/data_catalog/unity/datafusion.rs rename to crates/catalog-unity/src/datafusion.rs index 3e32a3ad68..23339b0b16 100644 --- a/crates/core/src/data_catalog/unity/datafusion.rs +++ b/crates/catalog-unity/src/datafusion.rs @@ -11,10 +11,12 @@ use datafusion::datasource::TableProvider; use datafusion_common::DataFusionError; use tracing::error; -use super::models::{GetTableResponse, ListCatalogsResponse, ListTableSummariesResponse}; +use super::models::{ + GetTableResponse, ListCatalogsResponse, ListSchemasResponse, ListTableSummariesResponse, +}; use super::{DataCatalogResult, UnityCatalog}; -use crate::data_catalog::models::ListSchemasResponse; -use crate::DeltaTableBuilder; + +use deltalake_core::DeltaTableBuilder; /// In-memory list of catalogs populated by unity catalog #[derive(Debug)] @@ -56,10 +58,6 @@ impl CatalogProviderList for UnityCatalogList { self } - fn catalog_names(&self) -> Vec { - self.catalogs.iter().map(|c| c.key().clone()).collect() - } - fn register_catalog( &self, name: String, @@ -68,6 +66,10 @@ impl CatalogProviderList for UnityCatalogList { self.catalogs.insert(name, catalog) } + fn catalog_names(&self) -> Vec { + self.catalogs.iter().map(|c| c.key().clone()).collect() + } + fn catalog(&self, name: &str) -> Option> { self.catalogs.get(name).map(|c| c.value().clone()) } diff --git a/crates/catalog-unity/src/error.rs b/crates/catalog-unity/src/error.rs new file mode 100644 index 0000000000..67610c07fe --- /dev/null +++ b/crates/catalog-unity/src/error.rs @@ -0,0 +1,41 @@ +#[derive(thiserror::Error, Debug)] +pub enum UnityCatalogError { + /// A generic error qualified in the message + #[error("Error in {catalog} catalog: {source}")] + Generic { + /// Name of the catalog + catalog: &'static str, + /// Error message + source: Box, + }, + + /// A generic error qualified in the message + + #[error("{source}")] + Retry { + /// Error message + #[from] + source: crate::client::retry::RetryError, + }, + + #[error("Request error: {source}")] + + /// Error from reqwest library + RequestError { + /// The underlying reqwest_middleware::Error + #[from] + source: reqwest::Error, + }, + + /// Error caused by missing environment variable for Unity Catalog. + #[error("Missing Unity Catalog environment variable: {var_name}")] + MissingEnvVar { + /// Variable name + var_name: String, + }, + + /// Error caused by invalid access token value + + #[error("Invalid Databricks personal access token")] + InvalidAccessToken, +} diff --git a/crates/core/src/data_catalog/unity/mod.rs b/crates/catalog-unity/src/lib.rs similarity index 86% rename from crates/core/src/data_catalog/unity/mod.rs rename to crates/catalog-unity/src/lib.rs index e9de725923..f5b8d1d08a 100644 --- a/crates/core/src/data_catalog/unity/mod.rs +++ b/crates/catalog-unity/src/lib.rs @@ -1,27 +1,30 @@ //! Databricks Unity Catalog. -//! -//! This module is gated behind the "unity-experimental" feature. use std::str::FromStr; -use reqwest::header::{HeaderValue, AUTHORIZATION}; +use reqwest::header::{HeaderValue, InvalidHeaderValue, AUTHORIZATION}; -use self::credential::{AzureCliCredential, ClientSecretOAuthProvider, CredentialProvider}; -use self::models::{ +use crate::credential::{AzureCliCredential, ClientSecretOAuthProvider, CredentialProvider}; +use crate::models::{ GetSchemaResponse, GetTableResponse, ListCatalogsResponse, ListSchemasResponse, ListTableSummariesResponse, }; -use super::client::retry::RetryExt; -use super::{client::retry::RetryConfig, DataCatalog, DataCatalogError, DataCatalogResult}; -use crate::storage::str_is_truthy; +use deltalake_core::data_catalog::DataCatalogResult; +use deltalake_core::{DataCatalog, DataCatalogError}; + +use crate::client::retry::*; +use deltalake_core::storage::str_is_truthy; + +pub mod client; pub mod credential; #[cfg(feature = "datafusion")] pub mod datafusion; +pub mod error; pub mod models; /// Possible errors from the unity-catalog/tables API call #[derive(thiserror::Error, Debug)] -enum UnityCatalogError { +pub enum UnityCatalogError { #[error("GET request error: {source}")] /// Error from reqwest library RequestError { @@ -30,6 +33,13 @@ enum UnityCatalogError { source: reqwest::Error, }, + #[error("Error in middleware: {source}")] + RequestMiddlewareError { + /// The underlying reqwest_middleware::Error + #[from] + source: reqwest_middleware::Error, + }, + /// Request returned error response #[error("Invalid table error: {error_code}: {message}")] InvalidTable { @@ -39,9 +49,11 @@ enum UnityCatalogError { message: String, }, - /// Unknown configuration key - #[error("Unknown configuration key: {0}")] - UnknownConfigKey(String), + #[error("Invalid token for auth header: {header_error}")] + InvalidHeader { + #[from] + header_error: InvalidHeaderValue, + }, /// Unknown configuration key #[error("Missing configuration key: {0}")] @@ -64,10 +76,6 @@ enum UnityCatalogError { impl From for DataCatalogError { fn from(value: UnityCatalogError) -> Self { match value { - UnityCatalogError::UnknownConfigKey(key) => DataCatalogError::UnknownConfigKey { - catalog: "Unity", - key, - }, _ => DataCatalogError::Generic { catalog: "Unity", source: Box::new(value), @@ -216,7 +224,10 @@ impl FromStr for UnityCatalogConfigKey { "workspace_url" | "unity_workspace_url" | "databricks_workspace_url" => { Ok(UnityCatalogConfigKey::WorkspaceUrl) } - _ => Err(UnityCatalogError::UnknownConfigKey(s.into()).into()), + _ => Err(DataCatalogError::UnknownConfigKey { + catalog: "unity", + key: s.to_string(), + }), } } } @@ -242,7 +253,7 @@ impl AsRef for UnityCatalogConfigKey { } } -/// Builder for crateing a UnityCatalogClient +/// Builder for creating a UnityCatalogClient #[derive(Default)] pub struct UnityCatalogBuilder { /// Url of a Databricks workspace @@ -282,7 +293,7 @@ pub struct UnityCatalogBuilder { retry_config: RetryConfig, /// Options for the underlying http client - client_options: super::client::ClientOptions, + client_options: client::ClientOptions, } #[allow(deprecated)] @@ -385,7 +396,7 @@ impl UnityCatalogBuilder { } /// Sets the client options, overriding any already set - pub fn with_client_options(mut self, options: super::client::ClientOptions) -> Self { + pub fn with_client_options(mut self, options: client::ClientOptions) -> Self { self.client_options = options; self } @@ -446,45 +457,33 @@ impl UnityCatalogBuilder { client, workspace_url, credential, - retry_config: self.retry_config, }) } } /// Databricks Unity Catalog pub struct UnityCatalog { - client: reqwest::Client, + client: reqwest_middleware::ClientWithMiddleware, credential: CredentialProvider, workspace_url: String, - retry_config: RetryConfig, } impl UnityCatalog { - async fn get_credential(&self) -> DataCatalogResult { + async fn get_credential(&self) -> Result { match &self.credential { CredentialProvider::BearerToken(token) => { - // we do the conversion to a HeaderValue here, since it is fallible + // we do the conversion to a HeaderValue here, since it is fallible, // and we want to use it in an infallible function - HeaderValue::from_str(&format!("Bearer {token}")).map_err(|err| { - super::DataCatalogError::Generic { - catalog: "Unity", - source: Box::new(err), - } - }) + Ok(HeaderValue::from_str(&format!("Bearer {token}"))?) } CredentialProvider::TokenCredential(cache, cred) => { let token = cache - .get_or_insert_with(|| cred.fetch_token(&self.client, &self.retry_config)) + .get_or_insert_with(|| cred.fetch_token(&self.client)) .await?; - // we do the conversion to a HeaderValue here, since it is fallible + // we do the conversion to a HeaderValue here, since it is fallible, // and we want to use it in an infallible function - HeaderValue::from_str(&format!("Bearer {token}")).map_err(|err| { - super::DataCatalogError::Generic { - catalog: "Unity", - source: Box::new(err), - } - }) + Ok(HeaderValue::from_str(&format!("Bearer {token}"))?) } } } @@ -497,14 +496,14 @@ impl UnityCatalog { /// all catalogs will be retrieved. Otherwise, only catalogs owned by the caller /// (or for which the caller has the USE_CATALOG privilege) will be retrieved. /// There is no guarantee of a specific ordering of the elements in the array. - pub async fn list_catalogs(&self) -> DataCatalogResult { + pub async fn list_catalogs(&self) -> Result { let token = self.get_credential().await?; // https://docs.databricks.com/api-explorer/workspace/schemas/list let resp = self .client .get(format!("{}/catalogs", self.catalog_url())) .header(AUTHORIZATION, token) - .send_retry(&self.retry_config) + .send() .await?; Ok(resp.json().await?) } @@ -521,7 +520,7 @@ impl UnityCatalog { pub async fn list_schemas( &self, catalog_name: impl AsRef, - ) -> DataCatalogResult { + ) -> Result { let token = self.get_credential().await?; // https://docs.databricks.com/api-explorer/workspace/schemas/list let resp = self @@ -529,7 +528,7 @@ impl UnityCatalog { .get(format!("{}/schemas", self.catalog_url())) .header(AUTHORIZATION, token) .query(&[("catalog_name", catalog_name.as_ref())]) - .send_retry(&self.retry_config) + .send() .await?; Ok(resp.json().await?) } @@ -542,7 +541,7 @@ impl UnityCatalog { &self, catalog_name: impl AsRef, schema_name: impl AsRef, - ) -> DataCatalogResult { + ) -> Result { let token = self.get_credential().await?; // https://docs.databricks.com/api-explorer/workspace/schemas/get let resp = self @@ -554,7 +553,7 @@ impl UnityCatalog { schema_name.as_ref() )) .header(AUTHORIZATION, token) - .send_retry(&self.retry_config) + .send() .await?; Ok(resp.json().await?) } @@ -574,7 +573,7 @@ impl UnityCatalog { &self, catalog_name: impl AsRef, schema_name_pattern: impl AsRef, - ) -> DataCatalogResult { + ) -> Result { let token = self.get_credential().await?; // https://docs.databricks.com/api-explorer/workspace/tables/listsummaries let resp = self @@ -585,7 +584,7 @@ impl UnityCatalog { ("schema_name_pattern", schema_name_pattern.as_ref()), ]) .header(AUTHORIZATION, token) - .send_retry(&self.retry_config) + .send() .await?; Ok(resp.json().await?) @@ -603,7 +602,7 @@ impl UnityCatalog { catalog_id: impl AsRef, database_name: impl AsRef, table_name: impl AsRef, - ) -> DataCatalogResult { + ) -> Result { let token = self.get_credential().await?; // https://docs.databricks.com/api-explorer/workspace/tables/get let resp = self @@ -616,7 +615,7 @@ impl UnityCatalog { table_name.as_ref() )) .header(AUTHORIZATION, token) - .send_retry(&self.retry_config) + .send() .await?; Ok(resp.json().await?) @@ -625,13 +624,14 @@ impl UnityCatalog { #[async_trait::async_trait] impl DataCatalog for UnityCatalog { + type Error = UnityCatalogError; /// Get the table storage location from the UnityCatalog async fn get_table_storage_location( &self, catalog_id: Option, database_name: &str, table_name: &str, - ) -> Result { + ) -> Result { match self .get_table( catalog_id.unwrap_or("main".into()), @@ -658,31 +658,47 @@ impl std::fmt::Debug for UnityCatalog { #[cfg(test)] mod tests { - use crate::data_catalog::client::ClientOptions; - - use super::super::client::mock_server::MockServer; - use super::models::tests::{GET_SCHEMA_RESPONSE, GET_TABLE_RESPONSE, LIST_SCHEMAS_RESPONSE}; - use super::*; - use hyper::{Body, Response}; - use reqwest::Method; + use crate::client::ClientOptions; + use crate::models::tests::{GET_SCHEMA_RESPONSE, GET_TABLE_RESPONSE, LIST_SCHEMAS_RESPONSE}; + use crate::models::*; + use crate::UnityCatalogBuilder; + use httpmock::prelude::*; #[tokio::test] async fn test_unity_client() { - let server = MockServer::new(); + let server = MockServer::start_async().await; let options = ClientOptions::default().with_allow_http(true); + let client = UnityCatalogBuilder::new() - .with_workspace_url(server.url()) + .with_workspace_url(server.url("")) .with_bearer_token("bearer_token") .with_client_options(options) .build() .unwrap(); - server.push_fn(move |req| { - assert_eq!(req.uri().path(), "/api/2.1/unity-catalog/schemas"); - assert_eq!(req.method(), &Method::GET); - Response::new(Body::from(LIST_SCHEMAS_RESPONSE)) - }); + server + .mock_async(|when, then| { + when.path("/api/2.1/unity-catalog/schemas").method("GET"); + then.body(LIST_SCHEMAS_RESPONSE); + }) + .await; + + server + .mock_async(|when, then| { + when.path("/api/2.1/unity-catalog/schemas/catalog_name.schema_name") + .method("GET"); + then.body(GET_SCHEMA_RESPONSE); + }) + .await; + + server + .mock_async(|when, then| { + when.path("/api/2.1/unity-catalog/tables/catalog_name.schema_name.table_name") + .method("GET"); + then.body(GET_TABLE_RESPONSE); + }) + .await; let list_schemas_response = client.list_schemas("catalog_name").await.unwrap(); assert!(matches!( @@ -690,30 +706,12 @@ mod tests { ListSchemasResponse::Success { .. } )); - server.push_fn(move |req| { - assert_eq!( - req.uri().path(), - "/api/2.1/unity-catalog/schemas/catalog_name.schema_name" - ); - assert_eq!(req.method(), &Method::GET); - Response::new(Body::from(GET_SCHEMA_RESPONSE)) - }); - let get_schema_response = client .get_schema("catalog_name", "schema_name") .await .unwrap(); assert!(matches!(get_schema_response, GetSchemaResponse::Success(_))); - server.push_fn(move |req| { - assert_eq!( - req.uri().path(), - "/api/2.1/unity-catalog/tables/catalog_name.schema_name.table_name" - ); - assert_eq!(req.method(), &Method::GET); - Response::new(Body::from(GET_TABLE_RESPONSE)) - }); - let get_table_response = client .get_table("catalog_name", "schema_name", "table_name") .await diff --git a/crates/core/src/data_catalog/unity/models.rs b/crates/catalog-unity/src/models.rs similarity index 100% rename from crates/core/src/data_catalog/unity/models.rs rename to crates/catalog-unity/src/models.rs diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 57a9496070..f499e76d06 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -20,7 +20,7 @@ delta_kernel.workspace = true # arrow arrow = { workspace = true } arrow-arith = { workspace = true } -arrow-array = { workspace = true , features = ["chrono-tz"]} +arrow-array = { workspace = true, features = ["chrono-tz"] } arrow-buffer = { workspace = true } arrow-cast = { workspace = true } arrow-ipc = { workspace = true } @@ -58,7 +58,7 @@ regex = { workspace = true } thiserror = { workspace = true } uuid = { workspace = true, features = ["serde", "v4"] } url = { workspace = true } -urlencoding = { workspace = true} +urlencoding = { workspace = true } # runtime async-trait = { workspace = true } @@ -81,7 +81,6 @@ dashmap = "6" errno = "0.3" either = "1.8" fix-hidden-lifetime-bug = "0.2" -hyper = { version = "0.14", optional = true } indexmap = "2.2.1" itertools = "0.13" lazy_static = "1" @@ -97,13 +96,7 @@ tracing = { workspace = true } rand = "0.8" z85 = "3.0.5" maplit = "1" -sqlparser = { version = "0.51" } - -# Unity -reqwest = { version = "0.11.18", default-features = false, features = [ - "rustls-tls", - "json", -], optional = true } +sqlparser = { version = "0.52.0" } [dev-dependencies] criterion = "0.5" @@ -111,7 +104,6 @@ ctor = "0" deltalake-test = { path = "../test", features = ["datafusion"] } dotenvy = "0" fs_extra = "1.2.0" -hyper = { version = "0.14", features = ["server"] } maplit = "1" pretty_assertions = "1.2.1" pretty_env_logger = "0.5.0" @@ -136,5 +128,4 @@ datafusion = [ ] datafusion-ext = ["datafusion"] json = ["parquet/json"] -python = ["arrow/pyarrow"] -unity-experimental = ["reqwest", "hyper"] +python = ["arrow/pyarrow"] \ No newline at end of file diff --git a/crates/core/src/data_catalog/client/mock_server.rs b/crates/core/src/data_catalog/client/mock_server.rs deleted file mode 100644 index 9bed67e75c..0000000000 --- a/crates/core/src/data_catalog/client/mock_server.rs +++ /dev/null @@ -1,94 +0,0 @@ -use std::collections::VecDeque; -use std::convert::Infallible; -use std::net::SocketAddr; -use std::sync::Arc; - -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Request, Response, Server}; -use parking_lot::Mutex; -use tokio::sync::oneshot; -use tokio::task::JoinHandle; - -pub type ResponseFn = Box) -> Response + Send>; - -/// A mock server -pub struct MockServer { - responses: Arc>>, - shutdown: oneshot::Sender<()>, - handle: JoinHandle<()>, - url: String, -} - -impl Default for MockServer { - fn default() -> Self { - Self::new() - } -} - -impl MockServer { - pub fn new() -> Self { - let responses: Arc>> = - Arc::new(Mutex::new(VecDeque::with_capacity(10))); - - let r = Arc::clone(&responses); - let make_service = make_service_fn(move |_conn| { - let r = Arc::clone(&r); - async move { - Ok::<_, Infallible>(service_fn(move |req| { - let r = Arc::clone(&r); - async move { - Ok::<_, Infallible>(match r.lock().pop_front() { - Some(r) => r(req), - None => Response::new(Body::from("Hello World")), - }) - } - })) - } - }); - - let (shutdown, rx) = oneshot::channel::<()>(); - let server = Server::bind(&SocketAddr::from(([127, 0, 0, 1], 0))).serve(make_service); - - let url = format!("http://{}", server.local_addr()); - - let handle = tokio::spawn(async move { - server - .with_graceful_shutdown(async { - rx.await.ok(); - }) - .await - .unwrap() - }); - - Self { - responses, - shutdown, - handle, - url, - } - } - - /// The url of the mock server - pub fn url(&self) -> &str { - &self.url - } - - /// Add a response - pub fn push(&self, response: Response) { - self.push_fn(|_| response) - } - - /// Add a response function - pub fn push_fn(&self, f: F) - where - F: FnOnce(Request) -> Response + Send + 'static, - { - self.responses.lock().push_back(Box::new(f)) - } - - /// Shutdown the mock server - pub async fn shutdown(self) { - let _ = self.shutdown.send(()); - self.handle.await.unwrap() - } -} diff --git a/crates/core/src/data_catalog/client/retry.rs b/crates/core/src/data_catalog/client/retry.rs deleted file mode 100644 index 300e7afe7b..0000000000 --- a/crates/core/src/data_catalog/client/retry.rs +++ /dev/null @@ -1,365 +0,0 @@ -//! A shared HTTP client implementation incorporating retries -use std::error::Error as StdError; - -use futures::future::BoxFuture; -use futures::FutureExt; -use reqwest::header::LOCATION; -use reqwest::{Response, StatusCode}; -use std::time::{Duration, Instant}; -use tracing::info; - -use super::backoff::{Backoff, BackoffConfig}; - -/// Retry request error -#[derive(Debug)] -pub struct RetryError { - retries: usize, - message: String, - source: Option, -} - -impl std::fmt::Display for RetryError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "response error \"{}\", after {} retries", - self.message, self.retries - )?; - if let Some(source) = &self.source { - write!(f, ": {source}")?; - } - Ok(()) - } -} - -impl std::error::Error for RetryError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.source.as_ref().map(|e| e as _) - } -} - -impl RetryError { - /// Returns the status code associated with this error if any - pub fn status(&self) -> Option { - self.source.as_ref().and_then(|e| e.status()) - } -} - -impl From for std::io::Error { - fn from(err: RetryError) -> Self { - use std::io::ErrorKind; - match (&err.source, err.status()) { - (Some(source), _) if source.is_builder() || source.is_request() => { - Self::new(ErrorKind::InvalidInput, err) - } - (_, Some(StatusCode::NOT_FOUND)) => Self::new(ErrorKind::NotFound, err), - (_, Some(StatusCode::BAD_REQUEST)) => Self::new(ErrorKind::InvalidInput, err), - (Some(source), None) if source.is_timeout() => Self::new(ErrorKind::TimedOut, err), - (Some(source), None) if source.is_connect() => Self::new(ErrorKind::NotConnected, err), - _ => Self::new(ErrorKind::Other, err), - } - } -} - -/// Error retrying http requests -pub type Result = std::result::Result; - -/// Contains the configuration for how to respond to server errors -/// -/// By default they will be retried up to some limit, using exponential -/// backoff with jitter. See [`BackoffConfig`] for more information -/// -#[derive(Debug, Clone)] -pub struct RetryConfig { - /// The backoff configuration - pub backoff: BackoffConfig, - - /// The maximum number of times to retry a request - /// - /// Set to 0 to disable retries - pub max_retries: usize, - - /// The maximum length of time from the initial request - /// after which no further retries will be attempted - /// - /// This not only bounds the length of time before a server - /// error will be surfaced to the application, but also bounds - /// the length of time a request's credentials must remain valid. - /// - /// As requests are retried without renewing credentials or - /// regenerating request payloads, this number should be kept - /// below 5 minutes to avoid errors due to expired credentials - /// and/or request payloads - pub retry_timeout: Duration, -} - -impl Default for RetryConfig { - fn default() -> Self { - Self { - backoff: Default::default(), - max_retries: 10, - retry_timeout: Duration::from_secs(3 * 60), - } - } -} - -/// Trait to rend requests with retry -pub trait RetryExt { - /// Dispatch a request with the given retry configuration - /// - /// # Panic - /// - /// This will panic if the request body is a stream - fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result>; -} - -impl RetryExt for reqwest::RequestBuilder { - fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result> { - let mut backoff = Backoff::new(&config.backoff); - let max_retries = config.max_retries; - let retry_timeout = config.retry_timeout; - - async move { - let mut retries = 0; - let now = Instant::now(); - - loop { - let s = self.try_clone().expect("request body must be cloneable"); - match s.send().await { - Ok(r) => match r.error_for_status_ref() { - Ok(_) if r.status().is_success() => return Ok(r), - Ok(r) => { - let is_bare_redirect = r.status().is_redirection() && !r.headers().contains_key(LOCATION); - let message = match is_bare_redirect { - true => "Received redirect without LOCATION, this normally indicates an incorrectly configured region".to_string(), - // Not actually sure if this is reachable, but here for completeness - false => format!("request unsuccessful: {}", r.status()), - }; - - return Err(RetryError{ - message, - retries, - source: None, - }) - } - Err(e) => { - let status = r.status(); - - if retries == max_retries - || now.elapsed() > retry_timeout - || !status.is_server_error() { - - // Get the response message if returned a client error - let message = match status.is_client_error() { - true => match r.text().await { - Ok(message) if !message.is_empty() => message, - Ok(_) => "No Body".to_string(), - Err(e) => format!("error getting response body: {e}") - } - false => status.to_string(), - }; - - return Err(RetryError{ - message, - retries, - source: Some(e), - }) - - } - - let sleep = backoff.tick(); - retries += 1; - info!("Encountered server error, backing off for {} seconds, retry {} of {}", sleep.as_secs_f32(), retries, max_retries); - tokio::time::sleep(sleep).await; - } - }, - Err(e) => - { - let mut do_retry = false; - if let Some(source) = e.source() { - if let Some(e) = source.downcast_ref::() { - if e.is_connect() || e.is_closed() || e.is_incomplete_message() { - do_retry = true; - } - } - } - - if retries == max_retries - || now.elapsed() > retry_timeout - || !do_retry { - - return Err(RetryError{ - retries, - message: "request error".to_string(), - source: Some(e) - }) - } - let sleep = backoff.tick(); - retries += 1; - info!("Encountered request error ({}) backing off for {} seconds, retry {} of {}", e, sleep.as_secs_f32(), retries, max_retries); - tokio::time::sleep(sleep).await; - } - } - } - } - .boxed() - } -} - -#[cfg(test)] -mod tests { - use super::super::mock_server::MockServer; - use super::RetryConfig; - use super::RetryExt; - use hyper::header::LOCATION; - use hyper::{Body, Response}; - use reqwest::{Client, Method, StatusCode}; - use std::time::Duration; - - #[tokio::test] - async fn test_retry() { - let mock = MockServer::new(); - - let retry = RetryConfig { - backoff: Default::default(), - max_retries: 2, - retry_timeout: Duration::from_secs(1000), - }; - - let client = Client::new(); - let do_request = || client.request(Method::GET, mock.url()).send_retry(&retry); - - // Simple request should work - let r = do_request().await.unwrap(); - assert_eq!(r.status(), StatusCode::OK); - - // Returns client errors immediately with status message - mock.push( - Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::from("cupcakes")) - .unwrap(), - ); - - let e = do_request().await.unwrap_err(); - assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); - assert_eq!(e.retries, 0); - assert_eq!(&e.message, "cupcakes"); - - // Handles client errors with no payload - mock.push( - Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::empty()) - .unwrap(), - ); - - let e = do_request().await.unwrap_err(); - assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); - assert_eq!(e.retries, 0); - assert_eq!(&e.message, "No Body"); - - // Should retry server error request - mock.push( - Response::builder() - .status(StatusCode::BAD_GATEWAY) - .body(Body::empty()) - .unwrap(), - ); - - let r = do_request().await.unwrap(); - assert_eq!(r.status(), StatusCode::OK); - - // Accepts 204 status code - mock.push( - Response::builder() - .status(StatusCode::NO_CONTENT) - .body(Body::empty()) - .unwrap(), - ); - - let r = do_request().await.unwrap(); - assert_eq!(r.status(), StatusCode::NO_CONTENT); - - // Follows 402 redirects - mock.push( - Response::builder() - .status(StatusCode::FOUND) - .header(LOCATION, "/foo") - .body(Body::empty()) - .unwrap(), - ); - - let r = do_request().await.unwrap(); - assert_eq!(r.status(), StatusCode::OK); - assert_eq!(r.url().path(), "/foo"); - - // Follows 401 redirects - mock.push( - Response::builder() - .status(StatusCode::FOUND) - .header(LOCATION, "/bar") - .body(Body::empty()) - .unwrap(), - ); - - let r = do_request().await.unwrap(); - assert_eq!(r.status(), StatusCode::OK); - assert_eq!(r.url().path(), "/bar"); - - // Handles redirect loop - for _ in 0..10 { - mock.push( - Response::builder() - .status(StatusCode::FOUND) - .header(LOCATION, "/bar") - .body(Body::empty()) - .unwrap(), - ); - } - - let e = do_request().await.unwrap_err().to_string(); - assert!(e.ends_with("too many redirects"), "{}", e); - - // Handles redirect missing location - mock.push( - Response::builder() - .status(StatusCode::FOUND) - .body(Body::empty()) - .unwrap(), - ); - - let e = do_request().await.unwrap_err(); - assert_eq!(e.message, "Received redirect without LOCATION, this normally indicates an incorrectly configured region"); - - // Gives up after the retrying the specified number of times - for _ in 0..=retry.max_retries { - mock.push( - Response::builder() - .status(StatusCode::BAD_GATEWAY) - .body(Body::from("ignored")) - .unwrap(), - ); - } - - let e = do_request().await.unwrap_err(); - assert_eq!(e.retries, retry.max_retries); - assert_eq!(e.message, "502 Bad Gateway"); - - // Panic results in an incomplete message error in the client - mock.push_fn(|_| panic!()); - let r = do_request().await.unwrap(); - assert_eq!(r.status(), StatusCode::OK); - - // Gives up after retrying mulitiple panics - for _ in 0..=retry.max_retries { - mock.push_fn(|_| panic!()); - } - let e = do_request().await.unwrap_err(); - assert_eq!(e.retries, retry.max_retries); - assert_eq!(e.message, "request error"); - - // Shutdown - mock.shutdown().await - } -} diff --git a/crates/core/src/data_catalog/mod.rs b/crates/core/src/data_catalog/mod.rs index eaa02ff09a..fbb44d95c1 100644 --- a/crates/core/src/data_catalog/mod.rs +++ b/crates/core/src/data_catalog/mod.rs @@ -2,15 +2,8 @@ use std::fmt::Debug; -#[cfg(feature = "unity-experimental")] -pub use unity::*; - -#[cfg(feature = "unity-experimental")] -pub mod client; #[cfg(feature = "datafusion")] pub mod storage; -#[cfg(feature = "unity-experimental")] -pub mod unity; /// A result type for data catalog implementations pub type DataCatalogResult = Result; @@ -27,37 +20,6 @@ pub enum DataCatalogError { source: Box, }, - /// A generic error qualified in the message - #[cfg(feature = "unity-experimental")] - #[error("{source}")] - Retry { - /// Error message - #[from] - source: client::retry::RetryError, - }, - - #[error("Request error: {source}")] - #[cfg(feature = "unity-experimental")] - /// Error from reqwest library - RequestError { - /// The underlying reqwest_middleware::Error - #[from] - source: reqwest::Error, - }, - - /// Error caused by missing environment variable for Unity Catalog. - #[cfg(feature = "unity-experimental")] - #[error("Missing Unity Catalog environment variable: {var_name}")] - MissingEnvVar { - /// Variable name - var_name: String, - }, - - /// Error caused by invalid access token value - #[cfg(feature = "unity-experimental")] - #[error("Invalid Databricks personal access token")] - InvalidAccessToken, - /// Error representing an invalid Data Catalog. #[error("This data catalog doesn't exist: {data_catalog}")] InvalidDataCatalog { @@ -74,16 +36,23 @@ pub enum DataCatalogError { /// configuration key key: String, }, + + #[error("Error in request: {source}")] + RequestError { + source: Box, + }, } /// Abstractions for data catalog for the Delta table. To add support for new cloud, simply implement this trait. #[async_trait::async_trait] pub trait DataCatalog: Send + Sync + Debug { + type Error; + /// Get the table storage location from the Data Catalog async fn get_table_storage_location( &self, catalog_id: Option, database_name: &str, table_name: &str, - ) -> Result; + ) -> Result; } diff --git a/crates/core/src/data_catalog/storage/mod.rs b/crates/core/src/data_catalog/storage/mod.rs index 110e4aa075..236caf79a8 100644 --- a/crates/core/src/data_catalog/storage/mod.rs +++ b/crates/core/src/data_catalog/storage/mod.rs @@ -88,9 +88,9 @@ impl ListingSchemaProvider { } } -// noramalizes a path fragment to be a valida table name in datafusion +// normalizes a path fragment to be a valida table name in datafusion // - removes some reserved characters (-, +, ., " ") -// - lowecase ascii +// - lowercase ascii fn normalize_table_name(path: &Path) -> Result { Ok(path .file_name() diff --git a/crates/core/src/delta_datafusion/expr.rs b/crates/core/src/delta_datafusion/expr.rs index c0e79ba490..b633cae141 100644 --- a/crates/core/src/delta_datafusion/expr.rs +++ b/crates/core/src/delta_datafusion/expr.rs @@ -234,6 +234,10 @@ impl ContextProvider for DeltaContextProvider<'_> { self.state.aggregate_functions().get(name).cloned() } + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } + fn get_variable_type(&self, _var: &[String]) -> Option { unimplemented!() } @@ -242,10 +246,6 @@ impl ContextProvider for DeltaContextProvider<'_> { self.state.config_options() } - fn get_window_meta(&self, name: &str) -> Option> { - self.state.window_functions().get(name).cloned() - } - fn udf_names(&self) -> Vec { self.state.scalar_functions().keys().cloned().collect() } diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index 034781b85c..e692dd054b 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -1555,7 +1555,7 @@ fn join_batches_with_add_actions( } /// Determine which files contain a record that satisfies the predicate -pub(crate) async fn find_files_scan<'a>( +pub(crate) async fn find_files_scan( snapshot: &DeltaTableState, log_store: LogStoreRef, state: &SessionState, @@ -1668,7 +1668,7 @@ pub(crate) async fn scan_memory_table( } /// Finds files in a snapshot that match the provided predicate. -pub async fn find_files<'a>( +pub async fn find_files( snapshot: &DeltaTableState, log_store: LogStoreRef, state: &SessionState, diff --git a/crates/core/src/kernel/models/actions.rs b/crates/core/src/kernel/models/actions.rs index ef370b4956..119f561b80 100644 --- a/crates/core/src/kernel/models/actions.rs +++ b/crates/core/src/kernel/models/actions.rs @@ -187,9 +187,9 @@ impl Protocol { let mut converted_writer_features = configuration .iter() .filter(|(_, value)| { - value.as_ref().map_or(false, |v| { - v.to_ascii_lowercase().parse::().is_ok_and(|v| v) - }) + value + .as_ref() + .is_some_and(|v| v.to_ascii_lowercase().parse::().is_ok_and(|v| v)) }) .collect::>>() .keys() @@ -216,9 +216,9 @@ impl Protocol { let converted_reader_features = configuration .iter() .filter(|(_, value)| { - value.as_ref().map_or(false, |v| { - v.to_ascii_lowercase().parse::().is_ok_and(|v| v) - }) + value + .as_ref() + .is_some_and(|v| v.to_ascii_lowercase().parse::().is_ok_and(|v| v)) }) .map(|(key, _)| (*key).clone().into()) .filter(|v| !matches!(v, ReaderFeatures::Other(_))) diff --git a/crates/core/src/kernel/snapshot/mod.rs b/crates/core/src/kernel/snapshot/mod.rs index 25e11b88ca..a85087ea9b 100644 --- a/crates/core/src/kernel/snapshot/mod.rs +++ b/crates/core/src/kernel/snapshot/mod.rs @@ -416,7 +416,7 @@ impl EagerSnapshot { } /// Update the snapshot to the given version - pub async fn update<'a>( + pub async fn update( &mut self, log_store: Arc, target_version: Option, diff --git a/crates/core/src/operations/load_cdf.rs b/crates/core/src/operations/load_cdf.rs index a63b5182b2..3d5bed2d26 100644 --- a/crates/core/src/operations/load_cdf.rs +++ b/crates/core/src/operations/load_cdf.rs @@ -173,9 +173,7 @@ impl CdfLoadBuilder { return if self.allow_out_of_range { Ok((change_files, add_files, remove_files)) } else { - Err(DeltaTableError::ChangeDataTimestampGreaterThanCommit { - ending_timestamp: ending_timestamp, - }) + Err(DeltaTableError::ChangeDataTimestampGreaterThanCommit { ending_timestamp }) }; } } diff --git a/crates/core/src/operations/merge/filter.rs b/crates/core/src/operations/merge/filter.rs index 0745c55830..602df519a1 100644 --- a/crates/core/src/operations/merge/filter.rs +++ b/crates/core/src/operations/merge/filter.rs @@ -252,16 +252,13 @@ pub(crate) fn generalize_filter( } } Expr::InList(in_list) => { - let compare_expr = match generalize_filter( + let compare_expr = generalize_filter( *in_list.expr, partition_columns, source_name, target_name, placeholders, - ) { - Some(expr) => expr, - None => return None, // Return early - }; + )?; let mut list_expr = Vec::new(); for item in in_list.list.into_iter() { diff --git a/crates/core/src/operations/transaction/mod.rs b/crates/core/src/operations/transaction/mod.rs index 88b28a8627..6d80d858b0 100644 --- a/crates/core/src/operations/transaction/mod.rs +++ b/crates/core/src/operations/transaction/mod.rs @@ -649,7 +649,7 @@ pub struct PostCommit<'a> { table_data: Option<&'a dyn TableReference>, } -impl<'a> PostCommit<'a> { +impl PostCommit<'_> { /// Runs the post commit activities async fn run_post_commit_hook(&self) -> DeltaResult { if let Some(table) = self.table_data { diff --git a/crates/core/src/writer/stats.rs b/crates/core/src/writer/stats.rs index 4fe448ea76..10260b8364 100644 --- a/crates/core/src/writer/stats.rs +++ b/crates/core/src/writer/stats.rs @@ -352,6 +352,7 @@ impl StatsScalar { // the precision - scale range take the next smaller (by magnitude) value val = f64::from_bits(val.to_bits() - 1); } + Ok(Self::Decimal(val)) } (Statistics::FixedLenByteArray(v), Some(LogicalType::Uuid)) => { diff --git a/crates/core/tests/command_optimize.rs b/crates/core/tests/command_optimize.rs index 13cbd168e4..e96ec08a6e 100644 --- a/crates/core/tests/command_optimize.rs +++ b/crates/core/tests/command_optimize.rs @@ -77,8 +77,8 @@ fn generate_random_batch>( let s = partition.into(); for _ in 0..rows { - x_vec.push(rng.gen()); - y_vec.push(rng.gen()); + x_vec.push(rng.r#gen()); + y_vec.push(rng.r#gen()); date_vec.push(s.clone()); } diff --git a/crates/core/tests/command_restore.rs b/crates/core/tests/command_restore.rs index 5013556ab8..9ac3f331da 100644 --- a/crates/core/tests/command_restore.rs +++ b/crates/core/tests/command_restore.rs @@ -74,8 +74,8 @@ fn get_record_batch() -> RecordBatch { let mut rng = rand::thread_rng(); for _ in 0..10 { - id_vec.push(rng.gen()); - value_vec.push(rng.gen()); + id_vec.push(rng.r#gen()); + value_vec.push(rng.r#gen()); } let schema = ArrowSchema::new(vec![ diff --git a/crates/deltalake/Cargo.toml b/crates/deltalake/Cargo.toml index 476f0b5d60..c760a55971 100644 --- a/crates/deltalake/Cargo.toml +++ b/crates/deltalake/Cargo.toml @@ -22,6 +22,7 @@ deltalake-azure = { version = "0.5.0", path = "../azure", optional = true } deltalake-gcp = { version = "0.6.0", path = "../gcp", optional = true } deltalake-hdfs = { version = "0.6.0", path = "../hdfs", optional = true } deltalake-catalog-glue = { version = "0.6.0", path = "../catalog-glue", optional = true } +deltalake-catalog-unity = { version = "0.6.0", path = "../catalog-unity", optional = true } [features] # All of these features are just reflected into the core crate until that @@ -37,7 +38,7 @@ json = ["deltalake-core/json"] python = ["deltalake-core/python"] s3-native-tls = ["deltalake-aws/native-tls"] s3 = ["deltalake-aws/rustls"] -unity-experimental = ["deltalake-core/unity-experimental"] +unity-experimental = ["deltalake-catalog-unity"] [dev-dependencies] tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/crates/gcp/tests/context.rs b/crates/gcp/tests/context.rs index 4bcc2c1b3b..5dc0f8cb44 100644 --- a/crates/gcp/tests/context.rs +++ b/crates/gcp/tests/context.rs @@ -76,10 +76,12 @@ impl StorageIntegration for GcpIntegration { let account_path = self.temp_dir.path().join("gcs.json"); info!("account_path: {account_path:?}"); std::fs::write(&account_path, serde_json::to_vec(&token).unwrap()).unwrap(); - std::env::set_var( - "GOOGLE_SERVICE_ACCOUNT", - account_path.as_path().to_str().unwrap(), - ); + unsafe { + std::env::set_var( + "GOOGLE_SERVICE_ACCOUNT", + account_path.as_path().to_str().unwrap(), + ); + } } fn bucket_name(&self) -> String { diff --git a/crates/sql/src/logical_plan.rs b/crates/sql/src/logical_plan.rs index 9f154c0204..27da52a96d 100644 --- a/crates/sql/src/logical_plan.rs +++ b/crates/sql/src/logical_plan.rs @@ -33,7 +33,7 @@ impl DeltaStatement { impl Display for Wrapper<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.0 { - DeltaStatement::Vacuum(Vacuum { + &DeltaStatement::Vacuum(Vacuum { ref table, ref dry_run, ref retention_hours, diff --git a/python/Cargo.toml b/python/Cargo.toml index bb6fbba621..8f44393819 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -46,8 +46,8 @@ reqwest = { version = "*", features = ["native-tls-vendored"] } deltalake-mount = { path = "../crates/mount" } [dependencies.pyo3] -version = "0.22.2" -features = ["extension-module", "abi3", "abi3-py39"] +version = "0.22.6" +features = ["extension-module", "abi3", "abi3-py39", "gil-refs"] [dependencies.deltalake] path = "../crates/deltalake"