Skip to content

Commit

Permalink
added endpoint impl
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlm committed Aug 18, 2023
1 parent 3cff200 commit ece90ed
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 11 deletions.
35 changes: 35 additions & 0 deletions botocore/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@
# Maximum allowed length of the ``user_agent_appid`` config field. Longer
# values result in a warning-level log message.
USERAGENT_APPID_MAXLEN = 50
# Allowed values for the ``account_id_endpoint_mode`` config field.
VALID_ACCOUNT_ID_ENDPOINT_MODES = [
'preferred',
'disabled',
'required',
]


class ClientArgsCreator:
Expand Down Expand Up @@ -165,6 +171,7 @@ def get_client_args(
is_secure,
endpoint_bridge,
event_emitter,
credentials,
)

# Copy the session's user agent factory and adds client configuration.
Expand Down Expand Up @@ -266,11 +273,15 @@ def compute_client_args(
disable_request_compression=(
client_config.disable_request_compression
),
account_id_endpoint_mode=(
client_config.account_id_endpoint_mode
),
)
self._compute_retry_config(config_kwargs)
self._compute_connect_timeout(config_kwargs)
self._compute_user_agent_appid_config(config_kwargs)
self._compute_request_compression_config(config_kwargs)
self._compute_account_id_endpoint_mode(config_kwargs)
s3_config = self.compute_s3_config(client_config)

