Skip to content

Commit

Permalink
add account id to credentials where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlm committed Aug 11, 2023
1 parent 0103868 commit 2d05ce7
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 35 deletions.
68 changes: 58 additions & 10 deletions botocore/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@
InstanceMetadataFetcher,
JSONFileCache,
SSOTokenLoader,
ArnParser,
parse_key_val_file,
resolve_imds_endpoint_mode,
)

logger = logging.getLogger(__name__)
ReadOnlyCredentials = namedtuple(
'ReadOnlyCredentials', ['access_key', 'secret_key', 'token']
'ReadOnlyCredentials', ['access_key', 'secret_key', 'token', 'account_id']
)

_DEFAULT_MANDATORY_REFRESH_TIMEOUT = 10 * 60 # 10 min
Expand Down Expand Up @@ -305,14 +306,16 @@ class Credentials:
:param str access_key: The access key part of the credentials.
:param str secret_key: The secret key part of the credentials.
:param str token: The security token, valid only for session credentials.
:param str account_id: The account ID associated with the credentials.
:param str method: A string which identifies where the credentials
were found.
"""

def __init__(self, access_key, secret_key, token=None, method=None):
def __init__(self, access_key, secret_key, token=None, account_id=None, method=None):
self.access_key = access_key
self.secret_key = secret_key
self.token = token
self.account_id = account_id

if method is None:
method = 'explicit'
Expand All @@ -332,7 +335,7 @@ def _normalize(self):

def get_frozen_credentials(self):
return ReadOnlyCredentials(
self.access_key, self.secret_key, self.token
self.access_key, self.secret_key, self.token, self.account_id
)


Expand All @@ -344,6 +347,7 @@ class RefreshableCredentials(Credentials):
:param str access_key: The access key part of the credentials.
:param str secret_key: The secret key part of the credentials.
:param str token: The security token, valid only for session credentials.
:param str account_id: The account ID associated with the credentials.
:param function refresh_using: Callback function to refresh the credentials.
:param str method: A string which identifies where the credentials
were found.
Expand All @@ -362,6 +366,7 @@ def __init__(
access_key,
secret_key,
token,
account_id,
expiry_time,
refresh_using,
method,
Expand All @@ -371,12 +376,13 @@ def __init__(
self._access_key = access_key
self._secret_key = secret_key
self._token = token
self._account_id = account_id
self._expiry_time = expiry_time
self._time_fetcher = time_fetcher
self._refresh_lock = threading.Lock()
self.method = method
self._frozen_credentials = ReadOnlyCredentials(
access_key, secret_key, token
access_key, secret_key, token, account_id
)
self._normalize()

Expand All @@ -390,6 +396,7 @@ def create_from_metadata(cls, metadata, refresh_using, method):
access_key=metadata['access_key'],
secret_key=metadata['secret_key'],
token=metadata['token'],
account_id=metadata.get('account_id'),
expiry_time=cls._expiry_datetime(metadata['expiry_time']),
method=method,
refresh_using=refresh_using,
Expand Down Expand Up @@ -435,6 +442,19 @@ def token(self):
def token(self, value):
self._token = value

@property
def account_id(self):
"""Warning: Using this property can lead to race conditions if you
access another property subsequently along the refresh boundary.
Please use get_frozen_credentials instead.
"""
self._refresh()
return self._account_id

@account_id.setter
def account_id(self, value):
self._account_id = value

def _seconds_remaining(self):
delta = self._expiry_time - self._time_fetcher()
return total_seconds(delta)
Expand All @@ -443,7 +463,7 @@ def refresh_needed(self, refresh_in=None):
"""Check if a refresh is needed.
A refresh is needed if the expiry time associated
with the temporary credentials is less than the
with the temporary credentials is dfless than the
provided ``refresh_in``. If ``time_delta`` is not
provided, ``self.advisory_refresh_needed`` will be used.
Expand Down Expand Up @@ -531,7 +551,7 @@ def _protected_refresh(self, is_mandatory):
return
self._set_from_data(metadata)
self._frozen_credentials = ReadOnlyCredentials(
self._access_key, self._secret_key, self._token
self._access_key, self._secret_key, self._token, self._account_id
)
if self._is_expired():
# We successfully refreshed credentials but for whatever
Expand Down Expand Up @@ -571,6 +591,7 @@ def _set_from_data(self, data):
logger.debug(
"Retrieved credentials will expire at: %s", self._expiry_time
)
self._account_id = data.get('account_id')
self._normalize()

