Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

37 Refactor Driver and SlipWrapper in preparation for async support #41

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ python = ["3.8", "3.9", "3.10", "3.11", "3.12"]
[tool.hatch.envs.test]
extra-dependencies = [
"coverage[toml]>=6.5",
"pytest-asyncio",
]

[tool.hatch.envs.test.scripts]
Expand Down Expand Up @@ -142,3 +143,6 @@ line_length = 120

[tool.mypy]
python_version = "3.8"

[tool.pytest.ini_options]
asyncio_mode = "auto"
115 changes: 47 additions & 68 deletions src/sliplib/slip.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@

from __future__ import annotations

import collections
import re
from queue import Empty, Queue

END = b"\xc0" #: The SLIP `END` byte.
ESC = b"\xdb" #: The SLIP `ESC` byte.
Expand Down Expand Up @@ -129,11 +129,11 @@ class Driver:
"""

def __init__(self) -> None:
self._finished = False
self._recv_buffer = b""
self._packets: collections.deque[bytes] = collections.deque()
self._messages: list[bytes] = []
self._packets: Queue[bytes] = Queue()

def send(self, message: bytes) -> bytes: # pylint: disable=no-self-use
def send(self, message: bytes) -> bytes:
"""Encodes a message into a SLIP-encoded packet.

The message can be any arbitrary byte sequence.
Expand All @@ -146,29 +146,28 @@ def send(self, message: bytes) -> bytes: # pylint: disable=no-self-use
"""
return encode(message)

def receive(self, data: bytes | int) -> list[bytes]:
"""Decodes data and gives a list of decoded messages.
def receive(self, data: bytes | int) -> None:
"""Decodes data to extract the SLIP-encoded messages.

Processes :obj:`data`, which must be a bytes-like object,
and returns a (possibly empty) list with :class:`bytes` objects,
each containing a decoded message.
Any non-terminated SLIP packets in :obj:`data`
are buffered, and processed with the next call to :meth:`receive`.
and extracts and buffers the SLIP messages contained therein.

A non-terminated SLIP packet in :obj:`data`
is also buffered, and processed with the next call to :meth:`receive`.

Args:
data: A bytes-like object to be processed.

An empty :obj:`data` parameter forces the internal
buffer to be flushed and decoded.
An empty :obj:`data` parameter indicates that no more data will follow.

To accommodate iteration over byte sequences, an
integer in the range(0, 256) is also accepted.

Returns:
A (possibly empty) list of decoded messages.
None.

Raises:
ProtocolError: When `data` contains an invalid byte sequence.
.. versionchanged:: 0.7
`receive()` no longer returns a list of decoded messages.
"""

# When a single byte is fed into this function
Expand All @@ -181,71 +180,51 @@ def receive(self, data: bytes | int) -> list[bytes]:
# To force a buffer flush, an END byte is added, so that the
# current contents of _recv_buffer will form a complete message.
if not data:
self._finished = True
data = END

self._recv_buffer += data

# The following situations can occur:
#
# 1) _recv_buffer is empty or contains only END bytes --> no packets available
# 2) _recv_buffer contains non-END bytes --> packets are available
# 2) _recv_buffer contains non-END bytes --> one or more (partial) packets are available
#
# Strip leading END bytes from _recv_buffer to avoid handling empty _packets.
# Strip leading END bytes from _recv_buffer to avoid handling empty packets.
self._recv_buffer = self._recv_buffer.lstrip(END)
if self._recv_buffer:
# The _recv_buffer contains non-END bytes.
# It is now split on sequences of one or more END bytes.
# The trailing element from the split operation is a possibly incomplete
# packet; this element is therefore used as the new _recv_buffer.
# If _recv_buffer contains one or more trailing END bytes,
# (meaning that there are no incomplete packets), then the last element,
# and therefore the new _recv_buffer, is an empty bytes object.
self._packets.extend(re.split(END + b"+", self._recv_buffer))
self._recv_buffer = self._packets.pop()

# Process the buffered packets
return self.flush()

def flush(self) -> list[bytes]:
"""Gives a list of decoded messages.

Decodes the packets in the internal buffer.
This enables the continuation of the processing
of received packets after a :exc:`ProtocolError`
has been handled.

# The _recv_buffer is now split on sequences of one or more END bytes.
# The trailing element from the split operation is a possibly incomplete
# packet; this element is therefore used as the new _recv_buffer.
# If _recv_buffer contains one or more trailing END bytes,
# (meaning that there are no incomplete packets), then the last element,
# and therefore the new _recv_buffer, is an empty bytes object.
*new_packets, self._recv_buffer = re.split(END + b"+", self._recv_buffer)

# Add the packets to the buffer
for packet in new_packets:
self._packets.put(packet)

def get(self, *, block: bool = True, timeout: float | None = None) -> bytes | None:
"""Get the next decoded message.

