Skip to content

Commit

Permalink
37 Refactor Driver and SlipWrapper in preparation for async support (#41
Browse files Browse the repository at this point in the history
)

* Refactored Driver

* Refactored SlipWrapper
  • Loading branch information
rhjdjong authored Mar 21, 2024
1 parent 227ef21 commit 916ae3b
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 161 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ python = ["3.8", "3.9", "3.10", "3.11", "3.12"]
[tool.hatch.envs.test]
extra-dependencies = [
"coverage[toml]>=6.5",
"pytest-asyncio",
]

[tool.hatch.envs.test.scripts]
Expand Down Expand Up @@ -142,3 +143,6 @@ line_length = 120

[tool.mypy]
python_version = "3.8"

[tool.pytest.ini_options]
asyncio_mode = "auto"
115 changes: 47 additions & 68 deletions src/sliplib/slip.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@

from __future__ import annotations

import collections
import re
from queue import Empty, Queue

END = b"\xc0" #: The SLIP `END` byte.
ESC = b"\xdb" #: The SLIP `ESC` byte.
Expand Down Expand Up @@ -129,11 +129,11 @@ class Driver:
"""

def __init__(self) -> None:
self._finished = False
self._recv_buffer = b""
self._packets: collections.deque[bytes] = collections.deque()
self._messages: list[bytes] = []
self._packets: Queue[bytes] = Queue()

def send(self, message: bytes) -> bytes: # pylint: disable=no-self-use
def send(self, message: bytes) -> bytes:
"""Encodes a message into a SLIP-encoded packet.
The message can be any arbitrary byte sequence.
Expand All @@ -146,29 +146,28 @@ def send(self, message: bytes) -> bytes: # pylint: disable=no-self-use
"""
return encode(message)

def receive(self, data: bytes | int) -> list[bytes]:
"""Decodes data and gives a list of decoded messages.
def receive(self, data: bytes | int) -> None:
"""Decodes data to extract the SLIP-encoded messages.
Processes :obj:`data`, which must be a bytes-like object,
and returns a (possibly empty) list with :class:`bytes` objects,
each containing a decoded message.
Any non-terminated SLIP packets in :obj:`data`
are buffered, and processed with the next call to :meth:`receive`.
and extracts and buffers the SLIP messages contained therein.
A non-terminated SLIP packet in :obj:`data`
is also buffered, and processed with the next call to :meth:`receive`.
Args:
data: A bytes-like object to be processed.
An empty :obj:`data` parameter forces the internal
buffer to be flushed and decoded.
An empty :obj:`data` parameter indicates that no more data will follow.
To accommodate iteration over byte sequences, an
integer in the range(0, 256) is also accepted.
Returns:
A (possibly empty) list of decoded messages.
None.
Raises:
ProtocolError: When `data` contains an invalid byte sequence.
.. versionchanged:: 0.7
`receive()` no longer returns a list of decoded messages.
"""

# When a single byte is fed into this function
Expand All @@ -181,71 +180,51 @@ def receive(self, data: bytes | int) -> list[bytes]:
# To force a buffer flush, an END byte is added, so that the
# current contents of _recv_buffer will form a complete message.
if not data:
self._finished = True
data = END

self._recv_buffer += data

# The following situations can occur:
#
# 1) _recv_buffer is empty or contains only END bytes --> no packets available
# 2) _recv_buffer contains non-END bytes --> packets are available
# 2) _recv_buffer contains non-END bytes --> one or more (partial) packets are available
#
# Strip leading END bytes from _recv_buffer to avoid handling empty _packets.
# Strip leading END bytes from _recv_buffer to avoid handling empty packets.
self._recv_buffer = self._recv_buffer.lstrip(END)
if self._recv_buffer:
# The _recv_buffer contains non-END bytes.
# It is now split on sequences of one or more END bytes.
# The trailing element from the split operation is a possibly incomplete
# packet; this element is therefore used as the new _recv_buffer.
# If _recv_buffer contains one or more trailing END bytes,
# (meaning that there are no incomplete packets), then the last element,
# and therefore the new _recv_buffer, is an empty bytes object.
self._packets.extend(re.split(END + b"+", self._recv_buffer))
self._recv_buffer = self._packets.pop()

# Process the buffered packets
return self.flush()

