diff --git a/README.rst b/README.rst index 13bf5500..ad59663f 100644 --- a/README.rst +++ b/README.rst @@ -954,6 +954,46 @@ and ``MQTT_LOG_DEBUG``. The message itself is in ``buf``. This may be used at the same time as the standard Python logging, which can be enabled via the ``enable_logger`` method. +on_socket_open() +'''''''''''''''' + +:: + + on_socket_open(client, userdata, sock) + +Called when the socket has been opened. +Use this to register the socket with an external event loop for reading. + +on_socket_close() +''''''''''''''''' + +:: + + on_socket_close(client, userdata, sock) + +Called when the socket is about to be closed. +Use this to unregister a socket from an external event loop for reading. + +on_socket_register_write() +'''''''''''''''''''''''''' + +:: + + on_socket_register_write(client, userdata, sock) + +Called when a write operation to the socket failed because it would have blocked, e.g. output buffer full. +Use this to register the socket with an external event loop for writing. + +on_socket_unregister_write() +'''''''''''''''''''''''''''' + +:: + + on_socket_unregister_write(client, userdata, sock) + +Called when a write operation to the socket succeeded after it had previously failed. +Use this to unregister the socket from an external event loop for writing. + External event loop support ``````````````````````````` @@ -995,6 +1035,9 @@ socket() Returns the socket object in use in the client to allow interfacing with other event loops. +This call is particularly useful for select_ based loops. See ``examples/loop_select.py``. + +.. _select: https://docs.python.org/3/library/select.html#select.select want_write() '''''''''''' @@ -1005,6 +1048,46 @@ want_write() Returns true if there is data waiting to be written, to allow interfacing the client with other event loops. +This call is particularly useful for select_ based loops. See ``examples/loop_select.py``. + +.. _select: https://docs.python.org/3/library/select.html#select.select + +state callbacks +''''''''''''''' + +:: + + on_socket_open + on_socket_close + on_socket_register_write + on_socket_unregister_write + +Use these callbacks to get notified about state changes in the socket. +This is particularly useful for event loops where you register or unregister a socket +for reading+writing. See ``examples/loop_asyncio.py`` for an example. + +When the socket is opened, ``on_socket_open`` is called. +Register the socket with your event loop for reading. + +When the socket is about to be closed, ``on_socket_close`` is called. +Unregister the socket from your event loop for reading. + +When a write to the socket failed because it would have blocked, e.g. output buffer full, +``on_socket_register_write`` is called. +Register the socket with your event loop for writing. + +When the next write to the socket succeeded, ``on_socket_unregister_write`` is called. +Unregister the socket from your event loop for writing. + +The callbacks are always called in this order: + +- ``on_socket_open`` +- Zero or more times: + + - ``on_socket_register_write`` + - ``on_socket_unregister_write`` + +- ``on_socket_close`` Global helper functions ``````````````````````` diff --git a/examples/loop_asyncio.py b/examples/loop_asyncio.py new file mode 100755 index 00000000..07b1ea5a --- /dev/null +++ b/examples/loop_asyncio.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 + +import socket +import uuid +import paho.mqtt.client as mqtt +import asyncio + +client_id = 'paho-mqtt-python/issue72/' + str(uuid.uuid4()) +topic = client_id +print("Using client_id / topic: " + client_id) + + +class AsyncioHelper: + def __init__(self, loop, client): + self.loop = loop + self.client = client + self.client.on_socket_open = self.on_socket_open + self.client.on_socket_close = self.on_socket_close + self.client.on_socket_register_write = self.on_socket_register_write + self.client.on_socket_unregister_write = self.on_socket_unregister_write + + def on_socket_open(self, client, userdata, sock): + print("Socket opened") + + def cb(): + print("Socket is readable, calling loop_read") + client.loop_read() + + self.loop.add_reader(sock, cb) + self.misc = self.loop.create_task(self.misc_loop()) + + def on_socket_close(self, client, userdata, sock): + print("Socket closed") + self.loop.remove_reader(sock) + self.misc.cancel() + + def on_socket_register_write(self, client, userdata, sock): + print("Watching socket for writability.") + + def cb(): + print("Socket is writable, calling loop_write") + client.loop_write() + + self.loop.add_writer(sock, cb) + + def on_socket_unregister_write(self, client, userdata, sock): + print("Stop watching socket for writability.") + self.loop.remove_writer(sock) + + async def misc_loop(self): + print("misc_loop started") + while self.client.loop_misc() == mqtt.MQTT_ERR_SUCCESS: + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + break + print("misc_loop finished") + + +class AsyncMqttExample: + def __init__(self, loop): + self.loop = loop + + def on_connect(self, client, userdata, flags, rc): + print("Subscribing") + client.subscribe(topic) + + def on_message(self, client, userdata, msg): + if not self.got_message: + print("Got unexpected message: {}".format(msg.decode())) + else: + self.got_message.set_result(msg.payload) + + def on_disconnect(self, client, userdata, rc): + self.disconnected.set_result(rc) + + async def main(self): + self.disconnected = self.loop.create_future() + self.got_message = None + + self.client = mqtt.Client(client_id=client_id) + self.client.on_connect = self.on_connect + self.client.on_message = self.on_message + self.client.on_disconnect = self.on_disconnect + + aioh = AsyncioHelper(self.loop, self.client) + + self.client.connect('iot.eclipse.org', 1883, 60) + self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) + + for c in range(3): + await asyncio.sleep(5) + print("Publishing") + self.got_message = self.loop.create_future() + self.client.publish(topic, b'Hello' * 40000, qos=1) + msg = await self.got_message + print("Got response with {} bytes".format(len(msg))) + self.got_message = None + + self.client.disconnect() + print("Disconnected: {}".format(await self.disconnected)) + + +print("Starting") +loop = asyncio.get_event_loop() +loop.run_until_complete(AsyncMqttExample(loop).main()) +loop.close() +print("Finished") diff --git a/examples/loop_select.py b/examples/loop_select.py new file mode 100755 index 00000000..626eef9a --- /dev/null +++ b/examples/loop_select.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 + +import socket +import uuid +import paho.mqtt.client as mqtt +from select import select +from time import time + +client_id = 'paho-mqtt-python/issue72/' + str(uuid.uuid4()) +topic = client_id +print("Using client_id / topic: " + client_id) + + +class SelectMqttExample: + def __init__(self): + pass + + def on_connect(self, client, userdata, flags, rc): + print("Subscribing") + client.subscribe(topic) + + def on_message(self, client, userdata, msg): + if self.state not in {1, 3, 5}: + print("Got unexpected message: {}".format(msg.decode())) + return + + print("Got message with len {}".format(len(msg.payload))) + self.state += 1 + self.t = time() + + def on_disconnect(self, client, userdata, rc): + self.disconnected = True, rc + + def do_select(self): + sock = self.client.socket() + if not sock: + raise Exception("Socket is gone") + + print("Selecting for reading" + (" and writing" if self.client.want_write() else "")) + r, w, e = select( + [sock], + [sock] if self.client.want_write() else [], + [], + 1 + ) + + if sock in r: + print("Socket is readable, calling loop_read") + self.client.loop_read() + + if sock in w: + print("Socket is writable, calling loop_write") + self.client.loop_write() + + self.client.loop_misc() + + def main(self): + self.disconnected = (False, None) + self.t = time() + self.state = 0 + + self.client = mqtt.Client(client_id=client_id) + self.client.on_connect = self.on_connect + self.client.on_message = self.on_message + self.client.on_disconnect = self.on_disconnect + + self.client.connect('iot.eclipse.org', 1883, 60) + print("Socket opened") + self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) + + while not self.disconnected[0]: + self.do_select() + + if self.state in {0, 2, 4}: + if time() - self.t >= 5: + print("Publishing") + self.client.publish(topic, b'Hello' * 40000) + self.state += 1 + + if self.state == 6: + self.state += 1 + self.client.disconnect() + + print("Disconnected: {}".format(self.disconnected[1])) + + +print("Starting") +SelectMqttExample().main() +print("Finished") diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index 90fe597c..c6214a74 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -148,6 +148,10 @@ class WebsocketConnectionError(ValueError): pass +class WouldBlockError(Exception): + pass + + def error_string(mqtt_errno): """Return the error string associated with an mqtt error number.""" if mqtt_errno == MQTT_ERR_SUCCESS: @@ -448,6 +452,19 @@ def on_connect(client, userdata, flags, rc): and will be one of MQTT_LOG_INFO, MQTT_LOG_NOTICE, MQTT_LOG_WARNING, MQTT_LOG_ERR, and MQTT_LOG_DEBUG. The message itself is in buf. + on_socket_open(client, userdata, sock): Called when the socket has been opened. Use this + to register the socket with an external event loop for reading. + + on_socket_close(client, userdata, sock): Called when the socket is about to be closed. + Use this to unregister a socket from an external event loop for reading. + + on_socket_register_write(client, userdata, sock): Called when a write operation to the + socket failed because it would have blocked, e.g. output buffer full. Use this to + register the socket with an external event loop for writing. + + on_socket_unregister_write(client, userdata, sock): Called when a write operation to the + socket succeeded after it had previously failed. Use this to unregister the socket + from an external event loop for writing. """ def __init__(self, client_id="", clean_session=True, userdata=None, @@ -559,6 +576,7 @@ def __init__(self, client_id="", clean_session=True, userdata=None, self._ssl_context = None self._tls_insecure = False # Only used when SSL context does not have check_hostname attribute self._logger = None + self._registered_write = False # No default callbacks self._on_log = None self._on_connect = None @@ -567,16 +585,60 @@ def __init__(self, client_id="", clean_session=True, userdata=None, self._on_publish = None self._on_unsubscribe = None self._on_disconnect = None + self._on_socket_open = None + self._on_socket_close = None + self._on_socket_register_write = None + self._on_socket_unregister_write = None self._websocket_path = "/mqtt" self._websocket_extra_headers = None def __del__(self): pass - def reinitialise(self, client_id="", clean_session=True, userdata=None): - if self._sock: - self._sock.close() + def _sock_recv(self, bufsize): + try: + return self._sock.recv(bufsize) + except socket.error as err: + if self._ssl and err.errno == ssl.SSL_ERROR_WANT_READ: + raise WouldBlockError() + if self._ssl and err.errno == ssl.SSL_ERROR_WANT_WRITE: + self._call_socket_register_write() + raise WouldBlockError() + if err.errno == EAGAIN: + raise WouldBlockError() + raise + + def _sock_send(self, buf): + try: + return self._sock.send(buf) + except socket.error as err: + if self._ssl and err.errno == ssl.SSL_ERROR_WANT_READ: + raise WouldBlockError() + if self._ssl and err.errno == ssl.SSL_ERROR_WANT_WRITE: + self._call_socket_register_write() + raise WouldBlockError() + if err.errno == EAGAIN: + self._call_socket_register_write() + raise WouldBlockError() + raise + + def _sock_close(self): + """Close the connection to the server.""" + if not self._sock: + return + + try: + sock = self._sock self._sock = None + self._call_socket_unregister_write(sock) + self._call_socket_close(sock) + finally: + # In case a callback fails, still close the socket to avoid leaking the file descriptor. + sock.close() + + def reinitialise(self, client_id="", clean_session=True, userdata=None): + self._sock_close() + if self._sockpairR: self._sockpairR.close() self._sockpairR = None @@ -870,9 +932,7 @@ def reconnect(self): self._ping_t = 0 self._state = mqtt_cs_new - if self._sock: - self._sock.close() - self._sock = None + self._sock_close() # Put messages in progress in a valid state. self._messages_reconnect_reset() @@ -925,6 +985,8 @@ def reconnect(self): self._sock = sock self._sock.setblocking(0) + self._registered_write = False + self._call_socket_open() return self._send_connect(self._keepalive, self._clean_session) @@ -1291,13 +1353,19 @@ def loop_write(self, max_packets=1): if max_packets < 1: max_packets = 1 - for _ in range(0, max_packets): - rc = self._packet_write() - if rc > 0: - return self._loop_rc_handle(rc) - elif rc == MQTT_ERR_AGAIN: - return MQTT_ERR_SUCCESS - return MQTT_ERR_SUCCESS + try: + for _ in range(0, max_packets): + rc = self._packet_write() + if rc > 0: + return self._loop_rc_handle(rc) + elif rc == MQTT_ERR_AGAIN: + return MQTT_ERR_SUCCESS + return MQTT_ERR_SUCCESS + finally: + if self.want_write(): + self._call_socket_register_write() + else: + self._call_socket_unregister_write() def want_write(self): """Call to determine if there is network data waiting to be written. @@ -1326,9 +1394,7 @@ def loop_misc(self): if self._ping_t > 0 and now - self._ping_t >= self._keepalive: # client->ping_t != 0 means we are waiting for a pingresp. # This hasn't happened in the keepalive time so we should disconnect. - if self._sock: - self._sock.close() - self._sock = None + self._sock_close() if self._state == mqtt_cs_disconnecting: rc = MQTT_ERR_SUCCESS @@ -1690,7 +1756,7 @@ def on_disconnect(self, func): """ Define the disconnect callback implementation. Expected signature is: - disconnect_callback(client, userdata, self) + disconnect_callback(client, userdata, rc) client: the client instance for this callback userdata: the private user data as set in Client() or userdata_set() @@ -1703,6 +1769,124 @@ def on_disconnect(self, func): with self._callback_mutex: self._on_disconnect = func + @property + def on_socket_open(self): + """If implemented, called just after the socket was opend.""" + return self._on_socket_open + + @on_socket_open.setter + def on_socket_open(self, func): + """Define the socket_open callback implementation. + + This should be used to register the socket to an external event loop for reading. + + Expected signature is: + socket_open_callback(client, userdata, socket) + + client: the client instance for this callback + userdata: the private user data as set in Client() or userdata_set() + sock: the socket which was just opened. + """ + with self._callback_mutex: + self._on_socket_open = func + + def _call_socket_open(self): + """Call the socket_open callback with the just-opened socket""" + with self._callback_mutex: + if self.on_socket_open: + with self._in_callback: + self.on_socket_open(self, self._userdata, self._sock) + + @property + def on_socket_close(self): + """If implemented, called just before the socket is closed.""" + return self._on_socket_close + + @on_socket_close.setter + def on_socket_close(self, func): + """Define the socket_close callback implementation. + + This should be used to unregister the socket from an external event loop for reading. + + Expected signature is: + socket_close_callback(client, userdata, socket) + + client: the client instance for this callback + userdata: the private user data as set in Client() or userdata_set() + sock: the socket which is about to be closed. + """ + with self._callback_mutex: + self._on_socket_close = func + + def _call_socket_close(self, sock): + """Call the socket_close callback with the about-to-be-closed socket""" + with self._callback_mutex: + if self.on_socket_close: + with self._in_callback: + self.on_socket_close(self, self._userdata, sock) + + @property + def on_socket_register_write(self): + """If implemented, called when the socket needs writing but can't.""" + return self._on_socket_register_write + + @on_socket_register_write.setter + def on_socket_register_write(self, func): + """Define the socket_register_write callback implementation. + + This should be used to register the socket with an external event loop for writing. + + Expected signature is: + socket_register_write_callback(client, userdata, socket) + + client: the client instance for this callback + userdata: the private user data as set in Client() or userdata_set() + sock: the socket which should be registered for writing + """ + with self._callback_mutex: + self._on_socket_register_write = func + + def _call_socket_register_write(self): + """Call the socket_register_write callback with the unwritable socket""" + if not self._sock or self._registered_write: + return + self._registered_write = True + with self._callback_mutex: + if self.on_socket_register_write: + self.on_socket_register_write(self, self._userdata, self._sock) + + @property + def on_socket_unregister_write(self): + """If implemented, called when the socket doesn't need writing anymore.""" + return self._on_socket_unregister_write + + @on_socket_unregister_write.setter + def on_socket_unregister_write(self, func): + """Define the socket_unregister_write callback implementation. + + This should be used to unregister the socket from an external event loop for writing. + + Expected signature is: + socket_unregister_write_callback(client, userdata, socket) + + client: the client instance for this callback + userdata: the private user data as set in Client() or userdata_set() + sock: the socket which should be unregistered for writing + """ + with self._callback_mutex: + self._on_socket_unregister_write = func + + def _call_socket_unregister_write(self, sock=None): + """Call the socket_unregister_write callback with the writable socket""" + sock = sock or self._sock + if not sock or not self._registered_write: + return + self._registered_write = False + + with self._callback_mutex: + if self.on_socket_unregister_write: + self.on_socket_unregister_write(self, self._userdata, sock) + def message_callback_add(self, sub, callback): """Register a message callback for a specific topic. Messages that match 'sub' will be passed to 'callback'. Any @@ -1738,9 +1922,7 @@ def message_callback_remove(self, sub): def _loop_rc_handle(self, rc): if rc: - if self._sock: - self._sock.close() - self._sock = None + self._sock_close() if self._state == mqtt_cs_disconnecting: rc = MQTT_ERR_SUCCESS @@ -1767,12 +1949,10 @@ def _packet_read(self): # Finally, free the memory and reset everything to starting conditions. if self._in_packet['command'] == 0: try: - command = self._sock.recv(1) + command = self._sock_recv(1) + except WouldBlockError: + return MQTT_ERR_AGAIN except socket.error as err: - if self._ssl and (err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE): - return MQTT_ERR_AGAIN - if err.errno == EAGAIN: - return MQTT_ERR_AGAIN self._easy_log(MQTT_LOG_ERR, 'failed to receive on socket: %s', err) return 1 else: @@ -1787,12 +1967,10 @@ def _packet_read(self): # http://publib.boulder.ibm.com/infocenter/wmbhelp/v6r0m0/topic/com.ibm.etools.mft.doc/ac10870_.htm while True: try: - byte = self._sock.recv(1) + byte = self._sock_recv(1) + except WouldBlockError: + return MQTT_ERR_AGAIN except socket.error as err: - if self._ssl and (err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE): - return MQTT_ERR_AGAIN - if err.errno == EAGAIN: - return MQTT_ERR_AGAIN self._easy_log(MQTT_LOG_ERR, 'failed to receive on socket: %s', err) return 1 else: @@ -1816,12 +1994,10 @@ def _packet_read(self): while self._in_packet['to_process'] > 0: try: - data = self._sock.recv(self._in_packet['to_process']) + data = self._sock_recv(self._in_packet['to_process']) + except WouldBlockError: + return MQTT_ERR_AGAIN except socket.error as err: - if self._ssl and (err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE): - return MQTT_ERR_AGAIN - if err.errno == EAGAIN: - return MQTT_ERR_AGAIN self._easy_log(MQTT_LOG_ERR, 'failed to receive on socket: %s', err) return 1 else: @@ -1856,16 +2032,15 @@ def _packet_write(self): packet = self._current_out_packet try: - write_length = self._sock.send(packet['packet'][packet['pos']:]) + write_length = self._sock_send(packet['packet'][packet['pos']:]) except (AttributeError, ValueError): self._current_out_packet_mutex.release() return MQTT_ERR_SUCCESS + except WouldBlockError: + self._current_out_packet_mutex.release() + return MQTT_ERR_AGAIN except socket.error as err: self._current_out_packet_mutex.release() - if self._ssl and (err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE): - return MQTT_ERR_AGAIN - if err.errno == EAGAIN: - return MQTT_ERR_AGAIN self._easy_log(MQTT_LOG_ERR, 'failed to receive on socket: %s', err) return 1 @@ -1893,9 +2068,7 @@ def _packet_write(self): with self._in_callback: self.on_disconnect(self, self._userdata, 0) - if self._sock: - self._sock.close() - self._sock = None + self._sock_close() return MQTT_ERR_SUCCESS with self._out_packet_mutex: @@ -1938,9 +2111,7 @@ def _check_keepalive(self): self._last_msg_out = now self._last_msg_in = now else: - if self._sock: - self._sock.close() - self._sock = None + self._sock_close() if self._state == mqtt_cs_disconnecting: rc = MQTT_ERR_SUCCESS @@ -2266,6 +2437,8 @@ def _packet_queue(self, command, packet, mid, qos, info=None): self._in_callback.release() return self.loop_write() + self._call_socket_register_write() + return MQTT_ERR_SUCCESS def _packet_handle(self):