diff --git a/requests_oauth2client/__init__.py b/requests_oauth2client/__init__.py index 32fd1e0..5db86e4 100644 --- a/requests_oauth2client/__init__.py +++ b/requests_oauth2client/__init__.py @@ -9,7 +9,6 @@ from .api_client import ApiClient, InvalidBoolFieldsParam, InvalidPathParam from .auth import ( - BaseOAuth2RenewableTokenAuth, NonRenewableTokenError, OAuth2AccessTokenAuth, OAuth2AuthorizationCodeAuth, @@ -152,7 +151,6 @@ "BackChannelAuthenticationPoolingJob", "BackChannelAuthenticationResponse", "BaseClientAuthenticationMethod", - "BaseOAuth2RenewableTokenAuth", "BaseTokenEndpointPoolingJob", "BearerToken", "BearerTokenSerializer", diff --git a/requests_oauth2client/api_client.py b/requests_oauth2client/api_client.py index 9769bcd..30ef287 100644 --- a/requests_oauth2client/api_client.py +++ b/requests_oauth2client/api_client.py @@ -7,7 +7,7 @@ from urllib.parse import urljoin import requests -from attrs import field, frozen +from attrs import frozen from typing_extensions import Literal, Self if TYPE_CHECKING: @@ -21,10 +21,10 @@ class InvalidBoolFieldsParam(ValueError): def __init__(self, bool_fields: object) -> None: super().__init__("""\ -Invalid value for 'bool_fields' parameter. It must be an iterable of 2 str values: -- first one for the True value -- second one for the False value -boolean fields in `data` or `params` with a boolean value (`True` or `False`) +Invalid value for `bool_fields` parameter. It must be an iterable of 2 `str` values: +- first one for the `True` value, +- second one for the `False` value. +Boolean fields in `data` or `params` with a boolean value (`True` or `False`) will be serialized to the corresponding value. Default is `('true', 'false')` Use this parameter when the target API expects some other values, e.g.: @@ -36,7 +36,7 @@ def __init__(self, bool_fields: object) -> None: def validate_bool_fields(bool_fields: tuple[str, str]) -> tuple[str, str]: - """Validate the `bool_fields` paremeter. + """Validate the `bool_fields` parameter. It must be a sequence of 2 values. First one is the `True` value, second one is the `False` value. Both must be `str` or string-able values. @@ -135,12 +135,12 @@ class ApiClient: """ base_url: str - auth: requests.auth.AuthBase | None = None - timeout: int | None = 60 - raise_for_status: bool = True - none_fields: Literal["include", "exclude", "empty"] = "exclude" - bool_fields: tuple[Any, Any] | None = "true", "false" - session: requests.Session = field(factory=requests.Session) + auth: requests.auth.AuthBase | None + timeout: int | None + raise_for_status: bool + none_fields: Literal["include", "exclude", "empty"] + bool_fields: tuple[Any, Any] | None + session: requests.Session def __init__( self, diff --git a/requests_oauth2client/auth.py b/requests_oauth2client/auth.py index 5427ca2..137a20f 100644 --- a/requests_oauth2client/auth.py +++ b/requests_oauth2client/auth.py @@ -21,8 +21,14 @@ class NonRenewableTokenError(Exception): @define(init=False) -class BaseOAuth2RenewableTokenAuth(requests.auth.AuthBase): - """Base class for BearerToken-based Auth Handlers, with an obtainable or renewable token. +class OAuth2AccessTokenAuth(requests.auth.AuthBase): + """Authentication Handler for OAuth 2.0 Access Tokens and (optional) Refresh Tokens. + + This [Requests Auth handler][requests.auth.AuthBase] implementation uses an access token as + Bearer or DPoP token, and can automatically refresh it when expired, if a refresh token is available. + + Token can be a simple `str` containing a raw access token value, or a + [BearerToken][requests_oauth2client.tokens.BearerToken] that can contain a `refresh_token`. In addition to adding a properly formatted `Authorization` header, this will obtain a new token once the current token is expired. Expiration is detected based on the `expires_in` hint @@ -30,6 +36,24 @@ class BaseOAuth2RenewableTokenAuth(requests.auth.AuthBase): token is obtained some seconds before the actual expiration is reached. This may help in situations where the client, AS and RS have slightly offset clocks. + Args: + client: the client to use to refresh tokens. + token: an initial Access Token, if you have one already. In most cases, leave `None`. + leeway: expiration leeway, in number of seconds. + **token_kwargs: additional kwargs to pass to the token endpoint. + + Example: + ```python + from requests_oauth2client import BearerToken, OAuth2Client, OAuth2AccessTokenAuth, requests + + client = OAuth2Client(token_endpoint="https://my.as.local/token", auth=("client_id", "client_secret")) + # obtain a BearerToken any way you see fit, optionally including a refresh token + # for this example, the token value is hardcoded + token = BearerToken(access_token="access_token", expires_in=600, refresh_token="refresh_token") + auth = OAuth2AccessTokenAuth(client, token, scope="my_scope") + resp = requests.post("https://my.api.local/resource", auth=auth) + ``` + """ client: OAuth2Client = field(on_setattr=setters.frozen) @@ -37,6 +61,13 @@ class BaseOAuth2RenewableTokenAuth(requests.auth.AuthBase): leeway: int = field(on_setattr=setters.frozen) token_kwargs: dict[str, Any] = field(on_setattr=setters.frozen) + def __init__( + self, client: OAuth2Client, token: str | BearerToken, *, leeway: int = 20, **token_kwargs: Any + ) -> None: + if isinstance(token, str): + token = BearerToken(token) + self.__attrs_init__(client=client, token=token, leeway=leeway, token_kwargs=token_kwargs) + def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: """Add the Access Token to the request. @@ -55,10 +86,11 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques def renew_token(self) -> None: """Obtain a new Bearer Token. - Subclasses should implement this. + This will try to use the `refresh_token`, if there is one. """ - raise NotImplementedError + if self.token is not None and self.token.refresh_token is not None: + self.token = self.client.refresh_token(refresh_token=self.token, **self.token_kwargs) def forget_token(self) -> None: """Forget the current token, forcing a renewal on the next HTTP request.""" @@ -66,29 +98,7 @@ def forget_token(self) -> None: @define(init=False) -class BaseOAuth2RefreshTokenAuth(BaseOAuth2RenewableTokenAuth): - """Base class for flows which can have a refresh-token. - - This implements a `renew_token()` method which uses the refresh token to obtain new tokens. - - """ - - @override - def renew_token(self) -> None: - """Obtain a new token, using the Refresh Token, if available. - - Raises: - NonRenewableTokenError: if the token is not renewable. - - """ - if self.token is None or self.token.refresh_token is None: - raise NonRenewableTokenError - - self.token = self.client.refresh_token(refresh_token=self.token, **self.token_kwargs) - - -@define(init=False) -class OAuth2ClientCredentialsAuth(BaseOAuth2RenewableTokenAuth): +class OAuth2ClientCredentialsAuth(OAuth2AccessTokenAuth): """An Auth Handler for the [Client Credentials grant](https://www.rfc-editor.org/rfc/rfc6749#section-4.4). This [requests AuthBase][requests.auth.AuthBase] automatically gets Access Tokens from an OAuth @@ -126,47 +136,7 @@ def renew_token(self) -> None: @define(init=False) -class OAuth2AccessTokenAuth(BaseOAuth2RefreshTokenAuth): - """Authentication Handler for OAuth 2.0 Access Tokens and (optional) Refresh Tokens. - - This [Requests Auth handler][requests.auth.AuthBase] implementation uses an access token as - Bearer token, and can automatically refresh it when expired, if a refresh token is available. - - Token can be a simple `str` containing a raw access token value, or a - [BearerToken][requests_oauth2client.tokens.BearerToken] that can contain a `refresh_token`. - If a `refresh_token` and an expiration date are available (based on `expires_in` hint), - this Auth Handler will automatically refresh the access token once it is expired. - - Args: - client: the client to use to refresh tokens. - token: an initial Access Token, if you have one already. In most cases, leave `None`. - leeway: expiration leeway, in number of seconds. - **token_kwargs: additional kwargs to pass to the token endpoint. - - Example: - ```python - from requests_oauth2client import BearerToken, OAuth2Client, OAuth2AccessTokenAuth, requests - - client = OAuth2Client(token_endpoint="https://my.as.local/token", auth=("client_id", "client_secret")) - # obtain a BearerToken any way you see fit, optionally including a refresh token - # for this example, the token value is hardcoded - token = BearerToken(access_token="access_token", expires_in=600, refresh_token="refresh_token") - auth = OAuth2AccessTokenAuth(client, token, scope="my_scope") - resp = requests.post("https://my.api.local/resource", auth=auth) - ``` - - """ - - def __init__( - self, client: OAuth2Client, token: str | BearerToken, *, leeway: int = 20, **token_kwargs: Any - ) -> None: - if isinstance(token, str): - token = BearerToken(token) - self.__attrs_init__(client=client, token=token, leeway=leeway, token_kwargs=token_kwargs) - - -@define(init=False) -class OAuth2AuthorizationCodeAuth(BaseOAuth2RefreshTokenAuth): # type: ignore[override] +class OAuth2AuthorizationCodeAuth(OAuth2AccessTokenAuth): # type: ignore[override] """Authentication handler for the [Authorization Code grant](https://www.rfc-editor.org/rfc/rfc6749#section-4.1). This [Requests Auth handler][requests.auth.AuthBase] implementation exchanges an Authorization @@ -235,7 +205,7 @@ def exchange_code_for_token(self) -> None: @define(init=False) -class OAuth2ResourceOwnerPasswordAuth(BaseOAuth2RenewableTokenAuth): # type: ignore[override] +class OAuth2ResourceOwnerPasswordAuth(OAuth2AccessTokenAuth): # type: ignore[override] """Authentication Handler for the [Resource Owner Password Credentials Flow](https://www.rfc-editor.org/rfc/rfc6749#section-4.3). This [Requests Auth handler][requests.auth.AuthBase] implementation exchanges the user @@ -313,7 +283,7 @@ def renew_token(self) -> None: @define(init=False) -class OAuth2DeviceCodeAuth(BaseOAuth2RefreshTokenAuth): # type: ignore[override] +class OAuth2DeviceCodeAuth(OAuth2AccessTokenAuth): # type: ignore[override] """Authentication Handler for the [Device Code Flow](https://www.rfc-editor.org/rfc/rfc8628). This [Requests Auth handler][requests.auth.AuthBase] implementation exchanges a Device Code for diff --git a/requests_oauth2client/authorization_request.py b/requests_oauth2client/authorization_request.py index 3c35494..e29a7a2 100644 --- a/requests_oauth2client/authorization_request.py +++ b/requests_oauth2client/authorization_request.py @@ -5,9 +5,10 @@ import re import secrets from enum import Enum +from functools import cached_property from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Sequence -from attrs import Factory, asdict, field, fields, frozen +from attrs import asdict, field, fields, frozen from binapy import BinaPy from furl import furl # type: ignore[import-untyped] from jwskate import JweCompact, Jwk, Jwt, SignedJwt @@ -47,7 +48,7 @@ class ResponseTypes(str, Enum): class CodeChallengeMethods(str, Enum): - """All standardised `code_challenge` values. + """All standardised `code_challenge_method` values. You should always use `S256`. @@ -58,7 +59,7 @@ class CodeChallengeMethods(str, Enum): class UnsupportedCodeChallengeMethod(ValueError): - """Raised when an unsupported code_challenge_method is provided.""" + """Raised when an unsupported `code_challenge_method` is provided.""" class InvalidCodeVerifierParam(ValueError): @@ -126,7 +127,7 @@ def derive_challenge(cls, verifier: str | bytes, method: str = CodeChallengeMeth @classmethod def generate_code_verifier_and_challenge(cls, method: str = CodeChallengeMethods.S256) -> tuple[str, str]: - """Generate a valid `code_verifier` and derive its `code_challenge`. + """Generate a valid `code_verifier` and its matching `code_challenge`. Args: method: the method to use for deriving the challenge. Accepts 'S256' or 'plain'. @@ -191,14 +192,15 @@ class AuthorizationResponse: """Represent a successful Authorization Response. An Authorization Response is the redirection initiated by the AS to the client's redirection - endpoint (redirect_uri) after an Authorization Request. This Response is typically created with - a call to `AuthorizationRequest.validate_callback()` once the call to the client Redirection - Endpoint is made. AuthorizationResponse contains the following, all accessible as attributes: + endpoint (redirect_uri), after an Authorization Request. + This Response is typically created with a call to `AuthorizationRequest.validate_callback()` + once the call to the client Redirection Endpoint is made. + `AuthorizationResponse` contains the following attributes: - all the parameters that have been returned by the AS, most notably the `code`, and optional parameters such as `state`. - - the redirect_uri that was used for the Authorization Request - - the code_verifier matching the code_challenge that was used for the Authorization Request + - the `redirect_uri` that was used for the Authorization Request + - the `code_verifier` matching the `code_challenge` that was used for the Authorization Request Parameters `redirect_uri` and `code_verifier` must be those from the matching `AuthorizationRequest`. All other parameters including `code` and `state` must be those @@ -215,14 +217,14 @@ class AuthorizationResponse: """ code: str - redirect_uri: str | None = None - code_verifier: str | None = None - state: str | None = None - nonce: str | None = None - acr_values: tuple[str, ...] | None = None - max_age: int | None = None - issuer: str | None = None - kwargs: dict[str, Any] = Factory(dict) + redirect_uri: str | None + code_verifier: str | None + state: str | None + nonce: str | None + acr_values: tuple[str, ...] | None + max_age: int | None + issuer: str | None + kwargs: dict[str, Any] def __init__( self, @@ -348,20 +350,19 @@ class AuthorizationRequest: authorization_endpoint: str client_id: str = field(metadata={"query": True}) - redirect_uri: str | None = field(metadata={"query": True}, default=None) - scope: tuple[str, ...] | None = field(metadata={"query": True}, default=("openid",)) - response_type: str = field(metadata={"query": True}, default=ResponseTypes.CODE) - state: str | None = field(metadata={"query": True}, default=None) - nonce: str | None = field(metadata={"query": True}, default=None) - code_challenge_method: str | None = field(metadata={"query": True}, default=CodeChallengeMethods.S256) - acr_values: tuple[str, ...] | None = field(metadata={"query": True}, default=None) - max_age: int | None = field(metadata={"query": True}, default=None) - kwargs: dict[str, Any] = Factory(dict) - - code_verifier: str | None = None - code_challenge: str | None = field(init=False, metadata={"query": True}) - authorization_response_iss_parameter_supported: bool = False - issuer: str | None = None + redirect_uri: str | None = field(metadata={"query": True}) + scope: tuple[str, ...] | None = field(metadata={"query": True}) + response_type: str = field(metadata={"query": True}) + state: str | None = field(metadata={"query": True}) + nonce: str | None = field(metadata={"query": True}) + code_challenge_method: str | None = field(metadata={"query": True}) + acr_values: tuple[str, ...] | None = field(metadata={"query": True}) + max_age: int | None = field(metadata={"query": True}) + kwargs: dict[str, Any] + + code_verifier: str | None + authorization_response_iss_parameter_supported: bool + issuer: str | None exception_classes: ClassVar[dict[str, type[AuthorizationResponseError]]] = { "interaction_required": InteractionRequired, @@ -434,11 +435,9 @@ def __init__( # noqa: PLR0913, C901 ) raise ValueError(msg) - code_challenge: str | None = None if code_challenge_method: if not code_verifier: code_verifier = PkceUtils.generate_code_verifier() - code_challenge = PkceUtils.derive_challenge(code_verifier, code_challenge_method) else: code_verifier = None @@ -458,7 +457,13 @@ def __init__( # noqa: PLR0913, C901 authorization_response_iss_parameter_supported=authorization_response_iss_parameter_supported, kwargs=kwargs, ) - object.__setattr__(self, "code_challenge", code_challenge) + + @cached_property + def code_challenge(self) -> str | None: + """The `code_challenge` that matches `code_verifier` and `code_challenge_method`.""" + if self.code_verifier and self.code_challenge_method: + return PkceUtils.derive_challenge(self.code_verifier, self.code_challenge_method) + return None def as_dict(self) -> dict[str, Any]: """Return the full argument dict. @@ -468,7 +473,6 @@ def as_dict(self) -> dict[str, Any]: """ d = asdict(self) d.update(**d.pop("kwargs", {})) - d.pop("code_challenge") return d @property @@ -482,6 +486,7 @@ def args(self) -> dict[str, Any]: d = {field.name: getattr(self, field.name) for field in fields(type(self)) if field.metadata.get("query")} if d["scope"]: d["scope"] = " ".join(d["scope"]) + d["code_challenge"] = self.code_challenge d.update(self.kwargs) return {key: val for key, val in d.items() if val is not None} @@ -763,8 +768,8 @@ class RequestParameterAuthorizationRequest: authorization_endpoint: str client_id: str request: Jwt - expires_at: datetime | None = None - kwargs: dict[str, Any] = Factory(dict) + expires_at: datetime | None + kwargs: dict[str, Any] @accepts_expires_in def __init__( @@ -829,8 +834,8 @@ class RequestUriParameterAuthorizationRequest: authorization_endpoint: str client_id: str request_uri: str - expires_at: datetime | None = None - kwargs: dict[str, Any] = Factory(dict) + expires_at: datetime | None + kwargs: dict[str, Any] @accepts_expires_in def __init__( @@ -903,7 +908,6 @@ def default_dumper(azr: AuthorizationRequest) -> str: """ d = asdict(azr) d.update(**d.pop("kwargs", {})) - d.pop("code_challenge") return BinaPy.serialize_to("json", d).to("deflate").to("b64u").ascii() @staticmethod diff --git a/requests_oauth2client/client.py b/requests_oauth2client/client.py index 0ede2d0..8fe319c 100644 --- a/requests_oauth2client/client.py +++ b/requests_oauth2client/client.py @@ -175,7 +175,7 @@ class Endpoints(str, Enum): AUTHORIZATION = "authorization_endpoint" BACKCHANNEL_AUTHENTICATION = "backchannel_authentication_endpoint" DEVICE_AUTHORIZATION = "device_authorization_endpoint" - INSTROSPECTION = "introspection_endpoint" + INTROSPECTION = "introspection_endpoint" REVOCATION = "revocation_endpoint" PUSHED_AUTHORIZATION_REQUEST = "pushed_authorization_request_endpoint" JWKS = "jwks_uri" @@ -283,7 +283,7 @@ class OAuth2Client: """ - auth: requests.auth.AuthBase = field(converter=client_auth_factory) + auth: requests.auth.AuthBase token_endpoint: str = field() revocation_endpoint: str | None = field() introspection_endpoint: str | None = field() @@ -296,14 +296,14 @@ class OAuth2Client: jwks_uri: str | None = field() authorization_server_jwks: JwkSet issuer: str | None = field() - id_token_signed_response_alg: str | None = SignatureAlgs.RS256 - id_token_encrypted_response_alg: str | None = None - id_token_decryption_key: Jwk | None = None - code_challenge_method: str | None = CodeChallengeMethods.S256 - authorization_response_iss_parameter_supported: bool = False - session: requests.Session = field(factory=requests.Session) - extra_metadata: dict[str, Any] = field(factory=dict) - testing: bool = False + id_token_signed_response_alg: str | None + id_token_encrypted_response_alg: str | None + id_token_decryption_key: Jwk | None + code_challenge_method: str | None + authorization_response_iss_parameter_supported: bool + session: requests.Session + extra_metadata: dict[str, Any] + testing: bool token_class: type[BearerToken] = BearerToken @@ -1346,7 +1346,7 @@ def introspect_token( data["token_type_hint"] = token_type_hint return self._request( - Endpoints.INSTROSPECTION, + Endpoints.INTROSPECTION, data=data, auth=self.auth, on_success=self.parse_introspection_response, @@ -1831,7 +1831,7 @@ def from_discovery_document( raise InvalidDiscoveryDocument(msg, discovery) authorization_endpoint = discovery.get(Endpoints.AUTHORIZATION) revocation_endpoint = discovery.get(Endpoints.REVOCATION) - introspection_endpoint = discovery.get(Endpoints.INSTROSPECTION) + introspection_endpoint = discovery.get(Endpoints.INTROSPECTION) userinfo_endpoint = discovery.get(Endpoints.USER_INFO) jwks_uri = discovery.get(Endpoints.JWKS) if jwks_uri is not None and not testing: diff --git a/requests_oauth2client/client_authentication.py b/requests_oauth2client/client_authentication.py index d2b78c9..44263ca 100644 --- a/requests_oauth2client/client_authentication.py +++ b/requests_oauth2client/client_authentication.py @@ -14,7 +14,7 @@ from uuid import uuid4 import requests -from attr import field, frozen +from attrs import frozen from binapy import BinaPy from jwskate import Jwk, Jwt, SignatureAlgs, SymmetricJwk, to_jwk @@ -170,6 +170,7 @@ class BaseClientAssertionAuthenticationMethod(BaseClientAuthenticationMethod): lifetime: int jti_gen: Callable[[], str] aud: str | None + alg: str | None def client_assertion(self, audience: str) -> str: """Generate a Client Assertion for a specific audience. @@ -236,7 +237,6 @@ class ClientSecretJwt(BaseClientAssertionAuthenticationMethod): """ client_secret: str - alg: str def __init__( self, @@ -339,8 +339,7 @@ class PrivateKeyJwt(BaseClientAssertionAuthenticationMethod): """ - private_jwk: Jwk = field(converter=to_jwk) - alg: str | None + private_jwk: Jwk def __init__( self, @@ -352,29 +351,31 @@ def __init__( jti_gen: Callable[[], str] = lambda: str(uuid4()), aud: str | None = None, ) -> None: - self.__attrs_init__( - client_id=client_id, - private_jwk=private_jwk, - alg=alg, - lifetime=lifetime, - jti_gen=jti_gen, - aud=aud, - ) + private_jwk = to_jwk(private_jwk) - alg = self.private_jwk.alg or alg + alg = private_jwk.alg or alg if not alg: raise InvalidClientAssertionSigningKeyOrAlg(alg) - if alg not in self.private_jwk.supported_signing_algorithms(): + if alg not in private_jwk.supported_signing_algorithms(): raise InvalidClientAssertionSigningKeyOrAlg(alg) - if not self.private_jwk.is_private or self.private_jwk.is_symmetric: + if not private_jwk.is_private or private_jwk.is_symmetric: raise InvalidClientAssertionSigningKeyOrAlg(alg) - kid = self.private_jwk.get("kid") + kid = private_jwk.get("kid") if not kid: raise InvalidClientAssertionSigningKeyOrAlg(alg) + self.__attrs_init__( + client_id=client_id, + private_jwk=private_jwk, + alg=alg, + lifetime=lifetime, + jti_gen=jti_gen, + aud=aud, + ) + def client_assertion(self, audience: str) -> str: """Generate a Client Assertion, asymmetrically signed with `private_jwk` as key. @@ -481,6 +482,10 @@ def client_auth_factory( an Auth Handler that will manage client authentication to the AS Token Endpoint or other backend endpoints. + Raises: + UnsupportedClientCredentials: if the provided parameters are not suitable to guess the + desired authentication method. + """ if auth is not None and (client_id is not None or client_secret is not None or private_key is not None): msg = """\ diff --git a/requests_oauth2client/pooling.py b/requests_oauth2client/pooling.py index 75d3b36..1beb9d0 100644 --- a/requests_oauth2client/pooling.py +++ b/requests_oauth2client/pooling.py @@ -5,7 +5,7 @@ import time from typing import TYPE_CHECKING, Any -from attrs import define +from attrs import define, field, setters from .exceptions import AuthorizationPending, SlowDown @@ -26,11 +26,11 @@ class BaseTokenEndpointPoolingJob: """ - client: OAuth2Client - requests_kwargs: dict[str, Any] - token_kwargs: dict[str, Any] + client: OAuth2Client = field(on_setattr=setters.frozen) + requests_kwargs: dict[str, Any] = field(on_setattr=setters.frozen) + token_kwargs: dict[str, Any] = field(on_setattr=setters.frozen) + slow_down_interval: int = field(on_setattr=setters.frozen) interval: int - slow_down_interval: int def __call__(self) -> BearerToken | None: """Wrap the actual Token Endpoint call with a pooling interval. diff --git a/requests_oauth2client/tokens.py b/requests_oauth2client/tokens.py index fc72234..4387acb 100644 --- a/requests_oauth2client/tokens.py +++ b/requests_oauth2client/tokens.py @@ -9,7 +9,7 @@ import jwskate import requests -from attrs import Factory, asdict, frozen +from attrs import asdict, frozen from binapy import BinaPy from typing_extensions import Self @@ -251,12 +251,12 @@ class BearerToken(TokenResponse, requests.auth.AuthBase): AUTHORIZATION_HEADER: ClassVar[str] = "Authorization" access_token: str - expires_at: datetime | None = None - scope: str | None = None - refresh_token: str | None = None - token_type: str = TOKEN_TYPE - id_token: IdToken | jwskate.JweCompact | None = None - kwargs: dict[str, Any] = Factory(dict) + expires_at: datetime | None + scope: str | None + refresh_token: str | None + token_type: str + id_token: IdToken | jwskate.JweCompact | None + kwargs: dict[str, Any] @accepts_expires_in def __init__( diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index b9779cb..0520dc8 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -368,7 +368,6 @@ def code_challenge_method(request: FixtureRequest) -> str | None: @pytest.fixture(scope="session") -@pytest.mark.slow def authorization_request( # noqa: C901 authorization_endpoint: str, client_id: str, diff --git a/tests/unit_tests/test_api_client.py b/tests/unit_tests/test_api_client.py index 0c40399..b251ec4 100644 --- a/tests/unit_tests/test_api_client.py +++ b/tests/unit_tests/test_api_client.py @@ -288,7 +288,7 @@ def test_bool_fields(requests_mock: RequestsMocker, target_api: str) -> None: assert requests_mock.last_request.query == "foo=bar&true=1&false=0" assert requests_mock.last_request.text == "foo=bar&true=1&false=0" - with pytest.raises(ValueError, match="Invalid value for 'bool_fields'") as exc: + with pytest.raises(ValueError, match="Invalid value for `bool_fields`") as exc: ApiClient(target_api).get(bool_fields=(1, 2, 3)) assert exc.type is InvalidBoolFieldsParam diff --git a/tests/unit_tests/test_authorization_request.py b/tests/unit_tests/test_authorization_request.py index ba3681e..596ab60 100644 --- a/tests/unit_tests/test_authorization_request.py +++ b/tests/unit_tests/test_authorization_request.py @@ -115,7 +115,7 @@ def test_request_uri_authorization_request_with_custom_param(authorization_endpo ) assert isinstance(request_uri_azr.uri, str) url = request_uri_azr.furl - assert url.origin + url.pathstr == authorization_endpoint + assert url.origin + str(url.path) == authorization_endpoint assert url.args == {"client_id": client_id, "request_uri": request_uri, "custom_attr": custom_attr} diff --git a/tests/unit_tests/test_backchannel_authentication.py b/tests/unit_tests/test_backchannel_authentication.py index c30e836..a3afe03 100644 --- a/tests/unit_tests/test_backchannel_authentication.py +++ b/tests/unit_tests/test_backchannel_authentication.py @@ -285,9 +285,11 @@ def test_pooling_job( ) requests_mock.post(token_endpoint, status_code=401, json={"error": "authorization_pending"}) - with mocker.patch("time.sleep"): - assert job() is None + mocker.patch("time.sleep") + + assert job() is None time.sleep.assert_called_once_with(job.interval) # type: ignore[attr-defined] + time.sleep.reset_mock() # type: ignore[attr-defined] assert requests_mock.called_once assert job.interval == interval @@ -296,9 +298,10 @@ def test_pooling_job( freezer.tick(job.interval) requests_mock.reset_mock() requests_mock.post(token_endpoint, status_code=401, json={"error": "slow_down"}) - with mocker.patch("time.sleep"): - assert job() is None + + assert job() is None time.sleep.assert_called_once_with(interval) # type: ignore[attr-defined] + time.sleep.reset_mock() # type: ignore[attr-defined] assert requests_mock.called_once assert job.interval == interval + job.slow_down_interval ciba_request_validator(requests_mock.last_request, auth_req_id=auth_req_id) @@ -306,9 +309,10 @@ def test_pooling_job( freezer.tick(job.interval) requests_mock.reset_mock() requests_mock.post(token_endpoint, json={"access_token": access_token}) - with mocker.patch("time.sleep"): - token = job() + + token = job() time.sleep.assert_called_once_with(interval + job.slow_down_interval) # type: ignore[attr-defined] + time.sleep.reset_mock() # type: ignore[attr-defined] assert requests_mock.called_once assert job.interval == interval + job.slow_down_interval ciba_request_validator(requests_mock.last_request, auth_req_id=auth_req_id) diff --git a/tests/unit_tests/test_oidc.py b/tests/unit_tests/test_oidc.py index e5b9c1e..a55e589 100644 --- a/tests/unit_tests/test_oidc.py +++ b/tests/unit_tests/test_oidc.py @@ -127,7 +127,7 @@ def test_invalid_id_token(token_endpoint: str) -> None: sig_jwk = Jwk.generate(alg=SignatureAlgs.RS256).with_kid_thumbprint() enc_jwk = Jwk.generate(alg=KeyManagementAlgs.ECDH_ES_A256KW).with_kid_thumbprint() - as_jwks = sig_jwk.as_jwks() + as_jwks = sig_jwk.public_jwk().as_jwks() issuer = "http://issuer.local" client_id = "my_client_id"