def flush(self) -> list[bytes]:
"""Gives a list of decoded messages.
Decodes the packets in the internal buffer.
This enables the continuation of the processing
of received packets after a :exc:`ProtocolError`
has been handled.

# The _recv_buffer is now split on sequences of one or more END bytes.
# The trailing element from the split operation is a possibly incomplete
# packet; this element is therefore used as the new _recv_buffer.
# If _recv_buffer contains one or more trailing END bytes,
# (meaning that there are no incomplete packets), then the last element,
# and therefore the new _recv_buffer, is an empty bytes object.
*new_packets, self._recv_buffer = re.split(END + b"+", self._recv_buffer)

# Add the packets to the buffer
for packet in new_packets:
self._packets.put(packet)

def get(self, *, block: bool = True, timeout: float | None = None) -> bytes | None:
"""Get the next decoded message.
Remove and decode a SLIP packet from the internal buffer, and return the resulting message.
If `block` is `True` and `timeout` is `None`(the default), then this method blocks until a message is available.
If `timeout` is a positive number, the blocking will last for at most `timeout` seconds,
and the method will return `None` if no message became available in that time.
If `block` is `False` the method returns immediately with either a message or `None`.
Returns:
A (possibly empty) list of decoded messages from the buffered packets.
A decoded SLIP message, or an empty bytestring `b""` if no further message will come available.
Raises:
ProtocolError: When any of the buffered packets contains an invalid byte sequence.
"""
messages: list[bytes] = []
while self._packets:
packet = self._packets.popleft()
try:
msg = decode(packet)
except ProtocolError:
# Add any already decoded messages to the internal message buffer
self._messages = messages
raise
messages.append(msg)
return messages

@property
def messages(self) -> list[bytes]:
"""A list of decoded messages.
The read-only attribute :attr:`messages` contains
the messages that were
already decoded before a
:exc:`ProtocolError` was raised.
This enables the handler of the :exc:`ProtocolError`
exception to recover the messages up to the
point where the error occurred.
This attribute is cleared after it has been read.
ProtocolError: When the packet that contained the message had an invalid byte sequence.
.. versionadded:: 0.7
"""
try:
return self._messages
finally:
self._messages = []
packet = self._packets.get(block, timeout)
except Empty:
return b"" if self._finished else None

return decode(packet)
62 changes: 6 additions & 56 deletions src/sliplib/slipwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,12 @@
"""
from __future__ import annotations

import sys
from types import TracebackType # noqa: TCH003
from typing import TYPE_CHECKING, Generic, TypeVar

if TYPE_CHECKING:
from collections.abc import Iterator

from collections import deque

from sliplib.slip import Driver, ProtocolError
from sliplib.slip import Driver

#: ByteStream is a :class:`TypeVar` that stands for a generic byte stream.
ByteStream = TypeVar("ByteStream")
Expand Down Expand Up @@ -74,11 +70,6 @@ def __init__(self, stream: ByteStream):
self.stream = stream
#: The :class:`SlipWrapper`'s :class:`Driver` instance.
self.driver = Driver()
self._messages: deque[bytes] = deque()
self._protocol_error: ProtocolError | None = None
self._traceback: TracebackType | None = None
self._flush_needed = False
self._stream_closed = False

def send_bytes(self, packet: bytes) -> None:
"""Send a packet over the stream.
Expand Down Expand Up @@ -128,54 +119,13 @@ def recv_msg(self) -> bytes:
A subsequent call to :meth:`recv_msg` (after handling the exception)
will return the message from the next packet.
"""

# First check if there are any pending messages
if self._messages:
return self._messages.popleft()

# No pending messages left. If a ProtocolError has occurred
# it must be re-raised here:
self._handle_pending_protocol_error()

while not self._messages and not self._stream_closed:
# As long as no messages are available,
# flush the internal packet buffer,
# and try to read data
try:
if self._flush_needed:
self._flush_needed = False
self._messages.extend(self.driver.flush())
else:
data = self.recv_bytes()
if data == b"":
self._stream_closed = True
if isinstance(data, int): # Single byte reads are represented as integers
data = bytes([data])
self._messages.extend(self.driver.receive(data))
except ProtocolError as protocol_error:
self._messages.extend(self.driver.messages)
self._protocol_error = protocol_error
self._traceback = sys.exc_info()[2]
break

if self._messages:
return self._messages.popleft()

self._handle_pending_protocol_error()
return b""

def _handle_pending_protocol_error(self) -> None:
if self._protocol_error:
try:
raise self._protocol_error.with_traceback(self._traceback)
finally:
self._protocol_error = None
self._traceback = None
self._flush_needed = True
while (message := self.driver.get(block=False)) is None:
data = self.recv_bytes()
self.driver.receive(data)
return message

def __iter__(self) -> Iterator[bytes]:
while True:
msg = self.recv_msg()
if not msg:
if not (msg := self.recv_msg()):
break
yield msg
69 changes: 36 additions & 33 deletions tests/unit/test_slip.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,24 @@ def test_single_message_decoding(self) -> None:
"""Test decoding of a byte string with a single packet."""
msg = b"hallo"
packet = encode(msg)
msg_list = self.driver.receive(packet)
assert msg_list == [msg]
self.driver.receive(packet)
assert self.driver.get(timeout=0.5) == msg

def test_multi_message_decoding(self) -> None:
"""Test decoding of a byte string with multiple packets."""
msgs = [b"hi", b"there"]
packet = END + msgs[0] + END + msgs[1] + END
assert self.driver.receive(packet) == msgs
self.driver.receive(packet)
assert self.driver.get(timeout=0.5) == msgs[0]
assert self.driver.get(timeout=0.5) == msgs[1]

