Skip to content

Commit

Permalink
unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlm committed Sep 28, 2023
1 parent c5b2915 commit 8afd119
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 3 deletions.
3 changes: 2 additions & 1 deletion tests/functional/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def create_assume_role_response(self, credentials, expiration=None):
},
'AssumedRoleUser': {
'AssumedRoleId': 'myroleid',
'Arn': 'arn:aws:iam::1234567890:user/myuser',
'Arn': f'arn:aws:iam::{credentials.account_id}:user/myuser',
},
}

Expand All @@ -209,6 +209,7 @@ def create_random_credentials(self):
'fake-%s' % random_chars(15),
'fake-%s' % random_chars(35),
'fake-%s' % random_chars(45),
account_id='fake-%s' % random_chars(12),
)

def assert_creds_equal(self, c1, c2):
Expand Down
10 changes: 8 additions & 2 deletions tests/integration/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def test_access_secret_vs_profile_code(self, credentials_cls):
)

credentials_cls.assert_called_with(
access_key='code', secret_key='code-secret', token=mock.ANY
access_key='code',
secret_key='code-secret',
token=mock.ANY,
account_id=mock.ANY,
)

def test_profile_env_vs_code(self):
Expand All @@ -97,7 +100,10 @@ def test_access_secret_env_vs_code(self, credentials_cls):
)

credentials_cls.assert_called_with(
access_key='code', secret_key='code-secret', token=mock.ANY
access_key='code',
secret_key='code-secret',
token=mock.ANY,
account_id=mock.ANY,
)

def test_access_secret_env_vs_profile_code(self):
Expand Down
251 changes: 251 additions & 0 deletions tests/unit/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ def test_refresh_returns_partial_credentials(self):
with self.assertRaises(botocore.exceptions.CredentialRetrievalError):
self.creds.access_key

def test_account_id_refresh(self):
metadata = self.metadata.copy()
metadata['account_id'] = '123456789012'
self.refresher.return_value = metadata
self.mock_time.return_value = datetime.now(tzlocal())
self.assertTrue(self.creds.refresh_needed())
self.assertEqual(self.creds.account_id, '123456789012')


class TestDeferredRefreshableCredentials(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -257,6 +265,7 @@ def get_expected_creds_from_response(self, response):
'secret_key': response['Credentials']['SecretAccessKey'],
'token': response['Credentials']['SessionToken'],
'expiry_time': expiration,
'account_id': None,
}

def some_future_time(self):
Expand Down Expand Up @@ -688,6 +697,49 @@ def test_mfa_refresh_enabled(self):
]
self.assertEqual(calls, expected_calls)

def test_account_id(self):
response = {
'Credentials': {
'AccessKeyId': 'foo',
'SecretAccessKey': 'bar',
'SessionToken': 'baz',
'Expiration': self.some_future_time().isoformat(),
},
'AssumedRoleUser': {
'AssumedRoleId': 'ARO123EXAMPLE123:myrole',
'Arn': 'arn:aws:sts::123456789012:assumed-role/myrole',
},
}
client_creator = self.create_client_creator(with_response=response)
refresher = credentials.AssumeRoleCredentialFetcher(
client_creator, self.source_creds, self.role_arn
)
expected_response = self.get_expected_creds_from_response(response)
expected_response['account_id'] = '123456789012'
response = refresher.fetch_credentials()
self.assertEqual(response, expected_response)

def test_account_id_invalid_arn(self):
response = {
'Credentials': {
'AccessKeyId': 'foo',
'SecretAccessKey': 'bar',
'SessionToken': 'baz',
'Expiration': self.some_future_time().isoformat(),
},
'AssumedRoleUser': {
'AssumedRoleId': 'ARO123EXAMPLE123:myrole',
'Arn': 'foo',
},
}
client_creator = self.create_client_creator(with_response=response)
refresher = credentials.AssumeRoleCredentialFetcher(
client_creator, self.source_creds, self.role_arn
)
expected_response = self.get_expected_creds_from_response(response)
response = refresher.fetch_credentials()
self.assertEqual(response, expected_response)


class TestAssumeRoleWithWebIdentityCredentialFetcher(BaseEnvVar):
def setUp(self):
Expand Down Expand Up @@ -720,6 +772,7 @@ def get_expected_creds_from_response(self, response):
'secret_key': response['Credentials']['SecretAccessKey'],
'token': response['Credentials']['SessionToken'],
'expiry_time': expiration,
'account_id': None,
}

def test_no_cache(self):
Expand Down Expand Up @@ -795,6 +848,53 @@ def test_assume_role_in_cache_but_expired(self):

self.assertEqual(response, expected)

