From 291a36373af7a9a185e55d6482a55a6a8f924279 Mon Sep 17 00:00:00 2001 From: Andriy Kushnir Date: Fri, 23 Feb 2024 13:43:40 +0200 Subject: [PATCH] Simplify decoding (#256) --- roombapy/roomba.py | 79 ++++++++++++++++--------------------------- tests/test_payload.py | 37 ++++++++++++++++++++ 2 files changed, 67 insertions(+), 49 deletions(-) create mode 100644 tests/test_payload.py diff --git a/roombapy/roomba.py b/roombapy/roomba.py index ef4c3b6..ff6917f 100755 --- a/roombapy/roomba.py +++ b/roombapy/roomba.py @@ -14,11 +14,14 @@ outlines if you don't have OpenCV """ +from __future__ import annotations + import logging import threading import time from collections.abc import Mapping from datetime import datetime +from typing import Any import orjson @@ -185,29 +188,28 @@ def on_message(self, _mosq, _obj, msg): if self.indent == 0: self.master_indent = max(self.master_indent, len(msg.topic)) - log_string, json_data = self.decode_payload(msg.topic, msg.payload) - self.dict_merge(self.master_state, json_data) + decoded_message = _decode_payload(msg.payload) + client_ip = self.remote_client.address - self.log.debug( - "Received Roomba Data %s: %s, %s", - self.remote_client.address, - str(msg.topic), - str(msg.payload), - ) + if decoded_message is None: + self.log.warning( + "Got malformed message from %s: %s", client_ip, msg + ) + return - self.decode_topics(json_data) + self.dict_merge(self.master_state, decoded_message) + self.log.debug("Received message from %s: %s", client_ip, msg) + self.decode_topics(decoded_message) # default every 5 minutes if time.time() - self.time > self.update_seconds: - self.log.debug( - "Publishing master_state %s", self.remote_client.address - ) + self.log.debug("Publishing master_state %s", client_ip) self.decode_topics(self.master_state) # publish all values self.time = time.time() # call the callback functions for callback in self.on_message_callbacks: - callback(json_data) + callback(decoded_message) def send_command(self, command, params=None): """Send a command to the Roomba.""" @@ -267,42 +269,6 @@ def dict_merge(self, dct, merge_dct): else: dct[k] = merge_dct[k] - def decode_payload(self, _topic, payload): - """Format json for pretty printing. - - Returns string sutiable for logging, and a dict of the json data - """ - indent = self.master_indent + 31 # number of spaces to indent json data - - json_data = None - try: - # if it's json data, decode it. OrderedDict is no longer - # needed since python 3.6 and later guarantees dict - # insertion order - json_data = orjson.loads( - payload.decode("utf-8") - .replace(":nan", ":NaN") - .replace(":inf", ":Infinity") - .replace(":-inf", ":-Infinity"), - ) - # if it's not a dictionary, probably just a number - if not isinstance(json_data, dict): - return json_data, dict(json_data) - json_data_string = "\n".join( - (indent * " ") + i - for i in ( - orjson.dumps(json_data, option=orjson.OPT_INDENT_2).decode( - "utf-8" - ) - ).splitlines() - ) - - formatted_data = "Decoded JSON: \n%s" % json_data_string - - except ValueError: - formatted_data = payload - return formatted_data, dict(json_data) - def decode_topics(self, state, prefix=None): """Decode json data dict and publish as individual topics. @@ -483,3 +449,18 @@ def update_state_machine(self, new_state=None): if self.current_state != current_mission: self.log.debug("State updated to: %s", self.current_state) + + +def _decode_payload(raw_payload: bytes) -> dict[str, Any] | None: + try: + payload = raw_payload.decode() + message = orjson.loads(payload) + except UnicodeDecodeError: + return None + except orjson.JSONDecodeError: + return None + + if not isinstance(message, dict): + return None + + return message diff --git a/tests/test_payload.py b/tests/test_payload.py new file mode 100644 index 0000000..a2783e6 --- /dev/null +++ b/tests/test_payload.py @@ -0,0 +1,37 @@ +"""Test the decoding of the Roomba messages.""" +from roombapy.roomba import _decode_payload + + +def test_skip_garbage() -> None: + """Skip garbage data in payload.""" + assert _decode_payload(b"\x00") is None + + +def test_skip_broken_json() -> None: + """Skip broken JSON.""" + assert _decode_payload(b"[") is None + assert _decode_payload(b"{") is None + + +def test_skip_non_object_json() -> None: + """Allow only objects in messages.""" + assert _decode_payload(b"[]") is None + assert _decode_payload(b"12") is None + + +def test_allow_empty_json() -> None: + """Allow empty objects.""" + assert _decode_payload(b"{}") == {} + + +def test_allow_valid_json() -> None: + """Properly decode valid JSON object.""" + payload = b""" + {"state": {"reported": {"signal": {"rssi": -45, "snr": 18, "noise": -63}}}} + """ + decoded = { + "state": { + "reported": {"signal": {"rssi": -45, "snr": 18, "noise": -63}} + } + } + assert _decode_payload(payload) == decoded