Skip to content

Commit

Permalink
Merge branch 'main' into 325adependabot/pip/paho-mqtt-2.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
allenporter authored Jan 13, 2025
2 parents 2a84279 + 5add0da commit 7613930
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 34 deletions.
82 changes: 53 additions & 29 deletions roborock/cloud_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,44 @@
DISCONNECT_REQUEST_ID = 1


class RoborockMqttClient(RoborockClient, mqtt.Client, ABC):
class _Mqtt(mqtt.Client):
"""Internal MQTT client.
This is a subclass of the Paho MQTT client that adds some additional functionality
for error cases where things get stuck.
"""

_thread: threading.Thread
_client_id: str

def __init__(self) -> None:
"""Initialize the MQTT client."""
super().__init__(protocol=mqtt.MQTTv5)
self.reset_client_id()

def reset_client_id(self):
"""Generate a new client id to make a new session when reconnecting."""
self._client_id = mqtt.base62(uuid.uuid4().int, padding=22)

def maybe_restart_loop(self) -> None:
"""Ensure that the MQTT loop is running in case it previously exited."""
if not self._thread or not self._thread.is_alive():
if self._thread:
_LOGGER.info("Stopping mqtt loop")
super().loop_stop()
_LOGGER.info("Starting mqtt loop")
super().loop_start()


class RoborockMqttClient(RoborockClient, ABC):
"""Roborock MQTT client base class."""

def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout: int = 10) -> None:
"""Initialize the Roborock MQTT client."""
rriot = user_data.rriot
if rriot is None:
raise RoborockException("Got no rriot data from user_data")
RoborockClient.__init__(self, device_info, queue_timeout)
mqtt.Client.__init__(self, protocol=mqtt.MQTTv5)
self._mqtt_user = rriot.u
self._hashed_user = md5hex(self._mqtt_user + ":" + rriot.k)[2:10]
url = urlparse(rriot.r.m)
Expand All @@ -39,16 +67,21 @@ def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout:
self._mqtt_host = str(url.hostname)
self._mqtt_port = url.port
self._mqtt_ssl = url.scheme == "ssl"

self._mqtt_client = _Mqtt()
self._mqtt_client.on_connect = self._mqtt_on_connect
self._mqtt_client.on_message = self._mqtt_on_message
self._mqtt_client.on_disconnect = self._mqtt_on_disconnect
if self._mqtt_ssl:
super().tls_set()
self._mqtt_client.tls_set()

self._mqtt_password = rriot.s
self._hashed_password = md5hex(self._mqtt_password + ":" + rriot.k)[16:]
super().username_pw_set(self._hashed_user, self._hashed_password)
self._mqtt_client.username_pw_set(self._hashed_user, self._hashed_password)
self._waiting_queue: dict[int, RoborockFuture] = {}
self._mutex = Lock()
self.update_client_id()

def on_connect(self, *args, **kwargs):
def _mqtt_on_connect(self, *args, **kwargs):
_, __, ___, rc, ____ = args
connection_queue = self._waiting_queue.get(CONNECT_REQUEST_ID)
if rc != mqtt.MQTT_ERR_SUCCESS:
Expand All @@ -59,7 +92,7 @@ def on_connect(self, *args, **kwargs):
return
self._logger.info(f"Connected to mqtt {self._mqtt_host}:{self._mqtt_port}")
topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}"
(result, mid) = self.subscribe(topic)
(result, mid) = self._mqtt_client.subscribe(topic)
if result != 0:
message = f"Failed to subscribe ({mqtt.error_string(rc)})"
self._logger.error(message)
Expand All @@ -70,48 +103,38 @@ def on_connect(self, *args, **kwargs):
if connection_queue:
connection_queue.set_result(True)

def on_message(self, *args, **kwargs):
def _mqtt_on_message(self, *args, **kwargs):
client, __, msg = args
try:
messages, _ = MessageParser.parse(msg.payload, local_key=self.device_info.device.local_key)
super().on_message_received(messages)
except Exception as ex:
self._logger.exception(ex)

