From 79dd9029b3b104b7c116f573ffa69c254abed9fb Mon Sep 17 00:00:00 2001 From: Jialing He Date: Fri, 8 Apr 2022 19:06:20 +0800 Subject: [PATCH] [Ray] Ray client channel get recv when first complied (#2740) --- mars/oscar/backends/ray/communication.py | 55 ++++++++++++++---------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/mars/oscar/backends/ray/communication.py b/mars/oscar/backends/ray/communication.py index 63550fac68..7e1c4207fd 100644 --- a/mars/oscar/backends/ray/communication.py +++ b/mars/oscar/backends/ray/communication.py @@ -120,7 +120,7 @@ class RayChannelBase(Channel, ABC): Channel for communications between ray processes. """ - __slots__ = "_channel_index", "_channel_id", "_in_queue", "_closed" + __slots__ = "_channel_index", "_channel_id", "_closed" name = "ray" _channel_index_gen = itertools.count() @@ -142,7 +142,6 @@ def __init__( self._channel_id = channel_id or ChannelID( local_address, _gen_client_id(), self._channel_index, dest_address ) - self._in_queue = asyncio.Queue() self._closed = asyncio.Event() @property @@ -169,7 +168,7 @@ class RayClientChannel(RayChannelBase): A channel from ray driver/actor to ray actor. Use ray call reply for client channel recv. """ - __slots__ = ("_peer_actor",) + __slots__ = "_peer_actor", "_done", "_todo" def __init__( self, @@ -181,36 +180,47 @@ def __init__( super().__init__(None, dest_address, channel_index, channel_id, compression) # ray actor should be created with the address as the name. self._peer_actor: "ray.actor.ActorHandle" = ray.get_actor(dest_address) + self._done = asyncio.Queue() + self._todo = set() + + def _submit_task(self, message: Any, object_ref: "ray.ObjectRef"): + async def handle_task(message: Any, object_ref: "ray.ObjectRef"): + # use `%.500` to avoid print too long messages + with debug_async_timeout( + "ray_object_retrieval_timeout", "Client sent message is %.500s", message + ): + result = await object_ref + if isinstance(result, RayChannelException): + raise result.exc_value.with_traceback(result.exc_traceback) + return result.message + + def _on_completion(future): + self._todo.remove(future) + self._done.put_nowait(future) + + future = asyncio.ensure_future(handle_task(message, object_ref)) + future.add_done_callback(_on_completion) + self._todo.add(future) @implements(Channel.send) async def send(self, message: Any): if self._closed.is_set(): # pragma: no cover raise ChannelClosed("Channel already closed, cannot send message") - # Put ray object ref to queue - self._in_queue.put_nowait( - ( - message, - self._peer_actor.__on_ray_recv__.remote( - self.channel_id, _ArgWrapper(message) - ), - ) + # Put ray object ref to todo queue + task = self._peer_actor.__on_ray_recv__.remote( + self.channel_id, _ArgWrapper(message) ) + self._submit_task(message, task) + await asyncio.sleep(0) @implements(Channel.recv) async def recv(self): if self._closed.is_set(): # pragma: no cover raise ChannelClosed("Channel already closed, cannot recv message") try: - # Wait on ray object ref - message, object_ref = await self._in_queue.get() - # use `%.500` to avoid print too long messages - with debug_async_timeout( - "ray_object_retrieval_timeout", "Client sent message is %.500s", message - ): - result = await object_ref - if isinstance(result, RayChannelException): - raise result.exc_value.with_traceback(result.exc_traceback) - return result.message + # Wait first done. + future = await self._done.get() + return future.result() except ray.exceptions.RayActorError: if not self._closed.is_set(): # raise a EOFError as the SocketChannel does @@ -228,7 +238,7 @@ class RayServerChannel(RayChannelBase): message's reply. """ - __slots__ = "_out_queue", "_msg_recv_counter", "_msg_sent_counter" + __slots__ = "_in_queue", "_out_queue", "_msg_recv_counter", "_msg_sent_counter" def __init__( self, @@ -238,6 +248,7 @@ def __init__( compression=None, ): super().__init__(local_address, None, channel_index, channel_id, compression) + self._in_queue = asyncio.Queue() self._out_queue = asyncio.Queue() self._msg_recv_counter = 0 self._msg_sent_counter = 0