diff --git a/jupyter_client/threaded.py b/jupyter_client/threaded.py index dcca6c94e..a62f61fe5 100644 --- a/jupyter_client/threaded.py +++ b/jupyter_client/threaded.py @@ -8,9 +8,11 @@ from threading import Event from threading import Thread from typing import Any +from typing import Awaitable from typing import Dict from typing import List from typing import Optional +from typing import Union import zmq from traitlets import Instance @@ -28,6 +30,10 @@ # during garbage collection of threads at exit +async def get_msg(msg: Awaitable) -> Union[List[bytes], List[zmq.Message]]: + return await msg + + class ThreadedZMQSocketChannel(object): """A ZMQ socket invoking a callback in the ioloop""" @@ -108,11 +114,13 @@ def thread_send(): assert self.ioloop is not None self.ioloop.add_callback(thread_send) - def _handle_recv(self, msg_list: List[bytes]) -> None: + def _handle_recv(self, future_msg: Awaitable) -> None: """Callback for stream.on_recv. Unpacks message, and calls handlers with it. """ + assert self.ioloop is not None + msg_list = self.ioloop._asyncio_event_loop.run_until_complete(get_msg(future_msg)) assert self.session is not None ident, smsg = self.session.feed_identities(msg_list) msg = self.session.deserialize(smsg)