Skip to content

Commit

Permalink
Merge pull request #5 from yjinjo/master
Browse files Browse the repository at this point in the history
Change variable from label to idp_name and optimize the code
  • Loading branch information
yjinjo authored Jun 7, 2024
2 parents 8b3b647 + bac8f79 commit 1705dc2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 28 deletions.
61 changes: 37 additions & 24 deletions src/plugin/connector/saml_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,21 @@ def init(self, options: dict) -> dict:
metadata_url = options.get("metadata_url")

xml_data = self._fetch_xml(metadata_url)
_, _, sso_url = self._parse_xml(xml_data)
_, _, sso_url = self._parse_idp_xml(xml_data)

label = self._get_label(identity_provider)
idp_name = self._get_idp_name(identity_provider)

metadata = {
"identity_provider": identity_provider,
"protocol": protocol,
"icon": icon,
"label": label,
"idp_name": idp_name,
"sso_url": sso_url,
}

return metadata

def authorize(
self, params: dict, metadata_url: str, sp_metadata_url: str, domain_id: str
) -> dict:
def authorize(self, params: dict, metadata_url: str, domain_id: str) -> dict:
"""Authorizes the user using SAML.
Args:
Expand All @@ -67,7 +65,7 @@ def authorize(
Raises:
ERROR_AUTHENTICATE_FAILURE: If authentication fails
"""
self._set_saml_settings(metadata_url, sp_metadata_url, domain_id)
self._set_saml_settings(params, metadata_url, domain_id)

auth = OneLogin_Saml2_Auth(
params,
Expand Down Expand Up @@ -98,44 +96,56 @@ def _get_user_info_from_auth(auth: OneLogin_Saml2_Auth) -> dict:
'user_info': 'dict'
'user_id': 'str'
"""
user_info = {"user_id": auth.get_nameid()}
user_info = {}

return user_info
try:
name_id = auth.get_nameid()

user_info = {"user_id": name_id}

return user_info
except Exception as e:
_LOGGER.error(f"[_get_user_info_from_auth] ERROR_NOT_FOUND: {e}")
raise ERROR_NOT_FOUND(message=f"ERROR_NOT_FOUND: {e}")

def _set_saml_settings(
self,
params: dict,
metadata_url: str,
sp_metadata_url: str,
domain_id: str,
) -> None:
"""Sets the SAML settings using the metadata URL and domain ID.
Args:
'params': 'dict',
'metadata_url': 'str',
'sp_metadata_url': 'str',
'domain_id': 'str',
"""
xml_data = self._fetch_xml(metadata_url)
entity_id, x509_certificate, sso_url = self._parse_xml(xml_data)
idp_xml_data = self._fetch_xml(metadata_url)
entity_id, idp_x509_certificate, sso_url = self._parse_idp_xml(idp_xml_data)

http_host = params.get("http_host")
acs_url = f"https://{http_host}/console-api/extension/auth/saml/{domain_id}"

self.saml_settings = {
"strict": False,
"strict": True,
"debug": True,
"idp": {
"entityId": entity_id,
"singleSignOnService": {
"url": sso_url,
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
},
"x509cert": x509_certificate,
"x509cert": idp_x509_certificate,
},
"sp": {
"entityId": domain_id,
"assertionConsumerService": {
"url": sso_url,
"url": acs_url,
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
},
"x509cert": x509_certificate,
# "x509cert": sp_x509_certificate,
},
}

Expand All @@ -161,7 +171,7 @@ def _fetch_xml(metadata_url: str) -> bytes:
raise ERROR_NOT_FOUND(message=f"ERROR_NOT_FOUND: {e}")

@staticmethod
def _parse_xml(xml_data: bytes) -> Tuple[str, str, str]:
def _parse_idp_xml(xml_data: bytes) -> Tuple[str, str, str]:
"""Parses the XML data to extract entity ID, x509 certificate, and SSO URL.
Args:
Expand All @@ -181,21 +191,24 @@ def _parse_xml(xml_data: bytes) -> Tuple[str, str, str]:
x509_certificate = root.find(".//ds:X509Certificate", ns).text

sso_service = root.find(".//md:SingleSignOnService", ns)
sso_url = sso_service.attrib["Location"] if sso_service is not None else None

sso_url = None
if sso_service is not None:
sso_url = sso_service.attrib["Location"]

return entity_id, x509_certificate, sso_url

@staticmethod
def _get_label(identity_provider: str) -> str:
"""Generates a label for the identity provider.
def _get_idp_name(identity_provider: str) -> str:
"""Generates a name for the identity provider.
Args:
'identity_provider': 'str'
Returns:
'label': 'str'
'idp_name': 'str'
"""
labels = {
idp_name = {
"okta": "Okta",
"frontegg": "Frontegg",
"auth0": "Auth0",
Expand All @@ -206,6 +219,6 @@ def _get_label(identity_provider: str) -> str:
"keycloak": "Keycloak",
"microsoft_entra_id": "Microsoft Entra ID",
}
label = labels.get(identity_provider, identity_provider.capitalize())
idp_name = idp_name.get(identity_provider, identity_provider.capitalize())

return f"Sign In with {label}"
return f"Sign In with {idp_name}"
5 changes: 1 addition & 4 deletions src/plugin/manager/external_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ def authorize(self, params: dict) -> dict:
"""
credentials = params["credentials"]
metadata_url = params["options"].get("metadata_url")
sp_metadata_url = params["options"].get("sp_metadata_url")
domain_id = params["domain_id"]
user_info = self.saml_connector.authorize(
credentials, metadata_url, sp_metadata_url, domain_id
)
user_info = self.saml_connector.authorize(credentials, metadata_url, domain_id)

return user_info

0 comments on commit 1705dc2

Please sign in to comment.