Skip to content

Commit

Permalink
chore/cleanups (#73)
Browse files Browse the repository at this point in the history
* remove default values from `attrs.frozen` classes with `init=False`, since those defaults are part of the init method.

* call `__attrs_init__` *after* checking the arguments in `PrivateKeyJwt` instead of *before*

* fix typos

* move common `alg` attribute from `ClientSecretJwt` and `PrivateKeyJwt` subclasses to `BaseClientAssertionAuthenticationMethod` parent class

* freeze all attributes except `interval` in `BaseTokenEndpointPoolingJob`

* add missing `Raises` section in `client_auth_factory` docstring

* use `OAuth2AccessTokenAuth` as base class for all OAuth2 auth handlers, remove `BaseOAuth2RenewableTokenAuth` and `BaseOAuth2RefreshTokenAuth` which are useless in practice.

* fix typo in `Endpoints.INSTROSPECTION` (extra S)

* turn `code_challenge` into a `cached_property` in `AuthorizationRequest`

* fix warnings in tests
  • Loading branch information
guillp committed Sep 19, 2024
1 parent ab12289 commit 8d0ca66
Show file tree
Hide file tree
Showing 13 changed files with 154 additions and 174 deletions.
2 changes: 0 additions & 2 deletions requests_oauth2client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from .api_client import ApiClient, InvalidBoolFieldsParam, InvalidPathParam
from .auth import (
BaseOAuth2RenewableTokenAuth,
NonRenewableTokenError,
OAuth2AccessTokenAuth,
OAuth2AuthorizationCodeAuth,
Expand Down Expand Up @@ -152,7 +151,6 @@
"BackChannelAuthenticationPoolingJob",
"BackChannelAuthenticationResponse",
"BaseClientAuthenticationMethod",
"BaseOAuth2RenewableTokenAuth",
"BaseTokenEndpointPoolingJob",
"BearerToken",
"BearerTokenSerializer",
Expand Down
24 changes: 12 additions & 12 deletions requests_oauth2client/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from urllib.parse import urljoin

import requests
from attrs import field, frozen
from attrs import frozen
from typing_extensions import Literal, Self

if TYPE_CHECKING:
Expand All @@ -21,10 +21,10 @@ class InvalidBoolFieldsParam(ValueError):

def __init__(self, bool_fields: object) -> None:
super().__init__("""\
Invalid value for 'bool_fields' parameter. It must be an iterable of 2 str values:
- first one for the True value
- second one for the False value
boolean fields in `data` or `params` with a boolean value (`True` or `False`)
Invalid value for `bool_fields` parameter. It must be an iterable of 2 `str` values:
- first one for the `True` value,
- second one for the `False` value.
Boolean fields in `data` or `params` with a boolean value (`True` or `False`)
will be serialized to the corresponding value.
Default is `('true', 'false')`
Use this parameter when the target API expects some other values, e.g.:
Expand All @@ -36,7 +36,7 @@ def __init__(self, bool_fields: object) -> None:


def validate_bool_fields(bool_fields: tuple[str, str]) -> tuple[str, str]:
"""Validate the `bool_fields` paremeter.
"""Validate the `bool_fields` parameter.
It must be a sequence of 2 values. First one is the `True` value, second one is the `False` value.
Both must be `str` or string-able values.
Expand Down Expand Up @@ -135,12 +135,12 @@ class ApiClient:
"""

base_url: str
auth: requests.auth.AuthBase | None = None
timeout: int | None = 60
raise_for_status: bool = True
none_fields: Literal["include", "exclude", "empty"] = "exclude"
bool_fields: tuple[Any, Any] | None = "true", "false"
session: requests.Session = field(factory=requests.Session)
auth: requests.auth.AuthBase | None
timeout: int | None
raise_for_status: bool
none_fields: Literal["include", "exclude", "empty"]
bool_fields: tuple[Any, Any] | None
session: requests.Session

def __init__(
self,
Expand Down
110 changes: 40 additions & 70 deletions requests_oauth2client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,53 @@ class NonRenewableTokenError(Exception):


@define(init=False)
class BaseOAuth2RenewableTokenAuth(requests.auth.AuthBase):
"""Base class for BearerToken-based Auth Handlers, with an obtainable or renewable token.
class OAuth2AccessTokenAuth(requests.auth.AuthBase):
"""Authentication Handler for OAuth 2.0 Access Tokens and (optional) Refresh Tokens.
This [Requests Auth handler][requests.auth.AuthBase] implementation uses an access token as
Bearer or DPoP token, and can automatically refresh it when expired, if a refresh token is available.
Token can be a simple `str` containing a raw access token value, or a
[BearerToken][requests_oauth2client.tokens.BearerToken] that can contain a `refresh_token`.
In addition to adding a properly formatted `Authorization` header, this will obtain a new token
once the current token is expired. Expiration is detected based on the `expires_in` hint
returned by the AS. A configurable `leeway`, in number of seconds, will make sure that a new
token is obtained some seconds before the actual expiration is reached. This may help in
situations where the client, AS and RS have slightly offset clocks.
Args:
client: the client to use to refresh tokens.
token: an initial Access Token, if you have one already. In most cases, leave `None`.
leeway: expiration leeway, in number of seconds.
**token_kwargs: additional kwargs to pass to the token endpoint.
Example:
```python
from requests_oauth2client import BearerToken, OAuth2Client, OAuth2AccessTokenAuth, requests
client = OAuth2Client(token_endpoint="https://my.as.local/token", auth=("client_id", "client_secret"))
# obtain a BearerToken any way you see fit, optionally including a refresh token
# for this example, the token value is hardcoded
token = BearerToken(access_token="access_token", expires_in=600, refresh_token="refresh_token")
auth = OAuth2AccessTokenAuth(client, token, scope="my_scope")
resp = requests.post("https://my.api.local/resource", auth=auth)
```
"""

client: OAuth2Client = field(on_setattr=setters.frozen)
token: BearerToken | None
leeway: int = field(on_setattr=setters.frozen)
token_kwargs: dict[str, Any] = field(on_setattr=setters.frozen)

def __init__(
self, client: OAuth2Client, token: str | BearerToken, *, leeway: int = 20, **token_kwargs: Any
) -> None:
if isinstance(token, str):
token = BearerToken(token)
self.__attrs_init__(client=client, token=token, leeway=leeway, token_kwargs=token_kwargs)

def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
"""Add the Access Token to the request.
Expand All @@ -55,40 +86,19 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques
def renew_token(self) -> None:
"""Obtain a new Bearer Token.
Subclasses should implement this.
This will try to use the `refresh_token`, if there is one.
"""
raise NotImplementedError
if self.token is not None and self.token.refresh_token is not None:
self.token = self.client.refresh_token(refresh_token=self.token, **self.token_kwargs)

def forget_token(self) -> None:
"""Forget the current token, forcing a renewal on the next HTTP request."""
self.token = None


@define(init=False)
class BaseOAuth2RefreshTokenAuth(BaseOAuth2RenewableTokenAuth):
"""Base class for flows which can have a refresh-token.
This implements a `renew_token()` method which uses the refresh token to obtain new tokens.
"""

@override
def renew_token(self) -> None:
"""Obtain a new token, using the Refresh Token, if available.
Raises:
NonRenewableTokenError: if the token is not renewable.
"""
if self.token is None or self.token.refresh_token is None:
raise NonRenewableTokenError

self.token = self.client.refresh_token(refresh_token=self.token, **self.token_kwargs)


@define(init=False)
class OAuth2ClientCredentialsAuth(BaseOAuth2RenewableTokenAuth):
class OAuth2ClientCredentialsAuth(OAuth2AccessTokenAuth):
"""An Auth Handler for the [Client Credentials grant](https://www.rfc-editor.org/rfc/rfc6749#section-4.4).
This [requests AuthBase][requests.auth.AuthBase] automatically gets Access Tokens from an OAuth
Expand Down Expand Up @@ -126,47 +136,7 @@ def renew_token(self) -> None:


@define(init=False)
class OAuth2AccessTokenAuth(BaseOAuth2RefreshTokenAuth):
"""Authentication Handler for OAuth 2.0 Access Tokens and (optional) Refresh Tokens.
This [Requests Auth handler][requests.auth.AuthBase] implementation uses an access token as
Bearer token, and can automatically refresh it when expired, if a refresh token is available.
Token can be a simple `str` containing a raw access token value, or a
[BearerToken][requests_oauth2client.tokens.BearerToken] that can contain a `refresh_token`.
If a `refresh_token` and an expiration date are available (based on `expires_in` hint),
this Auth Handler will automatically refresh the access token once it is expired.
Args:
client: the client to use to refresh tokens.
token: an initial Access Token, if you have one already. In most cases, leave `None`.
leeway: expiration leeway, in number of seconds.
**token_kwargs: additional kwargs to pass to the token endpoint.
Example:
```python
from requests_oauth2client import BearerToken, OAuth2Client, OAuth2AccessTokenAuth, requests
client = OAuth2Client(token_endpoint="https://my.as.local/token", auth=("client_id", "client_secret"))
# obtain a BearerToken any way you see fit, optionally including a refresh token
# for this example, the token value is hardcoded
token = BearerToken(access_token="access_token", expires_in=600, refresh_token="refresh_token")
auth = OAuth2AccessTokenAuth(client, token, scope="my_scope")
resp = requests.post("https://my.api.local/resource", auth=auth)
```
"""

def __init__(
self, client: OAuth2Client, token: str | BearerToken, *, leeway: int = 20, **token_kwargs: Any
) -> None:
if isinstance(token, str):
token = BearerToken(token)
self.__attrs_init__(client=client, token=token, leeway=leeway, token_kwargs=token_kwargs)


@define(init=False)
class OAuth2AuthorizationCodeAuth(BaseOAuth2RefreshTokenAuth): # type: ignore[override]
class OAuth2AuthorizationCodeAuth(OAuth2AccessTokenAuth): # type: ignore[override]
"""Authentication handler for the [Authorization Code grant](https://www.rfc-editor.org/rfc/rfc6749#section-4.1).
This [Requests Auth handler][requests.auth.AuthBase] implementation exchanges an Authorization
Expand Down Expand Up @@ -235,7 +205,7 @@ def exchange_code_for_token(self) -> None:


@define(init=False)
class OAuth2ResourceOwnerPasswordAuth(BaseOAuth2RenewableTokenAuth): # type: ignore[override]
class OAuth2ResourceOwnerPasswordAuth(OAuth2AccessTokenAuth): # type: ignore[override]
"""Authentication Handler for the [Resource Owner Password Credentials Flow](https://www.rfc-editor.org/rfc/rfc6749#section-4.3).
This [Requests Auth handler][requests.auth.AuthBase] implementation exchanges the user
Expand Down Expand Up @@ -313,7 +283,7 @@ def renew_token(self) -> None:


@define(init=False)
class OAuth2DeviceCodeAuth(BaseOAuth2RefreshTokenAuth): # type: ignore[override]
class OAuth2DeviceCodeAuth(OAuth2AccessTokenAuth): # type: ignore[override]
"""Authentication Handler for the [Device Code Flow](https://www.rfc-editor.org/rfc/rfc8628).
This [Requests Auth handler][requests.auth.AuthBase] implementation exchanges a Device Code for
Expand Down
Loading

0 comments on commit 8d0ca66

Please sign in to comment.