From 01a9b874da3a793fb568e05d773e7f75490db60d Mon Sep 17 00:00:00 2001 From: Ruud de Jong Date: Thu, 21 Mar 2024 10:59:54 +0100 Subject: [PATCH 1/2] Refactored Driver --- pyproject.toml | 4 ++ src/sliplib/slip.py | 113 +++++++++++++-------------------- src/sliplib/slipwrapper.py | 90 +++++++++++++++++++++++++- tests/unit/test_slip.py | 70 ++++++++++---------- tests/unit/test_slipwrapper.py | 18 +++++- 5 files changed, 190 insertions(+), 105 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2c21e69..df9f987 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.test] extra-dependencies = [ "coverage[toml]>=6.5", + "pytest-asyncio", ] [tool.hatch.envs.test.scripts] @@ -142,3 +143,6 @@ line_length = 120 [tool.mypy] python_version = "3.8" + +[tool.pytest.ini_options] +asyncio_mode = "auto" diff --git a/src/sliplib/slip.py b/src/sliplib/slip.py index 280bb3e..b424e88 100644 --- a/src/sliplib/slip.py +++ b/src/sliplib/slip.py @@ -49,6 +49,9 @@ import collections import re +from numbers import Number +from queue import Empty, Queue +from typing import Callable, Optional END = b"\xc0" #: The SLIP `END` byte. ESC = b"\xdb" #: The SLIP `ESC` byte. @@ -129,11 +132,11 @@ class Driver: """ def __init__(self) -> None: + self._finished = False self._recv_buffer = b"" - self._packets: collections.deque[bytes] = collections.deque() - self._messages: list[bytes] = [] + self._packets: Queue[bytes] = Queue() - def send(self, message: bytes) -> bytes: # pylint: disable=no-self-use + def send(self, message: bytes) -> bytes: """Encodes a message into a SLIP-encoded packet. The message can be any arbitrary byte sequence. @@ -146,29 +149,25 @@ def send(self, message: bytes) -> bytes: # pylint: disable=no-self-use """ return encode(message) - def receive(self, data: bytes | int) -> list[bytes]: - """Decodes data and gives a list of decoded messages. + def receive(self, data: bytes | int) -> None: + """Decodes data to extract the SLIP-encoded messages. Processes :obj:`data`, which must be a bytes-like object, - and returns a (possibly empty) list with :class:`bytes` objects, - each containing a decoded message. - Any non-terminated SLIP packets in :obj:`data` - are buffered, and processed with the next call to :meth:`receive`. + and extracts and buffers the SLIP messages contained therein. + + A non-terminated SLIP packet in :obj:`data` + is also buffered, and processed with the next call to :meth:`receive`. Args: data: A bytes-like object to be processed. - An empty :obj:`data` parameter forces the internal - buffer to be flushed and decoded. + An empty :obj:`data` parameter indicates that no more data will follow. To accommodate iteration over byte sequences, an integer in the range(0, 256) is also accepted. Returns: - A (possibly empty) list of decoded messages. - - Raises: - ProtocolError: When `data` contains an invalid byte sequence. + None. """ # When a single byte is fed into this function @@ -181,6 +180,7 @@ def receive(self, data: bytes | int) -> list[bytes]: # To force a buffer flush, an END byte is added, so that the # current contents of _recv_buffer will form a complete message. if not data: + self._finished = True data = END self._recv_buffer += data @@ -188,64 +188,41 @@ def receive(self, data: bytes | int) -> list[bytes]: # The following situations can occur: # # 1) _recv_buffer is empty or contains only END bytes --> no packets available - # 2) _recv_buffer contains non-END bytes --> packets are available + # 2) _recv_buffer contains non-END bytes --> one or more (partial) packets are available # - # Strip leading END bytes from _recv_buffer to avoid handling empty _packets. + # Strip leading END bytes from _recv_buffer to avoid handling empty packets. self._recv_buffer = self._recv_buffer.lstrip(END) - if self._recv_buffer: - # The _recv_buffer contains non-END bytes. - # It is now split on sequences of one or more END bytes. - # The trailing element from the split operation is a possibly incomplete - # packet; this element is therefore used as the new _recv_buffer. - # If _recv_buffer contains one or more trailing END bytes, - # (meaning that there are no incomplete packets), then the last element, - # and therefore the new _recv_buffer, is an empty bytes object. - self._packets.extend(re.split(END + b"+", self._recv_buffer)) - self._recv_buffer = self._packets.pop() - - # Process the buffered packets - return self.flush() - - def flush(self) -> list[bytes]: - """Gives a list of decoded messages. - - Decodes the packets in the internal buffer. - This enables the continuation of the processing - of received packets after a :exc:`ProtocolError` - has been handled. + + # The _recv_buffer is now split on sequences of one or more END bytes. + # The trailing element from the split operation is a possibly incomplete + # packet; this element is therefore used as the new _recv_buffer. + # If _recv_buffer contains one or more trailing END bytes, + # (meaning that there are no incomplete packets), then the last element, + # and therefore the new _recv_buffer, is an empty bytes object. + *new_packets, self._recv_buffer = re.split(END + b"+", self._recv_buffer) + + # Add the packets to the buffer + for packet in new_packets: + self._packets.put(packet) + + def get(self, block: bool = True, timeout: Optional[Number] = None) -> bytes: + """Get the next decoded message. + + Remove and decode a SLIP packet from the internal buffer, and return the resulting message. + If `block` is `True` and `timeout` is `None`(the default), then this method blocks until a message is available. + If `timeout` is a positive number, the blocking will last for at most `timeout` seconds, + and the method will return `None` if no message became available in that time. + If `block` is `False` the method returns immediately with either a message or `None`. Returns: - A (possibly empty) list of decoded messages from the buffered packets. + A decoded SLIP message, or an empty bytestring `b""` if no further message will come available. Raises: - ProtocolError: When any of the buffered packets contains an invalid byte sequence. - """ - messages: list[bytes] = [] - while self._packets: - packet = self._packets.popleft() - try: - msg = decode(packet) - except ProtocolError: - # Add any already decoded messages to the internal message buffer - self._messages = messages - raise - messages.append(msg) - return messages - - @property - def messages(self) -> list[bytes]: - """A list of decoded messages. - - The read-only attribute :attr:`messages` contains - the messages that were - already decoded before a - :exc:`ProtocolError` was raised. - This enables the handler of the :exc:`ProtocolError` - exception to recover the messages up to the - point where the error occurred. - This attribute is cleared after it has been read. + ProtocolError: When the packet that contained the message had an invalid byte sequence. """ try: - return self._messages - finally: - self._messages = [] + packet = self._packets.get(block, timeout) + except Empty: + return b"" if self._finished else None + + return decode(packet) diff --git a/src/sliplib/slipwrapper.py b/src/sliplib/slipwrapper.py index 2e491e3..36d79ef 100644 --- a/src/sliplib/slipwrapper.py +++ b/src/sliplib/slipwrapper.py @@ -27,9 +27,10 @@ """ from __future__ import annotations +import asyncio import sys from types import TracebackType # noqa: TCH003 -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import Coroutine, Optional, TYPE_CHECKING, Generic, TypeVar, Union if TYPE_CHECKING: from collections.abc import Iterator @@ -42,6 +43,15 @@ ByteStream = TypeVar("ByteStream") +def _is_async() -> bool: + try: + asyncio.get_running_loop() + except RuntimeError: + return False + else: + return True + + class SlipWrapper(Generic[ByteStream]): """Base class that provides a message based interface to a byte stream @@ -108,16 +118,46 @@ def recv_bytes(self) -> bytes: """ raise NotImplementedError - def send_msg(self, message: bytes) -> None: + async def async_send_bytes(self, packet: bytes) -> None: + """Send a packet over the stream. + + Derived classes must implement this method. + + Args: + packet: the packet to send over the stream + """ + raise NotImplementedError + + async def async_recv_bytes(self) -> bytes: + """Receive data from the stream. + + Derived classes must implement this method. + + .. note:: + The convention used within the :class:`SlipWrapper` class + is that :meth:`recv_bytes` returns an empty bytes object + to indicate that the end of + the byte stream has been reached and no further data will + be received. Derived implementations must ensure that + this convention is followed. + + Returns: + The bytes received from the stream + """ + raise NotImplementedError + + def send_msg(self, message: bytes) -> Union[None, Coroutine[None, None, None]]: """Send a SLIP-encoded message over the stream. Args: message (bytes): The message to encode and send """ packet = self.driver.send(message) + if _is_async(): + return self.async_send_bytes(packet) self.send_bytes(packet) - def recv_msg(self) -> bytes: + def recv_msg(self) -> Union[bytes, Coroutine[bytes, None, None]]: """Receive a single message from the stream. Returns: @@ -129,6 +169,50 @@ def recv_msg(self) -> bytes: will return the message from the next packet. """ + if _is_async(): + return self._async_recv_msg() + return self._sync_recv_msg() + + def _get_message_from_buffer(self) -> Optional[bytes]: + if not self._messages and self._protocol_error: + # No pending messages left, and a ProtocolError is waiting to be handled. + # The ProtocolError must be re-raised here + try: + raise self._protocol_error.with_traceback(self._traceback) + finally: + self._protocol_error = None + self._traceback = None + self._flush_needed = True + + if self._flush_needed: + # A previously raised ProtocolError has been handled, so the next valid message must be retrieved. + self._flush_needed = False + try: + self._messages.extend(self.driver.flush()) + except ProtocolError as protocol_error: + self._messages.extend(self.driver.messages) + self._protocol_error = protocol_error + self._traceback = sys.exc_info()[2] + + if self._messages: + return self._messages.popleft() + else: + return None + + + def _sync_recv_msg(self) -> bytes: + message = self._get_message_from_buffer() + if message is not None: + return message + + # Get data from the wrapped stream and feed it to the driver. + data = self.recv_bytes() + if data == b"": + self._stream_closed = True + if isinstance(data, int): # Single byte reads are represented as integers + data = bytes([data]) + self._messages.extend(self.driver.receive(data)) + # First check if there are any pending messages if self._messages: return self._messages.popleft() diff --git a/tests/unit/test_slip.py b/tests/unit/test_slip.py index a50c6d8..913d585 100644 --- a/tests/unit/test_slip.py +++ b/tests/unit/test_slip.py @@ -134,20 +134,24 @@ def test_single_message_decoding(self) -> None: """Test decoding of a byte string with a single packet.""" msg = b"hallo" packet = encode(msg) - msg_list = self.driver.receive(packet) - assert msg_list == [msg] + self.driver.receive(packet) + assert self.driver.get(timeout=0.5) == msg def test_multi_message_decoding(self) -> None: """Test decoding of a byte string with multiple packets.""" msgs = [b"hi", b"there"] packet = END + msgs[0] + END + msgs[1] + END - assert self.driver.receive(packet) == msgs + self.driver.receive(packet) + assert self.driver.get(timeout=0.5) == msgs[0] + assert self.driver.get(timeout=0.5) == msgs[1] def test_multiple_end_bytes_are_ignored_during_decoding(self) -> None: """Test decoding of a byte string with multiple packets.""" msgs = [b"hi", b"there"] packet = END + END + msgs[0] + END + END + END + END + msgs[1] + END + END + END - assert self.driver.receive(packet) == msgs + self.driver.receive(packet) + assert self.driver.get(timeout=0.5) == msgs[0] + assert self.driver.get(timeout=0.5) == msgs[1] def test_split_message_decoding(self) -> None: """Test that receives only returns the message after the complete packet has been received. @@ -157,51 +161,52 @@ def test_split_message_decoding(self) -> None: msg = b"hallo\0bye" packet = END + msg for byte_ in packet: - assert self.driver.receive(byte_) == [] - assert self.driver.receive(END) == [msg] + self.driver.receive(byte_) + assert self.driver.get(block=False) is None + self.driver.receive(END) + assert self.driver.get(timeout=0.5) == msg def test_flush_buffers_with_empty_packet(self) -> None: """Test that receiving an empty byte string results in completion of the pending packet.""" expected_msg_list = [b"hi", b"there"] packet = END + expected_msg_list[0] + END + expected_msg_list[1] - assert self.driver.receive(packet) == expected_msg_list[:1] - assert self.driver.receive(b"") == expected_msg_list[1:] + self.driver.receive(packet) + assert self.driver.get(timeout=0.5) == expected_msg_list[0] + assert self.driver.get(block=False) is None + self.driver.receive(b"") + assert self.driver.get(timeout=0.5) == expected_msg_list[1] - def test_packet_with_wrong_escape_sequence(self) -> None: + @pytest.mark.parametrize( + "message", + [b"with" + ESC + b" error", b"with trailing" + ESC] + ) + def test_packet_with_protocol_error(self, message) -> None: """Test that an invalid bytes sequence in the packet results in a protocol error.""" - msg = b"with" + ESC + b" error" - packet = END + msg + END - with pytest.raises(ProtocolError) as exc_info: - self.driver.receive(packet) - assert exc_info.value.args == (msg,) - - def test_packet_with_trailing_escape_byte(self) -> None: - """Test that a packet with a trailing escape byte results in a protocol error.""" - msg = b"with trailing" + ESC - packet = END + msg + END + packet = END + message + END + self.driver.receive(packet) with pytest.raises(ProtocolError) as exc_info: - self.driver.receive(packet) - assert exc_info.value.args == (msg,) + self.driver.get(timeout=0.5) + assert exc_info.value.args == (message,) def test_messages_before_invalid_packets(self) -> None: """Test that the messages that were received before an invalid packet can be retrieved.""" msgs = [b"hallo", b"with" + ESC + b" error"] packet = END + END.join(msgs) + END + self.driver.receive(packet) + assert self.driver.get(timeout=0.5) == msgs[0] with pytest.raises(ProtocolError) as exc_info: - self.driver.receive(packet) - assert self.driver.messages == msgs[:1] - # Verify that the messages attribute is cleared after reading - assert self.driver.messages == [] + self.driver.get(timeout=0.5) assert exc_info.value.args == (msgs[1],) def test_messages_after_invalid_packets(self) -> None: """Test that the messages that were received before an invalid packet can be retrieved.""" msgs = [b"with" + ESC + b" error", b"bye"] packet = END + END.join(msgs) + END + self.driver.receive(packet) with pytest.raises(ProtocolError) as exc_info: - self.driver.receive(packet) + self.driver.get(timeout=0.5) assert exc_info.value.args == (msgs[0],) - assert self.driver.flush() == msgs[1:] + assert self.driver.get(timeout=0.5) == msgs[1] def test_subsequent_packets_with_wrong_escape_sequence(self) -> None: """Test that each invalid packet results in a protocol error.""" @@ -213,12 +218,13 @@ def test_subsequent_packets_with_wrong_escape_sequence(self) -> None: b"bye", ] packet = END + END.join(msgs) + END + self.driver.receive(packet) + assert self.driver.get(timeout=0.5) == msgs[0] with pytest.raises(ProtocolError) as exc_info: - self.driver.receive(packet) - assert self.driver.messages == [msgs[0]] + self.driver.get(timeout=0.5) assert exc_info.value.args == (msgs[1],) + assert self.driver.get(timeout=0.5) == msgs[2] with pytest.raises(ProtocolError) as exc_info: - self.driver.flush() - assert self.driver.messages == [msgs[2]] + self.driver.get(timeout=0.5) assert exc_info.value.args == (msgs[3],) - assert self.driver.flush() == [msgs[4]] + assert self.driver.get(timeout=0.5) == msgs[4] diff --git a/tests/unit/test_slipwrapper.py b/tests/unit/test_slipwrapper.py index 92e9bb6..d777eb8 100644 --- a/tests/unit/test_slipwrapper.py +++ b/tests/unit/test_slipwrapper.py @@ -20,15 +20,29 @@ def setup(self) -> None: self.subwrapper = type("SubSlipWrapper", (SlipWrapper,), {})(None) # Dummy subclass without implementation def test_slip_wrapper_recv_msg_is_not_implemented(self) -> None: - """Verify that calling recv_msg on a SlipWrapper calls that does not implement read_bytes fails.""" + """Verify that calling recv_msg on a SlipWrapper instance that does not implement read_bytes fails.""" with pytest.raises(NotImplementedError): _ = self.slipwrapper.recv_msg() with pytest.raises(NotImplementedError): _ = self.subwrapper.recv_msg() def test_slip_wrapper_send_msg_is_not_implemented(self) -> None: - """Verify that calling send_msg on a SlipWrapper calls that does not implement send_bytes fails.""" + """Verify that calling send_msg on a SlipWrapper instance that does not implement send_bytes fails.""" with pytest.raises(NotImplementedError): self.slipwrapper.send_msg(b"oops") with pytest.raises(NotImplementedError): self.subwrapper.send_msg(b"oops") + + async def test_slip_wrapper_async_recv_msg_is_not_implemented(self) -> None: + """Verify that awaiting recv_msg on a SlipWrapper instance that does not implement async_recv_bytes fails.""" + with pytest.raises(NotImplementedError): + _ = await self.slipwrapper.recv_msg() + with pytest.raises(NotImplementedError): + _ = await self.slipwrapper.recv_msg() + + async def test_slip_wrapper_async_send_msg_is_not_implemented(self) -> None: + """Verify that awaiting send_msg on a SlipWrapper instance that does not implement async_send_bytes fails.""" + with pytest.raises(NotImplementedError): + _ = await self.slipwrapper.send_msg(b"oops") + with pytest.raises(NotImplementedError): + _ = await self.slipwrapper.send_msg(b"oops") \ No newline at end of file From e15a1c1833f046a64b8ce5c096cc9f2fd292a07c Mon Sep 17 00:00:00 2001 From: Ruud de Jong Date: Thu, 21 Mar 2024 12:45:45 +0100 Subject: [PATCH 2/2] Refactored SlipWrapper --- src/sliplib/slip.py | 10 ++- src/sliplib/slipwrapper.py | 152 ++------------------------------- tests/unit/test_slip.py | 7 +- tests/unit/test_slipwrapper.py | 18 +--- 4 files changed, 19 insertions(+), 168 deletions(-) diff --git a/src/sliplib/slip.py b/src/sliplib/slip.py index b424e88..6aaeedd 100644 --- a/src/sliplib/slip.py +++ b/src/sliplib/slip.py @@ -47,11 +47,8 @@ from __future__ import annotations -import collections import re -from numbers import Number from queue import Empty, Queue -from typing import Callable, Optional END = b"\xc0" #: The SLIP `END` byte. ESC = b"\xdb" #: The SLIP `ESC` byte. @@ -168,6 +165,9 @@ def receive(self, data: bytes | int) -> None: Returns: None. + + .. versionchanged:: 0.7 + `receive()` no longer returns a list of decoded messages. """ # When a single byte is fed into this function @@ -205,7 +205,7 @@ def receive(self, data: bytes | int) -> None: for packet in new_packets: self._packets.put(packet) - def get(self, block: bool = True, timeout: Optional[Number] = None) -> bytes: + def get(self, *, block: bool = True, timeout: float | None = None) -> bytes | None: """Get the next decoded message. Remove and decode a SLIP packet from the internal buffer, and return the resulting message. @@ -219,6 +219,8 @@ def get(self, block: bool = True, timeout: Optional[Number] = None) -> bytes: Raises: ProtocolError: When the packet that contained the message had an invalid byte sequence. + + .. versionadded:: 0.7 """ try: packet = self._packets.get(block, timeout) diff --git a/src/sliplib/slipwrapper.py b/src/sliplib/slipwrapper.py index 36d79ef..b5c2d4c 100644 --- a/src/sliplib/slipwrapper.py +++ b/src/sliplib/slipwrapper.py @@ -27,31 +27,17 @@ """ from __future__ import annotations -import asyncio -import sys -from types import TracebackType # noqa: TCH003 -from typing import Coroutine, Optional, TYPE_CHECKING, Generic, TypeVar, Union +from typing import TYPE_CHECKING, Generic, TypeVar if TYPE_CHECKING: from collections.abc import Iterator -from collections import deque - -from sliplib.slip import Driver, ProtocolError +from sliplib.slip import Driver #: ByteStream is a :class:`TypeVar` that stands for a generic byte stream. ByteStream = TypeVar("ByteStream") -def _is_async() -> bool: - try: - asyncio.get_running_loop() - except RuntimeError: - return False - else: - return True - - class SlipWrapper(Generic[ByteStream]): """Base class that provides a message based interface to a byte stream @@ -84,11 +70,6 @@ def __init__(self, stream: ByteStream): self.stream = stream #: The :class:`SlipWrapper`'s :class:`Driver` instance. self.driver = Driver() - self._messages: deque[bytes] = deque() - self._protocol_error: ProtocolError | None = None - self._traceback: TracebackType | None = None - self._flush_needed = False - self._stream_closed = False def send_bytes(self, packet: bytes) -> None: """Send a packet over the stream. @@ -118,46 +99,16 @@ def recv_bytes(self) -> bytes: """ raise NotImplementedError - async def async_send_bytes(self, packet: bytes) -> None: - """Send a packet over the stream. - - Derived classes must implement this method. - - Args: - packet: the packet to send over the stream - """ - raise NotImplementedError - - async def async_recv_bytes(self) -> bytes: - """Receive data from the stream. - - Derived classes must implement this method. - - .. note:: - The convention used within the :class:`SlipWrapper` class - is that :meth:`recv_bytes` returns an empty bytes object - to indicate that the end of - the byte stream has been reached and no further data will - be received. Derived implementations must ensure that - this convention is followed. - - Returns: - The bytes received from the stream - """ - raise NotImplementedError - - def send_msg(self, message: bytes) -> Union[None, Coroutine[None, None, None]]: + def send_msg(self, message: bytes) -> None: """Send a SLIP-encoded message over the stream. Args: message (bytes): The message to encode and send """ packet = self.driver.send(message) - if _is_async(): - return self.async_send_bytes(packet) self.send_bytes(packet) - def recv_msg(self) -> Union[bytes, Coroutine[bytes, None, None]]: + def recv_msg(self) -> bytes: """Receive a single message from the stream. Returns: @@ -168,98 +119,13 @@ def recv_msg(self) -> Union[bytes, Coroutine[bytes, None, None]]: A subsequent call to :meth:`recv_msg` (after handling the exception) will return the message from the next packet. """ - - if _is_async(): - return self._async_recv_msg() - return self._sync_recv_msg() - - def _get_message_from_buffer(self) -> Optional[bytes]: - if not self._messages and self._protocol_error: - # No pending messages left, and a ProtocolError is waiting to be handled. - # The ProtocolError must be re-raised here - try: - raise self._protocol_error.with_traceback(self._traceback) - finally: - self._protocol_error = None - self._traceback = None - self._flush_needed = True - - if self._flush_needed: - # A previously raised ProtocolError has been handled, so the next valid message must be retrieved. - self._flush_needed = False - try: - self._messages.extend(self.driver.flush()) - except ProtocolError as protocol_error: - self._messages.extend(self.driver.messages) - self._protocol_error = protocol_error - self._traceback = sys.exc_info()[2] - - if self._messages: - return self._messages.popleft() - else: - return None - - - def _sync_recv_msg(self) -> bytes: - message = self._get_message_from_buffer() - if message is not None: - return message - - # Get data from the wrapped stream and feed it to the driver. - data = self.recv_bytes() - if data == b"": - self._stream_closed = True - if isinstance(data, int): # Single byte reads are represented as integers - data = bytes([data]) - self._messages.extend(self.driver.receive(data)) - - # First check if there are any pending messages - if self._messages: - return self._messages.popleft() - - # No pending messages left. If a ProtocolError has occurred - # it must be re-raised here: - self._handle_pending_protocol_error() - - while not self._messages and not self._stream_closed: - # As long as no messages are available, - # flush the internal packet buffer, - # and try to read data - try: - if self._flush_needed: - self._flush_needed = False - self._messages.extend(self.driver.flush()) - else: - data = self.recv_bytes() - if data == b"": - self._stream_closed = True - if isinstance(data, int): # Single byte reads are represented as integers - data = bytes([data]) - self._messages.extend(self.driver.receive(data)) - except ProtocolError as protocol_error: - self._messages.extend(self.driver.messages) - self._protocol_error = protocol_error - self._traceback = sys.exc_info()[2] - break - - if self._messages: - return self._messages.popleft() - - self._handle_pending_protocol_error() - return b"" - - def _handle_pending_protocol_error(self) -> None: - if self._protocol_error: - try: - raise self._protocol_error.with_traceback(self._traceback) - finally: - self._protocol_error = None - self._traceback = None - self._flush_needed = True + while (message := self.driver.get(block=False)) is None: + data = self.recv_bytes() + self.driver.receive(data) + return message def __iter__(self) -> Iterator[bytes]: while True: - msg = self.recv_msg() - if not msg: + if not (msg := self.recv_msg()): break yield msg diff --git a/tests/unit/test_slip.py b/tests/unit/test_slip.py index 913d585..ebd7843 100644 --- a/tests/unit/test_slip.py +++ b/tests/unit/test_slip.py @@ -176,11 +176,8 @@ def test_flush_buffers_with_empty_packet(self) -> None: self.driver.receive(b"") assert self.driver.get(timeout=0.5) == expected_msg_list[1] - @pytest.mark.parametrize( - "message", - [b"with" + ESC + b" error", b"with trailing" + ESC] - ) - def test_packet_with_protocol_error(self, message) -> None: + @pytest.mark.parametrize("message", [b"with" + ESC + b" error", b"with trailing" + ESC]) + def test_packet_with_protocol_error(self, message: bytes) -> None: """Test that an invalid bytes sequence in the packet results in a protocol error.""" packet = END + message + END self.driver.receive(packet) diff --git a/tests/unit/test_slipwrapper.py b/tests/unit/test_slipwrapper.py index d777eb8..cb6aec3 100644 --- a/tests/unit/test_slipwrapper.py +++ b/tests/unit/test_slipwrapper.py @@ -19,30 +19,16 @@ def setup(self) -> None: self.slipwrapper = SlipWrapper("not a valid byte stream") self.subwrapper = type("SubSlipWrapper", (SlipWrapper,), {})(None) # Dummy subclass without implementation - def test_slip_wrapper_recv_msg_is_not_implemented(self) -> None: + def test_slip_wrapper_recv_bytes_is_not_implemented(self) -> None: """Verify that calling recv_msg on a SlipWrapper instance that does not implement read_bytes fails.""" with pytest.raises(NotImplementedError): _ = self.slipwrapper.recv_msg() with pytest.raises(NotImplementedError): _ = self.subwrapper.recv_msg() - def test_slip_wrapper_send_msg_is_not_implemented(self) -> None: + def test_slip_wrapper_send_bytes_is_not_implemented(self) -> None: """Verify that calling send_msg on a SlipWrapper instance that does not implement send_bytes fails.""" with pytest.raises(NotImplementedError): self.slipwrapper.send_msg(b"oops") with pytest.raises(NotImplementedError): self.subwrapper.send_msg(b"oops") - - async def test_slip_wrapper_async_recv_msg_is_not_implemented(self) -> None: - """Verify that awaiting recv_msg on a SlipWrapper instance that does not implement async_recv_bytes fails.""" - with pytest.raises(NotImplementedError): - _ = await self.slipwrapper.recv_msg() - with pytest.raises(NotImplementedError): - _ = await self.slipwrapper.recv_msg() - - async def test_slip_wrapper_async_send_msg_is_not_implemented(self) -> None: - """Verify that awaiting send_msg on a SlipWrapper instance that does not implement async_send_bytes fails.""" - with pytest.raises(NotImplementedError): - _ = await self.slipwrapper.send_msg(b"oops") - with pytest.raises(NotImplementedError): - _ = await self.slipwrapper.send_msg(b"oops") \ No newline at end of file