From ebf09186ab2df4a9881a2524cc6efb2b8185ac91 Mon Sep 17 00:00:00 2001 From: TennyZhuang Date: Mon, 11 Mar 2024 22:16:28 +0800 Subject: [PATCH 1/6] feat: implement OAuth for catalog rest Signed-off-by: TennyZhuang --- crates/catalog/rest/src/catalog.rs | 77 +++++++++++++++++-- .../catalog/rest/tests/rest_catalog_test.rs | 1 + 2 files changed, 73 insertions(+), 5 deletions(-) diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index c10d904b6..87ddd947c 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -38,7 +38,7 @@ use iceberg::{ use self::_serde::{ CatalogConfig, ErrorResponse, ListNamespaceResponse, ListTableResponse, NamespaceSerde, - RenameTableRequest, NO_CONTENT, OK, + RenameTableRequest, TokenResponse, NO_CONTENT, OK, }; const ICEBERG_REST_SPEC_VERSION: &str = "0.14.1"; @@ -96,9 +96,13 @@ impl RestCatalogConfig { .join("/") } + fn get_token_endpoint(&self) -> String { + [&self.uri, PATH_V1, "oauth", "tokens"].join("/") + } + fn try_create_rest_client(&self) -> Result { - //TODO: We will add oauth, ssl config, sigv4 later - let headers = HeaderMap::from_iter([ + // TODO: We will add ssl config, sigv4 later + let mut headers = HeaderMap::from_iter([ ( header::CONTENT_TYPE, HeaderValue::from_static("application/json"), @@ -113,6 +117,19 @@ impl RestCatalogConfig { ), ]); + if let Some(token) = self.props.get("token") { + headers.insert( + "Authorization", + HeaderValue::from_str(&format!("Bearer {token}")).map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + "Invalid token received from catalog server!", + ) + .with_source(e) + })?, + ); + } + Ok(HttpClient( Client::builder().default_headers(headers).build()?, )) @@ -144,6 +161,7 @@ impl HttpClient { .with_source(e) })?) } else { + let code = resp.status(); let text = resp.bytes().await?; let e = serde_json::from_slice::(&text).map_err(|e| { Error::new( @@ -151,6 +169,7 @@ impl HttpClient { "Failed to parse response from rest catalog server!", ) .with_context("json", String::from_utf8_lossy(&text)) + .with_context("code", code.to_string()) .with_source(e) })?; Err(e.into()) @@ -497,13 +516,53 @@ impl RestCatalog { client: config.try_create_rest_client()?, config, }; - + catalog.fetch_access_token().await?; + catalog.client = catalog.config.try_create_rest_client()?; catalog.update_config().await?; catalog.client = catalog.config.try_create_rest_client()?; Ok(catalog) } + async fn fetch_access_token(&mut self) -> Result<()> { + if let Some(credential) = self.config.props.get("credential") { + let (client_id, client_secret) = if credential.contains(':') { + let (client_id, client_secret) = credential.split_once(':').unwrap(); + (Some(client_id), client_secret) + } else { + (None, credential.as_str()) + }; + let mut params = HashMap::with_capacity(4); + params.insert("grant_type", "client_credentials"); + if let Some(client_id) = client_id { + params.insert("client_id", client_id); + } + params.insert("client_secret", client_secret); + params.insert("scope", "catalog"); + let req = self + .client + .0 + .post(self.config.get_token_endpoint()) + .form(¶ms) + .build()?; + let res = self + .client + .query::(req) + .await + .map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to fetch access token from catalog server!", + ) + .with_source(e) + })?; + let token = res.access_token; + self.config.props.insert("token".to_string(), token); + } + + Ok(()) + } + async fn update_config(&mut self) -> Result<()> { let mut request = self.client.0.get(self.config.config_endpoint()); @@ -626,6 +685,14 @@ mod _serde { } } + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct TokenResponse { + pub(super) access_token: String, + pub(super) token_type: String, + pub(super) expires_in: u64, + pub(super) issued_token_type: String, + } + #[derive(Debug, Serialize, Deserialize)] pub(super) struct NamespaceSerde { pub(super) namespace: Vec, @@ -1557,7 +1624,7 @@ mod tests { "type": "NoSuchTableException", "code": 404 } -} +} "#, ) .create_async() diff --git a/crates/catalog/rest/tests/rest_catalog_test.rs b/crates/catalog/rest/tests/rest_catalog_test.rs index a4d07955b..205428d61 100644 --- a/crates/catalog/rest/tests/rest_catalog_test.rs +++ b/crates/catalog/rest/tests/rest_catalog_test.rs @@ -66,6 +66,7 @@ async fn set_test_fixture(func: &str) -> TestFixture { rest_catalog, } } + #[tokio::test] async fn test_get_non_exist_namespace() { let fixture = set_test_fixture("test_get_non_exist_namespace").await; From 382764457fb13370975156eb31348be2745cb902 Mon Sep 17 00:00:00 2001 From: TennyZhuang Date: Mon, 11 Mar 2024 22:44:43 +0800 Subject: [PATCH 2/6] add a check Signed-off-by: TennyZhuang --- crates/catalog/rest/src/catalog.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index 87ddd947c..207deecbd 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -525,6 +525,9 @@ impl RestCatalog { } async fn fetch_access_token(&mut self) -> Result<()> { + if self.config.props.contains_key("token") { + return Ok(()); + } if let Some(credential) = self.config.props.get("credential") { let (client_id, client_secret) = if credential.contains(':') { let (client_id, client_secret) = credential.split_once(':').unwrap(); From fea676540edb21f2e6845536203c209c86b58746 Mon Sep 17 00:00:00 2001 From: TennyZhuang Date: Tue, 12 Mar 2024 13:14:25 +0800 Subject: [PATCH 3/6] make issued_token_type optional Signed-off-by: TennyZhuang --- crates/catalog/rest/src/catalog.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index 207deecbd..b57773c89 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -693,7 +693,7 @@ mod _serde { pub(super) access_token: String, pub(super) token_type: String, pub(super) expires_in: u64, - pub(super) issued_token_type: String, + pub(super) issued_token_type: Option, } #[derive(Debug, Serialize, Deserialize)] From 4d6573228d2a0a0730aa2e343dc9c3966f11f0dc Mon Sep 17 00:00:00 2001 From: TennyZhuang Date: Tue, 12 Mar 2024 16:24:39 +0800 Subject: [PATCH 4/6] make expires_in optional Signed-off-by: TennyZhuang --- crates/catalog/rest/src/catalog.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index b57773c89..5f903fb92 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -692,7 +692,7 @@ mod _serde { pub(super) struct TokenResponse { pub(super) access_token: String, pub(super) token_type: String, - pub(super) expires_in: u64, + pub(super) expires_in: Option, pub(super) issued_token_type: Option, } From 1741ec616cdbd74c68b4965732c4ccf389c133b2 Mon Sep 17 00:00:00 2001 From: TennyZhuang Date: Thu, 14 Mar 2024 10:39:00 +0800 Subject: [PATCH 5/6] use header::AUTHORIZATION Signed-off-by: TennyZhuang --- crates/catalog/rest/src/catalog.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index 5f903fb92..b1d2a6aab 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -119,7 +119,7 @@ impl RestCatalogConfig { if let Some(token) = self.props.get("token") { headers.insert( - "Authorization", + header::AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {token}")).map_err(|e| { Error::new( ErrorKind::DataInvalid, From b32eeafea3a44274d7d4141949fc1d7ce3705b6a Mon Sep 17 00:00:00 2001 From: TennyZhuang Date: Thu, 14 Mar 2024 20:05:31 +0800 Subject: [PATCH 6/6] add mock test Signed-off-by: TennyZhuang --- crates/catalog/rest/src/catalog.rs | 38 ++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index b1d2a6aab..8112a44b0 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -848,6 +848,44 @@ mod tests { .await } + async fn create_oauth_mock(server: &mut ServerGuard) -> Mock { + server + .mock("POST", "/v1/oauth/tokens") + .with_status(200) + .with_body( + r#"{ + "access_token": "ey000000000000", + "token_type": "Bearer", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "expires_in": 86400 + }"#, + ) + .create_async() + .await + } + + #[tokio::test] + async fn test_oauth() { + let mut server = Server::new_async().await; + let oauth_mock = create_oauth_mock(&mut server).await; + let config_mock = create_config_mock(&mut server).await; + + let mut props = HashMap::new(); + props.insert("credential".to_string(), "client1:secret1".to_string()); + + let _catalog = RestCatalog::new( + RestCatalogConfig::builder() + .uri(server.url()) + .props(props) + .build(), + ) + .await + .unwrap(); + + oauth_mock.assert_async().await; + config_mock.assert_async().await; + } + #[tokio::test] async fn test_list_namespace() { let mut server = Server::new_async().await;