def test_multiple_end_bytes_are_ignored_during_decoding(self) -> None:
"""Test decoding of a byte string with multiple packets."""
msgs = [b"hi", b"there"]
packet = END + END + msgs[0] + END + END + END + END + msgs[1] + END + END + END
assert self.driver.receive(packet) == msgs
self.driver.receive(packet)
assert self.driver.get(timeout=0.5) == msgs[0]
assert self.driver.get(timeout=0.5) == msgs[1]

def test_split_message_decoding(self) -> None:
"""Test that receives only returns the message after the complete packet has been received.
Expand All @@ -157,51 +161,49 @@ def test_split_message_decoding(self) -> None:
msg = b"hallo\0bye"
packet = END + msg
for byte_ in packet:
assert self.driver.receive(byte_) == []
assert self.driver.receive(END) == [msg]
self.driver.receive(byte_)
assert self.driver.get(block=False) is None
self.driver.receive(END)
assert self.driver.get(timeout=0.5) == msg

def test_flush_buffers_with_empty_packet(self) -> None:
"""Test that receiving an empty byte string results in completion of the pending packet."""
expected_msg_list = [b"hi", b"there"]
packet = END + expected_msg_list[0] + END + expected_msg_list[1]
assert self.driver.receive(packet) == expected_msg_list[:1]
assert self.driver.receive(b"") == expected_msg_list[1:]

def test_packet_with_wrong_escape_sequence(self) -> None:
self.driver.receive(packet)
assert self.driver.get(timeout=0.5) == expected_msg_list[0]
assert self.driver.get(block=False) is None
self.driver.receive(b"")
assert self.driver.get(timeout=0.5) == expected_msg_list[1]

@pytest.mark.parametrize("message", [b"with" + ESC + b" error", b"with trailing" + ESC])
def test_packet_with_protocol_error(self, message: bytes) -> None:
"""Test that an invalid bytes sequence in the packet results in a protocol error."""
msg = b"with" + ESC + b" error"
packet = END + msg + END
with pytest.raises(ProtocolError) as exc_info:
self.driver.receive(packet)
assert exc_info.value.args == (msg,)

def test_packet_with_trailing_escape_byte(self) -> None:
"""Test that a packet with a trailing escape byte results in a protocol error."""
msg = b"with trailing" + ESC
packet = END + msg + END
packet = END + message + END
self.driver.receive(packet)
with pytest.raises(ProtocolError) as exc_info:
self.driver.receive(packet)
assert exc_info.value.args == (msg,)
self.driver.get(timeout=0.5)
assert exc_info.value.args == (message,)

def test_messages_before_invalid_packets(self) -> None:
"""Test that the messages that were received before an invalid packet can be retrieved."""
msgs = [b"hallo", b"with" + ESC + b" error"]
packet = END + END.join(msgs) + END
self.driver.receive(packet)
assert self.driver.get(timeout=0.5) == msgs[0]
with pytest.raises(ProtocolError) as exc_info:
self.driver.receive(packet)
assert self.driver.messages == msgs[:1]
# Verify that the messages attribute is cleared after reading
assert self.driver.messages == []
self.driver.get(timeout=0.5)
assert exc_info.value.args == (msgs[1],)

def test_messages_after_invalid_packets(self) -> None:
"""Test that the messages that were received before an invalid packet can be retrieved."""
msgs = [b"with" + ESC + b" error", b"bye"]
packet = END + END.join(msgs) + END
self.driver.receive(packet)
with pytest.raises(ProtocolError) as exc_info:
self.driver.receive(packet)
self.driver.get(timeout=0.5)
assert exc_info.value.args == (msgs[0],)
assert self.driver.flush() == msgs[1:]
assert self.driver.get(timeout=0.5) == msgs[1]

def test_subsequent_packets_with_wrong_escape_sequence(self) -> None:
"""Test that each invalid packet results in a protocol error."""
Expand All @@ -213,12 +215,13 @@ def test_subsequent_packets_with_wrong_escape_sequence(self) -> None:
b"bye",
]
packet = END + END.join(msgs) + END
self.driver.receive(packet)
assert self.driver.get(timeout=0.5) == msgs[0]
with pytest.raises(ProtocolError) as exc_info:
self.driver.receive(packet)
assert self.driver.messages == [msgs[0]]
self.driver.get(timeout=0.5)
assert exc_info.value.args == (msgs[1],)
assert self.driver.get(timeout=0.5) == msgs[2]
with pytest.raises(ProtocolError) as exc_info:
self.driver.flush()
assert self.driver.messages == [msgs[2]]
self.driver.get(timeout=0.5)
assert exc_info.value.args == (msgs[3],)
assert self.driver.flush() == [msgs[4]]
assert self.driver.get(timeout=0.5) == msgs[4]
Loading

0 comments on commit 916ae3b

Please sign in to comment.