def on_disconnect(self, *args, **kwargs):
def _mqtt_on_disconnect(self, *args, **kwargs):
_, __, rc, ___ = args
try:
exc = RoborockException(mqtt.error_string(rc)) if rc != mqtt.MQTT_ERR_SUCCESS else None
super().on_connection_lost(exc)
if rc == mqtt.MQTT_ERR_PROTOCOL:
self.update_client_id()
self._mqtt_client.reset_client_id()
connection_queue = self._waiting_queue.get(DISCONNECT_REQUEST_ID)
if connection_queue:
connection_queue.set_result(True)
except Exception as ex:
self._logger.exception(ex)

def update_client_id(self):
self._client_id = mqtt.base62(uuid.uuid4().int, padding=22)

def sync_stop_loop(self) -> None:
if self._thread:
self._logger.info("Stopping mqtt loop")
super().loop_stop()

def sync_start_loop(self) -> None:
if not self._thread or not self._thread.is_alive():
self.sync_stop_loop()
self._logger.info("Starting mqtt loop")
super().loop_start()
def is_connected(self) -> bool:
"""Check if the mqtt client is connected."""
return self._mqtt_client.is_connected()

def sync_disconnect(self) -> Any:
if not self.is_connected():
return None

self._logger.info("Disconnecting from mqtt")
disconnected_future = self._async_response(DISCONNECT_REQUEST_ID)
rc = super().disconnect()
rc = self._mqtt_client.disconnect()

if rc == mqtt.MQTT_ERR_NO_CONN:
disconnected_future.cancel()
Expand All @@ -125,17 +148,16 @@ def sync_disconnect(self) -> Any:

def sync_connect(self) -> Any:
if self.is_connected():
self.sync_start_loop()
self._mqtt_client.maybe_restart_loop()
return None

if self._mqtt_port is None or self._mqtt_host is None:
raise RoborockException("Mqtt information was not entered. Cannot connect.")

self._logger.debug("Connecting to mqtt")
connected_future = self._async_response(CONNECT_REQUEST_ID)
super().connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE)

self.sync_start_loop()
self._mqtt_client.connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE)
self._mqtt_client.maybe_restart_loop()
return connected_future

async def async_disconnect(self) -> None:
Expand All @@ -155,6 +177,8 @@ async def async_connect(self) -> None:
raise RoborockException(err) from err

def _send_msg_raw(self, msg: bytes) -> None:
info = self.publish(f"rr/m/i/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}", msg)
info = self._mqtt_client.publish(
f"rr/m/i/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}", msg
)
if info.rc != mqtt.MQTT_ERR_SUCCESS:
raise RoborockException(f"Failed to publish ({mqtt.error_string(info.rc)})")
21 changes: 17 additions & 4 deletions roborock/local_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
from abc import ABC
from asyncio import Lock, TimerHandle, Transport
from collections.abc import Callable
from dataclasses import dataclass

import async_timeout

Expand All @@ -16,7 +18,15 @@
_LOGGER = logging.getLogger(__name__)


class RoborockLocalClient(RoborockClient, asyncio.Protocol, ABC):
@dataclass
class _LocalProtocol(asyncio.Protocol):
"""Callbacks for the Roborock local client transport."""

messages_cb: Callable[[bytes], None]
connection_lost_cb: Callable[[Exception | None], None]


class RoborockLocalClient(RoborockClient, ABC):
"""Roborock local client base class."""

def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
Expand All @@ -31,15 +41,18 @@ def __init__(self, device_data: DeviceData, queue_timeout: int = 4):
self._mutex = Lock()
self.keep_alive_task: TimerHandle | None = None
RoborockClient.__init__(self, device_data, queue_timeout)
self._local_protocol = _LocalProtocol(self._data_received, self._connection_lost)

def data_received(self, message):
def _data_received(self, message):
"""Called when data is received from the transport."""
if self.remaining:
message = self.remaining + message
self.remaining = b""
parser_msg, self.remaining = MessageParser.parse(message, local_key=self.device_info.device.local_key)
self.on_message_received(parser_msg)

def connection_lost(self, exc: Exception | None):
def _connection_lost(self, exc: Exception | None):
"""Called when the transport connection is lost."""
self.sync_disconnect()
self.on_connection_lost(exc)

