Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

swarm - Enhance retry decorator to support both sync and async functions #810

Merged
merged 1 commit into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 58 additions & 18 deletions pkgs/swarmauri/swarmauri/utils/retry_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,82 @@
import logging
import httpx
from functools import wraps
from typing import List
import asyncio # Import asyncio to use async sleep
from typing import List, Callable, Any
import asyncio
import inspect


def retry_on_status_codes(
status_codes: List[int] = [429],
max_retries: int = 3,
retry_delay: int = 2
status_codes: List[int] = [429], max_retries: int = 3, retry_delay: int = 2
):
"""
A decorator to retry a function call when one of the provided status codes is encountered,
A decorator to retry both sync and async functions when specific status codes are encountered,
with exponential backoff.

Parameters:
- status_codes: List of HTTP status codes that should trigger retries (default [429]).
- max_retries: The maximum number of retries (default 3).
- retry_delay: The initial delay between retries, in seconds (default 2).
"""
def decorator(func):

def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
attempt = 0
while attempt < max_retries:
try:
return await func(*args, **kwargs)
except httpx.HTTPStatusError as e:
if e.response.status_code in status_codes:
attempt += 1
if attempt == max_retries:
break
backoff_time = retry_delay * (2 ** (attempt - 1))
logging.warning(
f"Received status {e.response.status_code}. "
f"Retrying in {backoff_time} seconds... "
f"Attempt {attempt}/{max_retries}"
)
await asyncio.sleep(backoff_time)
else:
raise

logging.error(
f"Failed after {max_retries} retries due to rate limit or other status codes."
)
raise Exception(
f"Failed after {max_retries} retries due to {status_codes}."
)

@wraps(func)
async def wrapper(*args, **kwargs): # Make the wrapper async
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
attempt = 0
while attempt < max_retries:
try:
return await func(*args, **kwargs) # Use await to call the async function
return func(*args, **kwargs)
except httpx.HTTPStatusError as e:
if e.response.status_code in status_codes:
attempt += 1
backoff_time = retry_delay * (2 ** (attempt - 1)) # Exponential backoff
logging.warning(f"Received status {e.response.status_code}. Retrying in {backoff_time} seconds... Attempt {attempt}/{max_retries}")
await asyncio.sleep(backoff_time) # Use asyncio.sleep for async functions
if attempt == max_retries:
break
backoff_time = retry_delay * (2 ** (attempt - 1))
logging.warning(
f"Received status {e.response.status_code}. "
f"Retrying in {backoff_time} seconds... "
f"Attempt {attempt}/{max_retries}"
)
time.sleep(backoff_time)
else:
# If the error code is not in the retry list, raise the error
raise
# If retries are exhausted, raise an exception
logging.error(f"Failed after {max_retries} retries due to rate limit or other status codes.")
raise Exception(f"Failed after {max_retries} retries due to {status_codes}.")

return wrapper

logging.error(
f"Failed after {max_retries} retries due to rate limit or other status codes."
)
raise Exception(
f"Failed after {max_retries} retries due to {status_codes}."
)

# Check if the function is async or sync and return appropriate wrapper
return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper

return decorator
5 changes: 4 additions & 1 deletion pkgs/swarmauri/tests/unit/llms/GroqToolModel_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def groq_tool_model():
llm = LLM(api_key=API_KEY)
return llm


def get_allowed_models():
if not API_KEY:
return []
Expand Down Expand Up @@ -95,6 +96,7 @@ def test_agent_exec(groq_tool_model, toolkit, conversation, model_name):
result = agent.exec("Add 512+671")
assert type(result) is str


@retry_on_status_codes([429])
@timeout(5)
@pytest.mark.unit
Expand All @@ -107,6 +109,7 @@ def test_predict(groq_tool_model, toolkit, conversation, model_name):

assert type(conversation.get_last().content) == str


@retry_on_status_codes([429])
@timeout(5)
@pytest.mark.unit
Expand All @@ -124,6 +127,7 @@ def test_stream(groq_tool_model, toolkit, conversation, model_name):
# assert len(full_response) > 0
assert conversation.get_last().content == full_response


@retry_on_status_codes([429])
@timeout(5)
@pytest.mark.unit
Expand Down Expand Up @@ -156,7 +160,6 @@ async def test_apredict(groq_tool_model, toolkit, conversation, model_name):
assert isinstance(prediction, str)



@retry_on_status_codes([429])
@timeout(5)
@pytest.mark.unit
Expand Down