Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlm committed Jun 6, 2023
1 parent 545e39d commit 0680e36
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 104 deletions.
32 changes: 13 additions & 19 deletions botocore/awsrequest.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,21 +539,18 @@ def reset_stream(self):
class AWSRequestCompressor:
"""A class that can compress the body of an ``AWSRequest``."""

def compress(self, config, request, operation_model):
def compress(self, config, request_dict, operation_model):
"""Compresses the request body using the specified encodings if conditions
are met.
Check if the request should be compressed based on the contents of its
body and config settings. Set or append the `Content-Encoding` header
with the matched encoding if not present.
"""
# ``AWSRequest.body`` is a computed property that computes a URL-encoded
# value if it's a dictionary. So we should call it once here to minimize
# recalculation.
body = request.body
body = request_dict['body']
if self._should_compress_request(config, body, operation_model):
encodings = operation_model.request_compression['encodings']
headers = request.headers
headers = request_dict.get('headers', {})
for encoding in encodings:
encoder = getattr(self, f'_{encoding}_compress_body', None)
if encoder is not None:
Expand All @@ -562,30 +559,29 @@ def compress(self, config, request, operation_model):
elif encoding not in headers['Content-Encoding'].split(
','
):
headers.replace_header(
'Content-Encoding',
f'{headers["Content-Encoding"]},{encoding}',
)
# AWSRequest.data is the raw input for the request body.
# This is the parameter that should be updated to properly
# construct an AWSPreparedRequest object.
request.data = encoder(body)
headers[
'Content-Encoding'
] = f'{headers["Content-Encoding"]},{encoding}'
request_dict['body'] = encoder(body)
if 'headers' not in request_dict:
request_dict['headers'] = headers
else:
logger.debug(
'Unsupported compression encoding: %s' % encoding
)

def _gzip_compress_body(self, body):
if isinstance(body, (bytes, bytearray)):
if isinstance(body, str):
return gzip.compress(body.encode('utf-8'))
elif isinstance(body, (bytes, bytearray)):
return gzip.compress(body)
elif hasattr(body, 'read'):
if hasattr(body, 'seek') and hasattr(body, 'tell'):
current_position = body.tell()
compressed_obj = self._gzip_compress_fileobj(body)
body.seek(current_position)
return compressed_obj
else:
return self._gzip_compress_fileobj(body)
return self._gzip_compress_fileobj(body)

def _gzip_compress_fileobj(self, body):
compressed_obj = io.BytesIO()
Expand All @@ -598,8 +594,6 @@ def _gzip_compress_fileobj(self, body):
chunk = chunk.encode('utf-8')
gz.write(chunk)
compressed_obj.seek(0)
if hasattr(body, 'seek') and hasattr(body, 'tell'):
body.seek(0)
return compressed_obj

def _should_compress_request(self, config, body, operation_model):
Expand Down
9 changes: 9 additions & 0 deletions botocore/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import botocore
import botocore.auth
from botocore import utils
from botocore.awsrequest import AWSRequestCompressor
from botocore.compat import (
ETree,
OrderedDict,
Expand Down Expand Up @@ -101,6 +102,7 @@
VERSION_ID_SUFFIX = re.compile(r'\?versionId=[^\s]+$')

SERVICE_NAME_ALIASES = {'runtime.sagemaker': 'sagemaker-runtime'}
AWS_REQUEST_COMPRESSOR = AWSRequestCompressor()


def handle_service_name_alias(service_name, **kwargs):
Expand Down Expand Up @@ -1159,6 +1161,12 @@ def urlencode_body(model, params, **kwargs):
)


def compress_request(model, params, context, **kwargs):
body = params.get('body')
if model.request_compression and body is not None:
AWS_REQUEST_COMPRESSOR.compress(body, context['client_config'], model)


