Skip to content

Commit

Permalink
Use content_type in Fields not header
Browse files Browse the repository at this point in the history
  • Loading branch information
jhnstrk committed Dec 8, 2024
1 parent 12b3f26 commit f1e28d6
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 49 deletions.
45 changes: 23 additions & 22 deletions python_multipart/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def finalize(self) -> None: ...
def close(self) -> None: ...

class FieldProtocol(_FormProtocol, Protocol):
def __init__(self, name: bytes, headers: dict[str, bytes]) -> None: ...
def __init__(self, name: bytes, content_type: str | None = None) -> None: ...

def set_none(self) -> None: ...

class FileProtocol(_FormProtocol, Protocol):
def __init__(
self, file_name: bytes | None, field_name: bytes | None, config: FileConfig, headers: dict[str, bytes]
self, file_name: bytes | None, field_name: bytes | None, config: FileConfig, content_type: str | None = None
) -> None: ...

OnFieldCallback = Callable[[FieldProtocol], None]
Expand Down Expand Up @@ -223,12 +223,13 @@ class Field:
Args:
name: The name of the form field.
content_type: The value of the Content-Type header for this field.
"""

def __init__(self, name: bytes, headers: dict[str, bytes] = {}) -> None:
def __init__(self, name: bytes, content_type: str | None = None) -> None:
self._name = name
self._value: list[bytes] = []
self._headers: dict[str, bytes] = headers
self._content_type = content_type

# We cache the joined version of _value for speed.
self._cache = _missing
Expand Down Expand Up @@ -321,9 +322,9 @@ def value(self) -> bytes | None:
return self._cache

@property
def headers(self) -> dict[str, bytes]:
"""This property returns the headers of the field."""
return self._headers
def content_type(self) -> str | None:
"""This property returns the content_type value of the field."""
return self._content_type

def __eq__(self, other: object) -> bool:
if isinstance(other, Field):
Expand Down Expand Up @@ -362,14 +363,15 @@ class File:
file_name: The name of the file that this [`File`][python_multipart.File] represents.
field_name: The name of the form field that this file was uploaded with. This can be None, if, for example,
the file was uploaded with Content-Type application/octet-stream.
content_type: The value of the Content-Type header.
config: The configuration for this File. See above for valid configuration keys and their corresponding values.
""" # noqa: E501

