Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(users): Decision manager flow changes for SSO #4995

Merged
merged 10 commits into from
Jun 24, 2024
4 changes: 4 additions & 0 deletions crates/common_enums/src/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2758,6 +2758,10 @@ pub enum BankHolderType {
#[strum(serialize_all = "snake_case")]
#[serde(rename_all = "snake_case")]
pub enum TokenPurpose {
AuthSelect,
#[serde(rename = "sso")]
#[strum(serialize = "sso")]
SSO,
#[serde(rename = "totp")]
#[strum(serialize = "totp")]
TOTP,
Expand Down
7 changes: 3 additions & 4 deletions crates/router/src/core/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ pub async fn accept_invite_from_email_token_only_flow(
.map_err(|e| logger::error!(?e));

let current_flow = domain::CurrentFlow::new(
user_token.origin,
user_token,
domain::SPTFlow::AcceptInvitationFromEmail.into(),
)?;
let next_flow = current_flow.next(user_from_db.clone(), &state).await?;
Expand Down Expand Up @@ -1502,8 +1502,7 @@ pub async fn verify_email_token_only_flow(
.await
.map_err(|e| logger::error!(?e));

let current_flow =
domain::CurrentFlow::new(user_token.origin, domain::SPTFlow::VerifyEmail.into())?;
let current_flow = domain::CurrentFlow::new(user_token, domain::SPTFlow::VerifyEmail.into())?;
let next_flow = current_flow.next(user_from_db, &state).await?;
let token = next_flow.get_token(&state).await?;

Expand Down Expand Up @@ -1959,7 +1958,7 @@ pub async fn terminate_two_factor_auth(
}
}

let current_flow = domain::CurrentFlow::new(user_token.origin, domain::SPTFlow::TOTP.into())?;
let current_flow = domain::CurrentFlow::new(user_token, domain::SPTFlow::TOTP.into())?;
let next_flow = current_flow.next(user_from_db, &state).await?;
let token = next_flow.get_token(&state).await?;

Expand Down
2 changes: 1 addition & 1 deletion crates/router/src/core/user_role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ pub async fn merchant_select_token_only_flow(
.into();

let current_flow =
domain::CurrentFlow::new(user_token.origin, domain::SPTFlow::MerchantSelect.into())?;
domain::CurrentFlow::new(user_token, domain::SPTFlow::MerchantSelect.into())?;
let next_flow = current_flow.next(user_from_db.clone(), &state).await?;

let token = next_flow
Expand Down
5 changes: 5 additions & 0 deletions crates/router/src/services/authentication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ impl AuthenticationType {
pub struct UserFromSinglePurposeToken {
pub user_id: String,
pub origin: domain::Origin,
pub path: Vec<TokenPurpose>,
}

#[cfg(feature = "olap")]
Expand All @@ -132,6 +133,7 @@ pub struct SinglePurposeToken {
pub user_id: String,
pub purpose: TokenPurpose,
pub origin: domain::Origin,
pub path: Vec<TokenPurpose>,
pub exp: u64,
}

Expand All @@ -142,6 +144,7 @@ impl SinglePurposeToken {
purpose: TokenPurpose,
origin: domain::Origin,
settings: &Settings,
path: Vec<TokenPurpose>,
) -> UserResult<String> {
let exp_duration =
std::time::Duration::from_secs(consts::SINGLE_PURPOSE_TOKEN_TIME_IN_SECS);
Expand All @@ -151,6 +154,7 @@ impl SinglePurposeToken {
purpose,
origin,
exp,
path,
};
jwt::generate_jwt(&token_payload, settings).await
}
Expand Down Expand Up @@ -356,6 +360,7 @@ where
UserFromSinglePurposeToken {
user_id: payload.user_id.clone(),
origin: payload.origin.clone(),
path: payload.path,
},
AuthenticationType::SinglePurposeJwt {
user_id: payload.user_id,
Expand Down
1 change: 1 addition & 0 deletions crates/router/src/types/domain/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,7 @@ impl SignInWithMultipleRolesStrategy {
TokenPurpose::AcceptInvite,
Origin::SignIn,
&state.conf,
vec![],
)
.await?
.into(),
Expand Down
67 changes: 55 additions & 12 deletions crates/router/src/types/domain/user/decision_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,23 @@ pub enum UserFlow {
}

impl UserFlow {
async fn is_required(&self, user: &UserFromStorage, state: &SessionState) -> UserResult<bool> {
async fn is_required(
&self,
user: &UserFromStorage,
path: &[TokenPurpose],
state: &SessionState,
) -> UserResult<bool> {
match self {
Self::SPTFlow(flow) => flow.is_required(user, state).await,
Self::SPTFlow(flow) => flow.is_required(user, path, state).await,
Self::JWTFlow(flow) => flow.is_required(user, state).await,
}
}
}

#[derive(Eq, PartialEq, Clone, Copy)]
pub enum SPTFlow {
AuthSelect,
SSO,
TOTP,
VerifyEmail,
AcceptInvitationFromEmail,
Expand All @@ -36,15 +43,26 @@ pub enum SPTFlow {
}

impl SPTFlow {
async fn is_required(&self, user: &UserFromStorage, state: &SessionState) -> UserResult<bool> {
async fn is_required(
&self,
user: &UserFromStorage,
path: &[TokenPurpose],
state: &SessionState,
) -> UserResult<bool> {
match self {
// Auth
// AuthSelect and SSO flow are not enabled, once the terminate SSO API is ready, we can enable these flows
Self::AuthSelect => Ok(false),
Self::SSO => Ok(false),
// TOTP
Self::TOTP => Ok(true),
Self::TOTP => Ok(!path.contains(&TokenPurpose::SSO)),
// Main email APIs
Self::AcceptInvitationFromEmail | Self::ResetPassword => Ok(true),
Self::VerifyEmail => Ok(true),
// Final Checks
Self::ForceSetPassword => user.is_password_rotate_required(state),
Self::ForceSetPassword => user
.is_password_rotate_required(state)
.map(|rotate_required| rotate_required && !path.contains(&TokenPurpose::SSO)),
Self::MerchantSelect => user
.get_roles_from_db(state)
.await
Expand All @@ -62,6 +80,7 @@ impl SPTFlow {
self.into(),
next_flow.origin.clone(),
&state.conf,
next_flow.path.to_vec(),
)
.await
.map(|token| token.into())
Expand Down Expand Up @@ -103,6 +122,8 @@ impl JWTFlow {
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum Origin {
#[serde(rename = "sign_in_with_sso")]
SignInWithSSO,
SignIn,
SignUp,
MagicLink,
Expand All @@ -114,6 +135,7 @@ pub enum Origin {
impl Origin {
fn get_flows(&self) -> &'static [UserFlow] {
match self {
Self::SignInWithSSO => &SIGNIN_WITH_SSO_FLOW,
Self::SignIn => &SIGNIN_FLOW,
Self::SignUp => &SIGNUP_FLOW,
Self::VerifyEmail => &VERIFY_EMAIL_FLOW,
Expand All @@ -124,6 +146,11 @@ impl Origin {
}
}

const SIGNIN_WITH_SSO_FLOW: [UserFlow; 2] = [
UserFlow::SPTFlow(SPTFlow::MerchantSelect),
UserFlow::JWTFlow(JWTFlow::UserInfo),
];

const SIGNIN_FLOW: [UserFlow; 4] = [
UserFlow::SPTFlow(SPTFlow::TOTP),
UserFlow::SPTFlow(SPTFlow::ForceSetPassword),
Expand Down Expand Up @@ -154,7 +181,9 @@ const VERIFY_EMAIL_FLOW: [UserFlow; 5] = [
UserFlow::JWTFlow(JWTFlow::UserInfo),
];

const ACCEPT_INVITATION_FROM_EMAIL_FLOW: [UserFlow; 4] = [
const ACCEPT_INVITATION_FROM_EMAIL_FLOW: [UserFlow; 6] = [
UserFlow::SPTFlow(SPTFlow::AuthSelect),
UserFlow::SPTFlow(SPTFlow::SSO),
UserFlow::SPTFlow(SPTFlow::TOTP),
UserFlow::SPTFlow(SPTFlow::AcceptInvitationFromEmail),
UserFlow::SPTFlow(SPTFlow::ForceSetPassword),
Expand All @@ -169,31 +198,40 @@ const RESET_PASSWORD_FLOW: [UserFlow; 2] = [
pub struct CurrentFlow {
origin: Origin,
current_flow_index: usize,
path: Vec<TokenPurpose>,
}

impl CurrentFlow {
pub fn new(origin: Origin, current_flow: UserFlow) -> UserResult<Self> {
let flows = origin.get_flows();
pub fn new(
token: auth::UserFromSinglePurposeToken,
current_flow: UserFlow,
) -> UserResult<Self> {
let flows = token.origin.get_flows();
let index = flows
.iter()
.position(|flow| flow == &current_flow)
.ok_or(UserErrors::InternalServerError)?;
let mut path = token.path;
path.push(current_flow.into());

Ok(Self {
origin,
origin: token.origin,
current_flow_index: index,
path,
})
}

pub async fn next(&self, user: UserFromStorage, state: &SessionState) -> UserResult<NextFlow> {
pub async fn next(self, user: UserFromStorage, state: &SessionState) -> UserResult<NextFlow> {
let flows = self.origin.get_flows();
let remaining_flows = flows.iter().skip(self.current_flow_index + 1);

for flow in remaining_flows {
if flow.is_required(&user, state).await? {
if flow.is_required(&user, &self.path, state).await? {
return Ok(NextFlow {
origin: self.origin.clone(),
next_flow: *flow,
user,
path: self.path,
});
}
}
Expand All @@ -205,6 +243,7 @@ pub struct NextFlow {
origin: Origin,
next_flow: UserFlow,
user: UserFromStorage,
path: Vec<TokenPurpose>,
}

impl NextFlow {
Expand All @@ -214,12 +253,14 @@ impl NextFlow {
state: &SessionState,
) -> UserResult<Self> {
let flows = origin.get_flows();
let path = vec![];
for flow in flows {
if flow.is_required(&user, state).await? {
if flow.is_required(&user, &path, state).await? {
return Ok(Self {
origin,
next_flow: *flow,
user,
path,
});
}
}
Expand Down Expand Up @@ -284,6 +325,8 @@ impl From<UserFlow> for TokenPurpose {
impl From<SPTFlow> for TokenPurpose {
fn from(value: SPTFlow) -> Self {
match value {
SPTFlow::AuthSelect => Self::AuthSelect,
SPTFlow::SSO => Self::SSO,
SPTFlow::TOTP => Self::TOTP,
SPTFlow::VerifyEmail => Self::VerifyEmail,
SPTFlow::AcceptInvitationFromEmail => Self::AcceptInvitationFromEmail,
Expand Down
Loading