Skip to content

Commit

Permalink
Another round of refactoring. Move compression to a separate file and…
Browse files Browse the repository at this point in the history
… query body serialization to a utility
  • Loading branch information
davidlm committed Jul 7, 2023
1 parent 32e3716 commit 5b48c41
Show file tree
Hide file tree
Showing 12 changed files with 417 additions and 374 deletions.
88 changes: 0 additions & 88 deletions botocore/awsrequest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import functools
import gzip
import io
import logging
from collections.abc import Mapping

Expand Down Expand Up @@ -536,92 +534,6 @@ def reset_stream(self):
raise UnseekableStreamError(stream_object=self.body)


class RequestCompressor:
"""A class that can compress the body of an ``AWSRequest``."""

def compress(self, config, request_dict, operation_model):
"""Compresses the request body using the specified encodings.
Check if the request should be compressed based on the contents of its
body and config settings. Set or append the `Content-Encoding` header
with the matched encoding if not present.
"""
body = request_dict['body']
if self._should_compress_request(config, body, operation_model):
encodings = operation_model.request_compression['encodings']
headers = request_dict.get('headers', {})
for encoding in encodings:
encoder = getattr(self, f'_{encoding}_compress_body', None)
if encoder is not None:
ce_header = headers.get('Content-Encoding')
if ce_header is None:
headers['Content-Encoding'] = encoding
elif encoding not in ce_header.split(','):
headers['Content-Encoding'] = f'{ce_header},{encoding}'
request_dict['body'] = encoder(body)
if 'headers' not in request_dict:
request_dict['headers'] = headers
else:
logger.debug(
'Unsupported compression encoding: %s' % encoding
)

def _gzip_compress_body(self, body):
if isinstance(body, str):
return gzip.compress(body.encode('utf-8'))
elif isinstance(body, (bytes, bytearray)):
return gzip.compress(body)
elif hasattr(body, 'read'):
if hasattr(body, 'seek') and hasattr(body, 'tell'):
current_position = body.tell()
compressed_obj = self._gzip_compress_fileobj(body)
body.seek(current_position)
return compressed_obj
return self._gzip_compress_fileobj(body)

def _gzip_compress_fileobj(self, body):
compressed_obj = io.BytesIO()
with gzip.GzipFile(fileobj=compressed_obj, mode='wb') as gz:
while True:
chunk = body.read(8192)
if not chunk:
break
if isinstance(chunk, str):
chunk = chunk.encode('utf-8')
gz.write(chunk)
compressed_obj.seek(0)
return compressed_obj

def _should_compress_request(self, config, body, operation_model):
if (
config.disable_request_compression is not True
and config.signature_version != 'v2'
and operation_model.request_compression is not None
):
# Request is compressed no matter the content length if it has a streaming input.
# However, if the stream has the `requiresLength` trait it is NOT compressed.
if operation_model.has_streaming_input:
return (
'requiresLength'
not in operation_model.get_streaming_input().metadata
)
return (
config.request_min_compression_size_bytes
<= self._get_body_size(body)
)
return False

def _get_body_size(self, body):
size = botocore.utils.determine_content_length(body)
if size is None:
logger.debug(
'Unable to get length of the request body: %s. Not compressing.'
% body
)
return -1
return size


