From c39e9c9e5270049ef7b28b49f7de1857fdd696ed Mon Sep 17 00:00:00 2001 From: "panxuchen.pxc" Date: Thu, 12 Sep 2024 18:40:58 +0800 Subject: [PATCH] fix linux test --- src/agentscope/rpc/__init__.py | 6 +--- src/agentscope/rpc/rpc_client.py | 16 ++-------- src/agentscope/rpc/rpc_object.py | 51 ++++++++++++++++++++++---------- tests/async_result_pool_test.py | 21 +++++++------ 4 files changed, 49 insertions(+), 45 deletions(-) diff --git a/src/agentscope/rpc/__init__.py b/src/agentscope/rpc/__init__.py index 07881c6a3..cf2f350cf 100644 --- a/src/agentscope/rpc/__init__.py +++ b/src/agentscope/rpc/__init__.py @@ -1,9 +1,6 @@ # -*- coding: utf-8 -*- """Import all rpc related modules in the package.""" -from .rpc_client import ( - RpcClient, - call_func_in_thread, -) +from .rpc_client import RpcClient from .rpc_meta import async_func, sync_func, RpcMeta from .rpc_config import DistConf from .rpc_async import AsyncResult @@ -18,5 +15,4 @@ "sync_func", "AsyncResult", "DistConf", - "call_func_in_thread", ] diff --git a/src/agentscope/rpc/rpc_client.py b/src/agentscope/rpc/rpc_client.py index fe1311207..1df1edbfc 100644 --- a/src/agentscope/rpc/rpc_client.py +++ b/src/agentscope/rpc/rpc_client.py @@ -3,8 +3,8 @@ import json import os -from typing import Optional, Sequence, Union, Generator, Callable, Any -from concurrent.futures import Future, ThreadPoolExecutor +from typing import Optional, Sequence, Union, Generator, Any +from concurrent.futures import ThreadPoolExecutor from loguru import logger from ..message import Msg @@ -327,18 +327,6 @@ def __reduce__(self) -> tuple: ) -def call_func_in_thread(func: Callable) -> Future: - """Call a function in a sub-thread. - - Args: - func (`Callable`): The function to be called in sub-thread. - - Returns: - `Future`: A stub to get the response. - """ - return RpcClient._EXECUTOR.submit(func) # pylint: disable=W0212 - - class RpcAgentClient(RpcClient): """`RpcAgentClient` has renamed to `RpcClient`. This class is kept for backward compatibility, please use `RpcClient` diff --git a/src/agentscope/rpc/rpc_object.py b/src/agentscope/rpc/rpc_object.py index a3db16726..c62e9d86e 100644 --- a/src/agentscope/rpc/rpc_object.py +++ b/src/agentscope/rpc/rpc_object.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- """A proxy object which represent a object located in a rpc server.""" from typing import Any, Callable -from functools import partial from abc import ABC from inspect import getmembers, isfunction from types import FunctionType +from concurrent.futures import ThreadPoolExecutor +import threading try: import cloudpickle as pickle @@ -13,7 +14,7 @@ pickle = ImportErrorReporter(e, "distribute") -from .rpc_client import RpcClient, call_func_in_thread +from .rpc_client import RpcClient from .rpc_async import AsyncResult from ..exception import AgentCreationError, AgentServerNotAliveError @@ -27,6 +28,28 @@ def get_public_methods(cls: type) -> list[str]: ] +class _RpcThreadPool: + """Executor for rpc object tasks.""" + + _executor = None + _lock = threading.Lock() + + def __init__(self, max_workers: int = 32) -> None: + if _RpcThreadPool._executor is None: + with _RpcThreadPool._lock: + if _RpcThreadPool._executor is None: + _RpcThreadPool._executor = ThreadPoolExecutor( + max_workers=max_workers, + ) + + @classmethod + def submit(cls, fn: Callable, *args: Any, **kwargs: Any) -> Any: + """Submit a task to the executor.""" + if cls._executor is None: + cls() + return cls._executor.submit(fn, *args, **kwargs) + + class RpcObject(ABC): """A proxy object which represent an object located in a rpc server.""" @@ -65,6 +88,7 @@ def __init__( self._oid = oid self._cls = cls self.connect_existing = connect_existing + self.executor = ThreadPoolExecutor(max_workers=1) from ..studio._client import _studio_client @@ -108,15 +132,14 @@ def __init__( def create(self, configs: dict) -> None: """create the object on the rpc server.""" - self._creating_stub = call_func_in_thread( - partial( - self.client.create_agent, - configs, - self._oid, - ), + self._creating_stub = _RpcThreadPool.submit( + self.client.create_agent, + configs, + self._oid, ) def __call__(self, *args: Any, **kwargs: Any) -> Any: + self._check_created() if "__call__" in self._cls._async_func: return self._async_func("__call__")(*args, **kwargs) else: @@ -159,7 +182,6 @@ def _check_created(self) -> None: def _call_func(self, func_name: str, args: dict) -> Any: """Call a function in rpc server.""" - self._check_created() return pickle.loads( self.client.call_agent_func( agent_id=self._oid, @@ -173,12 +195,10 @@ def async_wrapper(*args, **kwargs) -> Any: # type: ignore[no-untyped-def] return AsyncResult( host=self.host, port=self.port, - stub=call_func_in_thread( - partial( - self._call_func, - func_name=name, - args={"args": args, "kwargs": kwargs}, - ), + stub=_RpcThreadPool.submit( + self._call_func, + func_name=name, + args={"args": args, "kwargs": kwargs}, ), ) @@ -194,6 +214,7 @@ def sync_wrapper(*args, **kwargs) -> Any: # type: ignore[no-untyped-def] return sync_wrapper def __getattr__(self, name: str) -> Callable: + self._check_created() if name in self._cls._async_func: # for async functions return self._async_func(name) diff --git a/tests/async_result_pool_test.py b/tests/async_result_pool_test.py index 5b8122eb7..871ec3fb5 100644 --- a/tests/async_result_pool_test.py +++ b/tests/async_result_pool_test.py @@ -2,12 +2,11 @@ """Test the async result pool.""" import unittest import time -import functools import pickle from loguru import logger -from agentscope.rpc import call_func_in_thread +from agentscope.rpc.rpc_object import _RpcThreadPool from agentscope.server.async_result_pool import ( AsyncResultPool, get_pool, @@ -38,18 +37,18 @@ def _test_result_pool(self, pool: AsyncResultPool) -> None: for target_value in range(10): oid = pool.prepare() get_stubs.append( - call_func_in_thread( - functools.partial(test_get_func, oid=oid, pool=pool), + _RpcThreadPool.submit( + test_get_func, + oid=oid, + pool=pool, ), ) set_stubs.append( - call_func_in_thread( - functools.partial( - test_set_func, - oid=oid, - value=target_value, - pool=pool, - ), + _RpcThreadPool.submit( + test_set_func, + oid=oid, + value=target_value, + pool=pool, ), ) et = time.time()