Remove and decode a SLIP packet from the internal buffer, and return the resulting message.
If `block` is `True` and `timeout` is `None`(the default), then this method blocks until a message is available.
If `timeout` is a positive number, the blocking will last for at most `timeout` seconds,
and the method will return `None` if no message became available in that time.
If `block` is `False` the method returns immediately with either a message or `None`.

Returns:
A (possibly empty) list of decoded messages from the buffered packets.
A decoded SLIP message, or an empty bytestring `b""` if no further message will come available.

Raises:
ProtocolError: When any of the buffered packets contains an invalid byte sequence.
"""
messages: list[bytes] = []
while self._packets:
packet = self._packets.popleft()
try:
msg = decode(packet)
except ProtocolError:
# Add any already decoded messages to the internal message buffer
self._messages = messages
raise
messages.append(msg)
return messages

@property
def messages(self) -> list[bytes]:
"""A list of decoded messages.

The read-only attribute :attr:`messages` contains
the messages that were
already decoded before a
:exc:`ProtocolError` was raised.
This enables the handler of the :exc:`ProtocolError`
exception to recover the messages up to the
point where the error occurred.
This attribute is cleared after it has been read.
ProtocolError: When the packet that contained the message had an invalid byte sequence.

.. versionadded:: 0.7
"""
try:
return self._messages
finally:
self._messages = []
packet = self._packets.get(block, timeout)
except Empty:
return b"" if self._finished else None

return decode(packet)
62 changes: 6 additions & 56 deletions src/sliplib/slipwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,12 @@
"""
from __future__ import annotations

import sys
from types import TracebackType # noqa: TCH003
from typing import TYPE_CHECKING, Generic, TypeVar

if TYPE_CHECKING:
from collections.abc import Iterator

from collections import deque

from sliplib.slip import Driver, ProtocolError
from sliplib.slip import Driver

#: ByteStream is a :class:`TypeVar` that stands for a generic byte stream.
ByteStream = TypeVar("ByteStream")
Expand Down Expand Up @@ -74,11 +70,6 @@ def __init__(self, stream: ByteStream):
self.stream = stream
#: The :class:`SlipWrapper`'s :class:`Driver` instance.
self.driver = Driver()
self._messages: deque[bytes] = deque()
self._protocol_error: ProtocolError | None = None
self._traceback: TracebackType | None = None
self._flush_needed = False
self._stream_closed = False

def send_bytes(self, packet: bytes) -> None:
"""Send a packet over the stream.
Expand Down Expand Up @@ -128,54 +119,13 @@ def recv_msg(self) -> bytes:
A subsequent call to :meth:`recv_msg` (after handling the exception)
will return the message from the next packet.
"""

# First check if there are any pending messages
if self._messages:
return self._messages.popleft()

# No pending messages left. If a ProtocolError has occurred
# it must be re-raised here:
self._handle_pending_protocol_error()

while not self._messages and not self._stream_closed:
# As long as no messages are available,
# flush the internal packet buffer,
# and try to read data
try:
if self._flush_needed:
self._flush_needed = False
self._messages.extend(self.driver.flush())
else:
data = self.recv_bytes()
if data == b"":
self._stream_closed = True
if isinstance(data, int): # Single byte reads are represented as integers
data = bytes([data])
self._messages.extend(self.driver.receive(data))
except ProtocolError as protocol_error:
self._messages.extend(self.driver.messages)
self._protocol_error = protocol_error
self._traceback = sys.exc_info()[2]
break

if self._messages:
return self._messages.popleft()

self._handle_pending_protocol_error()
return b""

def _handle_pending_protocol_error(self) -> None:
if self._protocol_error:
try:
raise self._protocol_error.with_traceback(self._traceback)
finally:
self._protocol_error = None
self._traceback = None
self._flush_needed = True
while (message := self.driver.get(block=False)) is None:
data = self.recv_bytes()
self.driver.receive(data)
return message

def __iter__(self) -> Iterator[bytes]:
while True:
msg = self.recv_msg()
if not msg:
if not (msg := self.recv_msg()):
break
yield msg
69 changes: 36 additions & 33 deletions tests/unit/test_slip.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,24 @@ def test_single_message_decoding(self) -> None:
"""Test decoding of a byte string with a single packet."""
msg = b"hallo"
packet = encode(msg)
msg_list = self.driver.receive(packet)
assert msg_list == [msg]
self.driver.receive(packet)
assert self.driver.get(timeout=0.5) == msg

def test_multi_message_decoding(self) -> None:
"""Test decoding of a byte string with multiple packets."""
msgs = [b"hi", b"there"]
packet = END + msgs[0] + END + msgs[1] + END
assert self.driver.receive(packet) == msgs
self.driver.receive(packet)
assert self.driver.get(timeout=0.5) == msgs[0]
assert self.driver.get(timeout=0.5) == msgs[1]

