diff --git a/graphiti_core/llm_client/client.py b/graphiti_core/llm_client/client.py index 02bd6f4..b76464c 100644 --- a/graphiti_core/llm_client/client.py +++ b/graphiti_core/llm_client/client.py @@ -20,7 +20,9 @@ import typing from abc import ABC, abstractmethod +import httpx from diskcache import Cache +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from ..prompts.models import Message from .config import LLMConfig @@ -47,6 +49,23 @@ def __init__(self, config: LLMConfig | None, cache: bool = False): def get_embedder(self) -> typing.Any: pass + @staticmethod + def _is_server_error(exception): + return ( + isinstance(exception, httpx.HTTPStatusError) + and 500 <= exception.response.status_code < 600 + ) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(httpx.HTTPStatusError), + ) + async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]: + if self._is_server_error(httpx.HTTPStatusError): + raise + return await self._generate_response(messages) + @abstractmethod async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: pass @@ -66,7 +85,15 @@ async def generate_response(self, messages: list[Message]) -> dict[str, typing.A logger.debug(f'Cache hit for {cache_key}') return cached_response - response = await self._generate_response(messages) + try: + response = await self._generate_response_with_retry(messages) + except httpx.HTTPStatusError as e: + error_type = 'server' if 500 <= e.response.status_code < 600 else 'client' + logger.error(f'Failed to generate response due to {error_type} error: {str(e)}') + raise + except Exception as e: + logger.error(f'Failed to generate response due to unexpected error: {str(e)}') + raise if self.cache_enabled: self.cache_dir.set(cache_key, response) diff --git a/pyproject.toml b/pyproject.toml index 2456f13..956e48d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ diskcache = "^5.6.3" arrow = "^1.3.0" openai = "^1.38.0" anthropic = "^0.34.1" +tenacity = "^9.0.0" [tool.poetry.dev-dependencies] pytest = "^8.3.2"