Skip to content

Commit

Permalink
add account_id_endpoint_mode config setting to endpoint resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlm committed Oct 3, 2023
1 parent 8afd119 commit 67372e5
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 3 deletions.
18 changes: 18 additions & 0 deletions botocore/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def get_client_args(
is_secure,
endpoint_bridge,
event_emitter,
credentials,
)

# Copy the session's user agent factory and adds client configuration.
Expand Down Expand Up @@ -266,11 +267,15 @@ def compute_client_args(
disable_request_compression=(
client_config.disable_request_compression
),
account_id_endpoint_mode=(
client_config.account_id_endpoint_mode
),
)
self._compute_retry_config(config_kwargs)
self._compute_connect_timeout(config_kwargs)
self._compute_user_agent_appid_config(config_kwargs)
self._compute_request_compression_config(config_kwargs)
self._compute_account_id_endpoint_mode(config_kwargs)
s3_config = self.compute_s3_config(client_config)

is_s3_service = self._is_s3_service(service_name)
Expand Down Expand Up @@ -597,6 +602,14 @@ def _validate_min_compression_size(self, min_size):

return min_size

def _compute_account_id_endpoint_mode(self, config_kwargs):
ep_mode = config_kwargs.get('account_id_endpoint_mode')
if ep_mode is None:
ep_mode = self._config_store.get_config_variable(
'account_id_endpoint_mode'
)
config_kwargs['account_id_endpoint_mode'] = ep_mode

def _ensure_boolean(self, val):
if isinstance(val, bool):
return val
Expand All @@ -616,6 +629,7 @@ def _build_endpoint_resolver(
is_secure,
endpoint_bridge,
event_emitter,
credentials,
):
if endpoints_ruleset_data is None:
return None
Expand Down Expand Up @@ -663,6 +677,7 @@ def _build_endpoint_resolver(
event_emitter=event_emitter,
use_ssl=is_secure,
requested_auth_scheme=sig_version,
credentials=credentials,
)

def compute_endpoint_resolver_builtin_defaults(
Expand Down Expand Up @@ -747,6 +762,9 @@ def compute_endpoint_resolver_builtin_defaults(
's3_disable_multiregion_access_points', False
),
EPRBuiltins.SDK_ENDPOINT: given_endpoint,
# account ID is calculated later if account based routing is
# enabled and configured for the service
EPRBuiltins.AWS_ACCOUNT_ID: None,
}

