diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bededc22..4953778d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,11 +18,12 @@ repos: - --quiet - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 + rev: 7.0.0 hooks: - id: flake8 + entry: pflake8 additional_dependencies: - - Flake8-pyproject==1.2.3 + - pyproject-flake8==7.0.0 - repo: https://github.com/PyCQA/isort rev: 5.12.0 diff --git a/bellows/ash.py b/bellows/ash.py new file mode 100644 index 00000000..3c2a4b4f --- /dev/null +++ b/bellows/ash.py @@ -0,0 +1,658 @@ +from __future__ import annotations + +import abc +import asyncio +import binascii +import dataclasses +import enum +import logging +import sys +import time + +if sys.version_info[:2] < (3, 11): + from async_timeout import timeout as asyncio_timeout # pragma: no cover +else: + from asyncio import timeout as asyncio_timeout # pragma: no cover + +from zigpy.types import BaseDataclassMixin + +import bellows.types as t + +_LOGGER = logging.getLogger(__name__) + +MAX_BUFFER_SIZE = 1024 + + +class Reserved(enum.IntEnum): + FLAG = 0x7E # Marks end of frame + ESCAPE = 0x7D + XON = 0x11 # Resume transmission + XOFF = 0x13 # Stop transmission + SUBSTITUTE = 0x18 # Replaces a byte received with a low-level communication error + CANCEL = 0x1A # Terminates a frame in progress + + +RESERVED_BYTES = frozenset(Reserved) +RESERVED_WITHOUT_ESCAPE = frozenset([v for v in Reserved if v != Reserved.ESCAPE]) + +# Initial value of t_rx_ack, the maximum time the NCP waits to receive acknowledgement +# of a DATA frame +T_RX_ACK_INIT = 1.6 + +# Minimum value of t_rx_ack +T_RX_ACK_MIN = 0.4 + +# Maximum value of t_rx_ack +T_RX_ACK_MAX = 3.2 + +# Delay before sending a non-piggybacked acknowledgement +T_TX_ACK_DELAY = 0.02 + +# Time from receiving an ACK or NAK with the nRdy flag set after which the NCP resumes +# sending callback frames to the host without requiring an ACK or NAK with the nRdy +# flag clear +T_REMOTE_NOTRDY = 1.0 + +# Maximum number of DATA frames the NCP can transmit without having received +# acknowledgements +TX_K = 1 + +# Maximum number of consecutive timeouts allowed while waiting to receive an ACK before +# going to the FAILED state. The value 0 prevents the NCP from entering the error state +# due to timeouts. +ACK_TIMEOUTS = 4 + + +def generate_random_sequence(length: int) -> bytes: + output = bytearray() + rand = 0x42 + + for _i in range(length): + output.append(rand) + + if rand & 0b00000001 == 0: + rand = rand >> 1 + else: + rand = (rand >> 1) ^ 0xB8 + + return output + + +# Since the sequence is static for every frame, we only need to generate it once +PSEUDO_RANDOM_DATA_SEQUENCE = generate_random_sequence(256) + + +class NcpState(enum.Enum): + CONNECTED = "connected" + FAILED = "failed" + + +class ParsingError(Exception): + pass + + +class AshException(Exception): + pass + + +class NotAcked(AshException): + def __init__(self, frame: NakFrame) -> None: + self.frame = frame + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(frame={self.frame})>" + + +class NcpFailure(AshException): + def __init__(self, code: t.NcpResetCode) -> None: + self.code = code + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(code={self.code})>" + + +class AshFrame(abc.ABC, BaseDataclassMixin): + MASK: t.uint8_t + MASK_VALUE: t.uint8_t + + @classmethod + def from_bytes(cls, data: bytes) -> AshFrame: + raise NotImplementedError() + + def to_bytes(self) -> bytes: + raise NotImplementedError() + + @classmethod + def _unwrap(cls, data: bytes) -> tuple[int, bytes]: + if len(data) < 3: + raise ParsingError(f"Frame is too short: {data!r}") + + computed_crc = binascii.crc_hqx(data[:-2], 0xFFFF).to_bytes(2, "big") + + if computed_crc != data[-2:]: + raise ParsingError( + f"Invalid CRC bytes in frame {data!r}:" + f" expected {computed_crc.hex()}, got {data[-2:].hex()}" + ) + + return data[0], data[1:-2] + + @staticmethod + def append_crc(data: bytes) -> bytes: + return data + binascii.crc_hqx(data, 0xFFFF).to_bytes(2, "big") + + +@dataclasses.dataclass(frozen=True) +class DataFrame(AshFrame): + MASK = 0b10000000 + MASK_VALUE = 0b00000000 + + frm_num: int + re_tx: bool + ack_num: int + ezsp_frame: bytes + + @staticmethod + def _randomize(data: bytes) -> bytes: + assert len(data) <= len(PSEUDO_RANDOM_DATA_SEQUENCE) + return bytes([a ^ b for a, b in zip(data, PSEUDO_RANDOM_DATA_SEQUENCE)]) + + @classmethod + def from_bytes(cls, data: bytes) -> DataFrame: + control, data = cls._unwrap(data) + + return cls( + frm_num=(control & 0b01110000) >> 4, + re_tx=(control & 0b00001000) >> 3, + ack_num=(control & 0b00000111) >> 0, + ezsp_frame=cls._randomize(data), + ) + + def to_bytes(self, *, randomize: bool = True) -> bytes: + return self.append_crc( + bytes( + [ + self.MASK_VALUE + | (self.frm_num) << 4 + | (self.re_tx) << 3 + | (self.ack_num) << 0 + ] + ) + + self._randomize(self.ezsp_frame) + ) + + +@dataclasses.dataclass(frozen=True) +class AckFrame(AshFrame): + MASK = 0b11100000 + MASK_VALUE = 0b10000000 + + res: int + ncp_ready: bool + ack_num: int + + @classmethod + def from_bytes(cls, data: bytes) -> AckFrame: + control, data = cls._unwrap(data) + + return cls( + res=(control & 0b00010000) >> 4, + ncp_ready=(control & 0b00001000) >> 3, + ack_num=(control & 0b00000111) >> 0, + ) + + def to_bytes(self) -> bytes: + return self.append_crc( + bytes( + [ + self.MASK_VALUE + | (self.res) << 4 + | (self.ncp_ready) << 3 + | (self.ack_num) << 0 + ] + ) + ) + + +@dataclasses.dataclass(frozen=True) +class NakFrame(AshFrame): + MASK = 0b11100000 + MASK_VALUE = 0b10100000 + + res: int + ncp_ready: bool + ack_num: int + + @classmethod + def from_bytes(cls, data: bytes) -> AckFrame: + control, data = cls._unwrap(data) + + return cls( + res=(control & 0b00010000) >> 4, + ncp_ready=(control & 0b00001000) >> 3, + ack_num=(control & 0b00000111) >> 0, + ) + + def to_bytes(self) -> bytes: + return self.append_crc( + bytes( + [ + self.MASK_VALUE + | (self.res) << 4 + | (self.ncp_ready) << 3 + | (self.ack_num) << 0 + ] + ) + ) + + +@dataclasses.dataclass(frozen=True) +class RstFrame(AshFrame): + MASK = 0b11111111 + MASK_VALUE = 0b11000000 + + @classmethod + def from_bytes(cls, data: bytes) -> RstFrame: + control, data = cls._unwrap(data) + + if data: + raise ParsingError(f"Invalid data for RST frame: {data!r}") + + return cls() + + def to_bytes(self) -> bytes: + return self.append_crc(bytes([self.MASK_VALUE])) + + +@dataclasses.dataclass(frozen=True) +class RStackFrame(AshFrame): + MASK = 0b11111111 + MASK_VALUE = 0b11000001 + + version: t.uint8_t + reset_code: t.NcpResetCode + + @classmethod + def from_bytes(cls, data: bytes) -> RStackFrame: + control, data = cls._unwrap(data) + + if len(data) != 2: + raise ParsingError(f"Invalid data length for RSTACK frame: {data!r}") + + version = data[0] + + if version != 0x02: + raise ParsingError(f"Invalid version for RSTACK frame: {data!r}") + + reset_code = t.NcpResetCode(data[1]) + + return cls( + version=version, + reset_code=reset_code, + ) + + def to_bytes(self) -> bytes: + return self.append_crc(bytes([self.MASK_VALUE, self.version, self.reset_code])) + + +@dataclasses.dataclass(frozen=True) +class ErrorFrame(AshFrame): + MASK = 0b11111111 + MASK_VALUE = 0b11000010 + + version: t.uint8_t + reset_code: t.NcpResetCode + + # We do not want to inherit from `RStackFrame` + from_bytes = classmethod(RStackFrame.from_bytes.__func__) + to_bytes = RStackFrame.to_bytes + + +def parse_frame( + data: bytes, +) -> DataFrame | AckFrame | NakFrame | RstFrame | RStackFrame | ErrorFrame: + """Parse a frame from the given data, looking at the control byte.""" + control_byte = data[0] + + # In order of use + for frame in [ + DataFrame, + AckFrame, + NakFrame, + RstFrame, + RStackFrame, + ErrorFrame, + ]: + if control_byte & frame.MASK == frame.MASK_VALUE: + return frame.from_bytes(data) + else: + raise ParsingError(f"Could not determine frame type: {data!r}") + + +class AshProtocol(asyncio.Protocol): + def __init__(self, ezsp_protocol) -> None: + self._ezsp_protocol = ezsp_protocol + self._transport = None + self._buffer = bytearray() + self._discarding_until_next_flag: bool = False + self._pending_data_frames: dict[int, asyncio.Future] = {} + self._send_data_frame_semaphore = asyncio.Semaphore(TX_K) + self._tx_seq: int = 0 + self._rx_seq: int = 0 + self._t_rx_ack = T_RX_ACK_INIT + + self._ncp_reset_code: t.NcpResetCode | None = None + self._ncp_state: NcpState = NcpState.CONNECTED + + def connection_made(self, transport): + self._transport = transport + self._ezsp_protocol.connection_made(self) + + def connection_lost(self, exc): + self._ezsp_protocol.connection_lost(exc) + + def eof_received(self): + self._ezsp_protocol.eof_received() + + def close(self): + if self._transport is not None: + self._transport.close() + + @staticmethod + def _stuff_bytes(data: bytes) -> bytes: + """Stuff bytes for transmission""" + out = bytearray() + + for c in data: + if c in RESERVED_BYTES: + out.extend([Reserved.ESCAPE, c ^ 0b00100000]) + else: + out.append(c) + + return out + + @staticmethod + def _unstuff_bytes(data: bytes) -> bytes: + """Unstuff bytes after receipt""" + out = bytearray() + escaped = False + + for c in data: + if escaped: + byte = c ^ 0b00100000 + assert byte in RESERVED_BYTES + out.append(byte) + escaped = False + elif c == Reserved.ESCAPE: + escaped = True + else: + out.append(c) + + return out + + def data_received(self, data: bytes) -> None: + _LOGGER.debug("Received data %s", data.hex()) + self._buffer.extend(data) + + if len(self._buffer) > MAX_BUFFER_SIZE: + _LOGGER.debug( + "Truncating buffer to %s bytes, it is growing too fast", MAX_BUFFER_SIZE + ) + self._buffer = self._buffer[:MAX_BUFFER_SIZE] + + while self._buffer: + if self._discarding_until_next_flag: + if bytes([Reserved.FLAG]) not in self._buffer: + self._buffer.clear() + break + + self._discarding_until_next_flag = False + _, _, self._buffer = self._buffer.partition(bytes([Reserved.FLAG])) + + try: + # Find the index of the first reserved byte that isn't an escape byte + reserved_index, reserved_byte = next( + (index, byte) + for index, byte in enumerate(self._buffer) + if byte in RESERVED_WITHOUT_ESCAPE + ) + except StopIteration: + break + + if reserved_byte == Reserved.FLAG: + # Flag Byte marks the end of a frame + frame_bytes = self._buffer[:reserved_index] + self._buffer = self._buffer[reserved_index + 1 :] + + # Consecutive EOFs can be received, empty frames are ignored + if not frame_bytes: + continue + + data = self._unstuff_bytes(frame_bytes) + + try: + frame = parse_frame(data) + except Exception: + _LOGGER.debug( + "Failed to parse frame %r", frame_bytes, exc_info=True + ) + else: + self.frame_received(frame) + elif reserved_byte == Reserved.CANCEL: + _LOGGER.debug("Received cancel byte, clearing buffer") + # All data received since the previous Flag Byte to be ignored + self._buffer = self._buffer[reserved_index + 1 :] + elif reserved_byte == Reserved.SUBSTITUTE: + _LOGGER.debug("Received substitute byte, marking buffer as corrupted") + # The data between the previous and the next Flag Byte is ignored + self._discarding_until_next_flag = True + self._buffer = self._buffer[reserved_index + 1 :] + elif reserved_byte == Reserved.XON: + # Resume transmission: not implemented! + _LOGGER.debug("Received XON byte, resuming transmission") + self._buffer.pop(reserved_index) + elif reserved_byte == Reserved.XOFF: + # Pause transmission: not implemented! + _LOGGER.debug("Received XOFF byte, pausing transmission") + self._buffer.pop(reserved_index) + else: + raise RuntimeError( + f"Unexpected reserved byte found: 0x{reserved_byte:02X}" + ) # pragma: no cover + + def _handle_ack(self, frame: DataFrame | AckFrame) -> None: + # Note that ackNum is the number of the next frame the receiver expects and it + # is one greater than the last frame received. + ack_num = (frame.ack_num - 1) % 8 + + fut = self._pending_data_frames.get(ack_num) + + if fut is None or fut.done(): + return + + # _LOGGER.debug("Resolving frame %d", ack_num) + self._pending_data_frames[ack_num].set_result(True) + + def frame_received(self, frame: AshFrame) -> None: + _LOGGER.debug("Received frame %r", frame) + + if isinstance(frame, DataFrame): + self.data_frame_received(frame) + elif isinstance(frame, RStackFrame): + self.rstack_frame_received(frame) + elif isinstance(frame, AckFrame): + self.ack_frame_received(frame) + elif isinstance(frame, NakFrame): + self.nak_frame_received(frame) + elif isinstance(frame, RstFrame): + self.rst_frame_received(frame) + elif isinstance(frame, ErrorFrame): + self.error_frame_received(frame) + else: + raise TypeError(f"Unknown frame received: {frame}") # pragma: no cover + + def data_frame_received(self, frame: DataFrame) -> None: + # The Host may not piggyback acknowledgments and should promptly send an ACK + # frame when it receives a DATA frame. + if frame.frm_num == self._rx_seq: + self._handle_ack(frame) + self._rx_seq = (frame.frm_num + 1) % 8 + self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) + + self._ezsp_protocol.data_received(frame.ezsp_frame) + elif frame.re_tx: + # Retransmitted frames must be immediately ACKed even if they are out of + # sequence + self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) + else: + _LOGGER.debug("Received an out of sequence frame: %r", frame) + self._write_frame(NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) + + def rstack_frame_received(self, frame: RStackFrame) -> None: + self._ncp_reset_code = None + self._ncp_state = NcpState.CONNECTED + + self._tx_seq = 0 + self._rx_seq = 0 + self._change_ack_timeout(T_RX_ACK_INIT) + self._ezsp_protocol.reset_received(frame.reset_code) + + def ack_frame_received(self, frame: AckFrame) -> None: + self._handle_ack(frame) + + def nak_frame_received(self, frame: NakFrame) -> None: + err = NotAcked(frame=frame) + + for fut in self._pending_data_frames.values(): + if not fut.done(): + fut.set_exception(err) + + def rst_frame_received(self, frame: RstFrame) -> None: + self._ncp_reset_code = None + self._ncp_state = NcpState.CONNECTED + + def error_frame_received(self, frame: ErrorFrame) -> None: + _LOGGER.debug("NCP has entered failed state: %s", frame.reset_code) + self._ncp_reset_code = frame.reset_code + self._ncp_state = NcpState.FAILED + + # Cancel all pending requests + exc = NcpFailure(code=self._ncp_reset_code) + + for fut in self._pending_data_frames.values(): + if not fut.done(): + fut.set_exception(exc) + + self._ezsp_protocol.reset_received(frame.reset_code) + + def _write_frame( + self, + frame: AshFrame, + *, + prefix: tuple[Reserved] = (), + suffix: tuple[Reserved] = (Reserved.FLAG,), + ) -> None: + if _LOGGER.isEnabledFor(logging.DEBUG): + prefix_str = "".join([f"{r.name} + " for r in prefix]) + suffix_str = "".join([f" + {r.name}" for r in suffix]) + _LOGGER.debug("Sending frame %s%r%s", prefix_str, frame, suffix_str) + + data = bytes(prefix) + self._stuff_bytes(frame.to_bytes()) + bytes(suffix) + _LOGGER.debug("Sending data %s", data.hex()) + self._transport.write(data) + + def _change_ack_timeout(self, new_value: float) -> None: + new_value = max(T_RX_ACK_MIN, min(new_value, T_RX_ACK_MAX)) + + if abs(new_value - self._t_rx_ack) > 0.01: + _LOGGER.debug( + "Changing ACK timeout from %0.2f to %0.2f", self._t_rx_ack, new_value + ) + + self._t_rx_ack = new_value + + async def _send_data_frame(self, frame: AshFrame) -> None: + if self._send_data_frame_semaphore.locked(): + _LOGGER.debug("Semaphore is locked, waiting") + + async with self._send_data_frame_semaphore: + frm_num = None + + try: + for attempt in range(ACK_TIMEOUTS): + if self._ncp_state == NcpState.FAILED: + _LOGGER.debug( + "NCP is in a failed state, not re-sending: %r", frame + ) + raise NcpFailure( + t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT + ) + + if frm_num is None: + frm_num = self._tx_seq + self._tx_seq = (self._tx_seq + 1) % 8 + + # Use a fresh ACK number on every retry + frame = frame.replace( + frm_num=frm_num, + re_tx=(attempt > 0), + ack_num=self._rx_seq, + ) + + send_time = time.monotonic() + + ack_future = asyncio.get_running_loop().create_future() + self._pending_data_frames[frm_num] = ack_future + self._write_frame(frame) + + try: + async with asyncio_timeout(self._t_rx_ack): + await ack_future + except NotAcked: + _LOGGER.debug( + "NCP responded with NAK. Retrying (attempt %d)", attempt + 1 + ) + + # For timing purposes, NAK can be treated as an ACK + delta = time.monotonic() - send_time + self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) + + if attempt >= ACK_TIMEOUTS - 1: + raise + except NcpFailure: + _LOGGER.debug( + "NCP has entered into a failed state, not retrying" + ) + raise + except asyncio.TimeoutError: + _LOGGER.debug( + "No ACK received in %0.2fs (attempt %d)", + self._t_rx_ack, + attempt + 1, + ) + # If a DATA frame acknowledgement is not received within the + # current timeout value, then t_rx_ack is doubled. + self._change_ack_timeout(2 * self._t_rx_ack) + + if attempt >= ACK_TIMEOUTS - 1: + raise + else: + # Whenever an acknowledgement is received, t_rx_ack is set to + # 7/8 of its current value plus 1/2 of the measured time for the + # acknowledgement. + delta = time.monotonic() - send_time + self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta) + + break + finally: + if frm_num is not None: + self._pending_data_frames.pop(frm_num) + + async def send_data(self, data: bytes) -> None: + await self._send_data_frame( + # All of the other fields will be set during transmission/retries + DataFrame(frm_num=None, re_tx=None, ack_num=None, ezsp_frame=data) + ) + + def send_reset(self) -> None: + self._write_frame(RstFrame(), prefix=(Reserved.CANCEL,)) diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index ef7157a0..d744a21c 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import abc import asyncio import binascii import functools import logging import sys -from typing import Any, Callable, Tuple +from typing import TYPE_CHECKING, Any, Callable if sys.version_info[:2] < (3, 11): from async_timeout import timeout as asyncio_timeout # pragma: no cover @@ -14,7 +16,9 @@ from bellows.config import CONF_EZSP_POLICIES from bellows.exception import InvalidCommandError import bellows.types as t -from bellows.typing import GatewayType + +if TYPE_CHECKING: + from bellows.uart import Gateway LOGGER = logging.getLogger(__name__) EZSP_CMD_TIMEOUT = 6 # Sum of all ASH retry timeouts: 0.4 + 0.8 + 1.6 + 3.2 @@ -26,7 +30,7 @@ class ProtocolHandler(abc.ABC): COMMANDS = {} VERSION = None - def __init__(self, cb_handler: Callable, gateway: GatewayType) -> None: + def __init__(self, cb_handler: Callable, gateway: Gateway) -> None: self._handle_callback = cb_handler self._awaiting = {} self._gw = gateway @@ -37,7 +41,7 @@ def __init__(self, cb_handler: Callable, gateway: GatewayType) -> None: } self.tc_policy = 0 - def _ezsp_frame(self, name: str, *args: Tuple[Any, ...]) -> bytes: + def _ezsp_frame(self, name: str, *args: tuple[Any, ...]) -> bytes: """Serialize the named frame and data.""" c = self.COMMANDS[name] frame = self._ezsp_frame_tx(name) @@ -45,7 +49,7 @@ def _ezsp_frame(self, name: str, *args: Tuple[Any, ...]) -> bytes: return frame + data @abc.abstractmethod - def _ezsp_frame_rx(self, data: bytes) -> Tuple[int, int, bytes]: + def _ezsp_frame_rx(self, data: bytes) -> tuple[int, int, bytes]: """Handler for received data frame.""" @abc.abstractmethod @@ -62,15 +66,15 @@ async def add_transient_link_key( async def command(self, name, *args) -> Any: """Serialize command and send it.""" - LOGGER.debug("Send command %s: %s", name, args) + LOGGER.debug("Sending command %s: %s", name, args) data = self._ezsp_frame(name, *args) - self._gw.data(data) cmd_id, _, rx_schema = self.COMMANDS[name] future = asyncio.get_running_loop().create_future() self._awaiting[self._seq] = (cmd_id, rx_schema, future) self._seq = (self._seq + 1) % 256 async with asyncio_timeout(EZSP_CMD_TIMEOUT): + await self._gw.send_data(data) return await future async def update_policies(self, policy_config: dict) -> None: @@ -110,7 +114,7 @@ def __call__(self, data: bytes) -> None: ) raise - LOGGER.debug("Application frame received %s: %s", frame_name, result) + LOGGER.debug("Received command %s: %s", frame_name, result) if data: LOGGER.debug("Frame contains trailing data: %s", data) diff --git a/bellows/uart.py b/bellows/uart.py index bf068a3e..ee2aea08 100644 --- a/bellows/uart.py +++ b/bellows/uart.py @@ -1,8 +1,6 @@ import asyncio -import binascii import logging import sys -import time if sys.version_info[:2] < (3, 11): from async_timeout import timeout as asyncio_timeout # pragma: no cover @@ -12,18 +10,13 @@ import zigpy.config import zigpy.serial +from bellows.ash import AshProtocol from bellows.thread import EventLoopThread, ThreadsafeProxy import bellows.types as t LOGGER = logging.getLogger(__name__) RESET_TIMEOUT = 5 -ASH_ACK_RETRIES = 4 - -ASH_RX_ACK_INIT = 1.6 -ASH_RX_ACK_MIN = 0.4 -ASH_RX_ACK_MAX = 3.2 - class Gateway(asyncio.Protocol): FLAG = b"\x7E" # Marks end of frame @@ -42,129 +35,33 @@ class Terminator: pass def __init__(self, application, connected_future=None, connection_done_future=None): - self._send_seq = 0 - self._rec_seq = 0 - self._buffer = b"" self._application = application + self._reset_future = None self._startup_reset_future = None self._connected_future = connected_future - self._sendq = asyncio.Queue() - self._pending = (-1, None) self._connection_done_future = connection_done_future - self._send_task = None - self._ack_timeout = ASH_RX_ACK_INIT + self._transport = None + + def close(self): + self._transport.close() def connection_made(self, transport): """Callback when the uart is connected""" self._transport = transport if self._connected_future is not None: self._connected_future.set_result(True) - self._send_task = asyncio.create_task(self._send_loop()) + + async def send_data(self, data: bytes) -> None: + await self._transport.send_data(data) def data_received(self, data): """Callback when there is data received from the uart""" - # TODO: Fix this handling for multiple instances of the characters - # If a Cancel Byte or Substitute Byte is received, the bytes received - # so far are discarded. In the case of a Substitute Byte, subsequent - # bytes will also be discarded until the next Flag Byte. - if self.CANCEL in data: - self._buffer = b"" - data = data[data.rfind(self.CANCEL) + 1 :] - if self.SUBSTITUTE in data: - self._buffer = b"" - data = data[data.find(self.FLAG) + 1 :] - - self._buffer += data - while self._buffer: - frame, self._buffer = self._extract_frame(self._buffer) - if frame is None: - break - self.frame_received(frame) - - def _extract_frame(self, data): - """Extract a frame from the data buffer""" - if self.FLAG in data: - place = data.find(self.FLAG) - frame = self._unstuff(data[: place + 1]) - rest = data[place + 1 :] - crc = binascii.crc_hqx(frame[:-3], 0xFFFF) - crc = bytes([crc >> 8, crc % 256]) - if crc != frame[-3:-1]: - LOGGER.error( - "CRC error in frame %s (%s != %s)", - binascii.hexlify(frame), - binascii.hexlify(frame[-3:-1]), - binascii.hexlify(crc), - ) - self.write(self._nak_frame()) - # Make sure that we also handle the next frame if it is already received - return self._extract_frame(rest) - - return frame, rest - return None, data - - def frame_received(self, data): - """Frame receive handler""" - if (data[0] & 0b10000000) == 0: - self.data_frame_received(data) - elif (data[0] & 0b11100000) == 0b10000000: - self.ack_frame_received(data) - elif (data[0] & 0b11100000) == 0b10100000: - self.nak_frame_received(data) - elif data[0] == 0b11000000: - self.rst_frame_received(data) - elif data[0] == 0b11000001: - self.rstack_frame_received(data) - elif data[0] == 0b11000010: - self.error_frame_received(data) - else: - LOGGER.error("UNKNOWN FRAME RECEIVED: %r", data) # TODO - - def data_frame_received(self, data): - """Data frame receive handler""" - LOGGER.debug("Data frame: %s", binascii.hexlify(data)) - seq = (data[0] & 0b01110000) >> 4 - re_tx = (data[0] & 0b00001000) >> 3 - - if seq == self._rec_seq: - self._rec_seq = (seq + 1) % 8 - self.write(self._ack_frame()) - - self._handle_ack(data[0]) - self._application.frame_received(self._randomize(data[1:-3])) - elif re_tx: - self.write(self._ack_frame()) - else: - self.write(self._nak_frame()) - - def ack_frame_received(self, data): - """Acknowledgement frame receive handler""" - LOGGER.debug("ACK frame: %s", binascii.hexlify(data)) - self._handle_ack(data[0]) - - def nak_frame_received(self, data): - """Negative acknowledgement frame receive handler""" - LOGGER.debug("NAK frame: %s", binascii.hexlify(data)) - self._handle_nak(data[0]) - - def rst_frame_received(self, data): - """Reset frame handler""" - LOGGER.debug("RST frame: %s", binascii.hexlify(data)) + self._application.frame_received(data) - def rstack_frame_received(self, data): + def reset_received(self, code: t.NcpResetCode) -> None: """Reset acknowledgement frame receive handler""" - self._send_seq = 0 - self._rec_seq = 0 - code, version = self._get_error_code(data) - - LOGGER.debug( - "RSTACK Version: %d Reason: %s frame: %s", - version, - code.name, - binascii.hexlify(data), - ) # not a reset we've requested. Signal application reset if code is not t.NcpResetCode.RESET_SOFTWARE: self._application.enter_failed_state(code) @@ -177,6 +74,10 @@ def rstack_frame_received(self, data): else: LOGGER.warning("Received an unexpected reset: %r", code) + def error_received(self, code: t.NcpResetCode) -> None: + """Error frame receive handler.""" + self._application.enter_failed_state(code) + async def wait_for_startup_reset(self) -> None: """Wait for the first reset frame on startup.""" assert self._startup_reset_future is None @@ -187,31 +88,6 @@ async def wait_for_startup_reset(self) -> None: finally: self._startup_reset_future = None - @staticmethod - def _get_error_code(data): - """Extracts error code from RSTACK or ERROR frames.""" - return t.NcpResetCode(data[2]), data[1] - - def error_frame_received(self, data): - """Error frame receive handler.""" - error_code, version = self._get_error_code(data) - LOGGER.debug( - "Error code: %s, Version: %d, frame: %s", - error_code.name, - version, - binascii.hexlify(data), - ) - self._application.enter_failed_state(error_code) - - def write(self, data): - """Send data to the uart""" - LOGGER.debug("Sending: %s", binascii.hexlify(data)) - self._transport.write(data) - - def close(self): - self._sendq.put_nowait(self.Terminator) - self._transport.close() - def _reset_cleanup(self, future): """Delete reset future.""" self._reset_future = None @@ -241,10 +117,6 @@ def connection_lost(self, exc): self._reset_future.set_exception(reason) self._reset_future = None - if self._send_task: - self._send_task.cancel() - self._send_task = None - if exc is None: LOGGER.debug("Closed serial connection") return @@ -261,192 +133,29 @@ async def reset(self): ) return await self._reset_future - self._send_seq = 0 - self._rec_seq = 0 - self._buffer = b"" - while not self._sendq.empty(): - self._sendq.get_nowait() - if self._pending[1]: - self._pending[1].set_result(True) - self._pending = (-1, None) - + self._transport.send_reset() self._reset_future = asyncio.get_event_loop().create_future() self._reset_future.add_done_callback(self._reset_cleanup) - self.write(self._rst_frame()) async with asyncio_timeout(RESET_TIMEOUT): return await self._reset_future - async def _send_loop(self): - """Send queue handler""" - while True: - item = await self._sendq.get() - if item is self.Terminator: - break - data, seq = item - - for attempt in range(ASH_ACK_RETRIES + 1): - self._pending = (seq, asyncio.get_event_loop().create_future()) - - send_time = time.monotonic() - rxmit = attempt > 0 - self.write(self._data_frame(data, seq, rxmit)) - - try: - async with asyncio_timeout(self._ack_timeout): - success = await self._pending[1] - except asyncio.TimeoutError: - success = None - LOGGER.debug( - "Frame %s (seq %s) timed out on attempt %d, retrying", - data, - seq, - attempt, - ) - else: - if success: - break - - LOGGER.debug( - "Frame %s (seq %s) failed to transmit on attempt %d, retrying", - data, - seq, - attempt, - ) - finally: - delta = time.monotonic() - send_time - - if success is not None: - new_ack_timeout = max( - ASH_RX_ACK_MIN, - min( - ASH_RX_ACK_MAX, - (7 / 8) * self._ack_timeout + 0.5 * delta, - ), - ) - else: - new_ack_timeout = max( - ASH_RX_ACK_MIN, min(ASH_RX_ACK_MAX, 2 * self._ack_timeout) - ) - - if abs(self._ack_timeout - new_ack_timeout) > 0.01: - LOGGER.debug( - "Adjusting ACK timeout from %.2f to %.2f", - self._ack_timeout, - new_ack_timeout, - ) - - self._ack_timeout = new_ack_timeout - self._pending = (-1, None) - else: - self.connection_lost( - ConnectionResetError( - f"Failed to transmit ASH frame after {ASH_ACK_RETRIES} retries" - ) - ) - return - - def _handle_ack(self, control): - """Handle an acknowledgement frame""" - ack = ((control & 0b00000111) - 1) % 8 - if ack == self._pending[0]: - pending, self._pending = self._pending, (-1, None) - pending[1].set_result(True) - - def _handle_nak(self, control): - """Handle negative acknowledgment frame""" - nak = control & 0b00000111 - if nak == self._pending[0]: - self._pending[1].set_result(False) - - def data(self, data): - """Send a data frame""" - seq = self._send_seq - self._send_seq = (seq + 1) % 8 - self._sendq.put_nowait((data, seq)) - - def _data_frame(self, data, seq, rxmit): - """Construct a data frame""" - assert 0 <= seq <= 7 - assert 0 <= rxmit <= 1 - control = (seq << 4) | (rxmit << 3) | self._rec_seq - return self._frame(bytes([control]), self._randomize(data)) - - def _ack_frame(self): - """Construct a acknowledgement frame""" - assert 0 <= self._rec_seq < 8 - control = bytes([0b10000000 | (self._rec_seq & 0b00000111)]) - return self._frame(control, b"") - - def _nak_frame(self): - """Construct a negative acknowledgement frame""" - assert 0 <= self._rec_seq < 8 - control = bytes([0b10100000 | (self._rec_seq & 0b00000111)]) - return self._frame(control, b"") - - def _rst_frame(self): - """Construct a reset frame""" - return self.CANCEL + self._frame(b"\xC0", b"") - - def _frame(self, control, data): - """Construct a frame""" - crc = binascii.crc_hqx(control + data, 0xFFFF) - crc = bytes([crc >> 8, crc % 256]) - return self._stuff(control + data + crc) + self.FLAG - - def _randomize(self, s): - """XOR s with a pseudo-random sequence for transmission - - Used only in data frames - """ - rand = self.RANDOMIZE_START - out = b"" - for c in s: - out += bytes([c ^ rand]) - if rand % 2: - rand = (rand >> 1) ^ self.RANDOMIZE_SEQ - else: - rand = rand >> 1 - return out - - def _stuff(self, s): - """Byte stuff (escape) a string for transmission""" - out = b"" - for c in s: - if c in self.RESERVED: - out += self.ESCAPE + bytes([c ^ self.STUFF]) - else: - out += bytes([c]) - return out - - def _unstuff(self, s): - """Unstuff (unescape) a string after receipt""" - out = b"" - escaped = False - for c in s: - if escaped: - out += bytes([c ^ self.STUFF]) - escaped = False - elif c in self.ESCAPE: - escaped = True - else: - out += bytes([c]) - return out - async def _connect(config, application): loop = asyncio.get_event_loop() connection_future = loop.create_future() connection_done_future = loop.create_future() - protocol = Gateway(application, connection_future, connection_done_future) + + gateway = Gateway(application, connection_future, connection_done_future) + protocol = AshProtocol(gateway) if config[zigpy.config.CONF_DEVICE_FLOW_CONTROL] is None: xon_xoff, rtscts = True, False else: xon_xoff, rtscts = False, True - transport, protocol = await zigpy.serial.create_serial_connection( + transport, _ = await zigpy.serial.create_serial_connection( loop, lambda: protocol, url=config[zigpy.config.CONF_DEVICE_PATH], @@ -457,7 +166,7 @@ async def _connect(config, application): await connection_future - thread_safe_protocol = ThreadsafeProxy(protocol, loop) + thread_safe_protocol = ThreadsafeProxy(gateway, loop) return thread_safe_protocol, connection_done_future diff --git a/pyproject.toml b/pyproject.toml index 674197b3..9ecfb1f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,15 +55,17 @@ ignore_errors = true asyncio_mode = "auto" [tool.flake8] -exclude = [".venv", ".git", ".tox", "docs", "venv", "bin", "lib", "deps", "build"] +exclude = ".venv,.git,.tox,docs,venv,bin,lib,deps,build" # To work with Black max-line-length = 88 # W503: Line break occurred before a binary operator # E203: Whitespace before ':' # E501: line too long # D202 No blank lines allowed after function docstring -ignore = ["W503", "E203", "E501", "D202"] -per-file-ignores = ["tests/*:F811,F401,F403"] +ignore = "W503,E203,E501,D202" +per-file-ignores = """ + tests/*:F811,F401,F403 +""" [tool.coverage.run] source = ["bellows"] @@ -71,3 +73,13 @@ omit = [ "bellows/cli/*.py", "bellows/typing.py", ] + + +[tool.coverage.report] +exclude_also = [ + "raise AssertionError", + "raise NotImplementedError", + "if TYPE_CHECKING", + "if typing.TYPE_CHECKING", + "@(abc\\.)?abstractmethod", +] \ No newline at end of file diff --git a/tests/test_ash.py b/tests/test_ash.py new file mode 100644 index 00000000..41000ce3 --- /dev/null +++ b/tests/test_ash.py @@ -0,0 +1,593 @@ +from __future__ import annotations + +import asyncio +import logging +import random +from unittest.mock import MagicMock, call, patch + +import pytest + +from bellows import ash +import bellows.types as t + + +@pytest.fixture(autouse=True, scope="function") +def random_seed(): + random.seed(0) + + +class AshNcpProtocol(ash.AshProtocol): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.nak_state = False + + def frame_received(self, frame: ash.AshFrame) -> None: + if self._ncp_reset_code is not None and not isinstance(frame, ash.RstFrame): + ash._LOGGER.debug( + "NCP in failure state %r, ignoring frame: %r", + self._ncp_reset_code, + frame, + ) + self._write_frame( + ash.ErrorFrame(version=2, reset_code=self._ncp_reset_code) + ) + return + + if self.nak_state: + asyncio.get_running_loop().call_later( + 2 * self._t_rx_ack, + lambda: self._write_frame( + ash.NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq) + ), + ) + return + + super().frame_received(frame) + + def _enter_ncp_error_state(self, code: t.NcpResetCode | None) -> None: + self._ncp_reset_code = code + + if code is None: + self._ncp_state = ash.NcpState.CONNECTED + else: + self._ncp_state = ash.NcpState.FAILED + + ash._LOGGER.debug("Changing connectivity state: %r", self._ncp_state) + ash._LOGGER.debug("Changing reset code: %r", self._ncp_reset_code) + + if self._ncp_state == ash.NcpState.FAILED: + self._write_frame( + ash.ErrorFrame(version=2, reset_code=self._ncp_reset_code) + ) + + def rst_frame_received(self, frame: ash.RstFrame) -> None: + super().rst_frame_received(frame) + + self._tx_seq = 0 + self._rx_seq = 0 + self._change_ack_timeout(ash.T_RX_ACK_INIT) + + self._enter_ncp_error_state(None) + self._write_frame( + ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE) + ) + + async def _send_data_frame(self, frame: ash.AshFrame) -> None: + try: + return await super()._send_data_frame(frame) + except asyncio.TimeoutError: + self._enter_ncp_error_state( + t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT + ) + raise + + def send_reset(self) -> None: + raise NotImplementedError() + + +class FakeTransport: + def __init__(self, receiver): + self.receiver = receiver + self.paused = False + + def write(self, data: bytes) -> None: + if not self.paused: + self.receiver.data_received(data) + + +class FakeTransportOneByteAtATime(FakeTransport): + def write(self, data: bytes) -> None: + for byte in data: + super().write(bytes([byte])) + + +class FakeTransportRandomLoss(FakeTransport): + def write(self, data: bytes) -> None: + if random.random() < 0.20: + return + + super().write(data) + + +class FakeTransportWithDelays(FakeTransport): + def write(self, data): + asyncio.get_running_loop().call_later(0, super().write, data) + + +def test_ash_exception_repr() -> None: + assert ( + repr(ash.NotAcked(ash.NakFrame(res=0, ncp_ready=0, ack_num=1))) + == "" + ) + assert ( + repr(ash.NcpFailure(t.NcpResetCode.RESET_SOFTWARE)) + == ")>" + ) + + +@pytest.mark.parametrize( + "frame", + [ + ash.RstFrame(), + ash.AckFrame(res=0, ncp_ready=0, ack_num=1), + ash.NakFrame(res=0, ncp_ready=0, ack_num=1), + ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE), + ash.ErrorFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE), + ash.DataFrame(frm_num=0, re_tx=False, ack_num=1, ezsp_frame=b"test"), + ], +) +def test_parse_frame(frame: ash.AshFrame) -> None: + assert ash.parse_frame(frame.to_bytes()) == frame + + +def test_parse_frame_failure() -> None: + with pytest.raises(ash.ParsingError): + ash.parse_frame(b"test") + + +def test_ash_protocol_event_propagation() -> None: + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + + err = RuntimeError("test") + protocol.connection_lost(err) + assert ezsp.connection_lost.mock_calls == [call(err)] + + protocol.eof_received() + assert ezsp.eof_received.mock_calls == [call()] + + +def test_stuffing(): + assert ash.AshProtocol._stuff_bytes(b"\x7E") == b"\x7D\x5E" + assert ash.AshProtocol._stuff_bytes(b"\x11") == b"\x7D\x31" + assert ash.AshProtocol._stuff_bytes(b"\x13") == b"\x7D\x33" + assert ash.AshProtocol._stuff_bytes(b"\x18") == b"\x7D\x38" + assert ash.AshProtocol._stuff_bytes(b"\x1A") == b"\x7D\x3A" + assert ash.AshProtocol._stuff_bytes(b"\x7D") == b"\x7D\x5D" + + assert ash.AshProtocol._unstuff_bytes(b"\x7D\x5E") == b"\x7E" + assert ash.AshProtocol._unstuff_bytes(b"\x7D\x31") == b"\x11" + assert ash.AshProtocol._unstuff_bytes(b"\x7D\x33") == b"\x13" + assert ash.AshProtocol._unstuff_bytes(b"\x7D\x38") == b"\x18" + assert ash.AshProtocol._unstuff_bytes(b"\x7D\x3A") == b"\x1A" + assert ash.AshProtocol._unstuff_bytes(b"\x7D\x5D") == b"\x7D" + + assert ash.AshProtocol._stuff_bytes(b"\x7F") == b"\x7F" + assert ash.AshProtocol._unstuff_bytes(b"\x7F") == b"\x7F" + + +def test_pseudo_random_data_sequence(): + assert ash.PSEUDO_RANDOM_DATA_SEQUENCE.startswith(b"\x42\x21\xA8\x54\x2A") + + +def test_frame_parsing_errors(): + with pytest.raises(ash.ParsingError, match=r"Frame is too short:"): + assert ash.RstFrame.from_bytes(b"\xC0\x38") + + with pytest.raises(ash.ParsingError, match=r"Invalid CRC bytes in frame"): + assert ash.RstFrame.from_bytes(b"\xC0\xAB\xCD") + + +def test_rst_frame(): + assert ash.RstFrame() == ash.RstFrame() + assert ash.RstFrame().to_bytes() == b"\xC0\x38\xBC" + assert ash.RstFrame.from_bytes(b"\xC0\x38\xBC") == ash.RstFrame() + assert str(ash.RstFrame()) == "RstFrame()" + + with pytest.raises(ash.ParsingError, match=r"Invalid data for RST frame:"): + ash.RstFrame.from_bytes(ash.AshFrame.append_crc(b"\xC0\xAB")) + + +def test_rstack_frame(): + frm = ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE) + + assert frm.to_bytes() == b"\xc1\x02\x0b\x0a\x52" + assert ash.RStackFrame.from_bytes(frm.to_bytes()) == frm + assert ( + str(frm) + == "RStackFrame(version=2, reset_code=)" + ) + + with pytest.raises( + ash.ParsingError, match=r"Invalid data length for RSTACK frame:" + ): + # Adding \xAB in the middle of the frame makes it invalid + ash.RStackFrame.from_bytes(ash.AshFrame.append_crc(b"\xc1\x02\xab\x0b")) + + with pytest.raises(ash.ParsingError, match=r"Invalid version for RSTACK frame:"): + # Version 3 is unknown + ash.RStackFrame.from_bytes(ash.AshFrame.append_crc(b"\xc1\x03\x0b")) + + +def test_cancel_byte(): + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + protocol.frame_received = MagicMock(wraps=protocol.frame_received) + + protocol.data_received(bytes.fromhex("ddf9ff")) + protocol.data_received(bytes.fromhex("1ac1020b0a527e")) # starts with a CANCEL byte + + # We still parse out the RSTACK frame + assert protocol.frame_received.mock_calls == [ + call(ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE)) + ] + assert not protocol._buffer + + # Parse byte-by-byte + protocol.frame_received.reset_mock() + + for byte in bytes.fromhex("ddf9ff 1ac1020b0a527e"): + protocol.data_received(bytes([byte])) + + # We still parse out the RSTACK frame + assert protocol.frame_received.mock_calls == [ + call(ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE)) + ] + assert not protocol._buffer + + +def test_substitute_byte(): + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + protocol.frame_received = MagicMock(wraps=protocol.frame_received) + + protocol.data_received(bytes.fromhex("c0 38bc 7e")) # RST frame + assert protocol.frame_received.mock_calls == [call(ash.RstFrame())] + protocol.data_received(bytes.fromhex("c0 18 38bc 7e")) # RST frame + SUBSTITUTE + assert protocol.frame_received.mock_calls == [call(ash.RstFrame())] # ignored! + protocol.data_received(bytes.fromhex("c1 020b 0a52 7e")) # RSTACK frame + assert protocol.frame_received.mock_calls == [ + call(ash.RstFrame()), + call(ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE)), + ] + + assert not protocol._buffer + + +def test_xon_xoff_bytes(): + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + protocol.frame_received = MagicMock(wraps=protocol.frame_received) + + protocol.data_received(bytes.fromhex("c0 11 38bc 13 7e")) # RST frame + XON + XOFF + assert protocol.frame_received.mock_calls == [call(ash.RstFrame())] + + +def test_multiple_eof(): + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + protocol.frame_received = MagicMock(wraps=protocol.frame_received) + + protocol.data_received(bytes.fromhex("c0 38bc 7e 7e 7e")) # RST frame + assert protocol.frame_received.mock_calls == [call(ash.RstFrame())] + + +def test_discarding(): + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + protocol.frame_received = MagicMock(wraps=protocol.frame_received) + + # RST frame with embedded SUBSTITUTE + protocol.data_received(bytes.fromhex("c0 18 38bc")) + + # Garbage: still ignored + protocol.data_received(bytes.fromhex("aa bb cc dd ee ff")) + + # Frame boundary: we now will handle data + protocol.data_received(bytes.fromhex("7e")) + + # Normal RST frame + protocol.data_received(bytes.fromhex("c0 38bc 7e 7e 7e")) + assert protocol.frame_received.mock_calls == [call(ash.RstFrame())] + + +def test_buffer_growth(): + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + + # Receive a lot of bogus data + for i in range(1000): + protocol.data_received(b"\xEE" * 100) + + # Make sure our internal buffer doesn't blow up + assert len(protocol._buffer) == ash.MAX_BUFFER_SIZE + + +async def test_sequence(): + loop = asyncio.get_running_loop() + ezsp = MagicMock() + transport = MagicMock() + + protocol = ash.AshProtocol(ezsp) + protocol._write_frame = MagicMock(wraps=protocol._write_frame) + protocol.connection_made(transport) + + # Normal send/receive + loop.call_later( + 0, + protocol.frame_received, + ash.DataFrame(frm_num=0, re_tx=False, ack_num=1, ezsp_frame=b"rx 1"), + ) + await protocol.send_data(b"tx 1") + assert protocol._write_frame.mock_calls[-1] == call( + ash.AckFrame(res=0, ncp_ready=0, ack_num=1) + ) + + assert protocol._rx_seq == 1 + assert protocol._tx_seq == 1 + assert ezsp.data_received.mock_calls == [call(b"rx 1")] + + # Skip ACK 2: we are out of sync! + protocol.frame_received( + ash.DataFrame(frm_num=2, re_tx=False, ack_num=1, ezsp_frame=b"out of sequence") + ) + + # We NAK it, it is out of sequence! + assert protocol._write_frame.mock_calls[-1] == call( + ash.NakFrame(res=0, ncp_ready=0, ack_num=1) + ) + + # Sequence numbers remain intact + assert protocol._rx_seq == 1 + assert protocol._tx_seq == 1 + + # Re-sync properly + protocol.frame_received( + ash.DataFrame(frm_num=1, re_tx=False, ack_num=1, ezsp_frame=b"rx 2") + ) + + assert ezsp.data_received.mock_calls == [call(b"rx 1"), call(b"rx 2")] + + # Trigger an error + loop.call_later( + 0, + protocol.frame_received, + ash.ErrorFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE), + ) + + with pytest.raises(ash.NcpFailure): + await protocol.send_data(b"tx 2") + + +async def test_frame_parsing_failure_recovery(caplog) -> None: + ezsp = MagicMock() + protocol = ash.AshProtocol(ezsp) + protocol.frame_received = MagicMock(spec_set=protocol.frame_received) + + protocol.data_received( + ash.DataFrame(frm_num=0, re_tx=0, ack_num=0, ezsp_frame=b"frame 1").to_bytes() + + bytes([ash.Reserved.FLAG]) + ) + + with caplog.at_level(logging.DEBUG): + protocol.data_received( + ash.AshFrame.append_crc(b"\xFESome unknown frame") + + bytes([ash.Reserved.FLAG]) + ) + + assert "Some unknown frame" in caplog.text + + protocol.data_received( + ash.DataFrame(frm_num=1, re_tx=0, ack_num=0, ezsp_frame=b"frame 2").to_bytes() + + bytes([ash.Reserved.FLAG]) + ) + + assert protocol.frame_received.mock_calls == [ + call(ash.DataFrame(frm_num=0, re_tx=0, ack_num=0, ezsp_frame=b"frame 1")), + call(ash.DataFrame(frm_num=1, re_tx=0, ack_num=0, ezsp_frame=b"frame 2")), + ] + + +async def test_ash_protocol_startup(caplog): + """Simple EZSP startup: reset, version(4), then version(8).""" + + # We have branching dependent on `_LOGGER.isEnabledFor` so test it here + caplog.set_level(logging.DEBUG) + + loop = asyncio.get_running_loop() + + ezsp = MagicMock() + transport = MagicMock() + + protocol = ash.AshProtocol(ezsp) + protocol._write_frame = MagicMock(wraps=protocol._write_frame) + protocol.connection_made(transport) + + assert ezsp.connection_made.mock_calls == [call(protocol)] + + assert protocol._rx_seq == 0 + assert protocol._tx_seq == 0 + + # ASH reset + protocol.send_reset() + loop.call_later( + 0, + protocol.frame_received, + ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_SOFTWARE), + ) + + await asyncio.sleep(0.01) + + assert ezsp.reset_received.mock_calls == [call(t.NcpResetCode.RESET_SOFTWARE)] + assert protocol._write_frame.mock_calls == [ + call(ash.RstFrame(), prefix=(ash.Reserved.CANCEL,)) + ] + + protocol._write_frame.reset_mock() + + # EZSP version(4) + loop.call_later( + 0, + protocol.frame_received, + ash.DataFrame( + frm_num=0, re_tx=False, ack_num=1, ezsp_frame=b"\x00\x80\x00\x08\x02\x80g" + ), + ) + await protocol.send_data(b"\x00\x00\x00\x04") + assert protocol._write_frame.mock_calls == [ + call( + ash.DataFrame( + frm_num=0, re_tx=False, ack_num=0, ezsp_frame=b"\x00\x00\x00\x04" + ) + ), + call(ash.AckFrame(res=0, ncp_ready=0, ack_num=1)), + ] + + protocol._write_frame.reset_mock() + + # EZSP version(8) + loop.call_later( + 0, + protocol.frame_received, + ash.DataFrame( + frm_num=1, + re_tx=False, + ack_num=2, + ezsp_frame=b"\x00\x80\x01\x00\x00\x08\x02\x80g", + ), + ) + await protocol.send_data(b"\x00\x00\x01\x00\x00\x08") + assert protocol._write_frame.mock_calls == [ + call( + ash.DataFrame( + frm_num=1, + re_tx=False, + ack_num=1, + ezsp_frame=b"\x00\x00\x01\x00\x00\x08", + ) + ), + call(ash.AckFrame(res=0, ncp_ready=0, ack_num=2)), + ] + + +@patch("bellows.ash.T_RX_ACK_INIT", ash.T_RX_ACK_INIT / 100) +@patch("bellows.ash.T_RX_ACK_MIN", ash.T_RX_ACK_MIN / 100) +@patch("bellows.ash.T_RX_ACK_MAX", ash.T_RX_ACK_MAX / 100) +@pytest.mark.parametrize( + "transport_cls", + [ + FakeTransport, + FakeTransportOneByteAtATime, + FakeTransportRandomLoss, + FakeTransportWithDelays, + ], +) +async def test_ash_end_to_end(transport_cls: type[FakeTransport]) -> None: + asyncio.get_running_loop() + + host_ezsp = MagicMock() + ncp_ezsp = MagicMock() + + host = ash.AshProtocol(host_ezsp) + ncp = AshNcpProtocol(ncp_ezsp) + + host_transport = transport_cls(ncp) + ncp_transport = transport_cls(host) + + host.connection_made(host_transport) + ncp.connection_made(ncp_transport) + + # Ping pong works + await asyncio.gather( + host.send_data(b"Hello 1!"), + host.send_data(b"Hello 2!"), + ) + assert ncp_ezsp.data_received.mock_calls == [call(b"Hello 1!"), call(b"Hello 2!")] + + await ncp.send_data(b"World!") + assert host_ezsp.data_received.mock_calls == [call(b"World!")] + + ncp_ezsp.data_received.reset_mock() + host_ezsp.data_received.reset_mock() + + # Let's pause the ncp so it can't ACK + with patch.object(ncp_transport, "paused", True): + send_task = asyncio.create_task(host.send_data(b"delayed")) + await asyncio.sleep(host._t_rx_ack * 2) + + # It'll still succeed + await send_task + + assert ncp_ezsp.data_received.mock_calls == [call(b"delayed")] + + ncp_ezsp.data_received.reset_mock() + host_ezsp.data_received.reset_mock() + + # Let's let a request fail due to a connectivity issue + with patch.object(ncp_transport, "paused", True): + send_task = asyncio.create_task(host.send_data(b"host failure")) + await asyncio.sleep(host._t_rx_ack * 15) + + with pytest.raises(asyncio.TimeoutError): + await send_task + + ncp_ezsp.data_received.reset_mock() + host_ezsp.data_received.reset_mock() + + # Simulate OOM on the NCP and send NAKs for a bit + with patch.object(ncp, "nak_state", True): + send_task = asyncio.create_task(host.send_data(b"ncp NAKing")) + await asyncio.sleep(host._t_rx_ack) + + # It'll still succeed + await send_task + + ncp_ezsp.data_received.reset_mock() + host_ezsp.data_received.reset_mock() + + # When the NCP fails to receive a reply, it enters a failed state + assert host._ncp_reset_code is None + assert ncp._ncp_reset_code is None + + with patch.object(host_transport, "paused", True): + send_task = asyncio.create_task(ncp.send_data(b"ncp failure")) + await asyncio.sleep(ncp._t_rx_ack * 15) + + with pytest.raises(asyncio.TimeoutError): + await send_task + + assert ( + host._ncp_reset_code is t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT + ) + assert ( + ncp._ncp_reset_code is t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT + ) + + ncp_ezsp.data_received.reset_mock() + host_ezsp.data_received.reset_mock() + + # All communication attempts with it will fail until it is reset + with pytest.raises(ash.NcpFailure): + await host.send_data(b"test") + + host.send_reset() + await asyncio.sleep(0.01) + await host.send_data(b"test") + + # Trigger a failure caused by excessive NAKs + ncp._t_rx_ack = ash.T_RX_ACK_INIT / 1000 + host._t_rx_ack = ash.T_RX_ACK_INIT / 1000 + + with patch.object(ncp, "nak_state", True): + with pytest.raises(ash.NotAcked): + await host.send_data(b"ncp NAKing until failure") diff --git a/tests/test_ezsp_protocol.py b/tests/test_ezsp_protocol.py index 29794512..798881af 100644 --- a/tests/test_ezsp_protocol.py +++ b/tests/test_ezsp_protocol.py @@ -9,6 +9,7 @@ import bellows.ezsp.v9 from bellows.ezsp.v9.commands import GetTokenDataRsp from bellows.types import NV3KeyId +from bellows.uart import Gateway from .async_mock import ANY, AsyncMock, MagicMock, call, patch @@ -16,23 +17,35 @@ @pytest.fixture def prot_hndl(): """Protocol handler mock.""" - return bellows.ezsp.v4.EZSPv4(MagicMock(), MagicMock()) + app = MagicMock() + gateway = Gateway(app) + gateway._transport = AsyncMock() + + callback_handler = MagicMock() + return bellows.ezsp.v4.EZSPv4(callback_handler, gateway) @pytest.fixture def prot_hndl_v9(): """Protocol handler mock.""" - return bellows.ezsp.v9.EZSPv9(MagicMock(), MagicMock()) + app = MagicMock() + gateway = Gateway(app) + gateway._transport = AsyncMock() + + callback_handler = MagicMock() + return bellows.ezsp.v9.EZSPv9(callback_handler, gateway) async def test_command(prot_hndl): - coro = prot_hndl.command("nop") - asyncio.get_running_loop().call_soon( - lambda: prot_hndl._awaiting[prot_hndl._seq - 1][2].set_result(True) - ) + with patch.object(prot_hndl._gw, "send_data") as mock_send_data: + coro = prot_hndl.command("nop") + asyncio.get_running_loop().call_soon( + lambda: prot_hndl._awaiting[prot_hndl._seq - 1][2].set_result(True) + ) + + await coro - await coro - assert prot_hndl._gw.data.call_count == 1 + assert mock_send_data.mock_calls == [call(b"\x00\x00\x05")] def test_receive_reply(prot_hndl): diff --git a/tests/test_uart.py b/tests/test_uart.py index 1dc5574c..68ac8664 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -6,8 +6,9 @@ import zigpy.config as conf from bellows import uart +import bellows.types as t -from .async_mock import AsyncMock, MagicMock, patch, sentinel +from .async_mock import AsyncMock, MagicMock, call, patch, sentinel @pytest.mark.parametrize("flow_control", ["software", "hardware"]) @@ -145,169 +146,11 @@ def gw(): return gw -def test_randomize(gw): - assert gw._randomize(b"\x00\x00\x00\x00\x00") == b"\x42\x21\xa8\x54\x2a" - assert gw._randomize(b"\x42\x21\xa8\x54\x2a") == b"\x00\x00\x00\x00\x00" - - -def test_stuff(gw): - orig = b"\x00\x7E\x01\x7D\x02\x11\x03\x13\x04\x18\x05\x1a\x06" - stuff = ( - b"\x00\x7D\x5E\x01\x7D\x5D\x02\x7D\x31\x03\x7D\x33\x04\x7D\x38\x05\x7D\x3a\x06" - ) - assert gw._stuff(orig) == stuff - - -def test_unstuff(gw): - orig = b"\x00\x7E\x01\x7D\x02\x11\x03\x13\x04\x18\x05\x1a\x06" - stuff = ( - b"\x00\x7D\x5E\x01\x7D\x5D\x02\x7D\x31\x03\x7D\x33\x04\x7D\x38\x05\x7D\x3a\x06" - ) - assert gw._unstuff(stuff) == orig - - -def test_rst(gw): - assert gw._rst_frame() == b"\x1a\xc0\x38\xbc\x7e" - - -def test_data_frame(gw): - expected = b"\x42\x21\xa8\x54\x2a" - assert gw._data_frame(b"\x00\x00\x00\x00\x00", 0, False)[1:-3] == expected - - -def test_cancel_received(gw): - gw.rst_frame_received = MagicMock() - gw.data_received(b"garbage") - gw.data_received(b"\x1a\xc0\x38\xbc\x7e") - assert gw.rst_frame_received.call_count == 1 - assert gw._buffer == b"" - - -def test_substitute_received(gw): - gw.rst_frame_received = MagicMock() - gw.data_received(b"garbage") - gw.data_received(b"\x18\x38\xbc\x7epart") - gw.data_received(b"ial") - gw.rst_frame_received.assert_not_called() - assert gw._buffer == b"partial" - - -def test_partial_data_received(gw): - gw.write = MagicMock() - gw._rec_seq = 5 - gw.data_received(b"\x54\x79\xa1\xb0") - gw.data_received(b"\x50\xf2\x6e\x7e") - assert gw.write.call_count == 1 - assert gw._application.frame_received.call_count == 1 - - -def test_crc_error(gw): - gw.write = MagicMock() - gw.data_received(b"L\xa1\x8e\x03\xcd\x07\xb9Y\xfbG%\xae\xbd~") - assert gw.write.call_count == 1 - assert gw._application.frame_received.call_count == 0 - - -def test_crc_error_and_valid_frame(gw): - gw.write = MagicMock() - gw._rec_seq = 5 - gw.data_received( - b"L\xa1\x8e\x03\xcd\x07\xb9Y\xfbG%\xae\xbd~\x54\x79\xa1\xb0\x50\xf2\x6e\x7e" - ) - assert gw.write.call_count == 2 - assert gw._application.frame_received.call_count == 1 - - -def test_data_frame_received(gw): - gw.write = MagicMock() - gw._rec_seq = 5 - gw.data_received(b"\x54\x79\xa1\xb0\x50\xf2\x6e\x7e") - assert gw.write.call_count == 1 - assert gw._application.frame_received.call_count == 1 - - -def test_ack_frame_received(gw): - gw.data_received(b"\x86\x10\xbe\x7e") - - -def test_nak_frame_received(gw): - gw.frame_received(bytes([0b10100000])) - - -def test_rst_frame_received(gw): - gw.data_received(b"garbage\x1a\xc0\x38\xbc\x7e") - - -def test_rstack_frame_received(gw): - gw._reset_future = MagicMock() - gw._reset_future.done = MagicMock(return_value=False) - gw.data_received(b"\xc1\x02\x0b\nR\x7e") - assert gw._reset_future.done.call_count == 1 - assert gw._reset_future.set_result.call_count == 1 - - -def test_wrong_rstack_frame_received(gw): - gw._reset_future = MagicMock() - gw.data_received(b"\xc1\x02\x01\xab\x18\x7e") - assert gw._reset_future.set_result.call_count == 0 - - -def test_error_rstack_frame_received(gw): - gw._reset_future = MagicMock() - gw.data_received(b"\xc1\x02\x81\x3a\x90\x7e") - assert gw._reset_future.set_result.call_count == 0 - - -def test_rstack_frame_received_nofut(gw): - gw.data_received(b"\xc1\x02\x0b\nR\x7e") - - -def test_rstack_frame_received_out_of_order(gw): - gw._reset_future = MagicMock() - gw._reset_future.done = MagicMock(return_value=True) - gw.data_received(b"\xc1\x02\x0b\nR\x7e") - assert gw._reset_future.done.call_count == 1 - assert gw._reset_future.set_result.call_count == 0 - - -def test_error_frame_received(gw): - from bellows.types import NcpResetCode - - gw.frame_received(b"\xc2\x02\x03\xd2\x0a\x7e") - efs = gw._application.enter_failed_state - assert efs.call_count == 1 - assert efs.call_args[0][0] == NcpResetCode.RESET_WATCHDOG - - -def test_unknown_frame_received(gw): - gw.frame_received(bytes([0b11011111])) - - def test_close(gw): gw.close() assert gw._transport.close.call_count == 1 -async def test_reset(gw): - gw._sendq.put_nowait(sentinel.queue_item) - fut = asyncio.Future() - gw._pending = (sentinel.seq, fut) - gw._transport.write.side_effect = lambda *args: gw._reset_future.set_result( - sentinel.reset_result - ) - reset_result = await gw.reset() - - assert gw._transport.write.call_count == 1 - assert gw._send_seq == 0 - assert gw._rec_seq == 0 - assert len(gw._buffer) == 0 - assert gw._sendq.empty() - assert fut.done() - assert gw._pending == (-1, None) - - assert reset_result is sentinel.reset_result - - async def test_reset_timeout(gw, monkeypatch): monkeypatch.setattr(uart, "RESET_TIMEOUT", 0.1) with pytest.raises(asyncio.TimeoutError): @@ -323,29 +166,6 @@ async def test_reset_old(gw): gw._transport.write.assert_not_called() -async def test_data(gw): - loop = asyncio.get_running_loop() - write_call_count = 0 - - def mockwrite(data): - nonlocal loop, write_call_count - if data == b"\x10 @\xda}^Z~": - loop.call_soon(gw._handle_nak, gw._pending[0]) - else: - loop.call_soon(gw._handle_ack, (gw._pending[0] + 1) % 8) - write_call_count += 1 - - gw.write = mockwrite - - gw.data(b"foo") - gw.data(b"bar") - gw.data(b"baz") - gw._sendq.put_nowait(gw.Terminator) - - await gw._send_loop() - assert write_call_count == 4 - - def test_connection_lost_exc(gw): gw.connection_lost(sentinel.exception) @@ -405,7 +225,7 @@ def on_transport_close(): async def test_wait_for_startup_reset(gw): loop = asyncio.get_running_loop() - loop.call_later(0.01, gw.data_received, b"\xc1\x02\x0b\nR\x7e") + loop.call_later(0.01, gw.reset_received, t.NcpResetCode.RESET_SOFTWARE) assert gw._startup_reset_future is None await gw.wait_for_startup_reset() @@ -421,109 +241,11 @@ async def test_wait_for_startup_reset_failure(gw): assert gw._startup_reset_future is None -ASH_ACK_MIN = 0.01 - - -@patch("bellows.uart.ASH_RX_ACK_MIN", new=ASH_ACK_MIN * 2**0) -@patch("bellows.uart.ASH_RX_ACK_INIT", new=ASH_ACK_MIN * 2**2) -@patch("bellows.uart.ASH_RX_ACK_MAX", new=ASH_ACK_MIN * 2**3) -async def test_retry_success(): - app = MagicMock() - transport = MagicMock() - connected_future = asyncio.get_running_loop().create_future() - - gw = uart.Gateway(app, connected_future) - gw.connection_made(transport) - - old_timeout = gw._ack_timeout - gw.data(b"TX 1") - await asyncio.sleep(0) - - # Wait more than one ACK cycle to reply - assert len(transport.write.mock_calls) == 1 - await asyncio.sleep(ASH_ACK_MIN * 5) - - # The gateway has retried once by now - assert len(transport.write.mock_calls) == 2 - - gw.frame_received( - # ash.DataFrame(frm_num=0, re_tx=0, ack_num=1, ezsp_frame=b"RX 1").to_bytes() - bytes.fromhex("01107988654851") - ) - - # An ACK has been received and the pending frame has been acknowledged - await asyncio.sleep(0) - assert gw._pending == (-1, None) - - assert gw._ack_timeout > old_timeout - - gw.close() - - -@patch("bellows.uart.ASH_RX_ACK_MIN", new=ASH_ACK_MIN * 2**0) -@patch("bellows.uart.ASH_RX_ACK_INIT", new=ASH_ACK_MIN * 2**2) -@patch("bellows.uart.ASH_RX_ACK_MAX", new=ASH_ACK_MIN * 2**3) -async def test_retry_nak_then_success(): - app = MagicMock() - transport = MagicMock() - connected_future = asyncio.get_running_loop().create_future() - - gw = uart.Gateway(app, connected_future) - gw.connection_made(transport) - - old_timeout = gw._ack_timeout - gw.data(b"TX 1") - await asyncio.sleep(0) - assert len(transport.write.mock_calls) == 1 - - # Wait less than one ACK cycle so that we can NAK the frame during the RX window - await asyncio.sleep(ASH_ACK_MIN) - # NAK the frame - gw.frame_received( - # ash.NakFrame(res=0, ncp_ready=0, ack_num=0).to_bytes() - bytes.fromhex("a0541a") - ) - - # The gateway has retried once more, instantly - await asyncio.sleep(0) - assert len(transport.write.mock_calls) == 2 +async def test_callbacks(gw): + gw.data_received(b"some ezsp packet") + assert gw._application.frame_received.mock_calls == [call(b"some ezsp packet")] - # Send a proper ACK - gw.frame_received( - # ash.AckFrame(res=0, ncp_ready=0, ack_num=1).to_bytes() - bytes.fromhex("816059") - ) - await asyncio.sleep(0) - assert gw._pending == (-1, None) - assert gw._ack_timeout < old_timeout - - gw.close() - - -@patch("bellows.uart.ASH_RX_ACK_MIN", new=ASH_ACK_MIN * 2**0) -@patch("bellows.uart.ASH_RX_ACK_INIT", new=ASH_ACK_MIN * 2**2) -@patch("bellows.uart.ASH_RX_ACK_MAX", new=ASH_ACK_MIN * 2**3) -async def test_retry_failure(): - app = MagicMock() - transport = MagicMock() - connected_future = asyncio.get_running_loop().create_future() - - gw = uart.Gateway(app, connected_future) - gw.connection_made(transport) - - old_timeout = gw._ack_timeout - gw.data(b"TX 1") - await asyncio.sleep(0) - - # Wait more than one ACK cycle to reply - assert len(transport.write.mock_calls) == 1 - await asyncio.sleep(ASH_ACK_MIN * 40) - - # The gateway has exhausted retries - assert len(transport.write.mock_calls) == 5 - - assert gw._pending == (-1, None) - assert gw._ack_timeout > old_timeout - assert gw._ack_timeout == ASH_ACK_MIN * 2**3 # max timeout - - gw.close() + gw.error_received(t.NcpResetCode.RESET_SOFTWARE) + assert gw._application.enter_failed_state.mock_calls == [ + call(t.NcpResetCode.RESET_SOFTWARE) + ]