Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to raise response validation exceptions #37

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/onelogin/saml2/logout_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,15 @@ def get_session_indexes(request):
session_indexes.append(session_index_node.text)
return session_indexes

def is_valid(self, request_data):
def is_valid(self, request_data, raises=False):
"""
Checks if the Logout Request received is valid
:param request_data: Request Data
:type request_data: dict

:param raises: Optional argument. If true, the function will raise an exception as soon as first validation test fails
:type raises: bool

:return: If the Logout Request is or not valid
:rtype: boolean
"""
Expand Down Expand Up @@ -274,6 +277,8 @@ def is_valid(self, request_data):
debug = self.__settings.is_debug_active()
if debug:
print(err)
if raises:
raise
return False

def get_error(self):
Expand Down
8 changes: 7 additions & 1 deletion src/onelogin/saml2/logout_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,15 @@ def get_status(self):
status = entries[0].attrib['Value']
return status

def is_valid(self, request_data, request_id=None):
def is_valid(self, request_data, request_id=None, raises=False):
"""
Determines if the SAML LogoutResponse is valid
:param request_id: The ID of the LogoutRequest sent by this SP to the IdP
:type request_id: string

:param raises: Optional argument. If true, the function will raise an exception as soon as first validation test fails
:type raises: bool

:return: Returns if the SAML LogoutResponse is or not valid
:rtype: boolean
"""
Expand Down Expand Up @@ -111,6 +115,8 @@ def is_valid(self, request_data, request_id=None):
debug = self.__settings.is_debug_active()
if debug:
print(err)
if raises:
raise
return False

def __query(self, query):
Expand Down
7 changes: 6 additions & 1 deletion src/onelogin/saml2/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, settings, response):
self.encrypted = True
self.decrypted_document = self.__decrypt_assertion(decrypted_document)

