Skip to content

Commit

Permalink
add retry strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed Oct 8, 2024
1 parent 5b9e1a4 commit bc9a0d1
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ cd $upper_dir
touch .pid

# activate your environment
source /mnt/conda/miniconda3/bin/activate as
source /root/miniconda3/bin/activate as

# start all agent servers
for ((i=0; i<(agent_server_num + env_server_num); i++)); do
Expand Down
148 changes: 148 additions & 0 deletions src/agentscope/rpc/retry_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# -*- coding: utf-8 -*-
"""
Timeout retry strategies
"""
from __future__ import annotations
import time
import random
from abc import ABC, abstractmethod
from typing import Callable, Any
from functools import partial
from loguru import logger


class RetryBase(ABC):
"""The base class for all retry strategies"""

@abstractmethod
def retry(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
"""Retry the func when any exception occurs"""

def __call__(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
"""Call the retry method"""
return self.retry(func, *args, **kwargs)

@classmethod
def load_dict(cls, data: dict) -> RetryBase:
"""Load the retry strategy from a dict"""
retry_type = data.pop("type", None)
if retry_type == "fixed":
return RetryFixedTimes(**data)
elif retry_type == "expential":
return RetryExpential(**data)
else:
raise NotImplementedError(
f"Unknown retry strategy type: {retry_type}",
)


class RetryFixedTimes(RetryBase):
"""
Retry a fixed number of times, and wait a fixed delay time between each attempt.
Init dict format:
type: 'fixed'
max_retries (`int`): The max retry times
delay (`float`): The delay time between each attempt
.. code-block:: python
retry = RetryBase.load_dict({
"type": "fixed",
"max_retries": 10,
"delay": 5,
})
"""

def __init__(self, max_retries: int = 10, delay: float = 5) -> None:
"""Initialize the retry strategy
Args:
max_retries (`int`): The max retry times
delay (`float`): The delay time between each attempt
"""
self.max_retries = max_retries
self.delay = delay

def retry( # pylint: disable=R1710
self,
func: Callable,
*args: Any,
**kwargs: Any,
) -> Any:
func = partial(func, *args, **kwargs)
for attempt in range(self.max_retries + 1):
try:
return func()
except Exception as e:
if attempt == self.max_retries:
raise TimeoutError("Max timeout exceeded.") from e
random_delay = (random.random() + 0.5) * self.delay
logger.info(
f"Attempt {attempt + 1} failed: {e}. Retrying in {random_delay} seconds...",
)
time.sleep(random_delay)


class RetryExpential(RetryBase):
"""
Retry with exponential backoff, which means the delay time will increase exponentially.
Init dict format:
type: 'expential'
max_retries (`int`): The max retry times
base_delay (`float`): The base delay time
max_delay (`float`): The max delay time, which will be used if the calculated delay time
exceeds it.
.. code-block:: python
retry = RetryBase.load_dict({
"type": "expential",
"max_retries": 10,
"base_delay": 5,
"max_delay": 300,
})
"""

def __init__(
self,
max_retries: int = 10,
base_delay: float = 5,
max_delay: float = 300,
) -> None:
"""Initialize the retry strategy
Args:
max_retries (`int`): The max retry times
base_delay (`float`): The base delay time
max_delay (`float`): The max delay time
"""
self.max_retries = max_retries
self.base_delay = base_delay
self.max_delay = max_delay

def retry( # pylint: disable=R1710
self,
func: Callable,
*args: Any,
**kwargs: Any,
) -> Any:
func = partial(func, *args, **kwargs)
delay = self.base_delay
for attempt in range(self.max_retries + 1):
try:
return func()
except Exception as e:
if attempt == self.max_retries:
raise TimeoutError("Max timeout exceeded.") from e
delay = (random.random() + 0.5) * delay
delay = min(delay, self.max_delay)
logger.info(
f"Attempt {attempt + 1} failed: {e}. Retrying in {delay} seconds...",
)
time.sleep(delay)
delay *= 2


_DEAFULT_RETRY_STRATEGY = RetryFixedTimes(max_retries=10, delay=5)
9 changes: 4 additions & 5 deletions src/agentscope/rpc/rpc_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..message import Msg
from .rpc_client import RpcClient
from ..utils.common import _is_web_url
from ..constants import _DEFAULT_RPC_RETRY_TIMES, _DEFAULT_RPC_TIMEOUT
from .retry_strategy import RetryBase, _DEAFULT_RETRY_STRATEGY


class AsyncResult:
Expand All @@ -26,10 +26,12 @@ def __init__(
port: int,
task_id: int = None,
stub: Future = None,
retry: RetryBase = _DEAFULT_RETRY_STRATEGY,
) -> None:
self._host = host
self._port = port
self._stub = None
self._retry = retry
self._task_id: int = None
if task_id is not None:
self._task_id = task_id
Expand All @@ -40,17 +42,14 @@ def __init__(

def _fetch_result(
self,
retry_times: int = _DEFAULT_RPC_RETRY_TIMES,
retry_interval: float = _DEFAULT_RPC_TIMEOUT,
) -> None:
"""Fetch result from the server."""
if self._task_id is None:
self._task_id = self._get_task_id()
self._data = pickle.loads(
RpcClient(self._host, self._port).update_result(
self._task_id,
retry_times=retry_times,
retry_interval=retry_interval,
retry=self._retry,
),
)
# NOTE: its a hack here to download files
Expand Down
44 changes: 12 additions & 32 deletions src/agentscope/rpc/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import json
import os
import time
import random
from typing import Optional, Sequence, Union, Generator, Any
from concurrent.futures import ThreadPoolExecutor
from loguru import logger
Expand All @@ -25,9 +23,10 @@
agent_pb2 = ImportErrorReporter(import_error, "distribute")
RpcAgentStub = ImportErrorReporter(import_error, "distribute")

from .retry_strategy import RetryBase, _DEAFULT_RETRY_STRATEGY
from ..utils.common import _generate_id_from_seed
from ..exception import AgentServerNotAliveError
from ..constants import _DEFAULT_RPC_OPTIONS
from ..constants import _DEFAULT_RPC_OPTIONS, _DEFAULT_RPC_TIMEOUT
from ..exception import AgentCallError, AgentCreationError
from ..manager import FileManager

Expand Down Expand Up @@ -223,8 +222,7 @@ def delete_all_agent(self) -> bool:
def update_result(
self,
task_id: int,
retry_times: int = 10,
retry_interval: float = 5,
retry: RetryBase = _DEAFULT_RETRY_STRATEGY,
) -> str:
"""Update the value of the async result.
Expand All @@ -233,42 +231,24 @@ def update_result(
Args:
task_id (`int`): `task_id` of the PlaceholderMessage.
retry_times (`int`): Number of retries. Defaults to 10.
retry_interval (`float`): Base interval between retries in seconds.
Defaults to 5. Double the interval between retries for each retry.
retry (`RetryBase`): Retry strategy. Defaults to `RetryFixedTimes(10, 5)`.
Returns:
bytes: Serialized value.
"""
stub = RpcAgentStub(RpcClient._get_channel(self.url))
resp = None
for _ in range(retry_times):
try:
resp = stub.update_placeholder(
agent_pb2.UpdatePlaceholderRequest(task_id=task_id),
)
except grpc.RpcError as e:
if e.code() != grpc.StatusCode.DEADLINE_EXCEEDED:
raise AgentCallError(
host=self.host,
port=self.port,
message=f"Failed to update placeholder: {str(e)}",
) from e
# wait for a random time between retries
interval = (random.random() + 0.5) * retry_interval
logger.info(
f"Update placeholder timeout, retrying after {interval} s...",
)
time.sleep(interval)
retry_interval *= 2
continue
break
if resp is None:
try:
resp = retry.retry(
stub.update_placeholder,
agent_pb2.UpdatePlaceholderRequest(task_id=task_id),
timeout=_DEFAULT_RPC_TIMEOUT,
)
except Exception as e:
raise AgentCallError(
host=self.host,
port=self.port,
message="Failed to update placeholder: timeout",
)
) from e
if not resp.ok:
raise AgentCallError(
host=self.host,
Expand Down
9 changes: 9 additions & 0 deletions src/agentscope/rpc/rpc_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from loguru import logger

from .rpc_object import RpcObject
from .retry_strategy import RetryBase, _DEAFULT_RETRY_STRATEGY


# Decorator for async and sync functions
Expand Down Expand Up @@ -115,6 +116,10 @@ def __call__(cls, *args: tuple, **kwargs: dict) -> Any:
"local_mode",
True,
),
retry_strategy=to_dist.pop(
"retry_strategy",
_DEAFULT_RETRY_STRATEGY,
),
connect_existing=False,
configs={
"args": args,
Expand Down Expand Up @@ -178,6 +183,7 @@ def to_dist( # pylint: disable=W0211
max_expire_time: int = 7200,
max_timeout_seconds: int = 5,
local_mode: bool = True,
retry_strategy: RetryBase = _DEAFULT_RETRY_STRATEGY,
) -> Any:
"""Convert current object into its distributed version.
Expand All @@ -201,6 +207,8 @@ def to_dist( # pylint: disable=W0211
Only takes effect when `host` and `port` are not filled in.
Whether the started agent server only listens to local
requests.
retry_strategy (`RetryBase`, defaults to `_DEAFULT_RETRY_STRATEGY`):
The retry strategy for the async rpc call.
Returns:
`RpcObject`: the wrapped agent instance with distributed
Expand All @@ -219,4 +227,5 @@ def to_dist( # pylint: disable=W0211
max_expire_time=max_expire_time,
max_timeout_seconds=max_timeout_seconds,
local_mode=local_mode,
retry_strategy=retry_strategy,
)
19 changes: 15 additions & 4 deletions src/agentscope/rpc/rpc_object.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""A proxy object which represent a object located in a rpc server."""
from typing import Any, Callable
from typing import Any, Callable, Union
from abc import ABC
from inspect import getmembers, isfunction
from types import FunctionType
Expand All @@ -16,6 +16,7 @@

from .rpc_client import RpcClient
from .rpc_async import AsyncResult
from .retry_strategy import RetryBase, _DEAFULT_RETRY_STRATEGY
from ..exception import AgentCreationError, AgentServerNotAliveError


Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
max_expire_time: int = 7200,
max_timeout_seconds: int = 5,
local_mode: bool = True,
retry_strategy: Union[RetryBase, dict] = _DEAFULT_RETRY_STRATEGY,
configs: dict = None,
) -> None:
"""Initialize the rpc object.
Expand All @@ -68,6 +70,9 @@ def __init__(
oid (`str`): The id of the object in the rpc server.
host (`str`): The host of the rpc server.
port (`int`): The port of the rpc server.
connect_existing (`bool`, defaults to `False`):
Set to `True`, if the object is already running on the
server.
max_pool_size (`int`, defaults to `8192`):
Max number of task results that the server can accommodate.
max_expire_time (`int`, defaults to `7200`):
Expand All @@ -77,16 +82,21 @@ def __init__(
local_mode (`bool`, defaults to `True`):
Whether the started gRPC server only listens to local
requests.
connect_existing (`bool`, defaults to `False`):
Set to `True`, if the object is already running on the
server.
retry_strategy (`Union[RetryBase, dict]`, defaults to `_DEAFULT_RETRY_STRATEGY`):
The retry strategy for async rpc call.
configs (`dict`, defaults to `None`):
The configs for the agent. Generated by `RpcMeta`. Don't use this arg manually.
"""
self.host = host
self.port = port
self._oid = oid
self._cls = cls
self.connect_existing = connect_existing
self.executor = ThreadPoolExecutor(max_workers=1)
if isinstance(retry_strategy, RetryBase):
self.retry_strategy = retry_strategy
else:
self.retry_strategy = RetryBase.load_dict(retry_strategy)

from ..studio._client import _studio_client

Expand Down Expand Up @@ -200,6 +210,7 @@ def async_wrapper(*args, **kwargs) -> Any: # type: ignore[no-untyped-def]
func_name=name,
args={"args": args, "kwargs": kwargs},
),
retry=self.retry_strategy,
)

return async_wrapper
Expand Down
Loading

0 comments on commit bc9a0d1

Please sign in to comment.