def test_multiple_end_bytes_are_ignored_during_decoding(self) -> None:
"""Test decoding of a byte string with multiple packets."""
msgs = [b"hi", b"there"]
packet = END + END + msgs[0] + END + END + END + END + msgs[1] + END + END + END
assert self.driver.receive(packet) == msgs
self.driver.receive(packet)
assert self.driver.get(timeout=0.5) == msgs[0]
assert self.driver.get(timeout=0.5) == msgs[1]

def test_split_message_decoding(self) -> None:
"""Test that receives only returns the message after the complete packet has been received.
Expand All @@ -157,51 +161,49 @@ def test_split_message_decoding(self) -> None:
msg = b"hallo\0bye"
packet = END + msg
for byte_ in packet:
assert self.driver.receive(byte_) == []
assert self.driver.receive(END) == [msg]
self.driver.receive(byte_)
assert self.driver.get(block=False) is None
self.driver.receive(END)
assert self.driver.get(timeout=0.5) == msg

def test_flush_buffers_with_empty_packet(self) -> None:
"""Test that receiving an empty byte string results in completion of the pending packet."""
expected_msg_list = [b"hi", b"there"]
packet = END + expected_msg_list[0] + END + expected_msg_list[1]
assert self.driver.receive(packet) == expected_msg_list[:1]
assert self.driver.receive(b"") == expected_msg_list[1:]

def test_packet_with_wrong_escape_sequence(self) -> None:
self.driver.receive(packet)
assert self.driver.get(timeout=0.5) == expected_msg_list[0]
assert self.driver.get(block=False) is None
self.driver.receive(b"")
assert self.driver.get(timeout=0.5) == expected_msg_list[1]

@pytest.mark.parametrize("message", [b"with" + ESC + b" error", b"with trailing" + ESC])
def test_packet_with_protocol_error(self, message: bytes) -> None:
"""Test that an invalid bytes sequence in the packet results in a protocol error."""
msg = b"with" + ESC + b" error"
packet = END + msg + END
with pytest.raises(ProtocolError) as exc_info:
self.driver.receive(packet)
assert exc_info.value.args == (msg,)

def test_packet_with_trailing_escape_byte(self) -> None:
"""Test that a packet with a trailing escape byte results in a protocol error."""
msg = b"with trailing" + ESC
packet = END + msg + END
packet = END + message + END
self.driver.receive(packet)
with pytest.raises(ProtocolError) as exc_info:
self.driver.receive(packet)
assert exc_info.value.args == (msg,)
self.driver.get(timeout=0.5)
assert exc_info.value.args == (message,)

def test_messages_before_invalid_packets(self) -> None:
"""Test that the messages that were received before an invalid packet can be retrieved."""
msgs = [b"hallo", b"with" + ESC + b" error"]
packet = END + END.join(msgs) + END
self.driver.receive(packet)
assert self.driver.get(timeout=0.5) == msgs[0]
with pytest.raises(ProtocolError) as exc_info:
self.driver.receive(packet)
assert self.driver.messages == msgs[:1]
# Verify that the messages attribute is cleared after reading
assert self.driver.messages == []
self.driver.get(timeout=0.5)
assert exc_info.value.args == (msgs[1],)

def test_messages_after_invalid_packets(self) -> None:
"""Test that the messages that were received before an invalid packet can be retrieved."""
msgs = [b"with" + ESC + b" error", b"bye"]
packet = END + END.join(msgs) + END
self.driver.receive(packet)
with pytest.raises(ProtocolError) as exc_info:
self.driver.receive(packet)
self.driver.get(timeout=0.5)
assert exc_info.value.args == (msgs[0],)
assert self.driver.flush() == msgs[1:]
assert self.driver.get(timeout=0.5) == msgs[1]

def test_subsequent_packets_with_wrong_escape_sequence(self) -> None:
"""Test that each invalid packet results in a protocol error."""
Expand All @@ -213,12 +215,13 @@ def test_subsequent_packets_with_wrong_escape_sequence(self) -> None:
b"bye",
]
packet = END + END.join(msgs) + END
self.driver.receive(packet)
assert self.driver.get(timeout=0.5) == msgs[0]
with pytest.raises(ProtocolError) as exc_info:
self.driver.receive(packet)
assert self.driver.messages == [msgs[0]]
self.driver.get(timeout=0.5)
assert exc_info.value.args == (msgs[1],)
assert self.driver.get(timeout=0.5) == msgs[2]
with pytest.raises(ProtocolError) as exc_info:
self.driver.flush()
assert self.driver.messages == [msgs[2]]
self.driver.get(timeout=0.5)
assert exc_info.value.args == (msgs[3],)
assert self.driver.flush() == [msgs[4]]
assert self.driver.get(timeout=0.5) == msgs[4]
Loading
Loading