From 0680e36769bdccf4717c31eb3a01d5ed5d3e774a Mon Sep 17 00:00:00 2001 From: davidlm Date: Tue, 6 Jun 2023 19:00:06 -0400 Subject: [PATCH] wip --- botocore/awsrequest.py | 32 +++++------ botocore/handlers.py | 9 +++ tests/unit/test_awsrequest.py | 100 ++++++++++++---------------------- tests/unit/test_endpoint.py | 20 ------- 4 files changed, 57 insertions(+), 104 deletions(-) diff --git a/botocore/awsrequest.py b/botocore/awsrequest.py index 918fe60126..45b408443c 100644 --- a/botocore/awsrequest.py +++ b/botocore/awsrequest.py @@ -539,7 +539,7 @@ def reset_stream(self): class AWSRequestCompressor: """A class that can compress the body of an ``AWSRequest``.""" - def compress(self, config, request, operation_model): + def compress(self, config, request_dict, operation_model): """Compresses the request body using the specified encodings if conditions are met. @@ -547,13 +547,10 @@ def compress(self, config, request, operation_model): body and config settings. Set or append the `Content-Encoding` header with the matched encoding if not present. """ - # ``AWSRequest.body`` is a computed property that computes a URL-encoded - # value if it's a dictionary. So we should call it once here to minimize - # recalculation. - body = request.body + body = request_dict['body'] if self._should_compress_request(config, body, operation_model): encodings = operation_model.request_compression['encodings'] - headers = request.headers + headers = request_dict.get('headers', {}) for encoding in encodings: encoder = getattr(self, f'_{encoding}_compress_body', None) if encoder is not None: @@ -562,21 +559,21 @@ def compress(self, config, request, operation_model): elif encoding not in headers['Content-Encoding'].split( ',' ): - headers.replace_header( - 'Content-Encoding', - f'{headers["Content-Encoding"]},{encoding}', - ) - # AWSRequest.data is the raw input for the request body. - # This is the parameter that should be updated to properly - # construct an AWSPreparedRequest object. - request.data = encoder(body) + headers[ + 'Content-Encoding' + ] = f'{headers["Content-Encoding"]},{encoding}' + request_dict['body'] = encoder(body) + if 'headers' not in request_dict: + request_dict['headers'] = headers else: logger.debug( 'Unsupported compression encoding: %s' % encoding ) def _gzip_compress_body(self, body): - if isinstance(body, (bytes, bytearray)): + if isinstance(body, str): + return gzip.compress(body.encode('utf-8')) + elif isinstance(body, (bytes, bytearray)): return gzip.compress(body) elif hasattr(body, 'read'): if hasattr(body, 'seek') and hasattr(body, 'tell'): @@ -584,8 +581,7 @@ def _gzip_compress_body(self, body): compressed_obj = self._gzip_compress_fileobj(body) body.seek(current_position) return compressed_obj - else: - return self._gzip_compress_fileobj(body) + return self._gzip_compress_fileobj(body) def _gzip_compress_fileobj(self, body): compressed_obj = io.BytesIO() @@ -598,8 +594,6 @@ def _gzip_compress_fileobj(self, body): chunk = chunk.encode('utf-8') gz.write(chunk) compressed_obj.seek(0) - if hasattr(body, 'seek') and hasattr(body, 'tell'): - body.seek(0) return compressed_obj def _should_compress_request(self, config, body, operation_model): diff --git a/botocore/handlers.py b/botocore/handlers.py index e5f371b9fb..dcb814640b 100644 --- a/botocore/handlers.py +++ b/botocore/handlers.py @@ -28,6 +28,7 @@ import botocore import botocore.auth from botocore import utils +from botocore.awsrequest import AWSRequestCompressor from botocore.compat import ( ETree, OrderedDict, @@ -101,6 +102,7 @@ VERSION_ID_SUFFIX = re.compile(r'\?versionId=[^\s]+$') SERVICE_NAME_ALIASES = {'runtime.sagemaker': 'sagemaker-runtime'} +AWS_REQUEST_COMPRESSOR = AWSRequestCompressor() def handle_service_name_alias(service_name, **kwargs): @@ -1159,6 +1161,12 @@ def urlencode_body(model, params, **kwargs): ) +def compress_request(model, params, context, **kwargs): + body = params.get('body') + if model.request_compression and body is not None: + AWS_REQUEST_COMPRESSOR.compress(body, context['client_config'], model) + + # This is a list of (event_name, handler). # When a Session is created, everything in this list will be # automatically registered with that Session. @@ -1232,6 +1240,7 @@ def urlencode_body(model, params, **kwargs): ('before-call.docdb.CopyDBClusterSnapshot', inject_presigned_url_rds), ('before-call.docdb.CreateDBCluster', inject_presigned_url_rds), ('before-call', urlencode_body), + ('before-call', compress_request), ('before-call.s3.PutObject', conditionally_calculate_md5), ('before-call.s3.UploadPart', conditionally_calculate_md5), ('before-call.s3.DeleteObjects', escape_xml_payload), diff --git a/tests/unit/test_awsrequest.py b/tests/unit/test_awsrequest.py index 3acabcda55..a39b60915d 100644 --- a/tests/unit/test_awsrequest.py +++ b/tests/unit/test_awsrequest.py @@ -171,22 +171,18 @@ def _streaming_op_with_compression_requires_length(): GZIP_ENCODINGS = ['gzip'] -def aws_request(): - return AWSRequest( - method='POST', - url='http://example.com', - data=REQUEST_BODY, - headers={'foo': 'bar'}, - ) +def request_dict(): + return { + 'body': REQUEST_BODY, + 'headers': {'foo': 'bar'}, + } -def aws_request_with_content_encoding_header(): - return AWSRequest( - method='POST', - url='http://example.com', - data=REQUEST_BODY, - headers={'foo': 'bar', 'Content-Encoding': 'identity'}, - ) +def request_dict_with_content_encoding_header(): + return { + 'body': REQUEST_BODY, + 'headers': {'foo': 'bar', 'Content-Encoding': 'identity'}, + } @pytest.fixture(scope="module") @@ -899,26 +895,21 @@ def test_iteration(self): @pytest.mark.parametrize( - 'config, aws_request, operation_model, is_compressed, encoding', + 'config, request_dict, operation_model, is_compressed, encoding', [ ( Config( disable_request_compression=True, request_min_compression_size_bytes=1000, ), - AWSRequest( - method='POST', - url='https://s3.amazonaws.com/mybucket?prefix=foo', - headers={}, - data=b'foo', - ), + {'body': b'foo'}, OP_WITH_COMPRESSION, False, None, ), ( COMPRESSION_CONFIG_128_BYTES, - aws_request(), + request_dict(), OP_WITH_COMPRESSION, True, 'gzip', @@ -928,104 +919,84 @@ def test_iteration(self): disable_request_compression=False, request_min_compression_size_bytes=256, ), - aws_request(), + request_dict(), OP_WITH_COMPRESSION, False, None, ), ( Config(request_min_compression_size_bytes=128), - aws_request(), + request_dict(), OP_WITH_COMPRESSION, True, 'gzip', ), ( DEFAULT_COMPRESSION_CONFIG, - aws_request(), + request_dict(), STREAMING_OP_WITH_COMPRESSION, True, 'gzip', ), ( DEFAULT_COMPRESSION_CONFIG, - aws_request(), + request_dict(), _streaming_op_with_compression_requires_length(), False, None, ), ( DEFAULT_COMPRESSION_CONFIG, - aws_request(), + request_dict(), _op_without_compression(), False, None, ), ( COMPRESSION_CONFIG_128_BYTES, - aws_request(), + request_dict(), OP_UNKNOWN_COMPRESSION, False, None, ), ( COMPRESSION_CONFIG_128_BYTES, - AWSRequest( - method='POST', - url='https://s3.amazonaws.com/mybucket?prefix=foo', - headers={}, - data=REQUEST_BODY.decode(), - ), + {'body': REQUEST_BODY.decode()}, OP_WITH_COMPRESSION, True, 'gzip', ), ( COMPRESSION_CONFIG_128_BYTES, - AWSRequest( - method='POST', - url='https://s3.amazonaws.com/mybucket?prefix=foo', - headers={}, - data=bytearray(REQUEST_BODY), - ), + {'body': bytearray(REQUEST_BODY)}, OP_WITH_COMPRESSION, True, 'gzip', ), ( COMPRESSION_CONFIG_128_BYTES, - AWSRequest( - method='POST', - url='https://s3.amazonaws.com/mybucket?prefix=foo', - headers={}, - data=io.BytesIO(REQUEST_BODY), - ), + {'body': io.BytesIO(REQUEST_BODY)}, OP_WITH_COMPRESSION, True, 'gzip', ), ( COMPRESSION_CONFIG_128_BYTES, - AWSRequest( - method='POST', - url='https://s3.amazonaws.com/mybucket?prefix=foo', - headers={}, - data=io.StringIO(REQUEST_BODY.decode()), - ), + {'body': io.StringIO(REQUEST_BODY.decode())}, OP_WITH_COMPRESSION, True, 'gzip', ), ( COMPRESSION_CONFIG_128_BYTES, - aws_request_with_content_encoding_header(), + request_dict_with_content_encoding_header(), OP_UNKNOWN_COMPRESSION, False, "foo", ), ( COMPRESSION_CONFIG_128_BYTES, - aws_request_with_content_encoding_header(), + request_dict_with_content_encoding_header(), OP_WITH_COMPRESSION, True, "gzip", @@ -1035,27 +1006,27 @@ def test_iteration(self): def test_compress( aws_request_compressor, config, - aws_request, + request_dict, operation_model, is_compressed, encoding, ): - aws_request_compressor.compress(config, aws_request, operation_model) - _assert_compression(is_compressed, aws_request.data) + aws_request_compressor.compress(config, request_dict, operation_model) + _assert_compression(is_compressed, request_dict['body']) assert ( - 'Content-Encoding' in aws_request.headers - and encoding in aws_request.headers['Content-Encoding'] + 'headers' in request_dict + and 'Content-Encoding' in request_dict['headers'] + and encoding in request_dict['headers']['Content-Encoding'] ) == is_compressed @pytest.mark.parametrize('body', [1, object(), None, True, 1.0]) def test_compress_bad_types(aws_request_compressor, body): - aws_request = AWSRequest(data=body) + request_dict = {'body': body} aws_request_compressor.compress( - COMPRESSION_CONFIG_0_BYTES, aws_request, OP_WITH_COMPRESSION + COMPRESSION_CONFIG_0_BYTES, request_dict, OP_WITH_COMPRESSION ) - # no compression will happen on request bodies that do not have a length - assert aws_request.data == body + assert request_dict['body'] == body @pytest.mark.parametrize( @@ -1063,9 +1034,8 @@ def test_compress_bad_types(aws_request_compressor, body): [io.StringIO("foo"), io.BytesIO(b"foo")], ) def test_body_streams_position_reset(aws_request_compressor, body): - aws_request = AWSRequest(data=body) aws_request_compressor.compress( - COMPRESSION_CONFIG_0_BYTES, aws_request, OP_WITH_COMPRESSION + COMPRESSION_CONFIG_0_BYTES, {'body': body}, OP_WITH_COMPRESSION ) assert body.tell() == 0 diff --git a/tests/unit/test_endpoint.py b/tests/unit/test_endpoint.py index 0a3a662e2c..33930b7746 100644 --- a/tests/unit/test_endpoint.py +++ b/tests/unit/test_endpoint.py @@ -206,26 +206,6 @@ def test_close(self): self.endpoint.close() self.endpoint.http_session.close.assert_called_once_with() - def test_make_request_with_compressed_body(self): - r = request_dict( - **{ - 'context': { - 'client_config': Config( - request_min_compression_size_bytes=0 - ) - } - } - ) - self.op.request_compression = {'encodings': ['gzip']} - self.op.has_streaming_input = False - with mock.patch( - 'botocore.endpoint.Endpoint.prepare_request' - ) as prepare: - self.endpoint.make_request(self.op, r) - request = prepare.call_args[0][0] - self.assertEqual(request.headers['Content-Encoding'], 'gzip') - self.assertEqual(request.data[:2], b'\x1f\x8b') - class TestRetryInterface(TestEndpointBase): def setUp(self):