Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for "unix" transport where socket module contains AF_UNIX #829

Merged
merged 2 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/paho/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,10 @@ class Client:

:param transport: use "websockets" to use WebSockets as the transport
mechanism. Set to "tcp" to use raw TCP, which is the default.
Use "unix" to use Unix sockets as the transport mechanism; note that
this option is only available on platforms that support Unix sockets,
and the "host" argument is interpreted as the path to the Unix socket
file in this case.

:param bool manual_ack: normally, when a message is received, the library automatically
acknowledges after on_message callback returns. manual_ack=True allows the application to
Expand Down Expand Up @@ -733,14 +737,16 @@ def __init__(
clean_session: bool | None = None,
userdata: Any = None,
protocol: MQTTProtocolVersion = MQTTv311,
transport: Literal["tcp", "websockets"] = "tcp",
transport: Literal["tcp", "websockets", "unix"] = "tcp",
reconnect_on_failure: bool = True,
manual_ack: bool = False,
) -> None:
transport = transport.lower() # type: ignore
if transport not in ("websockets", "tcp"):
if transport == "unix" and not hasattr(socket, "AF_UNIX"):
raise ValueError('"unix" transport not supported')
elif transport not in ("websockets", "tcp", "unix"):
raise ValueError(
f'transport must be "websockets" or "tcp", not {transport}')
f'transport must be "websockets", "tcp" or "unix", not {transport}')

self._manual_ack = manual_ack
self._transport = transport
Expand Down Expand Up @@ -931,7 +937,7 @@ def keepalive(self, value: int) -> None:
self._keepalive = value

@property
def transport(self) -> Literal["tcp", "websockets"]:
def transport(self) -> Literal["tcp", "websockets", "unix"]:
"""
Transport method used for the connection ("tcp" or "websockets").

Expand Down Expand Up @@ -4595,7 +4601,11 @@ def _get_proxy(self) -> dict[str, Any] | None:
return None

def _create_socket(self) -> SocketLike:
sock = self._create_socket_connection()
if self._transport == "unix":
sock = self._create_unix_socket_connection()
else:
sock = self._create_socket_connection()

if self._ssl:
sock = self._ssl_wrap_socket(sock)

Expand All @@ -4612,6 +4622,11 @@ def _create_socket(self) -> SocketLike:

return sock

def _create_unix_socket_connection(self) -> _socket.socket:
unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
unix_socket.connect(self._host)
return unix_socket

def _create_socket_connection(self) -> _socket.socket:
proxy = self._get_proxy()
addr = (self._host, self._port)
Expand Down
29 changes: 20 additions & 9 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_01_con_discon_success(self, proto_ver, callback_version, fake_broker):
callback_version,
"01-con-discon-success",
protocol=proto_ver,
transport=fake_broker.transport,
)

def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None):
Expand Down Expand Up @@ -70,7 +71,8 @@ def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None):

def test_01_con_failure_rc(self, proto_ver, callback_version, fake_broker):
mqttc = client.Client(
callback_version, "01-con-failure-rc", protocol=proto_ver)
callback_version, "01-con-failure-rc",
protocol=proto_ver, transport=fake_broker.transport)

def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None):
assert rc_or_reason_code > 0
Expand Down Expand Up @@ -107,7 +109,9 @@ def on_connect(mqttc, obj, flags, rc_or_reason_code, properties_or_none=None):
mqttc.loop_stop()

def test_connection_properties(self, proto_ver, callback_version, fake_broker):
mqttc = client.Client(CallbackAPIVersion.VERSION2, "client-id", protocol=proto_ver)
mqttc = client.Client(
CallbackAPIVersion.VERSION2, "client-id",
protocol=proto_ver, transport=fake_broker.transport)
mqttc.enable_logger()

is_connected = threading.Event()
Expand All @@ -131,7 +135,7 @@ def on_disconnect(*args):
mqttc.keepalive = 7
mqttc.max_inflight_messages = 7
mqttc.max_queued_messages = 7
mqttc.transport = "tcp"
mqttc.transport = fake_broker.transport
mqttc.username = "username"
mqttc.password = "password"

Expand Down Expand Up @@ -184,7 +188,7 @@ def on_disconnect(*args):
mqttc.max_queued_messages = 7

with pytest.raises(RuntimeError):
mqttc.transport = "tcp"
mqttc.transport = fake_broker.transport

with pytest.raises(RuntimeError):
mqttc.username = "username"
Expand Down Expand Up @@ -217,7 +221,9 @@ class Test_connect_v5:
"""

def test_01_broker_no_support(self, fake_broker):
mqttc = client.Client(CallbackAPIVersion.VERSION2, "01-broker-no-support", protocol=MQTTProtocolVersion.MQTTv5)
mqttc = client.Client(
CallbackAPIVersion.VERSION2, "01-broker-no-support",
protocol=MQTTProtocolVersion.MQTTv5, transport=fake_broker.transport)

