diff --git a/python/xoscar/backends/communication/dummy.py b/python/xoscar/backends/communication/dummy.py index 7909b24e..4174bca4 100644 --- a/python/xoscar/backends/communication/dummy.py +++ b/python/xoscar/backends/communication/dummy.py @@ -17,8 +17,9 @@ import asyncio import concurrent.futures as futures +import logging import weakref -from typing import Any, Callable, Coroutine, Dict, Type +from typing import Any, Callable, Coroutine, Dict, Optional, Type from urllib.parse import urlparse from ...errors import ServerClosed @@ -29,13 +30,15 @@ DEFAULT_DUMMY_ADDRESS = "dummy://0" +logger = logging.getLogger(__name__) + class DummyChannel(Channel): """ Channel for communications in same process. """ - __slots__ = "_in_queue", "_out_queue", "_closed" + __slots__ = "__weakref__", "_in_queue", "_out_queue", "_closed" name = "dummy" @@ -100,8 +103,8 @@ class DummyServer(Server): _address_to_instances: weakref.WeakValueDictionary[str, "DummyServer"] = ( weakref.WeakValueDictionary() ) - _channels: list[ChannelType] - _tasks: list[asyncio.Task] + _channels: weakref.WeakSet[Channel] + _tasks: set[asyncio.Task] scheme: str | None = "dummy" def __init__( @@ -111,8 +114,8 @@ def __init__( ): super().__init__(address, channel_handler) self._closed = asyncio.Event() - self._channels = [] - self._tasks = [] + self._channels = weakref.WeakSet() + self._tasks = set() @classmethod def get_instance(cls, address: str): @@ -178,7 +181,7 @@ async def on_connected(self, *args, **kwargs): f"{type(self).__name__} got unexpected " f'arguments: {",".join(kwargs)}' ) - self._channels.append(channel) + self._channels.add(channel) await self.channel_handler(channel) @implements(Server.stop) @@ -203,6 +206,7 @@ def __init__( self, local_address: str | None, dest_address: str | None, channel: Channel ): super().__init__(local_address, dest_address, channel) + self._task: Optional[asyncio.Task] = None @staticmethod @implements(Client.connect) @@ -232,11 +236,25 @@ async def connect( task = asyncio.create_task(conn_coro) client = DummyClient(local_address, dest_address, client_channel) client._task = task - server._tasks.append(task) + server._tasks.add(task) + + def _discard(t): + server._tasks.discard(t) + logger.info("Channel exit: %s", server_channel.info) + + task.add_done_callback(_discard) return client @implements(Client.close) async def close(self): await super().close() - self._task.cancel() - self._task = None + if self._task is not None: + task_loop = self._task.get_loop() + if task_loop is not None: + if not task_loop.is_running(): + logger.warning( + "Dummy channel cancel task on a stopped loop, dest address: %s.", + self.dest_address, + ) + self._task.cancel() + self._task = None diff --git a/python/xoscar/backends/communication/socket.py b/python/xoscar/backends/communication/socket.py index 4f6718b5..4f38e50c 100644 --- a/python/xoscar/backends/communication/socket.py +++ b/python/xoscar/backends/communication/socket.py @@ -113,7 +113,7 @@ def closed(self): class _BaseSocketServer(Server, metaclass=ABCMeta): __slots__ = "_aio_server", "_channels" - _channels: list[ChannelType] + _channels: set[Channel] def __init__( self, @@ -124,7 +124,7 @@ def __init__( super().__init__(address, channel_handler) # asyncio.Server self._aio_server = aio_server - self._channels = [] + self._channels = set() @implements(Server.start) async def start(self): @@ -170,9 +170,16 @@ async def on_connected(self, *args, **kwargs): dest_address=dest_address, channel_type=self.channel_type, ) - self._channels.append(channel) + self._channels.add(channel) # handle over channel to some handlers - await self.channel_handler(channel) + try: + await self.channel_handler(channel) + finally: + if not channel.closed: + await channel.close() + # Remove channel if channel exit + self._channels.discard(channel) + logger.debug("Channel exit: %s", channel.info) @implements(Server.stop) async def stop(self): @@ -185,6 +192,7 @@ async def stop(self): await asyncio.gather( *(channel.close() for channel in self._channels if not channel.closed) ) + self._channels.clear() @property @implements(Server.stopped) diff --git a/python/xoscar/backends/communication/ucx.py b/python/xoscar/backends/communication/ucx.py index 26e40b94..26d69a85 100644 --- a/python/xoscar/backends/communication/ucx.py +++ b/python/xoscar/backends/communication/ucx.py @@ -368,7 +368,7 @@ class UCXServer(Server): scheme = "ucx" _ucp_listener: "ucp.Listener" # type: ignore - _channels: List[UCXChannel] + _channels: set[UCXChannel] def __init__( self, @@ -381,7 +381,7 @@ def __init__( self.host = host self.port = port self._ucp_listener = ucp_listener - self._channels = [] + self._channels = set() self._closed = asyncio.Event() @classproperty @@ -469,9 +469,16 @@ async def on_connected(self, *args, **kwargs): channel = UCXChannel( ucp_endpoint, local_address=local_address, dest_address=dest_address ) - self._channels.append(channel) + self._channels.add(channel) # handle over channel to some handlers - await self.channel_handler(channel) + try: + await self.channel_handler(channel) + finally: + if not channel.closed: + await channel.close() + # Remove channel if channel exit + self._channels.discard(channel) + logger.debug("Channel exit: %s", channel.info) @implements(Server.stop) async def stop(self): @@ -480,7 +487,7 @@ async def stop(self): await asyncio.gather( *(channel.close() for channel in self._channels if not channel.closed) ) - self._channels = [] + self._channels.clear() self._ucp_listener = None self._closed.set() diff --git a/python/xoscar/backends/core.py b/python/xoscar/backends/core.py index c8fb54ff..4a0dade0 100644 --- a/python/xoscar/backends/core.py +++ b/python/xoscar/backends/core.py @@ -18,6 +18,8 @@ import asyncio import copy import logging +import threading +import weakref from typing import Type, Union from .._utils import Timer @@ -31,8 +33,8 @@ logger = logging.getLogger(__name__) -class ActorCaller: - __slots__ = "_client_to_message_futures", "_clients", "_profiling_data" +class ActorCallerThreadLocal: + __slots__ = ("_client_to_message_futures", "_clients", "_profiling_data") _client_to_message_futures: dict[Client, dict[bytes, asyncio.Future]] _clients: dict[Client, asyncio.Task] @@ -193,6 +195,7 @@ async def call( return await self.call_with_client(client, message, wait) async def stop(self): + logger.debug("Actor caller stop.") try: await asyncio.gather(*[client.close() for client in self._clients]) except (ConnectionError, ServerClosed): @@ -202,3 +205,37 @@ async def stop(self): def cancel_tasks(self): # cancel listening for all clients _ = [task.cancel() for task in self._clients.values()] + + +class ActorCaller: + __slots__ = "_thread_local" + + class _RefHolder: + pass + + _close_loop = asyncio.new_event_loop() + _close_thread = threading.Thread(target=_close_loop.run_forever, daemon=True) + _close_thread.start() + + def __init__(self): + self._thread_local = threading.local() + + def __getattr__(self, item): + try: + actor_caller = self._thread_local.actor_caller + except AttributeError: + thread_info = str(threading.current_thread()) + logger.debug("Creating a new actor caller for thread: %s", thread_info) + actor_caller = self._thread_local.actor_caller = ActorCallerThreadLocal() + ref = self._thread_local.ref = ActorCaller._RefHolder() + # If the thread exit, we clean the related actor callers and channels. + + def _cleanup(): + asyncio.run_coroutine_threadsafe(actor_caller.stop(), self._close_loop) + logger.debug( + "Clean up the actor caller due to thread exit: %s", thread_info + ) + + weakref.finalize(ref, _cleanup) + + return getattr(actor_caller, item) diff --git a/python/xoscar/backends/test/tests/test_actor_context.py b/python/xoscar/backends/test/tests/test_actor_context.py index 6eef6c8d..d684778d 100644 --- a/python/xoscar/backends/test/tests/test_actor_context.py +++ b/python/xoscar/backends/test/tests/test_actor_context.py @@ -12,14 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import asyncio +import gc import os import sys +import threading import pytest import xoscar as mo +from ...communication.dummy import DummyServer +from ...router import Router + class DummyActor(mo.Actor): def __init__(self, value): @@ -60,3 +65,90 @@ async def test_simple(actor_pool_context): allocate_strategy=mo.allocate_strategy.RandomSubPool(), ) assert await actor_ref.add(1) == 101 + + +def _cancel_all_tasks(loop): + to_cancel = asyncio.all_tasks(loop) + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) + + +def _run_forever(loop): + loop.run_forever() + _cancel_all_tasks(loop) + + +@pytest.mark.asyncio +async def test_channel_cleanup(actor_pool_context): + pool = actor_pool_context + actor_ref = await mo.create_actor( + DummyActor, + 0, + address=pool.external_address, + allocate_strategy=mo.allocate_strategy.RandomSubPool(), + ) + + curr_router = Router.get_instance() + server_address = curr_router.get_internal_address(actor_ref.address) + dummy_server = DummyServer.get_instance(server_address) + + async def inc(): + await asyncio.gather(*(actor_ref.add.tell(1) for _ in range(10))) + + loops = [] + threads = [] + futures = [] + for _ in range(10): + loop = asyncio.new_event_loop() + t = threading.Thread(target=_run_forever, args=(loop,)) + t.start() + loops.append(loop) + threads.append(t) + fut = asyncio.run_coroutine_threadsafe(inc(), loop=loop) + futures.append(fut) + + for fut in futures: + fut.result() + + while True: + if await actor_ref.add(0) == 100: + break + + assert len(dummy_server._channels) == 12 + assert len(dummy_server._tasks) == 12 + + for loop in loops: + loop.call_soon_threadsafe(loop.stop) + + for t in threads: + t.join() + threads.clear() + + curr_router = Router.get_instance() + server_address = curr_router.get_internal_address(actor_ref.address) + dummy_server = DummyServer.get_instance(server_address) + + while True: + gc.collect() + # Two channels left: + # 1. from the main pool to the actor + # 2. from current main thread to the actor. + if len(dummy_server._channels) == 2 and len(dummy_server._tasks) == 2: + break