Skip to content

Commit

Permalink
Improve type hints (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Feb 11, 2024
1 parent 1474796 commit 2baf8b1
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 32 deletions.
2 changes: 1 addition & 1 deletion multipart/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def write(self, data):
# Return the length of the data to indicate no error.
return len(data)

def close(self):
def close(self) -> None:
"""Close this decoder. If the underlying object has a `close()`
method, this function will call it.
"""
Expand Down
49 changes: 32 additions & 17 deletions multipart/multipart.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import io
import logging
import os
import shutil
Expand Down Expand Up @@ -534,14 +535,14 @@ def _get_disk_file(self):
self._actual_file_name = fname
return tmp_file

def write(self, data: bytes):
def write(self, data: bytes) -> int:
"""Write some data to the File.
:param data: a bytestring
"""
return self.on_data(data)

def on_data(self, data: bytes):
def on_data(self, data: bytes) -> int:
"""This method is a callback that will be called whenever data is
written to the File.
Expand Down Expand Up @@ -652,7 +653,7 @@ def callback(self, name: str, data: bytes | None = None, start: int | None = Non
self.logger.debug("Calling %s with no data", name)
func()

def set_callback(self, name: str, new_func):
def set_callback(self, name: str, new_func: Callable[..., Any] | None) -> None:
"""Update the function for a callback. Removes from the callbacks dict
if new_func is None.
Expand Down Expand Up @@ -1096,7 +1097,7 @@ def __init__(self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, ma
# Note: the +8 is since we can have, at maximum, "\r\n--" + boundary +
# "--\r\n" at the final boundary, and the length of '\r\n--' and
# '--\r\n' is 8 bytes.
self.lookbehind = [NULL for x in range(len(boundary) + 8)]
self.lookbehind = [NULL for _ in range(len(boundary) + 8)]

def write(self, data: bytes) -> int:
"""Write some data to the parser, which will perform size verification,
Expand Down Expand Up @@ -1642,22 +1643,23 @@ def __init__(

# Depending on the Content-Type, we instantiate the correct parser.
if content_type == "application/octet-stream":
f: FileProtocol | None = None
file: FileProtocol = None # type: ignore

def on_start() -> None:
nonlocal f
f = FileClass(file_name, None, config=self.config)
nonlocal file
file = FileClass(file_name, None, config=self.config)

def on_data(data: bytes, start: int, end: int) -> None:
nonlocal f
f.write(data[start:end])
nonlocal file
file.write(data[start:end])

def _on_end() -> None:
nonlocal file
# Finalize the file itself.
f.finalize()
file.finalize()

# Call our callback.
on_file(f)
on_file(file)

# Call the on-end callback.
if self.on_end is not None:
Expand All @@ -1672,7 +1674,7 @@ def _on_end() -> None:
elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded":
name_buffer: list[bytes] = []

f: FieldProtocol | None = None
f: FieldProtocol = None # type: ignore

def on_field_start() -> None:
pass
Expand Down Expand Up @@ -1747,13 +1749,13 @@ def on_part_end() -> None:
else:
on_field(f)

def on_header_field(data: bytes, start: int, end: int):
def on_header_field(data: bytes, start: int, end: int) -> None:
header_name.append(data[start:end])

def on_header_value(data: bytes, start: int, end: int):
def on_header_value(data: bytes, start: int, end: int) -> None:
header_value.append(data[start:end])

def on_header_end():
def on_header_end() -> None:
headers[b"".join(header_name)] = b"".join(header_value)
del header_name[:]
del header_value[:]
Expand Down Expand Up @@ -1855,7 +1857,13 @@ def __repr__(self) -> str:
return "{}(content_type={!r}, parser={!r})".format(self.__class__.__name__, self.content_type, self.parser)


def create_form_parser(headers, on_field, on_file, trust_x_headers=False, config={}):
def create_form_parser(
headers: dict[str, bytes],
on_field: OnFieldCallback,
on_file: OnFileCallback,
trust_x_headers: bool = False,
config={},
):
"""This function is a helper function to aid in creating a FormParser
instances. Given a dictionary-like headers object, it will determine
the correct information needed, instantiate a FormParser with the
Expand Down Expand Up @@ -1898,7 +1906,14 @@ def create_form_parser(headers, on_field, on_file, trust_x_headers=False, config
return form_parser


def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, **kwargs):
def parse_form(
headers: dict[str, bytes],
input_stream: io.FileIO,
on_field: OnFieldCallback,
on_file: OnFileCallback,
chunk_size: int = 1048576,
**kwargs,
):
"""This function is useful if you just want to parse a request body,
without too much work. Pass it a dictionary-like object of the request's
headers, and a file-like object for the input stream, along with two
Expand Down
30 changes: 16 additions & 14 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import random
import sys
Expand Down Expand Up @@ -288,19 +290,19 @@ def setUp(self):
self.b.callbacks = {}

def test_callbacks(self):
# The stupid list-ness is to get around lack of nonlocal on py2
l = [0]
called = 0

def on_foo():
l[0] += 1
nonlocal called
called += 1

self.b.set_callback("foo", on_foo)
self.b.callback("foo")
self.assertEqual(l[0], 1)
self.assertEqual(called, 1)

self.b.set_callback("foo", None)
self.b.callback("foo")
self.assertEqual(l[0], 1)
self.assertEqual(called, 1)


class TestQuerystringParser(unittest.TestCase):
Expand All @@ -316,15 +318,15 @@ def setUp(self):
self.reset()

def reset(self):
self.f = []
self.f: list[tuple[bytes, bytes]] = []

name_buffer = []
data_buffer = []
name_buffer: list[bytes] = []
data_buffer: list[bytes] = []

def on_field_name(data, start, end):
def on_field_name(data: bytes, start: int, end: int) -> None:
name_buffer.append(data[start:end])

def on_field_data(data, start, end):
def on_field_data(data: bytes, start: int, end: int) -> None:
data_buffer.append(data[start:end])

def on_field_end():
Expand Down Expand Up @@ -705,13 +707,13 @@ def split_all(val):
class TestFormParser(unittest.TestCase):
def make(self, boundary, config={}):
self.ended = False
self.files = []
self.fields = []
self.files: list[File] = []
self.fields: list[Field] = []

def on_field(f):
def on_field(f: Field) -> None:
self.fields.append(f)

def on_file(f):
def on_file(f: File) -> None:
self.files.append(f)

def on_end():
Expand Down

0 comments on commit 2baf8b1

Please sign in to comment.