def get_frozen_credentials(self):
Expand Down Expand Up @@ -627,6 +648,7 @@ def __init__(self, refresh_using, method, time_fetcher=_local_now):
self._refresh_lock = threading.Lock()
self.method = method
self._frozen_credentials = None
self._account_id = None

def refresh_needed(self, refresh_in=None):
if self._frozen_credentials is None:
Expand Down Expand Up @@ -680,6 +702,7 @@ def _get_cached_credentials(self):
'secret_key': creds['SecretAccessKey'],
'token': creds['SessionToken'],
'expiry_time': expiration,
'account_id': response['AccountId'],
}

def _load_from_cache(self):
Expand Down Expand Up @@ -754,6 +777,10 @@ def _create_cache_key(self):
args = json.dumps(args, sort_keys=True)
argument_hash = sha1(args.encode('utf-8')).hexdigest()
return self._make_file_safe(argument_hash)

def _generate_account_id(self, resp):
user_arn = resp['AssumedRoleUser']['Arn']
return ArnParser().parse_arn(user_arn)['account']


class AssumeRoleCredentialFetcher(BaseAssumeRoleCredentialFetcher):
Expand Down Expand Up @@ -815,7 +842,9 @@ def _get_credentials(self):
"""Get credentials by calling assume role."""
kwargs = self._assume_role_kwargs()
client = self._create_client()
return client.assume_role(**kwargs)
resp = client.assume_role(**kwargs)
resp['AccountId'] = self._generate_account_id(resp)
return resp

def _assume_role_kwargs(self):
"""Get the arguments for assume role based on current configuration."""
Expand Down Expand Up @@ -902,7 +931,9 @@ def _get_credentials(self):
# the token, explicitly configure the client to not sign requests.
config = Config(signature_version=UNSIGNED)
client = self._client_creator('sts', config=config)
return client.assume_role_with_web_identity(**kwargs)
resp = client.assume_role_with_web_identity(**kwargs)
resp['AccountId'] = self._generate_account_id(resp)
return resp

def _assume_role_kwargs(self):
"""Get the arguments for assume role based on current configuration."""
Expand Down Expand Up @@ -985,6 +1016,7 @@ def load(self):
access_key=creds_dict['access_key'],
secret_key=creds_dict['secret_key'],
token=creds_dict.get('token'),
account_id=creds_dict.get('account_id'),
method=self.METHOD,
)