def __init__(
self,
file_name: bytes | None,
field_name: bytes | None = None,
headers: dict[str, bytes] = {},
content_type: str | None = None,
config: FileConfig = {},
) -> None:
# Save configuration, set other variables default.
Expand All @@ -382,7 +384,7 @@ def __init__(
# Save the provided field/file name and content type.
self._field_name = field_name
self._file_name = file_name
self._headers = headers
self._content_type = content_type

# Our actual file name is None by default, since, depending on our
# config, we may not actually use the provided name.
Expand Down Expand Up @@ -436,14 +438,9 @@ def in_memory(self) -> bool:
return self._in_memory

@property
def headers(self) -> dict[str, bytes]:
"""The headers for this part."""
return self._headers

@property
def content_type(self) -> bytes | None:
"""The Content-Type value for this part."""
return self._headers.get("content-type")
def content_type(self) -> str | None:
"""The Content-Type value for this part, if it was set."""
return self._content_type

def flush_to_disk(self) -> None:
"""If the file is already on-disk, do nothing. Otherwise, copy from
Expand Down Expand Up @@ -1565,7 +1562,7 @@ def __init__(

def on_start() -> None:
nonlocal file
file = FileClass(file_name, None, headers={}, config=cast("FileConfig", self.config))
file = FileClass(file_name, None, content_type=None, config=cast("FileConfig", self.config))

def on_data(data: bytes, start: int, end: int) -> None:
nonlocal file
Expand Down Expand Up @@ -1604,7 +1601,7 @@ def on_field_name(data: bytes, start: int, end: int) -> None:
def on_field_data(data: bytes, start: int, end: int) -> None:
nonlocal f
if f is None:
f = FieldClass(b"".join(name_buffer), headers={})
f = FieldClass(b"".join(name_buffer), content_type=None)
del name_buffer[:]
f.write(data[start:end])

Expand All @@ -1614,7 +1611,7 @@ def on_field_end() -> None:
if f is None:
# If we get here, it's because there was no field data.
# We create a field, set it to None, and then continue.
f = FieldClass(b"".join(name_buffer), headers={})
f = FieldClass(b"".join(name_buffer), content_type=None)
del name_buffer[:]
f.set_none()

Expand Down Expand Up @@ -1700,10 +1697,14 @@ def on_headers_finished() -> None:
# TODO: check for errors

# Create the proper class.
content_type_b = headers.get("content-type")
content_type = content_type_b.decode("latin-1") if content_type_b is not None else None
if file_name is None:
f_multi = FieldClass(field_name, headers=headers)
f_multi = FieldClass(field_name, content_type=content_type)
else:
f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config), headers=headers)
f_multi = FileClass(
file_name, field_name, config=cast("FileConfig", self.config), content_type=content_type
)
is_file = True

# Parse the given Content-Transfer-Encoding to determine what
Expand Down
41 changes: 14 additions & 27 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def assert_file_data(self, f: File, data: bytes) -> None:
file_data = o.read()
self.assertEqual(file_data, data)

def assert_file(self, field_name: bytes, file_name: bytes, content_type: str, data: bytes) -> None:
def assert_file(self, field_name: bytes, file_name: bytes, content_type: str | None, data: bytes) -> None:
# Find this file.
found = None
for f in self.files:
Expand All @@ -770,7 +770,7 @@ def assert_file(self, field_name: bytes, file_name: bytes, content_type: str, da
self.assertIsNotNone(found)
assert found is not None

self.assertEqual(found.content_type, content_type.encode())
self.assertEqual(found.content_type, content_type)

try:
# Assert about this file.
Expand Down Expand Up @@ -911,7 +911,8 @@ def test_feed_single_bytes(self, param: TestParams) -> None:
self.assert_field(name, e["data"])

elif type == "file":
self.assert_file(name, e["file_name"].encode("latin-1"), "text/plain", e["data"])
content_type = "text/plain"
self.assert_file(name, e["file_name"].encode("latin-1"), content_type, e["data"])

else:
assert False
Expand Down Expand Up @@ -949,24 +950,16 @@ def test_feed_blocks(self) -> None:
# Assert that our field is here.
self.assert_field(b"field", b"0123456789ABCDEFGHIJ0123456789ABCDEFGHIJ")

def test_file_headers(self) -> None:
def test_file_content_type_header(self) -> None:
"""
This test checks headers for a file part are read.
This test checks the content-type for a file part is passed on.
"""
# Load test data.
test_file = "header_with_number.http"
with open(os.path.join(http_tests_dir, test_file), "rb") as f:
test_data = f.read()

expected_headers = {
"content-disposition": b'form-data; filename="secret.txt"; name="files"',
"content-type": b"text/plain; charset=utf-8",
"x-funky-header-1": b"bar",
"abcdefghijklmnopqrstuvwxyz01234": b"foo",
"abcdefghijklmnopqrstuvwxyz56789": b"bar",
"other!#$%&'*+-.^_`|~": b"baz",
"content-length": b"6",
}
expected_content_type = "text/plain; charset=utf-8"

# Create form parser.
self.make(boundary="b8825ae386be4fdc9644d87e392caad3")
Expand All @@ -975,22 +968,19 @@ def test_file_headers(self) -> None:

# Assert that our field is here.
self.assertEqual(1, len(self.files))
actual_headers = self.files[0].headers
self.assertEqual(len(actual_headers), len(expected_headers))
actual_content_type = self.files[0].content_type
self.assertEqual(actual_content_type, expected_content_type)

for k, v in expected_headers.items():
self.assertEqual(v, actual_headers[k])

def test_field_headers(self) -> None:
def test_field_content_type_header(self) -> None:
"""
This test checks headers for a field part are read.
This test checks content-tpye for a field part are read and passed.
"""
# Load test data.
test_file = "single_field.http"
with open(os.path.join(http_tests_dir, test_file), "rb") as f:
test_data = f.read()

expected_headers = {"content-disposition": b'form-data; name="field"'}
expected_content_type = None

# Create form parser.
self.make(boundary="----WebKitFormBoundaryTkr3kCBQlBe1nrhc")
Expand All @@ -999,11 +989,8 @@ def test_field_headers(self) -> None:

# Assert that our field is here.
self.assertEqual(1, len(self.fields))
actual_headers = self.fields[0].headers
self.assertEqual(len(actual_headers), len(expected_headers))

for k, v in expected_headers.items():
self.assertEqual(v, actual_headers[k])
actual_content_type = self.fields[0].content_type
self.assertEqual(actual_content_type, expected_content_type)

def test_request_body_fuzz(self) -> None:
"""
Expand Down

0 comments on commit f1e28d6

Please sign in to comment.