def _compute_user_agent_appid_config(self, config_kwargs):
Expand Down
7 changes: 7 additions & 0 deletions botocore/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ class Config:
:param disable_request_compression: Disables request body compression if
set to True.
Defaults to None.
:type account_id_endpoint_mode: str
:param account_id_endpoint_mode: Enables or disables account ID based
endpoint routing for supported operations.
Defaults to None.
"""

Expand Down Expand Up @@ -247,6 +253,7 @@ class Config:
('tcp_keepalive', None),
('request_min_compression_size_bytes', None),
('disable_request_compression', None),
('account_id_endpoint_mode', None),
]
)

Expand Down
6 changes: 6 additions & 0 deletions botocore/configprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@
False,
utils.ensure_boolean,
),
'account_id_endpoint_mode': (
'account_id_endpoint_mode',
'AWS_ACCOUNT_ID_ENDPOINT_MODE',
'preferred',
None,
),
}
# A mapping for the s3 specific configuration vars. These are the configuration
# vars that typically go in the s3 section of the config file. This mapping
Expand Down
4 changes: 4 additions & 0 deletions botocore/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,7 @@ class EndpointResolutionError(EndpointProviderError):

class UnknownEndpointResolutionBuiltInName(EndpointProviderError):
fmt = 'Unknown builtin variable name: {name}'


class AccountIDNotFound(EndpointResolutionError):
fmt = '`account_id_endpoint_mode is set to `required` but no account ID was found.'
2 changes: 0 additions & 2 deletions botocore/httpsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def create_urllib3_context(
context.set_ciphers(ciphers)
elif DEFAULT_CIPHERS:
context.set_ciphers(DEFAULT_CIPHERS)

# Setting the default here, as we may have no ssl module on import
cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs

Expand Down Expand Up @@ -362,7 +361,6 @@ def _get_proxy_manager(self, proxy_url):
proxy_manager = proxy_from_url(proxy_url, **proxy_manager_kwargs)
proxy_manager.pool_classes_by_scheme = self._pool_classes_by_scheme
self._proxy_managers[proxy_url] = proxy_manager

return self._proxy_managers[proxy_url]

def _path_url(self, url):
Expand Down
64 changes: 63 additions & 1 deletion botocore/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from botocore.crt import CRT_SUPPORTED_AUTH_TYPES
from botocore.endpoint_provider import EndpointProvider
from botocore.exceptions import (
AccountIDNotFound,
EndpointProviderError,
EndpointVariantError,
InvalidConfigError,
InvalidEndpointConfigurationError,
InvalidHostLabelError,
MissingDependencyException,
Expand All @@ -46,6 +48,12 @@
LOG = logging.getLogger(__name__)
DEFAULT_URI_TEMPLATE = '{service}.{region}.{dnsSuffix}' # noqa
DEFAULT_SERVICE_DATA = {'endpoints': {}}
# Allowed values for the ``account_id_endpoint_mode`` config field.
VALID_ACCOUNT_ID_ENDPOINT_MODES = [
'preferred',
'disabled',
'required',
]


class BaseEndpointResolver:
Expand Down Expand Up @@ -450,6 +458,8 @@ class EndpointResolverBuiltins(str, Enum):
AWS_S3_DISABLE_MRAP = "AWS::S3::DisableMultiRegionAccessPoints"
# Whether a custom endpoint has been configured (str)
SDK_ENDPOINT = "SDK::Endpoint"
# An account ID source from the credential resolution process.
AWS_ACCOUNT_ID = "AWS::Auth::AccountId"


class EndpointRulesetResolver:
Expand All @@ -465,6 +475,7 @@ def __init__(
event_emitter,
use_ssl=True,
requested_auth_scheme=None,
credentials=None,
):
self._provider = EndpointProvider(
ruleset_data=endpoint_ruleset_data,
Expand All @@ -478,6 +489,7 @@ def __init__(
self._use_ssl = use_ssl
self._requested_auth_scheme = requested_auth_scheme
self._instance_cache = {}
self._credentials = credentials

def construct_endpoint(
self,
Expand Down Expand Up @@ -546,6 +558,7 @@ def _get_provider_params(
customized_builtins = self._get_customized_builtins(
operation_model, call_args, request_context
)
self._resolve_account_id_builtin(request_context, customized_builtins)
for param_name, param_def in self._param_definitions.items():
param_val = self._resolve_param_from_context(
param_name=param_name,
Expand All @@ -559,9 +572,58 @@ def _get_provider_params(
)
if param_val is not None:
provider_params[param_name] = param_val

return provider_params

def _resolve_account_id_builtin(self, request_context, builtins):
"""Resolve the ``AWS::Auth::AccountId`` builtin if account ID based
routing is enabled.
"""
account_id_endpoint_mode = self._resolve_account_id_endpoint_mode(
request_context
)
if account_id_endpoint_mode == 'disabled':
return
# This will make a call to resolve credentials if they are not already
# or need to be refreshed.
frozen_creds = self._credentials.get_frozen_credentials()
account_id = frozen_creds.account_id
if account_id is None:
if account_id_endpoint_mode == 'preferred':
LOG.debug(
'`account_id_endpoint_mode` is set to `preferred`, but no '
'account ID was found.'
)
elif account_id_endpoint_mode == 'required':
raise AccountIDNotFound()
else:
builtins[EndpointResolverBuiltins.AWS_ACCOUNT_ID] = account_id

def _resolve_account_id_endpoint_mode(self, request_context):
"""Resolve the account ID endpoint mode for the request. Account ID
based routing is always disabled for presigned and unsigned requests.
Otherwise, the mode is determined by the ``account_id_endpoint_mode``
config setting.
"""
not_presign = not request_context.get('is_presign_request', False)
should_sign = self._requested_auth_scheme != UNSIGNED
creds_available = self._credentials is not None
if all((not_presign, should_sign, creds_available)):
config = request_context['client_config']
act_id_ep_mode = config.account_id_endpoint_mode
return self._validate_account_id_endpoint_mode(act_id_ep_mode)
return 'disabled'

def _validate_account_id_endpoint_mode(self, account_id_endpoint_mode):
if account_id_endpoint_mode not in VALID_ACCOUNT_ID_ENDPOINT_MODES:
valid_modes_str = ', '.join(VALID_ACCOUNT_ID_ENDPOINT_MODES)
error_msg = (
f"Invalid value '{account_id_endpoint_mode}' for "
"account_id_endpoint_mode. Valid values are: "
f"{valid_modes_str}"
)
raise InvalidConfigError(error_msg=error_msg)
return account_id_endpoint_mode

def _resolve_param_from_context(
self, param_name, operation_model, call_args
):
Expand Down

0 comments on commit 67372e5

Please sign in to comment.