Skip to content

Commit

Permalink
Add TypedDict callbacks (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Feb 10, 2024
1 parent 9c936b7 commit 8fdf6f8
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ lib64
pip-log.txt

# Unit test / coverage reports
.coverage.*
.coverage*
.tox
nosetests.xml

Expand Down
75 changes: 52 additions & 23 deletions multipart/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,38 @@
from enum import IntEnum
from io import BytesIO
from numbers import Number
from typing import Dict, Tuple, Union
from typing import TYPE_CHECKING

from .decoders import Base64Decoder, QuotedPrintableDecoder
from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError

if TYPE_CHECKING: # pragma: no cover
from typing import Callable, TypedDict

class QuerystringCallbacks(TypedDict, total=False):
on_field_start: Callable[[], None]
on_field_name: Callable[[bytes, int, int], None]
on_field_data: Callable[[bytes, int, int], None]
on_field_end: Callable[[], None]
on_end: Callable[[], None]

class OctetStreamCallbacks(TypedDict, total=False):
on_start: Callable[[], None]
on_data: Callable[[bytes, int, int], None]
on_end: Callable[[], None]

class MultipartCallbacks(TypedDict, total=False):
on_part_begin: Callable[[], None]
on_part_data: Callable[[bytes, int, int], None]
on_part_end: Callable[[], None]
on_headers_begin: Callable[[], None]
on_header_field: Callable[[bytes, int, int], None]
on_header_value: Callable[[bytes, int, int], None]
on_header_end: Callable[[], None]
on_headers_finished: Callable[[], None]
on_end: Callable[[], None]


# Unique missing object.
_missing = object()

Expand Down Expand Up @@ -86,7 +113,7 @@ def join_bytes(b):
return bytes(list(b))


def parse_options_header(value: Union[str, bytes]) -> Tuple[bytes, Dict[bytes, bytes]]:
def parse_options_header(value: str | bytes) -> tuple[bytes, dict[bytes, bytes]]:
"""
Parses a Content-Type header into a value in the following format:
(content_type, {parameters})
Expand Down Expand Up @@ -148,15 +175,15 @@ class Field:
:param name: the name of the form field
"""

def __init__(self, name):
def __init__(self, name: str):
self._name = name
self._value = []
self._value: list[bytes] = []

# We cache the joined version of _value for speed.
self._cache = _missing

@classmethod
def from_value(klass, name, value):
def from_value(cls, name: str, value: bytes | None) -> Field:
"""Create an instance of a :class:`Field`, and set the corresponding
value - either None or an actual value. This method will also
finalize the Field itself.
Expand All @@ -166,22 +193,22 @@ def from_value(klass, name, value):
None
"""

f = klass(name)
f = cls(name)
if value is None:
f.set_none()
else:
f.write(value)
f.finalize()
return f

def write(self, data):
def write(self, data: bytes) -> int:
"""Write some data into the form field.
:param data: a bytestring
"""
return self.on_data(data)

def on_data(self, data):
def on_data(self, data: bytes) -> int:
"""This method is a callback that will be called whenever data is
written to the Field.
Expand All @@ -191,24 +218,24 @@ def on_data(self, data):
self._cache = _missing
return len(data)

def on_end(self):
def on_end(self) -> None:
"""This method is called whenever the Field is finalized."""
if self._cache is _missing:
self._cache = b"".join(self._value)

def finalize(self):
def finalize(self) -> None:
"""Finalize the form field."""
self.on_end()

def close(self):
def close(self) -> None:
"""Close the Field object. This will free any underlying cache."""
# Free our value array.
if self._cache is _missing:
self._cache = b"".join(self._value)

del self._value

def set_none(self):
def set_none(self) -> None:
"""Some fields in a querystring can possibly have a value of None - for
example, the string "foo&bar=&baz=asdf" will have a field with the
name "foo" and value None, one with name "bar" and value "", and one
Expand All @@ -218,7 +245,7 @@ def set_none(self):
self._cache = None

@property
def field_name(self):
def field_name(self) -> str:
"""This property returns the name of the field."""
return self._name

Expand All @@ -230,13 +257,13 @@ def value(self):

return self._cache

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if isinstance(other, Field):
return self.field_name == other.field_name and self.value == other.value
else:
return NotImplemented

def __repr__(self):
def __repr__(self) -> str:
if len(self.value) > 97:
# We get the repr, and then insert three dots before the final
# quote.
Expand Down Expand Up @@ -553,7 +580,7 @@ class BaseParser:
def __init__(self):
self.logger = logging.getLogger(__name__)

def callback(self, name, data=None, start=None, end=None):
def callback(self, name: str, data=None, start=None, end=None):
"""This function calls a provided callback with some data. If the
callback is not set, will do nothing.
Expand Down Expand Up @@ -584,7 +611,7 @@ def callback(self, name, data=None, start=None, end=None):
self.logger.debug("Calling %s with no data", name)
func()

def set_callback(self, name, new_func):
def set_callback(self, name: str, new_func):
"""Update the function for a callback. Removes from the callbacks dict
if new_func is None.
Expand Down Expand Up @@ -637,7 +664,7 @@ class OctetStreamParser(BaseParser):
i.e. unbounded.
"""

def __init__(self, callbacks={}, max_size=float("inf")):
def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size=float("inf")):
super().__init__()
self.callbacks = callbacks
self._started = False
Expand All @@ -647,7 +674,7 @@ def __init__(self, callbacks={}, max_size=float("inf")):
self.max_size = max_size
self._current_size = 0

def write(self, data):
def write(self, data: bytes):
"""Write some data to the parser, which will perform size verification,
and then pass the data to the underlying callback.
Expand Down Expand Up @@ -732,7 +759,9 @@ class QuerystringParser(BaseParser):
i.e. unbounded.
"""

def __init__(self, callbacks={}, strict_parsing=False, max_size=float("inf")):
state: QuerystringState

def __init__(self, callbacks: QuerystringCallbacks = {}, strict_parsing=False, max_size=float("inf")):
super().__init__()
self.state = QuerystringState.BEFORE_FIELD
self._found_sep = False
Expand All @@ -748,7 +777,7 @@ def __init__(self, callbacks={}, strict_parsing=False, max_size=float("inf")):
# Should parsing be strict?
self.strict_parsing = strict_parsing

def write(self, data):
def write(self, data: bytes):
"""Write some data to the parser, which will perform size verification,
parse into either a field name or value, and then pass the
corresponding data to the underlying callback. If an error is
Expand Down Expand Up @@ -780,7 +809,7 @@ def write(self, data):

return l

def _internal_write(self, data, length):
def _internal_write(self, data: bytes, length: int):
state = self.state
strict_parsing = self.strict_parsing
found_sep = self._found_sep
Expand Down Expand Up @@ -989,7 +1018,7 @@ class MultipartParser(BaseParser):
i.e. unbounded.
"""

def __init__(self, boundary, callbacks={}, max_size=float("inf")):
def __init__(self, boundary, callbacks: MultipartCallbacks = {}, max_size=float("inf")):
# Initialize parser state.
super().__init__()
self.state = MultipartState.START
Expand Down
16 changes: 7 additions & 9 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,9 @@ def on_field_end():
del name_buffer[:]
del data_buffer[:]

callbacks = {"on_field_name": on_field_name, "on_field_data": on_field_data, "on_field_end": on_field_end}

self.p = QuerystringParser(callbacks)
self.p = QuerystringParser(
callbacks={"on_field_name": on_field_name, "on_field_data": on_field_data, "on_field_end": on_field_end}
)

def test_simple_querystring(self):
self.p.write(b"foo=bar")
Expand Down Expand Up @@ -464,18 +464,16 @@ def setUp(self):
self.started = 0
self.finished = 0

def on_start():
def on_start() -> None:
self.started += 1

def on_data(data, start, end):
def on_data(data: bytes, start: int, end: int) -> None:
self.d.append(data[start:end])

def on_end():
def on_end() -> None:
self.finished += 1

callbacks = {"on_start": on_start, "on_data": on_data, "on_end": on_end}

self.p = OctetStreamParser(callbacks)
self.p = OctetStreamParser(callbacks={"on_start": on_start, "on_data": on_data, "on_end": on_end})

def assert_data(self, data, finalize=True):
self.assertEqual(b"".join(self.d), data)
Expand Down

0 comments on commit 8fdf6f8

Please sign in to comment.