From db30e73bd208549cecc8acc3e25623a4e78cd1b4 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 26 Sep 2022 12:56:17 +0200 Subject: [PATCH 1/8] Refactor OIDC tests to better mimic an actual OIDC provider Instead of constantly mocking the internal methods of the OIDC handler, it now mocks HTTP responses Signed-off-by: Quentin Gliech --- poetry.lock | 2 +- pyproject.toml | 2 +- synapse/handlers/oidc.py | 11 +- tests/federation/test_federation_client.py | 31 +- tests/handlers/test_oidc.py | 580 +++++++++------------ tests/rest/client/test_auth.py | 27 +- tests/rest/client/test_login.py | 40 +- tests/rest/client/utils.py | 125 +++-- tests/test_utils/__init__.py | 40 +- tests/test_utils/oidc.py | 313 +++++++++++ 10 files changed, 718 insertions(+), 453 deletions(-) create mode 100644 tests/test_utils/oidc.py diff --git a/poetry.lock b/poetry.lock index 0f6d1cfa6944..0906e09d8398 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1626,7 +1626,7 @@ url_preview = ["lxml"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "1b14fc274d9e2a495a7f864150f3ffcf4d9f585e09a67e53301ae4ef3c2f3e48" +content-hash = "1d3463bc88a8db5ce8be85e868a80090b2b164bfc9e06c64e907f7ec9ccaee2c" [metadata.files] attrs = [ diff --git a/pyproject.toml b/pyproject.toml index 0a4242fb7201..77e890fbcb6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,7 +192,7 @@ psycopg2 = { version = ">=2.8", markers = "platform_python_implementation != 'Py psycopg2cffi = { version = ">=2.8", markers = "platform_python_implementation == 'PyPy'", optional = true } psycopg2cffi-compat = { version = "==1.1", markers = "platform_python_implementation == 'PyPy'", optional = true } pysaml2 = { version = ">=4.5.0", optional = true } -authlib = { version = ">=0.14.0", optional = true } +authlib = { version = ">=0.15.1", optional = true } # systemd-python is necessary for logging to the systemd journal via # `systemd.journal.JournalHandler`, as is documented in # `contrib/systemd/log_config.yaml`. diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index d7a82269006a..3f46e0f74865 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -275,6 +275,7 @@ def __init__( provider: OidcProviderConfig, ): self._store = hs.get_datastores().main + self._clock = hs.get_clock() self._macaroon_generaton = macaroon_generator @@ -673,6 +674,12 @@ async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: Returns: The decoded claims in the ID token. """ + id_token = token.get("id_token") + + # That has been theoritically been checked by the caller, so even though + # assertion are not enabled in production, it is mainly here to appease mypy + assert id_token is not None + metadata = await self.load_metadata() claims_params = { "nonce": nonce, @@ -715,7 +722,9 @@ async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: logger.debug("Decoded id_token JWT %r; validating", claims) - claims.validate(leeway=120) # allows 2 min of clock skew + claims.validate( + now=self._clock.time(), leeway=120 + ) # allows 2 min of clock skew return claims diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index 50e376f69574..7839797136ec 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -12,21 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from unittest import mock import twisted.web.client from twisted.internet import defer -from twisted.internet.protocol import Protocol -from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.server import HomeServer -from synapse.types import JsonDict from synapse.util import Clock +from tests.test_utils import FakeResponse from tests.unittest import FederatingHomeserverTestCase @@ -89,8 +86,8 @@ def test_get_room_state(self): # mock up the response, and have the agent return it self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "pdus": [ create_event_dict, member_event_dict, @@ -199,8 +196,8 @@ def _get_pdu_once(self) -> EventBase: # mock up the response, and have the agent return it self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( - _mock_response( - { + FakeResponse.json( + payload={ "origin": "yet.another.server", "origin_server_ts": 900, "pdus": [ @@ -230,21 +227,3 @@ def _get_pdu_once(self) -> EventBase: self.assertEqual(remote_pdu.internal_metadata.outlier, False) return remote_pdu - - -def _mock_response(resp: JsonDict): - body = json.dumps(resp).encode("utf-8") - - def deliver_body(p: Protocol): - p.dataReceived(body) - p.connectionLost(Failure(twisted.web.client.ResponseDone())) - - response = mock.Mock( - code=200, - phrase=b"OK", - headers=twisted.web.client.Headers({"content-Type": ["application/json"]}), - length=len(body), - deliverBody=deliver_body, - ) - mock.seal(response) - return response diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e6cd3af7b756..09feb1f52b2d 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os -from typing import Any, Dict +from typing import Any, Dict, Tuple from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse @@ -22,12 +21,15 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.sso import MappingException +from synapse.http.site import SynapseRequest from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import UserID from synapse.util import Clock -from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon +from synapse.util.macaroons import get_value_from_macaroon +from synapse.util.stringutils import random_string from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcProvider from tests.unittest import HomeserverTestCase, override_config try: @@ -46,12 +48,6 @@ CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" SCOPES = ["openid"] -AUTHORIZATION_ENDPOINT = ISSUER + "authorize" -TOKEN_ENDPOINT = ISSUER + "token" -USERINFO_ENDPOINT = ISSUER + "userinfo" -WELL_KNOWN = ISSUER + ".well-known/openid-configuration" -JWKS_URI = ISSUER + ".well-known/jwks.json" - # config for common cases DEFAULT_CONFIG = { "enabled": True, @@ -66,9 +62,9 @@ EXPLICIT_ENDPOINT_CONFIG = { **DEFAULT_CONFIG, "discover": False, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, + "authorization_endpoint": ISSUER + "authorize", + "token_endpoint": ISSUER + "token", + "jwks_uri": ISSUER + "jwks", } @@ -102,27 +98,6 @@ async def map_user_attributes(self, userinfo, token, failures): } -async def get_json(url: str) -> JsonDict: - # Mock get_json calls to handle jwks & oidc discovery endpoints - if url == WELL_KNOWN: - # Minimal discovery document, as defined in OpenID.Discovery - # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata - return { - "issuer": ISSUER, - "authorization_endpoint": AUTHORIZATION_ENDPOINT, - "token_endpoint": TOKEN_ENDPOINT, - "jwks_uri": JWKS_URI, - "userinfo_endpoint": USERINFO_ENDPOINT, - "response_types_supported": ["code"], - "subject_types_supported": ["public"], - "id_token_signing_alg_values_supported": ["RS256"], - } - elif url == JWKS_URI: - return {"keys": []} - - return {} - - def _key_file_path() -> str: """path to a file containing the private half of a test key""" @@ -159,11 +134,11 @@ def default_config(self) -> Dict[str, Any]: return config def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.http_client = Mock(spec=["get_json"]) - self.http_client.get_json.side_effect = get_json - self.http_client.user_agent = b"Synapse Test" + self.fake_provider = FakeOidcProvider(clock=clock, issuer=ISSUER) - hs = self.setup_test_homeserver(proxied_http_client=self.http_client) + hs = self.setup_test_homeserver() + self.hs_patcher = self.fake_provider.patch_homeserver(hs=hs) + self.hs_patcher.start() self.handler = hs.get_oidc_handler() self.provider = self.handler._providers["oidc"] @@ -175,18 +150,51 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 + auth_handler = hs.get_auth_handler() + # Mock the complete SSO login method. + self.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] + return hs + def tearDown(self) -> None: + self.hs_patcher.stop() + return super().tearDown() + + def reset_mocks(self): + """Reset all the Mocks.""" + self.fake_provider.reset_mocks() + self.render_error.reset_mock() + self.complete_sso_login.reset_mock() + def metadata_edit(self, values): """Modify the result that will be returned by the well-known query""" - async def patched_get_json(uri): - res = await get_json(uri) - if uri == WELL_KNOWN: - res.update(values) - return res + metadata = self.fake_provider.get_metadata() + metadata.update(values) + return patch.object(self.fake_provider, "get_metadata", return_value=metadata) - return patch.object(self.http_client, "get_json", patched_get_json) + def start_authorization( + self, + userinfo: dict, + client_redirect_url: str = "http://client/redirect", + scope: str = "openid", + with_sid: bool = False, + ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]: + """Start an authorization request, and get the callback request back.""" + nonce = random_string(10) + state = random_string(10) + + code, grant = self.fake_provider.start_authorization( + userinfo=userinfo, + scope=scope, + client_id=self.provider._client_auth.client_id, + redirect_uri=self.provider._callback_url, + nonce=nonce, + with_sid=with_sid, + ) + session = self._generate_oidc_session_token(state, nonce, client_redirect_url) + return _build_callback_request(code, state, session), grant def assertRenderedError(self, error, error_description=None): self.render_error.assert_called_once() @@ -210,52 +218,54 @@ def test_discovery(self) -> None: """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid metadata = self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.fake_provider.get_metadata_handler.assert_called_once() - self.assertEqual(metadata.issuer, ISSUER) - self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT) - self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT) - self.assertEqual(metadata.jwks_uri, JWKS_URI) - # FIXME: it seems like authlib does not have that defined in its metadata models - # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT) + self.assertEqual(metadata.issuer, self.fake_provider.issuer) + self.assertEqual( + metadata.authorization_endpoint, + self.fake_provider.authorization_endpoint, + ) + self.assertEqual(metadata.token_endpoint, self.fake_provider.token_endpoint) + self.assertEqual(metadata.jwks_uri, self.fake_provider.jwks_uri) + # It seems like authlib does not have that defined in its metadata models + self.assertEqual( + metadata.get("userinfo_endpoint"), + self.fake_provider.userinfo_endpoint, + ) # subsequent calls should be cached - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.fake_provider.get_metadata_handler.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) - self.http_client.get_json.assert_not_called() + self.fake_provider.get_metadata_handler.assert_not_called() - @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_called_once_with(JWKS_URI) - self.assertEqual(jwks, {"keys": []}) + self.fake_provider.get_jwks_handler.assert_called_once() + self.assertEqual(jwks, self.fake_provider.get_jwks()) # subsequent calls should be cached… - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks()) - self.http_client.get_json.assert_not_called() + self.fake_provider.get_jwks_handler.assert_not_called() # …unless forced - self.http_client.reset_mock() + self.reset_mocks() self.get_success(self.provider.load_jwks(force=True)) - self.http_client.get_json.assert_called_once_with(JWKS_URI) + self.fake_provider.get_jwks_handler.assert_called_once() - # Throw if the JWKS uri is missing - original = self.provider.load_metadata - - async def patched_load_metadata(): - m = (await original()).copy() - m.update({"jwks_uri": None}) - return m - - with patch.object(self.provider, "load_metadata", patched_load_metadata): + with self.metadata_edit({"jwks_uri": None}): + # If we don't do this, the load_metadata call will throw because of the + # missing jwks_uri + self.provider._user_profile_method = "userinfo_endpoint" + self.get_success(self.provider.load_metadata(force=True)) self.get_failure(self.provider.load_jwks(force=True), RuntimeError) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -359,7 +369,7 @@ def test_redirect_request(self) -> None: self.provider.handle_redirect_request(req, b"http://client/redirect") ) ) - auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) + auth_endpoint = urlparse(self.fake_provider.authorization_endpoint) self.assertEqual(url.scheme, auth_endpoint.scheme) self.assertEqual(url.netloc, auth_endpoint.netloc) @@ -424,48 +434,34 @@ def test_callback(self) -> None: with self.assertRaises(AttributeError): _ = mapping_provider.get_extra_attributes - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } username = "bar" userinfo = { "sub": "foo", "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - code = "code" - state = "state" - nonce = "nonce" client_redirect_url = "http://client/redirect" - ip_address = "10.0.0.1" - session = self._generate_oidc_session_token(state, nonce, client_redirect_url) - request = _build_callback_request(code, state, session, ip_address=ip_address) - + request, _ = self.start_authorization( + userinfo, client_redirect_url=client_redirect_url + ) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, client_redirect_url, None, new_user=True, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_not_called() + self.fake_provider.post_token_handler.assert_called_once() + self.fake_provider.get_userinfo_handler.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors + request, _ = self.start_authorization(userinfo) with patch.object( self.provider, "_remote_id_from_userinfo", @@ -475,81 +471,63 @@ def test_callback(self) -> None: self.assertRenderedError("mapping_error") # Handle ID token errors - self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request, _ = self.start_authorization(userinfo) + with self.fake_provider.id_token_override({"iss": "https://bad.issuer/"}): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") - auth_handler.complete_sso_login.reset_mock() - self.provider._exchange_code.reset_mock() - self.provider._parse_id_token.reset_mock() - self.provider._fetch_userinfo.reset_mock() + self.reset_mocks() # With userinfo fetching self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + # Without the "openid" scope, the FakeProvider does not generate an id_token + request, _ = self.start_authorization(userinfo, scope="") self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, None, new_user=False, auth_provider_session_id=None, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_not_called() - self.provider._fetch_userinfo.assert_called_once_with(token) + self.fake_provider.post_token_handler.assert_called_once() + self.fake_provider.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() + self.reset_mocks() + # With an ID token, userinfo fetching and sid in the ID token self.provider._user_profile_method = "userinfo_endpoint" - token = { - "type": "bearer", - "access_token": "access_token", - "id_token": "id_token", - } - id_token = { - "sid": "abcdefgh", - } - self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment] - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - auth_handler.complete_sso_login.reset_mock() - self.provider._fetch_userinfo.reset_mock() + request, grant = self.start_authorization(userinfo, with_sid=True) + self.assertIsNotNone(grant.sid) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( expected_user_id, - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, None, new_user=False, - auth_provider_session_id=id_token["sid"], + auth_provider_session_id=grant.sid, ) - self.provider._exchange_code.assert_called_once_with(code) - self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.provider._fetch_userinfo.assert_called_once_with(token) + self.fake_provider.post_token_handler.assert_called_once() + self.fake_provider.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() # Handle userinfo fetching error - self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment] - self.get_success(self.handler.handle_oidc_callback(request)) + request, _ = self.start_authorization(userinfo) + with self.fake_provider.buggy_endpoint(userinfo=True): + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") - # Handle code exchange failure - from synapse.handlers.oidc import OidcError - - self.provider._exchange_code = simple_async_mock( # type: ignore[assignment] - raises=OidcError("invalid_request") - ) - self.get_success(self.handler.handle_oidc_callback(request)) - self.assertRenderedError("invalid_request") + request, _ = self.start_authorization(userinfo) + with self.fake_provider.buggy_endpoint(token=True): + self.get_success(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("server_error") @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_session(self) -> None: @@ -599,18 +577,22 @@ def test_callback_session(self) -> None: ) def test_exchange_code(self) -> None: """Code exchange behaves correctly and handles various error scenarios.""" - token = {"type": "bearer"} - token_json = json.dumps(token).encode("utf-8") - self.http_client.request = simple_async_mock( - return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_provider.post_token_handler.side_effect = None + self.fake_provider.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_provider.request.call_args[1] self.assertEqual(ret, token) self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint) args = parse_qs(kwargs["data"].decode("utf-8")) self.assertEqual(args["grant_type"], ["authorization_code"]) @@ -620,12 +602,8 @@ def test_exchange_code(self) -> None: self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) # Test error handling - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad Request", - body=b'{"error": "foo", "error_description": "bar"}', - ) + self.fake_provider.post_token_handler.return_value = FakeResponse.json( + code=400, payload={"error": "foo", "error_description": "bar"} ) from synapse.handlers.oidc import OidcError @@ -634,46 +612,30 @@ def test_exchange_code(self) -> None: self.assertEqual(exc.value.error_description, "bar") # Internal server error with no JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b"Not JSON", - ) + self.fake_provider.post_token_handler.return_value = FakeResponse( + code=500, body=b"Not JSON" ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=500, - phrase=b"Internal Server Error", - body=b'{"error": "internal_server_error"}', - ) + self.fake_provider.post_token_handler.return_value = FakeResponse.json( + code=500, payload={"error": "internal_server_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=400, - phrase=b"Bad request", - body=b"{}", - ) + self.fake_provider.post_token_handler.return_value = FakeResponse.json( + code=400, payload={} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, - phrase=b"OK", - body=b'{"error": "some_error"}', - ) + self.fake_provider.post_token_handler.return_value = FakeResponse.json( + code=200, payload={"error": "some_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "some_error") @@ -697,11 +659,14 @@ def test_exchange_code_jwt_key(self) -> None: """Test that code exchange works with a JWK client secret.""" from authlib.jose import jwt - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_provider.post_token_handler.side_effect = None + self.fake_provider.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" @@ -714,9 +679,9 @@ def test_exchange_code_jwt_key(self) -> None: self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_provider.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint) # the client secret provided to the should be a jwt which can be checked with # the public key @@ -750,11 +715,14 @@ def test_exchange_code_jwt_key(self) -> None: ) def test_exchange_code_no_auth(self) -> None: """Test that code exchange works with no client secret.""" - token = {"type": "bearer"} - self.http_client.request = simple_async_mock( - return_value=FakeResponse( - code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") - ) + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_provider.post_token_handler.side_effect = None + self.fake_provider.post_token_handler.return_value = FakeResponse.json( + payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) @@ -762,9 +730,9 @@ def test_exchange_code_no_auth(self) -> None: self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.http_client.request.call_args[1] + kwargs = self.fake_provider.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint) # check the POSTed data args = parse_qs(kwargs["data"].decode("utf-8")) @@ -787,37 +755,19 @@ def test_extra_attributes(self) -> None: """ Login while using a mapping provider that implements get_extra_attributes. """ - token = { - "type": "bearer", - "id_token": "id_token", - "access_token": "access_token", - } userinfo = { "sub": "foo", "username": "foo", "phone": "1234567", } - self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - - state = "state" - client_redirect_url = "http://client/redirect" - session = self._generate_oidc_session_token( - state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ) - request = _build_callback_request("code", state, session) - + request, _ = self.start_authorization(userinfo) self.get_success(self.handler.handle_oidc_callback(request)) - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@foo:test", - "oidc", + self.provider.idp_id, request, - client_redirect_url, + ANY, {"phone": "1234567"}, new_user=True, auth_provider_session_id=None, @@ -826,41 +776,40 @@ def test_extra_attributes(self) -> None: @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_user(self) -> None: """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - userinfo: dict = { "sub": "test_user", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Some providers return an integer ID. userinfo = { "sub": 1234, "username": "test_user_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@test_user_2:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Test if the mxid is already taken store = self.hs.get_datastores().main @@ -869,8 +818,9 @@ def test_map_userinfo_to_user(self) -> None: store.register_user(user_id=user3.to_string(), password_hash=None) ) userinfo = {"sub": "test3", "username": "test_user_3"} - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Mapping provider does not support de-duplicating Matrix IDs", @@ -885,38 +835,37 @@ def test_map_userinfo_to_existing_user(self) -> None: store.register_user(user_id=user.to_string(), password_hash=None) ) - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # Map a user via SSO. userinfo = { "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Subsequent calls should map to the same mxid. - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Note that a second SSO user can be mapped to the same Matrix ID. (This # requires a unique sub, but something that maps to the same matrix ID, @@ -927,17 +876,18 @@ def test_map_userinfo_to_existing_user(self) -> None: "sub": "test1", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( user.to_string(), - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register some non-exact matching cases. user2 = UserID.from_string("@TEST_user_2:test") @@ -954,8 +904,9 @@ def test_map_userinfo_to_existing_user(self) -> None: "sub": "test2", "username": "TEST_USER_2", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() args = self.assertRenderedError("mapping_error") self.assertTrue( args[2].startswith( @@ -969,11 +920,12 @@ def test_map_userinfo_to_existing_user(self) -> None: store.register_user(user_id=user2.to_string(), password_hash=None) ) - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_called_once_with( + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_called_once_with( "@TEST_USER_2:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=False, @@ -983,9 +935,9 @@ def test_map_userinfo_to_existing_user(self) -> None: @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" - self.get_success( - _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) - ) + userinfo = {"sub": "test2", "username": "föö"} + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: föö") @override_config( @@ -1000,9 +952,6 @@ def test_map_userinfo_to_invalid_localpart(self) -> None: ) def test_map_userinfo_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - store = self.hs.get_datastores().main self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) @@ -1011,19 +960,20 @@ def test_map_userinfo_to_user_retries(self) -> None: "sub": "test", "username": "test_user", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # test_user is already taken, so test_user1 gets registered instead. - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@test_user1:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, auth_provider_session_id=None, ) - auth_handler.complete_sso_login.reset_mock() + self.reset_mocks() # Register all of the potential mxids for a particular OIDC username. self.get_success( @@ -1039,8 +989,9 @@ def test_map_userinfo_to_user_retries(self) -> None: "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) @@ -1052,7 +1003,8 @@ def test_empty_localpart(self) -> None: "sub": "tester", "username": "", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1071,7 +1023,8 @@ def test_null_localpart(self) -> None: "sub": "tester", "username": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error", "localpart is invalid: ") @override_config( @@ -1084,16 +1037,14 @@ def test_null_localpart(self) -> None: ) def test_attribute_requirements(self) -> None: """The required attributes must be met from the OIDC userinfo response.""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() - # userinfo lacking "test": "foobar" attribute should fail. userinfo = { "sub": "tester", "username": "tester", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": "foobar" attribute should succeed. userinfo = { @@ -1101,13 +1052,14 @@ def test_attribute_requirements(self) -> None: "username": "tester", "test": "foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, @@ -1124,21 +1076,20 @@ def test_attribute_requirements(self) -> None: ) def test_attribute_requirements_contains(self) -> None: """Test that auth succeeds if userinfo attribute CONTAINS required value""" - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed. userinfo = { "sub": "tester", "username": "tester", "test": ["foobar", "foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( + self.complete_sso_login.assert_called_once_with( "@tester:test", - "oidc", - ANY, + self.provider.idp_id, + request, ANY, None, new_user=True, @@ -1158,16 +1109,15 @@ def test_attribute_requirements_mismatch(self) -> None: Test that auth fails if attributes exist but don't match, or are non-string values. """ - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": "not_foobar" attribute should fail userinfo: dict = { "sub": "tester", "username": "tester", "test": "not_foobar", } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": ["foo", "bar"] attribute should fail userinfo = { @@ -1175,8 +1125,9 @@ def test_attribute_requirements_mismatch(self) -> None: "username": "tester", "test": ["foo", "bar"], } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": False attribute should fail # this is largely just to ensure we don't crash here @@ -1185,8 +1136,9 @@ def test_attribute_requirements_mismatch(self) -> None: "username": "tester", "test": False, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": None attribute should fail # a value of None breaks the OIDC spec, but it's important to not crash here @@ -1195,8 +1147,9 @@ def test_attribute_requirements_mismatch(self) -> None: "username": "tester", "test": None, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 1 attribute should fail # this is largely just to ensure we don't crash here @@ -1205,8 +1158,9 @@ def test_attribute_requirements_mismatch(self) -> None: "username": "tester", "test": 1, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() # userinfo with "test": 3.14 attribute should fail # this is largely just to ensure we don't crash here @@ -1215,8 +1169,9 @@ def test_attribute_requirements_mismatch(self) -> None: "username": "tester", "test": 3.14, } - self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) - auth_handler.complete_sso_login.assert_not_called() + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() def _generate_oidc_session_token( self, @@ -1230,7 +1185,7 @@ def _generate_oidc_session_token( return self.handler._macaroon_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( - idp_id="oidc", + idp_id=self.provider.idp_id, nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, @@ -1238,41 +1193,6 @@ def _generate_oidc_session_token( ) -async def _make_callback_with_userinfo( - hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" -) -> None: - """Mock up an OIDC callback with the given userinfo dict - - We'll pull out the OIDC handler from the homeserver, stub out a couple of methods, - and poke in the userinfo dict as if it were the response to an OIDC userinfo call. - - Args: - hs: the HomeServer impl to send the callback to. - userinfo: the OIDC userinfo dict - client_redirect_url: the URL to redirect to on success. - """ - - handler = hs.get_oidc_handler() - provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment] - provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] - - state = "state" - session = handler._macaroon_generator.generate_oidc_session_token( - state=state, - session_data=OidcSessionData( - idp_id="oidc", - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id="", - ), - ) - request = _build_callback_request("code", state, session) - - await handler.handle_oidc_callback(request) - - def _build_callback_request( code: str, state: str, diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 05355c7fb6d8..bf28dac2b23e 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -464,9 +464,11 @@ def test_ui_auth_via_sso(self) -> None: * checking that the original operation succeeds """ + fake_oidc_provider = self.helper.fake_oidc_provider() + # log the user in remote_user_id = UserID.from_string(self.user).localpart - login_resp = self.helper.login_via_oidc(remote_user_id) + login_resp, _ = self.helper.login_via_oidc(fake_oidc_provider, remote_user_id) self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device @@ -480,8 +482,8 @@ def test_ui_auth_via_sso(self) -> None: # run the UIA-via-SSO flow session_id = channel.json_body["session"] - channel = self.helper.auth_via_oidc( - {"sub": remote_user_id}, ui_auth_session_id=session_id + channel, _ = self.helper.auth_via_oidc( + fake_oidc_provider, {"sub": remote_user_id}, ui_auth_session_id=session_id ) # that should serve a confirmation page @@ -498,7 +500,8 @@ def test_ui_auth_via_sso(self) -> None: @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_does_not_offer_password_for_sso_user(self) -> None: - login_resp = self.helper.login_via_oidc("username") + fake_oidc_provider = self.helper.fake_oidc_provider() + login_resp, _ = self.helper.login_via_oidc(fake_oidc_provider, "username") user_tok = login_resp["access_token"] device_id = login_resp["device_id"] @@ -521,7 +524,10 @@ def test_does_not_offer_sso_for_password_user(self) -> None: @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_offers_both_flows_for_upgraded_user(self) -> None: """A user that had a password and then logged in with SSO should get both flows""" - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + fake_oidc_provider = self.helper.fake_oidc_provider() + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_provider, UserID.from_string(self.user).localpart + ) self.assertEqual(login_resp["user_id"], self.user) channel = self.delete_device( @@ -538,8 +544,13 @@ def test_offers_both_flows_for_upgraded_user(self) -> None: @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: """If the user tries to authenticate with the wrong SSO user, they get an error""" + + fake_oidc_provider = self.helper.fake_oidc_provider() + # log the user in - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + login_resp, _ = self.helper.login_via_oidc( + fake_oidc_provider, UserID.from_string(self.user).localpart + ) self.assertEqual(login_resp["user_id"], self.user) # start a UI Auth flow by attempting to delete a device @@ -552,8 +563,8 @@ def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: session_id = channel.json_body["session"] # do the OIDC auth, but auth as the wrong user - channel = self.helper.auth_via_oidc( - {"sub": "wrong_user"}, ui_auth_session_id=session_id + channel, _ = self.helper.auth_via_oidc( + fake_oidc_provider, {"sub": "wrong_user"}, ui_auth_session_id=session_id ) # that should return a failure message diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index e2a4d982755a..9c8ea6656edc 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -34,7 +34,7 @@ from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 -from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.rest.client.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless @@ -571,13 +571,16 @@ def test_multi_sso_redirect_to_saml(self) -> None: def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - # pick the default OIDC provider - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=oidc", - ) + fake_oidc_provider = self.helper.fake_oidc_provider() + + with fake_oidc_provider.patch_homeserver(hs=self.hs): + # pick the default OIDC provider + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=oidc", + ) self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -585,7 +588,7 @@ def test_login_via_oidc(self) -> None: oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + self.assertEqual(oidc_uri_path, fake_oidc_provider.authorization_endpoint) # ... and should have set a cookie including the redirect url cookie_headers = channel.headers.getRawHeaders("Set-Cookie") @@ -602,7 +605,9 @@ def test_login_via_oidc(self) -> None: TEST_CLIENT_REDIRECT_URL, ) - channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + channel, _ = self.helper.complete_oidc_auth( + fake_oidc_provider, oidc_uri, cookies, {"sub": "user1"} + ) # that should serve a confirmation page self.assertEqual(channel.code, 200, channel.result) @@ -652,7 +657,10 @@ def test_client_idp_redirect_to_unknown(self) -> None: def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" - channel = self._make_sso_redirect_request("oidc") + fake_oidc_provider = self.helper.fake_oidc_provider() + + with fake_oidc_provider.patch_homeserver(hs=self.hs): + channel = self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -660,7 +668,7 @@ def test_client_idp_redirect_to_oidc(self) -> None: oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + self.assertEqual(oidc_uri_path, fake_oidc_provider.authorization_endpoint) def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect @@ -1239,9 +1247,13 @@ def create_resource_dict(self) -> Dict[str, Resource]: def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" + fake_oidc_provider = self.helper.fake_oidc_provider() + # do the start of the login flow - channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL + channel, _ = self.helper.auth_via_oidc( + fake_oidc_provider, + {"sub": "tester", "displayname": "Jonny"}, + TEST_CLIENT_REDIRECT_URL, ) # that should redirect to the username picker diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index dd26145bf8c1..69bb41a62538 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -31,7 +31,6 @@ Tuple, overload, ) -from unittest.mock import patch from urllib.parse import urlencode import attr @@ -46,8 +45,19 @@ from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request -from tests.test_utils import FakeResponse from tests.test_utils.html_parsers import TestHtmlParser +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcProvider + +# an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_ISSUER = "https://issuer.test/" +TEST_OIDC_CONFIG = { + "enabled": True, + "issuer": TEST_OIDC_ISSUER, + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["openid"], + "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, +} @attr.s(auto_attribs=True) @@ -543,7 +553,15 @@ def upload_media( return channel.json_body - def login_via_oidc(self, remote_user_id: str) -> JsonDict: + def fake_oidc_provider(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcProvider: + return FakeOidcProvider( + clock=self.hs.get_clock(), + issuer=issuer, + ) + + def login_via_oidc( + self, provider: FakeOidcProvider, remote_user_id: str, with_sid: bool = False + ) -> Tuple[JsonDict, FakeAuthorizationGrant]: """Log in (as a new user) via OIDC Returns the result of the final token login. @@ -556,7 +574,10 @@ def login_via_oidc(self, remote_user_id: str) -> JsonDict: the normal places. """ client_redirect_url = "https://x" - channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) + userinfo = {"sub": remote_user_id} + channel, grant = self.auth_via_oidc( + provider, userinfo, client_redirect_url, with_sid=with_sid + ) # expect a confirmation page assert channel.code == HTTPStatus.OK, channel.result @@ -579,14 +600,16 @@ def login_via_oidc(self, remote_user_id: str) -> JsonDict: content={"type": "m.login.token", "token": login_token}, ) assert channel.code == HTTPStatus.OK - return channel.json_body + return channel.json_body, grant def auth_via_oidc( self, + provider: FakeOidcProvider, user_info_dict: JsonDict, client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, - ) -> FakeChannel: + with_sid: bool = False, + ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Perform an OIDC authentication flow via a mock OIDC provider. This can be used for either login or user-interactive auth. @@ -610,6 +633,7 @@ def auth_via_oidc( the login redirect endpoint ui_auth_session_id: if set, we will perform a UI Auth flow. The session id of the UI auth. + with_sid: if True, generates a random `sid` (OIDC session ID) Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. @@ -619,14 +643,15 @@ def auth_via_oidc( cookies: Dict[str, str] = {} - # if we're doing a ui auth, hit the ui auth redirect endpoint - if ui_auth_session_id: - # can't set the client redirect url for UI Auth - assert client_redirect_url is None - oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) - else: - # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + with provider.patch_homeserver(hs=self.hs): + # if we're doing a ui auth, hit the ui auth redirect endpoint + if ui_auth_session_id: + # can't set the client redirect url for UI Auth + assert client_redirect_url is None + oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) + else: + # otherwise, hit the login redirect endpoint + oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" @@ -634,17 +659,21 @@ def auth_via_oidc( # that synapse passes to the client. oauth_uri_path, _ = oauth_uri.split("?", 1) - assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( + assert oauth_uri_path == provider.authorization_endpoint, ( "unexpected SSO URI " + oauth_uri_path ) - return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + return self.complete_oidc_auth( + provider, oauth_uri, cookies, user_info_dict, with_sid=with_sid + ) def complete_oidc_auth( self, + provider: FakeOidcProvider, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, - ) -> FakeChannel: + with_sid: bool = False, + ) -> Tuple[FakeChannel, FakeAuthorizationGrant]: """Mock out an OIDC authentication flow Assumes that an OIDC auth has been initiated by one of initiate_sso_login or @@ -661,44 +690,30 @@ def complete_oidc_auth( sent back to the callback endpoint. user_info_dict: the remote userinfo that the OIDC provider should present. Typically this should be '{"sub": ""}'. + with_sid: if True, generates a random `sid` (OIDC session ID) Returns: A FakeChannel containing the result of calling the OIDC callback endpoint. """ _, oauth_uri_qs = oauth_uri.split("?", 1) params = urllib.parse.parse_qs(oauth_uri_qs) + + code, grant = provider.start_authorization( + scope=params["scope"][0], + userinfo=user_info_dict, + client_id=params["client_id"][0], + redirect_uri=params["redirect_uri"][0], + nonce=params["nonce"][0], + with_sid=with_sid, + ) + state = params["state"][0] + callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, - urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), + urllib.parse.urlencode({"state": state, "code": code}), ) - # before we hit the callback uri, stub out some methods in the http client so - # that we don't have to handle full HTTPS requests. - # (expected url, json response) pairs, in the order we expect them. - expected_requests = [ - # first we get a hit to the token endpoint, which we tell to return - # a dummy OIDC access token - (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), - # and then one to the user_info endpoint, which returns our remote user id. - (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), - ] - - async def mock_req( - method: str, - uri: str, - data: Optional[dict] = None, - headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): - (expected_uri, resp_obj) = expected_requests.pop(0) - assert uri == expected_uri - resp = FakeResponse( - code=HTTPStatus.OK, - phrase=b"OK", - body=json.dumps(resp_obj).encode("utf-8"), - ) - return resp - - with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): + with provider.patch_homeserver(hs=self.hs): # now hit the callback URI with the right params and a made-up code channel = make_request( self.hs.get_reactor(), @@ -709,7 +724,7 @@ async def mock_req( ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() ], ) - return channel + return channel, grant def initiate_sso_login( self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] @@ -800,21 +815,3 @@ def initiate_sso_ui_auth( assert len(p.links) == 1, "not exactly one link in confirmation page" oauth_uri = p.links[0] return oauth_uri - - -# an 'oidc_config' suitable for login_via_oidc. -TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" -TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" -TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" -TEST_OIDC_CONFIG = { - "enabled": True, - "discover": False, - "issuer": "https://issuer.test", - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "scopes": ["profile"], - "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, - "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, - "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, -} diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 0d0d6faf0d3a..c47b63fff45d 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -15,17 +15,24 @@ """ Utilities for running the unit tests """ +import json import sys import warnings from asyncio import Future from binascii import unhexlify -from typing import Awaitable, Callable, TypeVar +from typing import Awaitable, Callable, Tuple, TypeVar from unittest.mock import Mock import attr +import zope.interface from twisted.python.failure import Failure from twisted.web.client import ResponseDone +from twisted.web.http import RESPONSES +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse + +from synapse.types import JsonDict TV = TypeVar("TV") @@ -97,27 +104,44 @@ async def cb(*args, **kwargs): return Mock(side_effect=cb) -@attr.s -class FakeResponse: +# Type ignore: it does not fully implement IResponse, but is good enough for tests +@zope.interface.implementer(IResponse) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FakeResponse: # type: ignore[misc] """A fake twisted.web.IResponse object there is a similar class at treq.test.test_response, but it lacks a `phrase` attribute, and didn't support deliverBody until recently. """ - # HTTP response code - code = attr.ib(type=int) + verison: Tuple[bytes, int, int] = (b"HTTP", 1, 1) - # HTTP response phrase (eg b'OK' for a 200) - phrase = attr.ib(type=bytes) + # HTTP response code + code: int = 200 # body of the response - body = attr.ib(type=bytes) + body: bytes = b"" + + headers: Headers = Headers() + + @property + def phrase(self): + return RESPONSES.get(self.code, b"Unknown Status") + + @property + def length(self): + return len(self.body) def deliverBody(self, protocol): protocol.dataReceived(self.body) protocol.connectionLost(Failure(ResponseDone())) + @classmethod + def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse": + headers = Headers({"Content-Type": ["application/json"]}) + body = json.dumps(payload).encode("utf-8") + return cls(code=code, body=body, headers=headers) + # A small image used in some tests. # diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py new file mode 100644 index 000000000000..d71b6150e0db --- /dev/null +++ b/tests/test_utils/oidc.py @@ -0,0 +1,313 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +from typing import Any, Dict, List, Optional, Tuple +from unittest.mock import Mock, patch +from urllib.parse import parse_qs + +import attr + +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse + +from synapse.server import HomeServer +from synapse.util import Clock +from synapse.util.stringutils import random_string + +from tests.test_utils import FakeResponse + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FakeAuthorizationGrant: + userinfo: dict + client_id: str + redirect_uri: str + scope: str + nonce: Optional[str] + sid: Optional[str] + + +class FakeOidcProvider: + """A fake OpenID Connect Provider.""" + + # All methods here are mocks, so we can track when they are called, and override + # their values + request: Mock + get_jwks_handler: Mock + get_metadata_handler: Mock + get_userinfo_handler: Mock + post_token_handler: Mock + + def __init__(self, clock: Clock, issuer: str): + from authlib.jose import ECKey, KeySet + + self.clock = clock + self.issuer = issuer + + self.request = Mock(side_effect=self._request) + self.get_jwks_handler = Mock(side_effect=self._get_jwks_handler) + self.get_metadata_handler = Mock(side_effect=self._get_metadata_handler) + self.get_userinfo_handler = Mock(side_effect=self._get_userinfo_handler) + self.post_token_handler = Mock(side_effect=self._post_token_handler) + + # A code -> grant mapping + self.authorization_grants: Dict[str, FakeAuthorizationGrant] = {} + # An access token -> grant mapping + self.sessions: Dict[str, FakeAuthorizationGrant] = {} + + # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for + # signing JWTs. ECDSA keys are really quick to generate compared to RSA. + self.key = ECKey.generate_key(crv="P-256", is_private=True) + self.jwks = KeySet([ECKey.import_key(self.key.raw_key.public_key())]) + + self._id_token_overrides: Dict[str, Any] = {} + + def reset_mocks(self): + self.request.reset_mock() + self.get_jwks_handler.reset_mock() + self.get_metadata_handler.reset_mock() + self.get_userinfo_handler.reset_mock() + self.post_token_handler.reset_mock() + + def patch_homeserver(self, hs: HomeServer): + return patch.object(hs.get_proxied_http_client(), "request", self.request) + + @property + def authorization_endpoint(self) -> str: + return self.issuer + "authorize" + + @property + def token_endpoint(self) -> str: + return self.issuer + "token" + + @property + def userinfo_endpoint(self) -> str: + return self.issuer + "userinfo" + + @property + def metadata_endpoint(self) -> str: + return self.issuer + ".well-known/openid-configuration" + + @property + def jwks_uri(self) -> str: + return self.issuer + "jwks" + + def get_metadata(self) -> dict: + return { + "issuer": self.issuer, + "authorization_endpoint": self.authorization_endpoint, + "token_endpoint": self.token_endpoint, + "jwks_uri": self.jwks_uri, + "userinfo_endpoint": self.userinfo_endpoint, + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["ES256"], + } + + def get_jwks(self) -> dict: + return self.jwks.as_dict() + + def get_userinfo(self, access_token: str) -> Optional[dict]: + """Given an access token, get the userinfo of the associated session.""" + session = self.sessions.get(access_token, None) + if session is None: + return None + return session.userinfo + + def _sign(self, payload: dict) -> str: + from authlib.jose import JsonWebSignature + + jws = JsonWebSignature() + kid = self.get_jwks()["keys"][0]["kid"] + protected = {"alg": "ES256", "kid": kid} + json_payload = json.dumps(payload) + return jws.serialize_compact(protected, json_payload, self.key).decode("utf-8") + + def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: + now = self.clock.time() + id_token = { + **grant.userinfo, + "iss": self.issuer, + "aud": grant.client_id, + "iat": now, + "nbf": now, + "exp": now + 600, + } + + if grant.nonce is not None: + id_token["nonce"] = grant.nonce + + if grant.sid is not None: + id_token["sid"] = grant.sid + + id_token.update(self._id_token_overrides) + + return self._sign(id_token) + + def id_token_override(self, overrides: dict): + """Temporarily patch the ID token generated by the token endpoint.""" + return patch.object(self, "_id_token_overrides", overrides) + + def start_authorization( + self, + client_id: str, + scope: str, + redirect_uri: str, + userinfo: dict, + nonce: Optional[str] = None, + with_sid: bool = False, + ) -> Tuple[str, FakeAuthorizationGrant]: + """Start an authorization request, and get back the code to use on the authorization endpoint.""" + code = random_string(10) + sid = None + if with_sid: + sid = random_string(10) + + grant = FakeAuthorizationGrant( + userinfo=userinfo, + scope=scope, + redirect_uri=redirect_uri, + nonce=nonce, + client_id=client_id, + sid=sid, + ) + self.authorization_grants[code] = grant + + return code, grant + + def exchange_code(self, code: str) -> Optional[Dict[str, Any]]: + grant = self.authorization_grants.pop(code, None) + if grant is None: + return None + + access_token = random_string(10) + self.sessions[access_token] = grant + + token = { + "token_type": "Bearer", + "access_token": access_token, + "expires_in": 3600, + "scope": grant.scope, + } + + if "openid" in grant.scope: + token["id_token"] = self.generate_id_token(grant) + + return dict(token) + + def buggy_endpoint( + self, + *, + jwks: bool = False, + metadata: bool = False, + token: bool = False, + userinfo: bool = False, + ): + """A context which makes a set of endpoints return a 500 error. + + Args: + jwks: If True, makes the JWKS endpoint return a 500 error. + metadata: If True, makes the OIDC Discovery endpoint return a 500 error. + token: If True, makes the token endpoint return a 500 error. + userinfo: If True, makes the userinfo endpoint return a 500 error. + """ + buggy = FakeResponse(code=500, body=b"Internal server error") + + patches = {} + if jwks: + patches["get_jwks_handler"] = Mock(return_value=buggy) + if metadata: + patches["get_metadata_handler"] = Mock(return_value=buggy) + if token: + patches["post_token_handler"] = Mock(return_value=buggy) + if userinfo: + patches["get_userinfo_handler"] = Mock(return_value=buggy) + + return patch.multiple(self, **patches) + + async def _request( + self, + method: str, + uri: str, + data: Optional[bytes] = None, + headers: Optional[Headers] = None, + ) -> IResponse: + """The override of the SimpleHttpClient#request() method""" + access_token: Optional[str] = None + + if headers is None: + headers = Headers() + + # Try to find the access token in the headers if any + auth_headers = headers.getRawHeaders(b"Authorization") + if auth_headers: + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + access_token = parts[1].decode("ascii") + + if method == "POST": + # If the method is POST, assume it has an url-encoded body + if data is None or headers.getRawHeaders(b"Content-Type") != [ + b"application/x-www-form-urlencoded" + ]: + return FakeResponse.json(code=400, payload={"error": "invalid_request"}) + + params = parse_qs(data.decode("utf-8")) + + if uri == self.token_endpoint: + return self.post_token_handler(params) + + elif method == "GET": + if uri == self.jwks_uri: + return self.get_jwks_handler() + elif uri == self.metadata_endpoint: + return self.get_metadata_handler() + elif uri == self.userinfo_endpoint: + return self.get_userinfo_handler(access_token=access_token) + + return FakeResponse(code=404, body=b"404 not found") + + # Request handlers + def _get_jwks_handler(self) -> IResponse: + """Handles requests to the JWKS URI.""" + return FakeResponse.json(payload=self.get_jwks()) + + def _get_metadata_handler(self) -> IResponse: + """Handles requests to the OIDC well-known document.""" + return FakeResponse.json(payload=self.get_metadata()) + + def _get_userinfo_handler(self, access_token: Optional[str]) -> IResponse: + """Handles requests to the userinfo endpoint.""" + if access_token is None: + return FakeResponse(code=401) + user_info = self.get_userinfo(access_token) + if user_info is None: + return FakeResponse(code=401) + + return FakeResponse.json(payload=user_info) + + def _post_token_handler(self, params: Dict[str, List[str]]) -> IResponse: + """Handles requests to the token endpoint.""" + code = params.get("code", []) + + if len(code) != 1: + return FakeResponse.json(code=400, payload={"error": "invalid_request"}) + + grant = self.exchange_code(code=code[0]) + if grant is None: + return FakeResponse.json(code=400, payload={"error": "invalid_grant"}) + + return FakeResponse.json(payload=grant) From 99a3929ab83533b8a8ed8eaac58040a39ba0a43c Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 26 Sep 2022 18:22:50 +0200 Subject: [PATCH 2/8] Newsfile. --- changelog.d/13910.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/13910.misc diff --git a/changelog.d/13910.misc b/changelog.d/13910.misc new file mode 100644 index 000000000000..e906952aabba --- /dev/null +++ b/changelog.d/13910.misc @@ -0,0 +1 @@ +Refactor OIDC tests to better mimic an actual OIDC provider. From a973c910d3d8ddbc697619260e8b021d085dc9e1 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 24 Oct 2022 15:25:16 +0200 Subject: [PATCH 3/8] Rename `FakeOidcProvider` -> `FakeOidcServer` + docstrings --- tests/handlers/test_oidc.py | 4 ++-- tests/rest/client/test_auth.py | 10 +++++----- tests/rest/client/test_login.py | 6 +++--- tests/rest/client/utils.py | 20 ++++++++++++++------ tests/test_utils/oidc.py | 11 ++++++++++- 5 files changed, 34 insertions(+), 17 deletions(-) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 09feb1f52b2d..e8394809efb4 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -29,7 +29,7 @@ from synapse.util.stringutils import random_string from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock -from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcProvider +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer from tests.unittest import HomeserverTestCase, override_config try: @@ -134,7 +134,7 @@ def default_config(self) -> Dict[str, Any]: return config def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.fake_provider = FakeOidcProvider(clock=clock, issuer=ISSUER) + self.fake_provider = FakeOidcServer(clock=clock, issuer=ISSUER) hs = self.setup_test_homeserver() self.hs_patcher = self.fake_provider.patch_homeserver(hs=hs) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index eb5cbccfd693..74a01104f672 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -465,7 +465,7 @@ def test_ui_auth_via_sso(self) -> None: * checking that the original operation succeeds """ - fake_oidc_provider = self.helper.fake_oidc_provider() + fake_oidc_provider = self.helper.fake_oidc_server() # log the user in remote_user_id = UserID.from_string(self.user).localpart @@ -501,7 +501,7 @@ def test_ui_auth_via_sso(self) -> None: @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_does_not_offer_password_for_sso_user(self) -> None: - fake_oidc_provider = self.helper.fake_oidc_provider() + fake_oidc_provider = self.helper.fake_oidc_server() login_resp, _ = self.helper.login_via_oidc(fake_oidc_provider, "username") user_tok = login_resp["access_token"] device_id = login_resp["device_id"] @@ -525,7 +525,7 @@ def test_does_not_offer_sso_for_password_user(self) -> None: @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_offers_both_flows_for_upgraded_user(self) -> None: """A user that had a password and then logged in with SSO should get both flows""" - fake_oidc_provider = self.helper.fake_oidc_provider() + fake_oidc_provider = self.helper.fake_oidc_server() login_resp, _ = self.helper.login_via_oidc( fake_oidc_provider, UserID.from_string(self.user).localpart ) @@ -546,7 +546,7 @@ def test_offers_both_flows_for_upgraded_user(self) -> None: def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: """If the user tries to authenticate with the wrong SSO user, they get an error""" - fake_oidc_provider = self.helper.fake_oidc_provider() + fake_oidc_provider = self.helper.fake_oidc_server() # log the user in login_resp, _ = self.helper.login_via_oidc( @@ -595,7 +595,7 @@ def test_sso_not_approved(self) -> None: """Tests that if we register a user via SSO while requiring approval for new accounts, we still raise the correct error before logging the user in. """ - fake_oidc_provider = self.helper.fake_oidc_provider() + fake_oidc_provider = self.helper.fake_oidc_server() login_resp, _ = self.helper.login_via_oidc( fake_oidc_provider, "username", expected_status=403 ) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 2802e088cfcc..51e62b650bac 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -612,7 +612,7 @@ def test_multi_sso_redirect_to_saml(self) -> None: def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - fake_oidc_provider = self.helper.fake_oidc_provider() + fake_oidc_provider = self.helper.fake_oidc_server() with fake_oidc_provider.patch_homeserver(hs=self.hs): # pick the default OIDC provider @@ -698,7 +698,7 @@ def test_client_idp_redirect_to_unknown(self) -> None: def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" - fake_oidc_provider = self.helper.fake_oidc_provider() + fake_oidc_provider = self.helper.fake_oidc_server() with fake_oidc_provider.patch_homeserver(hs=self.hs): channel = self._make_sso_redirect_request("oidc") @@ -1288,7 +1288,7 @@ def create_resource_dict(self) -> Dict[str, Resource]: def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" - fake_oidc_provider = self.helper.fake_oidc_provider() + fake_oidc_provider = self.helper.fake_oidc_server() # do the start of the login flow channel, _ = self.helper.auth_via_oidc( diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index b1cfd831980d..4dfcc07b6be2 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -46,7 +46,7 @@ from tests.server import FakeChannel, FakeSite, make_request from tests.test_utils.html_parsers import TestHtmlParser -from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcProvider +from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer # an 'oidc_config' suitable for login_via_oidc. TEST_OIDC_ISSUER = "https://issuer.test/" @@ -553,15 +553,23 @@ def upload_media( return channel.json_body - def fake_oidc_provider(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcProvider: - return FakeOidcProvider( + def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer: + """Create a ``FakeOidcServer``. + + This can be used in conjuction with ``login_via_oidc``:: + + fake_oidc_server = self.helper.fake_oidc_server() + login_data, _ = self.helper.login_via_oidc(fake_oidc_server, "user") + """ + + return FakeOidcServer( clock=self.hs.get_clock(), issuer=issuer, ) def login_via_oidc( self, - provider: FakeOidcProvider, + provider: FakeOidcServer, remote_user_id: str, with_sid: bool = False, expected_status: int = 200, @@ -610,7 +618,7 @@ def login_via_oidc( def auth_via_oidc( self, - provider: FakeOidcProvider, + provider: FakeOidcServer, user_info_dict: JsonDict, client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, @@ -674,7 +682,7 @@ def auth_via_oidc( def complete_oidc_auth( self, - provider: FakeOidcProvider, + provider: FakeOidcServer, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py index d71b6150e0db..f0b879bbe924 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py @@ -40,7 +40,7 @@ class FakeAuthorizationGrant: sid: Optional[str] -class FakeOidcProvider: +class FakeOidcServer: """A fake OpenID Connect Provider.""" # All methods here are mocks, so we can track when they are called, and override @@ -83,6 +83,15 @@ def reset_mocks(self): self.post_token_handler.reset_mock() def patch_homeserver(self, hs: HomeServer): + """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``. + + This patch should be used whenever the HS is expected to perform request to the + OIDC provider, e.g.:: + + fake_oidc_server = self.helper.fake_oidc_server() + with fake_oidc_server.patch_homeserver(hs): + self.make_request("GET", "/_matrix/client/r0/login/sso/redirect") + """ return patch.object(hs.get_proxied_http_client(), "request", self.request) @property From 731c68321530e06b0fec7ad78adc9ea441540808 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 24 Oct 2022 15:33:36 +0200 Subject: [PATCH 4/8] Comment about the token endpoint not checking auth in the FakeOidcServer --- tests/test_utils/oidc.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py index f0b879bbe924..95a05adb7dd5 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py @@ -277,6 +277,9 @@ async def _request( params = parse_qs(data.decode("utf-8")) if uri == self.token_endpoint: + # Even though this endpoint should be protected, this does not check + # for client authentication. We're not checking it for simplicity, + # and because client authentication is tested in other standalone tests. return self.post_token_handler(params) elif method == "GET": From cf16d0afa31291251a2447ad4e2f7698117352aa Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 24 Oct 2022 16:44:15 +0200 Subject: [PATCH 5/8] Fix tests with newer authlib versions --- tests/test_utils/oidc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py index 95a05adb7dd5..914446101717 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py @@ -71,7 +71,7 @@ def __init__(self, clock: Clock, issuer: str): # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for # signing JWTs. ECDSA keys are really quick to generate compared to RSA. self.key = ECKey.generate_key(crv="P-256", is_private=True) - self.jwks = KeySet([ECKey.import_key(self.key.raw_key.public_key())]) + self.jwks = KeySet([ECKey.import_key(self.key.as_pem(is_private=False))]) self._id_token_overrides: Dict[str, Any] = {} From 7c981b4b9b4de127053b2f656bef474d1c3df57e Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 24 Oct 2022 22:07:36 +0200 Subject: [PATCH 6/8] Suggestion from code review --- synapse/handlers/oidc.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 3f46e0f74865..9759daf043ad 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -675,6 +675,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: The decoded claims in the ID token. """ id_token = token.get("id_token") + logger.debug("Attempting to decode JWT id_token %r", id_token) # That has been theoritically been checked by the caller, so even though # assertion are not enabled in production, it is mainly here to appease mypy @@ -695,9 +696,6 @@ async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: claim_options = {"iss": {"values": [metadata["issuer"]]}} - id_token = token["id_token"] - logger.debug("Attempting to decode JWT id_token %r", id_token) - # Try to decode the keys in cache first, then retry by forcing the keys # to be reloaded jwk_set = await self.load_jwks() From 32335811fc49d9f7e6527b70d1bd53f90c994634 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 25 Oct 2022 15:10:23 +0200 Subject: [PATCH 7/8] Apply suggestions from code review Co-authored-by: Patrick Cloke --- tests/handlers/test_oidc.py | 90 ++++++++++++++++----------------- tests/rest/client/test_auth.py | 24 ++++----- tests/rest/client/test_login.py | 18 +++---- tests/test_utils/__init__.py | 4 +- tests/test_utils/oidc.py | 24 ++++----- 5 files changed, 80 insertions(+), 80 deletions(-) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e8394809efb4..5955410524c9 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -134,10 +134,10 @@ def default_config(self) -> Dict[str, Any]: return config def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.fake_provider = FakeOidcServer(clock=clock, issuer=ISSUER) + self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER) hs = self.setup_test_homeserver() - self.hs_patcher = self.fake_provider.patch_homeserver(hs=hs) + self.hs_patcher = self.fake_server.patch_homeserver(hs=hs) self.hs_patcher.start() self.handler = hs.get_oidc_handler() @@ -163,16 +163,16 @@ def tearDown(self) -> None: def reset_mocks(self): """Reset all the Mocks.""" - self.fake_provider.reset_mocks() + self.fake_server.reset_mocks() self.render_error.reset_mock() self.complete_sso_login.reset_mock() def metadata_edit(self, values): """Modify the result that will be returned by the well-known query""" - metadata = self.fake_provider.get_metadata() + metadata = self.fake_server.get_metadata() metadata.update(values) - return patch.object(self.fake_provider, "get_metadata", return_value=metadata) + return patch.object(self.fake_server, "get_metadata", return_value=metadata) def start_authorization( self, @@ -185,7 +185,7 @@ def start_authorization( nonce = random_string(10) state = random_string(10) - code, grant = self.fake_provider.start_authorization( + code, grant = self.fake_server.start_authorization( userinfo=userinfo, scope=scope, client_id=self.provider._client_auth.client_id, @@ -218,48 +218,48 @@ def test_discovery(self) -> None: """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid metadata = self.get_success(self.provider.load_metadata()) - self.fake_provider.get_metadata_handler.assert_called_once() + self.fake_server.get_metadata_handler.assert_called_once() - self.assertEqual(metadata.issuer, self.fake_provider.issuer) + self.assertEqual(metadata.issuer, self.fake_server.issuer) self.assertEqual( metadata.authorization_endpoint, - self.fake_provider.authorization_endpoint, + self.fake_server.authorization_endpoint, ) - self.assertEqual(metadata.token_endpoint, self.fake_provider.token_endpoint) - self.assertEqual(metadata.jwks_uri, self.fake_provider.jwks_uri) + self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint) + self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri) # It seems like authlib does not have that defined in its metadata models self.assertEqual( metadata.get("userinfo_endpoint"), - self.fake_provider.userinfo_endpoint, + self.fake_server.userinfo_endpoint, ) # subsequent calls should be cached self.reset_mocks() self.get_success(self.provider.load_metadata()) - self.fake_provider.get_metadata_handler.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) - self.fake_provider.get_metadata_handler.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() @override_config({"oidc_config": DEFAULT_CONFIG}) def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) - self.fake_provider.get_jwks_handler.assert_called_once() - self.assertEqual(jwks, self.fake_provider.get_jwks()) + self.fake_server.get_jwks_handler.assert_called_once() + self.assertEqual(jwks, self.fake_server.get_jwks()) # subsequent calls should be cached… self.reset_mocks() self.get_success(self.provider.load_jwks()) - self.fake_provider.get_jwks_handler.assert_not_called() + self.fake_server.get_jwks_handler.assert_not_called() # …unless forced self.reset_mocks() self.get_success(self.provider.load_jwks(force=True)) - self.fake_provider.get_jwks_handler.assert_called_once() + self.fake_server.get_jwks_handler.assert_called_once() with self.metadata_edit({"jwks_uri": None}): # If we don't do this, the load_metadata call will throw because of the @@ -369,7 +369,7 @@ def test_redirect_request(self) -> None: self.provider.handle_redirect_request(req, b"http://client/redirect") ) ) - auth_endpoint = urlparse(self.fake_provider.authorization_endpoint) + auth_endpoint = urlparse(self.fake_server.authorization_endpoint) self.assertEqual(url.scheme, auth_endpoint.scheme) self.assertEqual(url.netloc, auth_endpoint.netloc) @@ -456,8 +456,8 @@ def test_callback(self) -> None: new_user=True, auth_provider_session_id=None, ) - self.fake_provider.post_token_handler.assert_called_once() - self.fake_provider.get_userinfo_handler.assert_not_called() + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors @@ -472,7 +472,7 @@ def test_callback(self) -> None: # Handle ID token errors request, _ = self.start_authorization(userinfo) - with self.fake_provider.id_token_override({"iss": "https://bad.issuer/"}): + with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}): self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") @@ -493,8 +493,8 @@ def test_callback(self) -> None: new_user=False, auth_provider_session_id=None, ) - self.fake_provider.post_token_handler.assert_called_once() - self.fake_provider.get_userinfo_handler.assert_called_once() + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() self.reset_mocks() @@ -514,18 +514,18 @@ def test_callback(self) -> None: new_user=False, auth_provider_session_id=grant.sid, ) - self.fake_provider.post_token_handler.assert_called_once() - self.fake_provider.get_userinfo_handler.assert_called_once() + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() # Handle userinfo fetching error request, _ = self.start_authorization(userinfo) - with self.fake_provider.buggy_endpoint(userinfo=True): + with self.fake_server.buggy_endpoint(userinfo=True): self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") request, _ = self.start_authorization(userinfo) - with self.fake_provider.buggy_endpoint(token=True): + with self.fake_server.buggy_endpoint(token=True): self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("server_error") @@ -582,17 +582,17 @@ def test_exchange_code(self) -> None: "access_token": "aabbcc", } - self.fake_provider.post_token_handler.side_effect = None - self.fake_provider.post_token_handler.return_value = FakeResponse.json( + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) - kwargs = self.fake_provider.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(ret, token) self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) args = parse_qs(kwargs["data"].decode("utf-8")) self.assertEqual(args["grant_type"], ["authorization_code"]) @@ -602,7 +602,7 @@ def test_exchange_code(self) -> None: self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) # Test error handling - self.fake_provider.post_token_handler.return_value = FakeResponse.json( + self.fake_server.post_token_handler.return_value = FakeResponse.json( code=400, payload={"error": "foo", "error_description": "bar"} ) from synapse.handlers.oidc import OidcError @@ -612,14 +612,14 @@ def test_exchange_code(self) -> None: self.assertEqual(exc.value.error_description, "bar") # Internal server error with no JSON body - self.fake_provider.post_token_handler.return_value = FakeResponse( + self.fake_server.post_token_handler.return_value = FakeResponse( code=500, body=b"Not JSON" ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body - self.fake_provider.post_token_handler.return_value = FakeResponse.json( + self.fake_server.post_token_handler.return_value = FakeResponse.json( code=500, payload={"error": "internal_server_error"} ) @@ -627,14 +627,14 @@ def test_exchange_code(self) -> None: self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field - self.fake_provider.post_token_handler.return_value = FakeResponse.json( + self.fake_server.post_token_handler.return_value = FakeResponse.json( code=400, payload={} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field - self.fake_provider.post_token_handler.return_value = FakeResponse.json( + self.fake_server.post_token_handler.return_value = FakeResponse.json( code=200, payload={"error": "some_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) @@ -664,8 +664,8 @@ def test_exchange_code_jwt_key(self) -> None: "access_token": "aabbcc", } - self.fake_provider.post_token_handler.side_effect = None - self.fake_provider.post_token_handler.return_value = FakeResponse.json( + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( payload=token ) code = "code" @@ -679,9 +679,9 @@ def test_exchange_code_jwt_key(self) -> None: self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.fake_provider.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # the client secret provided to the should be a jwt which can be checked with # the public key @@ -720,8 +720,8 @@ def test_exchange_code_no_auth(self) -> None: "access_token": "aabbcc", } - self.fake_provider.post_token_handler.side_effect = None - self.fake_provider.post_token_handler.return_value = FakeResponse.json( + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( payload=token ) code = "code" @@ -730,9 +730,9 @@ def test_exchange_code_no_auth(self) -> None: self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.fake_provider.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # check the POSTed data args = parse_qs(kwargs["data"].decode("utf-8")) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 74a01104f672..ebf653d018f6 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -465,11 +465,11 @@ def test_ui_auth_via_sso(self) -> None: * checking that the original operation succeeds """ - fake_oidc_provider = self.helper.fake_oidc_server() + fake_oidc_server = self.helper.fake_oidc_server() # log the user in remote_user_id = UserID.from_string(self.user).localpart - login_resp, _ = self.helper.login_via_oidc(fake_oidc_provider, remote_user_id) + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id) self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device @@ -484,7 +484,7 @@ def test_ui_auth_via_sso(self) -> None: # run the UIA-via-SSO flow session_id = channel.json_body["session"] channel, _ = self.helper.auth_via_oidc( - fake_oidc_provider, {"sub": remote_user_id}, ui_auth_session_id=session_id + fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id ) # that should serve a confirmation page @@ -501,8 +501,8 @@ def test_ui_auth_via_sso(self) -> None: @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_does_not_offer_password_for_sso_user(self) -> None: - fake_oidc_provider = self.helper.fake_oidc_server() - login_resp, _ = self.helper.login_via_oidc(fake_oidc_provider, "username") + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username") user_tok = login_resp["access_token"] device_id = login_resp["device_id"] @@ -525,9 +525,9 @@ def test_does_not_offer_sso_for_password_user(self) -> None: @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_offers_both_flows_for_upgraded_user(self) -> None: """A user that had a password and then logged in with SSO should get both flows""" - fake_oidc_provider = self.helper.fake_oidc_server() + fake_oidc_server = self.helper.fake_oidc_server() login_resp, _ = self.helper.login_via_oidc( - fake_oidc_provider, UserID.from_string(self.user).localpart + fake_oidc_server, UserID.from_string(self.user).localpart ) self.assertEqual(login_resp["user_id"], self.user) @@ -546,11 +546,11 @@ def test_offers_both_flows_for_upgraded_user(self) -> None: def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: """If the user tries to authenticate with the wrong SSO user, they get an error""" - fake_oidc_provider = self.helper.fake_oidc_server() + fake_oidc_server = self.helper.fake_oidc_server() # log the user in login_resp, _ = self.helper.login_via_oidc( - fake_oidc_provider, UserID.from_string(self.user).localpart + fake_oidc_server, UserID.from_string(self.user).localpart ) self.assertEqual(login_resp["user_id"], self.user) @@ -565,7 +565,7 @@ def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: # do the OIDC auth, but auth as the wrong user channel, _ = self.helper.auth_via_oidc( - fake_oidc_provider, {"sub": "wrong_user"}, ui_auth_session_id=session_id + fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id ) # that should return a failure message @@ -595,9 +595,9 @@ def test_sso_not_approved(self) -> None: """Tests that if we register a user via SSO while requiring approval for new accounts, we still raise the correct error before logging the user in. """ - fake_oidc_provider = self.helper.fake_oidc_server() + fake_oidc_server = self.helper.fake_oidc_server() login_resp, _ = self.helper.login_via_oidc( - fake_oidc_provider, "username", expected_status=403 + fake_oidc_server, "username", expected_status=403 ) self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 51e62b650bac..ff5baa9f0a78 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -612,9 +612,9 @@ def test_multi_sso_redirect_to_saml(self) -> None: def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - fake_oidc_provider = self.helper.fake_oidc_server() + fake_oidc_server = self.helper.fake_oidc_server() - with fake_oidc_provider.patch_homeserver(hs=self.hs): + with fake_oidc_server.patch_homeserver(hs=self.hs): # pick the default OIDC provider channel = self.make_request( "GET", @@ -629,7 +629,7 @@ def test_login_via_oidc(self) -> None: oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, fake_oidc_provider.authorization_endpoint) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) # ... and should have set a cookie including the redirect url cookie_headers = channel.headers.getRawHeaders("Set-Cookie") @@ -647,7 +647,7 @@ def test_login_via_oidc(self) -> None: ) channel, _ = self.helper.complete_oidc_auth( - fake_oidc_provider, oidc_uri, cookies, {"sub": "user1"} + fake_oidc_server, oidc_uri, cookies, {"sub": "user1"} ) # that should serve a confirmation page @@ -698,9 +698,9 @@ def test_client_idp_redirect_to_unknown(self) -> None: def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" - fake_oidc_provider = self.helper.fake_oidc_server() + fake_oidc_server = self.helper.fake_oidc_server() - with fake_oidc_provider.patch_homeserver(hs=self.hs): + with fake_oidc_server.patch_homeserver(hs=self.hs): channel = self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") @@ -709,7 +709,7 @@ def test_client_idp_redirect_to_oidc(self) -> None: oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, fake_oidc_provider.authorization_endpoint) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect @@ -1288,11 +1288,11 @@ def create_resource_dict(self) -> Dict[str, Resource]: def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" - fake_oidc_provider = self.helper.fake_oidc_server() + fake_oidc_server = self.helper.fake_oidc_server() # do the start of the login flow channel, _ = self.helper.auth_via_oidc( - fake_oidc_provider, + fake_oidc_server, {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL, ) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index c47b63fff45d..e62ebcc6a5a3 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -114,7 +114,7 @@ class FakeResponse: # type: ignore[misc] attribute, and didn't support deliverBody until recently. """ - verison: Tuple[bytes, int, int] = (b"HTTP", 1, 1) + version: Tuple[bytes, int, int] = (b"HTTP", 1, 1) # HTTP response code code: int = 200 @@ -122,7 +122,7 @@ class FakeResponse: # type: ignore[misc] # body of the response body: bytes = b"" - headers: Headers = Headers() + headers: Headers = attr.Factory(Headers) @property def phrase(self): diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py index 914446101717..de134bbc893b 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py @@ -54,7 +54,7 @@ class FakeOidcServer: def __init__(self, clock: Clock, issuer: str): from authlib.jose import ECKey, KeySet - self.clock = clock + self._clock = clock self.issuer = issuer self.request = Mock(side_effect=self._request) @@ -64,14 +64,14 @@ def __init__(self, clock: Clock, issuer: str): self.post_token_handler = Mock(side_effect=self._post_token_handler) # A code -> grant mapping - self.authorization_grants: Dict[str, FakeAuthorizationGrant] = {} + self._authorization_grants: Dict[str, FakeAuthorizationGrant] = {} # An access token -> grant mapping - self.sessions: Dict[str, FakeAuthorizationGrant] = {} + self._sessions: Dict[str, FakeAuthorizationGrant] = {} # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for # signing JWTs. ECDSA keys are really quick to generate compared to RSA. - self.key = ECKey.generate_key(crv="P-256", is_private=True) - self.jwks = KeySet([ECKey.import_key(self.key.as_pem(is_private=False))]) + self._key = ECKey.generate_key(crv="P-256", is_private=True) + self._jwks = KeySet([ECKey.import_key(self._key.as_pem(is_private=False))]) self._id_token_overrides: Dict[str, Any] = {} @@ -127,11 +127,11 @@ def get_metadata(self) -> dict: } def get_jwks(self) -> dict: - return self.jwks.as_dict() + return self._jwks.as_dict() def get_userinfo(self, access_token: str) -> Optional[dict]: """Given an access token, get the userinfo of the associated session.""" - session = self.sessions.get(access_token, None) + session = self._sessions.get(access_token, None) if session is None: return None return session.userinfo @@ -143,10 +143,10 @@ def _sign(self, payload: dict) -> str: kid = self.get_jwks()["keys"][0]["kid"] protected = {"alg": "ES256", "kid": kid} json_payload = json.dumps(payload) - return jws.serialize_compact(protected, json_payload, self.key).decode("utf-8") + return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: - now = self.clock.time() + now = self._clock.time() id_token = { **grant.userinfo, "iss": self.issuer, @@ -193,17 +193,17 @@ def start_authorization( client_id=client_id, sid=sid, ) - self.authorization_grants[code] = grant + self._authorization_grants[code] = grant return code, grant def exchange_code(self, code: str) -> Optional[Dict[str, Any]]: - grant = self.authorization_grants.pop(code, None) + grant = self._authorization_grants.pop(code, None) if grant is None: return None access_token = random_string(10) - self.sessions[access_token] = grant + self._sessions[access_token] = grant token = { "token_type": "Bearer", From f225a91940d6002b1b061605bcd2e8d417127479 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 25 Oct 2022 15:44:44 +0200 Subject: [PATCH 8/8] Rename provider -> fake_server --- tests/rest/client/utils.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 4dfcc07b6be2..967d229223ab 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -569,7 +569,7 @@ def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer: def login_via_oidc( self, - provider: FakeOidcServer, + fake_server: FakeOidcServer, remote_user_id: str, with_sid: bool = False, expected_status: int = 200, @@ -588,7 +588,7 @@ def login_via_oidc( client_redirect_url = "https://x" userinfo = {"sub": remote_user_id} channel, grant = self.auth_via_oidc( - provider, userinfo, client_redirect_url, with_sid=with_sid + fake_server, userinfo, client_redirect_url, with_sid=with_sid ) # expect a confirmation page @@ -618,7 +618,7 @@ def login_via_oidc( def auth_via_oidc( self, - provider: FakeOidcServer, + fake_server: FakeOidcServer, user_info_dict: JsonDict, client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, @@ -657,7 +657,7 @@ def auth_via_oidc( cookies: Dict[str, str] = {} - with provider.patch_homeserver(hs=self.hs): + with fake_server.patch_homeserver(hs=self.hs): # if we're doing a ui auth, hit the ui auth redirect endpoint if ui_auth_session_id: # can't set the client redirect url for UI Auth @@ -673,16 +673,16 @@ def auth_via_oidc( # that synapse passes to the client. oauth_uri_path, _ = oauth_uri.split("?", 1) - assert oauth_uri_path == provider.authorization_endpoint, ( + assert oauth_uri_path == fake_server.authorization_endpoint, ( "unexpected SSO URI " + oauth_uri_path ) return self.complete_oidc_auth( - provider, oauth_uri, cookies, user_info_dict, with_sid=with_sid + fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid ) def complete_oidc_auth( self, - provider: FakeOidcServer, + fake_serer: FakeOidcServer, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, @@ -698,6 +698,7 @@ def complete_oidc_auth( Requires the OIDC callback resource to be mounted at the normal place. Args: + fake_server: the fake OIDC server with which the auth should be done oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, from initiate_sso_login or initiate_sso_ui_auth). cookies: the cookies set by synapse's redirect endpoint, which will be @@ -712,7 +713,7 @@ def complete_oidc_auth( _, oauth_uri_qs = oauth_uri.split("?", 1) params = urllib.parse.parse_qs(oauth_uri_qs) - code, grant = provider.start_authorization( + code, grant = fake_serer.start_authorization( scope=params["scope"][0], userinfo=user_info_dict, client_id=params["client_id"][0], @@ -727,7 +728,7 @@ def complete_oidc_auth( urllib.parse.urlencode({"state": state, "code": code}), ) - with provider.patch_homeserver(hs=self.hs): + with fake_serer.patch_homeserver(hs=self.hs): # now hit the callback URI with the right params and a made-up code channel = make_request( self.hs.get_reactor(),