def test_account_id(self):
response = {
'Credentials': {
'AccessKeyId': 'foo',
'SecretAccessKey': 'bar',
'SessionToken': 'baz',
'Expiration': self.some_future_time().isoformat(),
},
'AssumedRoleUser': {
'AssumedRoleId': 'ARO123EXAMPLE123:myrole',
'Arn': 'arn:aws:sts::123456789012:assumed-role/myrole',
},
}
client_creator = self.create_client_creator(with_response=response)
refresher = credentials.AssumeRoleWithWebIdentityCredentialFetcher(
client_creator,
self.load_token,
self.role_arn,
)
expected_response = self.get_expected_creds_from_response(response)
expected_response['account_id'] = '123456789012'
response = refresher.fetch_credentials()
self.assertEqual(response, expected_response)

def test_account_id_invalid_arn(self):
response = {
'Credentials': {
'AccessKeyId': 'foo',
'SecretAccessKey': 'bar',
'SessionToken': 'baz',
'Expiration': self.some_future_time().isoformat(),
},
'AssumedRoleUser': {
'AssumedRoleId': 'ARO123EXAMPLE123:myrole',
'Arn': 'foo',
},
}
client_creator = self.create_client_creator(with_response=response)
refresher = credentials.AssumeRoleWithWebIdentityCredentialFetcher(
client_creator,
self.load_token,
self.role_arn,
)
expected_response = self.get_expected_creds_from_response(response)
response = refresher.fetch_credentials()
self.assertEqual(response, expected_response)


