Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support reading tables via Unity Catalog provided credentials #3078

Merged
merged 4 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions crates/catalog-unity/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,40 @@ tokio.workspace = true
serde.workspace = true
serde_json.workspace = true
thiserror.workspace = true
futures.workspace = true
chrono.workspace = true
tracing.workspace = true
deltalake-core = { version = "0.24.0", path = "../core", features = [
"datafusion",
]}
] }
deltalake-aws = { version = "0.7.0", path = "../aws", optional = true }
deltalake-azure = { version = "0.7.0", path = "../azure", optional = true }
deltalake-gcp = { version = "0.8.0", path = "../gcp", optional = true }
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json", "http2"] }
reqwest-retry = "0.7"
reqwest-middleware = "0.4.0"
reqwest-middleware = { version = "0.4.0", features = ["json"] }
rand = "0.8"
futures = { workspace = true }
chrono = { workspace = true }
dashmap = "6"
tracing = { workspace = true }
datafusion = { workspace = true, optional = true }
datafusion-common = { workspace = true, optional = true }
moka = { version = "0.12", optional = true, features = ["future"] }

[dev-dependencies]
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
tempfile = "3"
httpmock = { version = "0.8.0-alpha.1" }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

[features]
default = []
datafusion = ["dep:datafusion", "datafusion-common"]
default = ["datafusion", "aws"]
ion-elgreco marked this conversation as resolved.
Show resolved Hide resolved
aws = ["deltalake-aws"]
ion-elgreco marked this conversation as resolved.
Show resolved Hide resolved
azure = ["deltalake-azure"]
gcp = ["deltalake-gcp"]
r2 = ["deltalake-aws"]
datafusion = ["dep:datafusion", "datafusion-common", "deltalake-core/datafusion", "moka"]

[[example]]
name = "uc_example"
path = "examples/uc_example.rs"
required-features = ["datafusion", "aws"]

46 changes: 46 additions & 0 deletions crates/catalog-unity/examples/uc_example.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use datafusion::prelude::*;
use deltalake_catalog_unity::prelude::*;
use std::error::Error;
use std::sync::Arc;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let filter = tracing_subscriber::EnvFilter::builder().parse("deltalake_catalog_unity=info")?;
let subscriber = tracing_subscriber::fmt()
.pretty()
.with_env_filter(filter)
.finish();
tracing::subscriber::set_global_default(subscriber)?;

let uc = UnityCatalogBuilder::from_env().build()?;

deltalake_aws::register_handlers(None);

let catalog = UnityCatalogProvider::try_new(Arc::new(uc), "scarman_sandbox").await?;
let ctx = SessionContext::new();
ctx.register_catalog("scarman_sandbox", Arc::new(catalog));

ctx.sql(
"select hdci.city_name, hdci.country_code, hdci.latitude, hdci.longitude from \
scarman_sandbox.external_data.historical_hourly_imperial hhi \
join scarman_sandbox.external_data.historical_daily_calendar_imperial hdci on hdci.country_code = hhi.country_code \
order by city_name \
limit 50;"
)
.await?
.show()
.await?;

ctx.table("scarman_sandbox.external_data.historical_hourly_imperial")
.await?
.select(vec![
col("city_name"),
col("country_code"),
col("latitude"),
col("longitude"),
])?
.show_limit(50)
.await?;

Ok(())
}
Empty file.
107 changes: 71 additions & 36 deletions crates/catalog-unity/src/datafusion.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
//! Datafusion integration for UnityCatalog

use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;

use chrono::prelude::*;
use dashmap::DashMap;
use datafusion::catalog::SchemaProvider;
use datafusion::catalog::{CatalogProvider, CatalogProviderList};
use datafusion::datasource::TableProvider;
use datafusion_common::DataFusionError;
use futures::FutureExt;
use moka::future::Cache;
use moka::Expiry;
use std::any::Any;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::error;

use super::models::{
GetTableResponse, ListCatalogsResponse, ListSchemasResponse, ListTableSummariesResponse,
TableTempCredentialsResponse, TemporaryTableCredentials,
};
use super::{DataCatalogResult, UnityCatalog};

use super::{DataCatalogResult, UnityCatalog, UnityCatalogError};
use deltalake_core::DeltaTableBuilder;