Expand Down Expand Up @@ -1016,6 +1048,7 @@ def _retrieve_credentials_using(self, credential_process):
'secret_key': parsed['SecretAccessKey'],
'token': parsed.get('SessionToken'),
'expiry_time': parsed.get('Expiration'),
'account_id': parsed.get('AccountId'),
}
except KeyError as e:
raise CredentialRetrievalError(
Expand Down Expand Up @@ -1071,6 +1104,7 @@ class EnvProvider(CredentialProvider):
# AWS_SESSION_TOKEN is what other AWS SDKs have standardized on.
TOKENS = ['AWS_SECURITY_TOKEN', 'AWS_SESSION_TOKEN']
EXPIRY_TIME = 'AWS_CREDENTIAL_EXPIRATION'
ACCOUNT_ID = 'AWS_ACCOUNT_ID'

def __init__(self, environ=None, mapping=None):
"""
Expand All @@ -1097,6 +1131,7 @@ def _build_mapping(self, mapping):
var_mapping['secret_key'] = self.SECRET_KEY
var_mapping['token'] = self.TOKENS
var_mapping['expiry_time'] = self.EXPIRY_TIME
var_mapping['account_id'] = self.ACCOUNT_ID
else:
var_mapping['access_key'] = mapping.get(
'access_key', self.ACCESS_KEY
Expand All @@ -1110,6 +1145,9 @@ def _build_mapping(self, mapping):
var_mapping['expiry_time'] = mapping.get(
'expiry_time', self.EXPIRY_TIME
)
var_mapping['account_id'] = mapping.get(
'account_id', self.ACCOUNT_ID
)
return var_mapping

def load(self):
Expand All @@ -1123,14 +1161,14 @@ def load(self):
logger.info('Found credentials in environment variables.')
fetcher = self._create_credentials_fetcher()
credentials = fetcher(require_expiry=False)

expiry_time = credentials['expiry_time']
if expiry_time is not None:
expiry_time = parse(expiry_time)
return RefreshableCredentials(
credentials['access_key'],
credentials['secret_key'],
credentials['token'],
credentials['account_id'],
expiry_time,
refresh_using=fetcher,
method=self.METHOD,
Expand All @@ -1140,6 +1178,7 @@ def load(self):
credentials['access_key'],
credentials['secret_key'],
credentials['token'],
credentials['account_id'],
method=self.METHOD,
)
else:
Expand Down Expand Up @@ -1182,6 +1221,10 @@ def fetch_credentials(require_expiry=True):
raise PartialCredentialsError(
provider=method, cred_var=mapping['expiry_time']
)
credentials['account_id'] = None
account_id = environ.get(mapping['account_id'], '')
if account_id:
credentials['account_id'] = account_id

return credentials

Expand Down Expand Up @@ -1281,6 +1324,7 @@ class ConfigProvider(CredentialProvider):
# aws_security_token, but the SDKs are standardizing on aws_session_token
# so we support both.
TOKENS = ['aws_security_token', 'aws_session_token']
ACCOUNT_ID = 'aws_account_id'

def __init__(self, config_filename, profile_name, config_parser=None):
"""
Expand Down Expand Up @@ -1316,9 +1360,10 @@ def load(self):
access_key, secret_key = self._extract_creds_from_mapping(
profile_config, self.ACCESS_KEY, self.SECRET_KEY
)
account_id = profile_config.get(self.ACCOUNT_ID)
token = self._get_session_token(profile_config)
return Credentials(
access_key, secret_key, token, method=self.METHOD
access_key, secret_key, token, account_id, method=self.METHOD
)
else:
return None
Expand Down Expand Up @@ -1679,6 +1724,7 @@ def _resolve_static_credentials_from_profile(self, profile):
access_key=profile['aws_access_key_id'],
secret_key=profile['aws_secret_access_key'],
token=profile.get('aws_session_token'),
account_id=profile.get('aws_account_id'),
)
except KeyError as e:
raise PartialCredentialsError(
Expand Down Expand Up @@ -1912,6 +1958,7 @@ def _retrieve_or_fail(self):
access_key=creds['access_key'],
secret_key=creds['secret_key'],
token=creds['token'],
account_id=None,
method=self.METHOD,
expiry_time=_parse_if_needed(creds['expiry_time']),
refresh_using=fetcher,
Expand Down Expand Up @@ -2128,6 +2175,7 @@ def _get_credentials(self):

credentials = {
'ProviderType': 'sso',
'AccountId': self._account_id,
'Credentials': {
'AccessKeyId': credentials['accessKeyId'],
'SecretAccessKey': credentials['secretAccessKey'],
Expand Down
3 changes: 2 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def __init__(
if refresh_function is None:
refresh_function = self._do_refresh
super().__init__(
'0', '0', '0', expires_in, refresh_function, 'INTREFRESH'
'0', '0', '0', '0', expires_in, refresh_function, 'INTREFRESH'
)
self.creds_last_for = creds_last_for
self.refresh_counter = 0
Expand All @@ -337,6 +337,7 @@ def _do_refresh(self):
'secret_key': next_id,
'token': next_id,
'expiry_time': self._seconds_later(self.creds_last_for),
'account_id': next_id,
}

def _seconds_later(self, num_seconds):
Expand Down
Loading

0 comments on commit 2d05ce7

Please sign in to comment.