class AWSResponse:
"""A data class representing an HTTP response.
Expand Down
10 changes: 7 additions & 3 deletions botocore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from botocore import waiter, xform_name
from botocore.args import ClientArgsCreator
from botocore.auth import AUTH_TYPE_MAPS
from botocore.awsrequest import RequestCompressor, prepare_request_dict
from botocore.awsrequest import prepare_request_dict
from botocore.compress import RequestCompressor
from botocore.config import Config
from botocore.discovery import (
EndpointDiscoveryHandler,
Expand Down Expand Up @@ -47,6 +48,7 @@
S3RegionRedirectorv2,
ensure_boolean,
get_service_module_name,
urlencode_query_body,
)

# Keep these imported. There's pre-existing code that uses:
Expand All @@ -72,7 +74,6 @@
's3v4',
)
)
REQUEST_COMPRESSOR = RequestCompressor()


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -956,7 +957,10 @@ def _make_api_call(self, operation_name, api_params):
if event_response is not None:
http, parsed_response = event_response
else:
REQUEST_COMPRESSOR.compress(
urlencode_query_body(
request_dict, operation_model, self.meta.config
)
RequestCompressor.compress(
self.meta.config, request_dict, operation_model
)
apply_request_checksum(request_dict)
Expand Down
114 changes: 114 additions & 0 deletions botocore/compress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import gzip
import io
import logging

from botocore.utils import determine_content_length

logger = logging.getLogger(__name__)


class RequestCompressor:
"""A class that can compress the body of an ``AWSRequest``."""

@classmethod
def compress(cls, config, request_dict, operation_model):
"""Compresses the request body using the specified encodings.
Check if the request should be compressed based on the contents of its
body and config settings. Set or append the `Content-Encoding` header
with the matched encoding if not present.
"""
body = request_dict['body']
if cls._should_compress_request(config, body, operation_model):
encodings = operation_model.request_compression['encodings']
headers = request_dict.get('headers', {})
for encoding in encodings:
encoder = getattr(cls, f'_{encoding}_compress_body', None)
if encoder is not None:
ce_header = headers.get('Content-Encoding')
if ce_header is None:
headers['Content-Encoding'] = encoding
elif encoding not in ce_header.split(','):
headers['Content-Encoding'] = f'{ce_header},{encoding}'
logger.debug(
'Compressing request with %s encoding', encoding
)
request_dict['body'] = encoder(body)
if 'headers' not in request_dict:
request_dict['headers'] = headers
else:
logger.debug(
'Unsupported compression encoding: %s' % encoding
)

@classmethod
def _gzip_compress_body(cls, body):
if isinstance(body, str):
return gzip.compress(body.encode('utf-8'))
elif isinstance(body, (bytes, bytearray)):
return gzip.compress(body)
elif hasattr(body, 'read'):
if hasattr(body, 'seek') and hasattr(body, 'tell'):
current_position = body.tell()
compressed_obj = cls._gzip_compress_fileobj(body)
body.seek(current_position)
return compressed_obj
return cls._gzip_compress_fileobj(body)

@staticmethod
def _gzip_compress_fileobj(body):
compressed_obj = io.BytesIO()
with gzip.GzipFile(fileobj=compressed_obj, mode='wb') as gz:
while True:
chunk = body.read(8192)
if not chunk:
break
if isinstance(chunk, str):
chunk = chunk.encode('utf-8')
gz.write(chunk)
compressed_obj.seek(0)
return compressed_obj

@classmethod
def _should_compress_request(cls, config, body, operation_model):
if (
config.disable_request_compression is not True
and config.signature_version != 'v2'
and operation_model.request_compression is not None
):
# Request is compressed no matter the content length if it has a streaming input.
# However, if the stream has the `requiresLength` trait it is NOT compressed.
if operation_model.has_streaming_input:
return (
'requiresLength'
not in operation_model.get_streaming_input().metadata
)
return (
config.request_min_compression_size_bytes
<= cls._get_body_size(body)
)
return False

@staticmethod
def _get_body_size(body):
size = determine_content_length(body)
if size is None:
logger.debug(
'Unable to get length of the request body: %s. Not compressing.'
% body
)
return -1
return size
2 changes: 2 additions & 0 deletions botocore/configprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@
# whatever the defaults are in _retry.json.
'max_attempts': ('max_attempts', 'AWS_MAX_ATTEMPTS', None, int),
'user_agent_appid': ('sdk_ua_app_id', 'AWS_SDK_UA_APP_ID', None, None),
# This must be a parsable integer between 0 and 1048576, but validation
# is performed during client initialization instead of here.
'request_min_compression_size_bytes': (
'request_min_compression_size_bytes',
'AWS_REQUEST_MIN_COMPRESSION_SIZE_BYTES',
Expand Down
19 changes: 0 additions & 19 deletions botocore/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
quote,
unquote,
unquote_str,
urlencode,
urlsplit,
urlunsplit,
)
Expand Down Expand Up @@ -1145,23 +1144,6 @@ def remove_content_type_header_for_presigning(request, **kwargs):
del request.headers['Content-Type']


def urlencode_body(model, params, context, **kwargs):
"""URL-encode the request body if it is a dictionary.
This is used for services using the query protocol. The body must be
serialized as a URL-encoded string before it can be compressed.
"""
body = params.get('body')
if (
context['client_config'].signature_version != 'v2'
and model.service_model.protocol == 'query'
and isinstance(body, dict)
):
params['body'] = urlencode(body, doseq=True, encoding='utf-8').encode(
'utf-8'
)


# This is a list of (event_name, handler).
# When a Session is created, everything in this list will be
# automatically registered with that Session.
Expand Down Expand Up @@ -1420,6 +1402,5 @@ def urlencode_body(model, params, context, **kwargs):
AutoPopulatedParam('PreSignedUrl').document_auto_populated_param,
),
('before-call', inject_api_version_header_if_needed),
('before-call', urlencode_body),
]
_add_parameter_aliases(BUILTIN_HANDLERS)
14 changes: 14 additions & 0 deletions botocore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
get_tzinfo_options,
json,
quote,
urlencode,
urlparse,
urlsplit,
urlunsplit,
Expand Down Expand Up @@ -3429,3 +3430,16 @@ def _serialize_if_needed(self, value, iso=False):
'stepfunctions': 'sfn',
'storagegateway': 'storage-gateway',
}


def urlencode_query_body(request_dict, operation_model, config):
"""URL encode a request's body if it using the query protocol."""
body = request_dict['body']
if (
operation_model.service_model.protocol == 'query'
and isinstance(body, dict)
and config.signature_version != 'v2'
):
request_dict['body'] = urlencode(
body, doseq=True, encoding='utf-8'
).encode('utf-8')
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,10 @@ def _all_compression_operations():
@pytest.mark.parametrize("operation_model", _all_compression_operations())
def test_no_unknown_compression_encodings(operation_model):
for encoding in operation_model.request_compression["encodings"]:
assert (
encoding in KNOWN_COMPRESSION_ENCODINGS
), f"Found unknown compression encoding '{encoding}' in operation {operation_model.name}"
assert encoding in KNOWN_COMPRESSION_ENCODINGS, (
f"Found unknown compression encoding '{encoding}' "
f"in operation {operation_model.name}"
)


def test_compression(patched_session, monkeypatch):
Expand Down
Loading

0 comments on commit 5b48c41

Please sign in to comment.