Skip to content

Commit

Permalink
added exception building logic
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Aug 15, 2023
1 parent 9cadd63 commit c9ef1d5
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 51 deletions.
85 changes: 62 additions & 23 deletions google/api_core/retry_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import logging
import time
from functools import partial

from google.api_core import exceptions

Expand All @@ -26,6 +27,34 @@
T = TypeVar("T")


def _build_timeout_error(
exc_list: list[Exception], is_timeout: bool, timeout_val: float
) -> tuple[Exception, Exception | None]:
"""
Default exception_factory implementation. Builds an exception after the retry fails
Args:
- exc_list (list[Exception]): list of exceptions that occurred during the retry
- is_timeout (bool): whether the failure is due to the timeout value being exceeded,
or due to a non-retryable exception
- timeout_val (float): the original timeout value for the retry, for use in the exception message
Returns:
- tuple[Exception, Exception|None]: a tuple of the exception to be raised, and the cause exception if any
"""
src_exc = exc_list[-1] if exc_list else None
if is_timeout:
return (
exceptions.RetryError(
"Timeout of {:.1f}s exceeded".format(timeout_val),
src_exc,
),
src_exc,
)
else:
return exc_list[-1], None


class RetryableGenerator(Generator[T, Any, None]):
"""
Generator wrapper for retryable streaming RPCs.
Expand Down Expand Up @@ -103,6 +132,9 @@ def __init__(
sleep_generator: Iterable[float],
timeout: Optional[float] = None,
on_error: Optional[Callable[[Exception], None]] = None,
exception_factory: Optional[
Callable[[list[Exception], bool, float], tuple[Exception, Exception | None]]
] = None,
check_timeout_on_yield=False,
):
"""
Expand All @@ -119,23 +151,34 @@ def __init__(
on_error: A function to call while processing a
retryable exception. Any error raised by this function will *not*
be caught.
exception_factory: A function that creates an exception to raise
when the retry fails. The function takes three arguments:
a list of exceptions that occurred during the retry, a boolean
indicating whether the failure is due to retry timeout, and the original
timeout value (for building a helpful error message). It is expected to
return a tuple of the exception to raise and (optionally) a source
exception to chain to the raised exception.
If not provided, a default exception will be raised.
check_timeout_on_yield: If True, the timeout value will be checked
after each yield. If the timeout has been exceeded, the generator
will raise a RetryError. Note that this adds an overhead to each
yield, so it is preferred to add the timeout logic to the wrapped
stream when possible.
will raise an exception from exception_factory.
Note that this adds an overhead to each yield, so it is better
to add the timeout logic to the wrapped stream when possible.
"""
self.target_fn = target
self.active_target: Iterator[T] = self.target_fn().__iter__()
self.predicate = predicate
self.sleep_generator = iter(sleep_generator)
self.on_error = on_error
self.timeout = timeout
if self.timeout is not None:
self.deadline = time.monotonic() + self.timeout
if timeout is not None:
self.deadline = time.monotonic() + timeout
else:
self.deadline = None
self._check_timeout_on_yield = check_timeout_on_yield
self.error_list: list[Exception] = []
self._exc_factory = partial(
exception_factory or _build_timeout_error, timeout_val=timeout
)

def __iter__(self) -> Generator[T, Any, None]:
"""
Expand All @@ -149,8 +192,12 @@ def _handle_exception(self, exc) -> None:
check if it is retryable. If so, create a new active_target and
continue iterating. If not, raise the exception.
"""
self.error_list.append(exc)
if not self.predicate(exc):
raise exc
final_exc, src_exc = self._exc_factory(
exc_list=self.error_list, is_timeout=False
)
raise final_exc from src_exc
else:
# run on_error callback if provided
if self.on_error:
Expand All @@ -159,38 +206,30 @@ def _handle_exception(self, exc) -> None:
next_sleep = next(self.sleep_generator)
except StopIteration:
raise ValueError("Sleep generator stopped yielding sleep values")
# if deadline is exceeded, raise RetryError
# if deadline is exceeded, raise exception
if self.deadline is not None:
next_attempt = time.monotonic() + next_sleep
self._check_timeout(next_attempt, exc)
self._check_timeout(next_attempt)
# sleep before retrying
_LOGGER.debug(
"Retrying due to {}, sleeping {:.1f}s ...".format(exc, next_sleep)
)
time.sleep(next_sleep)
self.active_target = self.target_fn().__iter__()

def _check_timeout(
self, current_time: float, source_exception: Optional[Exception] = None
) -> None:
def _check_timeout(self, current_time: float) -> None:
"""
Helper function to check if the timeout has been exceeded, and raise a RetryError if so.
Helper function to check if the timeout has been exceeded, and raise an exception if so.
Args:
- current_time: the timestamp to check against the deadline
- source_exception: the exception that triggered the timeout check, if any
Raises:
- RetryError if the deadline has been exceeded
- Exception from exception_factory if the timeout has been exceeded
"""
if (
self.deadline is not None
and self.timeout is not None
and self.deadline < current_time
):
raise exceptions.RetryError(
"Timeout of {:.1f}s exceeded".format(self.timeout),
source_exception,
) from source_exception
if self.deadline is not None and self.deadline < current_time:
exc, src_exc = self._exc_factory(exc_list=self.error_list, is_timeout=True)
raise exc from src_exc

def __next__(self) -> T:
"""
Expand Down
53 changes: 33 additions & 20 deletions google/api_core/retry_streaming_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
import asyncio
import logging
import time
from functools import partial

from google.api_core import exceptions
from google.api_core.retry_streaming import _build_timeout_error

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -118,6 +120,9 @@ def __init__(
sleep_generator: Iterable[float],
timeout: Optional[float] = None,
on_error: Optional[Callable[[Exception], None]] = None,
exception_factory: Optional[
Callable[[list[Exception], bool, float], tuple[Exception, Exception | None]]
] = None,
check_timeout_on_yield=False,
):
"""
Expand All @@ -134,47 +139,51 @@ def __init__(
on_error: A function to call while processing a
retryable exception. Any error raised by this function will *not*
be caught.
exception_factory: A function that creates an exception to raise
when the retry fails. The function takes three arguments:
a list of exceptions that occurred during the retry, a boolean
indicating whether the failure is due to retry timeout, and the original
timeout value (for building a helpful error message). It is expected to
return a tuple of the exception to raise and (optionally) a source
exception to chain to the raised exception.
If not provided, a default exception will be raised.
check_timeout_on_yield: If True, the timeout value will be checked
after each yield. If the timeout has been exceeded, the generator
will raise a RetryError. Note that this adds an overhead to each
yield, so it is preferred to add the timeout logic to the wrapped
stream when possible.
will raise an exception from exception_factory.
Note that this adds an overhead to each yield, so it is better
to add the timeout logic to the wrapped stream when possible.
"""
self.target_fn = target
# active target must be populated in an async context
self.active_target: Optional[AsyncIterator[T]] = None
self.predicate = predicate
self.sleep_generator = iter(sleep_generator)
self.on_error = on_error
self.timeout = timeout
self.timeout_task = None
if self.timeout is not None:
self.deadline = time.monotonic() + self.timeout
if timeout is not None:
self.deadline = time.monotonic() + timeout
else:
self.deadline = None
self._check_timeout_on_yield = check_timeout_on_yield
self.error_list: list[Exception] = []
self._exc_factory = partial(
exception_factory or _build_timeout_error, timeout_val=timeout
)

def _check_timeout(
self, current_time: float, source_exception: Optional[Exception] = None
) -> None:
"""
Helper function to check if the timeout has been exceeded, and raise a RetryError if so.
Helper function to check if the timeout has been exceeded, and raise an exception if so.
Args:
- current_time: the timestamp to check against the deadline
- source_exception: the exception that triggered the timeout check, if any
Raises:
- RetryError if the deadline has been exceeded
- Exception from exception_factory if the timeout has been exceeded
"""
if (
self.deadline is not None
and self.timeout is not None
and self.deadline < current_time
):
raise exceptions.RetryError(
"Timeout of {:.1f}s exceeded".format(self.timeout),
source_exception,
) from source_exception
if self.deadline is not None and self.deadline < current_time:
exc, src_exc = self._exc_factory(exc_list=self.error_list, is_timeout=True)
raise exc from src_exc

async def _ensure_active_target(self) -> AsyncIterator[T]:
"""
Expand All @@ -200,8 +209,12 @@ async def _handle_exception(self, exc) -> None:
check if it is retryable. If so, create a new active_target and
continue iterating. If not, raise the exception.
"""
self.error_list.append(exc)
if not self.predicate(exc) and not isinstance(exc, asyncio.TimeoutError):
raise exc
final_exc, src_exc = self._exc_factory(
exc_list=self.error_list, is_timeout=False
)
raise final_exc from src_exc
else:
# run on_error callback if provided
if self.on_error:
Expand All @@ -210,7 +223,7 @@ async def _handle_exception(self, exc) -> None:
next_sleep = next(self.sleep_generator)
except StopIteration:
raise ValueError("Sleep generator stopped yielding sleep values")
# if deadline is exceeded, raise RetryError
# if deadline is exceeded, raise exception
if self.deadline is not None:
next_attempt = time.monotonic() + next_sleep
self._check_timeout(next_attempt, exc)
Expand Down
Loading

0 comments on commit c9ef1d5

Please sign in to comment.