Skip to content

Commit

Permalink
feat: add experimental GDCH support
Browse files Browse the repository at this point in the history
  • Loading branch information
arithmetic1728 committed May 5, 2022
1 parent de1fd41 commit 95dfce8
Show file tree
Hide file tree
Showing 7 changed files with 574 additions and 19 deletions.
41 changes: 41 additions & 0 deletions google/auth/_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
_SERVICE_ACCOUNT_TYPE = "service_account"
_EXTERNAL_ACCOUNT_TYPE = "external_account"
_IMPERSONATED_SERVICE_ACCOUNT_TYPE = "impersonated_service_account"
_GDCH_SERVICE_ACCOUNT_TYPE = "gdch_service_account"
_VALID_TYPES = (
_AUTHORIZED_USER_TYPE,
_SERVICE_ACCOUNT_TYPE,
_EXTERNAL_ACCOUNT_TYPE,
_IMPERSONATED_SERVICE_ACCOUNT_TYPE,
_GDCH_SERVICE_ACCOUNT_TYPE,
)

# Help message when no credentials can be found.
Expand Down Expand Up @@ -158,6 +160,8 @@ def _load_credentials_from_info(
credentials, project_id = _get_impersonated_service_account_credentials(
filename, info, scopes
)
elif credential_type == _GDCH_SERVICE_ACCOUNT_TYPE:
credentials, project_id = _get_gdch_service_account_credentials(info)
else:
raise exceptions.DefaultCredentialsError(
"The file {file} does not have a valid type. "
Expand Down Expand Up @@ -421,6 +425,36 @@ def _get_impersonated_service_account_credentials(filename, info, scopes):
return credentials, None


def _get_gdch_service_account_credentials(info):
from google.oauth2 import gdch_credentials

k8s_ca_cert_path = info.get("k8s_ca_cert_path")
k8s_cert_path = info.get("k8s_cert_path")
k8s_key_path = info.get("k8s_key_path")
k8s_token_endpoint = info.get("k8s_token_endpoint")
ais_ca_cert_path = info.get("ais_ca_cert_path")
ais_token_endpoint = info.get("ais_token_endpoint")

format_version = info.get("format_version")
if format_version != "v1":
raise exceptions.DefaultCredentialsError(
"format_version is not provided or unsupported. Supported version is: v1"
)

return (
gdch_credentials.ServiceAccountCredentials(
k8s_ca_cert_path,
k8s_cert_path,
k8s_key_path,
k8s_token_endpoint,
ais_ca_cert_path,
ais_token_endpoint,
None,
),
None,
)


def _apply_quota_project_id(credentials, quota_project_id):
if quota_project_id:
credentials = credentials.with_quota_project(quota_project_id)
Expand Down Expand Up @@ -456,6 +490,11 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non
endpoint.
The project ID returned in this case is the one corresponding to the
underlying workload identity pool resource if determinable.
If the environment variable is set to the path of a valid GDCH service
account JSON file (`Google Distributed Cloud Hosted`_), then a GDCH
credential will be returned. The project ID returned is None unless it
is set via `GOOGLE_CLOUD_PROJECT` environment variable.
2. If the `Google Cloud SDK`_ is installed and has application default
credentials set they are loaded and returned.
Expand Down Expand Up @@ -490,6 +529,8 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non
.. _Metadata Service: https://cloud.google.com/compute/docs\
/storing-retrieving-metadata
.. _Cloud Run: https://cloud.google.com/run
.. _Google Distributed Cloud Hosted: https://cloud.google.com/blog/topics\
/hybrid-cloud/announcing-google-distributed-cloud-edge-and-hosted
Example::
Expand Down
72 changes: 54 additions & 18 deletions google/oauth2/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ def _handle_error_response(response_data):
"""Translates an error response into an exception.
Args:
response_data (Mapping): The decoded response data.
response_data (Mapping | str): The decoded response data.
Raises:
google.auth.exceptions.RefreshError: The errors contained in response_data.
"""
if isinstance(response_data, six.string_types):
raise exceptions.RefreshError(response_data)
try:
error_details = "{}: {}".format(
response_data["error"], response_data.get("error_description")
Expand Down Expand Up @@ -79,7 +81,13 @@ def _parse_expiry(response_data):


def _token_endpoint_request_no_throw(
request, token_uri, body, access_token=None, use_json=False
request,
token_uri,
body,
access_token=None,
use_json=False,
expected_status_code=http_client.OK,
**kwargs
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
This function doesn't throw on response errors.
Expand All @@ -93,6 +101,10 @@ def _token_endpoint_request_no_throw(
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
expected_status_code (Optional(int)): The expected the status code of
the token response. The default value is 200. We may expect other
status code like 201 for GDCH credentials.
kwargs: Additional arguments passed on to the request method.
Returns:
Tuple(bool, Mapping[str, str]): A boolean indicating if the request is
Expand All @@ -112,32 +124,46 @@ def _token_endpoint_request_no_throw(
# retry to fetch token for maximum of two times if any internal failure
# occurs.
while True:
response = request(method="POST", url=token_uri, headers=headers, body=body)
response = request(
method="POST", url=token_uri, headers=headers, body=body, **kwargs
)
response_body = (
response.data.decode("utf-8")
if hasattr(response.data, "decode")
else response.data
)
response_data = json.loads(response_body)

if response.status == http_client.OK:
if response.status == expected_status_code:
# response_body should be a JSON
response_data = json.loads(response_body)
break
else:
error_desc = response_data.get("error_description") or ""
error_code = response_data.get("error") or ""
if (
any(e == "internal_failure" for e in (error_code, error_desc))
and retry < 1
):
retry += 1
continue
return response.status == http_client.OK, response_data

return response.status == http_client.OK, response_data
# For a failed response, response_body could be a string
try:
response_data = json.loads(response_body)
error_desc = response_data.get("error_description") or ""
error_code = response_data.get("error") or ""
if (
any(e == "internal_failure" for e in (error_code, error_desc))
and retry < 1
):
retry += 1
continue
except ValueError:
response_data = response_body
return response.status == expected_status_code, response_data

return response.status == expected_status_code, response_data


def _token_endpoint_request(
request, token_uri, body, access_token=None, use_json=False
request,
token_uri,
body,
access_token=None,
use_json=False,
expected_status_code=http_client.OK,
**kwargs
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
Expand All @@ -150,6 +176,10 @@ def _token_endpoint_request(
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
expected_status_code (Optional(int)): The expected the status code of
the token response. The default value is 200. We may expect other
status code like 201 for GDCH credentials.
kwargs: Additional arguments passed on to the request method.
Returns:
Mapping[str, str]: The JSON-decoded response data.
Expand All @@ -159,7 +189,13 @@ def _token_endpoint_request(
an error.
"""
response_status_ok, response_data = _token_endpoint_request_no_throw(
request, token_uri, body, access_token=access_token, use_json=use_json
request,
token_uri,
body,
access_token=access_token,
use_json=use_json,
expected_status_code=expected_status_code,
**kwargs
)
if not response_status_ok:
_handle_error_response(response_data)
Expand Down
Loading

0 comments on commit 95dfce8

Please sign in to comment.