diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index c08dec5..99a8204 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -68,13 +68,13 @@ def finalize(self) -> None: ... def close(self) -> None: ... class FieldProtocol(_FormProtocol, Protocol): - def __init__(self, name: bytes, headers: dict[str, bytes]) -> None: ... + def __init__(self, name: bytes, content_type: str | None = None) -> None: ... def set_none(self) -> None: ... class FileProtocol(_FormProtocol, Protocol): def __init__( - self, file_name: bytes | None, field_name: bytes | None, config: FileConfig, headers: dict[str, bytes] + self, file_name: bytes | None, field_name: bytes | None, config: FileConfig, content_type: str | None = None ) -> None: ... OnFieldCallback = Callable[[FieldProtocol], None] @@ -223,12 +223,13 @@ class Field: Args: name: The name of the form field. + content_type: The value of the Content-Type header for this field. """ - def __init__(self, name: bytes, headers: dict[str, bytes] = {}) -> None: + def __init__(self, name: bytes, content_type: str | None = None) -> None: self._name = name self._value: list[bytes] = [] - self._headers: dict[str, bytes] = headers + self._content_type = content_type # We cache the joined version of _value for speed. self._cache = _missing @@ -321,9 +322,9 @@ def value(self) -> bytes | None: return self._cache @property - def headers(self) -> dict[str, bytes]: - """This property returns the headers of the field.""" - return self._headers + def content_type(self) -> str | None: + """This property returns the content_type value of the field.""" + return self._content_type def __eq__(self, other: object) -> bool: if isinstance(other, Field): @@ -362,6 +363,7 @@ class File: file_name: The name of the file that this [`File`][python_multipart.File] represents. field_name: The name of the form field that this file was uploaded with. This can be None, if, for example, the file was uploaded with Content-Type application/octet-stream. + content_type: The value of the Content-Type header. config: The configuration for this File. See above for valid configuration keys and their corresponding values. """ # noqa: E501 @@ -369,7 +371,7 @@ def __init__( self, file_name: bytes | None, field_name: bytes | None = None, - headers: dict[str, bytes] = {}, + content_type: str | None = None, config: FileConfig = {}, ) -> None: # Save configuration, set other variables default. @@ -382,7 +384,7 @@ def __init__( # Save the provided field/file name and content type. self._field_name = field_name self._file_name = file_name - self._headers = headers + self._content_type = content_type # Our actual file name is None by default, since, depending on our # config, we may not actually use the provided name. @@ -436,14 +438,9 @@ def in_memory(self) -> bool: return self._in_memory @property - def headers(self) -> dict[str, bytes]: - """The headers for this part.""" - return self._headers - - @property - def content_type(self) -> bytes | None: - """The Content-Type value for this part.""" - return self._headers.get("content-type") + def content_type(self) -> str | None: + """The Content-Type value for this part, if it was set.""" + return self._content_type def flush_to_disk(self) -> None: """If the file is already on-disk, do nothing. Otherwise, copy from @@ -1565,7 +1562,7 @@ def __init__( def on_start() -> None: nonlocal file - file = FileClass(file_name, None, headers={}, config=cast("FileConfig", self.config)) + file = FileClass(file_name, None, content_type=None, config=cast("FileConfig", self.config)) def on_data(data: bytes, start: int, end: int) -> None: nonlocal file @@ -1604,7 +1601,7 @@ def on_field_name(data: bytes, start: int, end: int) -> None: def on_field_data(data: bytes, start: int, end: int) -> None: nonlocal f if f is None: - f = FieldClass(b"".join(name_buffer), headers={}) + f = FieldClass(b"".join(name_buffer), content_type=None) del name_buffer[:] f.write(data[start:end]) @@ -1614,7 +1611,7 @@ def on_field_end() -> None: if f is None: # If we get here, it's because there was no field data. # We create a field, set it to None, and then continue. - f = FieldClass(b"".join(name_buffer), headers={}) + f = FieldClass(b"".join(name_buffer), content_type=None) del name_buffer[:] f.set_none() @@ -1700,10 +1697,14 @@ def on_headers_finished() -> None: # TODO: check for errors # Create the proper class. + content_type_b = headers.get("content-type") + content_type = content_type_b.decode("latin-1") if content_type_b is not None else None if file_name is None: - f_multi = FieldClass(field_name, headers=headers) + f_multi = FieldClass(field_name, content_type=content_type) else: - f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config), headers=headers) + f_multi = FileClass( + file_name, field_name, config=cast("FileConfig", self.config), content_type=content_type + ) is_file = True # Parse the given Content-Transfer-Encoding to determine what diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 82f7138..a29b7ae 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -758,7 +758,7 @@ def assert_file_data(self, f: File, data: bytes) -> None: file_data = o.read() self.assertEqual(file_data, data) - def assert_file(self, field_name: bytes, file_name: bytes, content_type: str, data: bytes) -> None: + def assert_file(self, field_name: bytes, file_name: bytes, content_type: str | None, data: bytes) -> None: # Find this file. found = None for f in self.files: @@ -770,7 +770,7 @@ def assert_file(self, field_name: bytes, file_name: bytes, content_type: str, da self.assertIsNotNone(found) assert found is not None - self.assertEqual(found.content_type, content_type.encode()) + self.assertEqual(found.content_type, content_type) try: # Assert about this file. @@ -911,7 +911,8 @@ def test_feed_single_bytes(self, param: TestParams) -> None: self.assert_field(name, e["data"]) elif type == "file": - self.assert_file(name, e["file_name"].encode("latin-1"), "text/plain", e["data"]) + content_type = "text/plain" + self.assert_file(name, e["file_name"].encode("latin-1"), content_type, e["data"]) else: assert False @@ -949,24 +950,16 @@ def test_feed_blocks(self) -> None: # Assert that our field is here. self.assert_field(b"field", b"0123456789ABCDEFGHIJ0123456789ABCDEFGHIJ") - def test_file_headers(self) -> None: + def test_file_content_type_header(self) -> None: """ - This test checks headers for a file part are read. + This test checks the content-type for a file part is passed on. """ # Load test data. test_file = "header_with_number.http" with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() - expected_headers = { - "content-disposition": b'form-data; filename="secret.txt"; name="files"', - "content-type": b"text/plain; charset=utf-8", - "x-funky-header-1": b"bar", - "abcdefghijklmnopqrstuvwxyz01234": b"foo", - "abcdefghijklmnopqrstuvwxyz56789": b"bar", - "other!#$%&'*+-.^_`|~": b"baz", - "content-length": b"6", - } + expected_content_type = "text/plain; charset=utf-8" # Create form parser. self.make(boundary="b8825ae386be4fdc9644d87e392caad3") @@ -975,22 +968,19 @@ def test_file_headers(self) -> None: # Assert that our field is here. self.assertEqual(1, len(self.files)) - actual_headers = self.files[0].headers - self.assertEqual(len(actual_headers), len(expected_headers)) + actual_content_type = self.files[0].content_type + self.assertEqual(actual_content_type, expected_content_type) - for k, v in expected_headers.items(): - self.assertEqual(v, actual_headers[k]) - - def test_field_headers(self) -> None: + def test_field_content_type_header(self) -> None: """ - This test checks headers for a field part are read. + This test checks content-tpye for a field part are read and passed. """ # Load test data. test_file = "single_field.http" with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() - expected_headers = {"content-disposition": b'form-data; name="field"'} + expected_content_type = None # Create form parser. self.make(boundary="----WebKitFormBoundaryTkr3kCBQlBe1nrhc") @@ -999,11 +989,8 @@ def test_field_headers(self) -> None: # Assert that our field is here. self.assertEqual(1, len(self.fields)) - actual_headers = self.fields[0].headers - self.assertEqual(len(actual_headers), len(expected_headers)) - - for k, v in expected_headers.items(): - self.assertEqual(v, actual_headers[k]) + actual_content_type = self.fields[0].content_type + self.assertEqual(actual_content_type, expected_content_type) def test_request_body_fuzz(self) -> None: """