diff --git a/zmq/eventloop/zmqstream.py b/zmq/eventloop/zmqstream.py index 4108e056a..20c1866e1 100644 --- a/zmq/eventloop/zmqstream.py +++ b/zmq/eventloop/zmqstream.py @@ -614,6 +614,19 @@ def _handle_events(self, fd, events): zmq_events = self.socket.EVENTS except zmq.ContextTerminated: gen_log.warning("Got events for stream %s after terminating context", self) + # trigger close check, this will unregister callbacks + self.closed() + return + except zmq.ZMQError as e: + # run close check + # shadow sockets may have been closed elsewhere, + # which should show up as ENOTSOCK here + if self.closed(): + gen_log.warning( + "Got events for stream %s attached to closed socket: %s", self, e + ) + else: + gen_log.error("Error getting events for %s: %s", self, e) return try: # dispatch events: diff --git a/zmq/tests/test_zmqstream.py b/zmq/tests/test_zmqstream.py index e45193fb0..39d7d4962 100644 --- a/zmq/tests/test_zmqstream.py +++ b/zmq/tests/test_zmqstream.py @@ -4,6 +4,7 @@ import asyncio import logging +import warnings import pytest @@ -130,3 +131,29 @@ async def test_shadow_socket(context): assert type(stream.socket) is zmq.Socket assert stream.socket.underlying == socket.underlying stream.close() + + +async def test_shadow_socket_close(context, caplog): + with context.socket(zmq.PUSH) as push, context.socket(zmq.PULL) as pull: + push.linger = pull.linger = 0 + port = push.bind_to_random_port('tcp://127.0.0.1') + pull.connect(f'tcp://127.0.0.1:{port}') + shadow_pull = zmq.Socket.shadow(pull) + stream = zmqstream.ZMQStream(shadow_pull) + # send some messages + for i in range(10): + push.send_string(str(i)) + # make sure at least one message has been delivered + pull.recv() + # register callback + # this should schedule event callback on the next tick + stream.on_recv(print) + # close the shadowed socket + pull.close() + # run the event loop, which should see some events on the shadow socket + # but the socket has been closed! + with warnings.catch_warnings(record=True) as records: + await asyncio.sleep(0.2) + warning_text = "\n".join(str(r.message) for r in records) + assert "after closing socket" in warning_text + assert "closed socket" in caplog.text