Skip to content

Commit

Permalink
Merge pull request #2953 from kdaily/shallow-copy-config-value-store
Browse files Browse the repository at this point in the history
Shallow copy config value store and providers
  • Loading branch information
kdaily committed May 27, 2023
2 parents f104576 + 249dbcd commit 619a317
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 49 deletions.
5 changes: 5 additions & 0 deletions .changes/next-release/enhancement-configprovider-27540.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "configprovider",
"description": "Always use shallow copy of session config value store for clients"
}
27 changes: 19 additions & 8 deletions botocore/configprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,13 @@ def __deepcopy__(self, memo):

return config_store

def __copy__(self):
config_store = ConfigValueStore(copy.copy(self._mapping))
for logical_name, override_value in self._overrides.items():
config_store.set_config_variable(logical_name, override_value)

return config_store

def get_config_variable(self, logical_name):
"""
Retrieve the value associeated with the specified logical_name
Expand Down Expand Up @@ -546,24 +553,28 @@ def resolve_auto_mode(self, region_name):
return 'standard'

def _update_provider(self, config_store, variable, value):
provider = config_store.get_config_provider(variable)
original_provider = config_store.get_config_provider(variable)
default_provider = ConstantProvider(value)
if isinstance(provider, ChainProvider):
provider.set_default_provider(default_provider)
return
elif isinstance(provider, BaseProvider):
if isinstance(original_provider, ChainProvider):
chain_provider_copy = copy.deepcopy(original_provider)
chain_provider_copy.set_default_provider(default_provider)
default_provider = chain_provider_copy
elif isinstance(original_provider, BaseProvider):
default_provider = ChainProvider(
providers=[provider, default_provider]
providers=[original_provider, default_provider]
)
config_store.set_config_provider(variable, default_provider)

def _update_section_provider(
self, config_store, section_name, variable, value
):
section_provider = config_store.get_config_provider(section_name)
section_provider.set_default_provider(
section_provider_copy = copy.deepcopy(
config_store.get_config_provider(section_name)
)
section_provider_copy.set_default_provider(
variable, ConstantProvider(value)
)
config_store.set_config_provider(section_name, section_provider_copy)

def _set_retryMode(self, config_store, value):
self._update_provider(config_store, 'retry_mode', value)
Expand Down
4 changes: 2 additions & 2 deletions botocore/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,13 +952,13 @@ def create_client(
auth_token = self.get_auth_token()
endpoint_resolver = self._get_internal_component('endpoint_resolver')
exceptions_factory = self._get_internal_component('exceptions_factory')
config_store = self.get_component('config_store')

config_store = copy.copy(self.get_component('config_store'))
defaults_mode = self._resolve_defaults_mode(config, config_store)
if defaults_mode != 'legacy':
smart_defaults_factory = self._get_internal_component(
'smart_defaults_factory'
)
config_store = copy.deepcopy(config_store)
smart_defaults_factory.merge_smart_defaults(
config_store, defaults_mode, region_name
)
Expand Down
55 changes: 55 additions & 0 deletions tests/functional/models/sdk-default-configuration.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
{
"version": 1,
"base": {
"retryMode": "standard",
"stsRegionalEndpoints": "regional",
"s3UsEast1RegionalEndpoints": "regional",
"connectTimeoutInMillis": 9999000,
"tlsNegotiationTimeoutInMillis": 9999000
},
"modes": {
"standard": {
"connectTimeoutInMillis": {
"override": 9999000
},
"tlsNegotiationTimeoutInMillis": {
"override": 9999000
}
},
"in-region": {
},
"cross-region": {
"connectTimeoutInMillis": {
"override": 9999000
},
"tlsNegotiationTimeoutInMillis": {
"override": 9999000
}
},
"mobile": {
"connectTimeoutInMillis": {
"override": 99999000
},
"tlsNegotiationTimeoutInMillis": {
"override": 99999000
}
}
},
"documentation": {
"modes": {
"standard": "<p>FOR TESTING ONLY: The STANDARD mode provides the latest recommended default values that should be safe to run in most scenarios</p><p>Note that the default values vended from this mode might change as best practices may evolve. As a result, it is encouraged to perform tests when upgrading the SDK</p>",
"in-region": "<p>FOR TESTING ONLY: The IN_REGION mode builds on the standard mode and includes optimization tailored for applications which call AWS services from within the same AWS region</p><p>Note that the default values vended from this mode might change as best practices may evolve. As a result, it is encouraged to perform tests when upgrading the SDK</p>",
"cross-region": "<p>FOR TESTING ONLY: The CROSS_REGION mode builds on the standard mode and includes optimization tailored for applications which call AWS services in a different region</p><p>Note that the default values vended from this mode might change as best practices may evolve. As a result, it is encouraged to perform tests when upgrading the SDK</p>",
"mobile": "<p>FOR TESTING ONLY: The MOBILE mode builds on the standard mode and includes optimization tailored for mobile applications</p><p>Note that the default values vended from this mode might change as best practices may evolve. As a result, it is encouraged to perform tests when upgrading the SDK</p>",
"auto": "<p>FOR TESTING ONLY: The AUTO mode is an experimental mode that builds on the standard mode. The SDK will attempt to discover the execution environment to determine the appropriate settings automatically.</p><p>Note that the auto detection is heuristics-based and does not guarantee 100% accuracy. STANDARD mode will be used if the execution environment cannot be determined. The auto detection might query <a href=\"https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html\">EC2 Instance Metadata service</a>, which might introduce latency. Therefore we recommend choosing an explicit defaults_mode instead if startup latency is critical to your application</p>",
"legacy": "<p>FOR TESTING ONLY: The LEGACY mode provides default settings that vary per SDK and were used prior to establishment of defaults_mode</p>"
},
"configuration": {
"retryMode": "<p>FOR TESTING ONLY: A retry mode specifies how the SDK attempts retries. See <a href=\"https://docs.aws.amazon.com/sdkref/latest/guide/setting-global-retry_mode.html\">Retry Mode</a></p>",
"stsRegionalEndpoints": "<p>FOR TESTING ONLY: Specifies how the SDK determines the AWS service endpoint that it uses to talk to the AWS Security Token Service (AWS STS). See <a href=\"https://docs.aws.amazon.com/sdkref/latest/guide/setting-global-sts_regional_endpoints.html\">Setting STS Regional endpoints</a></p>",
"s3UsEast1RegionalEndpoints": "<p>FOR TESTING ONLY: Specifies how the SDK determines the AWS service endpoint that it uses to talk to the Amazon S3 for the us-east-1 region</p>",
"connectTimeoutInMillis": "<p>FOR TESTING ONLY: The amount of time after making an initial connection attempt on a socket, where if the client does not receive a completion of the connect handshake, the client gives up and fails the operation</p>",
"tlsNegotiationTimeoutInMillis": "<p>FOR TESTING ONLY: The maximum amount of time that a TLS handshake is allowed to take from the time the CLIENT HELLO message is sent to ethe time the client and server have fully negotiated ciphers and exchanged keys</p>"
}
}
}
76 changes: 74 additions & 2 deletions tests/functional/test_config_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from pathlib import Path

import pytest

import botocore.exceptions
from botocore.config import Config
from botocore.session import get_session

Expand All @@ -28,6 +31,13 @@
sdk_default_configuration = loader.load_data('sdk-default-configuration')


def assert_client_uses_standard_defaults(client):
assert client.meta.config.s3['us_east_1_regional_endpoint'] == 'regional'
assert client.meta.config.connect_timeout == 3.1
assert client.meta.endpoint_url == 'https://sts.us-west-2.amazonaws.com'
assert client.meta.config.retries['mode'] == 'standard'


@pytest.mark.parametrize("mode", sdk_default_configuration['base'])
def test_no_new_sdk_default_configuration_values(mode):
err_msg = (
Expand All @@ -45,7 +55,69 @@ def test_default_configurations_resolve_correctly():
client = session.create_client(
'sts', config=config, region_name='us-west-2'
)
assert_client_uses_standard_defaults(client)


@pytest.fixture
def loader():
test_models_dir = Path(__file__).parent / 'models'
loader = botocore.loaders.Loader()
loader.search_paths.insert(0, test_models_dir)
return loader


@pytest.fixture
def session(loader):
session = botocore.session.Session()
session.register_component('data_loader', loader)
return session


def assert_client_uses_legacy_defaults(client):
assert client.meta.config.s3 is None
assert client.meta.config.connect_timeout == 60
assert client.meta.endpoint_url == 'https://sts.amazonaws.com'
assert client.meta.config.retries['mode'] == 'legacy'


def assert_client_uses_testing_defaults(client):
assert client.meta.config.s3['us_east_1_regional_endpoint'] == 'regional'
assert client.meta.config.connect_timeout == 3.1
assert client.meta.endpoint_url == 'https://sts.us-west-2.amazonaws.com'
assert client.meta.config.connect_timeout == 9999
assert client.meta.endpoint_url == 'https://sts.amazonaws.com'
assert client.meta.config.retries['mode'] == 'standard'


class TestConfigurationDefaults:
def test_defaults_mode_resolved_from_config_store(self, session):
config_store = session.get_component('config_store')
config_store.set_config_variable('defaults_mode', 'standard')
client = session.create_client('sts', 'us-west-2')
assert_client_uses_testing_defaults(client)

def test_no_mutate_session_provider(self, session):
# Using the standard default mode should change the connect timeout
# on the client, but not the session
standard_client = session.create_client(
'sts', 'us-west-2', config=Config(defaults_mode='standard')
)
assert_client_uses_testing_defaults(standard_client)

# Using the legacy default mode should not change the connect timeout
# on the client or the session. By default the connect timeout for a client
# is 60 seconds, and unset on the session.
legacy_client = session.create_client('sts', 'us-west-2')
assert_client_uses_legacy_defaults(legacy_client)

def test_defaults_mode_resolved_from_client_config(self, session):
config = Config(defaults_mode='standard')
client = session.create_client('sts', 'us-west-2', config=config)
assert_client_uses_testing_defaults(client)

def test_defaults_mode_resolved_invalid_mode_exception(self, session):
with pytest.raises(botocore.exceptions.InvalidDefaultsMode):
config = Config(defaults_mode='invalid_default_mode')
session.create_client('sts', 'us-west-2', config=config)

def test_defaults_mode_resolved_legacy(self, session):
client = session.create_client('sts', 'us-west-2')
assert_client_uses_legacy_defaults(client)
121 changes: 112 additions & 9 deletions tests/unit/test_config_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,50 @@ def test_deepcopy_preserves_overrides(self):
value = config_store_deepcopy.get_config_variable('fake_variable')
self.assertEqual(value, 'override-value')

def test_copy_preserves_provider_identities(self):
fake_variable_provider = ConstantProvider(100)
config_store = ConfigValueStore(
mapping={
'fake_variable': fake_variable_provider,
}
)

config_store_copy = copy.copy(config_store)

self.assertIs(
config_store.get_config_provider('fake_variable'),
config_store_copy.get_config_provider('fake_variable'),
)

def test_copy_preserves_overrides(self):
provider = ConstantProvider(100)
config_store = ConfigValueStore(mapping={'fake_variable': provider})
config_store.set_config_variable('fake_variable', 'override-value')

config_store_copy = copy.copy(config_store)

value = config_store_copy.get_config_variable('fake_variable')
self.assertEqual(value, 'override-value')

def test_copy_update_does_not_mutate_source_config_store(self):
fake_variable_provider = ConstantProvider(100)
config_store = ConfigValueStore(
mapping={
'fake_variable': fake_variable_provider,
}
)

config_store_copy = copy.copy(config_store)

another_variable_provider = ConstantProvider('ABC')

config_store_copy.set_config_provider(
'fake_variable', another_variable_provider
)

assert config_store.get_config_variable('fake_variable') == 100
assert config_store_copy.get_config_variable('fake_variable') == 'ABC'


class TestInstanceVarProvider(unittest.TestCase):
def assert_provides_value(self, name, instance_map, expected_value):
Expand Down Expand Up @@ -643,17 +687,19 @@ def fake_session(self):
return fake_session

def _create_config_value_store(self, s3_mapping={}, **override_kwargs):
provider_foo = ConstantProvider(value='foo')
environment_provider_foo = EnvironmentProvider(
constant_provider = ConstantProvider(value='my_sts_regional_endpoint')
environment_provider = EnvironmentProvider(
name='AWS_RETRY_MODE', env={'AWS_RETRY_MODE': None}
)
fake_session = mock.Mock(spec=session.Session)
fake_session.get_scoped_config.return_value = {}
# Testing with three different providers to validate
# SmartDefaultsConfigStoreFactory._get_new_chain_provider
mapping = {
'sts_regional_endpoints': ChainProvider(providers=[provider_foo]),
'retry_mode': ChainProvider(providers=[environment_provider_foo]),
'sts_regional_endpoints': ChainProvider(
providers=[constant_provider]
),
'retry_mode': ChainProvider(providers=[environment_provider]),
's3': SectionConfigProvider('s3', fake_session, s3_mapping),
}
mapping.update(**override_kwargs)
Expand All @@ -667,11 +713,68 @@ def _create_os_environ_patcher(self):

def test_config_store_deepcopy(self):
config_store = ConfigValueStore()
config_store.set_config_provider('foo', ConstantProvider('bar'))
config_store.set_config_provider(
'constant_value', ConstantProvider('ABC')
)
config_store_copy = copy.deepcopy(config_store)
config_store_copy.set_config_provider('fizz', ConstantProvider('buzz'))
assert config_store.get_config_variable('fizz') is None
assert config_store_copy.get_config_variable('foo') == 'bar'
config_store_copy.set_config_provider(
'constant_value_copy', ConstantProvider('123')
)
assert config_store.get_config_variable('constant_value_copy') is None
assert config_store_copy.get_config_variable('constant_value') == 'ABC'

def _create_config_value_store_to_test_merge(self):
environment_provider = EnvironmentProvider(
name='AWS_S3_US_EAST_1_REGIONAL_ENDPOINT',
env={},
)

s3_mapping = {
'us_east_1_regional_endpoint': ChainProvider(
providers=[environment_provider]
)
}

override_kwargs = {'connect_timeout': ConstantProvider(value=None)}

config_value_store = self._create_config_value_store(
s3_mapping=s3_mapping, **override_kwargs
)

return config_value_store

@pytest.mark.parametrize(
'config_variable,expected_value_before,expected_value_after',
[
['retry_mode', None, 'standard'],
['sts_regional_endpoints', 'my_sts_regional_endpoint', 'regional'],
['connect_timeout', None, 2],
['s3', None, {'us_east_1_regional_endpoint': 'regional'}],
],
)
def test_config_store_providers_not_mutated_after_merge(
self,
config_variable,
expected_value_before,
expected_value_after,
smart_defaults_factory,
):
"""Test uses the standard default mode from the template"""

config_value_store = self._create_config_value_store_to_test_merge()

provider = config_value_store.get_config_provider(config_variable)

smart_defaults_factory.merge_smart_defaults(
config_value_store, 'standard', 'some-region'
)

assert provider.provide() == expected_value_before

assert (
config_value_store.get_config_variable(config_variable)
== expected_value_after
)

@pytest.mark.parametrize(
'defaults_mode, retry_mode, sts_regional_endpoints,'
Expand Down Expand Up @@ -720,7 +823,7 @@ def test_resolve_default_values_on_config(
assert config_store.get_config_variable('connect_timeout') == 2

def test_no_resolve_default_s3_values_on_config(
self, smart_defaults_factory, fake_session
self, smart_defaults_factory
):
environment_provider = EnvironmentProvider(
name='AWS_S3_US_EAST_1_REGIONAL_ENDPOINT',
Expand Down
Loading

0 comments on commit 619a317

Please sign in to comment.