From 5bd0ef73bda3e9786c9c48cf1a259091156b6c3a Mon Sep 17 00:00:00 2001 From: Rossil <40714231+Rossil2012@users.noreply.github.com> Date: Mon, 30 Oct 2023 17:54:24 +0800 Subject: [PATCH 01/12] pass oauth token with plaintext --- proto/user.proto | 1 + src/frontend/src/handler/alter_user.rs | 6 +++++- src/frontend/src/handler/create_user.rs | 3 ++- src/frontend/src/session.rs | 2 ++ src/frontend/src/user/user_authentication.rs | 10 ++++++++++ src/sqlparser/src/ast/statement.rs | 7 +++++-- src/sqlparser/src/keywords.rs | 1 + src/sqlparser/src/parser.rs | 2 +- src/utils/pgwire/src/pg_protocol.rs | 2 +- src/utils/pgwire/src/pg_server.rs | 5 +++++ 10 files changed, 33 insertions(+), 6 deletions(-) diff --git a/proto/user.proto b/proto/user.proto index c998f66d15133..86ccdb8e16066 100644 --- a/proto/user.proto +++ b/proto/user.proto @@ -14,6 +14,7 @@ message AuthInfo { PLAINTEXT = 1; SHA256 = 2; MD5 = 3; + OAuth = 4; } EncryptionType encryption_type = 1; bytes encrypted_value = 2; diff --git a/src/frontend/src/handler/alter_user.rs b/src/frontend/src/handler/alter_user.rs index 0d83c3ae867d5..99e3f74f3e12f 100644 --- a/src/frontend/src/handler/alter_user.rs +++ b/src/frontend/src/handler/alter_user.rs @@ -23,7 +23,7 @@ use super::RwPgResponse; use crate::binder::Binder; use crate::catalog::CatalogError; use crate::handler::HandlerArgs; -use crate::user::user_authentication::encrypted_password; +use crate::user::user_authentication::{build_oauth_info, encrypted_password}; use crate::user::user_catalog::UserCatalog; fn alter_prost_user_info( @@ -109,6 +109,10 @@ fn alter_prost_user_info( } update_fields.push(UpdateField::AuthInfo); } + UserOption::OAuth => { + user_info.auth_info = build_oauth_info(); + update_fields.push(UpdateField::AuthInfo) + } } } Ok((user_info, update_fields)) diff --git a/src/frontend/src/handler/create_user.rs b/src/frontend/src/handler/create_user.rs index 8659e1b647c33..bb68cd18df44b 100644 --- a/src/frontend/src/handler/create_user.rs +++ b/src/frontend/src/handler/create_user.rs @@ -23,7 +23,7 @@ use super::RwPgResponse; use crate::binder::Binder; use crate::catalog::{CatalogError, DatabaseId}; use crate::handler::HandlerArgs; -use crate::user::user_authentication::encrypted_password; +use crate::user::user_authentication::{build_oauth_info, encrypted_password}; use crate::user::user_catalog::UserCatalog; fn make_prost_user_info( @@ -89,6 +89,7 @@ fn make_prost_user_info( user_info.auth_info = encrypted_password(&user_info.name, &password.0); } } + UserOption::OAuth => user_info.auth_info = build_oauth_info(), } } diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 120671ea73cb2..bf43f52e80de1 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -825,6 +825,8 @@ impl SessionManager for SessionManagerImpl { ), salt, } + } else if auth_info.encryption_type == EncryptionType::OAuth as i32 { + UserAuthenticator::OAuth } else { return Err(Box::new(Error::new( ErrorKind::Unsupported, diff --git a/src/frontend/src/user/user_authentication.rs b/src/frontend/src/user/user_authentication.rs index ad6c6d2e758a8..67c2dae8f1309 100644 --- a/src/frontend/src/user/user_authentication.rs +++ b/src/frontend/src/user/user_authentication.rs @@ -24,6 +24,15 @@ const MD5_ENCRYPTED_PREFIX: &str = "md5"; const VALID_SHA256_ENCRYPTED_LEN: usize = SHA256_ENCRYPTED_PREFIX.len() + 64; const VALID_MD5_ENCRYPTED_LEN: usize = MD5_ENCRYPTED_PREFIX.len() + 32; +/// Build AuthInfo for OAuth. +#[inline(always)] +pub fn build_oauth_info() -> Option { + Some(AuthInfo { + encryption_type: EncryptionType::OAuth as i32, + encrypted_value: Vec::new(), + }) +} + /// Try to extract the encryption password from given password. The password is always stored /// encrypted in the system catalogs. The ENCRYPTED keyword has no effect, but is accepted for /// backwards compatibility. The method of encryption is by default SHA-256-encrypted. If the @@ -81,6 +90,7 @@ pub fn encrypted_raw_password(info: &AuthInfo) -> String { EncryptionType::Plaintext => "", EncryptionType::Sha256 => SHA256_ENCRYPTED_PREFIX, EncryptionType::Md5 => MD5_ENCRYPTED_PREFIX, + EncryptionType::OAuth => "", }; format!("{}{}", prefix, encrypted_pwd) } diff --git a/src/sqlparser/src/ast/statement.rs b/src/sqlparser/src/ast/statement.rs index 3ff012c81b766..17de6c8cad0ad 100644 --- a/src/sqlparser/src/ast/statement.rs +++ b/src/sqlparser/src/ast/statement.rs @@ -1053,6 +1053,7 @@ pub enum UserOption { NoLogin, EncryptedPassword(AstString), Password(Option), + OAuth, } impl fmt::Display for UserOption { @@ -1069,6 +1070,7 @@ impl fmt::Display for UserOption { UserOption::EncryptedPassword(p) => write!(f, "ENCRYPTED PASSWORD {}", p), UserOption::Password(None) => write!(f, "PASSWORD NULL"), UserOption::Password(Some(p)) => write!(f, "PASSWORD {}", p), + UserOption::OAuth => write!(f, "OAUTH"), } } } @@ -1156,10 +1158,11 @@ impl ParseTo for UserOptions { UserOption::EncryptedPassword(AstString::parse_to(parser)?), ) } + Keyword::OAUTH => (&mut builder.password, UserOption::OAuth), _ => { parser.expected( "SUPERUSER | NOSUPERUSER | CREATEDB | NOCREATEDB | LOGIN \ - | NOLOGIN | CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL", + | NOLOGIN | CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL | OAUTH", token, )?; unreachable!() @@ -1169,7 +1172,7 @@ impl ParseTo for UserOptions { } else { parser.expected( "SUPERUSER | NOSUPERUSER | CREATEDB | NOCREATEDB | LOGIN | NOLOGIN \ - | CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL", + | CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL | OAUTH", token, )? } diff --git a/src/sqlparser/src/keywords.rs b/src/sqlparser/src/keywords.rs index 4188f06f76ae3..5f1864dce444a 100644 --- a/src/sqlparser/src/keywords.rs +++ b/src/sqlparser/src/keywords.rs @@ -337,6 +337,7 @@ define_keywords!( NULLIF, NULLS, NUMERIC, + OAUTH, OBJECT, OCCURRENCES_REGEX, OCTET_LENGTH, diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 5cc094a204268..26878d0908a61 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -2322,7 +2322,7 @@ impl Parser { // | CREATEDB | NOCREATEDB // | CREATEUSER | NOCREATEUSER // | LOGIN | NOLOGIN - // | [ ENCRYPTED ] PASSWORD 'password' | PASSWORD NULL + // | [ ENCRYPTED ] PASSWORD 'password' | PASSWORD NULL | OAUTH fn parse_create_user(&mut self) -> Result { Ok(Statement::CreateUser(CreateUserStatement::parse_to(self)?)) } diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index ff705025a0d64..3ec41a53a8da4 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -387,7 +387,7 @@ where })?; self.ready_for_query()?; } - UserAuthenticator::ClearText(_) => { + UserAuthenticator::ClearText(_) | UserAuthenticator::OAuth => { self.stream .write_no_flush(&BeMessage::AuthenticationCleartextPassword)?; } diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index ba52215e4d34a..c26bf57b6c0b3 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -116,6 +116,7 @@ pub enum UserAuthenticator { encrypted_password: Vec, salt: [u8; 4], }, + OAuth, } impl UserAuthenticator { @@ -126,6 +127,10 @@ impl UserAuthenticator { UserAuthenticator::Md5WithSalt { encrypted_password, .. } => encrypted_password == password, + UserAuthenticator::OAuth => { + // TODO: OAuth authentication happens here. + true + } } } } From 1cc3d16560d734c3a53107cc48f2d8b4ea355c70 Mon Sep 17 00:00:00 2001 From: Rossil <40714231+Rossil2012@users.noreply.github.com> Date: Mon, 30 Oct 2023 18:15:37 +0800 Subject: [PATCH 02/12] print oauth token for debug --- src/utils/pgwire/src/pg_server.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index c26bf57b6c0b3..e4def7482d26d 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -129,6 +129,7 @@ impl UserAuthenticator { } => encrypted_password == password, UserAuthenticator::OAuth => { // TODO: OAuth authentication happens here. + tracing::info!("OAuth authenticator gets: {}", String::from_utf8_lossy(password)); true } } From 5d4bfe24876d7f977f76d0ae728e923f18cad9a3 Mon Sep 17 00:00:00 2001 From: Rossil <40714231+Rossil2012@users.noreply.github.com> Date: Wed, 21 Feb 2024 19:26:20 +0800 Subject: [PATCH 03/12] impl jwt auth --- Cargo.lock | 3 + proto/meta.proto | 1 + src/common/src/system_param/mod.rs | 33 +++++---- src/common/src/system_param/reader.rs | 7 ++ src/config/docs.md | 1 + src/config/example.toml | 1 + src/frontend/src/session.rs | 5 ++ src/frontend/src/user/user_authentication.rs | 2 +- src/utils/pgwire/Cargo.toml | 3 + src/utils/pgwire/src/pg_protocol.rs | 10 +-- src/utils/pgwire/src/pg_server.rs | 75 ++++++++++++++++++-- 11 files changed, 114 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 74c6eee85703a..7ca5fda6a0f13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7280,12 +7280,15 @@ dependencies = [ "bytes", "futures", "itertools 0.12.0", + "jsonwebtoken 9.2.0", "madsim-tokio", "openssl", "panic-message", "parking_lot 0.12.1", + "reqwest", "risingwave_common", "risingwave_sqlparser", + "serde", "tempfile", "thiserror", "thiserror-ext", diff --git a/proto/meta.proto b/proto/meta.proto index 01492cc0c4fff..ec03c1c91e148 100644 --- a/proto/meta.proto +++ b/proto/meta.proto @@ -557,6 +557,7 @@ message SystemParams { optional bool pause_on_next_bootstrap = 13; optional string wasm_storage_url = 14; optional bool enable_tracing = 15; + optional string oauth_jwks_url = 16; } message GetSystemParamsRequest {} diff --git a/src/common/src/system_param/mod.rs b/src/common/src/system_param/mod.rs index 278390887dd51..5894a2dc71275 100644 --- a/src/common/src/system_param/mod.rs +++ b/src/common/src/system_param/mod.rs @@ -73,21 +73,23 @@ impl_param_value!(String => &'a str); macro_rules! for_all_params { ($macro:ident) => { $macro! { - // name type default value mut? doc - { barrier_interval_ms, u32, Some(1000_u32), true, "The interval of periodic barrier.", }, - { checkpoint_frequency, u64, Some(1_u64), true, "There will be a checkpoint for every n barriers.", }, - { sstable_size_mb, u32, Some(256_u32), false, "Target size of the Sstable.", }, - { parallel_compact_size_mb, u32, Some(512_u32), false, "", }, - { block_size_kb, u32, Some(64_u32), false, "Size of each block in bytes in SST.", }, - { bloom_false_positive, f64, Some(0.001_f64), false, "False positive probability of bloom filter.", }, - { state_store, String, None, false, "", }, - { data_directory, String, None, false, "Remote directory for storing data and metadata objects.", }, - { backup_storage_url, String, None, true, "Remote storage url for storing snapshots.", }, - { backup_storage_directory, String, None, true, "Remote directory for storing snapshots.", }, - { max_concurrent_creating_streaming_jobs, u32, Some(1_u32), true, "Max number of concurrent creating streaming jobs.", }, - { pause_on_next_bootstrap, bool, Some(false), true, "Whether to pause all data sources on next bootstrap.", }, - { wasm_storage_url, String, Some("fs://.risingwave/data".to_string()), false, "", }, - { enable_tracing, bool, Some(false), true, "Whether to enable distributed tracing.", }, + // name type default value mut? doc + { barrier_interval_ms, u32, Some(1000_u32), true, "The interval of periodic barrier.", }, + { checkpoint_frequency, u64, Some(1_u64), true, "There will be a checkpoint for every n barriers.", }, + { sstable_size_mb, u32, Some(256_u32), false, "Target size of the Sstable.", }, + { parallel_compact_size_mb, u32, Some(512_u32), false, "", }, + { block_size_kb, u32, Some(64_u32), false, "Size of each block in bytes in SST.", }, + { bloom_false_positive, f64, Some(0.001_f64), false, "False positive probability of bloom filter.", }, + { state_store, String, None, false, "", }, + { data_directory, String, None, false, "Remote directory for storing data and metadata objects.", }, + { backup_storage_url, String, None, true, "Remote storage url for storing snapshots.", }, + { backup_storage_directory, String, None, true, "Remote directory for storing snapshots.", }, + { max_concurrent_creating_streaming_jobs, u32, Some(1_u32), true, "Max number of concurrent creating streaming jobs.", }, + { pause_on_next_bootstrap, bool, Some(false), true, "Whether to pause all data sources on next bootstrap.", }, + { wasm_storage_url, String, Some("fs://.risingwave/data".to_string()), false, "", }, + { enable_tracing, bool, Some(false), true, "Whether to enable distributed tracing.", }, + // TODO: modify default value + { oauth_jwks_url, String, Some("https://auth-static.confluent.io/jwks".to_string()), true, "Url to get JSON Web Key Set(JWKS) for oauth authentication.", }, } }; } @@ -442,6 +444,7 @@ mod tests { (PAUSE_ON_NEXT_BOOTSTRAP_KEY, "false"), (WASM_STORAGE_URL_KEY, "a"), (ENABLE_TRACING_KEY, "true"), + (OAUTH_JWKS_URL_KEY, "a"), ("a_deprecated_param", "foo"), ]; diff --git a/src/common/src/system_param/reader.rs b/src/common/src/system_param/reader.rs index cf17c7bb43dd5..442be850f0004 100644 --- a/src/common/src/system_param/reader.rs +++ b/src/common/src/system_param/reader.rs @@ -167,4 +167,11 @@ where .as_ref() .unwrap_or(&default::WASM_STORAGE_URL) } + + fn oauth_jwks_url(&self) -> &str { + self.inner() + .oauth_jwks_url + .as_ref() + .unwrap_or(&default::OAUTH_JWKS_URL) + } } diff --git a/src/config/docs.md b/src/config/docs.md index 36fd40ce2d13a..db1725a9c0765 100644 --- a/src/config/docs.md +++ b/src/config/docs.md @@ -151,3 +151,4 @@ This page is automatically generated by `./risedev generate-example-config` | sstable_size_mb | Target size of the Sstable. | 256 | | state_store | | | | wasm_storage_url | | "fs://.risingwave/data" | +| oauth_jwks_url | Url to get JSON Web Key Set(JWKS) for oauth authentication. | https://auth-static.confluent.io/jwks | diff --git a/src/config/example.toml b/src/config/example.toml index 59c68aff3c7c0..4c645eaddcc99 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -197,3 +197,4 @@ max_concurrent_creating_streaming_jobs = 1 pause_on_next_bootstrap = false wasm_storage_url = "fs://.risingwave/data" enable_tracing = false +oauth_jwks_url = "https://auth-static.confluent.io/jwks" diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 976e42fc2685b..dc407e15b2846 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -46,6 +46,7 @@ use risingwave_common::session_config::{ConfigMap, ConfigReporter, VisibilityMod use risingwave_common::system_param::local_manager::{ LocalSystemParamsManager, LocalSystemParamsManagerRef, }; +use risingwave_common::system_param::reader::SystemParamsReader; use risingwave_common::telemetry::manager::TelemetryManager; use risingwave_common::telemetry::telemetry_env_enabled; use risingwave_common::types::DataType; @@ -1089,6 +1090,10 @@ impl Session for SessionImpl { &self.user_authenticator } + async fn get_system_params(&self) -> std::result::Result { + Ok(self.env.meta_client.get_system_params().await?) + } + fn id(&self) -> SessionId { self.id } diff --git a/src/frontend/src/user/user_authentication.rs b/src/frontend/src/user/user_authentication.rs index 3c475b79e1ff0..32e1211b0f2b8 100644 --- a/src/frontend/src/user/user_authentication.rs +++ b/src/frontend/src/user/user_authentication.rs @@ -24,7 +24,7 @@ const MD5_ENCRYPTED_PREFIX: &str = "md5"; const VALID_SHA256_ENCRYPTED_LEN: usize = SHA256_ENCRYPTED_PREFIX.len() + 64; const VALID_MD5_ENCRYPTED_LEN: usize = MD5_ENCRYPTED_PREFIX.len() + 32; -/// Build AuthInfo for OAuth. +/// Build `AuthInfo` for `OAuth`. #[inline(always)] pub fn build_oauth_info() -> Option { Some(AuthInfo { diff --git a/src/utils/pgwire/Cargo.toml b/src/utils/pgwire/Cargo.toml index 0e5b4e98faefd..907b2ac5904db 100644 --- a/src/utils/pgwire/Cargo.toml +++ b/src/utils/pgwire/Cargo.toml @@ -21,11 +21,14 @@ byteorder = "1.5" bytes = "1" futures = { version = "0.3", default-features = false, features = ["alloc"] } itertools = "0.12" +jsonwebtoken = "9" openssl = "0.10.60" panic-message = "0.3" parking_lot = "0.12" +reqwest = { version = "0.11" } risingwave_common = { workspace = true } risingwave_sqlparser = { workspace = true } +serde = { version = "1", features = ["derive"] } thiserror = "1" thiserror-ext = { workspace = true } tokio = { version = "0.2", package = "madsim-tokio", features = ["rt", "macros"] } diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index 1c12ed308c393..cffcc8459a00e 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -387,7 +387,7 @@ where match msg { FeMessage::Ssl => self.process_ssl_msg().await?, FeMessage::Startup(msg) => self.process_startup_msg(msg)?, - FeMessage::Password(msg) => self.process_password_msg(msg)?, + FeMessage::Password(msg) => self.process_password_msg(msg).await?, FeMessage::Query(query_msg) => self.process_query_msg(query_msg.get_sql()).await?, FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?, FeMessage::Terminate => self.process_terminate(), @@ -523,11 +523,11 @@ where Ok(()) } - fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> { + async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> { let authenticator = self.session.as_ref().unwrap().user_authenticator(); - if !authenticator.authenticate(&msg.password) { - return Err(PsqlError::PasswordError); - } + authenticator + .authenticate(&msg.password, Arc::clone(self.session.as_ref().unwrap())) + .await?; self.stream.write_no_flush(&BeMessage::AuthenticationOk)?; self.stream .write_parameter_status_msg_no_flush(&ParameterStatus::default())?; diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index fd62bc1ec3207..f74b92240a2ba 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -12,20 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::future::Future; use std::io; use std::result::Result; +use std::str::FromStr; use std::sync::Arc; use std::time::Instant; use bytes::Bytes; +use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation}; use parking_lot::Mutex; +use risingwave_common::system_param::reader::{SystemParamsRead, SystemParamsReader}; use risingwave_common::types::DataType; use risingwave_sqlparser::ast::Statement; +use serde::Deserialize; use thiserror_ext::AsReport; use tokio::io::{AsyncRead, AsyncWrite}; -use crate::error::PsqlResult; +use crate::error::{PsqlError, PsqlResult}; use crate::net::{AddressRef, Listener}; use crate::pg_field_descriptor::PgFieldDescriptor; use crate::pg_message::TransactionStatus; @@ -107,6 +112,10 @@ pub trait Session: Send + Sync { fn user_authenticator(&self) -> &UserAuthenticator; + fn get_system_params( + &self, + ) -> impl Future> + Send; + fn id(&self) -> SessionId; fn set_config(&self, key: &str, value: String) -> Result<(), BoxedError>; @@ -158,20 +167,69 @@ pub enum UserAuthenticator { OAuth, } +#[derive(Debug, Deserialize)] +struct Jwks { + keys: Vec, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct Jwk { + kid: String, + alg: String, + n: String, + e: String, +} + +async fn fetch_jwks(url: &str) -> Result { + let resp: Jwks = reqwest::get(url).await?.json().await?; + Ok(resp) +} + +async fn validate_jwt(jwt: &str, jwks_url: &str) -> Result { + let header = decode_header(jwt)?; + let jwks = fetch_jwks(jwks_url).await?; + + let kid = header.kid.ok_or("kid not found in jwt header")?; + let jwk = jwks + .keys + .into_iter() + .find(|k| k.kid == kid) + .ok_or("kid not found in jwks")?; + + let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?; + let validation = Validation::new(Algorithm::from_str(&jwk.alg)?); + + Ok(decode::>(jwt, &decoding_key, &validation).is_ok()) +} + impl UserAuthenticator { - pub fn authenticate(&self, password: &[u8]) -> bool { - match self { + pub async fn authenticate( + &self, + password: &[u8], + session: Arc, + ) -> PsqlResult<()> { + let success = match self { UserAuthenticator::None => true, UserAuthenticator::ClearText(text) => password == text, UserAuthenticator::Md5WithSalt { encrypted_password, .. } => encrypted_password == password, UserAuthenticator::OAuth => { - // TODO: OAuth authentication happens here. - tracing::info!("OAuth authenticator gets: {}", String::from_utf8_lossy(password)); - true + let system_params_reader = session + .get_system_params() + .await + .map_err(PsqlError::StartupError)?; + let oauth_jwks_url = system_params_reader.oauth_jwks_url(); + validate_jwt(&String::from_utf8_lossy(password), oauth_jwks_url) + .await + .map_err(PsqlError::StartupError)? } + }; + if !success { + return Err(PsqlError::PasswordError); } + Ok(()) } } @@ -239,6 +297,7 @@ mod tests { use bytes::Bytes; use futures::stream::BoxStream; use futures::StreamExt; + use risingwave_common::system_param::reader::SystemParamsReader; use risingwave_common::types::DataType; use risingwave_sqlparser::ast::Statement; use tokio_postgres::NoTls; @@ -361,6 +420,10 @@ mod tests { &UserAuthenticator::None } + async fn get_system_params(&self) -> Result { + Ok(SystemParamsReader::new(Default::default())) + } + fn id(&self) -> SessionId { (0, 0) } From 106d5f87ac5fe0f35e6a9527d214f28b7324b146 Mon Sep 17 00:00:00 2001 From: Rossil <40714231+Rossil2012@users.noreply.github.com> Date: Fri, 23 Feb 2024 11:45:33 +0800 Subject: [PATCH 04/12] get_url on connect --- proto/user.proto | 2 +- src/frontend/src/session.rs | 50 ++++++++++++-------- src/frontend/src/test_utils.rs | 6 +-- src/frontend/src/user/user_authentication.rs | 4 +- src/utils/pgwire/src/pg_protocol.rs | 13 +++-- src/utils/pgwire/src/pg_server.rs | 29 ++++-------- 6 files changed, 52 insertions(+), 52 deletions(-) diff --git a/proto/user.proto b/proto/user.proto index 9bc5fee1748cc..dd04dd558a6a3 100644 --- a/proto/user.proto +++ b/proto/user.proto @@ -14,7 +14,7 @@ message AuthInfo { PLAINTEXT = 1; SHA256 = 2; MD5 = 3; - OAuth = 4; + OAUTH = 4; } EncryptionType encryption_type = 1; bytes encrypted_value = 2; diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index dc407e15b2846..8c2f201d961f7 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -46,7 +46,7 @@ use risingwave_common::session_config::{ConfigMap, ConfigReporter, VisibilityMod use risingwave_common::system_param::local_manager::{ LocalSystemParamsManager, LocalSystemParamsManagerRef, }; -use risingwave_common::system_param::reader::SystemParamsReader; +use risingwave_common::system_param::reader::{SystemParamsRead, SystemParamsReader}; use risingwave_common::telemetry::manager::TelemetryManager; use risingwave_common::telemetry::telemetry_env_enabled; use risingwave_common::types::DataType; @@ -925,26 +925,29 @@ pub struct SessionManagerImpl { impl SessionManager for SessionManagerImpl { type Session = SessionImpl; - fn connect( + async fn connect( &self, - database: &str, - user_name: &str, + database: String, + user_name: String, peer_addr: AddressRef, ) -> std::result::Result, BoxedError> { - let catalog_reader = self.env.catalog_reader(); - let reader = catalog_reader.read_guard(); - let database_id = reader - .get_database_by_name(database) - .map_err(|_| { - Box::new(Error::new( - ErrorKind::InvalidInput, - format!("database \"{}\" does not exist", database), - )) - })? - .id(); - let user_reader = self.env.user_info_reader(); - let reader = user_reader.read_guard(); - if let Some(user) = reader.get_user_by_name(user_name) { + let database_id = { + let catalog_reader = self.env.catalog_reader().read_guard(); + catalog_reader + .get_database_by_name(&database) + .map_err(|_| { + Box::new(Error::new( + ErrorKind::InvalidInput, + format!("database \"{}\" does not exist", database), + )) + })? + .id() + }; + let user = { + let user_reader = self.env.user_info_reader().read_guard(); + user_reader.get_user_by_name(&user_name).cloned() + }; + if let Some(user) = user { if !user.can_login { return Err(Box::new(Error::new( ErrorKind::InvalidInput, @@ -975,8 +978,15 @@ impl SessionManager for SessionManagerImpl { ), salt, } - } else if auth_info.encryption_type == EncryptionType::OAuth as i32 { - UserAuthenticator::OAuth + } else if auth_info.encryption_type == EncryptionType::Oauth as i32 { + let reader = self + .env + .meta_client() + .get_system_params() + .await + .map_err(|e| PsqlError::StartupError(e.into()))?; + let oauth_jwks_url = reader.oauth_jwks_url().to_string(); + UserAuthenticator::OAuth(oauth_jwks_url) } else { return Err(Box::new(Error::new( ErrorKind::Unsupported, diff --git a/src/frontend/src/test_utils.rs b/src/frontend/src/test_utils.rs index 55772ba9ed068..1ff663f178aa4 100644 --- a/src/frontend/src/test_utils.rs +++ b/src/frontend/src/test_utils.rs @@ -80,10 +80,10 @@ pub struct LocalFrontend { impl SessionManager for LocalFrontend { type Session = SessionImpl; - fn connect( + async fn connect( &self, - _database: &str, - _user_name: &str, + _database: String, + _user_name: String, _peer_addr: AddressRef, ) -> std::result::Result, BoxedError> { Ok(self.session_ref()) diff --git a/src/frontend/src/user/user_authentication.rs b/src/frontend/src/user/user_authentication.rs index 32e1211b0f2b8..d558fb03ee3b6 100644 --- a/src/frontend/src/user/user_authentication.rs +++ b/src/frontend/src/user/user_authentication.rs @@ -28,7 +28,7 @@ const VALID_MD5_ENCRYPTED_LEN: usize = MD5_ENCRYPTED_PREFIX.len() + 32; #[inline(always)] pub fn build_oauth_info() -> Option { Some(AuthInfo { - encryption_type: EncryptionType::OAuth as i32, + encryption_type: EncryptionType::Oauth as i32, encrypted_value: Vec::new(), }) } @@ -90,7 +90,7 @@ pub fn encrypted_raw_password(info: &AuthInfo) -> String { EncryptionType::Plaintext => "", EncryptionType::Sha256 => SHA256_ENCRYPTED_PREFIX, EncryptionType::Md5 => MD5_ENCRYPTED_PREFIX, - EncryptionType::OAuth => "", + EncryptionType::Oauth => "", }; format!("{}{}", prefix, encrypted_pwd) } diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index cffcc8459a00e..5d01b7e07f40b 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -386,7 +386,7 @@ where match msg { FeMessage::Ssl => self.process_ssl_msg().await?, - FeMessage::Startup(msg) => self.process_startup_msg(msg)?, + FeMessage::Startup(msg) => self.process_startup_msg(msg).await?, FeMessage::Password(msg) => self.process_password_msg(msg).await?, FeMessage::Query(query_msg) => self.process_query_msg(query_msg.get_sql()).await?, FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?, @@ -469,7 +469,7 @@ where Ok(()) } - fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> { + async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> { let db_name = msg .config .get("database") @@ -483,7 +483,8 @@ where let session = self .session_mgr - .connect(&db_name, &user_name, self.peer_addr.clone()) + .connect(db_name, user_name, self.peer_addr.clone()) + .await .map_err(PsqlError::StartupError)?; let application_name = msg.config.get("application_name"); @@ -508,7 +509,7 @@ where })?; self.ready_for_query()?; } - UserAuthenticator::ClearText(_) | UserAuthenticator::OAuth => { + UserAuthenticator::ClearText(_) | UserAuthenticator::OAuth(_) => { self.stream .write_no_flush(&BeMessage::AuthenticationCleartextPassword)?; } @@ -525,9 +526,7 @@ where async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> { let authenticator = self.session.as_ref().unwrap().user_authenticator(); - authenticator - .authenticate(&msg.password, Arc::clone(self.session.as_ref().unwrap())) - .await?; + authenticator.authenticate(&msg.password).await?; self.stream.write_no_flush(&BeMessage::AuthenticationOk)?; self.stream .write_parameter_status_msg_no_flush(&ParameterStatus::default())?; diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index f74b92240a2ba..5b2a81f273ba9 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -23,7 +23,7 @@ use std::time::Instant; use bytes::Bytes; use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation}; use parking_lot::Mutex; -use risingwave_common::system_param::reader::{SystemParamsRead, SystemParamsReader}; +use risingwave_common::system_param::reader::SystemParamsReader; use risingwave_common::types::DataType; use risingwave_sqlparser::ast::Statement; use serde::Deserialize; @@ -50,10 +50,10 @@ pub trait SessionManager: Send + Sync + 'static { fn connect( &self, - database: &str, - user_name: &str, + database: String, + user_name: String, peer_addr: AddressRef, - ) -> Result, BoxedError>; + ) -> impl Future, BoxedError>> + Send; fn cancel_queries_in_session(&self, session_id: SessionId); @@ -164,7 +164,7 @@ pub enum UserAuthenticator { encrypted_password: Vec, salt: [u8; 4], }, - OAuth, + OAuth(String), } #[derive(Debug, Deserialize)] @@ -204,23 +204,14 @@ async fn validate_jwt(jwt: &str, jwks_url: &str) -> Result { } impl UserAuthenticator { - pub async fn authenticate( - &self, - password: &[u8], - session: Arc, - ) -> PsqlResult<()> { + pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> { let success = match self { UserAuthenticator::None => true, UserAuthenticator::ClearText(text) => password == text, UserAuthenticator::Md5WithSalt { encrypted_password, .. } => encrypted_password == password, - UserAuthenticator::OAuth => { - let system_params_reader = session - .get_system_params() - .await - .map_err(PsqlError::StartupError)?; - let oauth_jwks_url = system_params_reader.oauth_jwks_url(); + UserAuthenticator::OAuth(oauth_jwks_url) => { validate_jwt(&String::from_utf8_lossy(password), oauth_jwks_url) .await .map_err(PsqlError::StartupError)? @@ -319,10 +310,10 @@ mod tests { impl SessionManager for MockSessionManager { type Session = MockSession; - fn connect( + async fn connect( &self, - _database: &str, - _user_name: &str, + _database: String, + _user_name: String, _peer_addr: crate::net::AddressRef, ) -> Result, Box> { Ok(Arc::new(MockSession {})) From 716400da6372b03f8181326e74e8674de52f4b53 Mon Sep 17 00:00:00 2001 From: Rossil <40714231+Rossil2012@users.noreply.github.com> Date: Fri, 23 Feb 2024 12:13:40 +0800 Subject: [PATCH 05/12] use None as default --- src/common/src/system_param/mod.rs | 33 +++++++++++++-------------- src/common/src/system_param/reader.rs | 5 +--- src/config/docs.md | 2 +- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/src/common/src/system_param/mod.rs b/src/common/src/system_param/mod.rs index 5894a2dc71275..e0169a24182d8 100644 --- a/src/common/src/system_param/mod.rs +++ b/src/common/src/system_param/mod.rs @@ -73,23 +73,22 @@ impl_param_value!(String => &'a str); macro_rules! for_all_params { ($macro:ident) => { $macro! { - // name type default value mut? doc - { barrier_interval_ms, u32, Some(1000_u32), true, "The interval of periodic barrier.", }, - { checkpoint_frequency, u64, Some(1_u64), true, "There will be a checkpoint for every n barriers.", }, - { sstable_size_mb, u32, Some(256_u32), false, "Target size of the Sstable.", }, - { parallel_compact_size_mb, u32, Some(512_u32), false, "", }, - { block_size_kb, u32, Some(64_u32), false, "Size of each block in bytes in SST.", }, - { bloom_false_positive, f64, Some(0.001_f64), false, "False positive probability of bloom filter.", }, - { state_store, String, None, false, "", }, - { data_directory, String, None, false, "Remote directory for storing data and metadata objects.", }, - { backup_storage_url, String, None, true, "Remote storage url for storing snapshots.", }, - { backup_storage_directory, String, None, true, "Remote directory for storing snapshots.", }, - { max_concurrent_creating_streaming_jobs, u32, Some(1_u32), true, "Max number of concurrent creating streaming jobs.", }, - { pause_on_next_bootstrap, bool, Some(false), true, "Whether to pause all data sources on next bootstrap.", }, - { wasm_storage_url, String, Some("fs://.risingwave/data".to_string()), false, "", }, - { enable_tracing, bool, Some(false), true, "Whether to enable distributed tracing.", }, - // TODO: modify default value - { oauth_jwks_url, String, Some("https://auth-static.confluent.io/jwks".to_string()), true, "Url to get JSON Web Key Set(JWKS) for oauth authentication.", }, + // name type default value mut? doc + { barrier_interval_ms, u32, Some(1000_u32), true, "The interval of periodic barrier.", }, + { checkpoint_frequency, u64, Some(1_u64), true, "There will be a checkpoint for every n barriers.", }, + { sstable_size_mb, u32, Some(256_u32), false, "Target size of the Sstable.", }, + { parallel_compact_size_mb, u32, Some(512_u32), false, "", }, + { block_size_kb, u32, Some(64_u32), false, "Size of each block in bytes in SST.", }, + { bloom_false_positive, f64, Some(0.001_f64), false, "False positive probability of bloom filter.", }, + { state_store, String, None, false, "", }, + { data_directory, String, None, false, "Remote directory for storing data and metadata objects.", }, + { backup_storage_url, String, None, true, "Remote storage url for storing snapshots.", }, + { backup_storage_directory, String, None, true, "Remote directory for storing snapshots.", }, + { max_concurrent_creating_streaming_jobs, u32, Some(1_u32), true, "Max number of concurrent creating streaming jobs.", }, + { pause_on_next_bootstrap, bool, Some(false), true, "Whether to pause all data sources on next bootstrap.", }, + { wasm_storage_url, String, Some("fs://.risingwave/data".to_string()), false, "", }, + { enable_tracing, bool, Some(false), true, "Whether to enable distributed tracing.", }, + { oauth_jwks_url, String, None, true, "Url to get JSON Web Key Set(JWKS) for oauth authentication.", }, } }; } diff --git a/src/common/src/system_param/reader.rs b/src/common/src/system_param/reader.rs index 442be850f0004..821c5fd80cc2b 100644 --- a/src/common/src/system_param/reader.rs +++ b/src/common/src/system_param/reader.rs @@ -169,9 +169,6 @@ where } fn oauth_jwks_url(&self) -> &str { - self.inner() - .oauth_jwks_url - .as_ref() - .unwrap_or(&default::OAUTH_JWKS_URL) + self.inner().oauth_jwks_url.as_ref().unwrap() } } diff --git a/src/config/docs.md b/src/config/docs.md index db1725a9c0765..b7a72755e9e6b 100644 --- a/src/config/docs.md +++ b/src/config/docs.md @@ -151,4 +151,4 @@ This page is automatically generated by `./risedev generate-example-config` | sstable_size_mb | Target size of the Sstable. | 256 | | state_store | | | | wasm_storage_url | | "fs://.risingwave/data" | -| oauth_jwks_url | Url to get JSON Web Key Set(JWKS) for oauth authentication. | https://auth-static.confluent.io/jwks | +| oauth_jwks_url | Url to get JSON Web Key Set(JWKS) for oauth authentication. | | From 7eeef9fb9acdd9272491d6c909e6150e41f09495 Mon Sep 17 00:00:00 2001 From: Rossil <40714231+Rossil2012@users.noreply.github.com> Date: Fri, 23 Feb 2024 13:53:52 +0800 Subject: [PATCH 06/12] fix unit test --- src/common/src/system_param/mod.rs | 1 + src/config/docs.md | 2 +- src/config/example.toml | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/common/src/system_param/mod.rs b/src/common/src/system_param/mod.rs index e0169a24182d8..b80b6a331f57a 100644 --- a/src/common/src/system_param/mod.rs +++ b/src/common/src/system_param/mod.rs @@ -377,6 +377,7 @@ macro_rules! impl_system_params_for_test { ret.state_store = Some("hummock+memory".to_string()); ret.backup_storage_url = Some("memory".into()); ret.backup_storage_directory = Some("backup".into()); + ret.oauth_jwks_url = Some("https://auth-static.confluent.io/jwks".into()); ret } }; diff --git a/src/config/docs.md b/src/config/docs.md index b7a72755e9e6b..5488241712ca0 100644 --- a/src/config/docs.md +++ b/src/config/docs.md @@ -146,9 +146,9 @@ This page is automatically generated by `./risedev generate-example-config` | data_directory | Remote directory for storing data and metadata objects. | | | enable_tracing | Whether to enable distributed tracing. | false | | max_concurrent_creating_streaming_jobs | Max number of concurrent creating streaming jobs. | 1 | +| oauth_jwks_url | Url to get JSON Web Key Set(JWKS) for oauth authentication. | | | parallel_compact_size_mb | | 512 | | pause_on_next_bootstrap | Whether to pause all data sources on next bootstrap. | false | | sstable_size_mb | Target size of the Sstable. | 256 | | state_store | | | | wasm_storage_url | | "fs://.risingwave/data" | -| oauth_jwks_url | Url to get JSON Web Key Set(JWKS) for oauth authentication. | | diff --git a/src/config/example.toml b/src/config/example.toml index 4c645eaddcc99..59c68aff3c7c0 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -197,4 +197,3 @@ max_concurrent_creating_streaming_jobs = 1 pause_on_next_bootstrap = false wasm_storage_url = "fs://.risingwave/data" enable_tracing = false -oauth_jwks_url = "https://auth-static.confluent.io/jwks" From dc97f269747bda51a36c358ec01db7f7426457ab Mon Sep 17 00:00:00 2001 From: August Date: Thu, 29 Feb 2024 02:17:08 +0800 Subject: [PATCH 07/12] fix default and add validation --- src/common/src/system_param/reader.rs | 5 +++- src/frontend/src/session.rs | 34 ++++++++++++++------------- src/frontend/src/test_utils.rs | 6 ++--- src/utils/pgwire/src/pg_protocol.rs | 7 +++--- src/utils/pgwire/src/pg_server.rs | 22 +++++------------ 5 files changed, 34 insertions(+), 40 deletions(-) diff --git a/src/common/src/system_param/reader.rs b/src/common/src/system_param/reader.rs index 72d7ef5ad7a02..7b0e0d4667e08 100644 --- a/src/common/src/system_param/reader.rs +++ b/src/common/src/system_param/reader.rs @@ -162,6 +162,9 @@ where } fn oauth_jwks_url(&self) -> &str { - self.inner().oauth_jwks_url.as_ref().unwrap() + self.inner() + .oauth_jwks_url + .as_ref() + .unwrap_or(&default::OAUTH_JWKS_URL) } } diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 9feb30f7c53e5..dfb14380e34cd 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -46,7 +46,7 @@ use risingwave_common::session_config::{ConfigMap, ConfigReporter, VisibilityMod use risingwave_common::system_param::local_manager::{ LocalSystemParamsManager, LocalSystemParamsManagerRef, }; -use risingwave_common::system_param::reader::{SystemParamsRead, SystemParamsReader}; +use risingwave_common::system_param::reader::SystemParamsRead; use risingwave_common::telemetry::manager::TelemetryManager; use risingwave_common::telemetry::telemetry_env_enabled; use risingwave_common::types::DataType; @@ -927,16 +927,16 @@ pub struct SessionManagerImpl { impl SessionManager for SessionManagerImpl { type Session = SessionImpl; - async fn connect( + fn connect( &self, - database: String, - user_name: String, + database: &str, + user_name: &str, peer_addr: AddressRef, ) -> std::result::Result, BoxedError> { let database_id = { let catalog_reader = self.env.catalog_reader().read_guard(); catalog_reader - .get_database_by_name(&database) + .get_database_by_name(database) .map_err(|_| { Box::new(Error::new( ErrorKind::InvalidInput, @@ -947,7 +947,7 @@ impl SessionManager for SessionManagerImpl { }; let user = { let user_reader = self.env.user_info_reader().read_guard(); - user_reader.get_user_by_name(&user_name).cloned() + user_reader.get_user_by_name(user_name).cloned() }; if let Some(user) = user { if !user.can_login { @@ -981,13 +981,19 @@ impl SessionManager for SessionManagerImpl { salt, } } else if auth_info.encryption_type == EncryptionType::Oauth as i32 { - let reader = self + let oauth_jwks_url = self .env - .meta_client() - .get_system_params() - .await - .map_err(|e| PsqlError::StartupError(e.into()))?; - let oauth_jwks_url = reader.oauth_jwks_url().to_string(); + .system_params_manager + .get_params() + .load() + .oauth_jwks_url() + .to_string(); + if oauth_jwks_url.is_empty() { + return Err(Box::new(Error::new( + ErrorKind::PermissionDenied, + "OAuth JWKS URL is not set", + ))); + } UserAuthenticator::OAuth(oauth_jwks_url) } else { return Err(Box::new(Error::new( @@ -1102,10 +1108,6 @@ impl Session for SessionImpl { &self.user_authenticator } - async fn get_system_params(&self) -> std::result::Result { - Ok(self.env.meta_client.get_system_params().await?) - } - fn id(&self) -> SessionId { self.id } diff --git a/src/frontend/src/test_utils.rs b/src/frontend/src/test_utils.rs index 1ff663f178aa4..55772ba9ed068 100644 --- a/src/frontend/src/test_utils.rs +++ b/src/frontend/src/test_utils.rs @@ -80,10 +80,10 @@ pub struct LocalFrontend { impl SessionManager for LocalFrontend { type Session = SessionImpl; - async fn connect( + fn connect( &self, - _database: String, - _user_name: String, + _database: &str, + _user_name: &str, _peer_addr: AddressRef, ) -> std::result::Result, BoxedError> { Ok(self.session_ref()) diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index 5d01b7e07f40b..18411b1a02359 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -386,7 +386,7 @@ where match msg { FeMessage::Ssl => self.process_ssl_msg().await?, - FeMessage::Startup(msg) => self.process_startup_msg(msg).await?, + FeMessage::Startup(msg) => self.process_startup_msg(msg)?, FeMessage::Password(msg) => self.process_password_msg(msg).await?, FeMessage::Query(query_msg) => self.process_query_msg(query_msg.get_sql()).await?, FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?, @@ -469,7 +469,7 @@ where Ok(()) } - async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> { + fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> { let db_name = msg .config .get("database") @@ -483,8 +483,7 @@ where let session = self .session_mgr - .connect(db_name, user_name, self.peer_addr.clone()) - .await + .connect(&db_name, &user_name, self.peer_addr.clone()) .map_err(PsqlError::StartupError)?; let application_name = msg.config.get("application_name"); diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index 5b2a81f273ba9..eb9c31442b85c 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -23,7 +23,6 @@ use std::time::Instant; use bytes::Bytes; use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation}; use parking_lot::Mutex; -use risingwave_common::system_param::reader::SystemParamsReader; use risingwave_common::types::DataType; use risingwave_sqlparser::ast::Statement; use serde::Deserialize; @@ -50,10 +49,10 @@ pub trait SessionManager: Send + Sync + 'static { fn connect( &self, - database: String, - user_name: String, + database: &str, + user_name: &str, peer_addr: AddressRef, - ) -> impl Future, BoxedError>> + Send; + ) -> Result, BoxedError>; fn cancel_queries_in_session(&self, session_id: SessionId); @@ -112,10 +111,6 @@ pub trait Session: Send + Sync { fn user_authenticator(&self) -> &UserAuthenticator; - fn get_system_params( - &self, - ) -> impl Future> + Send; - fn id(&self) -> SessionId; fn set_config(&self, key: &str, value: String) -> Result<(), BoxedError>; @@ -288,7 +283,6 @@ mod tests { use bytes::Bytes; use futures::stream::BoxStream; use futures::StreamExt; - use risingwave_common::system_param::reader::SystemParamsReader; use risingwave_common::types::DataType; use risingwave_sqlparser::ast::Statement; use tokio_postgres::NoTls; @@ -310,10 +304,10 @@ mod tests { impl SessionManager for MockSessionManager { type Session = MockSession; - async fn connect( + fn connect( &self, - _database: String, - _user_name: String, + _database: &str, + _user_name: &str, _peer_addr: crate::net::AddressRef, ) -> Result, Box> { Ok(Arc::new(MockSession {})) @@ -411,10 +405,6 @@ mod tests { &UserAuthenticator::None } - async fn get_system_params(&self) -> Result { - Ok(SystemParamsReader::new(Default::default())) - } - fn id(&self) -> SessionId { (0, 0) } From 80e410bdc65f5c6adfeab5722a20686136d91b1b Mon Sep 17 00:00:00 2001 From: August Date: Thu, 29 Feb 2024 02:22:00 +0800 Subject: [PATCH 08/12] update doc --- src/config/docs.md | 2 +- src/config/example.toml | 1 + src/frontend/src/session.rs | 31 ++++++++++++++----------------- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/config/docs.md b/src/config/docs.md index 4486d320de07b..07736636b9e37 100644 --- a/src/config/docs.md +++ b/src/config/docs.md @@ -146,7 +146,7 @@ This page is automatically generated by `./risedev generate-example-config` | data_directory | Remote directory for storing data and metadata objects. | | | enable_tracing | Whether to enable distributed tracing. | false | | max_concurrent_creating_streaming_jobs | Max number of concurrent creating streaming jobs. | 1 | -| oauth_jwks_url | Url to get JSON Web Key Set(JWKS) for oauth authentication. | | +| oauth_jwks_url | Url to get JSON Web Key Set(JWKS) for oauth authentication. | "" | | parallel_compact_size_mb | | 512 | | pause_on_next_bootstrap | Whether to pause all data sources on next bootstrap. | false | | sstable_size_mb | Target size of the Sstable. | 256 | diff --git a/src/config/example.toml b/src/config/example.toml index a1d5fadb52c53..9d83b7d5502f5 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -196,3 +196,4 @@ bloom_false_positive = 0.001 max_concurrent_creating_streaming_jobs = 1 pause_on_next_bootstrap = false enable_tracing = false +oauth_jwks_url = "" diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index dfb14380e34cd..17e4f7ef09e0c 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -933,23 +933,20 @@ impl SessionManager for SessionManagerImpl { user_name: &str, peer_addr: AddressRef, ) -> std::result::Result, BoxedError> { - let database_id = { - let catalog_reader = self.env.catalog_reader().read_guard(); - catalog_reader - .get_database_by_name(database) - .map_err(|_| { - Box::new(Error::new( - ErrorKind::InvalidInput, - format!("database \"{}\" does not exist", database), - )) - })? - .id() - }; - let user = { - let user_reader = self.env.user_info_reader().read_guard(); - user_reader.get_user_by_name(user_name).cloned() - }; - if let Some(user) = user { + let catalog_reader = self.env.catalog_reader(); + let reader = catalog_reader.read_guard(); + let database_id = reader + .get_database_by_name(database) + .map_err(|_| { + Box::new(Error::new( + ErrorKind::InvalidInput, + format!("database \"{}\" does not exist", database), + )) + })? + .id(); + let user_reader = self.env.user_info_reader(); + let reader = user_reader.read_guard(); + if let Some(user) = reader.get_user_by_name(user_name) { if !user.can_login { return Err(Box::new(Error::new( ErrorKind::InvalidInput, From e490f802db86d408d62eb5aa16c5132d57f179c9 Mon Sep 17 00:00:00 2001 From: August Date: Thu, 29 Feb 2024 02:52:25 +0800 Subject: [PATCH 09/12] fix e2e --- e2e_test/batch/catalog/pg_settings.slt.part | 1 + 1 file changed, 1 insertion(+) diff --git a/e2e_test/batch/catalog/pg_settings.slt.part b/e2e_test/batch/catalog/pg_settings.slt.part index 9ff41b8dbea45..eeec9713382d2 100644 --- a/e2e_test/batch/catalog/pg_settings.slt.part +++ b/e2e_test/batch/catalog/pg_settings.slt.part @@ -13,6 +13,7 @@ postmaster barrier_interval_ms postmaster checkpoint_frequency postmaster enable_tracing postmaster max_concurrent_creating_streaming_jobs +postmaster oauth_jwks_url postmaster pause_on_next_bootstrap user application_name user background_ddl From 9ce4feea4f7a71a59405c294d1284120ee3c32d9 Mon Sep 17 00:00:00 2001 From: Rossil <40714231+Rossil2012@users.noreply.github.com> Date: Thu, 29 Feb 2024 09:29:00 +0800 Subject: [PATCH 10/12] fix comments --- e2e_test/batch/catalog/pg_settings.slt.part | 1 - proto/meta.proto | 1 - proto/user.proto | 1 + src/common/src/system_param/mod.rs | 3 --- src/common/src/system_param/reader.rs | 7 ------- src/config/docs.md | 1 - src/config/example.toml | 1 - src/frontend/src/handler/alter_user.rs | 21 ++++++++++++++----- src/frontend/src/handler/create_user.rs | 21 +++++++++++++++---- src/frontend/src/session.rs | 16 +------------- src/frontend/src/user/user_authentication.rs | 22 +++++++++++++++++++- src/sqlparser/src/ast/statement.rs | 11 +++++++--- src/utils/pgwire/src/pg_server.rs | 19 ++++++++++++----- 13 files changed, 78 insertions(+), 47 deletions(-) diff --git a/e2e_test/batch/catalog/pg_settings.slt.part b/e2e_test/batch/catalog/pg_settings.slt.part index eeec9713382d2..9ff41b8dbea45 100644 --- a/e2e_test/batch/catalog/pg_settings.slt.part +++ b/e2e_test/batch/catalog/pg_settings.slt.part @@ -13,7 +13,6 @@ postmaster barrier_interval_ms postmaster checkpoint_frequency postmaster enable_tracing postmaster max_concurrent_creating_streaming_jobs -postmaster oauth_jwks_url postmaster pause_on_next_bootstrap user application_name user background_ddl diff --git a/proto/meta.proto b/proto/meta.proto index 4cb08f872f6a2..1db290af7b308 100644 --- a/proto/meta.proto +++ b/proto/meta.proto @@ -557,7 +557,6 @@ message SystemParams { optional bool pause_on_next_bootstrap = 13; optional string wasm_storage_url = 14 [deprecated = true]; optional bool enable_tracing = 15; - optional string oauth_jwks_url = 16; } message GetSystemParamsRequest {} diff --git a/proto/user.proto b/proto/user.proto index dd04dd558a6a3..0ebb1cb30649b 100644 --- a/proto/user.proto +++ b/proto/user.proto @@ -18,6 +18,7 @@ message AuthInfo { } EncryptionType encryption_type = 1; bytes encrypted_value = 2; + map meta_data = 3; } // User defines a user in the system. diff --git a/src/common/src/system_param/mod.rs b/src/common/src/system_param/mod.rs index 998450a7f79dd..19c36baf09c68 100644 --- a/src/common/src/system_param/mod.rs +++ b/src/common/src/system_param/mod.rs @@ -87,7 +87,6 @@ macro_rules! for_all_params { { max_concurrent_creating_streaming_jobs, u32, Some(1_u32), true, "Max number of concurrent creating streaming jobs.", }, { pause_on_next_bootstrap, bool, Some(false), true, "Whether to pause all data sources on next bootstrap.", }, { enable_tracing, bool, Some(false), true, "Whether to enable distributed tracing.", }, - { oauth_jwks_url, String, Some("".to_string()), true, "Url to get JSON Web Key Set(JWKS) for oauth authentication.", }, } }; } @@ -376,7 +375,6 @@ macro_rules! impl_system_params_for_test { ret.state_store = Some("hummock+memory".to_string()); ret.backup_storage_url = Some("memory".into()); ret.backup_storage_directory = Some("backup".into()); - ret.oauth_jwks_url = Some("https://auth-static.confluent.io/jwks".into()); ret } }; @@ -442,7 +440,6 @@ mod tests { (MAX_CONCURRENT_CREATING_STREAMING_JOBS_KEY, "1"), (PAUSE_ON_NEXT_BOOTSTRAP_KEY, "false"), (ENABLE_TRACING_KEY, "true"), - (OAUTH_JWKS_URL_KEY, "a"), ("a_deprecated_param", "foo"), ]; diff --git a/src/common/src/system_param/reader.rs b/src/common/src/system_param/reader.rs index 7b0e0d4667e08..3374e72120238 100644 --- a/src/common/src/system_param/reader.rs +++ b/src/common/src/system_param/reader.rs @@ -160,11 +160,4 @@ where .enable_tracing .unwrap_or_else(default::enable_tracing) } - - fn oauth_jwks_url(&self) -> &str { - self.inner() - .oauth_jwks_url - .as_ref() - .unwrap_or(&default::OAUTH_JWKS_URL) - } } diff --git a/src/config/docs.md b/src/config/docs.md index 07736636b9e37..63e8f7ce1278d 100644 --- a/src/config/docs.md +++ b/src/config/docs.md @@ -146,7 +146,6 @@ This page is automatically generated by `./risedev generate-example-config` | data_directory | Remote directory for storing data and metadata objects. | | | enable_tracing | Whether to enable distributed tracing. | false | | max_concurrent_creating_streaming_jobs | Max number of concurrent creating streaming jobs. | 1 | -| oauth_jwks_url | Url to get JSON Web Key Set(JWKS) for oauth authentication. | "" | | parallel_compact_size_mb | | 512 | | pause_on_next_bootstrap | Whether to pause all data sources on next bootstrap. | false | | sstable_size_mb | Target size of the Sstable. | 256 | diff --git a/src/config/example.toml b/src/config/example.toml index 9d83b7d5502f5..a1d5fadb52c53 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -196,4 +196,3 @@ bloom_false_positive = 0.001 max_concurrent_creating_streaming_jobs = 1 pause_on_next_bootstrap = false enable_tracing = false -oauth_jwks_url = "" diff --git a/src/frontend/src/handler/alter_user.rs b/src/frontend/src/handler/alter_user.rs index 062b9c401b5a1..05dffaa57ef0d 100644 --- a/src/frontend/src/handler/alter_user.rs +++ b/src/frontend/src/handler/alter_user.rs @@ -20,10 +20,12 @@ use risingwave_sqlparser::ast::{AlterUserStatement, ObjectName, UserOption, User use super::RwPgResponse; use crate::binder::Binder; use crate::catalog::CatalogError; -use crate::error::ErrorCode::{InternalError, PermissionDenied}; +use crate::error::ErrorCode::{self, InternalError, PermissionDenied}; use crate::error::Result; use crate::handler::HandlerArgs; -use crate::user::user_authentication::{build_oauth_info, encrypted_password}; +use crate::user::user_authentication::{ + build_oauth_info, encrypted_password, OAUTH_ISSUER_KEY, OAUTH_JWKS_URL_KEY, +}; use crate::user::user_catalog::UserCatalog; fn alter_prost_user_info( @@ -111,8 +113,14 @@ fn alter_prost_user_info( } update_fields.push(UpdateField::AuthInfo); } - UserOption::OAuth => { - user_info.auth_info = build_oauth_info(); + UserOption::OAuth(options) => { + let auth_info = build_oauth_info(options).ok_or_else(|| { + ErrorCode::InvalidParameterValue(format!( + "{} and {} must be provided", + OAUTH_JWKS_URL_KEY, OAUTH_ISSUER_KEY + )) + })?; + user_info.auth_info = Some(auth_info); update_fields.push(UpdateField::AuthInfo) } } @@ -185,6 +193,8 @@ pub async fn handle_alter_user( #[cfg(test)] mod tests { + use std::collections::HashMap; + use risingwave_pb::user::auth_info::EncryptionType; use risingwave_pb::user::AuthInfo; @@ -223,7 +233,8 @@ mod tests { user_info.auth_info, Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, - encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec() + encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec(), + meta_data: HashMap::new(), }) ); } diff --git a/src/frontend/src/handler/create_user.rs b/src/frontend/src/handler/create_user.rs index 429a35754d1aa..6022693c5cc36 100644 --- a/src/frontend/src/handler/create_user.rs +++ b/src/frontend/src/handler/create_user.rs @@ -20,10 +20,12 @@ use risingwave_sqlparser::ast::{CreateUserStatement, UserOption, UserOptions}; use super::RwPgResponse; use crate::binder::Binder; use crate::catalog::{CatalogError, DatabaseId}; -use crate::error::ErrorCode::PermissionDenied; +use crate::error::ErrorCode::{self, PermissionDenied}; use crate::error::Result; use crate::handler::HandlerArgs; -use crate::user::user_authentication::{build_oauth_info, encrypted_password}; +use crate::user::user_authentication::{ + build_oauth_info, encrypted_password, OAUTH_ISSUER_KEY, OAUTH_JWKS_URL_KEY, +}; use crate::user::user_catalog::UserCatalog; fn make_prost_user_info( @@ -91,7 +93,15 @@ fn make_prost_user_info( user_info.auth_info = encrypted_password(&user_info.name, &password.0); } } - UserOption::OAuth => user_info.auth_info = build_oauth_info(), + UserOption::OAuth(options) => { + let auth_info = build_oauth_info(options).ok_or_else(|| { + ErrorCode::InvalidParameterValue(format!( + "{} and {} must be provided", + OAUTH_JWKS_URL_KEY, OAUTH_ISSUER_KEY + )) + })?; + user_info.auth_info = Some(auth_info); + } } } @@ -131,6 +141,8 @@ pub async fn handle_create_user( #[cfg(test)] mod tests { + use std::collections::HashMap; + use risingwave_common::catalog::DEFAULT_DATABASE_NAME; use risingwave_pb::user::auth_info::EncryptionType; use risingwave_pb::user::AuthInfo; @@ -158,7 +170,8 @@ mod tests { user_info.auth_info, Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, - encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec() + encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec(), + meta_data: HashMap::new(), }) ); frontend diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 17e4f7ef09e0c..3ffd8dc7a6f6a 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -46,7 +46,6 @@ use risingwave_common::session_config::{ConfigMap, ConfigReporter, VisibilityMod use risingwave_common::system_param::local_manager::{ LocalSystemParamsManager, LocalSystemParamsManagerRef, }; -use risingwave_common::system_param::reader::SystemParamsRead; use risingwave_common::telemetry::manager::TelemetryManager; use risingwave_common::telemetry::telemetry_env_enabled; use risingwave_common::types::DataType; @@ -978,20 +977,7 @@ impl SessionManager for SessionManagerImpl { salt, } } else if auth_info.encryption_type == EncryptionType::Oauth as i32 { - let oauth_jwks_url = self - .env - .system_params_manager - .get_params() - .load() - .oauth_jwks_url() - .to_string(); - if oauth_jwks_url.is_empty() { - return Err(Box::new(Error::new( - ErrorKind::PermissionDenied, - "OAuth JWKS URL is not set", - ))); - } - UserAuthenticator::OAuth(oauth_jwks_url) + UserAuthenticator::OAuth(auth_info.meta_data.clone()) } else { return Err(Box::new(Error::new( ErrorKind::Unsupported, diff --git a/src/frontend/src/user/user_authentication.rs b/src/frontend/src/user/user_authentication.rs index d558fb03ee3b6..c1f3e570878c5 100644 --- a/src/frontend/src/user/user_authentication.rs +++ b/src/frontend/src/user/user_authentication.rs @@ -12,8 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use risingwave_pb::user::auth_info::EncryptionType; use risingwave_pb::user::AuthInfo; +use risingwave_sqlparser::ast::SqlOption; use sha2::{Digest, Sha256}; // SHA-256 is not supported in PostgreSQL protocol. We need to implement SCRAM-SHA-256 instead @@ -24,12 +27,23 @@ const MD5_ENCRYPTED_PREFIX: &str = "md5"; const VALID_SHA256_ENCRYPTED_LEN: usize = SHA256_ENCRYPTED_PREFIX.len() + 64; const VALID_MD5_ENCRYPTED_LEN: usize = MD5_ENCRYPTED_PREFIX.len() + 32; +pub const OAUTH_JWKS_URL_KEY: &str = "jwks_url"; +pub const OAUTH_ISSUER_KEY: &str = "issuer"; + /// Build `AuthInfo` for `OAuth`. #[inline(always)] -pub fn build_oauth_info() -> Option { +pub fn build_oauth_info(options: &Vec) -> Option { + let meta_data: HashMap = options + .iter() + .map(|opt| (opt.name.real_value(), opt.value.to_string())) + .collect(); + if !meta_data.contains_key(OAUTH_JWKS_URL_KEY) || !meta_data.contains_key(OAUTH_ISSUER_KEY) { + return None; + } Some(AuthInfo { encryption_type: EncryptionType::Oauth as i32, encrypted_value: Vec::new(), + meta_data, }) } @@ -62,11 +76,13 @@ pub fn encrypted_password(name: &str, password: &str) -> Option { Some(AuthInfo { encryption_type: EncryptionType::Sha256 as i32, encrypted_value: password.trim_start_matches(SHA256_ENCRYPTED_PREFIX).into(), + meta_data: HashMap::new(), }) } else if valid_md5_password(password) { Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: password.trim_start_matches(MD5_ENCRYPTED_PREFIX).into(), + meta_data: HashMap::new(), }) } else { Some(encrypt_default(name, password)) @@ -79,6 +95,7 @@ fn encrypt_default(name: &str, password: &str) -> AuthInfo { AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(name, password), + meta_data: HashMap::new(), } } @@ -166,15 +183,18 @@ mod tests { Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(user_name, password), + meta_data: HashMap::new(), }), None, Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(user_name, password), + meta_data: HashMap::new(), }), Some(AuthInfo { encryption_type: EncryptionType::Sha256 as i32, encrypted_value: sha256_hash(user_name, password), + meta_data: HashMap::new(), }), ]; let output_passwords = input_passwords diff --git a/src/sqlparser/src/ast/statement.rs b/src/sqlparser/src/ast/statement.rs index d5def93c81fef..6edc0702a425d 100644 --- a/src/sqlparser/src/ast/statement.rs +++ b/src/sqlparser/src/ast/statement.rs @@ -741,7 +741,7 @@ pub enum UserOption { NoLogin, EncryptedPassword(AstString), Password(Option), - OAuth, + OAuth(Vec), } impl fmt::Display for UserOption { @@ -758,7 +758,9 @@ impl fmt::Display for UserOption { UserOption::EncryptedPassword(p) => write!(f, "ENCRYPTED PASSWORD {}", p), UserOption::Password(None) => write!(f, "PASSWORD NULL"), UserOption::Password(Some(p)) => write!(f, "PASSWORD {}", p), - UserOption::OAuth => write!(f, "OAUTH"), + UserOption::OAuth(options) => { + write!(f, "({})", display_comma_separated(options.as_slice())) + } } } } @@ -846,7 +848,10 @@ impl ParseTo for UserOptions { UserOption::EncryptedPassword(AstString::parse_to(parser)?), ) } - Keyword::OAUTH => (&mut builder.password, UserOption::OAuth), + Keyword::OAUTH => { + let options = parser.parse_options()?; + (&mut builder.password, UserOption::OAuth(options)) + } _ => { parser.expected( "SUPERUSER | NOSUPERUSER | CREATEDB | NOCREATEDB | LOGIN \ diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index eb9c31442b85c..5fef18a61bff9 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -159,7 +159,7 @@ pub enum UserAuthenticator { encrypted_password: Vec, salt: [u8; 4], }, - OAuth(String), + OAuth(HashMap), } #[derive(Debug, Deserialize)] @@ -181,7 +181,11 @@ async fn fetch_jwks(url: &str) -> Result { Ok(resp) } -async fn validate_jwt(jwt: &str, jwks_url: &str) -> Result { +async fn validate_jwt( + jwt: &str, + jwks_url: &str, + meta_data: &HashMap, +) -> Result { let header = decode_header(jwt)?; let jwks = fetch_jwks(jwks_url).await?; @@ -194,8 +198,11 @@ async fn validate_jwt(jwt: &str, jwks_url: &str) -> Result { let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?; let validation = Validation::new(Algorithm::from_str(&jwk.alg)?); + let token_data = decode::>(jwt, &decoding_key, &validation)?; - Ok(decode::>(jwt, &decoding_key, &validation).is_ok()) + Ok(meta_data + .iter() + .all(|(k, v)| token_data.claims.get(k) == Some(v))) } impl UserAuthenticator { @@ -206,8 +213,10 @@ impl UserAuthenticator { UserAuthenticator::Md5WithSalt { encrypted_password, .. } => encrypted_password == password, - UserAuthenticator::OAuth(oauth_jwks_url) => { - validate_jwt(&String::from_utf8_lossy(password), oauth_jwks_url) + UserAuthenticator::OAuth(meta_data) => { + let mut meta_data = meta_data.clone(); + let jwks_url = meta_data.remove("jwks_url").unwrap(); + validate_jwt(&String::from_utf8_lossy(password), &jwks_url, &meta_data) .await .map_err(PsqlError::StartupError)? } From 94d4462e027cad68c3b30ec6a235758a3b8e5398 Mon Sep 17 00:00:00 2001 From: Rossil <40714231+Rossil2012@users.noreply.github.com> Date: Thu, 29 Feb 2024 15:47:51 +0800 Subject: [PATCH 11/12] minor fix --- Cargo.lock | 1 + src/frontend/src/user/user_authentication.rs | 9 ++++-- src/utils/pgwire/Cargo.toml | 1 + src/utils/pgwire/src/pg_server.rs | 30 +++++++++++++------- 4 files changed, 27 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8582bd0e76ec4..ec832a85e95cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7358,6 +7358,7 @@ dependencies = [ "risingwave_common", "risingwave_sqlparser", "serde", + "serde_json", "tempfile", "thiserror", "thiserror-ext", diff --git a/src/frontend/src/user/user_authentication.rs b/src/frontend/src/user/user_authentication.rs index c1f3e570878c5..10dea11c4e13c 100644 --- a/src/frontend/src/user/user_authentication.rs +++ b/src/frontend/src/user/user_authentication.rs @@ -19,6 +19,8 @@ use risingwave_pb::user::AuthInfo; use risingwave_sqlparser::ast::SqlOption; use sha2::{Digest, Sha256}; +use crate::WithOptions; + // SHA-256 is not supported in PostgreSQL protocol. We need to implement SCRAM-SHA-256 instead // if necessary. const SHA256_ENCRYPTED_PREFIX: &str = "SHA-256:"; @@ -33,9 +35,10 @@ pub const OAUTH_ISSUER_KEY: &str = "issuer"; /// Build `AuthInfo` for `OAuth`. #[inline(always)] pub fn build_oauth_info(options: &Vec) -> Option { - let meta_data: HashMap = options - .iter() - .map(|opt| (opt.name.real_value(), opt.value.to_string())) + let meta_data: HashMap = WithOptions::try_from(options.as_slice()) + .ok()? + .into_inner() + .into_iter() .collect(); if !meta_data.contains_key(OAUTH_JWKS_URL_KEY) || !meta_data.contains_key(OAUTH_ISSUER_KEY) { return None; diff --git a/src/utils/pgwire/Cargo.toml b/src/utils/pgwire/Cargo.toml index 907b2ac5904db..47840b0cf4983 100644 --- a/src/utils/pgwire/Cargo.toml +++ b/src/utils/pgwire/Cargo.toml @@ -29,6 +29,7 @@ reqwest = { version = "0.11" } risingwave_common = { workspace = true } risingwave_sqlparser = { workspace = true } serde = { version = "1", features = ["derive"] } +serde_json = "1" thiserror = "1" thiserror-ext = { workspace = true } tokio = { version = "0.2", package = "madsim-tokio", features = ["rt", "macros"] } diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index 5fef18a61bff9..5f61d6d5ab6e9 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -16,12 +16,11 @@ use std::collections::HashMap; use std::future::Future; use std::io; use std::result::Result; -use std::str::FromStr; use std::sync::Arc; use std::time::Instant; use bytes::Bytes; -use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation}; +use jsonwebtoken::{decode, decode_header, DecodingKey, Validation}; use parking_lot::Mutex; use risingwave_common::types::DataType; use risingwave_sqlparser::ast::Statement; @@ -184,6 +183,7 @@ async fn fetch_jwks(url: &str) -> Result { async fn validate_jwt( jwt: &str, jwks_url: &str, + issuer: &str, meta_data: &HashMap, ) -> Result { let header = decode_header(jwt)?; @@ -197,12 +197,14 @@ async fn validate_jwt( .ok_or("kid not found in jwks")?; let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?; - let validation = Validation::new(Algorithm::from_str(&jwk.alg)?); - let token_data = decode::>(jwt, &decoding_key, &validation)?; - - Ok(meta_data - .iter() - .all(|(k, v)| token_data.claims.get(k) == Some(v))) + let mut validation = Validation::new(header.alg); + validation.set_issuer(&[issuer]); + validation.set_required_spec_claims(&["exp", "iss"]); + let token_data = decode::>(jwt, &decoding_key, &validation)?; + + Ok(meta_data.iter().all( + |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v), + )) } impl UserAuthenticator { @@ -216,9 +218,15 @@ impl UserAuthenticator { UserAuthenticator::OAuth(meta_data) => { let mut meta_data = meta_data.clone(); let jwks_url = meta_data.remove("jwks_url").unwrap(); - validate_jwt(&String::from_utf8_lossy(password), &jwks_url, &meta_data) - .await - .map_err(PsqlError::StartupError)? + let issuer = meta_data.remove("issuer").unwrap(); + validate_jwt( + &String::from_utf8_lossy(password), + &jwks_url, + &issuer, + &meta_data, + ) + .await + .map_err(PsqlError::StartupError)? } }; if !success { From b5bc3173987e199e231870d222ce59561e6722f8 Mon Sep 17 00:00:00 2001 From: Rossil <40714231+Rossil2012@users.noreply.github.com> Date: Fri, 1 Mar 2024 20:10:37 +0800 Subject: [PATCH 12/12] fix comments --- proto/user.proto | 2 +- src/frontend/src/handler/alter_user.rs | 2 +- src/frontend/src/handler/create_user.rs | 2 +- src/frontend/src/session.rs | 2 +- src/frontend/src/user/user_authentication.rs | 18 +++---- src/storage/src/hummock/sstable_store.rs | 4 +- src/utils/pgwire/src/pg_server.rs | 51 ++++++++++++-------- 7 files changed, 46 insertions(+), 35 deletions(-) diff --git a/proto/user.proto b/proto/user.proto index 0ebb1cb30649b..014a8d0c1b0d3 100644 --- a/proto/user.proto +++ b/proto/user.proto @@ -18,7 +18,7 @@ message AuthInfo { } EncryptionType encryption_type = 1; bytes encrypted_value = 2; - map meta_data = 3; + map metadata = 3; } // User defines a user in the system. diff --git a/src/frontend/src/handler/alter_user.rs b/src/frontend/src/handler/alter_user.rs index 05dffaa57ef0d..431a217a20cf3 100644 --- a/src/frontend/src/handler/alter_user.rs +++ b/src/frontend/src/handler/alter_user.rs @@ -234,7 +234,7 @@ mod tests { Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec(), - meta_data: HashMap::new(), + metadata: HashMap::new(), }) ); } diff --git a/src/frontend/src/handler/create_user.rs b/src/frontend/src/handler/create_user.rs index 6022693c5cc36..bfdc33e6db80f 100644 --- a/src/frontend/src/handler/create_user.rs +++ b/src/frontend/src/handler/create_user.rs @@ -171,7 +171,7 @@ mod tests { Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec(), - meta_data: HashMap::new(), + metadata: HashMap::new(), }) ); frontend diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 3ffd8dc7a6f6a..30d1b02df7c03 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -977,7 +977,7 @@ impl SessionManager for SessionManagerImpl { salt, } } else if auth_info.encryption_type == EncryptionType::Oauth as i32 { - UserAuthenticator::OAuth(auth_info.meta_data.clone()) + UserAuthenticator::OAuth(auth_info.metadata.clone()) } else { return Err(Box::new(Error::new( ErrorKind::Unsupported, diff --git a/src/frontend/src/user/user_authentication.rs b/src/frontend/src/user/user_authentication.rs index 10dea11c4e13c..b0cefc1faedcb 100644 --- a/src/frontend/src/user/user_authentication.rs +++ b/src/frontend/src/user/user_authentication.rs @@ -35,18 +35,18 @@ pub const OAUTH_ISSUER_KEY: &str = "issuer"; /// Build `AuthInfo` for `OAuth`. #[inline(always)] pub fn build_oauth_info(options: &Vec) -> Option { - let meta_data: HashMap = WithOptions::try_from(options.as_slice()) + let metadata: HashMap = WithOptions::try_from(options.as_slice()) .ok()? .into_inner() .into_iter() .collect(); - if !meta_data.contains_key(OAUTH_JWKS_URL_KEY) || !meta_data.contains_key(OAUTH_ISSUER_KEY) { + if !metadata.contains_key(OAUTH_JWKS_URL_KEY) || !metadata.contains_key(OAUTH_ISSUER_KEY) { return None; } Some(AuthInfo { encryption_type: EncryptionType::Oauth as i32, encrypted_value: Vec::new(), - meta_data, + metadata, }) } @@ -79,13 +79,13 @@ pub fn encrypted_password(name: &str, password: &str) -> Option { Some(AuthInfo { encryption_type: EncryptionType::Sha256 as i32, encrypted_value: password.trim_start_matches(SHA256_ENCRYPTED_PREFIX).into(), - meta_data: HashMap::new(), + metadata: HashMap::new(), }) } else if valid_md5_password(password) { Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: password.trim_start_matches(MD5_ENCRYPTED_PREFIX).into(), - meta_data: HashMap::new(), + metadata: HashMap::new(), }) } else { Some(encrypt_default(name, password)) @@ -98,7 +98,7 @@ fn encrypt_default(name: &str, password: &str) -> AuthInfo { AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(name, password), - meta_data: HashMap::new(), + metadata: HashMap::new(), } } @@ -186,18 +186,18 @@ mod tests { Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(user_name, password), - meta_data: HashMap::new(), + metadata: HashMap::new(), }), None, Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(user_name, password), - meta_data: HashMap::new(), + metadata: HashMap::new(), }), Some(AuthInfo { encryption_type: EncryptionType::Sha256 as i32, encrypted_value: sha256_hash(user_name, password), - meta_data: HashMap::new(), + metadata: HashMap::new(), }), ]; let output_passwords = input_passwords diff --git a/src/storage/src/hummock/sstable_store.rs b/src/storage/src/hummock/sstable_store.rs index c603b7d8f503a..f0cacf863fcc9 100644 --- a/src/storage/src/hummock/sstable_store.rs +++ b/src/storage/src/hummock/sstable_store.rs @@ -1020,9 +1020,9 @@ impl SstableWriter for StreamingUploadWriter { } async fn finish(mut self, meta: SstableMeta) -> HummockResult { - let meta_data = Bytes::from(meta.encode_to_bytes()); + let metadata = Bytes::from(meta.encode_to_bytes()); - self.object_uploader.write_bytes(meta_data).await?; + self.object_uploader.write_bytes(metadata).await?; let join_handle = tokio::spawn(async move { let uploader_memory_usage = self.object_uploader.get_memory_usage(); let _tracker = self.tracker.map(|mut t| { diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index 5f61d6d5ab6e9..7f6dd41368d45 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -16,11 +16,12 @@ use std::collections::HashMap; use std::future::Future; use std::io; use std::result::Result; +use std::str::FromStr; use std::sync::Arc; use std::time::Instant; use bytes::Bytes; -use jsonwebtoken::{decode, decode_header, DecodingKey, Validation}; +use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation}; use parking_lot::Mutex; use risingwave_common::types::DataType; use risingwave_sqlparser::ast::Statement; @@ -161,34 +162,34 @@ pub enum UserAuthenticator { OAuth(HashMap), } +/// A JWK Set is a JSON object that represents a set of JWKs. +/// The JSON object MUST have a "keys" member, with its value being an array of JWKs. +/// See for more details. #[derive(Debug, Deserialize)] struct Jwks { keys: Vec, } -#[allow(dead_code)] +/// A JSON Web Key (JWK) is a JSON object that represents a cryptographic key. +/// See for more details. #[derive(Debug, Deserialize)] struct Jwk { - kid: String, - alg: String, - n: String, - e: String, -} - -async fn fetch_jwks(url: &str) -> Result { - let resp: Jwks = reqwest::get(url).await?.json().await?; - Ok(resp) + kid: String, // Key ID + alg: String, // Algorithm + n: String, // Modulus + e: String, // Exponent } async fn validate_jwt( jwt: &str, jwks_url: &str, issuer: &str, - meta_data: &HashMap, + metadata: &HashMap, ) -> Result { let header = decode_header(jwt)?; - let jwks = fetch_jwks(jwks_url).await?; + let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?; + // 1. Retrieve the kid from the header to find the right JWK in the JWK Set. let kid = header.kid.ok_or("kid not found in jwt header")?; let jwk = jwks .keys @@ -196,15 +197,25 @@ async fn validate_jwt( .find(|k| k.kid == kid) .ok_or("kid not found in jwks")?; + // 2. Check if the algorithms are matched. + if Algorithm::from_str(&jwk.alg)? != header.alg { + return Err("alg in jwt header does not match with alg in jwk".into()); + } + + // 3. Decode the JWT and validate the claims. let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?; let mut validation = Validation::new(header.alg); validation.set_issuer(&[issuer]); validation.set_required_spec_claims(&["exp", "iss"]); let token_data = decode::>(jwt, &decoding_key, &validation)?; - Ok(meta_data.iter().all( + // 4. Check if the metadata in the token matches. + if !metadata.iter().all( |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v), - )) + ) { + return Err("metadata in jwt does not match with metadata declared with user".into()); + } + Ok(true) } impl UserAuthenticator { @@ -215,15 +226,15 @@ impl UserAuthenticator { UserAuthenticator::Md5WithSalt { encrypted_password, .. } => encrypted_password == password, - UserAuthenticator::OAuth(meta_data) => { - let mut meta_data = meta_data.clone(); - let jwks_url = meta_data.remove("jwks_url").unwrap(); - let issuer = meta_data.remove("issuer").unwrap(); + UserAuthenticator::OAuth(metadata) => { + let mut metadata = metadata.clone(); + let jwks_url = metadata.remove("jwks_url").unwrap(); + let issuer = metadata.remove("issuer").unwrap(); validate_jwt( &String::from_utf8_lossy(password), &jwks_url, &issuer, - &meta_data, + &metadata, ) .await .map_err(PsqlError::StartupError)?