Expand All @@ -62,7 +75,7 @@ async def async_connect(self) -> None:
async with async_timeout.timeout(self.queue_timeout):
self._logger.debug(f"Connecting to {self.host}")
self.transport, _ = await self.event_loop.create_connection( # type: ignore
lambda: self, self.host, 58867
lambda: self._local_protocol, self.host, 58867
)
self._logger.info(f"Connected to {self.host}")
should_ping = True
Expand Down
45 changes: 44 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import io
import logging
import re
from asyncio import Protocol
from collections.abc import Callable, Generator
from queue import Queue
from typing import Any
Expand All @@ -11,8 +13,9 @@

from roborock import HomeData, UserData
from roborock.containers import DeviceData
from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1
from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1
from tests.mock_data import HOME_DATA_RAW, USER_DATA
from tests.mock_data import HOME_DATA_RAW, TEST_LOCAL_API_HOST, USER_DATA

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -191,3 +194,43 @@ def mock_rest() -> aioresponses:
payload={"api": None, "code": 200, "result": HOME_DATA_RAW, "status": "ok", "success": True},
)
yield mocked


@pytest.fixture(name="mock_create_local_connection")
def create_local_connection_fixture(request_handler: RequestHandler) -> Generator[None, None, None]:
"""Fixture that overrides the transport creation to wire it up to the mock socket."""

async def create_connection(protocol_factory: Callable[[], Protocol], *args) -> tuple[Any, Any]:
protocol = protocol_factory()

def handle_write(data: bytes) -> None:
_LOGGER.debug("Received: %s", data)
response = request_handler(data)
if response is not None:
_LOGGER.debug("Replying with %s", response)
loop = asyncio.get_running_loop()
loop.call_soon(protocol.data_received, response)

closed = asyncio.Event()

mock_transport = Mock()
mock_transport.write = handle_write
mock_transport.close = closed.set
mock_transport.is_reading = lambda: not closed.is_set()

return (mock_transport, "proto")

with patch("roborock.api.get_running_loop_or_create_one") as mock_loop:
mock_loop.return_value.create_connection.side_effect = create_connection
yield


@pytest.fixture(name="local_client")
def local_client_fixture(mock_create_local_connection: None) -> Generator[RoborockLocalClientV1, None, None]:
home_data = HomeData.from_dict(HOME_DATA_RAW)
device_info = DeviceData(
device=home_data.devices[0],
model=home_data.products[0].model,
host=TEST_LOCAL_API_HOST,
)
yield RoborockLocalClientV1(device_info)
1 change: 1 addition & 0 deletions tests/mock_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,4 @@
GET_CODE_RESPONSE = {"code": 200, "msg": "success", "data": None}
HASHED_USER = hashlib.md5((USER_ID + ":" + K_VALUE).encode()).hexdigest()[2:10]
MQTT_PUBLISH_TOPIC = f"rr/m/o/{USER_ID}/{HASHED_USER}/{PRODUCT_ID}"
TEST_LOCAL_API_HOST = "1.1.1.1"
37 changes: 37 additions & 0 deletions tests/test_local_api_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Tests for the Roborock Local Client V1."""

from queue import Queue

from roborock.protocol import MessageParser
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
from roborock.version_1_apis import RoborockLocalClientV1

from .mock_data import LOCAL_KEY


def build_rpc_response(protocol: RoborockMessageProtocol, seq: int) -> bytes:
"""Build an encoded RPC response message."""
message = RoborockMessage(
protocol=protocol,
random=23,
seq=seq,
payload=b"ignored",
)
return MessageParser.build(message, local_key=LOCAL_KEY)


async def test_async_connect(
local_client: RoborockLocalClientV1,
received_requests: Queue,
response_queue: Queue,
):
"""Test that we can connect to the Roborock device."""
response_queue.put(build_rpc_response(RoborockMessageProtocol.HELLO_RESPONSE, 1))
response_queue.put(build_rpc_response(RoborockMessageProtocol.PING_RESPONSE, 2))

await local_client.async_connect()
assert local_client.is_connected()
assert received_requests.qsize() == 2

await local_client.async_disconnect()
assert not local_client.is_connected()

0 comments on commit 7613930

Please sign in to comment.