def is_valid(self, request_data, request_id=None):
def is_valid(self, request_data, request_id=None, raises=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add this functionality to rest of is_valid methods (logoutrequest and logoutresponse)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like a good idea. I'll let you know when I have made the changes.

"""
Validates the response object.

Expand All @@ -57,6 +57,9 @@ def is_valid(self, request_data, request_id=None):
:param request_id: Optional argument. The ID of the AuthNRequest sent by this SP to the IdP
:type request_id: string

:param raises: Optional argument. If true, the function will raise an exception as soon as first validation test fails
:type raises: bool

:returns: True if the SAML Response is valid, False if not
:rtype: bool
"""
Expand Down Expand Up @@ -226,6 +229,8 @@ def is_valid(self, request_data, request_id=None):
debug = self.__settings.is_debug_active()
if debug:
print(err)
if raises:
raise
return False

def check_status(self):
Expand Down
59 changes: 29 additions & 30 deletions tests/src/OneLogin/saml2_tests/logout_request_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,8 @@ def testGetNameIdData(self):
self.assertEqual(expected_name_id_data, name_id_data_2)

request_2 = self.file_contents(join(self.data_path, 'logout_requests', 'logout_request_encrypted_nameid.xml'))
with self.assertRaises(Exception) as context:
with self.assertRaisesRegexp(Exception, 'Key is required in order to decrypt the NameID'):
OneLogin_Saml2_Logout_Request.get_nameid(request_2)
exception = context.exception
self.assertIn("Key is required in order to decrypt the NameID", str(exception))

settings = OneLogin_Saml2_Settings(self.loadSettingsJSON())
key = settings.get_sp_key()
Expand All @@ -140,16 +138,12 @@ def testGetNameIdData(self):
encrypted_id_nodes = dom_2.getElementsByTagName('saml:EncryptedID')
encrypted_data = encrypted_id_nodes[0].firstChild.nextSibling
encrypted_id_nodes[0].removeChild(encrypted_data)
with self.assertRaises(Exception) as context:
with self.assertRaisesRegexp(Exception, 'Not NameID found in the Logout Request'):
OneLogin_Saml2_Logout_Request.get_nameid(dom_2.toxml(), key)
exception = context.exception
self.assertIn("Not NameID found in the Logout Request", str(exception))

inv_request = self.file_contents(join(self.data_path, 'logout_requests', 'invalids', 'no_nameId.xml'))
with self.assertRaises(Exception) as context:
with self.assertRaisesRegexp(Exception, 'Not NameID found in the Logout Request'):
OneLogin_Saml2_Logout_Request.get_nameid(inv_request)
exception = context.exception
self.assertIn("Not NameID found in the Logout Request", str(exception))

def testGetNameId(self):
"""
Expand All @@ -160,10 +154,8 @@ def testGetNameId(self):
self.assertEqual(name_id, 'ONELOGIN_1e442c129e1f822c8096086a1103c5ee2c7cae1c')

request_2 = self.file_contents(join(self.data_path, 'logout_requests', 'logout_request_encrypted_nameid.xml'))
with self.assertRaises(Exception) as context:
with self.assertRaisesRegexp(Exception, 'Key is required in order to decrypt the NameID'):
OneLogin_Saml2_Logout_Request.get_nameid(request_2)
exception = context.exception
self.assertIn("Key is required in order to decrypt the NameID", str(exception))

settings = OneLogin_Saml2_Settings(self.loadSettingsJSON())
key = settings.get_sp_key()
Expand Down Expand Up @@ -242,12 +234,9 @@ def testIsInvalidIssuer(self):
self.assertTrue(logout_request.is_valid(request_data))

settings.set_strict(True)
try:
logout_request2 = OneLogin_Saml2_Logout_Request(settings, OneLogin_Saml2_Utils.b64encode(request))
valid = logout_request2.is_valid(request_data)
self.assertFalse(valid)
except Exception as e:
self.assertIn('Invalid issuer in the Logout Request', str(e))
logout_request2 = OneLogin_Saml2_Logout_Request(settings, OneLogin_Saml2_Utils.b64encode(request))
with self.assertRaisesRegexp(Exception, 'Invalid issuer in the Logout Request'):
logout_request2.is_valid(request_data, raises=True)

def testIsInvalidDestination(self):
"""
Expand All @@ -264,12 +253,9 @@ def testIsInvalidDestination(self):
self.assertTrue(logout_request.is_valid(request_data))

settings.set_strict(True)
try:
logout_request2 = OneLogin_Saml2_Logout_Request(settings, OneLogin_Saml2_Utils.b64encode(request))
valid = logout_request2.is_valid(request_data)
self.assertFalse(valid)
except Exception as e:
self.assertIn('The LogoutRequest was received at', str(e))
logout_request2 = OneLogin_Saml2_Logout_Request(settings, OneLogin_Saml2_Utils.b64encode(request))
with self.assertRaisesRegexp(Exception, 'The LogoutRequest was received at'):
logout_request2.is_valid(request_data, raises=True)

dom = parseString(request)
dom.documentElement.setAttribute('Destination', None)
Expand Down Expand Up @@ -298,12 +284,9 @@ def testIsInvalidNotOnOrAfter(self):
self.assertTrue(logout_request.is_valid(request_data))

settings.set_strict(True)
try:
logout_request2 = OneLogin_Saml2_Logout_Request(settings, OneLogin_Saml2_Utils.b64encode(request))
valid = logout_request2.is_valid(request_data)
self.assertFalse(valid)
except Exception as e:
self.assertIn('Timing issues (please check your clock settings)', str(e))
logout_request2 = OneLogin_Saml2_Logout_Request(settings, OneLogin_Saml2_Utils.b64encode(request))
with self.assertRaisesRegexp(Exception, 'Timing issues \(please check your clock settings\)'):
logout_request2.is_valid(request_data, raises=True)

def testIsValid(self):
"""
Expand Down Expand Up @@ -336,3 +319,19 @@ def testIsValid(self):
request = request.replace('http://stuff.com/endpoints/endpoints/sls.php', current_url)
logout_request5 = OneLogin_Saml2_Logout_Request(settings, OneLogin_Saml2_Utils.b64encode(request))
self.assertTrue(logout_request5.is_valid(request_data))

def testIsValidRaisesExceptionWhenRaisesArgumentIsTrue(self):
request = OneLogin_Saml2_Utils.b64encode('<xml>invalid</xml>')
request_data = {
'http_host': 'example.com',
'script_name': 'index.html',
}
settings = OneLogin_Saml2_Settings(self.loadSettingsJSON())
settings.set_strict(True)

logout_request = OneLogin_Saml2_Logout_Request(settings, request)

self.assertFalse(logout_request.is_valid(request_data))

with self.assertRaises(Exception):
logout_request.is_valid(request_data, raises=True)
38 changes: 23 additions & 15 deletions tests/src/OneLogin/saml2_tests/logout_response_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,8 @@ def testIsInValidIssuer(self):

settings.set_strict(True)
response_2 = OneLogin_Saml2_Logout_Response(settings, message)
try:
valid = response_2.is_valid(request_data)
self.assertFalse(valid)
except Exception as e:
self.assertIn('Invalid issuer in the Logout Request', str(e))
with self.assertRaisesRegexp(Exception, 'Invalid issuer in the Logout Request'):
response_2.is_valid(request_data, raises=True)

def testIsInValidDestination(self):
"""
Expand All @@ -226,11 +223,8 @@ def testIsInValidDestination(self):

settings.set_strict(True)
response_2 = OneLogin_Saml2_Logout_Response(settings, message)
try:
valid = response_2.is_valid(request_data)
self.assertFalse(valid)
except Exception as e:
self.assertIn('The LogoutRequest was received at', str(e))
with self.assertRaisesRegexp(Exception, 'The LogoutRequest was received at'):
response_2.is_valid(request_data, raises=True)

# Empty destination
dom = parseString(OneLogin_Saml2_Utils.decode_base64_and_inflate(message))
Expand Down Expand Up @@ -264,11 +258,8 @@ def testIsValid(self):

settings.set_strict(True)
response_2 = OneLogin_Saml2_Logout_Response(settings, message)
try:
valid = response_2.is_valid(request_data)
self.assertFalse(valid)
except Exception as e:
self.assertIn('The LogoutRequest was received at', str(e))
with self.assertRaisesRegexp(Exception, 'The LogoutRequest was received at'):
response_2.is_valid(request_data, raises=True)

plain_message = compat.to_string(OneLogin_Saml2_Utils.decode_base64_and_inflate(message))
current_url = OneLogin_Saml2_Utils.get_self_url_no_query(request_data)
Expand All @@ -277,3 +268,20 @@ def testIsValid(self):

response_3 = OneLogin_Saml2_Logout_Response(settings, message_3)
self.assertTrue(response_3.is_valid(request_data))

def testIsValidRaisesExceptionWhenRaisesArgumentIsTrue(self):
message = OneLogin_Saml2_Utils.deflate_and_base64_encode('<xml>invalid</xml>')
request_data = {
'http_host': 'example.com',
'script_name': 'index.html',
'get_data': {}
}
settings = OneLogin_Saml2_Settings(self.loadSettingsJSON())
settings.set_strict(True)

response = OneLogin_Saml2_Logout_Response(settings, message)

self.assertFalse(response.is_valid(request_data))

with self.assertRaises(Exception):
response.is_valid(request_data, raises=True)
Loading