diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 2abf318002b8..470f32b7b067 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -1,9 +1,11 @@ # Release History -## 1.31.1 (Unreleased) +## 1.32.0 (Unreleased) ### Features Added +- Added a default implementation to handle token challenges in `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy`. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/core/azure-core/azure/core/_version.py b/sdk/core/azure-core/azure/core/_version.py index 10fcd28a3fcf..1c43dbb9b140 100644 --- a/sdk/core/azure-core/azure/core/_version.py +++ b/sdk/core/azure-core/azure/core/_version.py @@ -9,4 +9,4 @@ # regenerated. # -------------------------------------------------------------------------- -VERSION = "1.31.1" +VERSION = "1.32.0" diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index dc3e23de37c8..537270038eee 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -4,6 +4,7 @@ # license information. # ------------------------------------------------------------------------- import time +import base64 from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast from azure.core.credentials import ( TokenCredential, @@ -19,6 +20,7 @@ from azure.core.rest import HttpResponse, HttpRequest from . import HTTPPolicy, SansIOHTTPPolicy from ...exceptions import ServiceRequestError +from ._utils import get_challenge_parameter if TYPE_CHECKING: @@ -82,13 +84,7 @@ def _need_new_token(self) -> bool: refresh_on = getattr(self._token, "refresh_on", None) return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 - def _request_token(self, *scopes: str, **kwargs: Any) -> None: - """Request a new token from the credential. - - This will call the credential's appropriate method to get a token and store it in the policy. - - :param str scopes: The type of access needed. - """ + def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]: if self._enable_cae: kwargs.setdefault("enable_cae", self._enable_cae) @@ -99,9 +95,17 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> None: if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member options[key] = kwargs.pop(key) # type: ignore[literal-required] - self._token = cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) - else: - self._token = cast(TokenCredential, self._credential).get_token(*scopes, **kwargs) + return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) + return cast(TokenCredential, self._credential).get_token(*scopes, **kwargs) + + def _request_token(self, *scopes: str, **kwargs: Any) -> None: + """Request a new token from the credential. + + This will call the credential's appropriate method to get a token and store it in the policy. + + :param str scopes: The type of access needed. + """ + self._token = self._get_token(*scopes, **kwargs) class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]): @@ -191,6 +195,22 @@ def on_challenge( :rtype: bool """ # pylint:disable=unused-argument + headers = response.http_response.headers + error = get_challenge_parameter(headers, "Bearer", "error") + if error == "insufficient_claims": + encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") + if not encoded_claims: + return False + try: + padding_needed = -len(encoded_claims) % 4 + claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8") + if claims: + token = self._get_token(*self._scopes, claims=claims) + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token + request.http_request.headers["Authorization"] = "Bearer " + bearer_token + return True + except Exception: # pylint:disable=broad-except + return False return False def on_response( diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index 7fb68a606a39..f97b8df3b7b2 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -4,6 +4,7 @@ # license information. # ------------------------------------------------------------------------- import time +import base64 from typing import Any, Awaitable, Optional, cast, TypeVar, Union from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions @@ -23,6 +24,7 @@ ) from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.utils._utils import get_running_async_lock +from ._utils import get_challenge_parameter from .._tools_async import await_result @@ -138,6 +140,22 @@ async def on_challenge( :rtype: bool """ # pylint:disable=unused-argument + headers = response.http_response.headers + error = get_challenge_parameter(headers, "Bearer", "error") + if error == "insufficient_claims": + encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") + if not encoded_claims: + return False + try: + padding_needed = -len(encoded_claims) % 4 + claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8") + if claims: + token = await self._get_token(*self._scopes, claims=claims) + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token + request.http_request.headers["Authorization"] = "Bearer " + bearer_token + return True + except Exception: # pylint:disable=broad-except + return False return False def on_response( @@ -169,13 +187,7 @@ def _need_new_token(self) -> bool: refresh_on = getattr(self._token, "refresh_on", None) return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 - async def _request_token(self, *scopes: str, **kwargs: Any) -> None: - """Request a new token from the credential. - - This will call the credential's appropriate method to get a token and store it in the policy. - - :param str scopes: The type of access needed. - """ + async def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]: if self._enable_cae: kwargs.setdefault("enable_cae", self._enable_cae) @@ -186,14 +198,22 @@ async def _request_token(self, *scopes: str, **kwargs: Any) -> None: if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member options[key] = kwargs.pop(key) # type: ignore[literal-required] - self._token = await await_result( + return await await_result( cast(AsyncSupportsTokenInfo, self._credential).get_token_info, *scopes, options=options, ) - else: - self._token = await await_result( - cast(AsyncTokenCredential, self._credential).get_token, - *scopes, - **kwargs, - ) + return await await_result( + cast(AsyncTokenCredential, self._credential).get_token, + *scopes, + **kwargs, + ) + + async def _request_token(self, *scopes: str, **kwargs: Any) -> None: + """Request a new token from the credential. + + This will call the credential's appropriate method to get a token and store it in the policy. + + :param str scopes: The type of access needed. + """ + self._token = await self._get_token(*scopes, **kwargs) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py index 1733632a9ab2..dce2c45bc5a3 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_utils.py @@ -25,7 +25,7 @@ # -------------------------------------------------------------------------- import datetime import email.utils -from typing import Optional, cast, Union +from typing import Optional, cast, Union, Tuple from urllib.parse import urlparse from azure.core.pipeline.transport import ( @@ -102,3 +102,103 @@ def get_domain(url: str) -> str: :return: The domain of the url. """ return str(urlparse(url).netloc).lower() + + +def get_challenge_parameter(headers, challenge_scheme: str, challenge_parameter: str) -> Optional[str]: + """ + Parses the specified parameter from a challenge header found in the response. + + :param dict[str, str] headers: The response headers to parse. + :param str challenge_scheme: The challenge scheme containing the challenge parameter, e.g., "Bearer". + :param str challenge_parameter: The parameter key name to search for. + :return: The value of the parameter name if found. + :rtype: str or None + """ + header_value = headers.get("WWW-Authenticate") + if not header_value: + return None + + scheme = challenge_scheme + parameter = challenge_parameter + header_span = header_value + + # Iterate through each challenge value. + while True: + challenge = get_next_challenge(header_span) + if not challenge: + break + challenge_key, header_span = challenge + if challenge_key.lower() != scheme.lower(): + continue + # Enumerate each key-value parameter until we find the parameter key on the specified scheme challenge. + while True: + parameters = get_next_parameter(header_span) + if not parameters: + break + key, value, header_span = parameters + if key.lower() == parameter.lower(): + return value + + return None + + +def get_next_challenge(header_value: str) -> Optional[Tuple[str, str]]: + """ + Iterates through the challenge schemes present in a challenge header. + + :param str header_value: The header value which will be sliced to remove the first parsed challenge key. + :return: The parsed challenge scheme and the remaining header value. + :rtype: tuple[str, str] or None + """ + header_value = header_value.lstrip(" ") + end_of_challenge_key = header_value.find(" ") + + if end_of_challenge_key < 0: + return None + + challenge_key = header_value[:end_of_challenge_key] + header_value = header_value[end_of_challenge_key + 1 :] + + return challenge_key, header_value + + +def get_next_parameter(header_value: str, separator: str = "=") -> Optional[Tuple[str, str, str]]: + """ + Iterates through a challenge header value to extract key-value parameters. + + :param str header_value: The header value after being parsed by get_next_challenge. + :param str separator: The challenge parameter key-value pair separator, default is '='. + :return: The next available challenge parameter as a tuple (param_key, param_value, remaining header_value). + :rtype: tuple[str, str, str] or None + """ + space_or_comma = " ," + header_value = header_value.lstrip(space_or_comma) + + next_space = header_value.find(" ") + next_separator = header_value.find(separator) + + if next_space < next_separator and next_space != -1: + return None + + if next_separator < 0: + return None + + param_key = header_value[:next_separator].strip() + header_value = header_value[next_separator + 1 :] + + quote_index = header_value.find('"') + + if quote_index >= 0: + header_value = header_value[quote_index + 1 :] + param_value = header_value[: header_value.find('"')] + else: + trailing_delimiter_index = header_value.find(" ") + if trailing_delimiter_index >= 0: + param_value = header_value[:trailing_delimiter_index] + else: + param_value = header_value + + if header_value != param_value: + header_value = header_value[len(param_value) + 1 :] + + return param_key, param_value, header_value diff --git a/sdk/core/azure-core/tests/test_utils.py b/sdk/core/azure-core/tests/test_utils.py index c09b48c9c5c5..015557dbec8e 100644 --- a/sdk/core/azure-core/tests/test_utils.py +++ b/sdk/core/azure-core/tests/test_utils.py @@ -8,7 +8,7 @@ import pytest from azure.core.utils import case_insensitive_dict from azure.core.utils._utils import get_running_async_lock -from azure.core.pipeline.policies._utils import parse_retry_after +from azure.core.pipeline.policies._utils import parse_retry_after, get_challenge_parameter @pytest.fixture() @@ -146,3 +146,58 @@ def test_parse_retry_after(): assert ret == 0 ret = parse_retry_after("0.9") assert ret == 0.9 + + +def test_get_challenge_parameter(): + headers = { + "WWW-Authenticate": 'Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"' + } + assert ( + get_challenge_parameter(headers, "Bearer", "authorization_uri") == "https://login.microsoftonline.com/tenant-id" + ) + assert get_challenge_parameter(headers, "Bearer", "resource") == "https://vault.azure.net" + assert get_challenge_parameter(headers, "Bearer", "foo") is None + + headers = { + "WWW-Authenticate": 'Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="' + } + assert ( + get_challenge_parameter(headers, "Bearer", "authorization_uri") + == "https://login.microsoftonline.com/common/oauth2/authorize" + ) + assert get_challenge_parameter(headers, "Bearer", "error") == "insufficient_claims" + assert ( + get_challenge_parameter(headers, "Bearer", "claims") + == "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==" + ) + + +def test_get_challenge_parameter_not_found(): + headers = { + "WWW-Authenticate": 'Pop authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"' + } + assert get_challenge_parameter(headers, "Bearer", "resource") is None + + +def test_get_multi_challenge_parameter(): + headers = { + "WWW-Authenticate": 'Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net" Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"' + } + assert ( + get_challenge_parameter(headers, "Bearer", "authorization_uri") == "https://login.microsoftonline.com/tenant-id" + ) + assert get_challenge_parameter(headers, "Bearer", "resource") == "https://vault.azure.net" + assert get_challenge_parameter(headers, "Bearer", "foo") is None + + headers = { + "WWW-Authenticate": 'Digest realm="foo@test.com", qop="auth,auth-int", nonce="123456abcdefg", opaque="123456", Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="' + } + assert ( + get_challenge_parameter(headers, "Bearer", "authorization_uri") + == "https://login.microsoftonline.com/common/oauth2/authorize" + ) + assert get_challenge_parameter(headers, "Bearer", "error") == "insufficient_claims" + assert ( + get_challenge_parameter(headers, "Bearer", "claims") + == "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ==" + )