/// In-memory list of catalogs populated by unity catalog
Expand All @@ -27,20 +30,13 @@ pub struct UnityCatalogList {

impl UnityCatalogList {
/// Create a new instance of [`UnityCatalogList`]
pub async fn try_new(
client: Arc<UnityCatalog>,
storage_options: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)> + Clone,
) -> DataCatalogResult<Self> {
pub async fn try_new(client: Arc<UnityCatalog>) -> DataCatalogResult<Self> {
let catalogs = match client.list_catalogs().await? {
ListCatalogsResponse::Success { catalogs } => {
ListCatalogsResponse::Success { catalogs, .. } => {
let mut providers = Vec::new();
for catalog in catalogs {
let provider = UnityCatalogProvider::try_new(
client.clone(),
&catalog.name,
storage_options.clone(),
)
.await?;
let provider =
UnityCatalogProvider::try_new(client.clone(), &catalog.name).await?;
providers.push((catalog.name, Arc::new(provider) as Arc<dyn CatalogProvider>));
}
providers
Expand Down Expand Up @@ -87,20 +83,15 @@ impl UnityCatalogProvider {
pub async fn try_new(
client: Arc<UnityCatalog>,
catalog_name: impl Into<String>,
storage_options: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)> + Clone,
) -> DataCatalogResult<Self> {
let catalog_name = catalog_name.into();
let schemas = match client.list_schemas(&catalog_name).await? {
ListSchemasResponse::Success { schemas } => {
let mut providers = Vec::new();
for schema in schemas {
let provider = UnitySchemaProvider::try_new(
client.clone(),
&catalog_name,
&schema.name,
storage_options.clone(),
)
.await?;
let provider =
UnitySchemaProvider::try_new(client.clone(), &catalog_name, &schema.name)
.await?;
providers.push((schema.name, Arc::new(provider) as Arc<dyn SchemaProvider>));
}
providers
Expand All @@ -127,20 +118,33 @@ impl CatalogProvider for UnityCatalogProvider {
}
}

struct TokenExpiry;

impl Expiry<String, TemporaryTableCredentials> for TokenExpiry {
fn expire_after_read(
&self,
_key: &String,
value: &TemporaryTableCredentials,
_read_at: Instant,
_duration_until_expiry: Option<Duration>,
_last_modified_at: Instant,
) -> Option<Duration> {
let time_to_expire = value.expiration_time - Utc::now();
tracing::info!("Token {} expires in {}", _key, time_to_expire);
time_to_expire.to_std().ok()
}
}

/// A datafusion [`SchemaProvider`] backed by Databricks UnityCatalog
#[derive(Debug)]
pub struct UnitySchemaProvider {
/// UnityCatalog Api client
client: Arc<UnityCatalog>,

catalog_name: String,

schema_name: String,

/// Parent catalog for schemas of interest.
table_names: Vec<String>,

storage_options: HashMap<String, String>,
token_cache: Cache<String, TemporaryTableCredentials>,
}

impl UnitySchemaProvider {
Expand All @@ -149,7 +153,6 @@ impl UnitySchemaProvider {
client: Arc<UnityCatalog>,
catalog_name: impl Into<String>,
schema_name: impl Into<String>,
storage_options: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
) -> DataCatalogResult<Self> {
let catalog_name = catalog_name.into();
let schema_name = schema_name.into();
Expand All @@ -163,17 +166,37 @@ impl UnitySchemaProvider {
.collect(),
ListTableSummariesResponse::Error(_) => vec![],
};
let token_cache = Cache::builder().expire_after(TokenExpiry).build();
Ok(Self {
client,
table_names,
catalog_name,
schema_name,
storage_options: storage_options
.into_iter()
.map(|(key, value)| (key.into(), value.into()))
.collect(),
token_cache,
})
}

async fn get_creds(
&self,
catalog: &str,
schema: &str,
table: &str,
) -> Result<TemporaryTableCredentials, UnityCatalogError> {
tracing::debug!(
"Fetching new credential for: {}.{}.{}",
catalog,
schema,
table
);
self.client
.get_temp_table_credentials(catalog, schema, table)
.map(|resp| match resp {
Ok(TableTempCredentialsResponse::Success(temp_creds)) => Ok(temp_creds),
Ok(TableTempCredentialsResponse::Error(err)) => Err(err.into()),
Err(err) => Err(err.into()),
})
.await
}
}

#[async_trait::async_trait]
Expand All @@ -195,8 +218,20 @@ impl SchemaProvider for UnitySchemaProvider {

match maybe_table {
GetTableResponse::Success(table) => {
let temp_creds = self
.token_cache
.try_get_with(
table.table_id,
self.get_creds(&self.catalog_name, &self.schema_name, name),
)
.await
.map_err(|err| DataFusionError::External(err.into()))?;

let new_storage_opts = temp_creds.get_credentials().ok_or_else(|| {
DataFusionError::External(UnityCatalogError::MissingCredential.into())
})?;
let table = DeltaTableBuilder::from_uri(table.storage_location)
.with_storage_options(self.storage_options.clone())
.with_storage_options(new_storage_opts)
.load()
.await?;
Ok(Some(Arc::new(table)))
Expand Down
38 changes: 0 additions & 38 deletions crates/catalog-unity/src/error.rs

This file was deleted.

Loading
Loading