Skip to content

Commit

Permalink
support custom async func
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed Aug 9, 2024
1 parent 2386bb8 commit 9ff9234
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/agentscope/message/placeholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __update_task_id(self) -> None:
raise ValueError(
f"Failed to get task_id: {self._stub.get_response()}",
) from e
self._task_id = resp["task_id"] # type: ignore[call-overload]
self._task_id = resp
self._stub = None

def __getstate__(self) -> dict:
Expand Down
4 changes: 4 additions & 0 deletions src/agentscope/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
call_func_in_thread,
)

from .rpc_config import DistConf, async_func

try:
from .rpc_agent_pb2 import RpcMsg # pylint: disable=E0611
from .rpc_agent_pb2_grpc import RpcAgentServicer
Expand All @@ -29,6 +31,8 @@
"RpcMsg",
"RpcAgentServicer",
"RpcAgentStub",
"DistConf",
"async_func",
"call_func_in_thread",
"add_RpcAgentServicer_to_server",
]
14 changes: 14 additions & 0 deletions src/agentscope/rpc/rpc_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
# -*- coding: utf-8 -*-
"""Configs for Distributed mode."""
from typing import Callable
from loguru import logger


def async_func(func: Callable) -> Callable:
"""A decorator for async function.
In distributed mode, async functions will return a placeholder message
immediately.
Args:
func (`Callable`): The function to decorate.
"""

func._is_async = True # pylint: disable=W0212
return func


class DistConf(dict):
"""Distribution configuration for agents."""

Expand Down
70 changes: 58 additions & 12 deletions src/agentscope/server/servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,33 @@
from agentscope.studio._client import _studio_client
from agentscope.exception import StudioRegisterError
from agentscope.rpc.rpc_agent_pb2_grpc import RpcAgentServicer
from agentscope.rpc.rpc_agent_client import RpcAgentClient
from agentscope.message import (
Msg,
PlaceholderMessage,
deserialize,
serialize,
)


class TaskResult:
"""Use this class to get the the result from rpc server."""

# TODO: merge into placeholder

def __init__(self, host: str, port: int, task_id: int) -> None:
self.host = host
self.port = port
self.task_id = task_id

def get(self) -> Any:
"""Get the value"""
return deserialize(
RpcAgentClient(self.host, self.port).update_placeholder(
self.task_id,
),
)


def _register_server_to_studio(
studio_url: str,
server_id: str,
Expand Down Expand Up @@ -329,14 +348,14 @@ def call_agent_func( # pylint: disable=W0236
return self.call_custom_func(
request.agent_id,
request.target_func,
deserialize(request.value),
request.value,
)

def call_custom_func(
self,
agent_id: str,
func_name: str,
args: dict,
raw_value: str,
) -> agent_pb2.GeneralResponse:
"""Call a custom function"""
agent = self.get_agent(agent_id)
Expand All @@ -345,6 +364,31 @@ def call_custom_func(
success=False,
message=f"Agent [{agent_id}] not exists.",
)
func = getattr(agent, func_name)
if (
hasattr(func, "_is_async")
and func._is_async # pylint: disable=W0212
): # pylint: disable=W0212
task_id = self.get_task_id()
self.result_pool[task_id] = threading.Condition()
self.executor.submit(
self._process_messages,
task_id,
agent_id,
func_name,
raw_value,
)
return agent_pb2.GeneralResponse(
ok=True,
message=serialize(
TaskResult(
host=self.host,
port=self.port,
task_id=task_id,
),
),
)
args = deserialize(raw_value)
res = getattr(agent, func_name)(
*args.get("args", ()),
**args.get("kwargs", {}),
Expand Down Expand Up @@ -492,11 +536,7 @@ def _reply(self, request: agent_pb2.RpcMsg) -> agent_pb2.GeneralResponse:
return agent_pb2.GeneralResponse(
ok=True,
message=serialize(
Msg( # type: ignore[arg-type]
name=self.get_agent(request.agent_id).name,
content=None,
task_id=task_id,
),
task_id,
),
)

Expand Down Expand Up @@ -534,16 +574,22 @@ def _process_messages(
target_func (`str`): the name of the function that will be called.
raw_msg (`str`): the input serialized message.
"""
if raw_msg:
if raw_msg is not None:
msg = deserialize(raw_msg)
else:
msg = None
if isinstance(msg, PlaceholderMessage):
msg.update_value()
cond = self.result_pool[task_id]
agent = self.get_agent(agent_id)
if isinstance(msg, PlaceholderMessage):
msg.update_value()
try:
result = getattr(agent, target_func)(msg)
if target_func == "reply":
result = getattr(agent, target_func)(msg)
else:
result = getattr(agent, target_func)(
*msg.get("args", ()),
**msg.get("kwargs", {}),
)
self.result_pool[task_id] = result
except Exception:
error_msg = traceback.format_exc()
Expand Down
28 changes: 27 additions & 1 deletion tests/rpc_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from agentscope.agents import AgentBase, DistConf, DialogAgent
from agentscope.manager import MonitorManager, ASManager
from agentscope.server import RpcAgentServerLauncher
from agentscope.server.servicer import TaskResult
from agentscope.message import Msg
from agentscope.message import PlaceholderMessage
from agentscope.message import deserialize, serialize
from agentscope.msghub import msghub
from agentscope.pipelines import sequentialpipeline
from agentscope.rpc.rpc_agent_client import RpcAgentClient
from agentscope.rpc import RpcAgentClient, async_func
from agentscope.agents import RpcAgent
from agentscope.exception import AgentCallError, QuotaExceededError

Expand Down Expand Up @@ -181,6 +182,7 @@ def __init__( # type: ignore[no-untyped-def]
**kwargs,
) -> None:
super().__init__(name, **kwargs)
self.cnt = 0
self.judge_func = judge_func

def reply(self, x: Msg = None) -> Msg:
Expand All @@ -203,6 +205,17 @@ def custom_judge_func(self, x: str) -> bool:
res = self.judge_func(x)
return res

@async_func
def custom_async_func(self, num: int) -> int:
"""A custom function that executes in async"""
time.sleep(num)
self.cnt += num
return self.cnt

def custom_sync_func(self) -> int:
"""A custom function that executes in sync"""
return self.cnt


class BasicRpcAgentTest(unittest.TestCase):
"""Test cases for Rpc Agent"""
Expand Down Expand Up @@ -815,6 +828,7 @@ def test_server_auto_alloc(
"args": (),
"kwargs": {"name": "custom"},
"class_name": "CustomAgent",
"type": "agent",
},
agent_id=custom_agent_id,
),
Expand Down Expand Up @@ -850,3 +864,15 @@ def test_custom_agent_func(self) -> None:
self.assertFalse(agent.custom_judge_func("diuafhsua$FAIL$"))
self.assertTrue(agent.custom_judge_func("72354rfv$PASS$"))
self.assertEqual(r, 1)
start_time = time.time()
r1 = agent.custom_async_func(1)
r2 = agent.custom_async_func(1)
r3 = agent.custom_sync_func()
end_time = time.time()
self.assertTrue(end_time - start_time < 1)
self.assertEqual(r3, 0)
self.assertTrue(isinstance(r1, TaskResult))
self.assertTrue(r1.get() <= 2)
self.assertTrue(r2.get() <= 2)
r4 = agent.custom_sync_func()
self.assertEqual(r4, 2)

0 comments on commit 9ff9234

Please sign in to comment.