diff --git a/tests/functional/test_credentials.py b/tests/functional/test_credentials.py index 5981d1cedc..6c05d940b0 100644 --- a/tests/functional/test_credentials.py +++ b/tests/functional/test_credentials.py @@ -790,7 +790,8 @@ def assert_session_credentials(self, expected_params, **kwargs): expected_creds = self.create_random_credentials() response = self.create_assume_role_response(expected_creds) session = StubbedSession(**kwargs) - stubber = session.stub('sts') + config = Config(signature_version=UNSIGNED) + stubber = session.stub('sts', config=config) stubber.add_response( 'assume_role_with_web_identity', response, expected_params ) diff --git a/tests/unit/data/endpoints/valid-rules/aws-account-id.json b/tests/unit/data/endpoints/valid-rules/aws-account-id.json new file mode 100644 index 0000000000..debea100cd --- /dev/null +++ b/tests/unit/data/endpoints/valid-rules/aws-account-id.json @@ -0,0 +1,83 @@ +{ + "parameters": { + "Region": { + "type": "string", + "builtIn": "AWS::Region", + "documentation": "The region to dispatch this request, eg. `us-east-1`." + }, + "AccountId": { + "type": "string", + "builtIn": "AWS::Auth::AccountId", + "documentation": "The account ID to dispatch this request, eg. `123456789012`." + } + }, + "rules": [ + { + "documentation": "Template the account ID into the URI when account ID is set", + "conditions": [ + { + "fn": "isSet", + "argv": [ + { + "ref": "AccountId" + } + ] + }, + { + "fn": "isSet", + "argv": [ + { + "ref": "Region" + } + ] + } + ], + "endpoint": { + "url": "https://{AccountId}.amazonaws.com", + "properties": { + "authSchemes": [ + { + "name": "sigv4", + "signingName": "serviceName", + "signingRegion": "{Region}" + } + ] + } + }, + "type": "endpoint" + }, + { + "documentation": "Fallback when account ID isn't set", + "conditions": [ + { + "fn": "isSet", + "argv": [ + { + "ref": "Region" + } + ] + } + ], + "endpoint": { + "url": "https://amazonaws.com", + "properties": { + "authSchemes": [ + { + "name": "sigv4", + "signingName": "serviceName", + "signingRegion": "{Region}" + } + ] + } + }, + "type": "endpoint" + }, + { + "documentation": "fallback when region is unset", + "conditions": [], + "error": "Region must be set to resolve a valid endpoint", + "type": "error" + } + ], + "version": "1.3" +} diff --git a/tests/unit/test_args.py b/tests/unit/test_args.py index 8f1c992422..4d90c0d17d 100644 --- a/tests/unit/test_args.py +++ b/tests/unit/test_args.py @@ -614,6 +614,30 @@ def test_bad_value_disable_request_compression(self): config = client_args['client_config'] self.assertFalse(config.disable_request_compression) + def test_account_id_endpoint_mode_config_store(self): + self.config_store.set_config_variable( + 'account_id_endpoint_mode', 'preferred' + ) + config = self.call_get_client_args()['client_config'] + self.assertEqual(config.account_id_endpoint_mode, 'preferred') + + def test_account_id_endpoint_mode_client_config(self): + config = Config(account_id_endpoint_mode='preferred') + config = self.call_get_client_args(client_config=config) + client_config = config['client_config'] + self.assertEqual(client_config.account_id_endpoint_mode, 'preferred') + + def test_account_id_endpoint_mode_client_config_overrides_config_store( + self, + ): + self.config_store.set_config_variable( + 'account_id_endpoint_mode', 'preferred' + ) + config = Config(account_id_endpoint_mode='required') + config = self.call_get_client_args(client_config=config) + client_config = config['client_config'] + self.assertEqual(client_config.account_id_endpoint_mode, 'required') + class TestEndpointResolverBuiltins(unittest.TestCase): def setUp(self): @@ -679,6 +703,7 @@ def test_builtins_defaults(self): bins['AWS::S3::DisableMultiRegionAccessPoints'], False ) self.assertEqual(bins['SDK::Endpoint'], None) + self.assertEqual(bins['AWS::Auth::AccountId'], None) def test_aws_region(self): bins = self.call_compute_endpoint_resolver_builtin_defaults( diff --git a/tests/unit/test_endpoint_provider.py b/tests/unit/test_endpoint_provider.py index 51a07079bc..9cc241780d 100644 --- a/tests/unit/test_endpoint_provider.py +++ b/tests/unit/test_endpoint_provider.py @@ -18,6 +18,9 @@ import pytest +from botocore import UNSIGNED +from botocore.config import Config +from botocore.credentials import Credentials from botocore.endpoint_provider import ( EndpointProvider, EndpointRule, @@ -28,12 +31,14 @@ TreeRule, ) from botocore.exceptions import ( + AccountIDNotFound, EndpointResolutionError, + InvalidConfigError, MissingDependencyException, UnknownSignatureVersionError, ) from botocore.loaders import Loader -from botocore.regions import EndpointRulesetResolver +from botocore.regions import EndpointResolverBuiltins, EndpointRulesetResolver from tests import requires_crt REGION_TEMPLATE = "{Region}" @@ -98,8 +103,7 @@ def rule_lib(partitions): return RuleSetStandardLibrary(partitions) -@pytest.fixture(scope="module") -def ruleset_dict(): +def _ruleset_dict(): path = os.path.join( os.path.dirname(__file__), "data", @@ -111,6 +115,11 @@ def ruleset_dict(): return json.load(f) +@pytest.fixture(scope="module") +def ruleset_dict(): + return _ruleset_dict() + + @pytest.fixture(scope="module") def endpoint_provider(ruleset_dict, partitions): return EndpointProvider(ruleset_dict, partitions) @@ -511,3 +520,218 @@ def test_aws_is_virtual_hostable_s3_bucket_allow_subdomains( rule_lib.aws_is_virtual_hostable_s3_bucket(bucket, True) == expected_value ) + + +def _account_id_ruleset(): + rule_path = os.path.join( + os.path.dirname(__file__), + "data", + "endpoints", + "valid-rules", + "aws-account-id.json", + ) + with open(rule_path) as f: + return json.load(f) + + +@pytest.fixture +def operation_model_empty_context_params(): + operation_model = Mock() + operation_model.static_context_parameters = [] + operation_model.context_parameters = [] + return operation_model + + +ACCOUNT_ID_RULESET = _account_id_ruleset() +BUILTINS_WITH_UNRESOLVED_ACCOUNT_ID = { + EndpointResolverBuiltins.AWS_REGION: "us-west-2", + EndpointResolverBuiltins.AWS_ACCOUNT_ID: None, +} +BUILTINS_WITH_RESOLVED_ACCOUNT_ID = { + EndpointResolverBuiltins.AWS_REGION: "us-west-2", + EndpointResolverBuiltins.AWS_ACCOUNT_ID: "0987654321", +} +STATIC_CREDENTIALS = Credentials( + access_key="access_key", + secret_key="secret_key", + token="token", + account_id="1234567890", +) + + +def create_ruleset_resolver(ruleset, bulitins, credentials, auth_scheme): + service_model = Mock() + service_model.client_context_parameters = [] + return EndpointRulesetResolver( + endpoint_ruleset_data=ruleset, + partition_data={}, + service_model=service_model, + builtins=bulitins, + client_context=None, + event_emitter=Mock(), + use_ssl=True, + credentials=credentials, + requested_auth_scheme=auth_scheme, + ) + + +ACT_ID_REQUIRED_CONTEXT = { + "client_config": Config(account_id_endpoint_mode="required") +} +ACT_ID_PREFERRED_CONTEXT = { + "client_config": Config(account_id_endpoint_mode="preferred") +} +ACT_ID_DISABLED_CONTEXT = { + "client_config": Config(account_id_endpoint_mode="disabled") +} + +URL_NO_ACCOUNT_ID = "https://amazonaws.com" +URL_WITH_ACCOUNT_ID = "https://1234567890.amazonaws.com" + + +@pytest.mark.parametrize( + "ruleset, builtins, credentials, auth_scheme, request_context, expected_url", + [ + ( + ACCOUNT_ID_RULESET, + BUILTINS_WITH_UNRESOLVED_ACCOUNT_ID, + STATIC_CREDENTIALS, + None, + ACT_ID_REQUIRED_CONTEXT, + URL_WITH_ACCOUNT_ID, + ), + ( + ACCOUNT_ID_RULESET, + BUILTINS_WITH_UNRESOLVED_ACCOUNT_ID, + STATIC_CREDENTIALS, + None, + ACT_ID_PREFERRED_CONTEXT, + URL_WITH_ACCOUNT_ID, + ), + # custom account ID takes precedence over credentials + ( + ACCOUNT_ID_RULESET, + BUILTINS_WITH_RESOLVED_ACCOUNT_ID, + STATIC_CREDENTIALS, + None, + ACT_ID_REQUIRED_CONTEXT, + "https://0987654321.amazonaws.com", + ), + # no account ID builtin in ruleset + ( + _ruleset_dict(), + BUILTINS_WITH_UNRESOLVED_ACCOUNT_ID, + STATIC_CREDENTIALS, + None, + ACT_ID_REQUIRED_CONTEXT, + "https://us-west-2.amazonaws.com", + ), + ( + ACCOUNT_ID_RULESET, + BUILTINS_WITH_UNRESOLVED_ACCOUNT_ID, + STATIC_CREDENTIALS, + None, + ACT_ID_DISABLED_CONTEXT, + URL_NO_ACCOUNT_ID, + ), + # custom account ID removed if account ID mode is disabled + ( + ACCOUNT_ID_RULESET, + BUILTINS_WITH_RESOLVED_ACCOUNT_ID, + STATIC_CREDENTIALS, + None, + ACT_ID_DISABLED_CONTEXT, + URL_NO_ACCOUNT_ID, + ), + ( + ACCOUNT_ID_RULESET, + BUILTINS_WITH_UNRESOLVED_ACCOUNT_ID, + STATIC_CREDENTIALS, + UNSIGNED, + ACT_ID_REQUIRED_CONTEXT, + URL_NO_ACCOUNT_ID, + ), + ( + ACCOUNT_ID_RULESET, + BUILTINS_WITH_UNRESOLVED_ACCOUNT_ID, + STATIC_CREDENTIALS, + None, + {**ACT_ID_REQUIRED_CONTEXT, "is_presign_request": True}, + URL_NO_ACCOUNT_ID, + ), + # no credentials + ( + ACCOUNT_ID_RULESET, + BUILTINS_WITH_UNRESOLVED_ACCOUNT_ID, + None, + None, + ACT_ID_PREFERRED_CONTEXT, + URL_NO_ACCOUNT_ID, + ), + # no account ID in credentials + ( + ACCOUNT_ID_RULESET, + BUILTINS_WITH_UNRESOLVED_ACCOUNT_ID, + Credentials(access_key="foo", secret_key="bar", token="baz"), + None, + ACT_ID_PREFERRED_CONTEXT, + URL_NO_ACCOUNT_ID, + ), + ], +) +def test_account_id_builtin( + operation_model_empty_context_params, + ruleset, + builtins, + credentials, + auth_scheme, + request_context, + expected_url, +): + resolver = create_ruleset_resolver( + ruleset, builtins, credentials, auth_scheme + ) + endpoint = resolver.construct_endpoint( + operation_model=operation_model_empty_context_params, + request_context=request_context, + call_args={}, + ) + assert endpoint.url == expected_url + + +@pytest.mark.parametrize( + "credentials, auth_scheme, request_context, expected_error", + [ + ( + STATIC_CREDENTIALS, + None, + {'client_config': Config(account_id_endpoint_mode="foo")}, + InvalidConfigError, + ), + ( + Credentials(access_key="foo", secret_key="bar", token="baz"), + None, + ACT_ID_REQUIRED_CONTEXT, + AccountIDNotFound, + ), + ], +) +def test_account_id_error_cases( + operation_model_empty_context_params, + credentials, + auth_scheme, + request_context, + expected_error, +): + resolver = create_ruleset_resolver( + ACCOUNT_ID_RULESET, + BUILTINS_WITH_UNRESOLVED_ACCOUNT_ID, + credentials, + auth_scheme, + ) + with pytest.raises(expected_error): + resolver.construct_endpoint( + operation_model=operation_model_empty_context_params, + request_context=request_context, + call_args={}, + )