diff --git a/requests_oauth2client/client.py b/requests_oauth2client/client.py index cdaba8f..0ede2d0 100644 --- a/requests_oauth2client/client.py +++ b/requests_oauth2client/client.py @@ -1679,29 +1679,37 @@ def from_discovery_endpoint( testing: bool = False, **kwargs: Any, ) -> OAuth2Client: - """Initialise an OAuth2Client based on Authorization Server Metadata. + """Initialize an `OAuth2Client` using an AS Discovery Document endpoint. - This will retrieve the standardised metadata document available at `url`, and will extract + If an `url` is provided, an HTTPS request will be done to that URL to obtain the Authorization Server Metadata. + + If an `issuer` is provided, the OpenID Connect Discovery document url will be automatically + derived from it, as specified in [OpenID Connect Discovery](https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationRequest). + + Once the standardized metadata document is obtained, this will extract all Endpoint Uris from that document, will fetch the current public keys from its - `jwks_uri`, then will initialise an OAuth2Client based on those endpoints. + `jwks_uri`, then will initialize an OAuth2Client based on those endpoints. Args: - url: the url where the server metadata will be retrieved - auth: the authentication handler to use for client authentication - client_id: client ID - client_secret: client secret to use to authenticate the client - private_key: private key to sign client assertions - session: a `requests.Session` to use to retrieve the document and initialise the client with - issuer: if an issuer is given, check that it matches the one from the retrieved document - testing: if True, don't try to validate the endpoint urls that are part of the document - **kwargs: additional keyword parameters to pass to OAuth2Client + url: The url where the server metadata will be retrieved. + issuer: The issuer value that is expected in the discovery document. + If not `url` is given, the OpenID Connect Discovery url for this issuer will be retrieved. + auth: The authentication handler to use for client authentication. + client_id: Client ID. + client_secret: Client secret to use to authenticate the client. + private_key: Private key to sign client assertions. + session: A `requests.Session` to use to retrieve the document and initialise the client with. + testing: If `True`, do not try to validate the issuer uri nor the endpoint urls + that are part of the document. + **kwargs: Additional keyword parameters to pass to `OAuth2Client`. Returns: - an OAuth2Client with endpoint initialised based on the obtained metadata + An `OAuth2Client` with endpoints initialized based on the obtained metadata. Raises: - InvalidParam: if neither `url` nor `issuer` are suitable urls - requests.HTTPError: if an error happens while fetching the documents + InvalidIssuer: If `issuer` is not using https, or contains credentials or fragment. + InvalidParam: If neither `url` nor `issuer` are suitable urls. + requests.HTTPError: If an error happens while fetching the documents. Example: ```python @@ -1710,25 +1718,30 @@ def from_discovery_endpoint( client = OAuth2Client.from_discovery_endpoint( issuer="https://myserver.net", client_id="my_client_id, - client_secret="my_client_secret" + client_secret="my_client_secret", ) ``` """ + if issuer is not None and not testing: + try: + validate_issuer_uri(issuer) + except InvalidUri as exc: + raise InvalidIssuer("issuer", issuer, exc) from exc # noqa: EM101 if url is None and issuer is not None: url = oidc_discovery_document_url(issuer) if url is None: msg = "Please specify at least one of `issuer` or `url`" raise InvalidParam(msg) - validate_endpoint_uri(url, path=False) + if not testing: + validate_endpoint_uri(url, path=False) session = session or requests.Session() discovery = session.get(url).json() jwks_uri = discovery.get("jwks_uri") - if jwks_uri: - jwks = JwkSet(session.get(jwks_uri).json()) + jwks = JwkSet(session.get(jwks_uri).json()) if jwks_uri else None return cls.from_discovery_document( discovery, @@ -1754,46 +1767,56 @@ def from_discovery_document( client_secret: str | None = None, private_key: Jwk | dict[str, Any] | None = None, authorization_server_jwks: JwkSet | dict[str, Any] | None = None, - session: requests.Session | None = None, https: bool = True, testing: bool = False, **kwargs: Any, ) -> OAuth2Client: - """Initialize an OAuth2Client, based on the server metadata from `discovery`. + """Initialize an `OAuth2Client`, based on an AS Discovery Document. Args: - discovery: a dict of server metadata, in the same format as retrieved from a discovery endpoint. - issuer: if an issuer is given, check that it matches the one mentioned in the document - auth: the authentication handler to use for client authentication - client_id: client ID - client_secret: client secret to use to authenticate the client - private_key: private key to sign client assertions - authorization_server_jwks: the current authorization server JWKS keys - session: a requests Session to use to retrieve the document and initialise the client with - https: (deprecated) if `True`, validates that urls in the discovery document use the https scheme - testing: if True, don't try to validate the endpoint urls that are part of the document - **kwargs: additional args that will be passed to OAuth2Client + discovery: A `dict` of server metadata, in the same format as retrieved from a discovery endpoint. + issuer: If an issuer is given, check that it matches the one mentioned in the document. + auth: The authentication handler to use for client authentication. + client_id: Client ID. + client_secret: Client secret to use to authenticate the client. + private_key: Private key to sign client assertions. + authorization_server_jwks: The current authorization server JWKS keys. + https: (deprecated) If `True`, validates that urls in the discovery document use the https scheme. + testing: If `True`, don't try to validate the endpoint urls that are part of the document. + **kwargs: Additional args that will be passed to `OAuth2Client`. Returns: - an `OAuth2Client` initialized with the endpoints from the discovery document + An `OAuth2Client` initialized with the endpoints from the discovery document. Raises: - InvalidDiscoveryDocument: if the document does not contain at least a `"token_endpoint"`. + InvalidDiscoveryDocument: If the document does not contain at least a `"token_endpoint"`. + + Examples: + ```python + from requests_oauth2client import OAuth2Client + + client = OAuth2Client.from_discovery_document( + { + "issuer": "https://myas.local", + "token_endpoint": "https://myas.local/token", + }, + client_id="client_id", + client_secret="client_secret", + ) + ``` """ if not https: warnings.warn( """\ -The https parameter is deprecated. +The `https` parameter is deprecated. To disable endpoint uri validation, set `testing=True` when initializing your `OAuth2Client`.""", stacklevel=1, ) testing = True if issuer and discovery.get("issuer") != issuer: - msg = ( - f"Mismatching `issuer` value in discovery document" - f" (received '{discovery.get('issuer')}', expected '{issuer}')" - ) + msg = f"""\ +Mismatching `issuer` value in discovery document (received '{discovery.get('issuer')}', expected '{issuer}').""" raise InvalidParam( msg, issuer, @@ -1811,8 +1834,8 @@ def from_discovery_document( introspection_endpoint = discovery.get(Endpoints.INSTROSPECTION) userinfo_endpoint = discovery.get(Endpoints.USER_INFO) jwks_uri = discovery.get(Endpoints.JWKS) - if jwks_uri is not None: - validate_endpoint_uri(jwks_uri, https=https) + if jwks_uri is not None and not testing: + validate_endpoint_uri(jwks_uri) authorization_response_iss_parameter_supported = discovery.get( "authorization_response_iss_parameter_supported", False, @@ -1830,7 +1853,6 @@ def from_discovery_document( client_id=client_id, client_secret=client_secret, private_key=private_key, - session=session, issuer=issuer, authorization_response_iss_parameter_supported=authorization_response_iss_parameter_supported, testing=testing, diff --git a/requests_oauth2client/utils.py b/requests_oauth2client/utils.py index f7fff90..a2d7f4e 100644 --- a/requests_oauth2client/utils.py +++ b/requests_oauth2client/utils.py @@ -51,7 +51,7 @@ def validate_endpoint_uri( *, https: bool = True, no_credentials: bool = True, - no_port: bool = True, + no_port: bool = False, no_fragment: bool = True, path: bool = True, ) -> str: diff --git a/tests/unit_tests/test_client.py b/tests/unit_tests/test_client.py index 5f56854..96cc638 100644 --- a/tests/unit_tests/test_client.py +++ b/tests/unit_tests/test_client.py @@ -19,8 +19,11 @@ ClientSecretPost, DeviceAuthorizationResponse, IdToken, + InvalidIssuer, + InvalidParam, InvalidPushedAuthorizationResponse, InvalidTokenResponse, + InvalidUri, OAuth2Client, PrivateKeyJwt, PublicApp, @@ -658,7 +661,7 @@ def test_from_discovery_document( auth=client_id, ) - with pytest.warns(match="https parameter is deprecated"): + with pytest.warns(match="`https` parameter is deprecated"): OAuth2Client.from_discovery_document( { "issuer": issuer, @@ -1516,9 +1519,6 @@ def test_testing_oauth2client() -> None: with pytest.raises(ValueError, match="must use https"): OAuth2Client(token_endpoint="https://valid.token/endpoint", client_id="client_id", issuer=issuer) - with pytest.raises(ValueError, match="no custom port number allowed"): - OAuth2Client(token_endpoint="https://valid.token/endpoint", client_id="client_id", issuer=issuer) - with pytest.raises(ValueError, match="must include a path"): OAuth2Client(token_endpoint="https://foo.bar/", client_id="client_id") @@ -1546,3 +1546,52 @@ class ProxyAuthorizationBearerToken(BearerToken): requests.post(target_api, auth=ProxyAuthorizationBearerToken(access_token)) assert requests_mock.last_request is not None assert requests_mock.last_request.headers[auth_header] == f"Bearer {access_token}" + + +def test_custom_ports_in_endpoints(requests_mock: RequestsMocker) -> None: + issuer = "https://as.local:8443" + token_endpoint = "https://as.local:8443/token" + client = OAuth2Client(token_endpoint=token_endpoint, client_id="client_id", client_secret="client_secret") + assert client.token_endpoint == token_endpoint + + assert not requests_mock.called_once + with pytest.raises(InvalidIssuer, match="must use https"): + OAuth2Client.from_discovery_endpoint(issuer="http://as.local") + assert not requests_mock.called_once + + with pytest.raises(InvalidIssuer, match="must use https"): + OAuth2Client.from_discovery_endpoint(issuer="http://as.local:8080") + assert not requests_mock.called_once + + with pytest.raises(InvalidUri, match="must use https"): + OAuth2Client.from_discovery_endpoint(url="http://as.local/.well-known/openid-configuration") + assert not requests_mock.called_once + + requests_mock.get( + "https://as.local/.well-known/openid-configuration", json={"issuer": issuer, "token_endpoint": token_endpoint} + ) + with pytest.raises( + InvalidParam, + match=rf"Mismatching `issuer` value in discovery document \(received '{issuer}', expected 'https://as.local'\)", + ): + OAuth2Client.from_discovery_endpoint(issuer="https://as.local", client_id="client_id") + assert requests_mock.called_once + + discovery_url = "https://as.local:8443/.well-known/openid-configuration" + requests_mock.get(discovery_url, json={"issuer": issuer, "token_endpoint": token_endpoint}) + + requests_mock.reset() + assert ( + OAuth2Client.from_discovery_endpoint( + url="https://as.local:8443/.well-known/openid-configuration", client_id="client_id" + ).token_endpoint + == token_endpoint + ) + assert requests_mock.called_once + + requests_mock.reset() + assert ( + OAuth2Client.from_discovery_endpoint(issuer="https://as.local:8443", client_id="client_id").token_endpoint + == token_endpoint + ) + assert requests_mock.called_once diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py index 015c0de..7cc2e01 100644 --- a/tests/unit_tests/test_utils.py +++ b/tests/unit_tests/test_utils.py @@ -10,6 +10,7 @@ def test_validate_uri() -> None: validate_endpoint_uri("https://myas.local/token") + validate_endpoint_uri("https://myas.local:443/token", no_port=True) with pytest.raises(ValueError, match="https") as exc: validate_endpoint_uri("http://myas.local/token") assert exc.type is InvalidUri @@ -23,7 +24,7 @@ def test_validate_uri() -> None: validate_endpoint_uri("https://user:passwd@myas.local/token") assert exc.type is InvalidUri with pytest.raises(ValueError, match="port") as exc: - validate_endpoint_uri("https://myas.local:1234/token") + validate_endpoint_uri("https://myas.local:1234/token", no_port=True) assert exc.type is InvalidUri