From 0655b37c809ba32097706da0354bdfa4944bc639 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 1 Jan 2024 21:09:08 -0500 Subject: [PATCH 01/42] Implement independent ASHv2 protocol parsing --- bellows/ash.py | 529 ++++++++++++++++++++++++++++++++++++++++++++++ tests/test_ash.py | 21 ++ 2 files changed, 550 insertions(+) create mode 100644 bellows/ash.py create mode 100644 tests/test_ash.py diff --git a/bellows/ash.py b/bellows/ash.py new file mode 100644 index 00000000..79528853 --- /dev/null +++ b/bellows/ash.py @@ -0,0 +1,529 @@ +from __future__ import annotations + +import abc +import asyncio +import binascii +import dataclasses +import enum +import logging + +from zigpy.types import BaseDataclassMixin + +import bellows.types as t + +_LOGGER = logging.getLogger(__name__) + +FLAG = b"\x7E" # Marks end of frame +ESCAPE = b"\x7D" +XON = b"\x11" # Resume transmission +XOFF = b"\x13" # Stop transmission +SUBSTITUTE = b"\x18" +CANCEL = b"\x1A" # Terminates a frame in progress + +RESERVED = frozenset(FLAG + ESCAPE + XON + XOFF + SUBSTITUTE + CANCEL) + +# Initial value of t_rx_ack, the maximum time the NCP waits to receive acknowledgement +# of a DATA frame +T_RX_ACK_INIT = 1.6 + +# Minimum value of t_rx_ack +T_RX_ACK_MIN = 0.4 + +# Maximum value of t_rx_ack +T_RX_ACK_MAX = 3.2 + +# Delay before sending a non-piggybacked acknowledgement +T_TX_ACK_DELAY = 0.02 + +# Time from receiving an ACK or NAK with the nRdy flag set after which the NCP resumes +# sending callback frames to the host without requiring an ACK or NAK with the nRdy +# flag clear +T_REMOTE_NOTRDY = 1.0 + +# Maximum number of DATA frames the NCP can transmit without having received +# acknowledgements +TX_K = 5 + +# Maximum number of consecutive timeouts allowed while waiting to receive an ACK before +# going to the FAILED state. The value 0 prevents the NCP from entering the error state +# due to timeouts. +ACK_TIMEOUTS = 4 + + +def generate_random_sequence(length: int) -> bytes: + output = bytearray() + rand = 0x42 + + for _i in range(length): + output.append(rand) + + if rand & 0b00000001 == 0: + rand = rand >> 1 + else: + rand = (rand >> 1) ^ 0xB8 + + return output + + +# Since the sequence is static for every frame, we only need to generate it once +PSEUDO_RANDOM_DATA_SEQUENCE = generate_random_sequence(256) + + +class NCPState(enum.Enum): + CONNECTED = "connected" + FAILED = "failed" + + +class AshException(Exception): + pass + + +class NotAcked(AshException): + def __init__(self, frame: NakFrame): + self.frame = frame + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(" f"frame={self.frame}" f")>" + + +class OutOfSequenceError(AshException): + def __init__(self, expected_seq: int, frame: AshFrame): + self.expected_seq = expected_seq + self.frame = frame + + def __repr__(self) -> str: + return ( + f"<{self.__class__.__name__}(" + f"expected_seq={self.expected_seq}" + f", frame={self.frame}" + f")>" + ) + + +class AshFrame(abc.ABC, BaseDataclassMixin): + MASK: t.uint8_t + MASK_VALUE: t.uint8_t + + @classmethod + def from_bytes(cls, data: bytes) -> DataFrame: + raise NotImplementedError() + + def to_bytes(self) -> bytes: + raise NotImplementedError() + + @classmethod + def _unwrap(cls, data: bytes) -> tuple[int, bytes]: + if len(data) < 3: + raise ValueError(f"Frame is too short: {data!r}") + + computed_crc = binascii.crc_hqx(data[:-2], 0xFFFF).to_bytes(2, "big") + + if computed_crc != data[-2:]: + raise ValueError(f"Invalid CRC bytes in frame {data!r}") + + return data[0], data[1:-2] + + @staticmethod + def append_crc(data: bytes) -> bytes: + return data + binascii.crc_hqx(data, 0xFFFF).to_bytes(2, "big") + + +@dataclasses.dataclass(frozen=True) +class DataFrame(AshFrame): + MASK = 0b10000000 + MASK_VALUE = 0b00000000 + + frm_num: int + re_tx: bool + ack_num: int + ezsp_frame: bytes + + @staticmethod + def _randomize(data: bytes) -> bytes: + assert len(data) <= len(PSEUDO_RANDOM_DATA_SEQUENCE) + return bytes([a ^ b for a, b in zip(data, PSEUDO_RANDOM_DATA_SEQUENCE)]) + + @classmethod + def from_bytes(cls, data: bytes) -> DataFrame: + control, data = cls._unwrap(data) + + return cls( + frm_num=(control & 0b01110000) >> 4, + re_tx=(control & 0b00001000) >> 3, + ack_num=(control & 0b00000111) >> 0, + ezsp_frame=cls._randomize(data), + ) + + def to_bytes(self, *, randomize: bool = True) -> bytes: + return self.append_crc( + bytes( + [ + self.MASK_VALUE + | (self.frm_num) << 4 + | (self.re_tx) << 3 + | (self.ack_num) << 0 + ] + ) + + self._randomize(self.ezsp_frame) + ) + + def __str__(self) -> str: + return f"DATA(num={self.frm_num}, ack={self.ack_num}, re_tx={self.re_tx}) = {self.ezsp_frame.hex()}" + + +@dataclasses.dataclass(frozen=True) +class AckFrame(AshFrame): + MASK = 0b11100000 + MASK_VALUE = 0b10000000 + + res: int + ncp_ready: bool + ack_num: int + + @classmethod + def from_bytes(cls, data: bytes) -> AckFrame: + control, data = cls._unwrap(data) + + return cls( + res=(control & 0b00010000) >> 4, + ncp_ready=(control & 0b00001000) >> 3, + ack_num=(control & 0b00000111) >> 0, + ) + + def to_bytes(self) -> bytes: + return self.append_crc( + bytes( + [ + self.MASK_VALUE + | (self.res) << 4 + | (self.ncp_ready) << 3 + | (self.ack_num) << 0 + ] + ) + ) + + def __str__(self) -> str: + return f"ACK(ack={self.ack_num}, ready={'+' if self.ncp_ready == 0 else '-'!r})" + + +@dataclasses.dataclass(frozen=True) +class NakFrame(AshFrame): + MASK = 0b11100000 + MASK_VALUE = 0b10100000 + + res: int + ncp_ready: bool + ack_num: int + + @classmethod + def from_bytes(cls, data: bytes) -> AckFrame: + control, data = cls._unwrap(data) + + return cls( + res=(control & 0b00010000) >> 4, + ncp_ready=(control & 0b00001000) >> 3, + ack_num=(control & 0b00000111) >> 0, + ) + + def to_bytes(self) -> bytes: + return self.append_crc( + bytes( + [ + self.MASK_VALUE + | (self.res) << 4 + | (self.ncp_ready) << 3 + | (self.ack_num) << 0 + ] + ) + ) + + def __str__(self) -> str: + return f"NAK(ack={self.ack_num}, ready={'+' if self.ncp_ready == 0 else '-'!r})" + + +@dataclasses.dataclass(frozen=True) +class RstFrame(AshFrame): + MASK = 0b11111111 + MASK_VALUE = 0b11000000 + + @classmethod + def from_bytes(cls, data: bytes) -> RstFrame: + control, data = cls._unwrap(data) + + if data: + raise ValueError(f"Invalid data for RST frame: {data!r}") + + return cls() + + def to_bytes(self) -> bytes: + return self.append_crc(bytes([self.MASK_VALUE])) + + def __str__(self) -> str: + return "RST()" + + +@dataclasses.dataclass(frozen=True) +class RStackFrame(AshFrame): + MASK = 0b11111111 + MASK_VALUE = 0b11000001 + + version: t.uint8_t + reset_code: t.NcpResetCode + + @classmethod + def from_bytes(cls, data: bytes) -> RStackFrame: + control, data = cls._unwrap(data) + + if len(data) != 2: + raise ValueError(f"Invalid data length for RSTACK frame: {data!r}") + + version = data[0] + + if version != 0x02: + raise ValueError(f"Invalid version for RSTACK frame: {version}") + + reset_code = t.NcpResetCode(data[1]) + + return cls( + version=version, + reset_code=reset_code, + ) + + def to_bytes(self) -> bytes: + return self.append_crc(bytes([self.MASK_VALUE]) + self.data) + + def __str__(self) -> str: + return f"RSTACK(ver={self.version}, code={self.reset_code})" + + +@dataclasses.dataclass(frozen=True) +class ErrorFrame(RStackFrame): + MASK_VALUE = 0b11000010 + + def __str__(self) -> str: + return f"ERROR(ver={self.version}, code={self.reset_code})" + + +class AshProtocol(asyncio.Protocol): + def __init__(self, ezsp_protocol) -> None: + self._transport = None + self._buffer = bytearray() + self._discarding_until_flag: bool = False + self._pending_data_frames: dict[int, asyncio.Future] = {} + self._ncp_state = NCPState.CONNECTED + self._send_data_frame_lock = asyncio.Lock() + self._tx_seq: int = 0 + self._rx_seq: int = 0 + + def _get_tx_seq(self) -> int: + result = self._tx_seq + self._tx_seq = (self._tx_seq + 1) % 8 + + return result + + def _extract_frame(self, data: bytes) -> AshFrame: + control_byte = data[0] + + for frame in [ + DataFrame, + AckFrame, + NakFrame, + RstFrame, + RStackFrame, + ErrorFrame, + ]: + if control_byte & frame.MASK == frame.MASK_VALUE: + return frame.from_bytes(data) + else: + raise ValueError(f"Could not determine frame type: {data!r}") + + @staticmethod + def _stuff_bytes(data: bytes) -> bytes: + """Stuff bytes for transmission""" + out = bytearray() + + for c in data: + if c in RESERVED: + out.extend([ESCAPE[0], c ^ 0b00100000]) + else: + out.append(c) + + return out + + @staticmethod + def _unstuff_bytes(data: bytes) -> bytes: + """Unstuff bytes after receipt""" + out = bytearray() + escaped = False + + for c in data: + if escaped: + byte = c ^ 0b00100000 + assert byte in RESERVED + out.append(byte) + escaped = False + elif c == ESCAPE[0]: + escaped = True + else: + out.append(c) + + return out + + def data_received(self, data: bytes) -> None: + _LOGGER.debug("Received data %r", data) + self._buffer.extend(data) + + while self._buffer: + if self._discarding_until_flag: + if FLAG not in self._buffer: + self._buffer.clear() + return + + self._discarding_until_flag = False + _, _, self._buffer = self._buffer.partition(FLAG) + + if self._buffer.startswith(FLAG): + # Consecutive Flag Bytes after the first Flag Byte are ignored + self._buffer = self._buffer[1:] + elif self._buffer.startswith(CANCEL): + _, _, self._buffer = self._buffer.partition(CANCEL) + elif self._buffer.startswith(XON): + _LOGGER.debug("Received XON byte, resuming transmission") + self._buffer = self._buffer[1:] + elif self._buffer.startswith(XOFF): + _LOGGER.debug("Received XOFF byte, pausing transmission") + self._buffer = self._buffer[1:] + elif self._buffer.startswith(SUBSTITUTE): + self._discarding_until_flag = True + self._buffer = self._buffer[1:] + elif FLAG in self._buffer: + frame_bytes, _, self._buffer = self._buffer.partition(FLAG) + data = self._unstuff_bytes(frame_bytes) + + try: + frame = self._extract_frame(data) + except ValueError: + _LOGGER.warning( + "Failed to parse frame %r", frame_bytes, exc_info=True + ) + else: + self.frame_received(frame) + else: + break + + def _handle_ack(self, frame: DataFrame | AckFrame) -> None: + # Note that ackNum is the number of the next frame the receiver expects and it + # is one greater than the last frame received. + + ack_num = (frame.ack_num - 1) % 8 + + if ack_num not in self._pending_data_frames: + _LOGGER.warning("Received an unexpected ACK: %r", frame) + return + + self._pending_data_frames[ack_num].set_result(True) + + def frame_received(self, frame: AshFrame) -> None: + _LOGGER.debug("Received frame %r", frame) + return + + if isinstance(frame, DataFrame): + # The Host may not piggyback acknowledgments and should promptly send an ACK + # frame when it receives a DATA frame. + self.send_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq + 1)) + self._handle_ack(frame) + + if frame.re_tx: + expected_seq = self._rx_seq + else: + expected_seq = (self._rx_seq + 1) % 8 + + if frame.frm_num != expected_seq: + _LOGGER.warning("Received an out of sequence frame: %r", frame) + else: + self._rx_seq = expected_seq + elif isinstance(frame, AckFrame): + self._handle_ack(frame) + elif isinstance(frame, NakFrame): + error = NotAcked(frame=frame) + self._pending_data_frames[frame.ack_num].set_exception(error) + + def _write_frame(self, frame: AshFrame) -> None: + _LOGGER.debug("Sending frame %r", frame) + data = self._stuff_bytes(frame.to_bytes()) + FLAG + + _LOGGER.debug("Sending data %r", data) + self._transport.write(data) + + async def send_frame(self, frame: AshFrame) -> None: + return await asyncio.shield(self._send_frame(frame)) + + async def _send_frame(self, frame: AshFrame) -> None: + if not isinstance(frame, DataFrame): + self._write_frame(frame) + return + + async with self._send_data_frame_lock: + frm_num = self._get_tx_seq() + ack_future = asyncio.get_running_loop().create_future() + self._pending_data_frames[frm_num] = ack_future + + for attempt in range(ACK_TIMEOUTS): + self.send_frame( + frame.replace( + frm_num=frm_num, + re_tx=(attempt > 0), + ack_num=self._rx_seq, + ) + ) + + try: + await asyncio.wait_for(ack_future, timeout=T_RX_ACK_MAX) + except asyncio.TimeoutError: + pass + else: + break + else: + self._enter_failed_state() + raise + + self._pending_data_frames.pop(frm_num) + + +if __name__ == "__main__": + import ast + import pathlib + import sys + + import coloredlogs + + coloredlogs.install(level="INFO") + + protocol = AshProtocol(None) + + def frame_received(frame): + protocol.last_frame = frame + + protocol.frame_received = frame_received + + for log_f in sys.argv[1:]: + with pathlib.Path(log_f).open("r") as f: + for line in f: + if "xbee" in line: + continue + + if "uart] Sending: b'" in line: + send_frame = bytes.fromhex(line.split(": b'")[1].split("'")[0]) + + protocol.data_received(send_frame) + _LOGGER.info(" -----> %s", protocol.last_frame) + elif " frame: b'" in line and "ZCL" not in line: + decoded = ast.literal_eval(line.split(": b")[1]) + unstuffed_frame = bytes.fromhex(decoded) + receive_frame = ( + AshProtocol._stuff_bytes(unstuffed_frame[:-1]) + + unstuffed_frame[-1:] + ) + + protocol.data_received(receive_frame) + _LOGGER.info("<----- %s", protocol.last_frame) diff --git a/tests/test_ash.py b/tests/test_ash.py new file mode 100644 index 00000000..9a22f894 --- /dev/null +++ b/tests/test_ash.py @@ -0,0 +1,21 @@ +from bellows.ash import PSEUDO_RANDOM_DATA_SEQUENCE, AshProtocol + + +def test_stuffing(): + assert AshProtocol._stuff_bytes(b"\x7E") == b"\x7D\x5E" + assert AshProtocol._stuff_bytes(b"\x11") == b"\x7D\x31" + assert AshProtocol._stuff_bytes(b"\x13") == b"\x7D\x33" + assert AshProtocol._stuff_bytes(b"\x18") == b"\x7D\x38" + assert AshProtocol._stuff_bytes(b"\x1A") == b"\x7D\x3A" + assert AshProtocol._stuff_bytes(b"\x7D") == b"\x7D\x5D" + + assert AshProtocol._unstuff_bytes(b"\x7D\x5E") == b"\x7E" + assert AshProtocol._unstuff_bytes(b"\x7D\x31") == b"\x11" + assert AshProtocol._unstuff_bytes(b"\x7D\x33") == b"\x13" + assert AshProtocol._unstuff_bytes(b"\x7D\x38") == b"\x18" + assert AshProtocol._unstuff_bytes(b"\x7D\x3A") == b"\x1A" + assert AshProtocol._unstuff_bytes(b"\x7D\x5D") == b"\x7D" + + +def test_pseudo_random_data_sequence(): + assert PSEUDO_RANDOM_DATA_SEQUENCE.startswith(b"\x42\x21\xA8\x54\x2A") From 585ffc6b5da0dbbb70dd52e201285fb4cbe9c925 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 2 Jan 2024 15:17:58 -0500 Subject: [PATCH 02/42] Implement dynamic timeout computation --- bellows/ash.py | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index 79528853..66a781a2 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -6,6 +6,7 @@ import dataclasses import enum import logging +import time from zigpy.types import BaseDataclassMixin @@ -314,6 +315,7 @@ def __init__(self, ezsp_protocol) -> None: self._send_data_frame_lock = asyncio.Lock() self._tx_seq: int = 0 self._rx_seq: int = 0 + self._t_rx_ack = T_RX_ACK_INIT def _get_tx_seq(self) -> int: result = self._tx_seq @@ -455,10 +457,14 @@ def _write_frame(self, frame: AshFrame) -> None: _LOGGER.debug("Sending data %r", data) self._transport.write(data) - async def send_frame(self, frame: AshFrame) -> None: - return await asyncio.shield(self._send_frame(frame)) + def _change_ack_timeout(self, new_value: float) -> None: + new_value = max(T_RX_ACK_MIN, min(new_value, T_RX_ACK_MAX)) + _LOGGER.debug( + "Changing ACK timeout from %0.2f to %0.2f", self._t_rx_ack, new_value + ) + self._t_rx_ack = new_value - async def _send_frame(self, frame: AshFrame) -> None: + async def send_frame(self, frame: AshFrame) -> None: if not isinstance(frame, DataFrame): self._write_frame(frame) return @@ -469,19 +475,28 @@ async def _send_frame(self, frame: AshFrame) -> None: self._pending_data_frames[frm_num] = ack_future for attempt in range(ACK_TIMEOUTS): - self.send_frame( - frame.replace( - frm_num=frm_num, - re_tx=(attempt > 0), - ack_num=self._rx_seq, - ) + # Use a fresh ACK number on every try + frame = frame.replace( + frm_num=frm_num, + re_tx=(attempt > 0), + ack_num=self._rx_seq, ) + send_time = time.monotonic() + self.send_frame(frame) + try: - await asyncio.wait_for(ack_future, timeout=T_RX_ACK_MAX) + await asyncio.wait_for(ack_future, timeout=self._t_rx_ack) except asyncio.TimeoutError: - pass + # If a DATA frame acknowledgement is not received within the current + # timeout value, then t_rx_ack isdoubled. + self._change_ack_timeout(2 * self._t_rx_ack) else: + # Whenever an acknowledgement is received, t_rx_ack is set to 7/8 of + # its current value plus 1/2 of the measured time for the + # acknowledgement. + delta = time.monotonic() - send_time + self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) break else: self._enter_failed_state() From 951eb4fceb81e42d490a05608e94b821ee7c18fa Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 2 Jan 2024 15:19:19 -0500 Subject: [PATCH 03/42] Use a semaphore instead of a lock to allow concurrent sending of un-ACKed frames --- bellows/ash.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index 66a781a2..073eacaf 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -312,7 +312,7 @@ def __init__(self, ezsp_protocol) -> None: self._discarding_until_flag: bool = False self._pending_data_frames: dict[int, asyncio.Future] = {} self._ncp_state = NCPState.CONNECTED - self._send_data_frame_lock = asyncio.Lock() + self._send_data_frame_semaphore = asyncio.Semaphore(TX_K) self._tx_seq: int = 0 self._rx_seq: int = 0 self._t_rx_ack = T_RX_ACK_INIT @@ -388,6 +388,7 @@ def data_received(self, data: bytes) -> None: # Consecutive Flag Bytes after the first Flag Byte are ignored self._buffer = self._buffer[1:] elif self._buffer.startswith(CANCEL): + # all data received since the previous Flag Byte to be ignored _, _, self._buffer = self._buffer.partition(CANCEL) elif self._buffer.startswith(XON): _LOGGER.debug("Received XON byte, resuming transmission") @@ -416,7 +417,6 @@ def data_received(self, data: bytes) -> None: def _handle_ack(self, frame: DataFrame | AckFrame) -> None: # Note that ackNum is the number of the next frame the receiver expects and it # is one greater than the last frame received. - ack_num = (frame.ack_num - 1) % 8 if ack_num not in self._pending_data_frames: @@ -427,7 +427,6 @@ def _handle_ack(self, frame: DataFrame | AckFrame) -> None: def frame_received(self, frame: AshFrame) -> None: _LOGGER.debug("Received frame %r", frame) - return if isinstance(frame, DataFrame): # The Host may not piggyback acknowledgments and should promptly send an ACK @@ -442,6 +441,7 @@ def frame_received(self, frame: AshFrame) -> None: if frame.frm_num != expected_seq: _LOGGER.warning("Received an out of sequence frame: %r", frame) + self.send_frame(NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) else: self._rx_seq = expected_seq elif isinstance(frame, AckFrame): @@ -450,7 +450,7 @@ def frame_received(self, frame: AshFrame) -> None: error = NotAcked(frame=frame) self._pending_data_frames[frame.ack_num].set_exception(error) - def _write_frame(self, frame: AshFrame) -> None: + def write_frame(self, frame: AshFrame) -> None: _LOGGER.debug("Sending frame %r", frame) data = self._stuff_bytes(frame.to_bytes()) + FLAG @@ -466,10 +466,11 @@ def _change_ack_timeout(self, new_value: float) -> None: async def send_frame(self, frame: AshFrame) -> None: if not isinstance(frame, DataFrame): - self._write_frame(frame) + # Non-DATA frames can be sent immediately and do not require an ACK + self.write_frame(frame) return - async with self._send_data_frame_lock: + async with self._send_data_frame_semaphore: frm_num = self._get_tx_seq() ack_future = asyncio.get_running_loop().create_future() self._pending_data_frames[frm_num] = ack_future From 937f159a7f69ddc538ba86a51c9fbe4c4025a966 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 2 Jan 2024 21:54:22 -0500 Subject: [PATCH 04/42] Replace current ASH+EZSP implementation --- bellows/ash.py | 133 +++++++++++++------- bellows/ezsp/protocol.py | 4 +- bellows/uart.py | 265 ++++----------------------------------- 3 files changed, 116 insertions(+), 286 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index 073eacaf..79acba11 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -307,6 +307,7 @@ def __str__(self) -> str: class AshProtocol(asyncio.Protocol): def __init__(self, ezsp_protocol) -> None: + self._ezsp_protocol = ezsp_protocol self._transport = None self._buffer = bytearray() self._discarding_until_flag: bool = False @@ -317,12 +318,26 @@ def __init__(self, ezsp_protocol) -> None: self._rx_seq: int = 0 self._t_rx_ack = T_RX_ACK_INIT + def connection_made(self, transport): + self._transport = transport + self._ezsp_protocol.connection_made(self) + + def connection_lost(self, exc): + self._ezsp_protocol.connection_lost(exc) + + def eof_received(self): + self._ezsp_protocol.eof_received() + def _get_tx_seq(self) -> int: result = self._tx_seq self._tx_seq = (self._tx_seq + 1) % 8 return result + def close(self): + if self._transport is not None: + self._transport.close() + def _extract_frame(self, data: bytes) -> AshFrame: control_byte = data[0] @@ -372,7 +387,7 @@ def _unstuff_bytes(data: bytes) -> bytes: return out def data_received(self, data: bytes) -> None: - _LOGGER.debug("Received data %r", data) + _LOGGER.debug("Received data: %s", data.hex()) self._buffer.extend(data) while self._buffer: @@ -419,10 +434,16 @@ def _handle_ack(self, frame: DataFrame | AckFrame) -> None: # is one greater than the last frame received. ack_num = (frame.ack_num - 1) % 8 - if ack_num not in self._pending_data_frames: + fut = self._pending_data_frames.get(ack_num) + + if fut is None: _LOGGER.warning("Received an unexpected ACK: %r", frame) return + elif fut.done(): + _LOGGER.debug("Received a double ACK, ignoring...") + return + _LOGGER.debug("Resolving frame %d", ack_num) self._pending_data_frames[ack_num].set_result(True) def frame_received(self, frame: AshFrame) -> None: @@ -431,79 +452,107 @@ def frame_received(self, frame: AshFrame) -> None: if isinstance(frame, DataFrame): # The Host may not piggyback acknowledgments and should promptly send an ACK # frame when it receives a DATA frame. - self.send_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq + 1)) - self._handle_ack(frame) if frame.re_tx: - expected_seq = self._rx_seq + expected_seq = (self._rx_seq - 1) % 8 else: - expected_seq = (self._rx_seq + 1) % 8 + expected_seq = self._rx_seq if frame.frm_num != expected_seq: _LOGGER.warning("Received an out of sequence frame: %r", frame) - self.send_frame(NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) + self._write_frame(NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) else: - self._rx_seq = expected_seq + self._handle_ack(frame) + + self._rx_seq = (frame.frm_num + 1) % 8 + self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) + + self._ezsp_protocol.data_received(frame.ezsp_frame) + elif isinstance(frame, ErrorFrame): + self._ezsp_protocol.error_received(frame.reset_code) + elif isinstance(frame, RStackFrame): + self._tx_seq = 0 + self._rx_seq = 0 + self._change_ack_timeout(T_RX_ACK_INIT) + self._ezsp_protocol.reset_received(frame.reset_code) elif isinstance(frame, AckFrame): self._handle_ack(frame) elif isinstance(frame, NakFrame): error = NotAcked(frame=frame) self._pending_data_frames[frame.ack_num].set_exception(error) - def write_frame(self, frame: AshFrame) -> None: + def _write_frame(self, frame: AshFrame) -> None: _LOGGER.debug("Sending frame %r", frame) data = self._stuff_bytes(frame.to_bytes()) + FLAG - _LOGGER.debug("Sending data %r", data) + _LOGGER.debug("Sending data %s", data.hex()) self._transport.write(data) def _change_ack_timeout(self, new_value: float) -> None: new_value = max(T_RX_ACK_MIN, min(new_value, T_RX_ACK_MAX)) - _LOGGER.debug( - "Changing ACK timeout from %0.2f to %0.2f", self._t_rx_ack, new_value - ) + + if abs(new_value - self._t_rx_ack) > 0.01: + _LOGGER.debug( + "Changing ACK timeout from %0.2f to %0.2f", self._t_rx_ack, new_value + ) + self._t_rx_ack = new_value - async def send_frame(self, frame: AshFrame) -> None: + async def _send_frame(self, frame: AshFrame) -> None: if not isinstance(frame, DataFrame): # Non-DATA frames can be sent immediately and do not require an ACK - self.write_frame(frame) + self._write_frame(frame) return + if self._send_data_frame_semaphore.locked(): + _LOGGER.debug("Semaphore is locked, waiting") + async with self._send_data_frame_semaphore: - frm_num = self._get_tx_seq() + frm_num = self._tx_seq + self._tx_seq = (self._tx_seq + 1) % 8 + ack_future = asyncio.get_running_loop().create_future() self._pending_data_frames[frm_num] = ack_future - for attempt in range(ACK_TIMEOUTS): - # Use a fresh ACK number on every try - frame = frame.replace( - frm_num=frm_num, - re_tx=(attempt > 0), - ack_num=self._rx_seq, - ) - - send_time = time.monotonic() - self.send_frame(frame) + try: + for attempt in range(ACK_TIMEOUTS): + # Use a fresh ACK number on every retry + frame = frame.replace( + frm_num=frm_num, + re_tx=(attempt > 0), + ack_num=self._rx_seq, + ) - try: - await asyncio.wait_for(ack_future, timeout=self._t_rx_ack) - except asyncio.TimeoutError: - # If a DATA frame acknowledgement is not received within the current - # timeout value, then t_rx_ack isdoubled. - self._change_ack_timeout(2 * self._t_rx_ack) + send_time = time.monotonic() + self._write_frame(frame) + + try: + await asyncio.wait_for(ack_future, timeout=self._t_rx_ack) + except asyncio.TimeoutError: + # If a DATA frame acknowledgement is not received within the current + # timeout value, then t_rx_ack isdoubled. + self._change_ack_timeout(2 * self._t_rx_ack) + else: + # Whenever an acknowledgement is received, t_rx_ack is set to 7/8 of + # its current value plus 1/2 of the measured time for the + # acknowledgement. + delta = time.monotonic() - send_time + self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) + break else: - # Whenever an acknowledgement is received, t_rx_ack is set to 7/8 of - # its current value plus 1/2 of the measured time for the - # acknowledgement. - delta = time.monotonic() - send_time - self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) - break - else: - self._enter_failed_state() - raise + self._enter_failed_state() + raise + finally: + self._pending_data_frames.pop(frm_num) + + async def send_data(self, data: bytes) -> None: + await self._send_frame( + # All of the other fields will be set during transmission/retries + DataFrame(frm_num=None, re_tx=None, ack_num=None, ezsp_frame=data) + ) - self._pending_data_frames.pop(frm_num) + def send_reset(self) -> None: + self._write_frame(RstFrame()) if __name__ == "__main__": diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index c7b08054..5c2f2aeb 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -64,12 +64,13 @@ async def command(self, name, *args) -> Any: """Serialize command and send it.""" LOGGER.debug("Send command %s: %s", name, args) data = self._ezsp_frame(name, *args) - self._gw.data(data) c = self.COMMANDS[name] future = asyncio.Future() self._awaiting[self._seq] = (c[0], c[2], future) self._seq = (self._seq + 1) % 256 + await self._gw.send_data(data) + async with asyncio_timeout(EZSP_CMD_TIMEOUT): return await future @@ -85,6 +86,7 @@ async def update_policies(self, policy_config: dict) -> None: def __call__(self, data: bytes) -> None: """Handler for received data frame.""" + LOGGER.debug("Received EZSP frame %s", data) orig_data = data sequence, frame_id, data = self._ezsp_frame_rx(data) diff --git a/bellows/uart.py b/bellows/uart.py index 73bad18b..13d4b029 100644 --- a/bellows/uart.py +++ b/bellows/uart.py @@ -1,5 +1,4 @@ import asyncio -import binascii import logging import sys @@ -11,6 +10,7 @@ import zigpy.config import zigpy.serial +from bellows.ash import AshProtocol from bellows.thread import EventLoopThread, ThreadsafeProxy import bellows.types as t @@ -35,120 +35,33 @@ class Terminator: pass def __init__(self, application, connected_future=None, connection_done_future=None): - self._send_seq = 0 - self._rec_seq = 0 - self._buffer = b"" self._application = application + self._reset_future = None self._startup_reset_future = None self._connected_future = connected_future - self._sendq = asyncio.Queue() - self._pending = (-1, None) self._connection_done_future = connection_done_future - self._send_task = None + self._transport = None + + def close(self): + self._transport.close() def connection_made(self, transport): """Callback when the uart is connected""" self._transport = transport if self._connected_future is not None: self._connected_future.set_result(True) - self._send_task = asyncio.create_task(self._send_loop()) + + async def send_data(self, data: bytes) -> None: + await self._transport.send_data(data) def data_received(self, data): """Callback when there is data received from the uart""" - # TODO: Fix this handling for multiple instances of the characters - # If a Cancel Byte or Substitute Byte is received, the bytes received - # so far are discarded. In the case of a Substitute Byte, subsequent - # bytes will also be discarded until the next Flag Byte. - if self.CANCEL in data: - self._buffer = b"" - data = data[data.rfind(self.CANCEL) + 1 :] - if self.SUBSTITUTE in data: - self._buffer = b"" - data = data[data.find(self.FLAG) + 1 :] - - self._buffer += data - while self._buffer: - frame, self._buffer = self._extract_frame(self._buffer) - if frame is None: - break - self.frame_received(frame) - - def _extract_frame(self, data): - """Extract a frame from the data buffer""" - if self.FLAG in data: - place = data.find(self.FLAG) - frame = self._unstuff(data[: place + 1]) - rest = data[place + 1 :] - crc = binascii.crc_hqx(frame[:-3], 0xFFFF) - crc = bytes([crc >> 8, crc % 256]) - if crc != frame[-3:-1]: - LOGGER.error( - "CRC error in frame %s (%s != %s)", - binascii.hexlify(frame), - binascii.hexlify(frame[-3:-1]), - binascii.hexlify(crc), - ) - self.write(self._nak_frame()) - # Make sure that we also handle the next frame if it is already received - return self._extract_frame(rest) - - return frame, rest - return None, data - - def frame_received(self, data): - """Frame receive handler""" - if (data[0] & 0b10000000) == 0: - self.data_frame_received(data) - elif (data[0] & 0b11100000) == 0b10000000: - self.ack_frame_received(data) - elif (data[0] & 0b11100000) == 0b10100000: - self.nak_frame_received(data) - elif data[0] == 0b11000000: - self.rst_frame_received(data) - elif data[0] == 0b11000001: - self.rstack_frame_received(data) - elif data[0] == 0b11000010: - self.error_frame_received(data) - else: - LOGGER.error("UNKNOWN FRAME RECEIVED: %r", data) # TODO - - def data_frame_received(self, data): - """Data frame receive handler""" - LOGGER.debug("Data frame: %s", binascii.hexlify(data)) - seq = (data[0] & 0b01110000) >> 4 - self._rec_seq = (seq + 1) % 8 - self.write(self._ack_frame()) - self._handle_ack(data[0]) - self._application.frame_received(self._randomize(data[1:-3])) - - def ack_frame_received(self, data): - """Acknowledgement frame receive handler""" - LOGGER.debug("ACK frame: %s", binascii.hexlify(data)) - self._handle_ack(data[0]) - - def nak_frame_received(self, data): - """Negative acknowledgement frame receive handler""" - LOGGER.debug("NAK frame: %s", binascii.hexlify(data)) - self._handle_nak(data[0]) - - def rst_frame_received(self, data): - """Reset frame handler""" - LOGGER.debug("RST frame: %s", binascii.hexlify(data)) - - def rstack_frame_received(self, data): + self._application.frame_received(data) + + def reset_received(self, code): """Reset acknowledgement frame receive handler""" - self._send_seq = 0 - self._rec_seq = 0 - code, version = self._get_error_code(data) - - LOGGER.debug( - "RSTACK Version: %d Reason: %s frame: %s", - version, - code.name, - binascii.hexlify(data), - ) # not a reset we've requested. Signal application reset if code is not t.NcpResetCode.RESET_SOFTWARE: self._application.enter_failed_state(code) @@ -161,6 +74,10 @@ def rstack_frame_received(self, data): else: LOGGER.warning("Received an unexpected reset: %r", code) + def error_received(self, code): + """Error frame receive handler.""" + self._application.enter_failed_state(code) + async def wait_for_startup_reset(self) -> None: """Wait for the first reset frame on startup.""" assert self._startup_reset_future is None @@ -171,31 +88,6 @@ async def wait_for_startup_reset(self) -> None: finally: self._startup_reset_future = None - @staticmethod - def _get_error_code(data): - """Extracts error code from RSTACK or ERROR frames.""" - return t.NcpResetCode(data[2]), data[1] - - def error_frame_received(self, data): - """Error frame receive handler.""" - error_code, version = self._get_error_code(data) - LOGGER.debug( - "Error code: %s, Version: %d, frame: %s", - error_code.name, - version, - binascii.hexlify(data), - ) - self._application.enter_failed_state(error_code) - - def write(self, data): - """Send data to the uart""" - LOGGER.debug("Sending: %s", binascii.hexlify(data)) - self._transport.write(data) - - def close(self): - self._sendq.put_nowait(self.Terminator) - self._transport.close() - def _reset_cleanup(self, future): """Delete reset future.""" self._reset_future = None @@ -225,10 +117,6 @@ def connection_lost(self, exc): self._reset_future.set_exception(reason) self._reset_future = None - if self._send_task: - self._send_task.cancel() - self._send_task = None - if exc is None: LOGGER.debug("Closed serial connection") return @@ -245,138 +133,29 @@ async def reset(self): ) return await self._reset_future - self._send_seq = 0 - self._rec_seq = 0 - self._buffer = b"" - while not self._sendq.empty(): - self._sendq.get_nowait() - if self._pending[1]: - self._pending[1].set_result(True) - self._pending = (-1, None) - + self._transport.send_reset() self._reset_future = asyncio.get_event_loop().create_future() self._reset_future.add_done_callback(self._reset_cleanup) - self.write(self._rst_frame()) async with asyncio_timeout(RESET_TIMEOUT): return await self._reset_future - async def _send_loop(self): - """Send queue handler""" - while True: - item = await self._sendq.get() - if item is self.Terminator: - break - data, seq = item - success = False - rxmit = 0 - while not success: - self._pending = (seq, asyncio.get_event_loop().create_future()) - self.write(self._data_frame(data, seq, rxmit)) - rxmit = 1 - success = await self._pending[1] - - def _handle_ack(self, control): - """Handle an acknowledgement frame""" - ack = ((control & 0b00000111) - 1) % 8 - if ack == self._pending[0]: - pending, self._pending = self._pending, (-1, None) - pending[1].set_result(True) - - def _handle_nak(self, control): - """Handle negative acknowledgment frame""" - nak = control & 0b00000111 - if nak == self._pending[0]: - self._pending[1].set_result(False) - - def data(self, data): - """Send a data frame""" - seq = self._send_seq - self._send_seq = (seq + 1) % 8 - self._sendq.put_nowait((data, seq)) - - def _data_frame(self, data, seq, rxmit): - """Construct a data frame""" - assert 0 <= seq <= 7 - assert 0 <= rxmit <= 1 - control = (seq << 4) | (rxmit << 3) | self._rec_seq - return self._frame(bytes([control]), self._randomize(data)) - - def _ack_frame(self): - """Construct a acknowledgement frame""" - assert 0 <= self._rec_seq < 8 - control = bytes([0b10000000 | (self._rec_seq & 0b00000111)]) - return self._frame(control, b"") - - def _nak_frame(self): - """Construct a negative acknowledgement frame""" - assert 0 <= self._rec_seq < 8 - control = bytes([0b10100000 | (self._rec_seq & 0b00000111)]) - return self._frame(control, b"") - - def _rst_frame(self): - """Construct a reset frame""" - return self.CANCEL + self._frame(b"\xC0", b"") - - def _frame(self, control, data): - """Construct a frame""" - crc = binascii.crc_hqx(control + data, 0xFFFF) - crc = bytes([crc >> 8, crc % 256]) - return self._stuff(control + data + crc) + self.FLAG - - def _randomize(self, s): - """XOR s with a pseudo-random sequence for transmission - - Used only in data frames - """ - rand = self.RANDOMIZE_START - out = b"" - for c in s: - out += bytes([c ^ rand]) - if rand % 2: - rand = (rand >> 1) ^ self.RANDOMIZE_SEQ - else: - rand = rand >> 1 - return out - - def _stuff(self, s): - """Byte stuff (escape) a string for transmission""" - out = b"" - for c in s: - if c in self.RESERVED: - out += self.ESCAPE + bytes([c ^ self.STUFF]) - else: - out += bytes([c]) - return out - - def _unstuff(self, s): - """Unstuff (unescape) a string after receipt""" - out = b"" - escaped = False - for c in s: - if escaped: - out += bytes([c ^ self.STUFF]) - escaped = False - elif c in self.ESCAPE: - escaped = True - else: - out += bytes([c]) - return out - async def _connect(config, application): loop = asyncio.get_event_loop() connection_future = loop.create_future() connection_done_future = loop.create_future() - protocol = Gateway(application, connection_future, connection_done_future) + + gateway = Gateway(application, connection_future, connection_done_future) + protocol = AshProtocol(gateway) if config[zigpy.config.CONF_DEVICE_FLOW_CONTROL] is None: xon_xoff, rtscts = True, False else: xon_xoff, rtscts = False, True - transport, protocol = await zigpy.serial.create_serial_connection( + transport, _ = await zigpy.serial.create_serial_connection( loop, lambda: protocol, url=config[zigpy.config.CONF_DEVICE_PATH], @@ -387,7 +166,7 @@ async def _connect(config, application): await connection_future - thread_safe_protocol = ThreadsafeProxy(protocol, loop) + thread_safe_protocol = ThreadsafeProxy(gateway, loop) return thread_safe_protocol, connection_done_future From d8349c545dc1afecc1b56e52f8aecdfc78a49960 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 2 Jan 2024 21:57:05 -0500 Subject: [PATCH 05/42] Increase max concurrency to match ASH --- bellows/ezsp/__init__.py | 2 +- bellows/ezsp/protocol.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index fa4aab0b..52f008b9 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -41,7 +41,7 @@ NETWORK_OPS_TIMEOUT = 10 NETWORK_COORDINATOR_STARTUP_RESET_WAIT = 1 -MAX_COMMAND_CONCURRENCY = 4 +MAX_COMMAND_CONCURRENCY = 5 class EZSP: diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index 5c2f2aeb..0ab21878 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -69,9 +69,8 @@ async def command(self, name, *args) -> Any: self._awaiting[self._seq] = (c[0], c[2], future) self._seq = (self._seq + 1) % 256 - await self._gw.send_data(data) - async with asyncio_timeout(EZSP_CMD_TIMEOUT): + await self._gw.send_data(data) return await future async def update_policies(self, policy_config: dict) -> None: From 4e5df9404eab5cd153a5bd5de666705b70ed53c1 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:37:02 -0500 Subject: [PATCH 06/42] Allow setting the ACK mode as host or NCP --- bellows/ash.py | 57 ++++++++++++-------------------------------------- 1 file changed, 13 insertions(+), 44 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index 79acba11..cdaecc6c 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -306,7 +306,7 @@ def __str__(self) -> str: class AshProtocol(asyncio.Protocol): - def __init__(self, ezsp_protocol) -> None: + def __init__(self, ezsp_protocol, *, host: bool = True) -> None: self._ezsp_protocol = ezsp_protocol self._transport = None self._buffer = bytearray() @@ -316,7 +316,9 @@ def __init__(self, ezsp_protocol) -> None: self._send_data_frame_semaphore = asyncio.Semaphore(TX_K) self._tx_seq: int = 0 self._rx_seq: int = 0 - self._t_rx_ack = T_RX_ACK_INIT + + self._host = host + self._t_rx_ack = T_RX_ACK_MAX if host else T_RX_ACK_INIT def connection_made(self, transport): self._transport = transport @@ -529,15 +531,21 @@ async def _send_frame(self, frame: AshFrame) -> None: try: await asyncio.wait_for(ack_future, timeout=self._t_rx_ack) except asyncio.TimeoutError: + _LOGGER.debug("No ACK received in %0.2fs", self._t_rx_ack) # If a DATA frame acknowledgement is not received within the current # timeout value, then t_rx_ack isdoubled. - self._change_ack_timeout(2 * self._t_rx_ack) + if not self._host: + self._change_ack_timeout(2 * self._t_rx_ack) else: # Whenever an acknowledgement is received, t_rx_ack is set to 7/8 of # its current value plus 1/2 of the measured time for the # acknowledgement. - delta = time.monotonic() - send_time - self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) + if not self._host: + delta = time.monotonic() - send_time + self._change_ack_timeout( + (7 / 8) * self._t_rx_ack + 0.5 * delta + ) + break else: self._enter_failed_state() @@ -553,42 +561,3 @@ async def send_data(self, data: bytes) -> None: def send_reset(self) -> None: self._write_frame(RstFrame()) - - -if __name__ == "__main__": - import ast - import pathlib - import sys - - import coloredlogs - - coloredlogs.install(level="INFO") - - protocol = AshProtocol(None) - - def frame_received(frame): - protocol.last_frame = frame - - protocol.frame_received = frame_received - - for log_f in sys.argv[1:]: - with pathlib.Path(log_f).open("r") as f: - for line in f: - if "xbee" in line: - continue - - if "uart] Sending: b'" in line: - send_frame = bytes.fromhex(line.split(": b'")[1].split("'")[0]) - - protocol.data_received(send_frame) - _LOGGER.info(" -----> %s", protocol.last_frame) - elif " frame: b'" in line and "ZCL" not in line: - decoded = ast.literal_eval(line.split(": b")[1]) - unstuffed_frame = bytes.fromhex(decoded) - receive_frame = ( - AshProtocol._stuff_bytes(unstuffed_frame[:-1]) - + unstuffed_frame[-1:] - ) - - protocol.data_received(receive_frame) - _LOGGER.info("<----- %s", protocol.last_frame) From d155888241f4f7a9d93e97b9058a0d5f95806f31 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:37:34 -0500 Subject: [PATCH 07/42] [TEST] Shut down the event loop with a separate exception --- bellows/thread.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/bellows/thread.py b/bellows/thread.py index 6d8c1309..0f92958b 100644 --- a/bellows/thread.py +++ b/bellows/thread.py @@ -4,6 +4,11 @@ import logging import sys + +class EventLoopShuttingDown(RuntimeError): + pass + + LOGGER = logging.getLogger(__name__) @@ -59,11 +64,16 @@ def force_stop(self): if self.loop is None: return + LOGGER.debug("Shutting down thread") + def cancel_tasks_and_stop_loop(): tasks = asyncio.all_tasks(loop=self.loop) for task in tasks: - self.loop.call_soon_threadsafe(task.cancel) + coro = task.get_coro() + + if coro is not None: + self.loop.call_soon_threadsafe(coro.throw, EventLoopShuttingDown()) gather = asyncio.gather(*tasks, return_exceptions=True) gather.add_done_callback( From 515d7db597e3e71c3aa2a5df098835ad09bfd5d4 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Fri, 2 Feb 2024 17:08:41 -0500 Subject: [PATCH 08/42] Re-implement a CLI tool to parse ASH frames from debug logs --- bellows/ash.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/bellows/ash.py b/bellows/ash.py index cdaecc6c..f3ad9011 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -561,3 +561,56 @@ async def send_data(self, data: bytes) -> None: def send_reset(self) -> None: self._write_frame(RstFrame()) + + +def main(): + import ast + import pathlib + import sys + import unittest.mock + + import coloredlogs + + coloredlogs.install(level="DEBUG") + + class CapturingAshProtocol(AshProtocol): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._parsed_frames = [] + + def frame_received(self, frame: AshFrame) -> None: + self._parsed_frames.append(frame) + + with pathlib.Path(sys.argv[1]).open("r") as f: + for line in f: + if "bellows.uart" not in line: + continue + + if "Sending: " in line: + direction = " --->" + elif ( + "Data frame:" in line or "ACK frame: " in line or "NAK frame: " in line + ): + direction = "<--- " + else: + continue + + data = bytes.fromhex(ast.literal_eval(line.split(": b", 1)[1])) + + # Data frames are logged already unstuffed + if direction == "<--- ": + data = AshProtocol._stuff_bytes(data[:-1]) + data[-1:] + + protocol = CapturingAshProtocol(ezsp_protocol=unittest.mock.Mock()) + protocol.data_received(data) + + if len(protocol._parsed_frames) != 1: + raise ValueError(f"Failed to parse frames: {protocol._parsed_frames}") + + frame = protocol._parsed_frames[0] + + _LOGGER.info("%s: %s", direction, frame) + + +if __name__ == "__main__": + main() From 93b4c21cfd72ea86f29d73bb7a70f600cd69bc19 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Fri, 2 Feb 2024 21:01:22 -0500 Subject: [PATCH 09/42] Remove `host` handling of timeouts --- bellows/ash.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index f3ad9011..5b5d9fa2 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -306,7 +306,7 @@ def __str__(self) -> str: class AshProtocol(asyncio.Protocol): - def __init__(self, ezsp_protocol, *, host: bool = True) -> None: + def __init__(self, ezsp_protocol) -> None: self._ezsp_protocol = ezsp_protocol self._transport = None self._buffer = bytearray() @@ -316,9 +316,7 @@ def __init__(self, ezsp_protocol, *, host: bool = True) -> None: self._send_data_frame_semaphore = asyncio.Semaphore(TX_K) self._tx_seq: int = 0 self._rx_seq: int = 0 - - self._host = host - self._t_rx_ack = T_RX_ACK_MAX if host else T_RX_ACK_INIT + self._t_rx_ack = T_RX_ACK_INIT def connection_made(self, transport): self._transport = transport @@ -533,18 +531,14 @@ async def _send_frame(self, frame: AshFrame) -> None: except asyncio.TimeoutError: _LOGGER.debug("No ACK received in %0.2fs", self._t_rx_ack) # If a DATA frame acknowledgement is not received within the current - # timeout value, then t_rx_ack isdoubled. - if not self._host: - self._change_ack_timeout(2 * self._t_rx_ack) + # timeout value, then t_rx_ack is doubled. + self._change_ack_timeout(2 * self._t_rx_ack) else: # Whenever an acknowledgement is received, t_rx_ack is set to 7/8 of # its current value plus 1/2 of the measured time for the # acknowledgement. - if not self._host: - delta = time.monotonic() - send_time - self._change_ack_timeout( - (7 / 8) * self._t_rx_ack + 0.5 * delta - ) + delta = time.monotonic() - send_time + self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) break else: From c2481f41b21b15419cac059e3e203f9094518b81 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Fri, 2 Feb 2024 21:01:41 -0500 Subject: [PATCH 10/42] Properly send ACKs in response to re-transmitted frames --- bellows/ash.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index 5b5d9fa2..d5adefae 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -453,21 +453,19 @@ def frame_received(self, frame: AshFrame) -> None: # The Host may not piggyback acknowledgments and should promptly send an ACK # frame when it receives a DATA frame. - if frame.re_tx: - expected_seq = (self._rx_seq - 1) % 8 - else: - expected_seq = self._rx_seq - - if frame.frm_num != expected_seq: - _LOGGER.warning("Received an out of sequence frame: %r", frame) - self._write_frame(NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) - else: + if frame.frm_num == self._rx_seq: self._handle_ack(frame) - self._rx_seq = (frame.frm_num + 1) % 8 self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) - self._ezsp_protocol.data_received(frame.ezsp_frame) + self._ezsp_protocol.data_received(frame.ezsp_frame) + elif frame.re_tx: + # Retransmitted frames must be immediately ACKed even if they are out of + # sequence + self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) + else: + _LOGGER.warning("Received an out of sequence frame: %r", frame) + self._write_frame(NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) elif isinstance(frame, ErrorFrame): self._ezsp_protocol.error_received(frame.reset_code) elif isinstance(frame, RStackFrame): From 2fb1f32775bc1833cdc64409f7906f05e9f43ed2 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Fri, 2 Feb 2024 21:01:54 -0500 Subject: [PATCH 11/42] Set maximum command concurrency to 1 --- bellows/ezsp/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index 52f008b9..1872a740 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -41,7 +41,7 @@ NETWORK_OPS_TIMEOUT = 10 NETWORK_COORDINATOR_STARTUP_RESET_WAIT = 1 -MAX_COMMAND_CONCURRENCY = 5 +MAX_COMMAND_CONCURRENCY = 1 class EZSP: From 5aaf78540fb024e88d8c76f841d6bc697bd49688 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sat, 3 Feb 2024 01:00:39 -0500 Subject: [PATCH 12/42] Get ASH TX and NAK handling working reliably --- bellows/ash.py | 90 +++++++++++++++++++++++++--------------- bellows/ezsp/protocol.py | 2 +- 2 files changed, 57 insertions(+), 35 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index d5adefae..da0689d2 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -6,8 +6,14 @@ import dataclasses import enum import logging +import sys import time +if sys.version_info[:2] < (3, 11): + from async_timeout import timeout as asyncio_timeout # pragma: no cover +else: + from asyncio import timeout as asyncio_timeout # pragma: no cover + from zigpy.types import BaseDataclassMixin import bellows.types as t @@ -43,7 +49,7 @@ # Maximum number of DATA frames the NCP can transmit without having received # acknowledgements -TX_K = 5 +TX_K = 1 # Maximum number of consecutive timeouts allowed while waiting to receive an ACK before # going to the FAILED state. The value 0 prevents the NCP from entering the error state @@ -437,10 +443,8 @@ def _handle_ack(self, frame: DataFrame | AckFrame) -> None: fut = self._pending_data_frames.get(ack_num) if fut is None: - _LOGGER.warning("Received an unexpected ACK: %r", frame) return elif fut.done(): - _LOGGER.debug("Received a double ACK, ignoring...") return _LOGGER.debug("Resolving frame %d", ack_num) @@ -477,7 +481,13 @@ def frame_received(self, frame: AshFrame) -> None: self._handle_ack(frame) elif isinstance(frame, NakFrame): error = NotAcked(frame=frame) - self._pending_data_frames[frame.ack_num].set_exception(error) + + for frm_num, fut in self._pending_data_frames.items(): + if ( + not frame.ack_num - TX_K <= frm_num <= frame.ack_num + and not fut.done() + ): + fut.set_exception(error) def _write_frame(self, frame: AshFrame) -> None: _LOGGER.debug("Sending frame %r", frame) @@ -509,41 +519,53 @@ async def _send_frame(self, frame: AshFrame) -> None: frm_num = self._tx_seq self._tx_seq = (self._tx_seq + 1) % 8 - ack_future = asyncio.get_running_loop().create_future() - self._pending_data_frames[frm_num] = ack_future + for attempt in range(ACK_TIMEOUTS): + # Use a fresh ACK number on every retry + frame = frame.replace( + frm_num=frm_num, + re_tx=(attempt > 0), + ack_num=self._rx_seq, + ) + + send_time = time.monotonic() - try: - for attempt in range(ACK_TIMEOUTS): - # Use a fresh ACK number on every retry - frame = frame.replace( - frm_num=frm_num, - re_tx=(attempt > 0), - ack_num=self._rx_seq, + ack_future = asyncio.get_running_loop().create_future() + self._pending_data_frames[frm_num] = ack_future + self._write_frame(frame) + + try: + async with asyncio_timeout(self._t_rx_ack): + await ack_future + except NotAcked: + _LOGGER.debug( + "NCP responded with NAK. Retrying (attempt %d)", attempt + 1 ) - send_time = time.monotonic() - self._write_frame(frame) - - try: - await asyncio.wait_for(ack_future, timeout=self._t_rx_ack) - except asyncio.TimeoutError: - _LOGGER.debug("No ACK received in %0.2fs", self._t_rx_ack) - # If a DATA frame acknowledgement is not received within the current - # timeout value, then t_rx_ack is doubled. - self._change_ack_timeout(2 * self._t_rx_ack) - else: - # Whenever an acknowledgement is received, t_rx_ack is set to 7/8 of - # its current value plus 1/2 of the measured time for the - # acknowledgement. - delta = time.monotonic() - send_time - self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) - - break + # For timing purposes, NAK can be treated as an ACK + delta = time.monotonic() - send_time + self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) + except asyncio.TimeoutError: + _LOGGER.debug( + "No ACK received in %0.2fs (attempt %d)", + self._t_rx_ack, + attempt + 1, + ) + # If a DATA frame acknowledgement is not received within the current + # timeout value, then t_rx_ack is doubled. + self._change_ack_timeout(2 * self._t_rx_ack) else: - self._enter_failed_state() - raise - finally: + # Whenever an acknowledgement is received, t_rx_ack is set to 7/8 of + # its current value plus 1/2 of the measured time for the + # acknowledgement. + delta = time.monotonic() - send_time + self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) + + break + + # Any exception will trigger this self._pending_data_frames.pop(frm_num) + else: + raise async def send_data(self, data: bytes) -> None: await self._send_frame( diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index 0ab21878..2b07dea6 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -17,7 +17,7 @@ from bellows.typing import GatewayType LOGGER = logging.getLogger(__name__) -EZSP_CMD_TIMEOUT = 5 +EZSP_CMD_TIMEOUT = 10 class ProtocolHandler(abc.ABC): From b92e9779dce5b1c205a6300672af8b8922e706d8 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 4 Feb 2024 14:06:20 -0500 Subject: [PATCH 13/42] Fix RStackFrame `to_bytes()` --- bellows/ash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bellows/ash.py b/bellows/ash.py index da0689d2..b6cf9688 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -297,7 +297,7 @@ def from_bytes(cls, data: bytes) -> RStackFrame: ) def to_bytes(self) -> bytes: - return self.append_crc(bytes([self.MASK_VALUE]) + self.data) + return self.append_crc(bytes([self.MASK_VALUE, self.version, self.reset_code])) def __str__(self) -> str: return f"RSTACK(ver={self.version}, code={self.reset_code})" From 8646ecbdcbf967a483334a34b6ce8035e15f1512 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 14 Apr 2024 17:38:45 -0400 Subject: [PATCH 14/42] Bump flake8 --- .pre-commit-config.yaml | 5 +++-- pyproject.toml | 8 +++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bededc22..4953778d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,11 +18,12 @@ repos: - --quiet - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 + rev: 7.0.0 hooks: - id: flake8 + entry: pflake8 additional_dependencies: - - Flake8-pyproject==1.2.3 + - pyproject-flake8==7.0.0 - repo: https://github.com/PyCQA/isort rev: 5.12.0 diff --git a/pyproject.toml b/pyproject.toml index 674197b3..9c065a20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,15 +55,17 @@ ignore_errors = true asyncio_mode = "auto" [tool.flake8] -exclude = [".venv", ".git", ".tox", "docs", "venv", "bin", "lib", "deps", "build"] +exclude = ".venv,.git,.tox,docs,venv,bin,lib,deps,build" # To work with Black max-line-length = 88 # W503: Line break occurred before a binary operator # E203: Whitespace before ':' # E501: line too long # D202 No blank lines allowed after function docstring -ignore = ["W503", "E203", "E501", "D202"] -per-file-ignores = ["tests/*:F811,F401,F403"] +ignore = "W503,E203,E501,D202" +per-file-ignores = """ + tests/*:F811,F401,F403 +""" [tool.coverage.run] source = ["bellows"] From 471e2cc81e286981cebdae49f8cb2cb587c6bb8c Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 14 Apr 2024 17:38:51 -0400 Subject: [PATCH 15/42] Optimize command logging for readability --- bellows/ash.py | 6 +++--- bellows/ezsp/protocol.py | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index b6cf9688..bd76ffd1 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -393,7 +393,7 @@ def _unstuff_bytes(data: bytes) -> bytes: return out def data_received(self, data: bytes) -> None: - _LOGGER.debug("Received data: %s", data.hex()) + # _LOGGER.debug("Received data: %s", data.hex()) self._buffer.extend(data) while self._buffer: @@ -447,7 +447,7 @@ def _handle_ack(self, frame: DataFrame | AckFrame) -> None: elif fut.done(): return - _LOGGER.debug("Resolving frame %d", ack_num) + # _LOGGER.debug("Resolving frame %d", ack_num) self._pending_data_frames[ack_num].set_result(True) def frame_received(self, frame: AshFrame) -> None: @@ -493,7 +493,7 @@ def _write_frame(self, frame: AshFrame) -> None: _LOGGER.debug("Sending frame %r", frame) data = self._stuff_bytes(frame.to_bytes()) + FLAG - _LOGGER.debug("Sending data %s", data.hex()) + # _LOGGER.debug("Sending data %s", data.hex()) self._transport.write(data) def _change_ack_timeout(self, new_value: float) -> None: diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index 4c25d806..a5b58653 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -62,7 +62,7 @@ async def add_transient_link_key( async def command(self, name, *args) -> Any: """Serialize command and send it.""" - LOGGER.debug("Send command %s: %s", name, args) + LOGGER.debug("Sending command %s: %s", name, args) data = self._ezsp_frame(name, *args) c = self.COMMANDS[name] future = asyncio.Future() @@ -85,7 +85,6 @@ async def update_policies(self, policy_config: dict) -> None: def __call__(self, data: bytes) -> None: """Handler for received data frame.""" - LOGGER.debug("Received EZSP frame %s", data) orig_data = data sequence, frame_id, data = self._ezsp_frame_rx(data) @@ -108,7 +107,7 @@ def __call__(self, data: bytes) -> None: ) raise - LOGGER.debug("Application frame received %s: %s", frame_name, result) + LOGGER.debug("Received command %s: %s", frame_name, result) if data: LOGGER.debug("Frame contains trailing data: %s", data) From 79a8adc01348a47dc672cb30128e70c95c998cb4 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 14 Apr 2024 17:52:20 -0400 Subject: [PATCH 16/42] Get unit tests passing again --- bellows/ezsp/protocol.py | 14 +- tests/test_ezsp_protocol.py | 22 ++- tests/test_uart.py | 289 ------------------------------------ 3 files changed, 24 insertions(+), 301 deletions(-) diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index a5b58653..857825ee 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import abc import asyncio import binascii import functools import logging import sys -from typing import Any, Callable, Tuple +from typing import TYPE_CHECKING, Any, Callable if sys.version_info[:2] < (3, 11): from async_timeout import timeout as asyncio_timeout # pragma: no cover @@ -14,7 +16,9 @@ from bellows.config import CONF_EZSP_POLICIES from bellows.exception import InvalidCommandError import bellows.types as t -from bellows.typing import GatewayType + +if TYPE_CHECKING: + from bellows.uart import Gateway LOGGER = logging.getLogger(__name__) EZSP_CMD_TIMEOUT = 6 # Sum of all ASH retry timeouts: 0.4 + 0.8 + 1.6 + 3.2 @@ -26,7 +30,7 @@ class ProtocolHandler(abc.ABC): COMMANDS = {} VERSION = None - def __init__(self, cb_handler: Callable, gateway: GatewayType) -> None: + def __init__(self, cb_handler: Callable, gateway: Gateway) -> None: self._handle_callback = cb_handler self._awaiting = {} self._gw = gateway @@ -37,7 +41,7 @@ def __init__(self, cb_handler: Callable, gateway: GatewayType) -> None: } self.tc_policy = 0 - def _ezsp_frame(self, name: str, *args: Tuple[Any, ...]) -> bytes: + def _ezsp_frame(self, name: str, *args: tuple[Any, ...]) -> bytes: """Serialize the named frame and data.""" c = self.COMMANDS[name] frame = self._ezsp_frame_tx(name) @@ -45,7 +49,7 @@ def _ezsp_frame(self, name: str, *args: Tuple[Any, ...]) -> bytes: return frame + data @abc.abstractmethod - def _ezsp_frame_rx(self, data: bytes) -> Tuple[int, int, bytes]: + def _ezsp_frame_rx(self, data: bytes) -> tuple[int, int, bytes]: """Handler for received data frame.""" @abc.abstractmethod diff --git a/tests/test_ezsp_protocol.py b/tests/test_ezsp_protocol.py index 6b727754..754d8bdd 100644 --- a/tests/test_ezsp_protocol.py +++ b/tests/test_ezsp_protocol.py @@ -6,6 +6,7 @@ from bellows.ezsp import EZSP import bellows.ezsp.v4 import bellows.ezsp.v4.types as t +from bellows.uart import Gateway from .async_mock import ANY, AsyncMock, MagicMock, call, patch @@ -13,17 +14,24 @@ @pytest.fixture def prot_hndl(): """Protocol handler mock.""" - return bellows.ezsp.v4.EZSPv4(MagicMock(), MagicMock()) + app = MagicMock() + gateway = Gateway(app) + gateway._transport = AsyncMock() + + callback_handler = MagicMock() + return bellows.ezsp.v4.EZSPv4(callback_handler, gateway) async def test_command(prot_hndl): - coro = prot_hndl.command("nop") - asyncio.get_running_loop().call_soon( - lambda: prot_hndl._awaiting[prot_hndl._seq - 1][2].set_result(True) - ) + with patch.object(prot_hndl._gw, "send_data") as mock_send_data: + coro = prot_hndl.command("nop") + asyncio.get_running_loop().call_soon( + lambda: prot_hndl._awaiting[prot_hndl._seq - 1][2].set_result(True) + ) + + await coro - await coro - assert prot_hndl._gw.data.call_count == 1 + assert mock_send_data.mock_calls == [call(b"\x00\x00\x05")] def test_receive_reply(prot_hndl): diff --git a/tests/test_uart.py b/tests/test_uart.py index 1dc5574c..e017ab5a 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -145,169 +145,11 @@ def gw(): return gw -def test_randomize(gw): - assert gw._randomize(b"\x00\x00\x00\x00\x00") == b"\x42\x21\xa8\x54\x2a" - assert gw._randomize(b"\x42\x21\xa8\x54\x2a") == b"\x00\x00\x00\x00\x00" - - -def test_stuff(gw): - orig = b"\x00\x7E\x01\x7D\x02\x11\x03\x13\x04\x18\x05\x1a\x06" - stuff = ( - b"\x00\x7D\x5E\x01\x7D\x5D\x02\x7D\x31\x03\x7D\x33\x04\x7D\x38\x05\x7D\x3a\x06" - ) - assert gw._stuff(orig) == stuff - - -def test_unstuff(gw): - orig = b"\x00\x7E\x01\x7D\x02\x11\x03\x13\x04\x18\x05\x1a\x06" - stuff = ( - b"\x00\x7D\x5E\x01\x7D\x5D\x02\x7D\x31\x03\x7D\x33\x04\x7D\x38\x05\x7D\x3a\x06" - ) - assert gw._unstuff(stuff) == orig - - -def test_rst(gw): - assert gw._rst_frame() == b"\x1a\xc0\x38\xbc\x7e" - - -def test_data_frame(gw): - expected = b"\x42\x21\xa8\x54\x2a" - assert gw._data_frame(b"\x00\x00\x00\x00\x00", 0, False)[1:-3] == expected - - -def test_cancel_received(gw): - gw.rst_frame_received = MagicMock() - gw.data_received(b"garbage") - gw.data_received(b"\x1a\xc0\x38\xbc\x7e") - assert gw.rst_frame_received.call_count == 1 - assert gw._buffer == b"" - - -def test_substitute_received(gw): - gw.rst_frame_received = MagicMock() - gw.data_received(b"garbage") - gw.data_received(b"\x18\x38\xbc\x7epart") - gw.data_received(b"ial") - gw.rst_frame_received.assert_not_called() - assert gw._buffer == b"partial" - - -def test_partial_data_received(gw): - gw.write = MagicMock() - gw._rec_seq = 5 - gw.data_received(b"\x54\x79\xa1\xb0") - gw.data_received(b"\x50\xf2\x6e\x7e") - assert gw.write.call_count == 1 - assert gw._application.frame_received.call_count == 1 - - -def test_crc_error(gw): - gw.write = MagicMock() - gw.data_received(b"L\xa1\x8e\x03\xcd\x07\xb9Y\xfbG%\xae\xbd~") - assert gw.write.call_count == 1 - assert gw._application.frame_received.call_count == 0 - - -def test_crc_error_and_valid_frame(gw): - gw.write = MagicMock() - gw._rec_seq = 5 - gw.data_received( - b"L\xa1\x8e\x03\xcd\x07\xb9Y\xfbG%\xae\xbd~\x54\x79\xa1\xb0\x50\xf2\x6e\x7e" - ) - assert gw.write.call_count == 2 - assert gw._application.frame_received.call_count == 1 - - -def test_data_frame_received(gw): - gw.write = MagicMock() - gw._rec_seq = 5 - gw.data_received(b"\x54\x79\xa1\xb0\x50\xf2\x6e\x7e") - assert gw.write.call_count == 1 - assert gw._application.frame_received.call_count == 1 - - -def test_ack_frame_received(gw): - gw.data_received(b"\x86\x10\xbe\x7e") - - -def test_nak_frame_received(gw): - gw.frame_received(bytes([0b10100000])) - - -def test_rst_frame_received(gw): - gw.data_received(b"garbage\x1a\xc0\x38\xbc\x7e") - - -def test_rstack_frame_received(gw): - gw._reset_future = MagicMock() - gw._reset_future.done = MagicMock(return_value=False) - gw.data_received(b"\xc1\x02\x0b\nR\x7e") - assert gw._reset_future.done.call_count == 1 - assert gw._reset_future.set_result.call_count == 1 - - -def test_wrong_rstack_frame_received(gw): - gw._reset_future = MagicMock() - gw.data_received(b"\xc1\x02\x01\xab\x18\x7e") - assert gw._reset_future.set_result.call_count == 0 - - -def test_error_rstack_frame_received(gw): - gw._reset_future = MagicMock() - gw.data_received(b"\xc1\x02\x81\x3a\x90\x7e") - assert gw._reset_future.set_result.call_count == 0 - - -def test_rstack_frame_received_nofut(gw): - gw.data_received(b"\xc1\x02\x0b\nR\x7e") - - -def test_rstack_frame_received_out_of_order(gw): - gw._reset_future = MagicMock() - gw._reset_future.done = MagicMock(return_value=True) - gw.data_received(b"\xc1\x02\x0b\nR\x7e") - assert gw._reset_future.done.call_count == 1 - assert gw._reset_future.set_result.call_count == 0 - - -def test_error_frame_received(gw): - from bellows.types import NcpResetCode - - gw.frame_received(b"\xc2\x02\x03\xd2\x0a\x7e") - efs = gw._application.enter_failed_state - assert efs.call_count == 1 - assert efs.call_args[0][0] == NcpResetCode.RESET_WATCHDOG - - -def test_unknown_frame_received(gw): - gw.frame_received(bytes([0b11011111])) - - def test_close(gw): gw.close() assert gw._transport.close.call_count == 1 -async def test_reset(gw): - gw._sendq.put_nowait(sentinel.queue_item) - fut = asyncio.Future() - gw._pending = (sentinel.seq, fut) - gw._transport.write.side_effect = lambda *args: gw._reset_future.set_result( - sentinel.reset_result - ) - reset_result = await gw.reset() - - assert gw._transport.write.call_count == 1 - assert gw._send_seq == 0 - assert gw._rec_seq == 0 - assert len(gw._buffer) == 0 - assert gw._sendq.empty() - assert fut.done() - assert gw._pending == (-1, None) - - assert reset_result is sentinel.reset_result - - async def test_reset_timeout(gw, monkeypatch): monkeypatch.setattr(uart, "RESET_TIMEOUT", 0.1) with pytest.raises(asyncio.TimeoutError): @@ -323,29 +165,6 @@ async def test_reset_old(gw): gw._transport.write.assert_not_called() -async def test_data(gw): - loop = asyncio.get_running_loop() - write_call_count = 0 - - def mockwrite(data): - nonlocal loop, write_call_count - if data == b"\x10 @\xda}^Z~": - loop.call_soon(gw._handle_nak, gw._pending[0]) - else: - loop.call_soon(gw._handle_ack, (gw._pending[0] + 1) % 8) - write_call_count += 1 - - gw.write = mockwrite - - gw.data(b"foo") - gw.data(b"bar") - gw.data(b"baz") - gw._sendq.put_nowait(gw.Terminator) - - await gw._send_loop() - assert write_call_count == 4 - - def test_connection_lost_exc(gw): gw.connection_lost(sentinel.exception) @@ -419,111 +238,3 @@ async def test_wait_for_startup_reset_failure(gw): await asyncio.wait_for(gw.wait_for_startup_reset(), 0.01) assert gw._startup_reset_future is None - - -ASH_ACK_MIN = 0.01 - - -@patch("bellows.uart.ASH_RX_ACK_MIN", new=ASH_ACK_MIN * 2**0) -@patch("bellows.uart.ASH_RX_ACK_INIT", new=ASH_ACK_MIN * 2**2) -@patch("bellows.uart.ASH_RX_ACK_MAX", new=ASH_ACK_MIN * 2**3) -async def test_retry_success(): - app = MagicMock() - transport = MagicMock() - connected_future = asyncio.get_running_loop().create_future() - - gw = uart.Gateway(app, connected_future) - gw.connection_made(transport) - - old_timeout = gw._ack_timeout - gw.data(b"TX 1") - await asyncio.sleep(0) - - # Wait more than one ACK cycle to reply - assert len(transport.write.mock_calls) == 1 - await asyncio.sleep(ASH_ACK_MIN * 5) - - # The gateway has retried once by now - assert len(transport.write.mock_calls) == 2 - - gw.frame_received( - # ash.DataFrame(frm_num=0, re_tx=0, ack_num=1, ezsp_frame=b"RX 1").to_bytes() - bytes.fromhex("01107988654851") - ) - - # An ACK has been received and the pending frame has been acknowledged - await asyncio.sleep(0) - assert gw._pending == (-1, None) - - assert gw._ack_timeout > old_timeout - - gw.close() - - -@patch("bellows.uart.ASH_RX_ACK_MIN", new=ASH_ACK_MIN * 2**0) -@patch("bellows.uart.ASH_RX_ACK_INIT", new=ASH_ACK_MIN * 2**2) -@patch("bellows.uart.ASH_RX_ACK_MAX", new=ASH_ACK_MIN * 2**3) -async def test_retry_nak_then_success(): - app = MagicMock() - transport = MagicMock() - connected_future = asyncio.get_running_loop().create_future() - - gw = uart.Gateway(app, connected_future) - gw.connection_made(transport) - - old_timeout = gw._ack_timeout - gw.data(b"TX 1") - await asyncio.sleep(0) - assert len(transport.write.mock_calls) == 1 - - # Wait less than one ACK cycle so that we can NAK the frame during the RX window - await asyncio.sleep(ASH_ACK_MIN) - # NAK the frame - gw.frame_received( - # ash.NakFrame(res=0, ncp_ready=0, ack_num=0).to_bytes() - bytes.fromhex("a0541a") - ) - - # The gateway has retried once more, instantly - await asyncio.sleep(0) - assert len(transport.write.mock_calls) == 2 - - # Send a proper ACK - gw.frame_received( - # ash.AckFrame(res=0, ncp_ready=0, ack_num=1).to_bytes() - bytes.fromhex("816059") - ) - await asyncio.sleep(0) - assert gw._pending == (-1, None) - assert gw._ack_timeout < old_timeout - - gw.close() - - -@patch("bellows.uart.ASH_RX_ACK_MIN", new=ASH_ACK_MIN * 2**0) -@patch("bellows.uart.ASH_RX_ACK_INIT", new=ASH_ACK_MIN * 2**2) -@patch("bellows.uart.ASH_RX_ACK_MAX", new=ASH_ACK_MIN * 2**3) -async def test_retry_failure(): - app = MagicMock() - transport = MagicMock() - connected_future = asyncio.get_running_loop().create_future() - - gw = uart.Gateway(app, connected_future) - gw.connection_made(transport) - - old_timeout = gw._ack_timeout - gw.data(b"TX 1") - await asyncio.sleep(0) - - # Wait more than one ACK cycle to reply - assert len(transport.write.mock_calls) == 1 - await asyncio.sleep(ASH_ACK_MIN * 40) - - # The gateway has exhausted retries - assert len(transport.write.mock_calls) == 5 - - assert gw._pending == (-1, None) - assert gw._ack_timeout > old_timeout - assert gw._ack_timeout == ASH_ACK_MIN * 2**3 # max timeout - - gw.close() From afa4d231f3399502397b70c53ef7ed698cda7c1b Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 14 Apr 2024 19:32:43 -0400 Subject: [PATCH 17/42] Revert "[TEST] Shut down the event loop with a separate exception" This reverts commit d155888241f4f7a9d93e97b9058a0d5f95806f31. --- bellows/thread.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/bellows/thread.py b/bellows/thread.py index 0f92958b..6d8c1309 100644 --- a/bellows/thread.py +++ b/bellows/thread.py @@ -4,11 +4,6 @@ import logging import sys - -class EventLoopShuttingDown(RuntimeError): - pass - - LOGGER = logging.getLogger(__name__) @@ -64,16 +59,11 @@ def force_stop(self): if self.loop is None: return - LOGGER.debug("Shutting down thread") - def cancel_tasks_and_stop_loop(): tasks = asyncio.all_tasks(loop=self.loop) for task in tasks: - coro = task.get_coro() - - if coro is not None: - self.loop.call_soon_threadsafe(coro.throw, EventLoopShuttingDown()) + self.loop.call_soon_threadsafe(task.cancel) gather = asyncio.gather(*tasks, return_exceptions=True) gather.add_done_callback( From 4c364817b0511516df3499410eaec50967e616ae Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 15 Apr 2024 09:32:12 -0400 Subject: [PATCH 18/42] Fix startup reset unit test --- bellows/uart.py | 2 +- tests/test_uart.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/bellows/uart.py b/bellows/uart.py index 13d4b029..4a0bc1e6 100644 --- a/bellows/uart.py +++ b/bellows/uart.py @@ -60,7 +60,7 @@ def data_received(self, data): """Callback when there is data received from the uart""" self._application.frame_received(data) - def reset_received(self, code): + def reset_received(self, code: t.NcpResetCode) -> None: """Reset acknowledgement frame receive handler""" # not a reset we've requested. Signal application reset if code is not t.NcpResetCode.RESET_SOFTWARE: diff --git a/tests/test_uart.py b/tests/test_uart.py index e017ab5a..16cb8930 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -6,6 +6,7 @@ import zigpy.config as conf from bellows import uart +import bellows.types as t from .async_mock import AsyncMock, MagicMock, patch, sentinel @@ -224,7 +225,7 @@ def on_transport_close(): async def test_wait_for_startup_reset(gw): loop = asyncio.get_running_loop() - loop.call_later(0.01, gw.data_received, b"\xc1\x02\x0b\nR\x7e") + loop.call_later(0.01, gw.reset_received, t.NcpResetCode.RESET_SOFTWARE) assert gw._startup_reset_future is None await gw.wait_for_startup_reset() From ce3da98fd4b5d99b9339f7a83e8e9951c01a3909 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 15 Apr 2024 09:43:23 -0400 Subject: [PATCH 19/42] Use strict `zip` --- bellows/ash.py | 5 +++-- bellows/types/__init__.py | 2 +- bellows/zigbee/application.py | 4 +++- tests/test_application.py | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index bd76ffd1..e43d0f6a 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -147,8 +147,9 @@ class DataFrame(AshFrame): @staticmethod def _randomize(data: bytes) -> bytes: - assert len(data) <= len(PSEUDO_RANDOM_DATA_SEQUENCE) - return bytes([a ^ b for a, b in zip(data, PSEUDO_RANDOM_DATA_SEQUENCE)]) + return bytes( + [a ^ b for a, b in zip(data, PSEUDO_RANDOM_DATA_SEQUENCE, strict=True)] + ) @classmethod def from_bytes(cls, data: bytes) -> DataFrame: diff --git a/bellows/types/__init__.py b/bellows/types/__init__.py index 4a25b948..e8a9e552 100644 --- a/bellows/types/__init__.py +++ b/bellows/types/__init__.py @@ -12,4 +12,4 @@ def deserialize(data, schema): def serialize(data, schema): - return b"".join(t(v).serialize() for t, v in zip(schema, data)) + return b"".join(t(v).serialize() for t, v in zip(schema, data, strict=True)) diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index 5b757bf0..2e9b1302 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -1000,7 +1000,9 @@ async def _watchdog_feed(self): else: (res,) = await self._ezsp.readAndClearCounters() - for cnt_type, value in zip(self._ezsp.types.EmberCounterType, res): + for cnt_type, value in zip( + self._ezsp.types.EmberCounterType, res, strict=True + ): counters[cnt_type.name[8:]].update(value) if remainder == 0: diff --git a/tests/test_application.py b/tests/test_application.py index 7434e230..acb48a30 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1592,7 +1592,7 @@ async def test_startup_new_coordinator_no_groups_joined(app, ieee): ) async def test_energy_scanning(app, scan_results): app._ezsp.startScan = AsyncMock( - return_value=list(zip(range(11, 26 + 1), scan_results)) + return_value=list(zip(range(11, 26 + 1), scan_results, strict=True)) ) results = await app.energy_scan( From c11000a18af04a5f088a0d58ce4f99fb5650b649 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 15 Apr 2024 10:04:18 -0400 Subject: [PATCH 20/42] Use better parsing errors --- bellows/ash.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index e43d0f6a..f34b300b 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -81,6 +81,18 @@ class NCPState(enum.Enum): FAILED = "failed" +class ParsingError(Exception): + pass + + +class InvalidChecksum(ParsingError): + pass + + +class FrameTooShort(ParsingError): + pass + + class AshException(Exception): pass @@ -121,12 +133,12 @@ def to_bytes(self) -> bytes: @classmethod def _unwrap(cls, data: bytes) -> tuple[int, bytes]: if len(data) < 3: - raise ValueError(f"Frame is too short: {data!r}") + raise FrameTooShort(f"Frame is too short: {data!r}") computed_crc = binascii.crc_hqx(data[:-2], 0xFFFF).to_bytes(2, "big") if computed_crc != data[-2:]: - raise ValueError(f"Invalid CRC bytes in frame {data!r}") + raise InvalidChecksum(f"Invalid CRC bytes in frame {data!r}") return data[0], data[1:-2] @@ -427,7 +439,7 @@ def data_received(self, data: bytes) -> None: try: frame = self._extract_frame(data) - except ValueError: + except Exception: _LOGGER.warning( "Failed to parse frame %r", frame_bytes, exc_info=True ) From 2038c538a69b21aac4076d0626816cc33b6924cf Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 15 Apr 2024 10:05:27 -0400 Subject: [PATCH 21/42] Revert "Use strict `zip`" This reverts commit ce3da98fd4b5d99b9339f7a83e8e9951c01a3909. --- bellows/ash.py | 5 ++--- bellows/types/__init__.py | 2 +- bellows/zigbee/application.py | 4 +--- tests/test_application.py | 2 +- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index f34b300b..d675fc15 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -159,9 +159,8 @@ class DataFrame(AshFrame): @staticmethod def _randomize(data: bytes) -> bytes: - return bytes( - [a ^ b for a, b in zip(data, PSEUDO_RANDOM_DATA_SEQUENCE, strict=True)] - ) + assert len(data) <= len(PSEUDO_RANDOM_DATA_SEQUENCE) + return bytes([a ^ b for a, b in zip(data, PSEUDO_RANDOM_DATA_SEQUENCE)]) @classmethod def from_bytes(cls, data: bytes) -> DataFrame: diff --git a/bellows/types/__init__.py b/bellows/types/__init__.py index e8a9e552..4a25b948 100644 --- a/bellows/types/__init__.py +++ b/bellows/types/__init__.py @@ -12,4 +12,4 @@ def deserialize(data, schema): def serialize(data, schema): - return b"".join(t(v).serialize() for t, v in zip(schema, data, strict=True)) + return b"".join(t(v).serialize() for t, v in zip(schema, data)) diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index 2e9b1302..5b757bf0 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -1000,9 +1000,7 @@ async def _watchdog_feed(self): else: (res,) = await self._ezsp.readAndClearCounters() - for cnt_type, value in zip( - self._ezsp.types.EmberCounterType, res, strict=True - ): + for cnt_type, value in zip(self._ezsp.types.EmberCounterType, res): counters[cnt_type.name[8:]].update(value) if remainder == 0: diff --git a/tests/test_application.py b/tests/test_application.py index acb48a30..7434e230 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1592,7 +1592,7 @@ async def test_startup_new_coordinator_no_groups_joined(app, ieee): ) async def test_energy_scanning(app, scan_results): app._ezsp.startScan = AsyncMock( - return_value=list(zip(range(11, 26 + 1), scan_results, strict=True)) + return_value=list(zip(range(11, 26 + 1), scan_results)) ) results = await app.energy_scan( From 863d869d29a9247a4e5ec38746cb3a67b3a59a64 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 15 Apr 2024 16:58:05 -0400 Subject: [PATCH 22/42] Implement both sides of ASH to make testing easier --- bellows/ash.py | 293 ++++++++++++++++++++++++++-------------------- pyproject.toml | 10 ++ tests/test_ash.py | 209 ++++++++++++++++++++++++++++++--- 3 files changed, 368 insertions(+), 144 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index d675fc15..c424919c 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -76,11 +76,16 @@ def generate_random_sequence(length: int) -> bytes: PSEUDO_RANDOM_DATA_SEQUENCE = generate_random_sequence(256) -class NCPState(enum.Enum): +class NcpState(enum.Enum): CONNECTED = "connected" FAILED = "failed" +class AshRole(enum.Enum): + HOST = "host" + NCP = "ncp" + + class ParsingError(Exception): pass @@ -98,11 +103,19 @@ class AshException(Exception): class NotAcked(AshException): - def __init__(self, frame: NakFrame): + def __init__(self, frame: NakFrame) -> None: self.frame = frame def __repr__(self) -> str: - return f"<{self.__class__.__name__}(" f"frame={self.frame}" f")>" + return f"<{self.__class__.__name__}(frame={self.frame})>" + + +class NcpFailure(AshException): + def __init__(self, code: t.NcpResetCode) -> None: + self.code = code + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(code={self.code})>" class OutOfSequenceError(AshException): @@ -186,9 +199,6 @@ def to_bytes(self, *, randomize: bool = True) -> bytes: + self._randomize(self.ezsp_frame) ) - def __str__(self) -> str: - return f"DATA(num={self.frm_num}, ack={self.ack_num}, re_tx={self.re_tx}) = {self.ezsp_frame.hex()}" - @dataclasses.dataclass(frozen=True) class AckFrame(AshFrame): @@ -221,9 +231,6 @@ def to_bytes(self) -> bytes: ) ) - def __str__(self) -> str: - return f"ACK(ack={self.ack_num}, ready={'+' if self.ncp_ready == 0 else '-'!r})" - @dataclasses.dataclass(frozen=True) class NakFrame(AshFrame): @@ -256,9 +263,6 @@ def to_bytes(self) -> bytes: ) ) - def __str__(self) -> str: - return f"NAK(ack={self.ack_num}, ready={'+' if self.ncp_ready == 0 else '-'!r})" - @dataclasses.dataclass(frozen=True) class RstFrame(AshFrame): @@ -277,9 +281,6 @@ def from_bytes(cls, data: bytes) -> RstFrame: def to_bytes(self) -> bytes: return self.append_crc(bytes([self.MASK_VALUE])) - def __str__(self) -> str: - return "RST()" - @dataclasses.dataclass(frozen=True) class RStackFrame(AshFrame): @@ -311,31 +312,36 @@ def from_bytes(cls, data: bytes) -> RStackFrame: def to_bytes(self) -> bytes: return self.append_crc(bytes([self.MASK_VALUE, self.version, self.reset_code])) - def __str__(self) -> str: - return f"RSTACK(ver={self.version}, code={self.reset_code})" - @dataclasses.dataclass(frozen=True) -class ErrorFrame(RStackFrame): +class ErrorFrame(AshFrame): + MASK = 0b11111111 MASK_VALUE = 0b11000010 - def __str__(self) -> str: - return f"ERROR(ver={self.version}, code={self.reset_code})" + version: t.uint8_t + reset_code: t.NcpResetCode + + # We do not want to inherit from `RStackFrame` + from_bytes = classmethod(RStackFrame.from_bytes.__func__) + to_bytes = RStackFrame.to_bytes class AshProtocol(asyncio.Protocol): - def __init__(self, ezsp_protocol) -> None: + def __init__(self, ezsp_protocol, *, role: AshRole = AshRole.HOST) -> None: self._ezsp_protocol = ezsp_protocol self._transport = None self._buffer = bytearray() self._discarding_until_flag: bool = False self._pending_data_frames: dict[int, asyncio.Future] = {} - self._ncp_state = NCPState.CONNECTED self._send_data_frame_semaphore = asyncio.Semaphore(TX_K) self._tx_seq: int = 0 self._rx_seq: int = 0 self._t_rx_ack = T_RX_ACK_INIT + self._role: AshRole = role + self._ncp_reset_code: t.NcpResetCode | None = None + self._ncp_state: NcpState = NcpState.CONNECTED + def connection_made(self, transport): self._transport = transport self._ezsp_protocol.connection_made(self) @@ -405,7 +411,7 @@ def _unstuff_bytes(data: bytes) -> bytes: return out def data_received(self, data: bytes) -> None: - # _LOGGER.debug("Received data: %s", data.hex()) + _LOGGER.debug("Received data %s", data.hex()) self._buffer.extend(data) while self._buffer: @@ -465,10 +471,22 @@ def _handle_ack(self, frame: DataFrame | AckFrame) -> None: def frame_received(self, frame: AshFrame) -> None: _LOGGER.debug("Received frame %r", frame) + if ( + self._ncp_reset_code is not None + and self._role == AshRole.NCP + and not isinstance(frame, RstFrame) + ): + _LOGGER.debug( + "NCP in failure state %r, ignoring frame: %r", + self._ncp_reset_code, + frame, + ) + self._write_frame(ErrorFrame(version=2, reset_code=self._ncp_reset_code)) + return + if isinstance(frame, DataFrame): # The Host may not piggyback acknowledgments and should promptly send an ACK # frame when it receives a DATA frame. - if frame.frm_num == self._rx_seq: self._handle_ack(frame) self._rx_seq = (frame.frm_num + 1) % 8 @@ -480,11 +498,12 @@ def frame_received(self, frame: AshFrame) -> None: # sequence self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) else: - _LOGGER.warning("Received an out of sequence frame: %r", frame) + _LOGGER.debug("Received an out of sequence frame: %r", frame) self._write_frame(NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) - elif isinstance(frame, ErrorFrame): - self._ezsp_protocol.error_received(frame.reset_code) elif isinstance(frame, RStackFrame): + self._ncp_reset_code = None + self._ncp_state = NcpState.CONNECTED + self._tx_seq = 0 self._rx_seq = 0 self._change_ack_timeout(T_RX_ACK_INIT) @@ -500,12 +519,36 @@ def frame_received(self, frame: AshFrame) -> None: and not fut.done() ): fut.set_exception(error) + elif isinstance(frame, RstFrame): + self._ncp_reset_code = None + self._ncp_state = NcpState.CONNECTED + + if self._role == AshRole.NCP: + self._tx_seq = 0 + self._rx_seq = 0 + self._change_ack_timeout(T_RX_ACK_INIT) + + self._enter_ncp_error_state(None) + self._write_frame( + RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE) + ) + elif isinstance(frame, ErrorFrame) and self._role == AshRole.HOST: + _LOGGER.debug("NCP has entered failed state: %s", frame.reset_code) + self._ncp_reset_code = frame.reset_code + self._ncp_state = NcpState.FAILED + + # Cancel all pending requests + exc = NcpFailure(code=self._ncp_reset_code) + + for fut in self._pending_data_frames.values(): + if not fut.done(): + fut.set_exception(exc) def _write_frame(self, frame: AshFrame) -> None: - _LOGGER.debug("Sending frame %r", frame) + _LOGGER.debug("Sending frame %r", frame) data = self._stuff_bytes(frame.to_bytes()) + FLAG - # _LOGGER.debug("Sending data %s", data.hex()) + _LOGGER.debug("Sending data %s", data.hex()) self._transport.write(data) def _change_ack_timeout(self, new_value: float) -> None: @@ -518,6 +561,20 @@ def _change_ack_timeout(self, new_value: float) -> None: self._t_rx_ack = new_value + def _enter_ncp_error_state(self, code: t.NcpResetCode | None) -> None: + self._ncp_reset_code = code + + if code is None: + self._ncp_state = NcpState.CONNECTED + else: + self._ncp_state = NcpState.FAILED + + _LOGGER.debug("Changing connectivity state: %r", self._ncp_state) + _LOGGER.debug("Changing reset code: %r", self._ncp_reset_code) + + if self._ncp_state == NcpState.FAILED: + self._write_frame(ErrorFrame(version=2, reset_code=self._ncp_reset_code)) + async def _send_frame(self, frame: AshFrame) -> None: if not isinstance(frame, DataFrame): # Non-DATA frames can be sent immediately and do not require an ACK @@ -528,56 +585,85 @@ async def _send_frame(self, frame: AshFrame) -> None: _LOGGER.debug("Semaphore is locked, waiting") async with self._send_data_frame_semaphore: - frm_num = self._tx_seq - self._tx_seq = (self._tx_seq + 1) % 8 - - for attempt in range(ACK_TIMEOUTS): - # Use a fresh ACK number on every retry - frame = frame.replace( - frm_num=frm_num, - re_tx=(attempt > 0), - ack_num=self._rx_seq, - ) - - send_time = time.monotonic() - - ack_future = asyncio.get_running_loop().create_future() - self._pending_data_frames[frm_num] = ack_future - self._write_frame(frame) - - try: - async with asyncio_timeout(self._t_rx_ack): - await ack_future - except NotAcked: - _LOGGER.debug( - "NCP responded with NAK. Retrying (attempt %d)", attempt + 1 - ) - - # For timing purposes, NAK can be treated as an ACK - delta = time.monotonic() - send_time - self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) - except asyncio.TimeoutError: - _LOGGER.debug( - "No ACK received in %0.2fs (attempt %d)", - self._t_rx_ack, - attempt + 1, + frm_num = None + + try: + for attempt in range(ACK_TIMEOUTS): + if ( + self._role == AshRole.HOST + and self._ncp_state == NcpState.FAILED + ): + _LOGGER.debug( + "NCP is in a failed state, not re-sending: %r", frame + ) + raise NcpFailure( + t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT + ) + + if frm_num is None: + frm_num = self._tx_seq + self._tx_seq = (self._tx_seq + 1) % 8 + + # Use a fresh ACK number on every retry + frame = frame.replace( + frm_num=frm_num, + re_tx=(attempt > 0), + ack_num=self._rx_seq, ) - # If a DATA frame acknowledgement is not received within the current - # timeout value, then t_rx_ack is doubled. - self._change_ack_timeout(2 * self._t_rx_ack) - else: - # Whenever an acknowledgement is received, t_rx_ack is set to 7/8 of - # its current value plus 1/2 of the measured time for the - # acknowledgement. - delta = time.monotonic() - send_time - self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) - break - - # Any exception will trigger this + send_time = time.monotonic() + + ack_future = asyncio.get_running_loop().create_future() + self._pending_data_frames[frm_num] = ack_future + self._write_frame(frame) + + try: + async with asyncio_timeout(self._t_rx_ack): + await ack_future + except NotAcked: + _LOGGER.debug( + "NCP responded with NAK. Retrying (attempt %d)", attempt + 1 + ) + + # For timing purposes, NAK can be treated as an ACK + delta = time.monotonic() - send_time + self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) + + if attempt >= ACK_TIMEOUTS - 1: + raise + except NcpFailure: + _LOGGER.debug( + "NCP has entered into a failed state, not retrying" + ) + raise + except asyncio.TimeoutError: + _LOGGER.debug( + "No ACK received in %0.2fs (attempt %d)", + self._t_rx_ack, + attempt + 1, + ) + # If a DATA frame acknowledgement is not received within the + # current timeout value, then t_rx_ack is doubled. + self._change_ack_timeout(2 * self._t_rx_ack) + + if attempt >= ACK_TIMEOUTS - 1: + # Only a timeout is enough to enter an error state + if self._role == AshRole.NCP: + self._enter_ncp_error_state( + t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT + ) + + raise + else: + # Whenever an acknowledgement is received, t_rx_ack is set to + # 7/8 of its current value plus 1/2 of the measured time for the + # acknowledgement. + delta = time.monotonic() - send_time + self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) + + break + finally: self._pending_data_frames.pop(frm_num) - else: - raise async def send_data(self, data: bytes) -> None: await self._send_frame( @@ -587,56 +673,3 @@ async def send_data(self, data: bytes) -> None: def send_reset(self) -> None: self._write_frame(RstFrame()) - - -def main(): - import ast - import pathlib - import sys - import unittest.mock - - import coloredlogs - - coloredlogs.install(level="DEBUG") - - class CapturingAshProtocol(AshProtocol): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._parsed_frames = [] - - def frame_received(self, frame: AshFrame) -> None: - self._parsed_frames.append(frame) - - with pathlib.Path(sys.argv[1]).open("r") as f: - for line in f: - if "bellows.uart" not in line: - continue - - if "Sending: " in line: - direction = " --->" - elif ( - "Data frame:" in line or "ACK frame: " in line or "NAK frame: " in line - ): - direction = "<--- " - else: - continue - - data = bytes.fromhex(ast.literal_eval(line.split(": b", 1)[1])) - - # Data frames are logged already unstuffed - if direction == "<--- ": - data = AshProtocol._stuff_bytes(data[:-1]) + data[-1:] - - protocol = CapturingAshProtocol(ezsp_protocol=unittest.mock.Mock()) - protocol.data_received(data) - - if len(protocol._parsed_frames) != 1: - raise ValueError(f"Failed to parse frames: {protocol._parsed_frames}") - - frame = protocol._parsed_frames[0] - - _LOGGER.info("%s: %s", direction, frame) - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index 9c065a20..9ecfb1f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,3 +73,13 @@ omit = [ "bellows/cli/*.py", "bellows/typing.py", ] + + +[tool.coverage.report] +exclude_also = [ + "raise AssertionError", + "raise NotImplementedError", + "if TYPE_CHECKING", + "if typing.TYPE_CHECKING", + "@(abc\\.)?abstractmethod", +] \ No newline at end of file diff --git a/tests/test_ash.py b/tests/test_ash.py index 9a22f894..0d4a2e50 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -1,21 +1,202 @@ -from bellows.ash import PSEUDO_RANDOM_DATA_SEQUENCE, AshProtocol +import asyncio +from unittest.mock import MagicMock, call, patch + +import pytest + +from bellows import ash +import bellows.types as t def test_stuffing(): - assert AshProtocol._stuff_bytes(b"\x7E") == b"\x7D\x5E" - assert AshProtocol._stuff_bytes(b"\x11") == b"\x7D\x31" - assert AshProtocol._stuff_bytes(b"\x13") == b"\x7D\x33" - assert AshProtocol._stuff_bytes(b"\x18") == b"\x7D\x38" - assert AshProtocol._stuff_bytes(b"\x1A") == b"\x7D\x3A" - assert AshProtocol._stuff_bytes(b"\x7D") == b"\x7D\x5D" + assert ash.AshProtocol._stuff_bytes(b"\x7E") == b"\x7D\x5E" + assert ash.AshProtocol._stuff_bytes(b"\x11") == b"\x7D\x31" + assert ash.AshProtocol._stuff_bytes(b"\x13") == b"\x7D\x33" + assert ash.AshProtocol._stuff_bytes(b"\x18") == b"\x7D\x38" + assert ash.AshProtocol._stuff_bytes(b"\x1A") == b"\x7D\x3A" + assert ash.AshProtocol._stuff_bytes(b"\x7D") == b"\x7D\x5D" - assert AshProtocol._unstuff_bytes(b"\x7D\x5E") == b"\x7E" - assert AshProtocol._unstuff_bytes(b"\x7D\x31") == b"\x11" - assert AshProtocol._unstuff_bytes(b"\x7D\x33") == b"\x13" - assert AshProtocol._unstuff_bytes(b"\x7D\x38") == b"\x18" - assert AshProtocol._unstuff_bytes(b"\x7D\x3A") == b"\x1A" - assert AshProtocol._unstuff_bytes(b"\x7D\x5D") == b"\x7D" + assert ash.AshProtocol._unstuff_bytes(b"\x7D\x5E") == b"\x7E" + assert ash.AshProtocol._unstuff_bytes(b"\x7D\x31") == b"\x11" + assert ash.AshProtocol._unstuff_bytes(b"\x7D\x33") == b"\x13" + assert ash.AshProtocol._unstuff_bytes(b"\x7D\x38") == b"\x18" + assert ash.AshProtocol._unstuff_bytes(b"\x7D\x3A") == b"\x1A" + assert ash.AshProtocol._unstuff_bytes(b"\x7D\x5D") == b"\x7D" def test_pseudo_random_data_sequence(): - assert PSEUDO_RANDOM_DATA_SEQUENCE.startswith(b"\x42\x21\xA8\x54\x2A") + assert ash.PSEUDO_RANDOM_DATA_SEQUENCE.startswith(b"\x42\x21\xA8\x54\x2A") + + +def test_rst_frame(): + assert ash.RstFrame() == ash.RstFrame() + assert ash.RstFrame().to_bytes() == bytes.fromhex("c038bc") + assert ash.RstFrame.from_bytes(bytes.fromhex("c038bc")) == ash.RstFrame() + assert str(ash.RstFrame()) == "RstFrame()" + + +async def test_ash_protocol_startup(): + """Simple EZSP startup: reset, version(4), then version(8).""" + + loop = asyncio.get_running_loop() + + ezsp = MagicMock() + transport = MagicMock() + + protocol = ash.AshProtocol(ezsp) + protocol._write_frame = MagicMock(wraps=protocol._write_frame) + protocol.connection_made(transport) + + assert ezsp.connection_made.mock_calls == [call(protocol)] + + assert protocol._rx_seq == 0 + assert protocol._tx_seq == 0 + + # ASH reset + protocol.send_reset() + loop.call_later( + 0, + protocol.frame_received, + ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE), + ) + + await asyncio.sleep(0.01) + + assert ezsp.reset_received.mock_calls == [call(t.NcpResetCode.RESET_SOFTWARE)] + assert protocol._write_frame.mock_calls == [call(ash.RstFrame())] + + protocol._write_frame.reset_mock() + + # EZSP version(4) + loop.call_later( + 0, + protocol.frame_received, + ash.DataFrame( + frm_num=0, re_tx=False, ack_num=1, ezsp_frame=b"\x00\x80\x00\x08\x02\x80g" + ), + ) + await protocol.send_data(b"\x00\x00\x00\x04") + assert protocol._write_frame.mock_calls == [ + call( + ash.DataFrame( + frm_num=0, re_tx=False, ack_num=0, ezsp_frame=b"\x00\x00\x00\x04" + ) + ), + call(ash.AckFrame(res=0, ncp_ready=0, ack_num=1)), + ] + + protocol._write_frame.reset_mock() + + # EZSP version(8) + loop.call_later( + 0, + protocol.frame_received, + ash.DataFrame( + frm_num=1, + re_tx=False, + ack_num=2, + ezsp_frame=b"\x00\x80\x01\x00\x00\x08\x02\x80g", + ), + ) + await protocol.send_data(b"\x00\x00\x01\x00\x00\x08") + assert protocol._write_frame.mock_calls == [ + call( + ash.DataFrame( + frm_num=1, + re_tx=False, + ack_num=1, + ezsp_frame=b"\x00\x00\x01\x00\x00\x08", + ) + ), + call(ash.AckFrame(res=0, ncp_ready=0, ack_num=2)), + ] + + +@patch("bellows.ash.T_RX_ACK_INIT", ash.T_RX_ACK_INIT / 100) +@patch("bellows.ash.T_RX_ACK_MIN", ash.T_RX_ACK_MIN / 100) +@patch("bellows.ash.T_RX_ACK_MAX", ash.T_RX_ACK_MAX / 100) +async def test_ash_end_to_end(): + asyncio.get_running_loop() + + host_ezsp = MagicMock() + ncp_ezsp = MagicMock() + + class FakeTransport: + def __init__(self, receiver): + self.receiver = receiver + self.paused = False + + def write(self, data): + if not self.paused: + self.receiver.data_received(data) + + host = ash.AshProtocol(host_ezsp, role=ash.AshRole.HOST) + ncp = ash.AshProtocol(ncp_ezsp, role=ash.AshRole.NCP) + + host_transport = FakeTransport(ncp) + ncp_transport = FakeTransport(host) + + host.connection_made(host_transport) + ncp.connection_made(ncp_transport) + + # Ping pong works + await host.send_data(b"Hello!") + assert ncp_ezsp.data_received.mock_calls == [call(b"Hello!")] + + await ncp.send_data(b"World!") + assert host_ezsp.data_received.mock_calls == [call(b"World!")] + + ncp_ezsp.data_received.reset_mock() + host_ezsp.data_received.reset_mock() + + # Let's pause the ncp so it can't ACK + with patch.object(ncp_transport, "paused", True): + send_task = asyncio.create_task(host.send_data(b"delayed")) + await asyncio.sleep(host._t_rx_ack * 2) + + # It'll still succeed + await send_task + + assert ncp_ezsp.data_received.mock_calls == [call(b"delayed")] + + ncp_ezsp.data_received.reset_mock() + host_ezsp.data_received.reset_mock() + + # Let's let a request fail due to a connectivity issue + with patch.object(ncp_transport, "paused", True): + send_task = asyncio.create_task(host.send_data(b"host failure")) + await asyncio.sleep(host._t_rx_ack * 15) + + with pytest.raises(asyncio.TimeoutError): + await send_task + + ncp_ezsp.data_received.reset_mock() + host_ezsp.data_received.reset_mock() + + # When the NCP fail to receive a reply, it enters a failed state + assert host._ncp_reset_code is None + assert ncp._ncp_reset_code is None + + with patch.object(host_transport, "paused", True): + send_task = asyncio.create_task(ncp.send_data(b"ncp failure")) + await asyncio.sleep(ncp._t_rx_ack * 15) + + with pytest.raises(asyncio.TimeoutError): + await send_task + + assert ( + host._ncp_reset_code is t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT + ) + assert ( + ncp._ncp_reset_code is t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT + ) + + ncp_ezsp.data_received.reset_mock() + host_ezsp.data_received.reset_mock() + + # All communication attempts with it will fail until it is reset + with pytest.raises(ash.NcpFailure): + await host.send_data(b"test") + + host.send_reset() + await asyncio.sleep(0.01) + await host.send_data(b"test") From ef3c3c136ec9c8a5457aaee6cb6e3a46636dd6ac Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 15 Apr 2024 17:09:46 -0400 Subject: [PATCH 23/42] Only pop the pending frame future if the frame number was actually assigned --- bellows/ash.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bellows/ash.py b/bellows/ash.py index c424919c..ef1fd21d 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -663,7 +663,8 @@ async def _send_frame(self, frame: AshFrame) -> None: break finally: - self._pending_data_frames.pop(frm_num) + if frm_num is not None: + self._pending_data_frames.pop(frm_num) async def send_data(self, data: bytes) -> None: await self._send_frame( From 7f604e0df96fdadc40a6a6e40437f1f0b25f8019 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 15 Apr 2024 17:26:41 -0400 Subject: [PATCH 24/42] Move NCP ASH implementation into tests --- bellows/ash.py | 167 +++++++++++++++++----------------------------- tests/test_ash.py | 64 +++++++++++++++++- 2 files changed, 125 insertions(+), 106 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index ef1fd21d..e5fc9292 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -81,11 +81,6 @@ class NcpState(enum.Enum): FAILED = "failed" -class AshRole(enum.Enum): - HOST = "host" - NCP = "ncp" - - class ParsingError(Exception): pass @@ -327,7 +322,7 @@ class ErrorFrame(AshFrame): class AshProtocol(asyncio.Protocol): - def __init__(self, ezsp_protocol, *, role: AshRole = AshRole.HOST) -> None: + def __init__(self, ezsp_protocol) -> None: self._ezsp_protocol = ezsp_protocol self._transport = None self._buffer = bytearray() @@ -338,7 +333,6 @@ def __init__(self, ezsp_protocol, *, role: AshRole = AshRole.HOST) -> None: self._rx_seq: int = 0 self._t_rx_ack = T_RX_ACK_INIT - self._role: AshRole = role self._ncp_reset_code: t.NcpResetCode | None = None self._ncp_state: NcpState = NcpState.CONNECTED @@ -352,12 +346,6 @@ def connection_lost(self, exc): def eof_received(self): self._ezsp_protocol.eof_received() - def _get_tx_seq(self) -> int: - result = self._tx_seq - self._tx_seq = (self._tx_seq + 1) % 8 - - return result - def close(self): if self._transport is not None: self._transport.close() @@ -471,78 +459,72 @@ def _handle_ack(self, frame: DataFrame | AckFrame) -> None: def frame_received(self, frame: AshFrame) -> None: _LOGGER.debug("Received frame %r", frame) - if ( - self._ncp_reset_code is not None - and self._role == AshRole.NCP - and not isinstance(frame, RstFrame) - ): - _LOGGER.debug( - "NCP in failure state %r, ignoring frame: %r", - self._ncp_reset_code, - frame, - ) - self._write_frame(ErrorFrame(version=2, reset_code=self._ncp_reset_code)) - return - if isinstance(frame, DataFrame): - # The Host may not piggyback acknowledgments and should promptly send an ACK - # frame when it receives a DATA frame. - if frame.frm_num == self._rx_seq: - self._handle_ack(frame) - self._rx_seq = (frame.frm_num + 1) % 8 - self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) - - self._ezsp_protocol.data_received(frame.ezsp_frame) - elif frame.re_tx: - # Retransmitted frames must be immediately ACKed even if they are out of - # sequence - self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) - else: - _LOGGER.debug("Received an out of sequence frame: %r", frame) - self._write_frame(NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) + self.data_frame_received(frame) elif isinstance(frame, RStackFrame): - self._ncp_reset_code = None - self._ncp_state = NcpState.CONNECTED - - self._tx_seq = 0 - self._rx_seq = 0 - self._change_ack_timeout(T_RX_ACK_INIT) - self._ezsp_protocol.reset_received(frame.reset_code) + self.rstack_frame_received(frame) elif isinstance(frame, AckFrame): - self._handle_ack(frame) + self.ack_frame_received(frame) elif isinstance(frame, NakFrame): - error = NotAcked(frame=frame) - - for frm_num, fut in self._pending_data_frames.items(): - if ( - not frame.ack_num - TX_K <= frm_num <= frame.ack_num - and not fut.done() - ): - fut.set_exception(error) + self.nak_frame_received(frame) elif isinstance(frame, RstFrame): - self._ncp_reset_code = None - self._ncp_state = NcpState.CONNECTED - - if self._role == AshRole.NCP: - self._tx_seq = 0 - self._rx_seq = 0 - self._change_ack_timeout(T_RX_ACK_INIT) - - self._enter_ncp_error_state(None) - self._write_frame( - RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE) - ) - elif isinstance(frame, ErrorFrame) and self._role == AshRole.HOST: - _LOGGER.debug("NCP has entered failed state: %s", frame.reset_code) - self._ncp_reset_code = frame.reset_code - self._ncp_state = NcpState.FAILED - - # Cancel all pending requests - exc = NcpFailure(code=self._ncp_reset_code) - - for fut in self._pending_data_frames.values(): - if not fut.done(): - fut.set_exception(exc) + self.rst_frame_received(frame) + elif isinstance(frame, ErrorFrame): + self.error_frame_received(frame) + else: + raise TypeError(f"Unknown frame received: {frame}") # pragma: no cover + + def data_frame_received(self, frame: DataFrame) -> None: + # The Host may not piggyback acknowledgments and should promptly send an ACK + # frame when it receives a DATA frame. + if frame.frm_num == self._rx_seq: + self._handle_ack(frame) + self._rx_seq = (frame.frm_num + 1) % 8 + self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) + + self._ezsp_protocol.data_received(frame.ezsp_frame) + elif frame.re_tx: + # Retransmitted frames must be immediately ACKed even if they are out of + # sequence + self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) + else: + _LOGGER.debug("Received an out of sequence frame: %r", frame) + self._write_frame(NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) + + def rstack_frame_received(self, frame: RStackFrame) -> None: + self._ncp_reset_code = None + self._ncp_state = NcpState.CONNECTED + + self._tx_seq = 0 + self._rx_seq = 0 + self._change_ack_timeout(T_RX_ACK_INIT) + self._ezsp_protocol.reset_received(frame.reset_code) + + def ack_frame_received(self, frame: AckFrame) -> None: + self._handle_ack(frame) + + def nak_frame_received(self, frame: NakFrame) -> None: + err = NotAcked(frame=frame) + + for frm_num, fut in self._pending_data_frames.items(): + if not frame.ack_num - TX_K <= frm_num <= frame.ack_num and not fut.done(): + fut.set_exception(err) + + def rst_frame_received(self, frame: RstFrame) -> None: + self._ncp_reset_code = None + self._ncp_state = NcpState.CONNECTED + + def error_frame_received(self, frame: ErrorFrame) -> None: + _LOGGER.debug("NCP has entered failed state: %s", frame.reset_code) + self._ncp_reset_code = frame.reset_code + self._ncp_state = NcpState.FAILED + + # Cancel all pending requests + exc = NcpFailure(code=self._ncp_reset_code) + + for fut in self._pending_data_frames.values(): + if not fut.done(): + fut.set_exception(exc) def _write_frame(self, frame: AshFrame) -> None: _LOGGER.debug("Sending frame %r", frame) @@ -561,20 +543,6 @@ def _change_ack_timeout(self, new_value: float) -> None: self._t_rx_ack = new_value - def _enter_ncp_error_state(self, code: t.NcpResetCode | None) -> None: - self._ncp_reset_code = code - - if code is None: - self._ncp_state = NcpState.CONNECTED - else: - self._ncp_state = NcpState.FAILED - - _LOGGER.debug("Changing connectivity state: %r", self._ncp_state) - _LOGGER.debug("Changing reset code: %r", self._ncp_reset_code) - - if self._ncp_state == NcpState.FAILED: - self._write_frame(ErrorFrame(version=2, reset_code=self._ncp_reset_code)) - async def _send_frame(self, frame: AshFrame) -> None: if not isinstance(frame, DataFrame): # Non-DATA frames can be sent immediately and do not require an ACK @@ -589,10 +557,7 @@ async def _send_frame(self, frame: AshFrame) -> None: try: for attempt in range(ACK_TIMEOUTS): - if ( - self._role == AshRole.HOST - and self._ncp_state == NcpState.FAILED - ): + if self._ncp_state == NcpState.FAILED: _LOGGER.debug( "NCP is in a failed state, not re-sending: %r", frame ) @@ -647,12 +612,6 @@ async def _send_frame(self, frame: AshFrame) -> None: self._change_ack_timeout(2 * self._t_rx_ack) if attempt >= ACK_TIMEOUTS - 1: - # Only a timeout is enough to enter an error state - if self._role == AshRole.NCP: - self._enter_ncp_error_state( - t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT - ) - raise else: # Whenever an acknowledgement is received, t_rx_ack is set to diff --git a/tests/test_ash.py b/tests/test_ash.py index 0d4a2e50..95550878 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -7,6 +7,66 @@ import bellows.types as t +class AshNcpProtocol(ash.AshProtocol): + def frame_received(self, frame: ash.AshFrame) -> None: + if self._ncp_reset_code is not None and not isinstance(frame, ash.RstFrame): + ash._LOGGER.debug( + "NCP in failure state %r, ignoring frame: %r", + self._ncp_reset_code, + frame, + ) + self._write_frame( + ash.ErrorFrame(version=2, reset_code=self._ncp_reset_code) + ) + return + + super().frame_received(frame) + + def _enter_ncp_error_state(self, code: t.NcpResetCode | None) -> None: + self._ncp_reset_code = code + + if code is None: + self._ncp_state = ash.NcpState.CONNECTED + else: + self._ncp_state = ash.NcpState.FAILED + + ash._LOGGER.debug("Changing connectivity state: %r", self._ncp_state) + ash._LOGGER.debug("Changing reset code: %r", self._ncp_reset_code) + + if self._ncp_state == ash.NcpState.FAILED: + self._write_frame( + ash.ErrorFrame(version=2, reset_code=self._ncp_reset_code) + ) + + def rst_frame_received(self, frame: ash.RstFrame) -> None: + super().rst_frame_received(frame) + + self._tx_seq = 0 + self._rx_seq = 0 + self._change_ack_timeout(ash.T_RX_ACK_INIT) + + self._enter_ncp_error_state(None) + self._write_frame( + ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE) + ) + + async def _send_frame(self, frame: ash.AshFrame) -> None: + try: + return await super()._send_frame(frame) + except asyncio.TimeoutError: + self._enter_ncp_error_state( + t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT + ) + raise + if not isinstance(frame, ash.DataFrame): + # Non-DATA frames can be sent immediately and do not require an ACK + self._write_frame(frame) + return + + def send_reset(self) -> None: + raise NotImplementedError() + + def test_stuffing(): assert ash.AshProtocol._stuff_bytes(b"\x7E") == b"\x7D\x5E" assert ash.AshProtocol._stuff_bytes(b"\x11") == b"\x7D\x31" @@ -129,8 +189,8 @@ def write(self, data): if not self.paused: self.receiver.data_received(data) - host = ash.AshProtocol(host_ezsp, role=ash.AshRole.HOST) - ncp = ash.AshProtocol(ncp_ezsp, role=ash.AshRole.NCP) + host = ash.AshProtocol(host_ezsp) + ncp = AshNcpProtocol(ncp_ezsp) host_transport = FakeTransport(ncp) ncp_transport = FakeTransport(host) From 082d370d5531542813e70acacb17ce2ef13ebefc Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 15 Apr 2024 17:29:44 -0400 Subject: [PATCH 25/42] Ensure tests pass with 3.8 --- tests/test_ash.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_ash.py b/tests/test_ash.py index 95550878..6a1a044b 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from unittest.mock import MagicMock, call, patch From 630ece72bb193c41bea6d14fdb0fbf28deb1bc98 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 29 Apr 2024 15:46:42 -0400 Subject: [PATCH 26/42] Properly handle cancel and substitute bytes --- bellows/ash.py | 76 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 53 insertions(+), 23 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index e5fc9292..c9371b00 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -20,14 +20,17 @@ _LOGGER = logging.getLogger(__name__) +MAX_BUFFER_SIZE = 1024 + FLAG = b"\x7E" # Marks end of frame ESCAPE = b"\x7D" XON = b"\x11" # Resume transmission XOFF = b"\x13" # Stop transmission -SUBSTITUTE = b"\x18" +SUBSTITUTE = b"\x18" # Replaces a byte received with a low-level communication error CANCEL = b"\x1A" # Terminates a frame in progress RESERVED = frozenset(FLAG + ESCAPE + XON + XOFF + SUBSTITUTE + CANCEL) +RESERVED_WITHOUT_ESCAPE = RESERVED - frozenset([ESCAPE[0]]) # Initial value of t_rx_ack, the maximum time the NCP waits to receive acknowledgement # of a DATA frame @@ -326,7 +329,7 @@ def __init__(self, ezsp_protocol) -> None: self._ezsp_protocol = ezsp_protocol self._transport = None self._buffer = bytearray() - self._discarding_until_flag: bool = False + self._discarding_until_next_flag: bool = False self._pending_data_frames: dict[int, asyncio.Future] = {} self._send_data_frame_semaphore = asyncio.Semaphore(TX_K) self._tx_seq: int = 0 @@ -402,32 +405,40 @@ def data_received(self, data: bytes) -> None: _LOGGER.debug("Received data %s", data.hex()) self._buffer.extend(data) + if len(self._buffer) > MAX_BUFFER_SIZE: + _LOGGER.debug( + "Truncating buffer to %s bytes, it is growing too fast", MAX_BUFFER_SIZE + ) + self._buffer = self._buffer[:MAX_BUFFER_SIZE] + while self._buffer: - if self._discarding_until_flag: + if self._discarding_until_next_flag: if FLAG not in self._buffer: self._buffer.clear() - return + break - self._discarding_until_flag = False + self._discarding_until_next_flag = False _, _, self._buffer = self._buffer.partition(FLAG) - if self._buffer.startswith(FLAG): - # Consecutive Flag Bytes after the first Flag Byte are ignored - self._buffer = self._buffer[1:] - elif self._buffer.startswith(CANCEL): - # all data received since the previous Flag Byte to be ignored - _, _, self._buffer = self._buffer.partition(CANCEL) - elif self._buffer.startswith(XON): - _LOGGER.debug("Received XON byte, resuming transmission") - self._buffer = self._buffer[1:] - elif self._buffer.startswith(XOFF): - _LOGGER.debug("Received XOFF byte, pausing transmission") - self._buffer = self._buffer[1:] - elif self._buffer.startswith(SUBSTITUTE): - self._discarding_until_flag = True - self._buffer = self._buffer[1:] - elif FLAG in self._buffer: - frame_bytes, _, self._buffer = self._buffer.partition(FLAG) + try: + # Find the index of the first reserved byte that isn't an escape byte + reserved_index, reserved_byte = next( + (index, byte) + for index, byte in enumerate(self._buffer) + if byte in RESERVED_WITHOUT_ESCAPE + ) + except StopIteration: + break + + if reserved_byte == FLAG[0]: + # Flag Byte marks the end of a frame + frame_bytes = self._buffer[:reserved_index] + self._buffer = self._buffer[reserved_index + 1 :] + + # Consecutive EOFs can be received, empty frames are ignored + if not frame_bytes: + continue + data = self._unstuff_bytes(frame_bytes) try: @@ -438,8 +449,27 @@ def data_received(self, data: bytes) -> None: ) else: self.frame_received(frame) + elif reserved_byte == CANCEL[0]: + _LOGGER.debug("Received cancel byte, clearing buffer") + # All data received since the previous Flag Byte to be ignored + self._buffer = self._buffer[reserved_index + 1 :] + elif reserved_byte == SUBSTITUTE[0]: + _LOGGER.warning("Received substitute byte, marking buffer as corrupted") + # The data between the previous and the next Flag Byte is ignored + self._discarding_until_next_flag = True + self._buffer = self._buffer[reserved_index + 1 :] + elif reserved_byte == XON[0]: + # Resume transmission: not implemented! + _LOGGER.debug("Received XON byte, resuming transmission") + self._buffer.pop(reserved_index) + elif reserved_byte == XOFF[0]: + # Pause transmission: not implemented! + _LOGGER.debug("Received XOFF byte, pausing transmission") + self._buffer.pop(reserved_index) else: - break + raise RuntimeError( + f"Unexpected reserved byte found: 0x{reserved_byte:02X}" + ) # pragma: no cover def _handle_ack(self, frame: DataFrame | AckFrame) -> None: # Note that ackNum is the number of the next frame the receiver expects and it From 8aef3ba71165fd542765b2b177de45936414d2a8 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 29 Apr 2024 15:51:53 -0400 Subject: [PATCH 27/42] Add a unit test --- tests/test_ash.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test_ash.py b/tests/test_ash.py index 6a1a044b..2cef7684 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -96,6 +96,33 @@ def test_rst_frame(): assert str(ash.RstFrame()) == "RstFrame()" +def test_cancel_byte(): + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + protocol.frame_received = MagicMock(wraps=protocol.frame_received) + + protocol.data_received(bytes.fromhex("ddf9ff")) + protocol.data_received(bytes.fromhex("1ac1020b0a527e")) # starts with a CANCEL byte + + # We still parse out the RSTACK frame + assert protocol.frame_received.mock_calls == [ + call(ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE)) + ] + assert not protocol._buffer + + +def test_buffer_growth(): + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + + # Receive a lot of bogus data + for i in range(1000): + protocol.data_received(b"\xEE" * 100) + + # Make sure our internal buffer doesn't blow up + assert len(protocol._buffer) == ash.MAX_BUFFER_SIZE + + async def test_ash_protocol_startup(): """Simple EZSP startup: reset, version(4), then version(8).""" From a020473da78b29ffdc9ade0a75d8a99f0e9bf942 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 29 Apr 2024 16:13:43 -0400 Subject: [PATCH 28/42] Simulate NAK state during end-to-end testing --- tests/test_ash.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_ash.py b/tests/test_ash.py index 2cef7684..da2a871f 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -10,6 +10,10 @@ class AshNcpProtocol(ash.AshProtocol): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.nak_state = False + def frame_received(self, frame: ash.AshFrame) -> None: if self._ncp_reset_code is not None and not isinstance(frame, ash.RstFrame): ash._LOGGER.debug( @@ -22,6 +26,10 @@ def frame_received(self, frame: ash.AshFrame) -> None: ) return + if self.nak_state: + self._write_frame(ash.NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) + return + super().frame_received(frame) def _enter_ncp_error_state(self, code: t.NcpResetCode | None) -> None: @@ -261,6 +269,17 @@ def write(self, data): ncp_ezsp.data_received.reset_mock() host_ezsp.data_received.reset_mock() + # Simulate OOM on the NCP and send NAKs for a bit + with patch.object(ncp, "nak_state", True): + send_task = asyncio.create_task(host.send_data(b"ncp NAKing")) + await asyncio.sleep(host._t_rx_ack) + + # It'll still succeed + await send_task + + ncp_ezsp.data_received.reset_mock() + host_ezsp.data_received.reset_mock() + # When the NCP fail to receive a reply, it enters a failed state assert host._ncp_reset_code is None assert ncp._ncp_reset_code is None From e83e7153bf704536deea7bde237630995b660468 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 29 Apr 2024 16:16:11 -0400 Subject: [PATCH 29/42] Ensure transports are resilient when it comes to framing --- tests/test_ash.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/tests/test_ash.py b/tests/test_ash.py index da2a871f..7629ec56 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -208,29 +208,37 @@ async def test_ash_protocol_startup(): ] +class FakeTransport: + def __init__(self, receiver): + self.receiver = receiver + self.paused = False + + def write(self, data): + if not self.paused: + self.receiver.data_received(data) + + +class FakeTransportOneByteAtATime(FakeTransport): + def write(self, data): + for byte in data: + super().write(bytes([byte])) + + @patch("bellows.ash.T_RX_ACK_INIT", ash.T_RX_ACK_INIT / 100) @patch("bellows.ash.T_RX_ACK_MIN", ash.T_RX_ACK_MIN / 100) @patch("bellows.ash.T_RX_ACK_MAX", ash.T_RX_ACK_MAX / 100) -async def test_ash_end_to_end(): +@pytest.mark.parametrize("transport_cls", [FakeTransport, FakeTransportOneByteAtATime]) +async def test_ash_end_to_end(transport_cls: type[FakeTransport]) -> None: asyncio.get_running_loop() host_ezsp = MagicMock() ncp_ezsp = MagicMock() - class FakeTransport: - def __init__(self, receiver): - self.receiver = receiver - self.paused = False - - def write(self, data): - if not self.paused: - self.receiver.data_received(data) - host = ash.AshProtocol(host_ezsp) ncp = AshNcpProtocol(ncp_ezsp) - host_transport = FakeTransport(ncp) - ncp_transport = FakeTransport(host) + host_transport = transport_cls(ncp) + ncp_transport = transport_cls(host) host.connection_made(host_transport) ncp.connection_made(ncp_transport) From bfb780078e850b792ff2d27749eab52f55732923 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 29 Apr 2024 16:28:04 -0400 Subject: [PATCH 30/42] Introduce random loss testing as well --- tests/test_ash.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/test_ash.py b/tests/test_ash.py index 7629ec56..c0517fb1 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import random from unittest.mock import MagicMock, call, patch import pytest @@ -9,6 +10,11 @@ import bellows.types as t +@pytest.fixture(autouse=True, scope="function") +def random_seed(): + random.seed(0) + + class AshNcpProtocol(ash.AshProtocol): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -224,10 +230,22 @@ def write(self, data): super().write(bytes([byte])) +class FakeTransportRandomLoss(FakeTransport): + def write(self, data): + if random.random() < 0.25: + return + + for byte in data: + super().write(bytes([byte])) + + @patch("bellows.ash.T_RX_ACK_INIT", ash.T_RX_ACK_INIT / 100) @patch("bellows.ash.T_RX_ACK_MIN", ash.T_RX_ACK_MIN / 100) @patch("bellows.ash.T_RX_ACK_MAX", ash.T_RX_ACK_MAX / 100) -@pytest.mark.parametrize("transport_cls", [FakeTransport, FakeTransportOneByteAtATime]) +@pytest.mark.parametrize( + "transport_cls", + [FakeTransport, FakeTransportOneByteAtATime, FakeTransportRandomLoss], +) async def test_ash_end_to_end(transport_cls: type[FakeTransport]) -> None: asyncio.get_running_loop() From 20472c2de8e3a7615313f263eabf9e0f844c095b Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 29 Apr 2024 17:04:49 -0400 Subject: [PATCH 31/42] Add more tests --- bellows/ash.py | 19 +++++------ tests/test_ash.py | 80 +++++++++++++++++++++++++++++++---------------- 2 files changed, 60 insertions(+), 39 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index c9371b00..637bb85f 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -88,14 +88,6 @@ class ParsingError(Exception): pass -class InvalidChecksum(ParsingError): - pass - - -class FrameTooShort(ParsingError): - pass - - class AshException(Exception): pass @@ -144,12 +136,15 @@ def to_bytes(self) -> bytes: @classmethod def _unwrap(cls, data: bytes) -> tuple[int, bytes]: if len(data) < 3: - raise FrameTooShort(f"Frame is too short: {data!r}") + raise ParsingError(f"Frame is too short: {data!r}") computed_crc = binascii.crc_hqx(data[:-2], 0xFFFF).to_bytes(2, "big") if computed_crc != data[-2:]: - raise InvalidChecksum(f"Invalid CRC bytes in frame {data!r}") + raise ParsingError( + f"Invalid CRC bytes in frame {data!r}:" + f" expected {computed_crc.hex()}, got {data[-2:].hex()}" + ) return data[0], data[1:-2] @@ -444,7 +439,7 @@ def data_received(self, data: bytes) -> None: try: frame = self._extract_frame(data) except Exception: - _LOGGER.warning( + _LOGGER.debug( "Failed to parse frame %r", frame_bytes, exc_info=True ) else: @@ -454,7 +449,7 @@ def data_received(self, data: bytes) -> None: # All data received since the previous Flag Byte to be ignored self._buffer = self._buffer[reserved_index + 1 :] elif reserved_byte == SUBSTITUTE[0]: - _LOGGER.warning("Received substitute byte, marking buffer as corrupted") + _LOGGER.debug("Received substitute byte, marking buffer as corrupted") # The data between the previous and the next Flag Byte is ignored self._discarding_until_next_flag = True self._buffer = self._buffer[reserved_index + 1 :] diff --git a/tests/test_ash.py b/tests/test_ash.py index c0517fb1..df01f171 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -83,6 +83,31 @@ def send_reset(self) -> None: raise NotImplementedError() +class FakeTransport: + def __init__(self, receiver): + self.receiver = receiver + self.paused = False + + def write(self, data): + if not self.paused: + self.receiver.data_received(data) + + +class FakeTransportOneByteAtATime(FakeTransport): + def write(self, data): + for byte in data: + super().write(bytes([byte])) + + +class FakeTransportRandomLoss(FakeTransport): + def write(self, data): + if random.random() < 0.25: + return + + for byte in data: + super().write(bytes([byte])) + + def test_stuffing(): assert ash.AshProtocol._stuff_bytes(b"\x7E") == b"\x7D\x5E" assert ash.AshProtocol._stuff_bytes(b"\x11") == b"\x7D\x31" @@ -98,6 +123,9 @@ def test_stuffing(): assert ash.AshProtocol._unstuff_bytes(b"\x7D\x3A") == b"\x1A" assert ash.AshProtocol._unstuff_bytes(b"\x7D\x5D") == b"\x7D" + assert ash.AshProtocol._stuff_bytes(b"\x7F") == b"\x7F" + assert ash.AshProtocol._unstuff_bytes(b"\x7F") == b"\x7F" + def test_pseudo_random_data_sequence(): assert ash.PSEUDO_RANDOM_DATA_SEQUENCE.startswith(b"\x42\x21\xA8\x54\x2A") @@ -105,10 +133,33 @@ def test_pseudo_random_data_sequence(): def test_rst_frame(): assert ash.RstFrame() == ash.RstFrame() - assert ash.RstFrame().to_bytes() == bytes.fromhex("c038bc") - assert ash.RstFrame.from_bytes(bytes.fromhex("c038bc")) == ash.RstFrame() + assert ash.RstFrame().to_bytes() == b"\xC0\x38\xBC" + assert ash.RstFrame.from_bytes(b"\xC0\x38\xBC") == ash.RstFrame() assert str(ash.RstFrame()) == "RstFrame()" + with pytest.raises(ValueError, match=r"Invalid data for RST frame: "): + ash.RstFrame.from_bytes(ash.AshFrame.append_crc(b"\xC0\xAB")) + + +def test_rstack_frame(): + frm = ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE) + + assert frm == frm + assert frm.to_bytes() == b"\xc1\x02\x0b\x0a\x52" + assert ash.RStackFrame.from_bytes(frm.to_bytes()) == frm + assert ( + str(frm) + == "RStackFrame(version=2, reset_code=)" + ) + + with pytest.raises(ValueError, match=r"Invalid data length for RSTACK frame: "): + # Adding \xAB in the middle of the frame makes it invalid + ash.RStackFrame.from_bytes(ash.AshFrame.append_crc(b"\xc1\x02\xab\x0b")) + + with pytest.raises(ValueError, match=r"Invalid version for RSTACK frame: 3"): + # Version 3 is unknown + ash.RStackFrame.from_bytes(ash.AshFrame.append_crc(b"\xc1\x03\x0b")) + def test_cancel_byte(): ezsp = MagicMock() @@ -214,31 +265,6 @@ async def test_ash_protocol_startup(): ] -class FakeTransport: - def __init__(self, receiver): - self.receiver = receiver - self.paused = False - - def write(self, data): - if not self.paused: - self.receiver.data_received(data) - - -class FakeTransportOneByteAtATime(FakeTransport): - def write(self, data): - for byte in data: - super().write(bytes([byte])) - - -class FakeTransportRandomLoss(FakeTransport): - def write(self, data): - if random.random() < 0.25: - return - - for byte in data: - super().write(bytes([byte])) - - @patch("bellows.ash.T_RX_ACK_INIT", ash.T_RX_ACK_INIT / 100) @patch("bellows.ash.T_RX_ACK_MIN", ash.T_RX_ACK_MIN / 100) @patch("bellows.ash.T_RX_ACK_MAX", ash.T_RX_ACK_MAX / 100) From c4de177ed10a73baab65a19d51d26b4ec52abe42 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 29 Apr 2024 18:08:44 -0400 Subject: [PATCH 32/42] Add more tests --- bellows/ash.py | 63 ++++++++++++++++------------------- bellows/uart.py | 2 +- tests/test_ash.py | 84 ++++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 109 insertions(+), 40 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index 637bb85f..f23977c9 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -108,26 +108,12 @@ def __repr__(self) -> str: return f"<{self.__class__.__name__}(code={self.code})>" -class OutOfSequenceError(AshException): - def __init__(self, expected_seq: int, frame: AshFrame): - self.expected_seq = expected_seq - self.frame = frame - - def __repr__(self) -> str: - return ( - f"<{self.__class__.__name__}(" - f"expected_seq={self.expected_seq}" - f", frame={self.frame}" - f")>" - ) - - class AshFrame(abc.ABC, BaseDataclassMixin): MASK: t.uint8_t MASK_VALUE: t.uint8_t @classmethod - def from_bytes(cls, data: bytes) -> DataFrame: + def from_bytes(cls, data: bytes) -> AshFrame: raise NotImplementedError() def to_bytes(self) -> bytes: @@ -267,7 +253,7 @@ def from_bytes(cls, data: bytes) -> RstFrame: control, data = cls._unwrap(data) if data: - raise ValueError(f"Invalid data for RST frame: {data!r}") + raise ParsingError(f"Invalid data for RST frame: {data!r}") return cls() @@ -288,12 +274,12 @@ def from_bytes(cls, data: bytes) -> RStackFrame: control, data = cls._unwrap(data) if len(data) != 2: - raise ValueError(f"Invalid data length for RSTACK frame: {data!r}") + raise ParsingError(f"Invalid data length for RSTACK frame: {data!r}") version = data[0] if version != 0x02: - raise ValueError(f"Invalid version for RSTACK frame: {version}") + raise ParsingError(f"Invalid version for RSTACK frame: {data!r}") reset_code = t.NcpResetCode(data[1]) @@ -319,6 +305,27 @@ class ErrorFrame(AshFrame): to_bytes = RStackFrame.to_bytes +def parse_frame( + data: bytes, +) -> DataFrame | AckFrame | NakFrame | RstFrame | RStackFrame | ErrorFrame: + """Parse a frame from the given data, looking at the control byte.""" + control_byte = data[0] + + # In order of use + for frame in [ + DataFrame, + AckFrame, + NakFrame, + RstFrame, + RStackFrame, + ErrorFrame, + ]: + if control_byte & frame.MASK == frame.MASK_VALUE: + return frame.from_bytes(data) + else: + raise ParsingError(f"Could not determine frame type: {data!r}") + + class AshProtocol(asyncio.Protocol): def __init__(self, ezsp_protocol) -> None: self._ezsp_protocol = ezsp_protocol @@ -348,22 +355,6 @@ def close(self): if self._transport is not None: self._transport.close() - def _extract_frame(self, data: bytes) -> AshFrame: - control_byte = data[0] - - for frame in [ - DataFrame, - AckFrame, - NakFrame, - RstFrame, - RStackFrame, - ErrorFrame, - ]: - if control_byte & frame.MASK == frame.MASK_VALUE: - return frame.from_bytes(data) - else: - raise ValueError(f"Could not determine frame type: {data!r}") - @staticmethod def _stuff_bytes(data: bytes) -> bytes: """Stuff bytes for transmission""" @@ -437,7 +428,7 @@ def data_received(self, data: bytes) -> None: data = self._unstuff_bytes(frame_bytes) try: - frame = self._extract_frame(data) + frame = parse_frame(data) except Exception: _LOGGER.debug( "Failed to parse frame %r", frame_bytes, exc_info=True @@ -551,6 +542,8 @@ def error_frame_received(self, frame: ErrorFrame) -> None: if not fut.done(): fut.set_exception(exc) + self._ezsp_protocol.error_received(frame.reset_code) + def _write_frame(self, frame: AshFrame) -> None: _LOGGER.debug("Sending frame %r", frame) data = self._stuff_bytes(frame.to_bytes()) + FLAG diff --git a/bellows/uart.py b/bellows/uart.py index 4a0bc1e6..ee2aea08 100644 --- a/bellows/uart.py +++ b/bellows/uart.py @@ -74,7 +74,7 @@ def reset_received(self, code: t.NcpResetCode) -> None: else: LOGGER.warning("Received an unexpected reset: %r", code) - def error_received(self, code): + def error_received(self, code: t.NcpResetCode) -> None: """Error frame receive handler.""" self._application.enter_failed_state(code) diff --git a/tests/test_ash.py b/tests/test_ash.py index df01f171..2fdc46bf 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -131,20 +131,27 @@ def test_pseudo_random_data_sequence(): assert ash.PSEUDO_RANDOM_DATA_SEQUENCE.startswith(b"\x42\x21\xA8\x54\x2A") +def test_frame_parsing_errors(): + with pytest.raises(ash.ParsingError, match=r"Frame is too short:"): + assert ash.RstFrame.from_bytes(b"\xC0\x38") + + with pytest.raises(ash.ParsingError, match=r"Invalid CRC bytes in frame"): + assert ash.RstFrame.from_bytes(b"\xC0\xAB\xCD") + + def test_rst_frame(): assert ash.RstFrame() == ash.RstFrame() assert ash.RstFrame().to_bytes() == b"\xC0\x38\xBC" assert ash.RstFrame.from_bytes(b"\xC0\x38\xBC") == ash.RstFrame() assert str(ash.RstFrame()) == "RstFrame()" - with pytest.raises(ValueError, match=r"Invalid data for RST frame: "): + with pytest.raises(ash.ParsingError, match=r"Invalid data for RST frame:"): ash.RstFrame.from_bytes(ash.AshFrame.append_crc(b"\xC0\xAB")) def test_rstack_frame(): frm = ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE) - assert frm == frm assert frm.to_bytes() == b"\xc1\x02\x0b\x0a\x52" assert ash.RStackFrame.from_bytes(frm.to_bytes()) == frm assert ( @@ -152,11 +159,13 @@ def test_rstack_frame(): == "RStackFrame(version=2, reset_code=)" ) - with pytest.raises(ValueError, match=r"Invalid data length for RSTACK frame: "): + with pytest.raises( + ash.ParsingError, match=r"Invalid data length for RSTACK frame:" + ): # Adding \xAB in the middle of the frame makes it invalid ash.RStackFrame.from_bytes(ash.AshFrame.append_crc(b"\xc1\x02\xab\x0b")) - with pytest.raises(ValueError, match=r"Invalid version for RSTACK frame: 3"): + with pytest.raises(ash.ParsingError, match=r"Invalid version for RSTACK frame:"): # Version 3 is unknown ash.RStackFrame.from_bytes(ash.AshFrame.append_crc(b"\xc1\x03\x0b")) @@ -175,6 +184,73 @@ def test_cancel_byte(): ] assert not protocol._buffer + # Parse byte-by-byte + protocol.frame_received.reset_mock() + + for byte in bytes.fromhex("ddf9ff 1ac1020b0a527e"): + protocol.data_received(bytes([byte])) + + # We still parse out the RSTACK frame + assert protocol.frame_received.mock_calls == [ + call(ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE)) + ] + assert not protocol._buffer + + +def test_substitute_byte(): + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + protocol.frame_received = MagicMock(wraps=protocol.frame_received) + + protocol.data_received(bytes.fromhex("c0 38bc 7e")) # RST frame + assert protocol.frame_received.mock_calls == [call(ash.RstFrame())] + protocol.data_received(bytes.fromhex("c0 18 38bc 7e")) # RST frame + SUBSTITUTE + assert protocol.frame_received.mock_calls == [call(ash.RstFrame())] # ignored! + protocol.data_received(bytes.fromhex("c1 020b 0a52 7e")) # RSTACK frame + assert protocol.frame_received.mock_calls == [ + call(ash.RstFrame()), + call(ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE)), + ] + + assert not protocol._buffer + + +def test_xon_xoff_bytes(): + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + protocol.frame_received = MagicMock(wraps=protocol.frame_received) + + protocol.data_received(bytes.fromhex("c0 11 38bc 13 7e")) # RST frame + XON + XOFF + assert protocol.frame_received.mock_calls == [call(ash.RstFrame())] + + +def test_multiple_eof(): + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + protocol.frame_received = MagicMock(wraps=protocol.frame_received) + + protocol.data_received(bytes.fromhex("c0 38bc 7e 7e 7e")) # RST frame + assert protocol.frame_received.mock_calls == [call(ash.RstFrame())] + + +def test_discarding(): + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + protocol.frame_received = MagicMock(wraps=protocol.frame_received) + + # RST frame with embedded SUBSTITUTE + protocol.data_received(bytes.fromhex("c0 18 38bc")) + + # Garbage: still ignored + protocol.data_received(bytes.fromhex("aa bb cc dd ee ff")) + + # Frame boundary: we now will handle data + protocol.data_received(bytes.fromhex("7e")) + + # Normal RST frame + protocol.data_received(bytes.fromhex("c0 38bc 7e 7e 7e")) + assert protocol.frame_received.mock_calls == [call(ash.RstFrame())] + def test_buffer_growth(): ezsp = MagicMock() From 1acb4a08d24fd240acca93d4b41671fc6ec2894e Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 30 Apr 2024 12:52:37 -0400 Subject: [PATCH 33/42] Cancel all pending frames when receiving a NAK --- bellows/ash.py | 15 ++++------ tests/test_ash.py | 71 ++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 69 insertions(+), 17 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index f23977c9..204075ce 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -522,8 +522,8 @@ def ack_frame_received(self, frame: AckFrame) -> None: def nak_frame_received(self, frame: NakFrame) -> None: err = NotAcked(frame=frame) - for frm_num, fut in self._pending_data_frames.items(): - if not frame.ack_num - TX_K <= frm_num <= frame.ack_num and not fut.done(): + for fut in self._pending_data_frames.values(): + if not fut.done(): fut.set_exception(err) def rst_frame_received(self, frame: RstFrame) -> None: @@ -542,7 +542,7 @@ def error_frame_received(self, frame: ErrorFrame) -> None: if not fut.done(): fut.set_exception(exc) - self._ezsp_protocol.error_received(frame.reset_code) + self._ezsp_protocol.reset_received(frame.reset_code) def _write_frame(self, frame: AshFrame) -> None: _LOGGER.debug("Sending frame %r", frame) @@ -561,12 +561,7 @@ def _change_ack_timeout(self, new_value: float) -> None: self._t_rx_ack = new_value - async def _send_frame(self, frame: AshFrame) -> None: - if not isinstance(frame, DataFrame): - # Non-DATA frames can be sent immediately and do not require an ACK - self._write_frame(frame) - return - + async def _send_data_frame(self, frame: AshFrame) -> None: if self._send_data_frame_semaphore.locked(): _LOGGER.debug("Semaphore is locked, waiting") @@ -644,7 +639,7 @@ async def _send_frame(self, frame: AshFrame) -> None: self._pending_data_frames.pop(frm_num) async def send_data(self, data: bytes) -> None: - await self._send_frame( + await self._send_data_frame( # All of the other fields will be set during transmission/retries DataFrame(frm_num=None, re_tx=None, ack_num=None, ezsp_frame=data) ) diff --git a/tests/test_ash.py b/tests/test_ash.py index 2fdc46bf..33907e79 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -33,7 +33,12 @@ def frame_received(self, frame: ash.AshFrame) -> None: return if self.nak_state: - self._write_frame(ash.NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) + asyncio.get_running_loop().call_later( + 2 * self._t_rx_ack, + lambda: self._write_frame( + ash.NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq) + ), + ) return super().frame_received(frame) @@ -66,18 +71,14 @@ def rst_frame_received(self, frame: ash.RstFrame) -> None: ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE) ) - async def _send_frame(self, frame: ash.AshFrame) -> None: + async def _send_data_frame(self, frame: ash.AshFrame) -> None: try: - return await super()._send_frame(frame) + return await super()._send_data_frame(frame) except asyncio.TimeoutError: self._enter_ncp_error_state( t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT ) raise - if not isinstance(frame, ash.DataFrame): - # Non-DATA frames can be sent immediately and do not require an ACK - self._write_frame(frame) - return def send_reset(self) -> None: raise NotImplementedError() @@ -264,6 +265,62 @@ def test_buffer_growth(): assert len(protocol._buffer) == ash.MAX_BUFFER_SIZE +async def test_sequence(): + loop = asyncio.get_running_loop() + ezsp = MagicMock() + transport = MagicMock() + + protocol = ash.AshProtocol(ezsp) + protocol._write_frame = MagicMock(wraps=protocol._write_frame) + protocol.connection_made(transport) + + # Normal send/receive + loop.call_later( + 0, + protocol.frame_received, + ash.DataFrame(frm_num=0, re_tx=False, ack_num=1, ezsp_frame=b"rx 1"), + ) + await protocol.send_data(b"tx 1") + assert protocol._write_frame.mock_calls[-1] == call( + ash.AckFrame(res=0, ncp_ready=0, ack_num=1) + ) + + assert protocol._rx_seq == 1 + assert protocol._tx_seq == 1 + assert ezsp.data_received.mock_calls == [call(b"rx 1")] + + # Skip ACK 2: we are out of sync! + protocol.frame_received( + ash.DataFrame(frm_num=2, re_tx=False, ack_num=1, ezsp_frame=b"out of sequence") + ) + + # We NAK it, it is out of sequence! + assert protocol._write_frame.mock_calls[-1] == call( + ash.NakFrame(res=0, ncp_ready=0, ack_num=1) + ) + + # Sequence numbers remain intact + assert protocol._rx_seq == 1 + assert protocol._tx_seq == 1 + + # Re-sync properly + protocol.frame_received( + ash.DataFrame(frm_num=1, re_tx=False, ack_num=1, ezsp_frame=b"rx 2") + ) + + assert ezsp.data_received.mock_calls == [call(b"rx 1"), call(b"rx 2")] + + # Trigger an error + loop.call_later( + 0, + protocol.frame_received, + ash.ErrorFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE), + ) + + with pytest.raises(ash.NcpFailure): + await protocol.send_data(b"tx 2") + + async def test_ash_protocol_startup(): """Simple EZSP startup: reset, version(4), then version(8).""" From 7d3735dff543fb90fcbe271ddc645d0dcd21ba20 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 30 Apr 2024 13:03:23 -0400 Subject: [PATCH 34/42] Move reserved bytes into an enum --- bellows/ash.py | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index 204075ce..3bcc1867 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -22,15 +22,17 @@ MAX_BUFFER_SIZE = 1024 -FLAG = b"\x7E" # Marks end of frame -ESCAPE = b"\x7D" -XON = b"\x11" # Resume transmission -XOFF = b"\x13" # Stop transmission -SUBSTITUTE = b"\x18" # Replaces a byte received with a low-level communication error -CANCEL = b"\x1A" # Terminates a frame in progress -RESERVED = frozenset(FLAG + ESCAPE + XON + XOFF + SUBSTITUTE + CANCEL) -RESERVED_WITHOUT_ESCAPE = RESERVED - frozenset([ESCAPE[0]]) +class Reserved(enum.IntEnum): + FLAG = 0x7E # Marks end of frame + ESCAPE = 0x7D + XON = 0x11 # Resume transmission + XOFF = 0x13 # Stop transmission + SUBSTITUTE = 0x18 # Replaces a byte received with a low-level communication error + CANCEL = 0x1A # Terminates a frame in progress + + +RESERVED_WITHOUT_ESCAPE = frozenset([v for v in Reserved if v != Reserved.ESCAPE]) # Initial value of t_rx_ack, the maximum time the NCP waits to receive acknowledgement # of a DATA frame @@ -361,8 +363,8 @@ def _stuff_bytes(data: bytes) -> bytes: out = bytearray() for c in data: - if c in RESERVED: - out.extend([ESCAPE[0], c ^ 0b00100000]) + if c in Reserved: + out.extend([Reserved.ESCAPE, c ^ 0b00100000]) else: out.append(c) @@ -377,10 +379,10 @@ def _unstuff_bytes(data: bytes) -> bytes: for c in data: if escaped: byte = c ^ 0b00100000 - assert byte in RESERVED + assert byte in Reserved out.append(byte) escaped = False - elif c == ESCAPE[0]: + elif c == Reserved.ESCAPE: escaped = True else: out.append(c) @@ -399,12 +401,12 @@ def data_received(self, data: bytes) -> None: while self._buffer: if self._discarding_until_next_flag: - if FLAG not in self._buffer: + if bytes([Reserved.FLAG]) not in self._buffer: self._buffer.clear() break self._discarding_until_next_flag = False - _, _, self._buffer = self._buffer.partition(FLAG) + _, _, self._buffer = self._buffer.partition(bytes([Reserved.FLAG])) try: # Find the index of the first reserved byte that isn't an escape byte @@ -416,7 +418,7 @@ def data_received(self, data: bytes) -> None: except StopIteration: break - if reserved_byte == FLAG[0]: + if reserved_byte == Reserved.FLAG: # Flag Byte marks the end of a frame frame_bytes = self._buffer[:reserved_index] self._buffer = self._buffer[reserved_index + 1 :] @@ -435,20 +437,20 @@ def data_received(self, data: bytes) -> None: ) else: self.frame_received(frame) - elif reserved_byte == CANCEL[0]: + elif reserved_byte == Reserved.CANCEL: _LOGGER.debug("Received cancel byte, clearing buffer") # All data received since the previous Flag Byte to be ignored self._buffer = self._buffer[reserved_index + 1 :] - elif reserved_byte == SUBSTITUTE[0]: + elif reserved_byte == Reserved.SUBSTITUTE: _LOGGER.debug("Received substitute byte, marking buffer as corrupted") # The data between the previous and the next Flag Byte is ignored self._discarding_until_next_flag = True self._buffer = self._buffer[reserved_index + 1 :] - elif reserved_byte == XON[0]: + elif reserved_byte == Reserved.XON: # Resume transmission: not implemented! _LOGGER.debug("Received XON byte, resuming transmission") self._buffer.pop(reserved_index) - elif reserved_byte == XOFF[0]: + elif reserved_byte == Reserved.XOFF: # Pause transmission: not implemented! _LOGGER.debug("Received XOFF byte, pausing transmission") self._buffer.pop(reserved_index) @@ -546,7 +548,7 @@ def error_frame_received(self, frame: ErrorFrame) -> None: def _write_frame(self, frame: AshFrame) -> None: _LOGGER.debug("Sending frame %r", frame) - data = self._stuff_bytes(frame.to_bytes()) + FLAG + data = self._stuff_bytes(frame.to_bytes()) + bytes([Reserved.FLAG]) _LOGGER.debug("Sending data %s", data.hex()) self._transport.write(data) From bb409101cc566f420c8c7026907303d2553785d2 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 30 Apr 2024 13:04:12 -0400 Subject: [PATCH 35/42] Send a `CANCEL` byte before the reset frame --- bellows/ash.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index 3bcc1867..1adf1abc 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -546,9 +546,12 @@ def error_frame_received(self, frame: ErrorFrame) -> None: self._ezsp_protocol.reset_received(frame.reset_code) - def _write_frame(self, frame: AshFrame) -> None: + def _write_frame(self, frame: AshFrame, *, cancel: bool = False) -> None: _LOGGER.debug("Sending frame %r", frame) + data = self._stuff_bytes(frame.to_bytes()) + bytes([Reserved.FLAG]) + if cancel: + data = bytes([Reserved.CANCEL]) + data _LOGGER.debug("Sending data %s", data.hex()) self._transport.write(data) @@ -647,4 +650,4 @@ async def send_data(self, data: bytes) -> None: ) def send_reset(self) -> None: - self._write_frame(RstFrame()) + self._write_frame(RstFrame(), cancel=True) From d221d5984bbad5f7f3cb168dfbab5c3c2023b9ed Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 30 Apr 2024 13:17:31 -0400 Subject: [PATCH 36/42] Improve logging --- bellows/ash.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index 1adf1abc..c3ad00e5 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -546,13 +546,19 @@ def error_frame_received(self, frame: ErrorFrame) -> None: self._ezsp_protocol.reset_received(frame.reset_code) - def _write_frame(self, frame: AshFrame, *, cancel: bool = False) -> None: - _LOGGER.debug("Sending frame %r", frame) - - data = self._stuff_bytes(frame.to_bytes()) + bytes([Reserved.FLAG]) - if cancel: - data = bytes([Reserved.CANCEL]) + data - + def _write_frame( + self, + frame: AshFrame, + *, + prefix: tuple[Reserved] = (), + suffix: tuple[Reserved] = (Reserved.FLAG,), + ) -> None: + if _LOGGER.isEnabledFor(logging.DEBUG): + prefix_str = "".join([f"{r.name} + " for r in prefix]) + suffix_str = "".join([f" + {r.name}" for r in suffix]) + _LOGGER.debug("Sending frame %s%r%s", prefix_str, frame, suffix_str) + + data = bytes(prefix) + self._stuff_bytes(frame.to_bytes()) + bytes(suffix) _LOGGER.debug("Sending data %s", data.hex()) self._transport.write(data) @@ -650,4 +656,4 @@ async def send_data(self, data: bytes) -> None: ) def send_reset(self) -> None: - self._write_frame(RstFrame(), cancel=True) + self._write_frame(RstFrame(), prefix=(Reserved.CANCEL,)) From 547d974084eb8667284759bc9890fa479f051cd4 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 30 Apr 2024 13:27:09 -0400 Subject: [PATCH 37/42] Fix unit tests --- tests/test_ash.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_ash.py b/tests/test_ash.py index 33907e79..fd3ad5ca 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -349,7 +349,9 @@ async def test_ash_protocol_startup(): await asyncio.sleep(0.01) assert ezsp.reset_received.mock_calls == [call(t.NcpResetCode.RESET_SOFTWARE)] - assert protocol._write_frame.mock_calls == [call(ash.RstFrame())] + assert protocol._write_frame.mock_calls == [ + call(ash.RstFrame(), prefix=(ash.Reserved.CANCEL,)) + ] protocol._write_frame.reset_mock() From 58b41bd35feb5656534a0fab0ab8b5a53f52010b Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 30 Apr 2024 13:37:02 -0400 Subject: [PATCH 38/42] Ensure codebase works with 3.8 --- bellows/ash.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index c3ad00e5..527376e4 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -32,6 +32,7 @@ class Reserved(enum.IntEnum): CANCEL = 0x1A # Terminates a frame in progress +RESERVED_BYTES = frozenset(Reserved) RESERVED_WITHOUT_ESCAPE = frozenset([v for v in Reserved if v != Reserved.ESCAPE]) # Initial value of t_rx_ack, the maximum time the NCP waits to receive acknowledgement @@ -363,7 +364,7 @@ def _stuff_bytes(data: bytes) -> bytes: out = bytearray() for c in data: - if c in Reserved: + if c in RESERVED_BYTES: out.extend([Reserved.ESCAPE, c ^ 0b00100000]) else: out.append(c) @@ -379,7 +380,7 @@ def _unstuff_bytes(data: bytes) -> bytes: for c in data: if escaped: byte = c ^ 0b00100000 - assert byte in Reserved + assert byte in RESERVED_BYTES out.append(byte) escaped = False elif c == Reserved.ESCAPE: From 5c9bb44690d6c964d4b9fee0e3685d9951e5ac37 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 30 Apr 2024 17:21:22 -0400 Subject: [PATCH 39/42] Almost at 100% coverage --- bellows/ash.py | 4 +- tests/test_ash.py | 105 ++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 98 insertions(+), 11 deletions(-) diff --git a/bellows/ash.py b/bellows/ash.py index 527376e4..3c2a4b4f 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -467,9 +467,7 @@ def _handle_ack(self, frame: DataFrame | AckFrame) -> None: fut = self._pending_data_frames.get(ack_num) - if fut is None: - return - elif fut.done(): + if fut is None or fut.done(): return # _LOGGER.debug("Resolving frame %d", ack_num) diff --git a/tests/test_ash.py b/tests/test_ash.py index fd3ad5ca..4bc3f0e2 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging import random from unittest.mock import MagicMock, call, patch @@ -89,19 +90,19 @@ def __init__(self, receiver): self.receiver = receiver self.paused = False - def write(self, data): + def write(self, data: bytes) -> None: if not self.paused: self.receiver.data_received(data) class FakeTransportOneByteAtATime(FakeTransport): - def write(self, data): + def write(self, data: bytes) -> None: for byte in data: super().write(bytes([byte])) class FakeTransportRandomLoss(FakeTransport): - def write(self, data): + def write(self, data: bytes) -> None: if random.random() < 0.25: return @@ -109,6 +110,54 @@ def write(self, data): super().write(bytes([byte])) +class FakeTransportWithDelays(FakeTransport): + def write(self, data): + asyncio.get_running_loop().call_later(0, super().write, data) + + +def test_ash_exception_repr() -> None: + assert ( + repr(ash.NotAcked(ash.NakFrame(res=0, ncp_ready=0, ack_num=1))) + == "" + ) + assert ( + repr(ash.NcpFailure(t.NcpResetCode.RESET_SOFTWARE)) + == ")>" + ) + + +@pytest.mark.parametrize( + "frame", + [ + ash.RstFrame(), + ash.AckFrame(res=0, ncp_ready=0, ack_num=1), + ash.NakFrame(res=0, ncp_ready=0, ack_num=1), + ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE), + ash.ErrorFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE), + ash.DataFrame(frm_num=0, re_tx=False, ack_num=1, ezsp_frame=b"test"), + ], +) +def test_parse_frame(frame: ash.AshFrame) -> None: + assert ash.parse_frame(frame.to_bytes()) == frame + + +def test_parse_frame_failure() -> None: + with pytest.raises(ash.ParsingError): + ash.parse_frame(b"test") + + +def test_ash_protocol_event_propagation() -> None: + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + + err = RuntimeError("test") + protocol.connection_lost(err) + assert ezsp.connection_lost.mock_calls == [call(err)] + + protocol.eof_received() + assert ezsp.eof_received.mock_calls == [call()] + + def test_stuffing(): assert ash.AshProtocol._stuff_bytes(b"\x7E") == b"\x7D\x5E" assert ash.AshProtocol._stuff_bytes(b"\x11") == b"\x7D\x31" @@ -321,9 +370,41 @@ async def test_sequence(): await protocol.send_data(b"tx 2") -async def test_ash_protocol_startup(): +async def test_frame_parsing_failure_recovery(caplog) -> None: + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + protocol.frame_received = MagicMock(spec_set=protocol.frame_received) + + protocol.data_received( + ash.DataFrame(frm_num=0, re_tx=0, ack_num=0, ezsp_frame=b"frame 1").to_bytes() + + bytes([ash.Reserved.FLAG]) + ) + + with caplog.at_level(logging.DEBUG): + protocol.data_received( + ash.AshFrame.append_crc(b"\xFESome unknown frame") + + bytes([ash.Reserved.FLAG]) + ) + + assert "Some unknown frame" in caplog.text + + protocol.data_received( + ash.DataFrame(frm_num=1, re_tx=0, ack_num=0, ezsp_frame=b"frame 2").to_bytes() + + bytes([ash.Reserved.FLAG]) + ) + + assert protocol.frame_received.mock_calls == [ + call(ash.DataFrame(frm_num=0, re_tx=0, ack_num=0, ezsp_frame=b"frame 1")), + call(ash.DataFrame(frm_num=1, re_tx=0, ack_num=0, ezsp_frame=b"frame 2")), + ] + + +async def test_ash_protocol_startup(caplog): """Simple EZSP startup: reset, version(4), then version(8).""" + # We have branching dependent on `_LOGGER.isEnabledFor` so test it here + caplog.set_level(logging.DEBUG) + loop = asyncio.get_running_loop() ezsp = MagicMock() @@ -405,7 +486,12 @@ async def test_ash_protocol_startup(): @patch("bellows.ash.T_RX_ACK_MAX", ash.T_RX_ACK_MAX / 100) @pytest.mark.parametrize( "transport_cls", - [FakeTransport, FakeTransportOneByteAtATime, FakeTransportRandomLoss], + [ + FakeTransport, + FakeTransportOneByteAtATime, + FakeTransportRandomLoss, + FakeTransportWithDelays, + ], ) async def test_ash_end_to_end(transport_cls: type[FakeTransport]) -> None: asyncio.get_running_loop() @@ -423,8 +509,11 @@ async def test_ash_end_to_end(transport_cls: type[FakeTransport]) -> None: ncp.connection_made(ncp_transport) # Ping pong works - await host.send_data(b"Hello!") - assert ncp_ezsp.data_received.mock_calls == [call(b"Hello!")] + await asyncio.gather( + host.send_data(b"Hello 1!"), + host.send_data(b"Hello 2!"), + ) + assert ncp_ezsp.data_received.mock_calls == [call(b"Hello 1!"), call(b"Hello 2!")] await ncp.send_data(b"World!") assert host_ezsp.data_received.mock_calls == [call(b"World!")] @@ -467,7 +556,7 @@ async def test_ash_end_to_end(transport_cls: type[FakeTransport]) -> None: ncp_ezsp.data_received.reset_mock() host_ezsp.data_received.reset_mock() - # When the NCP fail to receive a reply, it enters a failed state + # When the NCP fails to receive a reply, it enters a failed state assert host._ncp_reset_code is None assert ncp._ncp_reset_code is None From 61658a574aff1ab900fc6de22911ac8615826298 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 2 May 2024 14:41:34 -0400 Subject: [PATCH 40/42] Unit test UART callbacks --- tests/test_uart.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/test_uart.py b/tests/test_uart.py index 16cb8930..68ac8664 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -8,7 +8,7 @@ from bellows import uart import bellows.types as t -from .async_mock import AsyncMock, MagicMock, patch, sentinel +from .async_mock import AsyncMock, MagicMock, call, patch, sentinel @pytest.mark.parametrize("flow_control", ["software", "hardware"]) @@ -239,3 +239,13 @@ async def test_wait_for_startup_reset_failure(gw): await asyncio.wait_for(gw.wait_for_startup_reset(), 0.01) assert gw._startup_reset_future is None + + +async def test_callbacks(gw): + gw.data_received(b"some ezsp packet") + assert gw._application.frame_received.mock_calls == [call(b"some ezsp packet")] + + gw.error_received(t.NcpResetCode.RESET_SOFTWARE) + assert gw._application.enter_failed_state.mock_calls == [ + call(t.NcpResetCode.RESET_SOFTWARE) + ] From 20378f61bf942e6fa7e0a197820884cdb9f746ae Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 2 May 2024 16:31:06 -0400 Subject: [PATCH 41/42] Get coverage up to 100% --- tests/test_ash.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_ash.py b/tests/test_ash.py index 4bc3f0e2..690e21a0 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -103,7 +103,7 @@ def write(self, data: bytes) -> None: class FakeTransportRandomLoss(FakeTransport): def write(self, data: bytes) -> None: - if random.random() < 0.25: + if random.random() < 0.20: return for byte in data: @@ -584,3 +584,8 @@ async def test_ash_end_to_end(transport_cls: type[FakeTransport]) -> None: host.send_reset() await asyncio.sleep(0.01) await host.send_data(b"test") + + # Trigger a failure caused by excessive NAKs + with patch.object(ncp, "nak_state", True): + with pytest.raises(ash.NotAcked): + await host.send_data(b"ncp NAKing until failure") From 087d68256884fa95c34d23260da894460b1c0568 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 2 May 2024 16:36:57 -0400 Subject: [PATCH 42/42] Make tests less flaky --- tests/test_ash.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_ash.py b/tests/test_ash.py index 690e21a0..41000ce3 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -106,8 +106,7 @@ def write(self, data: bytes) -> None: if random.random() < 0.20: return - for byte in data: - super().write(bytes([byte])) + super().write(data) class FakeTransportWithDelays(FakeTransport): @@ -586,6 +585,9 @@ async def test_ash_end_to_end(transport_cls: type[FakeTransport]) -> None: await host.send_data(b"test") # Trigger a failure caused by excessive NAKs + ncp._t_rx_ack = ash.T_RX_ACK_INIT / 1000 + host._t_rx_ack = ash.T_RX_ACK_INIT / 1000 + with patch.object(ncp, "nak_state", True): with pytest.raises(ash.NotAcked): await host.send_data(b"ncp NAKing until failure")