diff --git a/config.yaml b/config.yaml index a69a2c7..a1b1974 100644 --- a/config.yaml +++ b/config.yaml @@ -13,6 +13,7 @@ mqtt: return_brightness_topic: "MQTTAnimator/rbrightness" return_anim_topic: "MQTTAnimator/ranimation" args_topic: "MQTTAnimator/args" + full_args_topic: "MQTTAnimator/fargs" animation_topic: "MQTTAnimator/animation" reconnection: first_reconnect_delay: 1 diff --git a/mqtt_animator.py b/mqtt_animator.py index 1d13ad9..7a359d1 100644 --- a/mqtt_animator.py +++ b/mqtt_animator.py @@ -7,6 +7,7 @@ import time import threading import traceback +import dataclasses import board import neopixel @@ -14,6 +15,7 @@ from paho.mqtt import client as mqtt_client import animator +from animator import AnimationArgs # Import yaml config with open("config.yaml", encoding="utf-8") as stream: @@ -42,6 +44,7 @@ state_topic: str = mqtt_topics.get("state_topic", "MQTTAnimator/state") brightness_topic: str = mqtt_topics.get("brightness_topic", "MQTTAnimator/brightness") args_topic: str = mqtt_topics.get("args_topic", "MQTTAnimator/args") +full_args_topic: str = mqtt_topics.get("full_args_topic", "MQTTAnimator/fargs") animation_topic: str = mqtt_topics.get("animation_topic", "MQTTAnimator/animation") data_request_return_topic: str = mqtt_topics.get("return_data_request_topic", @@ -63,6 +66,8 @@ pixel_pin = getattr(board, driver_config.get("pin", "D18")) # rpi gpio pin pixel_order = driver_config.get("order", "RGB") # Color order +global animation_args + animation_args = animator.AnimationArgs() animation_args.single_color.color = [0, 255, 0] @@ -75,6 +80,25 @@ ) animator = animator.Animator(pixels, num_pixels, animation_state, animation_args) +def validate_arg_import(json_data, dataclass_type): + # Convert the JSON data to a dictionary + data_dict = json.loads(json_data) + + # Get the fields of the dataclass + class_fields = dataclasses.fields(dataclass_type) + + # Check if all the required fields are present in the JSON data + for field in class_fields: + field_name = field.name + if field_name not in data_dict: + return False + + # If the field is a nested dataclass, recursively validate it + if hasattr(field.type, "__annotations__"): + if not validate_arg_import(json.dumps(data_dict[field_name]), field.type): + return False + + return True def on_connect(_, __, ___, rc): "On disconnection of mqtt" @@ -108,6 +132,8 @@ def on_disconnect(cli, _, rc): def on_message(cli: mqtt_client.Client, __, msg): + global animation_args + "Callback for mqtt message recieved" print(f"Received `{msg.payload.decode()}` from `{msg.topic}` topic") @@ -115,7 +141,10 @@ def on_message(cli: mqtt_client.Client, __, msg): cli.publish(data_request_return_topic, json.dumps({"state": animation_state.state, "brightness": animation_state.brightness, - "animation": animation_state.effect})) + "animation": animation_state.effect, + "args": json.dumps(dataclasses.asdict(animation_args)) + }) + ) elif msg.topic == state_topic: animation_state.state = "ON" if msg.payload.decode() == "ON" else "OFF" cli.publish(state_return_topic, "ON" if msg.payload.decode() == "ON" else "OFF") @@ -136,6 +165,16 @@ def on_message(cli: mqtt_client.Client, __, msg): for key, value in data.items(): setattr(argument, key, value) + elif msg.topic == full_args_topic: + try: + data = json.loads(msg.payload.decode("utf-8")) + except json.JSONDecodeError: + return + + if not validate_arg_import(data, animation_args): + return + + animation_args = AnimationArgs(**json.loads(data)) if __name__ == "__main__":