Skip to content

Commit

Permalink
Zigpy serial protocol (#160)
Browse files Browse the repository at this point in the history
* Migrate zigate to zigpy serial protocol

* Fix unit tests

* Let zigpy handle flow control

* Bump minimum zigpy version

* Remove unnecessary `close`

* Clean API only on close

* Fix annotations

* Test `connection_lost`
  • Loading branch information
puddly authored Oct 27, 2024
1 parent 93c7358 commit eb4d141
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 72 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ license = {text = "GPL-3.0"}
requires-python = ">=3.8"
dependencies = [
"voluptuous",
"zigpy>=0.66.0",
"zigpy>=0.70.0",
"pyusb>=1.1.0",
"gpiozero",
'async-timeout; python_version<"3.11"',
Expand Down
11 changes: 7 additions & 4 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from unittest.mock import MagicMock, patch, sentinel
from unittest.mock import AsyncMock, MagicMock, patch, sentinel

import pytest
import serial_asyncio
Expand Down Expand Up @@ -37,10 +37,13 @@ async def mock_conn(loop, protocol_factory, **kwargs):
await api.connect()


def test_close(api):
@pytest.mark.asyncio
async def test_disconnect(api):
uart = api._uart
api.close()
assert uart.close.call_count == 1
uart.disconnect = AsyncMock()

await api.disconnect()
assert uart.disconnect.call_count == 1
assert api._uart is None


Expand Down
12 changes: 6 additions & 6 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,28 +102,28 @@ async def mock_get_network_state():

@pytest.mark.asyncio
async def test_disconnect_success(app):
api = MagicMock()
api = AsyncMock()

app._api = api
await app.disconnect()

api.close.assert_called_once()
api.disconnect.assert_called_once()
assert app._api is None


@pytest.mark.asyncio
async def test_disconnect_failure(app, caplog):
api = MagicMock()
api.disconnect = MagicMock(side_effect=RuntimeError("Broken"))
api = AsyncMock()
api.reset = AsyncMock(side_effect=RuntimeError("Broken"))

app._api = api

with caplog.at_level(logging.WARNING):
await app.disconnect()

assert "disconnect" in caplog.text
assert "Failed to reset before disconnect" in caplog.text

api.close.assert_called_once()
api.disconnect.assert_called_once()
assert app._api is None


Expand Down
15 changes: 7 additions & 8 deletions tests/test_uart.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, call

import gpiozero
import pytest
Expand Down Expand Up @@ -52,6 +52,12 @@ def test_close(gw):
assert gw._transport.close.call_count == 1


def test_connection_lost(gw):
exc = RuntimeError()
gw.connection_lost(exc)
assert gw._api.connection_lost.mock_calls == [call(exc)]


def test_data_received_chunk_frame(gw):
data = b"\x01\x80\x10\x02\x10\x02\x15\xaa\x02\x10\x02\x1f?\xf0\xff\x03"
gw.data_received(data[:-4])
Expand Down Expand Up @@ -108,13 +114,6 @@ def test_escape(gw):
assert r == data_escaped


def test_length(gw):
data = b"\x80\x10\x00\x05\xaa\x00\x0f?\xf0\xff"
length = 5
r = gw._length(data)
assert r == length


def test_checksum(gw):
data = b"\x00\x0f?\xf0"
checksum = 0xAA
Expand Down
6 changes: 3 additions & 3 deletions zigpy_zigate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,9 @@ def connection_lost(self, exc: Exception) -> None:
if self._app is not None:
self._app.connection_lost(exc)

def close(self):
if self._uart:
self._uart.close()
async def disconnect(self):
if self._uart is not None:
await self._uart.disconnect()
self._uart = None

def set_application(self, app):
Expand Down
70 changes: 21 additions & 49 deletions zigpy_zigate/uart.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import asyncio
import binascii
import logging
import struct
from typing import Any, Dict
from typing import Any

import zigpy.config
import zigpy.serial
Expand All @@ -12,39 +14,24 @@
LOGGER = logging.getLogger(__name__)


class Gateway(asyncio.Protocol):
class Gateway(zigpy.serial.SerialProtocol):
START = b"\x01"
END = b"\x03"

def __init__(self, api, connected_future=None):
self._buffer = b""
self._connected_future = connected_future
def __init__(self, api):
super().__init__()
self._api = api

def connection_lost(self, exc) -> None:
"""Port was closed expecteddly or unexpectedly."""
if self._connected_future and not self._connected_future.done():
if exc is None:
self._connected_future.set_result(True)
else:
self._connected_future.set_exception(exc)
if exc is None:
LOGGER.debug("Closed serial connection")
return

LOGGER.error("Lost serial connection: %s", exc)
self._api.connection_lost(exc)
def connection_lost(self, exc: Exception | None) -> None:
"""Port was closed expectedly or unexpectedly."""
super().connection_lost(exc)

def connection_made(self, transport):
"""Callback when the uart is connected"""
LOGGER.debug("Connection made")
self._transport = transport
if self._connected_future:
self._connected_future.set_result(True)
if self._api is not None:
self._api.connection_lost(exc)

def close(self):
if self._transport:
self._transport.close()
super().close()
self._api = None

def send(self, cmd, data=b""):
"""Send data, taking care of escaping and framing"""
Expand All @@ -60,8 +47,7 @@ def send(self, cmd, data=b""):

def data_received(self, data):
"""Callback when there is data received from the uart"""
self._buffer += data
# LOGGER.debug('data_received %s', self._buffer)
super().data_received(data)
endpos = self._buffer.find(self.END)
while endpos != -1:
startpos = self._buffer.rfind(self.START, 0, endpos)
Expand All @@ -71,7 +57,7 @@ def data_received(self, data):
cmd, length, checksum, f_data, lqi = struct.unpack(
"!HHB%dsB" % (len(frame) - 6), frame
)
if self._length(frame) != length:
if len(frame) - 5 != length:
LOGGER.warning(
"Invalid length: %s, data: %s", length, len(frame) - 6
)
Expand Down Expand Up @@ -126,42 +112,28 @@ def _checksum(self, *args):
chcksum ^= x
return chcksum

def _length(self, frame):
length = len(frame) - 5
return length


async def connect(device_config: Dict[str, Any], api, loop=None):
if loop is None:
loop = asyncio.get_event_loop()

connected_future = asyncio.Future()
protocol = Gateway(api, connected_future)

async def connect(device_config: dict[str, Any], api, loop=None):
loop = asyncio.get_running_loop()
port = device_config[zigpy.config.CONF_DEVICE_PATH]
if port == "auto":
port = await loop.run_in_executor(None, c.discover_port)

if await c.async_is_pizigate(port):
LOGGER.debug("PiZiGate detected")
await c.async_set_pizigate_running_mode()
# in case of pizigate:/dev/ttyAMA0 syntax
if port.startswith("pizigate:"):
port = port.replace("pizigate:", "", 1)
port = port.replace("pizigate:", "", 1)
elif await c.async_is_zigate_din(port):
LOGGER.debug("ZiGate USB DIN detected")
await c.async_set_zigatedin_running_mode()
elif c.is_zigate_wifi(port):
LOGGER.debug("ZiGate WiFi detected")

protocol = Gateway(api)
_, protocol = await zigpy.serial.create_serial_connection(
loop,
lambda: protocol,
url=port,
baudrate=device_config[zigpy.config.CONF_DEVICE_BAUDRATE],
xonxoff=False,
flow_control=device_config[zigpy.config.CONF_DEVICE_FLOW_CONTROL],
)

await connected_future
await protocol.wait_until_connected()

return protocol
2 changes: 1 addition & 1 deletion zigpy_zigate/zigbee/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def disconnect(self):
except Exception as e:
LOGGER.warning("Failed to reset before disconnect: %s", e)
finally:
self._api.close()
await self._api.disconnect()
self._api = None

async def start_network(self):
Expand Down

0 comments on commit eb4d141

Please sign in to comment.