Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MQTT Sink #659

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion conda/post-link.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ $PREFIX/bin/pip install \
'protobuf>=5.27.2,<6.0' \
'influxdb3-python>=0.7,<1.0' \
'pyiceberg[pyarrow,glue]>=0.7,<0.8' \
'redis[hiredis]>=5.2.0,<6'
'redis[hiredis]>=5.2.0,<6' \
'paho-mqtt>=2.1.0,<3'
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ all = [
"psycopg2-binary>=2.9.9,<3",
"boto3>=1.35.65,<2.0",
"boto3-stubs>=1.35.65,<2.0",
"redis[hiredis]>=5.2.0,<6"
"redis[hiredis]>=5.2.0,<6",
"paho-mqtt>=2.1.0,<3"
]

avro = ["fastavro>=1.8,<2.0"]
Expand All @@ -50,6 +51,7 @@ pubsub = ["google-cloud-pubsub>=2.23.1,<3"]
postgresql = ["psycopg2-binary>=2.9.9,<3"]
kinesis = ["boto3>=1.35.65,<2.0", "boto3-stubs[kinesis]>=1.35.65,<2.0"]
redis=["redis[hiredis]>=5.2.0,<6"]
mqtt=["paho-mqtt>=2.1.0,<3"]

[tool.setuptools.packages.find]
include = ["quixstreams*"]
Expand Down
169 changes: 169 additions & 0 deletions quixstreams/sinks/community/mqtt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import json
from datetime import datetime
from typing import Any, List, Tuple

from quixstreams.models.types import HeaderValue
from quixstreams.sinks.base.sink import BaseSink

try:
import paho.mqtt.client as paho
from paho import mqtt
except ImportError as exc:
raise ImportError(
'Package "paho-mqtt" is missing: ' "run pip install quixstreams[mqtt] to fix it"
) from exc


class MQTTSink(BaseSink):
"""
A sink that publishes messages to an MQTT broker.
"""

def __init__(
self,
mqtt_client_id: str,
mqtt_server: str,
mqtt_port: int,
mqtt_topic_root: str,
mqtt_username: str = None,
mqtt_password: str = None,
mqtt_version: str = "3.1.1",
tls_enabled: bool = True,
qos: int = 1,
Comment on lines +24 to +32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mqtt_client_id: str,
mqtt_server: str,
mqtt_port: int,
mqtt_topic_root: str,
mqtt_username: str = None,
mqtt_password: str = None,
mqtt_version: str = "3.1.1",
tls_enabled: bool = True,
qos: int = 1,
client_id: str,
server: str,
port: int,
topic_root: str,
username: str = None,
password: str = None,
version: str = "3.1.1",
tls_enabled: bool = True,
qos: int = 1,

):
"""
Initialize the MQTTSink.

:param mqtt_client_id: MQTT client identifier.
:param mqtt_server: MQTT broker server address.
:param mqtt_port: MQTT broker server port.
:param mqtt_topic_root: Root topic to publish messages to.
:param mqtt_username: Username for MQTT broker authentication. Defaults to None
:param mqtt_password: Password for MQTT broker authentication. Defaults to None
:param mqtt_version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1
:param tls_enabled: Whether to use TLS encryption. Defaults to True
:param qos: Quality of Service level (0, 1, or 2). Defaults to 1
"""

super().__init__()

self.mqtt_version = mqtt_version
self.mqtt_username = mqtt_username
self.mqtt_password = mqtt_password
self.mqtt_topic_root = mqtt_topic_root
self.tls_enabled = tls_enabled
self.qos = qos

self.mqtt_client = paho.Client(
callback_api_version=paho.CallbackAPIVersion.VERSION2,
client_id=mqtt_client_id,
userdata=None,
protocol=self._mqtt_protocol_version(),
)

if self.tls_enabled:
self.mqtt_client.tls_set(
tls_version=mqtt.client.ssl.PROTOCOL_TLS
) # we'll be using tls now

self.mqtt_client.reconnect_delay_set(5, 60)
self._configure_authentication()
self.mqtt_client.on_connect = self._mqtt_on_connect_cb
self.mqtt_client.on_disconnect = self._mqtt_on_disconnect_cb
self.mqtt_client.connect(mqtt_server, int(mqtt_port))

# setting callbacks for different events to see if it works, print the message etc.
def _mqtt_on_connect_cb(
self,
client: paho.Client,
userdata: any,
connect_flags: paho.ConnectFlags,
reason_code: paho.ReasonCode,
properties: paho.Properties,
):
if reason_code == 0:
print("CONNECTED!") # required for Quix to know this has connected
else:
print(f"ERROR ({reason_code.value}). {reason_code.getName()}")

def _mqtt_on_disconnect_cb(
self,
client: paho.Client,
userdata: any,
disconnect_flags: paho.DisconnectFlags,
reason_code: paho.ReasonCode,
properties: paho.Properties,
):
print(
f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!"
)