def on_connect(mqttc, obj, flags, reason, properties):
assert reason == 132
Expand Down Expand Up @@ -261,6 +267,7 @@ def test_with_loop_start(self, fake_broker: FakeBroker):
"test_with_loop_start",
protocol=MQTTProtocolVersion.MQTTv311,
reconnect_on_failure=False,
transport=fake_broker.transport
)

on_connect_reached = threading.Event()
Expand Down Expand Up @@ -311,6 +318,7 @@ def test_with_loop(self, fake_broker: FakeBroker):
CallbackAPIVersion.VERSION1,
"test_with_loop",
clean_session=True,
transport=fake_broker.transport,
)

on_connect_reached = threading.Event()
Expand Down Expand Up @@ -367,6 +375,7 @@ def test_publish_before_connect(self, fake_broker: FakeBroker) -> None:
mqttc = client.Client(
CallbackAPIVersion.VERSION1,
"test_publish_before_connect",
transport=fake_broker.transport,
)

def on_connect(mqttc, obj, flags, rc):
Expand Down Expand Up @@ -424,7 +433,7 @@ def on_connect(mqttc, obj, flags, rc):
])
class TestPublishBroker2Client:
def test_invalid_utf8_topic(self, callback_version, fake_broker):
mqttc = client.Client(callback_version, "client-id")
mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport)

def on_message(client, userdata, msg):
with pytest.raises(UnicodeDecodeError):
Expand Down Expand Up @@ -466,7 +475,7 @@ def on_message(client, userdata, msg):
assert not packet_in # Check connection is closed

def test_valid_utf8_topic_recv(self, callback_version, fake_broker):
mqttc = client.Client(callback_version, "client-id")
mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport)

# It should be non-ascii multi-bytes character
topic = unicodedata.lookup('SNOWMAN')
Expand Down Expand Up @@ -512,7 +521,7 @@ def on_message(client, userdata, msg):
assert not packet_in # Check connection is closed

def test_valid_utf8_topic_publish(self, callback_version, fake_broker):
mqttc = client.Client(callback_version, "client-id")
mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport)

# It should be non-ascii multi-bytes character
topic = unicodedata.lookup('SNOWMAN')
Expand Down Expand Up @@ -558,7 +567,7 @@ def test_valid_utf8_topic_publish(self, callback_version, fake_broker):
assert not packet_in # Check connection is closed

def test_message_callback(self, callback_version, fake_broker):
mqttc = client.Client(callback_version, "client-id")
mqttc = client.Client(callback_version, "client-id", transport=fake_broker.transport)
userdata = {
'on_message': 0,
'callback1': 0,
Expand Down Expand Up @@ -698,6 +707,7 @@ def test_callback_v1_mqtt3(self, fake_broker):
CallbackAPIVersion.VERSION1,
"client-id",
userdata=callback_called,
transport=fake_broker.transport,
)

def on_connect(cl, userdata, flags, rc):
Expand Down Expand Up @@ -823,6 +833,7 @@ def test_callback_v2_mqtt3(self, fake_broker):
CallbackAPIVersion.VERSION2,
"client-id",
userdata=callback_called,
transport=fake_broker.transport,
)

def on_connect(cl, userdata, flags, reason, properties):
Expand Down
36 changes: 26 additions & 10 deletions tests/testsupport/broker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import os
import socket
import socketserver
import threading
Expand All @@ -9,18 +10,27 @@


class FakeBroker:
def __init__(self):
# Bind to "localhost" for maximum performance, as described in:
# http://docs.python.org/howto/sockets.html#ipc
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
def __init__(self, transport):
if transport == "tcp":
# Bind to "localhost" for maximum performance, as described in:
# http://docs.python.org/howto/sockets.html#ipc
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", 0))
self.port = sock.getsockname()[1]
elif transport == "unix":
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.bind("localhost")
self.port = 1883
else:
raise ValueError(f"unsupported transport {transport}")

sock.settimeout(5)
sock.bind(("localhost", 0))
self.port = sock.getsockname()[1]
sock.listen(1)

self._sock = sock
self._conn = None
self.transport = transport

def start(self):
if self._sock is None:
Expand All @@ -39,6 +49,12 @@ def finish(self):
self._sock.close()
self._sock = None

if self.transport == 'unix':
try:
os.unlink('localhost')
except OSError:
pass

def receive_packet(self, num_bytes):
if self._conn is None:
raise ValueError('Connection is not open')
Expand All @@ -60,10 +76,10 @@ def expect_packet(self, name, packet):
paho_test.expect_packet(self._conn, name, packet)


@pytest.fixture
def fake_broker():
@pytest.fixture(params=["tcp"] + (["unix"] if hasattr(socket, 'AF_UNIX') else []))
def fake_broker(request):
# print('Setup broker')
broker = FakeBroker()
broker = FakeBroker(request.param)

yield broker

Expand Down
Loading