class TestAssumeRoleWithWebIdentityCredentialProvider(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -1012,6 +1112,22 @@ def test_envvars_found_with_session_token(self):
self.assertEqual(creds.token, 'baz')
self.assertEqual(creds.method, 'env')

def test_envvars_found_with_account_id(self):
environ = {
'AWS_ACCESS_KEY_ID': 'foo',
'AWS_SECRET_ACCESS_KEY': 'bar',
'AWS_SESSION_TOKEN': 'baz',
'AWS_ACCOUNT_ID': '123456789012',
}
provider = credentials.EnvProvider(environ)
creds = provider.load()
self.assertIsNotNone(creds)
self.assertEqual(creds.access_key, 'foo')
self.assertEqual(creds.secret_key, 'bar')
self.assertEqual(creds.token, 'baz')
self.assertEqual(creds.account_id, '123456789012')
self.assertEqual(creds.method, 'env')

def test_envvars_not_found(self):
provider = credentials.EnvProvider(environ={})
creds = provider.load()
Expand Down Expand Up @@ -1119,6 +1235,22 @@ def test_can_override_expiry_env_var_mapping(self):
with self.assertRaisesRegex(RuntimeError, error_message):
creds.get_frozen_credentials()

def test_can_override_account_id_env_var_mapping(self):
environ = {
'AWS_ACCESS_KEY_ID': 'foo',
'AWS_SECRET_ACCESS_KEY': 'bar',
'AWS_SESSION_TOKEN': 'baz',
'FOO_ACCOUNT_ID': '123456789012',
}
provider = credentials.EnvProvider(
environ, {'account_id': 'FOO_ACCOUNT_ID'}
)
creds = provider.load()
self.assertEqual(creds.access_key, 'foo')
self.assertEqual(creds.secret_key, 'bar')
self.assertEqual(creds.token, 'baz')
self.assertEqual(creds.account_id, '123456789012')

def test_partial_creds_is_an_error(self):
# If the user provides an access key, they must also
# provide a secret key. Not doing so will generate an
Expand Down Expand Up @@ -1390,6 +1522,28 @@ def test_credentials_file_does_not_exist_returns_none(self):
creds = provider.load()
self.assertIsNone(creds)

def test_credentials_file_exists_with_account_id(self):
self.ini_parser.return_value = {
'default': {
'aws_access_key_id': 'foo',
'aws_secret_access_key': 'bar',
'aws_session_token': 'baz',
'aws_account_id': '123456789012',
}
}
provider = credentials.SharedCredentialProvider(
creds_filename='~/.aws/creds',
profile_name='default',
ini_parser=self.ini_parser,
)
creds = provider.load()
self.assertIsNotNone(creds)
self.assertEqual(creds.access_key, 'foo')
self.assertEqual(creds.secret_key, 'bar')
self.assertEqual(creds.token, 'baz')
self.assertEqual(creds.method, 'shared-credentials-file')
self.assertEqual(creds.account_id, '123456789012')


class TestConfigFileProvider(BaseEnvVar):
def setUp(self):
Expand Down Expand Up @@ -1452,6 +1606,24 @@ def test_partial_creds_is_error(self):
with self.assertRaises(botocore.exceptions.PartialCredentialsError):
provider.load()

def test_config_account_id(self):
profile_config = {
'aws_access_key_id': 'a',
'aws_secret_access_key': 'b',
'aws_session_token': 'c',
'aws_account_id': '123456789012',
}
parsed = {'profiles': {'default': profile_config}}
parser = mock.Mock()
parser.return_value = parsed
provider = credentials.ConfigProvider('cli.cfg', 'default', parser)
creds = provider.load()
self.assertIsNotNone(creds)
self.assertEqual(creds.access_key, 'a')
self.assertEqual(creds.secret_key, 'b')
self.assertEqual(creds.token, 'c')
self.assertEqual(creds.account_id, '123456789012')


class TestBotoProvider(BaseEnvVar):
def setUp(self):
Expand Down Expand Up @@ -3407,6 +3579,82 @@ def test_missing_expiration_and_session_token(self):
self.assertIsNone(creds.token)
self.assertEqual(creds.method, 'custom-process')

def test_missing_account_id(self):
self.loaded_config['profiles'] = {
'default': {'credential_process': 'my-process'}
}
self._set_process_return_value(
{
'Version': 1,
'AccessKeyId': 'foo',
'SecretAccessKey': 'bar',
'SessionToken': 'baz',
'Expiration': '2999-01-01T00:00:00Z',
# Missing AccountId.
}
)

provider = self.create_process_provider()
creds = provider.load()
self.assertIsNotNone(creds)
self.assertEqual(creds.access_key, 'foo')
self.assertEqual(creds.secret_key, 'bar')
self.assertEqual(creds.token, 'baz')
self.assertEqual(creds.method, 'custom-process')
self.assertIsNone(creds.account_id)

def test_account_id_from_profile(self):
self.loaded_config['profiles'] = {
'default': {
'credential_process': 'my-process',
'aws_account_id': '1234567890',
}
}
self._set_process_return_value(
{
'Version': 1,
'AccessKeyId': 'foo',
'SecretAccessKey': 'bar',
'SessionToken': 'baz',
'Expiration': '2999-01-01T00:00:00Z',
# Missing AccountId.
}
)
provider = self.create_process_provider()
creds = provider.load()
self.assertIsNotNone(creds)
self.assertEqual(creds.access_key, 'foo')
self.assertEqual(creds.secret_key, 'bar')
self.assertEqual(creds.token, 'baz')
self.assertEqual(creds.method, 'custom-process')
self.assertEqual(creds.account_id, '1234567890')

def test_account_id_from_process_takes_precedence(self):
self.loaded_config['profiles'] = {
'default': {
'credential_process': 'my-process',
'aws_account_id': '1234567890',
}
}
self._set_process_return_value(
{
'Version': 1,
'AccessKeyId': 'foo',
'SecretAccessKey': 'bar',
'SessionToken': 'baz',
'Expiration': '2999-01-01T00:00:00Z',
'AccountId': '0987654321',
}
)
provider = self.create_process_provider()
creds = provider.load()
self.assertIsNotNone(creds)
self.assertEqual(creds.access_key, 'foo')
self.assertEqual(creds.secret_key, 'bar')
self.assertEqual(creds.token, 'baz')
self.assertEqual(creds.method, 'custom-process')
self.assertEqual(creds.account_id, '0987654321')


class TestProfileProviderBuilder(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -3494,6 +3742,7 @@ def test_can_fetch_credentials(self):
'SecretAccessKey': 'bar',
'SessionToken': 'baz',
'Expiration': '2008-09-23T12:43:20Z',
'AccountId': '1234567890',
},
}
self.assertEqual(self.cache[cache_key], expected_cached_credentials)
Expand Down Expand Up @@ -3587,6 +3836,7 @@ def test_load_sso_credentials_without_cache(self):
self.assertEqual(credentials.access_key, 'foo')
self.assertEqual(credentials.secret_key, 'bar')
self.assertEqual(credentials.token, 'baz')
self.assertEqual(credentials.account_id, '1234567890')

def test_load_sso_credentials_with_cache(self):
cached_creds = {
Expand Down Expand Up @@ -3620,6 +3870,7 @@ def test_load_sso_credentials_with_cache_expired(self):
self.assertEqual(credentials.access_key, 'foo')
self.assertEqual(credentials.secret_key, 'bar')
self.assertEqual(credentials.token, 'baz')
self.assertEqual(credentials.account_id, '1234567890')

def test_required_config_not_set(self):
del self.config['sso_start_url']
Expand Down
Loading

0 comments on commit 8afd119

Please sign in to comment.