Skip to content

Commit

Permalink
Fix more input serialization issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan343 committed Aug 16, 2024
1 parent 6d55023 commit 7db1cf3
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 53 deletions.
140 changes: 87 additions & 53 deletions botocore/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,11 @@ def serialize_to_request(self, parameters, operation_model):
'method', self.DEFAULT_METHOD
)
shape = operation_model.input_shape

host_prefix = self._expand_host_prefix(parameters, operation_model)
if host_prefix is not None:
serialized['host_prefix'] = host_prefix

if shape is None:
serialized['url_path'] = operation_model.http['requestUri']
return serialized
Expand Down Expand Up @@ -509,10 +514,6 @@ def serialize_to_request(self, parameters, operation_model):
)
self._serialize_content_type(serialized, shape, shape_members)

host_prefix = self._expand_host_prefix(parameters, operation_model)
if host_prefix is not None:
serialized['host_prefix'] = host_prefix

return serialized

def _render_uri_template(self, uri_template, params):
Expand Down Expand Up @@ -577,19 +578,6 @@ def _serialize_content_type(self, serialized, shape, shape_members):
"""
pass

def _handle_streaming_payload_content_type(
self, serialized, shape, shape_members
):
"""Set Content-Type header for streaming payload member types."""
payload = shape.serialization.get('payload')
if self._has_streaming_payload(payload, shape_members):
if shape_members[payload].type_name == 'string':
serialized['headers']['Content-Type'] = 'text/plain'
elif shape_members[payload].type_name == 'blob':
serialized['headers']['Content-Type'] = (
'application/octet-stream'
)

def _requires_empty_body(self, shape):
"""
Some protocols require a specific body to represent an empty
Expand Down Expand Up @@ -635,27 +623,26 @@ def _partition_parameters(
partitioned['uri_path_kwargs'][key_name] = param_value
elif location == 'querystring':
if isinstance(param_value, dict):
partitioned['query_string_kwargs'].update(param_value)
elif isinstance(param_value, bool):
bool_str = str(param_value).lower()
partitioned['query_string_kwargs'][key_name] = bool_str
elif member.type_name == 'timestamp':
timestamp_format = member.serialization.get(
'timestampFormat', self.QUERY_STRING_TIMESTAMP_FORMAT
)
timestamp = self._convert_timestamp_to_str(
param_value, timestamp_format
)
partitioned['query_string_kwargs'][key_name] = timestamp
# Add only new query string key/value pairs.
# Named query parameters should take precedence.
for key, value in param_value.items():
partitioned['query_string_kwargs'].setdefault(key, value)
elif member.type_name == 'list':
new_param = [
self._get_query_string_value(value, member.member)
for value in param_value
]
partitioned['query_string_kwargs'][key_name] = new_param
else:
partitioned['query_string_kwargs'][key_name] = param_value
new_param = self._get_query_string_value(param_value, member)
partitioned['query_string_kwargs'][key_name] = new_param
elif location == 'header':
shape = shape_members[param_name]
if not param_value and shape.type_name in ['list', 'string']:
# Empty lists and strings should not be set on the headers
return
value = self._convert_header_value(shape, param_value)
partitioned['headers'][key_name] = str(value)
partitioned['headers'][key_name] = value
elif location == 'headers':
# 'headers' is a bit of an oddball. The ``key_name``
# is actually really a prefix for the header names:
Expand All @@ -670,6 +657,19 @@ def _partition_parameters(
else:
partitioned['body_kwargs'][param_name] = param_value

def _get_query_string_value(self, param_value, member):
if isinstance(param_value, bool):
return str(param_value).lower()
elif member.type_name == 'timestamp':
timestamp_format = member.serialization.get(
'timestampFormat', self.QUERY_STRING_TIMESTAMP_FORMAT
)
return self._convert_timestamp_to_str(
param_value, timestamp_format
)
else:
return param_value

def _do_serialize_header_map(self, header_prefix, headers, user_input):
for key, val in user_input.items():
full_key = header_prefix + key
Expand All @@ -679,24 +679,41 @@ def _serialize_body_params(self, params, shape):
raise NotImplementedError('_serialize_body_params')

def _convert_header_value(self, shape, value):
if shape.type_name == 'timestamp':
if isinstance(value, bool):
return str(value).lower()
elif shape.type_name == 'timestamp':
datetime_obj = parse_to_aware_datetime(value)
timestamp = calendar.timegm(datetime_obj.utctimetuple())
timestamp_format = shape.serialization.get(
'timestampFormat', self.HEADER_TIMESTAMP_FORMAT
)
return self._convert_timestamp_to_str(timestamp, timestamp_format)
return str(
self._convert_timestamp_to_str(timestamp, timestamp_format)
)
elif shape.type_name == 'list':
converted_value = [
self._convert_header_value(shape.member, v)
for v in value
if v is not None
]
return ",".join(converted_value)
if shape.member.type_name == "string":
converted_value = [
self._escape_header_list_string(v)
for v in value
if v is not None
]
else:
converted_value = [
self._convert_header_value(shape.member, v)
for v in value
if v is not None
]
return ", ".join(converted_value)
elif is_json_value_header(shape):
# Serialize with no spaces after separators to save space in
# the header.
return self._get_base64(json.dumps(value, separators=(',', ':')))
else:
return str(value)

def _escape_header_list_string(self, value):
if '"' in value or ',' in value:
return '"' + value.replace('"', '\\"') + '"'
else:
return value

Expand All @@ -717,18 +734,29 @@ def _requires_empty_body(self, shape):

def _serialize_content_type(self, serialized, shape, shape_members):
"""Set Content-Type to application/json for all structured bodies."""
self._handle_streaming_payload_content_type(
serialized, shape, shape_members
)

has_body = serialized['body'] != b''
has_content_type = has_header('Content-Type', serialized['headers'])
if has_body and not has_content_type:
serialized['headers']['Content-Type'] = 'application/json'
if has_content_type:
return
payload = shape.serialization.get('payload')
if self._has_streaming_payload(payload, shape_members):
if shape_members[payload].type_name == 'string':
serialized['headers']['Content-Type'] = 'text/plain'
elif shape_members[payload].type_name == 'blob':
serialized['headers']['Content-Type'] = (
'application/octet-stream'
)
else:
if serialized['body'] != b'':
serialized['headers']['Content-Type'] = 'application/json'

def _serialize_body_params(self, params, shape):
serialized_body = self.MAP_TYPE()
self._serialize(serialized_body, params, shape)
# Handle document types as a payload
if list(serialized_body.keys()) == [None] and shape.metadata.get(
'document'
):
serialized_body = serialized_body[None]
return json.dumps(serialized_body).encode(self.DEFAULT_ENCODING)


Expand Down Expand Up @@ -842,14 +870,20 @@ def _default_serialize(self, xmlnode, params, shape, name):

def _serialize_content_type(self, serialized, shape, shape_members):
"""Set Content-Type to application/xml for all structured bodies."""
self._handle_streaming_payload_content_type(
serialized, shape, shape_members
)

has_body = serialized['body'] != b''
has_content_type = has_header('Content-Type', serialized['headers'])
if has_body and not has_content_type:
serialized['headers']['Content-Type'] = 'application/xml'
if has_content_type:
return
payload = shape.serialization.get('payload')
if self._has_streaming_payload(payload, shape_members):
if shape_members[payload].type_name == 'string':
serialized['headers']['Content-Type'] = 'text/plain'
elif shape_members[payload].type_name == 'blob':
serialized['headers']['Content-Type'] = (
'application/octet-stream'
)
else:
if serialized['body'] != b'':
serialized['headers']['Content-Type'] = 'application/xml'

def _add_xml_namespace(self, shape, structure_node):
if 'xmlNamespace' in shape.serialization:
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@
'rest-xml': RestXMLParser,
}
PROTOCOL_TEST_BLACKLIST = [
# These cases test functionality outside the serializers and parsers.
"Test cases for QueryIdempotencyTokenAutoFill operation",
"Test cases for PutWithContentEncoding operation",
"Test cases for HttpChecksumRequired operation",
]


Expand Down

0 comments on commit 7db1cf3

Please sign in to comment.