diff --git a/botocore/awsrequest.py b/botocore/awsrequest.py index d12f935945..f00a0dde57 100644 --- a/botocore/awsrequest.py +++ b/botocore/awsrequest.py @@ -12,8 +12,6 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import functools -import gzip -import io import logging from collections.abc import Mapping @@ -536,92 +534,6 @@ def reset_stream(self): raise UnseekableStreamError(stream_object=self.body) -class RequestCompressor: - """A class that can compress the body of an ``AWSRequest``.""" - - def compress(self, config, request_dict, operation_model): - """Compresses the request body using the specified encodings. - - Check if the request should be compressed based on the contents of its - body and config settings. Set or append the `Content-Encoding` header - with the matched encoding if not present. - """ - body = request_dict['body'] - if self._should_compress_request(config, body, operation_model): - encodings = operation_model.request_compression['encodings'] - headers = request_dict.get('headers', {}) - for encoding in encodings: - encoder = getattr(self, f'_{encoding}_compress_body', None) - if encoder is not None: - ce_header = headers.get('Content-Encoding') - if ce_header is None: - headers['Content-Encoding'] = encoding - elif encoding not in ce_header.split(','): - headers['Content-Encoding'] = f'{ce_header},{encoding}' - request_dict['body'] = encoder(body) - if 'headers' not in request_dict: - request_dict['headers'] = headers - else: - logger.debug( - 'Unsupported compression encoding: %s' % encoding - ) - - def _gzip_compress_body(self, body): - if isinstance(body, str): - return gzip.compress(body.encode('utf-8')) - elif isinstance(body, (bytes, bytearray)): - return gzip.compress(body) - elif hasattr(body, 'read'): - if hasattr(body, 'seek') and hasattr(body, 'tell'): - current_position = body.tell() - compressed_obj = self._gzip_compress_fileobj(body) - body.seek(current_position) - return compressed_obj - return self._gzip_compress_fileobj(body) - - def _gzip_compress_fileobj(self, body): - compressed_obj = io.BytesIO() - with gzip.GzipFile(fileobj=compressed_obj, mode='wb') as gz: - while True: - chunk = body.read(8192) - if not chunk: - break - if isinstance(chunk, str): - chunk = chunk.encode('utf-8') - gz.write(chunk) - compressed_obj.seek(0) - return compressed_obj - - def _should_compress_request(self, config, body, operation_model): - if ( - config.disable_request_compression is not True - and config.signature_version != 'v2' - and operation_model.request_compression is not None - ): - # Request is compressed no matter the content length if it has a streaming input. - # However, if the stream has the `requiresLength` trait it is NOT compressed. - if operation_model.has_streaming_input: - return ( - 'requiresLength' - not in operation_model.get_streaming_input().metadata - ) - return ( - config.request_min_compression_size_bytes - <= self._get_body_size(body) - ) - return False - - def _get_body_size(self, body): - size = botocore.utils.determine_content_length(body) - if size is None: - logger.debug( - 'Unable to get length of the request body: %s. Not compressing.' - % body - ) - return -1 - return size - - class AWSResponse: """A data class representing an HTTP response. diff --git a/botocore/client.py b/botocore/client.py index b884bff0af..b03c26df97 100644 --- a/botocore/client.py +++ b/botocore/client.py @@ -15,7 +15,8 @@ from botocore import waiter, xform_name from botocore.args import ClientArgsCreator from botocore.auth import AUTH_TYPE_MAPS -from botocore.awsrequest import RequestCompressor, prepare_request_dict +from botocore.awsrequest import prepare_request_dict +from botocore.compress import RequestCompressor from botocore.config import Config from botocore.discovery import ( EndpointDiscoveryHandler, @@ -47,6 +48,7 @@ S3RegionRedirectorv2, ensure_boolean, get_service_module_name, + urlencode_query_body, ) # Keep these imported. There's pre-existing code that uses: @@ -72,7 +74,6 @@ 's3v4', ) ) -REQUEST_COMPRESSOR = RequestCompressor() logger = logging.getLogger(__name__) @@ -956,7 +957,10 @@ def _make_api_call(self, operation_name, api_params): if event_response is not None: http, parsed_response = event_response else: - REQUEST_COMPRESSOR.compress( + urlencode_query_body( + request_dict, operation_model, self.meta.config + ) + RequestCompressor.compress( self.meta.config, request_dict, operation_model ) apply_request_checksum(request_dict) diff --git a/botocore/compress.py b/botocore/compress.py new file mode 100644 index 0000000000..52eaa86b88 --- /dev/null +++ b/botocore/compress.py @@ -0,0 +1,114 @@ +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import gzip +import io +import logging + +from botocore.utils import determine_content_length + +logger = logging.getLogger(__name__) + + +class RequestCompressor: + """A class that can compress the body of an ``AWSRequest``.""" + + @classmethod + def compress(cls, config, request_dict, operation_model): + """Compresses the request body using the specified encodings. + + Check if the request should be compressed based on the contents of its + body and config settings. Set or append the `Content-Encoding` header + with the matched encoding if not present. + """ + body = request_dict['body'] + if cls._should_compress_request(config, body, operation_model): + encodings = operation_model.request_compression['encodings'] + headers = request_dict.get('headers', {}) + for encoding in encodings: + encoder = getattr(cls, f'_{encoding}_compress_body', None) + if encoder is not None: + ce_header = headers.get('Content-Encoding') + if ce_header is None: + headers['Content-Encoding'] = encoding + elif encoding not in ce_header.split(','): + headers['Content-Encoding'] = f'{ce_header},{encoding}' + logger.debug( + 'Compressing request with %s encoding', encoding + ) + request_dict['body'] = encoder(body) + if 'headers' not in request_dict: + request_dict['headers'] = headers + else: + logger.debug( + 'Unsupported compression encoding: %s' % encoding + ) + + @classmethod + def _gzip_compress_body(cls, body): + if isinstance(body, str): + return gzip.compress(body.encode('utf-8')) + elif isinstance(body, (bytes, bytearray)): + return gzip.compress(body) + elif hasattr(body, 'read'): + if hasattr(body, 'seek') and hasattr(body, 'tell'): + current_position = body.tell() + compressed_obj = cls._gzip_compress_fileobj(body) + body.seek(current_position) + return compressed_obj + return cls._gzip_compress_fileobj(body) + + @staticmethod + def _gzip_compress_fileobj(body): + compressed_obj = io.BytesIO() + with gzip.GzipFile(fileobj=compressed_obj, mode='wb') as gz: + while True: + chunk = body.read(8192) + if not chunk: + break + if isinstance(chunk, str): + chunk = chunk.encode('utf-8') + gz.write(chunk) + compressed_obj.seek(0) + return compressed_obj + + @classmethod + def _should_compress_request(cls, config, body, operation_model): + if ( + config.disable_request_compression is not True + and config.signature_version != 'v2' + and operation_model.request_compression is not None + ): + # Request is compressed no matter the content length if it has a streaming input. + # However, if the stream has the `requiresLength` trait it is NOT compressed. + if operation_model.has_streaming_input: + return ( + 'requiresLength' + not in operation_model.get_streaming_input().metadata + ) + return ( + config.request_min_compression_size_bytes + <= cls._get_body_size(body) + ) + return False + + @staticmethod + def _get_body_size(body): + size = determine_content_length(body) + if size is None: + logger.debug( + 'Unable to get length of the request body: %s. Not compressing.' + % body + ) + return -1 + return size diff --git a/botocore/configprovider.py b/botocore/configprovider.py index 1f9155e320..778507775d 100644 --- a/botocore/configprovider.py +++ b/botocore/configprovider.py @@ -147,6 +147,8 @@ # whatever the defaults are in _retry.json. 'max_attempts': ('max_attempts', 'AWS_MAX_ATTEMPTS', None, int), 'user_agent_appid': ('sdk_ua_app_id', 'AWS_SDK_UA_APP_ID', None, None), + # This must be a parsable integer between 0 and 1048576, but validation + # is performed during client initialization instead of here. 'request_min_compression_size_bytes': ( 'request_min_compression_size_bytes', 'AWS_REQUEST_MIN_COMPRESSION_SIZE_BYTES', diff --git a/botocore/handlers.py b/botocore/handlers.py index 863bb54921..55087f6749 100644 --- a/botocore/handlers.py +++ b/botocore/handlers.py @@ -38,7 +38,6 @@ quote, unquote, unquote_str, - urlencode, urlsplit, urlunsplit, ) @@ -1145,23 +1144,6 @@ def remove_content_type_header_for_presigning(request, **kwargs): del request.headers['Content-Type'] -def urlencode_body(model, params, context, **kwargs): - """URL-encode the request body if it is a dictionary. - - This is used for services using the query protocol. The body must be - serialized as a URL-encoded string before it can be compressed. - """ - body = params.get('body') - if ( - context['client_config'].signature_version != 'v2' - and model.service_model.protocol == 'query' - and isinstance(body, dict) - ): - params['body'] = urlencode(body, doseq=True, encoding='utf-8').encode( - 'utf-8' - ) - - # This is a list of (event_name, handler). # When a Session is created, everything in this list will be # automatically registered with that Session. @@ -1420,6 +1402,5 @@ def urlencode_body(model, params, context, **kwargs): AutoPopulatedParam('PreSignedUrl').document_auto_populated_param, ), ('before-call', inject_api_version_header_if_needed), - ('before-call', urlencode_body), ] _add_parameter_aliases(BUILTIN_HANDLERS) diff --git a/botocore/utils.py b/botocore/utils.py index 266d204629..9d35c61211 100644 --- a/botocore/utils.py +++ b/botocore/utils.py @@ -55,6 +55,7 @@ get_tzinfo_options, json, quote, + urlencode, urlparse, urlsplit, urlunsplit, @@ -3429,3 +3430,16 @@ def _serialize_if_needed(self, value, iso=False): 'stepfunctions': 'sfn', 'storagegateway': 'storage-gateway', } + + +def urlencode_query_body(request_dict, operation_model, config): + """URL encode a request's body if it using the query protocol.""" + body = request_dict['body'] + if ( + operation_model.service_model.protocol == 'query' + and isinstance(body, dict) + and config.signature_version != 'v2' + ): + request_dict['body'] = urlencode( + body, doseq=True, encoding='utf-8' + ).encode('utf-8') diff --git a/tests/functional/test_compression.py b/tests/functional/test_compress.py similarity index 95% rename from tests/functional/test_compression.py rename to tests/functional/test_compress.py index b5289fe07e..3ad0465240 100644 --- a/tests/functional/test_compression.py +++ b/tests/functional/test_compress.py @@ -97,9 +97,10 @@ def _all_compression_operations(): @pytest.mark.parametrize("operation_model", _all_compression_operations()) def test_no_unknown_compression_encodings(operation_model): for encoding in operation_model.request_compression["encodings"]: - assert ( - encoding in KNOWN_COMPRESSION_ENCODINGS - ), f"Found unknown compression encoding '{encoding}' in operation {operation_model.name}" + assert encoding in KNOWN_COMPRESSION_ENCODINGS, ( + f"Found unknown compression encoding '{encoding}' " + f"in operation {operation_model.name}" + ) def test_compression(patched_session, monkeypatch): diff --git a/tests/unit/test_awsrequest.py b/tests/unit/test_awsrequest.py index bba00332c1..a29a846116 100644 --- a/tests/unit/test_awsrequest.py +++ b/tests/unit/test_awsrequest.py @@ -27,12 +27,10 @@ AWSRequest, AWSResponse, HeadersDict, - RequestCompressor, create_request_object, prepare_request_dict, ) from botocore.compat import file_type -from botocore.config import Config from botocore.exceptions import UnseekableStreamError from tests import mock, unittest @@ -109,97 +107,6 @@ def tell(self): return self._stream.tell() -def _op_with_compression(): - op = mock.Mock() - op.request_compression = {'encodings': ['gzip']} - op.has_streaming_input = False - return op - - -def _op_unknown_compression(): - op = mock.Mock() - op.request_compression = {'encodings': ['foo']} - op.has_streaming_input = None - return op - - -def _op_without_compression(): - op = mock.Mock() - op.request_compression = None - op.has_streaming_input = False - return op - - -def _streaming_op_with_compression(): - op = _op_with_compression() - op.has_streaming_input = True - streaming_shape = mock.Mock() - streaming_shape.metadata = {} - op.get_streaming_input.return_value = streaming_shape - return op - - -def _streaming_op_with_compression_requires_length(): - op = _streaming_op_with_compression() - streaming_shape = mock.Mock() - streaming_shape.metadata = {'requiresLength': True} - op.get_streaming_input.return_value = streaming_shape - return op - - -OP_WITH_COMPRESSION = _op_with_compression() -OP_UNKNOWN_COMPRESSION = _op_unknown_compression() -STREAMING_OP_WITH_COMPRESSION = _streaming_op_with_compression() -REQUEST_BODY = ( - b'Action=PutMetricData&Version=2010-08-01&Namespace=Namespace' - b'&MetricData.member.1.MetricName=metric&MetricData.member.1.Unit=Bytes' - b'&MetricData.member.1.Value=128' -) - -DEFAULT_COMPRESSION_CONFIG = Config( - disable_request_compression=False, - request_min_compression_size_bytes=10420, -) -COMPRESSION_CONFIG_128_BYTES = Config( - disable_request_compression=False, - request_min_compression_size_bytes=128, -) -COMPRESSION_CONFIG_0_BYTES = Config( - disable_request_compression=False, - request_min_compression_size_bytes=0, -) - - -def request_dict(): - return { - 'body': REQUEST_BODY, - 'headers': {'foo': 'bar'}, - } - - -def request_dict_with_content_encoding_header(): - return { - 'body': REQUEST_BODY, - 'headers': {'foo': 'bar', 'Content-Encoding': 'identity'}, - } - - -@pytest.fixture(scope='module') -def request_compressor(): - return RequestCompressor() - - -COMPRESSION_HEADERS = {'gzip': b'\x1f\x8b'} - - -def _assert_compression(is_compressed, body, encoding): - if hasattr(body, 'read'): - header = body.read(2) - else: - header = body[:2] - assert is_compressed == (header == COMPRESSION_HEADERS.get(encoding)) - - class TestAWSRequest(unittest.TestCase): def setUp(self): self.tempdir = tempfile.mkdtemp() @@ -891,151 +798,5 @@ def test_iteration(self): self.assertIn(('dead', 'beef'), headers_items) -@pytest.mark.parametrize( - 'config, request_dict, operation_model, is_compressed, encoding', - [ - ( - Config( - disable_request_compression=True, - request_min_compression_size_bytes=1000, - ), - {'body': b'foo'}, - OP_WITH_COMPRESSION, - False, - None, - ), - ( - COMPRESSION_CONFIG_128_BYTES, - request_dict(), - OP_WITH_COMPRESSION, - True, - 'gzip', - ), - ( - Config( - disable_request_compression=False, - request_min_compression_size_bytes=256, - ), - request_dict(), - OP_WITH_COMPRESSION, - False, - None, - ), - ( - Config(request_min_compression_size_bytes=128), - request_dict(), - OP_WITH_COMPRESSION, - True, - 'gzip', - ), - ( - DEFAULT_COMPRESSION_CONFIG, - request_dict(), - STREAMING_OP_WITH_COMPRESSION, - True, - 'gzip', - ), - ( - DEFAULT_COMPRESSION_CONFIG, - request_dict(), - _streaming_op_with_compression_requires_length(), - False, - None, - ), - ( - DEFAULT_COMPRESSION_CONFIG, - request_dict(), - _op_without_compression(), - False, - None, - ), - ( - COMPRESSION_CONFIG_128_BYTES, - request_dict(), - OP_UNKNOWN_COMPRESSION, - False, - None, - ), - ( - COMPRESSION_CONFIG_128_BYTES, - {'body': REQUEST_BODY.decode()}, - OP_WITH_COMPRESSION, - True, - 'gzip', - ), - ( - COMPRESSION_CONFIG_128_BYTES, - {'body': bytearray(REQUEST_BODY)}, - OP_WITH_COMPRESSION, - True, - 'gzip', - ), - ( - COMPRESSION_CONFIG_128_BYTES, - {'body': io.BytesIO(REQUEST_BODY)}, - OP_WITH_COMPRESSION, - True, - 'gzip', - ), - ( - COMPRESSION_CONFIG_128_BYTES, - {'body': io.StringIO(REQUEST_BODY.decode())}, - OP_WITH_COMPRESSION, - True, - 'gzip', - ), - ( - COMPRESSION_CONFIG_128_BYTES, - request_dict_with_content_encoding_header(), - OP_UNKNOWN_COMPRESSION, - False, - 'foo', - ), - ( - COMPRESSION_CONFIG_128_BYTES, - request_dict_with_content_encoding_header(), - OP_WITH_COMPRESSION, - True, - 'gzip', - ), - ], -) -def test_compress( - request_compressor, - config, - request_dict, - operation_model, - is_compressed, - encoding, -): - request_compressor.compress(config, request_dict, operation_model) - _assert_compression(is_compressed, request_dict['body'], encoding) - assert ( - 'headers' in request_dict - and 'Content-Encoding' in request_dict['headers'] - and encoding in request_dict['headers']['Content-Encoding'] - ) == is_compressed - - -@pytest.mark.parametrize('body', [1, object(), None, True, 1.0]) -def test_compress_bad_types(request_compressor, body): - request_dict = {'body': body} - request_compressor.compress( - COMPRESSION_CONFIG_0_BYTES, request_dict, OP_WITH_COMPRESSION - ) - assert request_dict['body'] == body - - -@pytest.mark.parametrize( - 'body', - [io.StringIO('foo'), io.BytesIO(b'foo')], -) -def test_body_streams_position_reset(request_compressor, body): - request_compressor.compress( - COMPRESSION_CONFIG_0_BYTES, {'body': body}, OP_WITH_COMPRESSION - ) - assert body.tell() == 0 - - if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 3a9529665d..c06aa6c672 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1468,7 +1468,7 @@ def inject_params(params, **kwargs): # Ensure the handler passed on the correct param values. body = self.endpoint.make_request.call_args[0][1]['body'] - self.assertEqual(body['Foo'], 'zero') + self.assertIn(b'Foo=zero', body) def test_client_default_for_s3_addressing_style(self): creator = self.create_client_creator() @@ -1748,8 +1748,8 @@ def test_request_compression_client_config_overrides_config_store(self): self.assertEqual( service_client.meta.config.request_min_compression_size_bytes, 0 ) - self.assertEqual( - service_client.meta.config.disable_request_compression, False + self.assertFalse( + service_client.meta.config.disable_request_compression ) def test_bad_request_min_compression_size_bytes(self): diff --git a/tests/unit/test_compress.py b/tests/unit/test_compress.py new file mode 100644 index 0000000000..08a06c7069 --- /dev/null +++ b/tests/unit/test_compress.py @@ -0,0 +1,250 @@ +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import io + +import pytest + +from botocore.compress import RequestCompressor +from botocore.config import Config +from tests import mock + + +def _op_with_compression(): + op = mock.Mock() + op.request_compression = {'encodings': ['gzip']} + op.has_streaming_input = False + return op + + +def _op_unknown_compression(): + op = mock.Mock() + op.request_compression = {'encodings': ['foo']} + op.has_streaming_input = None + return op + + +def _op_without_compression(): + op = mock.Mock() + op.request_compression = None + op.has_streaming_input = False + return op + + +def _streaming_op_with_compression(): + op = _op_with_compression() + op.has_streaming_input = True + streaming_shape = mock.Mock() + streaming_shape.metadata = {} + op.get_streaming_input.return_value = streaming_shape + return op + + +def _streaming_op_with_compression_requires_length(): + op = _streaming_op_with_compression() + streaming_shape = mock.Mock() + streaming_shape.metadata = {'requiresLength': True} + op.get_streaming_input.return_value = streaming_shape + return op + + +OP_WITH_COMPRESSION = _op_with_compression() +OP_UNKNOWN_COMPRESSION = _op_unknown_compression() +STREAMING_OP_WITH_COMPRESSION = _streaming_op_with_compression() +REQUEST_BODY = ( + b'Action=PutMetricData&Version=2010-08-01&Namespace=Namespace' + b'&MetricData.member.1.MetricName=metric&MetricData.member.1.Unit=Bytes' + b'&MetricData.member.1.Value=128' +) + +DEFAULT_COMPRESSION_CONFIG = Config( + disable_request_compression=False, + request_min_compression_size_bytes=10420, +) +COMPRESSION_CONFIG_128_BYTES = Config( + disable_request_compression=False, + request_min_compression_size_bytes=128, +) +COMPRESSION_CONFIG_0_BYTES = Config( + disable_request_compression=False, + request_min_compression_size_bytes=0, +) + + +def request_dict(): + return { + 'body': REQUEST_BODY, + 'headers': {'foo': 'bar'}, + } + + +def request_dict_with_content_encoding_header(): + return { + 'body': REQUEST_BODY, + 'headers': {'foo': 'bar', 'Content-Encoding': 'identity'}, + } + + +COMPRESSION_HEADERS = {'gzip': b'\x1f\x8b'} + + +def _assert_compression(is_compressed, body, encoding): + if hasattr(body, 'read'): + header = body.read(2) + else: + header = body[:2] + assert is_compressed == (header == COMPRESSION_HEADERS.get(encoding)) + + +@pytest.mark.parametrize( + 'config, request_dict, operation_model, is_compressed, encoding', + [ + ( + Config( + disable_request_compression=True, + request_min_compression_size_bytes=1000, + ), + {'body': b'foo'}, + OP_WITH_COMPRESSION, + False, + None, + ), + ( + COMPRESSION_CONFIG_128_BYTES, + request_dict(), + OP_WITH_COMPRESSION, + True, + 'gzip', + ), + ( + Config( + disable_request_compression=False, + request_min_compression_size_bytes=256, + ), + request_dict(), + OP_WITH_COMPRESSION, + False, + None, + ), + ( + Config(request_min_compression_size_bytes=128), + request_dict(), + OP_WITH_COMPRESSION, + True, + 'gzip', + ), + ( + DEFAULT_COMPRESSION_CONFIG, + request_dict(), + STREAMING_OP_WITH_COMPRESSION, + True, + 'gzip', + ), + ( + DEFAULT_COMPRESSION_CONFIG, + request_dict(), + _streaming_op_with_compression_requires_length(), + False, + None, + ), + ( + DEFAULT_COMPRESSION_CONFIG, + request_dict(), + _op_without_compression(), + False, + None, + ), + ( + COMPRESSION_CONFIG_128_BYTES, + request_dict(), + OP_UNKNOWN_COMPRESSION, + False, + None, + ), + ( + COMPRESSION_CONFIG_128_BYTES, + {'body': REQUEST_BODY.decode()}, + OP_WITH_COMPRESSION, + True, + 'gzip', + ), + ( + COMPRESSION_CONFIG_128_BYTES, + {'body': bytearray(REQUEST_BODY)}, + OP_WITH_COMPRESSION, + True, + 'gzip', + ), + ( + COMPRESSION_CONFIG_128_BYTES, + {'body': io.BytesIO(REQUEST_BODY)}, + OP_WITH_COMPRESSION, + True, + 'gzip', + ), + ( + COMPRESSION_CONFIG_128_BYTES, + {'body': io.StringIO(REQUEST_BODY.decode())}, + OP_WITH_COMPRESSION, + True, + 'gzip', + ), + ( + COMPRESSION_CONFIG_128_BYTES, + request_dict_with_content_encoding_header(), + OP_UNKNOWN_COMPRESSION, + False, + 'foo', + ), + ( + COMPRESSION_CONFIG_128_BYTES, + request_dict_with_content_encoding_header(), + OP_WITH_COMPRESSION, + True, + 'gzip', + ), + ], +) +def test_compress( + config, + request_dict, + operation_model, + is_compressed, + encoding, +): + RequestCompressor.compress(config, request_dict, operation_model) + _assert_compression(is_compressed, request_dict['body'], encoding) + assert ( + 'headers' in request_dict + and 'Content-Encoding' in request_dict['headers'] + and encoding in request_dict['headers']['Content-Encoding'] + ) == is_compressed + + +@pytest.mark.parametrize('body', [1, object(), None, True, 1.0]) +def test_compress_bad_types(body): + request_dict = {'body': body} + RequestCompressor.compress( + COMPRESSION_CONFIG_0_BYTES, request_dict, OP_WITH_COMPRESSION + ) + assert request_dict['body'] == body + + +@pytest.mark.parametrize( + 'body', + [io.StringIO('foo'), io.BytesIO(b'foo')], +) +def test_body_streams_position_reset(body): + RequestCompressor.compress( + COMPRESSION_CONFIG_0_BYTES, {'body': body}, OP_WITH_COMPRESSION + ) + assert body.tell() == 0 diff --git a/tests/unit/test_handlers.py b/tests/unit/test_handlers.py index 47d093ad87..08a315a78a 100644 --- a/tests/unit/test_handlers.py +++ b/tests/unit/test_handlers.py @@ -1089,25 +1089,6 @@ def test_set_operation_specific_signer_s3v4(auth_type, expected_response): assert response == expected_response -@pytest.mark.parametrize( - 'protocol, signature_version, params, expected_body', - [ - ('query', 'v4', {'body': {'foo': 'bar'}}, b'foo=bar'), - ('query', 'v4', {}, None), - ('json', 'v4', {'body': {'foo': 'bar'}}, {'foo': 'bar'}), - ('query', 'v2', {'body': {'foo': 'bar'}}, {'foo': 'bar'}), - ('query', 'v4', {'body': 'foo=bar'}, 'foo=bar'), - ], -) -def test_urlencode_body(protocol, signature_version, params, expected_body): - operation_def = {'name': 'CreateFoo'} - service_def = {'metadata': {'protocol': protocol}, 'shapes': {}} - model = OperationModel(operation_def, ServiceModel(service_def)) - context = {'client_config': Config(signature_version=signature_version)} - handlers.urlencode_body(model, params, context) - assert params.get('body') == expected_body - - class TestConvertStringBodyToFileLikeObject(BaseSessionTest): def assert_converts_to_file_like_object_with_bytes(self, body, body_bytes): params = {'Body': body} diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 6b1a33ba78..427abd7d0d 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -94,6 +94,7 @@ set_value_from_jmespath, switch_host_s3_accelerate, switch_to_virtual_host_style, + urlencode_query_body, validate_jmespath_for_set, ) from tests import FreezeTime, RawResponse, create_session, mock, unittest @@ -3476,3 +3477,25 @@ def cached_fn(self, a, b): assert cls2.cached_fn.cache_info().currsize == 2 assert cls2.cached_fn.cache_info().hits == 1 # the call was a cache hit assert cls2.cached_fn.cache_info().misses == 2 + + +@pytest.mark.parametrize( + 'request_dict, protocol, signature_version, expected_body', + [ + ({'body': {'foo': 'bar'}}, 'query', 'v4', b'foo=bar'), + ({'body': b''}, 'query', 'v4', b''), + ({'body': {'foo': 'bar'}}, 'json', 'v4', {'foo': 'bar'}), + ({'body': {'foo': 'bar'}}, 'query', 'v2', {'foo': 'bar'}), + ({'body': 'foo=bar'}, 'query', 'v4', 'foo=bar'), + ], +) +def test_urlencode_query_body( + request_dict, protocol, signature_version, expected_body +): + operation_def = {'name': 'CreateFoo'} + service_def = {'metadata': {'protocol': protocol}, 'shapes': {}} + model = OperationModel(operation_def, ServiceModel(service_def)) + urlencode_query_body( + request_dict, model, Config(signature_version=signature_version) + ) + assert request_dict['body'] == expected_body