diff --git a/src/query/service/src/interpreters/interpreter_role_set.rs b/src/query/service/src/interpreters/interpreter_role_set.rs index 9ce4145ba722b..cf11b4f6b3438 100644 --- a/src/query/service/src/interpreters/interpreter_role_set.rs +++ b/src/query/service/src/interpreters/interpreter_role_set.rs @@ -44,40 +44,23 @@ impl Interpreter for SetRoleInterpreter { #[tracing::instrument(level = "debug", skip(self), fields(ctx.id = self.ctx.get_id().as_str()))] async fn execute2(&self) -> Result { let session = self.ctx.get_current_session(); - let current_user = session.get_current_user()?; - let available_roles = session.get_all_available_roles().await?; - let role = available_roles - .iter() - .find(|r| r.name == self.plan.role_name); - match role { - None => { - let available_role_names = available_roles - .iter() - .map(|r| r.name.clone()) - .collect::>() - .join(","); - return Err(common_exception::ErrorCode::InvalidRole(format!( - "Invalid role ({}) for {}, available: {}", - self.plan.role_name, + let role = session + .validate_available_role(&self.plan.role_name) + .await?; + if self.plan.is_default { + let current_user = self.ctx.get_current_user()?; + UserApiProvider::instance() + .update_user_default_role( + &self.ctx.get_tenant(), current_user.identity(), - available_role_names, - ))); - } - Some(role) => { - if self.plan.is_default { - let current_user = self.ctx.get_current_user()?; - UserApiProvider::instance() - .update_user_default_role( - &self.ctx.get_tenant(), - current_user.identity(), - Some(role.name.clone()), - ) - .await?; - } else { - session.set_current_role(Some(role.clone())); - } - } + Some(role.name.clone()), + ) + .await?; + } else { + session + .set_current_role_checked(&self.plan.role_name) + .await?; } Ok(PipelineBuildResult::create()) } diff --git a/src/query/service/src/sessions/session.rs b/src/query/service/src/sessions/session.rs index 6866dcee3d441..ecf9fe4c78b88 100644 --- a/src/query/service/src/sessions/session.rs +++ b/src/query/service/src/sessions/session.rs @@ -266,8 +266,31 @@ impl Session { Ok(()) } - pub fn set_current_role(self: &Arc, role: Option) { - self.session_ctx.set_current_role(role); + pub async fn validate_available_role(self: &Arc, role_name: &str) -> Result { + let available_roles = self.get_all_available_roles().await?; + let role = available_roles.iter().find(|r| r.name == role_name); + match role { + Some(role) => Ok(role.clone()), + None => { + let available_role_names = available_roles + .iter() + .map(|r| r.name.clone()) + .collect::>() + .join(","); + return Err(ErrorCode::InvalidRole(format!( + "Invalid role {} for current session, available: {}", + role_name, available_role_names, + ))); + } + } + } + + // Only the available role can be set as current role. The current role can be set by the SET + // ROLE statement, or by the X-DATABEND-ROLE header in HTTP protocol (not implemented yet). + pub async fn set_current_role_checked(self: &Arc, role_name: &str) -> Result<()> { + let role = self.validate_available_role(role_name).await?; + self.session_ctx.set_current_role(Some(role)); + Ok(()) } pub fn get_current_role(self: &Arc) -> Option {