diff --git a/libs/langgraph/langgraph/pregel/retry.py b/libs/langgraph/langgraph/pregel/retry.py index cdf7b33d3..553b25468 100644 --- a/libs/langgraph/langgraph/pregel/retry.py +++ b/libs/langgraph/langgraph/pregel/retry.py @@ -2,7 +2,7 @@ import logging import random import time -from typing import Optional +from typing import Optional, Sequence from langgraph.constants import CONFIG_KEY_RESUMING from langgraph.errors import GraphInterrupt @@ -38,11 +38,21 @@ def run_with_retry( # increment attempts attempts += 1 # check if we should retry - if callable(retry_policy.retry_on): + if isinstance(retry_policy.retry_on, Sequence): + if not isinstance(exc, tuple(retry_policy.retry_on)): + raise + elif isinstance(retry_policy.retry_on, type) and issubclass( + retry_policy.retry_on, Exception + ): + if not isinstance(exc, retry_policy.retry_on): + raise + elif callable(retry_policy.retry_on): if not retry_policy.retry_on(exc): raise - elif not isinstance(exc, retry_policy.retry_on): - raise + else: + raise TypeError( + "retry_on must be an Exception class, a list or tuple of Exception classes, or a callable" + ) # check if we should give up if attempts >= retry_policy.max_attempts: raise @@ -94,11 +104,21 @@ async def arun_with_retry( # increment attempts attempts += 1 # check if we should retry - if callable(retry_policy.retry_on): + if isinstance(retry_policy.retry_on, Sequence): + if not isinstance(exc, tuple(retry_policy.retry_on)): + raise + elif isinstance(retry_policy.retry_on, type) and issubclass( + retry_policy.retry_on, Exception + ): + if not isinstance(exc, retry_policy.retry_on): + raise + elif callable(retry_policy.retry_on): if not retry_policy.retry_on(exc): raise - elif not isinstance(exc, retry_policy.retry_on): - raise + else: + raise TypeError( + "retry_on must be an Exception class, a list or tuple of Exception classes, or a callable" + ) # check if we should give up if attempts >= retry_policy.max_attempts: raise diff --git a/libs/langgraph/langgraph/pregel/types.py b/libs/langgraph/langgraph/pregel/types.py index 589ed43ab..f8682d50b 100644 --- a/libs/langgraph/langgraph/pregel/types.py +++ b/libs/langgraph/langgraph/pregel/types.py @@ -1,5 +1,5 @@ from collections import deque -from typing import Any, Callable, Literal, NamedTuple, Optional, Type, Union +from typing import Any, Callable, Literal, NamedTuple, Optional, Sequence, Type, Union from langchain_core.runnables import Runnable, RunnableConfig @@ -52,7 +52,7 @@ class RetryPolicy(NamedTuple): jitter: bool = True """Whether to add random jitter to the interval between retries.""" retry_on: Union[ - Type[Exception], tuple[Type[Exception], ...], Callable[[Exception], bool] + Type[Exception], Sequence[Type[Exception]], Callable[[Exception], bool] ] = default_retry_on """List of exception classes that should trigger a retry, or a callable that returns True for exceptions that should trigger a retry."""