# This is a list of (event_name, handler).
# When a Session is created, everything in this list will be
# automatically registered with that Session.
Expand Down Expand Up @@ -1232,6 +1240,7 @@ def urlencode_body(model, params, **kwargs):
('before-call.docdb.CopyDBClusterSnapshot', inject_presigned_url_rds),
('before-call.docdb.CreateDBCluster', inject_presigned_url_rds),
('before-call', urlencode_body),
('before-call', compress_request),
('before-call.s3.PutObject', conditionally_calculate_md5),
('before-call.s3.UploadPart', conditionally_calculate_md5),
('before-call.s3.DeleteObjects', escape_xml_payload),
Expand Down
100 changes: 35 additions & 65 deletions tests/unit/test_awsrequest.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,22 +171,18 @@ def _streaming_op_with_compression_requires_length():
GZIP_ENCODINGS = ['gzip']


def aws_request():
return AWSRequest(
method='POST',
url='http://example.com',
data=REQUEST_BODY,
headers={'foo': 'bar'},
)
def request_dict():
return {
'body': REQUEST_BODY,
'headers': {'foo': 'bar'},
}


def aws_request_with_content_encoding_header():
return AWSRequest(
method='POST',
url='http://example.com',
data=REQUEST_BODY,
headers={'foo': 'bar', 'Content-Encoding': 'identity'},
)
def request_dict_with_content_encoding_header():
return {
'body': REQUEST_BODY,
'headers': {'foo': 'bar', 'Content-Encoding': 'identity'},
}


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -899,26 +895,21 @@ def test_iteration(self):


