Skip to content

Commit

Permalink
Improve performance, especially in data with many CR-LF (#137)
Browse files Browse the repository at this point in the history
* Improve parsing content with many cr-lf

Drops the look-behind buffer since the content is always the boundary.

* Improve performance by using built-in bytes.find.

The Boyer-Moore-Horspool algorithm was removed and replaced with Python's built-in `find` method. This appears to be faster, sometimes by an order of magnitude.

* Delete unused join_bytes

---------

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
jhnstrk and Kludex authored Sep 28, 2024
1 parent dcf0ba1 commit a790e40
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 76 deletions.
130 changes: 61 additions & 69 deletions multipart/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,6 @@ def ord_char(c: int) -> int:
return c


def join_bytes(b: bytes) -> bytes:
return bytes(list(b))


def parse_options_header(value: str | bytes) -> tuple[bytes, dict[bytes, bytes]]:
"""Parses a Content-Type header into a value in the following format: (content_type, {parameters})."""
# Uses email.message.Message to parse the header as described in PEP 594.
Expand Down Expand Up @@ -976,29 +972,11 @@ def __init__(
# Setup marks. These are used to track the state of data received.
self.marks: dict[str, int] = {}

# TODO: Actually use this rather than the dumb version we currently use
# # Precompute the skip table for the Boyer-Moore-Horspool algorithm.
# skip = [len(boundary) for x in range(256)]
# for i in range(len(boundary) - 1):
# skip[ord_char(boundary[i])] = len(boundary) - i - 1
#
# # We use a tuple since it's a constant, and marginally faster.
# self.skip = tuple(skip)

# Save our boundary.
if isinstance(boundary, str): # pragma: no cover
boundary = boundary.encode("latin-1")
self.boundary = b"\r\n--" + boundary

# Get a set of characters that belong to our boundary.
self.boundary_chars = frozenset(self.boundary)

# We also create a lookbehind list.
# Note: the +8 is since we can have, at maximum, "\r\n--" + boundary +
# "--\r\n" at the final boundary, and the length of '\r\n--' and
# '--\r\n' is 8 bytes.
self.lookbehind = [NULL for _ in range(len(boundary) + 8)]

def write(self, data: bytes) -> int:
"""Write some data to the parser, which will perform size verification,
and then parse the data into the appropriate location (e.g. header,
Expand Down Expand Up @@ -1061,21 +1039,43 @@ def delete_mark(name: str, reset: bool = False) -> None:
# end of the buffer, and reset the mark, instead of deleting it. This
# is used at the end of the function to call our callbacks with any
# remaining data in this chunk.
def data_callback(name: str, remaining: bool = False) -> None:
def data_callback(name: str, end_i: int, remaining: bool = False) -> None:
marked_index = self.marks.get(name)
if marked_index is None:
return

# If we're getting remaining data, we ignore the current i value
# and just call with the remaining data.
if remaining:
self.callback(name, data, marked_index, length)
self.marks[name] = 0

# Otherwise, we call it from the mark to the current byte we're
# processing.
if end_i <= marked_index:
# There is no additional data to send.
pass
elif marked_index >= 0:
# We are emitting data from the local buffer.
self.callback(name, data, marked_index, end_i)
else:
# Some of the data comes from a partial boundary match.
# and requires look-behind.
# We need to use self.flags (and not flags) because we care about
# the state when we entered the loop.
lookbehind_len = -marked_index
if lookbehind_len <= len(boundary):
self.callback(name, boundary, 0, lookbehind_len)
elif self.flags & FLAG_PART_BOUNDARY:
lookback = boundary + b"\r\n"
self.callback(name, lookback, 0, lookbehind_len)
elif self.flags & FLAG_LAST_BOUNDARY:
lookback = boundary + b"--\r\n"
self.callback(name, lookback, 0, lookbehind_len)
else: # pragma: no cover (error case)
self.logger.warning("Look-back buffer error")

if end_i > 0:
self.callback(name, data, 0, end_i)
# If we're getting remaining data, we have got all the data we
# can be certain is not a boundary, leaving only a partial boundary match.
if remaining:
self.marks[name] = end_i - length
else:
self.callback(name, data, marked_index, i)
self.marks.pop(name, None)

# For each byte...
Expand Down Expand Up @@ -1183,7 +1183,7 @@ def data_callback(name: str, remaining: bool = False) -> None:
raise e

# Call our callback with the header field.
data_callback("header_field")
data_callback("header_field", i)

# Move to parsing the header value.
state = MultipartState.HEADER_VALUE_START
Expand Down Expand Up @@ -1212,7 +1212,7 @@ def data_callback(name: str, remaining: bool = False) -> None:
# If we've got a CR, we're nearly done our headers. Otherwise,
# we do nothing and just move past this character.
if c == CR:
data_callback("header_value")
data_callback("header_value", i)
self.callback("header_end")
state = MultipartState.HEADER_VALUE_ALMOST_DONE

Expand Down Expand Up @@ -1256,46 +1256,46 @@ def data_callback(name: str, remaining: bool = False) -> None:
# We're processing our part data right now. During this, we
# need to efficiently search for our boundary, since any data
# on any number of lines can be a part of the current data.
# We use the Boyer-Moore-Horspool algorithm to efficiently
# search through the remainder of the buffer looking for our
# boundary.

# Save the current value of our index. We use this in case we
# find part of a boundary, but it doesn't match fully.
prev_index = index

# Set up variables.
boundary_length = len(boundary)
boundary_end = boundary_length - 1
data_length = length
boundary_chars = self.boundary_chars

# If our index is 0, we're starting a new part, so start our
# search.
if index == 0:
# Search forward until we either hit the end of our buffer,
# or reach a character that's in our boundary.
i += boundary_end
while i < data_length - 1 and data[i] not in boundary_chars:
i += boundary_length

# Reset i back the length of our boundary, which is the
# earliest possible location that could be our match (i.e.
# if we've just broken out of our loop since we saw the
# last character in our boundary)
i -= boundary_end
# The most common case is likely to be that the whole
# boundary is present in the buffer.
# Calling `find` is much faster than iterating here.
i0 = data.find(boundary, i, data_length)
if i0 >= 0:
# We matched the whole boundary string.
index = boundary_length - 1
i = i0 + boundary_length - 1
else:
# No match found for whole string.
# There may be a partial boundary at the end of the
# data, which the find will not match.
# Since the length should to be searched is limited to
# the boundary length, just perform a naive search.
i = max(i, data_length - boundary_length)

# Search forward until we either hit the end of our buffer,
# or reach a potential start of the boundary.
while i < data_length - 1 and data[i] != boundary[0]:
i += 1

c = data[i]

# Now, we have a couple of cases here. If our index is before
# the end of the boundary...
if index < boundary_length:
# If the character matches...
if boundary[index] == c:
# If we found a match for our boundary, we send the
# existing data.
if index == 0:
data_callback("part_data")

# The current character matches, so continue!
index += 1
else:
Expand Down Expand Up @@ -1332,6 +1332,8 @@ def data_callback(name: str, remaining: bool = False) -> None:
# Unset the part boundary flag.
flags &= ~FLAG_PART_BOUNDARY

# We have identified a boundary, callback for any data before it.
data_callback("part_data", i - index)
# Callback indicating that we've reached the end of
# a part, and are starting a new one.
self.callback("part_end")
Expand All @@ -1353,6 +1355,8 @@ def data_callback(name: str, remaining: bool = False) -> None:
elif flags & FLAG_LAST_BOUNDARY:
# We need a second hyphen here.
if c == HYPHEN:
# We have identified a boundary, callback for any data before it.
data_callback("part_data", i - index)
# Callback to end the current part, and then the
# message.
self.callback("part_end")
Expand All @@ -1362,26 +1366,14 @@ def data_callback(name: str, remaining: bool = False) -> None:
# No match, so reset index.
index = 0

# If we have an index, we need to keep this byte for later, in
# case we can't match the full boundary.
if index > 0:
self.lookbehind[index - 1] = c

# Otherwise, our index is 0. If the previous index is not, it
# means we reset something, and we need to take the data we
# thought was part of our boundary and send it along as actual
# data.
elif prev_index > 0:
# Callback to write the saved data.
lb_data = join_bytes(self.lookbehind)
self.callback("part_data", lb_data, 0, prev_index)

if index == 0 and prev_index > 0:
# Overwrite our previous index.
prev_index = 0

# Re-set our mark for part data.
set_mark("part_data")

# Re-consider the current character, since this could be
# the start of the boundary itself.
i -= 1
Expand Down Expand Up @@ -1410,9 +1402,9 @@ def data_callback(name: str, remaining: bool = False) -> None:
# that we haven't yet reached the end of this 'thing'. So, by setting
# the mark to 0, we cause any data callbacks that take place in future
# calls to this function to start from the beginning of that buffer.
data_callback("header_field", True)
data_callback("header_value", True)
data_callback("part_data", True)
data_callback("header_field", length, True)
data_callback("header_value", length, True)
data_callback("part_data", length - index, True)

# Save values to locals.
self.state = state
Expand Down
35 changes: 28 additions & 7 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,14 @@ def test_not_aligned(self):

http_tests.append({"name": fname, "test": test_data, "result": yaml_data})

# Datasets used for single-byte writing test.
single_byte_tests = [
"almost_match_boundary",
"almost_match_boundary_without_CR",
"almost_match_boundary_without_LF",
"almost_match_boundary_without_final_hyphen",
"single_field_single_file",
]

def split_all(val):
"""
Expand Down Expand Up @@ -843,17 +851,19 @@ def test_random_splitting(self):
self.assert_field(b"field", b"test1")
self.assert_file(b"file", b"file.txt", b"test2")

def test_feed_single_bytes(self):
@parametrize("param", [ t for t in http_tests if t["name"] in single_byte_tests])
def test_feed_single_bytes(self, param):
"""
This test parses a simple multipart body 1 byte at a time.
This test parses multipart bodies 1 byte at a time.
"""
# Load test data.
test_file = "single_field_single_file.http"
test_file = param["name"] + ".http"
boundary = param["result"]["boundary"]
with open(os.path.join(http_tests_dir, test_file), "rb") as f:
test_data = f.read()

# Create form parser.
self.make("boundary")
self.make(boundary)

# Write all bytes.
# NOTE: Can't simply do `for b in test_data`, since that gives
Expand All @@ -868,9 +878,20 @@ def test_feed_single_bytes(self):
# Assert we processed everything.
self.assertEqual(i, len(test_data))

# Assert that our file and field are here.
self.assert_field(b"field", b"test1")
self.assert_file(b"file", b"file.txt", b"test2")
# Assert that the parser gave us the appropriate fields/files.
for e in param["result"]["expected"]:
# Get our type and name.
type = e["type"]
name = e["name"].encode("latin-1")

if type == "field":
self.assert_field(name, e["data"])

elif type == "file":
self.assert_file(name, e["file_name"].encode("latin-1"), e["data"])

else:
assert False

def test_feed_blocks(self):
"""
Expand Down

0 comments on commit a790e40

Please sign in to comment.