is_s3_service = self._is_s3_service(service_name)
Expand Down Expand Up @@ -616,6 +627,7 @@ def _build_endpoint_resolver(
is_secure,
endpoint_bridge,
event_emitter,
credentials,
):
if endpoints_ruleset_data is None:
return None
Expand All @@ -640,6 +652,7 @@ def _build_endpoint_resolver(
endpoint_bridge=endpoint_bridge,
client_endpoint_url=endpoint_url,
legacy_endpoint_url=endpoint.host,
credentials=credentials,
)
# botocore does not support client context parameters generically
# for every service. Instead, the s3 config section entries are
Expand Down Expand Up @@ -673,6 +686,7 @@ def compute_endpoint_resolver_builtin_defaults(
endpoint_bridge,
client_endpoint_url,
legacy_endpoint_url,
credentials,
):
# EndpointRulesetResolver rulesets may accept an "SDK::Endpoint" as
# input. If the endpoint_url argument of create_client() is set, it
Expand Down Expand Up @@ -701,6 +715,9 @@ def compute_endpoint_resolver_builtin_defaults(
else:
force_path_style = s3_config.get('addressing_style') == 'path'

account_id = None
if credentials is not None:
account_id = credentials.account_id
return {
EPRBuiltins.AWS_REGION: region_name,
EPRBuiltins.AWS_USE_FIPS: (
Expand Down Expand Up @@ -747,6 +764,7 @@ def compute_endpoint_resolver_builtin_defaults(
's3_disable_multiregion_access_points', False
),
EPRBuiltins.SDK_ENDPOINT: given_endpoint,
EPRBuiltins.AWS_ACCOUNT_ID: account_id,
}

def _compute_user_agent_appid_config(self, config_kwargs):
Expand All @@ -764,3 +782,20 @@ def _compute_user_agent_appid_config(self, config_kwargs):
f'maximum length of {USERAGENT_APPID_MAXLEN} characters.'
)
config_kwargs['user_agent_appid'] = user_agent_appid

def _compute_account_id_endpoint_mode(self, config_kwargs):
account_id_endpoint_mode = config_kwargs.get(
'account_id_endpoint_mode'
)
if account_id_endpoint_mode is None:
account_id_endpoint_mode = self._config_store.get_config_variable(
'account_id_endpoint_mode'
)
if account_id_endpoint_mode not in VALID_ACCOUNT_ID_ENDPOINT_MODES:
valid_modes_str = ', '.join(VALID_ACCOUNT_ID_ENDPOINT_MODES)
error_msg = (
f"Invalid value '{account_id_endpoint_mode}' for "
f"account_id_endpoint_mode. Valid values are: {valid_modes_str}"
)
raise botocore.exceptions.InvalidConfigError(error_msg=error_msg)
config_kwargs['account_id_endpoint_mode'] = account_id_endpoint_mode
7 changes: 7 additions & 0 deletions botocore/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ class Config:
set to True.
Defaults to None.
:type account_id_endpoint_mode: str
:param account_id_endpoint_mode: Enables or disables account ID based
endpoint routing for supported operations.
Defaults to None.
"""

OPTION_DEFAULTS = OrderedDict(
Expand Down Expand Up @@ -247,6 +253,7 @@ class Config:
('tcp_keepalive', None),
('request_min_compression_size_bytes', None),
('disable_request_compression', None),
('account_id_endpoint_mode', None),
]
)

Expand Down
6 changes: 6 additions & 0 deletions botocore/configprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@
False,
utils.ensure_boolean,
),
'account_id_endpoint_mode': (
'account_id_endpoint_mode',
'AWS_ACCOUNT_ID_ENDPOINT_MODE',
'preferred',
None,
),
}
# A mapping for the s3 specific configuration vars. These are the configuration
# vars that typically go in the s3 section of the config file. This mapping
Expand Down
23 changes: 15 additions & 8 deletions botocore/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def _refresh(self):
self._protected_refresh(is_mandatory=True)

def _protected_refresh(self, is_mandatory):
# breakpoint()
# precondition: this method should only be called if you've acquired
# the self._refresh_lock.
try:
Expand Down Expand Up @@ -680,7 +681,7 @@ def _make_file_safe(self, filename):

def _get_credentials(self):
raise NotImplementedError('_get_credentials()')

def _resolve_account_id(self, response=None):
raise NotImplementedError('_resolve_account_id()')

Expand Down Expand Up @@ -1058,15 +1059,14 @@ def _retrieve_credentials_using(self, credential_process):
provider=self.METHOD,
error_msg=f"Missing required key in response: {e}",
)

def _resolve_account_id(self, parsed_response):
account_id = parsed_response.get('AccountId')
if account_id:
return account_id
return self._profile_config.get('aws_account_id')
return parsed_response.get('AccountId') or self.profile_config.get(
'aws_account_id'
)

@property
def _profile_config(self):
def profile_config(self):
if self._loaded_config is None:
self._loaded_config = self._load_config()
return self._loaded_config.get('profiles', {}).get(
Expand Down Expand Up @@ -1317,7 +1317,11 @@ def load(self):
token = self._get_session_token(config)
account_id = config.get(self.ACCOUNT_ID)
return Credentials(
access_key, secret_key, token, account_id, method=self.METHOD
access_key,
secret_key,
token,
account_id,
method=self.METHOD,
)

def _get_session_token(self, config):
Expand Down Expand Up @@ -2167,6 +2171,9 @@ def _parse_timestamp(self, timestamp_ms):
timestamp = datetime.datetime.fromtimestamp(timestamp_seconds, tzutc())
return timestamp.strftime(self._UTC_DATE_FORMAT)

def _resolve_account_id(self, response=None):
return self._account_id

def _get_credentials(self):
"""Get credentials by calling SSO get role credentials."""
config = Config(
Expand Down
3 changes: 2 additions & 1 deletion botocore/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,8 @@ class EndpointResolverBuiltins(str, Enum):
AWS_S3_DISABLE_MRAP = "AWS::S3::DisableMultiRegionAccessPoints"
# Whether a custom endpoint has been configured (str)
SDK_ENDPOINT = "SDK::Endpoint"
# An account ID source from the credential resolution process.
AWS_ACCOUNT_ID = "AWS::Auth::AccountId"


class EndpointRulesetResolver:
Expand Down Expand Up @@ -559,7 +561,6 @@ def _get_provider_params(
)
if param_val is not None:
provider_params[param_name] = param_val

return provider_params

def _resolve_param_from_context(
Expand Down
19 changes: 17 additions & 2 deletions tests/functional/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def create_assume_role_response(self, credentials, expiration=None):
},
'AssumedRoleUser': {
'AssumedRoleId': 'myroleid',
'Arn': 'arn:aws:iam::1234567890:user/myuser',
'Arn': f'arn:aws:iam::{credentials.account_id}:user/myuser',
},
}

Expand All @@ -209,6 +209,7 @@ def create_random_credentials(self):
'fake-%s' % random_chars(15),
'fake-%s' % random_chars(35),
'fake-%s' % random_chars(45),
'fake-%s' % random_chars(12),
)

def assert_creds_equal(self, c1, c2):
Expand Down Expand Up @@ -789,7 +790,11 @@ def assert_session_credentials(self, expected_params, **kwargs):
expected_creds = self.create_random_credentials()
response = self.create_assume_role_response(expected_creds)
session = StubbedSession(**kwargs)
stubber = session.stub('sts')
stubber = session.stub(
'sts',
aws_access_key_id='spam',
aws_secret_access_key='eggs',
)
stubber.add_response(
'assume_role_with_web_identity', response, expected_params
)
Expand Down Expand Up @@ -1143,3 +1148,13 @@ def add_credential_response(self, stubber):
}
}
stubber.add_response(body=json.dumps(response).encode('utf-8'))


def _load_account_id_test_cases():
path = os.path.join(
os.path.dirname(__file__),
'credentials',
'accountid-source-testcases.json',
)
with open(os.path.join(path, 'account_id_test_cases.json')) as f:
return json.load(f)

0 comments on commit ece90ed

Please sign in to comment.