@pytest.mark.parametrize(
'config, aws_request, operation_model, is_compressed, encoding',
'config, request_dict, operation_model, is_compressed, encoding',
[
(
Config(
disable_request_compression=True,
request_min_compression_size_bytes=1000,
),
AWSRequest(
method='POST',
url='https://s3.amazonaws.com/mybucket?prefix=foo',
headers={},
data=b'foo',
),
{'body': b'foo'},
OP_WITH_COMPRESSION,
False,
None,
),
(
COMPRESSION_CONFIG_128_BYTES,
aws_request(),
request_dict(),
OP_WITH_COMPRESSION,
True,
'gzip',
Expand All @@ -928,104 +919,84 @@ def test_iteration(self):
disable_request_compression=False,
request_min_compression_size_bytes=256,
),
aws_request(),
request_dict(),
OP_WITH_COMPRESSION,
False,
None,
),
(
Config(request_min_compression_size_bytes=128),
aws_request(),
request_dict(),
OP_WITH_COMPRESSION,
True,
'gzip',
),
(
DEFAULT_COMPRESSION_CONFIG,
aws_request(),
request_dict(),
STREAMING_OP_WITH_COMPRESSION,
True,
'gzip',
),
(
DEFAULT_COMPRESSION_CONFIG,
aws_request(),
request_dict(),
_streaming_op_with_compression_requires_length(),
False,
None,
),
(
DEFAULT_COMPRESSION_CONFIG,
aws_request(),
request_dict(),
_op_without_compression(),
False,
None,
),
(
COMPRESSION_CONFIG_128_BYTES,
aws_request(),
request_dict(),
OP_UNKNOWN_COMPRESSION,
False,
None,
),
(
COMPRESSION_CONFIG_128_BYTES,
AWSRequest(
method='POST',
url='https://s3.amazonaws.com/mybucket?prefix=foo',
headers={},
data=REQUEST_BODY.decode(),
),
{'body': REQUEST_BODY.decode()},
OP_WITH_COMPRESSION,
True,
'gzip',
),
(
COMPRESSION_CONFIG_128_BYTES,
AWSRequest(
method='POST',
url='https://s3.amazonaws.com/mybucket?prefix=foo',
headers={},
data=bytearray(REQUEST_BODY),
),
{'body': bytearray(REQUEST_BODY)},
OP_WITH_COMPRESSION,
True,
'gzip',
),
(
COMPRESSION_CONFIG_128_BYTES,
AWSRequest(
method='POST',
url='https://s3.amazonaws.com/mybucket?prefix=foo',
headers={},
data=io.BytesIO(REQUEST_BODY),
),
{'body': io.BytesIO(REQUEST_BODY)},
OP_WITH_COMPRESSION,
True,
'gzip',
),
(
COMPRESSION_CONFIG_128_BYTES,
AWSRequest(
method='POST',
url='https://s3.amazonaws.com/mybucket?prefix=foo',
headers={},
data=io.StringIO(REQUEST_BODY.decode()),
),
{'body': io.StringIO(REQUEST_BODY.decode())},
OP_WITH_COMPRESSION,
True,
'gzip',
),
(
COMPRESSION_CONFIG_128_BYTES,
aws_request_with_content_encoding_header(),
request_dict_with_content_encoding_header(),
OP_UNKNOWN_COMPRESSION,
False,
"foo",
),
(
COMPRESSION_CONFIG_128_BYTES,
aws_request_with_content_encoding_header(),
request_dict_with_content_encoding_header(),
OP_WITH_COMPRESSION,
True,
"gzip",
Expand All @@ -1035,37 +1006,36 @@ def test_iteration(self):
def test_compress(
aws_request_compressor,
config,
aws_request,
request_dict,
operation_model,
is_compressed,
encoding,
):
aws_request_compressor.compress(config, aws_request, operation_model)
_assert_compression(is_compressed, aws_request.data)
aws_request_compressor.compress(config, request_dict, operation_model)
_assert_compression(is_compressed, request_dict['body'])
assert (
'Content-Encoding' in aws_request.headers
and encoding in aws_request.headers['Content-Encoding']
'headers' in request_dict
and 'Content-Encoding' in request_dict['headers']
and encoding in request_dict['headers']['Content-Encoding']
) == is_compressed


@pytest.mark.parametrize('body', [1, object(), None, True, 1.0])
def test_compress_bad_types(aws_request_compressor, body):
aws_request = AWSRequest(data=body)
request_dict = {'body': body}
aws_request_compressor.compress(
COMPRESSION_CONFIG_0_BYTES, aws_request, OP_WITH_COMPRESSION
COMPRESSION_CONFIG_0_BYTES, request_dict, OP_WITH_COMPRESSION
)
# no compression will happen on request bodies that do not have a length
assert aws_request.data == body
assert request_dict['body'] == body


@pytest.mark.parametrize(
'body',
[io.StringIO("foo"), io.BytesIO(b"foo")],
)
def test_body_streams_position_reset(aws_request_compressor, body):
aws_request = AWSRequest(data=body)
aws_request_compressor.compress(
COMPRESSION_CONFIG_0_BYTES, aws_request, OP_WITH_COMPRESSION
COMPRESSION_CONFIG_0_BYTES, {'body': body}, OP_WITH_COMPRESSION
)
assert body.tell() == 0

Expand Down
20 changes: 0 additions & 20 deletions tests/unit/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,26 +206,6 @@ def test_close(self):
self.endpoint.close()
self.endpoint.http_session.close.assert_called_once_with()

def test_make_request_with_compressed_body(self):
r = request_dict(
**{
'context': {
'client_config': Config(
request_min_compression_size_bytes=0
)
}
}
)
self.op.request_compression = {'encodings': ['gzip']}
self.op.has_streaming_input = False
with mock.patch(
'botocore.endpoint.Endpoint.prepare_request'
) as prepare:
self.endpoint.make_request(self.op, r)
request = prepare.call_args[0][0]
self.assertEqual(request.headers['Content-Encoding'], 'gzip')
self.assertEqual(request.data[:2], b'\x1f\x8b')


class TestRetryInterface(TestEndpointBase):
def setUp(self):
Expand Down

0 comments on commit 0680e36

Please sign in to comment.