Skip to content

Commit

Permalink
fix linux test
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed Sep 12, 2024
1 parent d9698dd commit c39e9c9
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 45 deletions.
6 changes: 1 addition & 5 deletions src/agentscope/rpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# -*- coding: utf-8 -*-
"""Import all rpc related modules in the package."""
from .rpc_client import (
RpcClient,
call_func_in_thread,
)
from .rpc_client import RpcClient
from .rpc_meta import async_func, sync_func, RpcMeta
from .rpc_config import DistConf
from .rpc_async import AsyncResult
Expand All @@ -18,5 +15,4 @@
"sync_func",
"AsyncResult",
"DistConf",
"call_func_in_thread",
]
16 changes: 2 additions & 14 deletions src/agentscope/rpc/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import json
import os
from typing import Optional, Sequence, Union, Generator, Callable, Any
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Optional, Sequence, Union, Generator, Any
from concurrent.futures import ThreadPoolExecutor
from loguru import logger

from ..message import Msg
Expand Down Expand Up @@ -327,18 +327,6 @@ def __reduce__(self) -> tuple:
)


def call_func_in_thread(func: Callable) -> Future:
"""Call a function in a sub-thread.
Args:
func (`Callable`): The function to be called in sub-thread.
Returns:
`Future`: A stub to get the response.
"""
return RpcClient._EXECUTOR.submit(func) # pylint: disable=W0212


class RpcAgentClient(RpcClient):
"""`RpcAgentClient` has renamed to `RpcClient`.
This class is kept for backward compatibility, please use `RpcClient`
Expand Down
51 changes: 36 additions & 15 deletions src/agentscope/rpc/rpc_object.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# -*- coding: utf-8 -*-
"""A proxy object which represent a object located in a rpc server."""
from typing import Any, Callable
from functools import partial
from abc import ABC
from inspect import getmembers, isfunction
from types import FunctionType
from concurrent.futures import ThreadPoolExecutor
import threading

try:
import cloudpickle as pickle
Expand All @@ -13,7 +14,7 @@

pickle = ImportErrorReporter(e, "distribute")

from .rpc_client import RpcClient, call_func_in_thread
from .rpc_client import RpcClient
from .rpc_async import AsyncResult
from ..exception import AgentCreationError, AgentServerNotAliveError

Expand All @@ -27,6 +28,28 @@ def get_public_methods(cls: type) -> list[str]:
]


class _RpcThreadPool:
"""Executor for rpc object tasks."""

_executor = None
_lock = threading.Lock()

def __init__(self, max_workers: int = 32) -> None:
if _RpcThreadPool._executor is None:
with _RpcThreadPool._lock:
if _RpcThreadPool._executor is None:
_RpcThreadPool._executor = ThreadPoolExecutor(
max_workers=max_workers,
)

@classmethod
def submit(cls, fn: Callable, *args: Any, **kwargs: Any) -> Any:
"""Submit a task to the executor."""
if cls._executor is None:
cls()
return cls._executor.submit(fn, *args, **kwargs)


class RpcObject(ABC):
"""A proxy object which represent an object located in a rpc server."""

Expand Down Expand Up @@ -65,6 +88,7 @@ def __init__(
self._oid = oid
self._cls = cls
self.connect_existing = connect_existing
self.executor = ThreadPoolExecutor(max_workers=1)

from ..studio._client import _studio_client

Expand Down Expand Up @@ -108,15 +132,14 @@ def __init__(

def create(self, configs: dict) -> None:
"""create the object on the rpc server."""
self._creating_stub = call_func_in_thread(
partial(
self.client.create_agent,
configs,
self._oid,
),
self._creating_stub = _RpcThreadPool.submit(
self.client.create_agent,
configs,
self._oid,
)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
self._check_created()
if "__call__" in self._cls._async_func:
return self._async_func("__call__")(*args, **kwargs)
else:
Expand Down Expand Up @@ -159,7 +182,6 @@ def _check_created(self) -> None:

def _call_func(self, func_name: str, args: dict) -> Any:
"""Call a function in rpc server."""
self._check_created()
return pickle.loads(
self.client.call_agent_func(
agent_id=self._oid,
Expand All @@ -173,12 +195,10 @@ def async_wrapper(*args, **kwargs) -> Any: # type: ignore[no-untyped-def]
return AsyncResult(
host=self.host,
port=self.port,
stub=call_func_in_thread(
partial(
self._call_func,
func_name=name,
args={"args": args, "kwargs": kwargs},
),
stub=_RpcThreadPool.submit(
self._call_func,
func_name=name,
args={"args": args, "kwargs": kwargs},
),
)

Expand All @@ -194,6 +214,7 @@ def sync_wrapper(*args, **kwargs) -> Any: # type: ignore[no-untyped-def]
return sync_wrapper

def __getattr__(self, name: str) -> Callable:
self._check_created()
if name in self._cls._async_func:
# for async functions
return self._async_func(name)
Expand Down
21 changes: 10 additions & 11 deletions tests/async_result_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
"""Test the async result pool."""
import unittest
import time
import functools
import pickle

from loguru import logger

from agentscope.rpc import call_func_in_thread
from agentscope.rpc.rpc_object import _RpcThreadPool
from agentscope.server.async_result_pool import (
AsyncResultPool,
get_pool,
Expand Down Expand Up @@ -38,18 +37,18 @@ def _test_result_pool(self, pool: AsyncResultPool) -> None:
for target_value in range(10):
oid = pool.prepare()
get_stubs.append(
call_func_in_thread(
functools.partial(test_get_func, oid=oid, pool=pool),
_RpcThreadPool.submit(
test_get_func,
oid=oid,
pool=pool,
),
)
set_stubs.append(
call_func_in_thread(
functools.partial(
test_set_func,
oid=oid,
value=target_value,
pool=pool,
),
_RpcThreadPool.submit(
test_set_func,
oid=oid,
value=target_value,
pool=pool,
),
)
et = time.time()
Expand Down

0 comments on commit c39e9c9

Please sign in to comment.