def _mqtt_protocol_version(self):
if self.mqtt_version == "3.1":
return paho.MQTTv31
elif self.mqtt_version == "3.1.1":
return paho.MQTTv311
elif self.mqtt_version == "5":
return paho.MQTTv5
else:
raise ValueError(f"Unsupported MQTT version: {self.mqtt_version}")

def _configure_authentication(self):
if self.mqtt_username:
self.mqtt_client.username_pw_set(self.mqtt_username, self.mqtt_password)

def _publish_to_mqtt(
self,
data: str,
key: bytes,
timestamp: datetime,
headers: List[Tuple[str, HeaderValue]],
):
if isinstance(data, bytes):
data = data.decode("utf-8") # Decode bytes to string using utf-8

json_data = json.dumps(data)
message_key_string = key.decode(
"utf-8"
) # Convert to string using utf-8 encoding
# publish to MQTT
self.mqtt_client.publish(
self.mqtt_topic_root + "/" + message_key_string,
payload=json_data,
qos=self.qos,
)

def add(
self,
topic: str,
partition: int,
offset: int,
key: bytes,
value: bytes,
timestamp: datetime,
headers: List[Tuple[str, HeaderValue]],
**kwargs: Any,
):
self._publish_to_mqtt(value, key, timestamp, headers)

def _construct_topic(self, key):
if key:
key_str = key.decode("utf-8") if isinstance(key, bytes) else str(key)
return f"{self.mqtt_topic_root}/{key_str}"
else:
return self.mqtt_topic_root

def on_paused(self, topic: str, partition: int):
# not used
pass

def flush(self, topic: str, partition: str):
# not used
pass

def cleanup(self):
self.mqtt_client.loop_stop()
self.mqtt_client.disconnect()

def __del__(self):
self.cleanup()
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ protobuf>=5.27.2
influxdb3-python>=0.7.0,<1.0
pyiceberg[pyarrow,glue]>=0.7,<0.8
redis[hiredis]>=5.2.0,<6
paho-mqtt==2.1.0
85 changes: 85 additions & 0 deletions tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from datetime import datetime
from unittest.mock import patch

import pytest

from quixstreams.sinks.community.mqtt import MQTTSink


@pytest.fixture()
def mqtt_sink_factory():
def factory(
mqtt_client_id: str = "test_client",
mqtt_server: str = "localhost",
mqtt_port: int = 1883,
mqtt_topic_root: str = "test/topic",
mqtt_username: str = None,
mqtt_password: str = None,
mqtt_version: str = "3.1.1",
tls_enabled: bool = True,
qos: int = 1,
) -> MQTTSink:
with patch("paho.mqtt.client.Client") as MockClient:
mock_mqtt_client = MockClient.return_value
sink = MQTTSink(
mqtt_client_id=mqtt_client_id,
mqtt_server=mqtt_server,
mqtt_port=mqtt_port,
mqtt_topic_root=mqtt_topic_root,
mqtt_username=mqtt_username,
mqtt_password=mqtt_password,
mqtt_version=mqtt_version,
tls_enabled=tls_enabled,
qos=qos,
)
sink.mqtt_client = mock_mqtt_client
return sink, mock_mqtt_client

return factory


class TestMQTTSink:
def test_mqtt_connect(self, mqtt_sink_factory):
sink, mock_mqtt_client = mqtt_sink_factory()
mock_mqtt_client.connect.assert_called_once_with("localhost", 1883)

def test_mqtt_tls_enabled(self, mqtt_sink_factory):
sink, mock_mqtt_client = mqtt_sink_factory(tls_enabled=True)
mock_mqtt_client.tls_set.assert_called_once()

def test_mqtt_tls_disabled(self, mqtt_sink_factory):
sink, mock_mqtt_client = mqtt_sink_factory(tls_enabled=False)
mock_mqtt_client.tls_set.assert_not_called()

def test_mqtt_publish(self, mqtt_sink_factory):
sink, mock_mqtt_client = mqtt_sink_factory()
data = "test_data"
key = b"test_key"
timestamp = datetime.now()
headers = []

sink.add(
topic="test-topic",
partition=0,
offset=1,
key=key,
value=data.encode("utf-8"),
timestamp=timestamp,
headers=headers,
)

mock_mqtt_client.publish.assert_called_once_with(
"test/topic/test_key", payload='"test_data"', qos=1
)

def test_mqtt_authentication(self, mqtt_sink_factory):
sink, mock_mqtt_client = mqtt_sink_factory(
mqtt_username="user", mqtt_password="pass"
)
mock_mqtt_client.username_pw_set.assert_called_once_with("user", "pass")

def test_mqtt_disconnect_on_delete(self, mqtt_sink_factory):
sink, mock_mqtt_client = mqtt_sink_factory()
sink.cleanup() # Explicitly call cleanup
mock_mqtt_client.loop_stop.assert_called_once()
mock_mqtt_client.disconnect.assert_called_once()