Skip to content

Commit

Permalink
simplified ensure_tareget; fixed mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Aug 15, 2023
1 parent 41c7868 commit 30fccb9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 49 deletions.
39 changes: 14 additions & 25 deletions google/api_core/retry_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Helpers for retries for streaming APIs."""

from typing import Callable, Optional, Iterable, Iterator, Generator, TypeVar, Any, cast
from typing import Callable, Optional, List, Tuple, Iterable, Iterator, Generator, TypeVar, Any, cast

import logging
import time
Expand All @@ -27,9 +27,7 @@
T = TypeVar("T")


def _build_timeout_error(
exc_list: list[Exception], is_timeout: bool, timeout_val: float
) -> tuple[Exception, Exception | None]:
def _build_timeout_error(exc_list:List[Exception], is_timeout:bool, timeout_val:float) -> Tuple[Exception, Optional[Exception]]:
"""
Default exception_factory implementation. Builds an exception after the retry fails
Expand All @@ -44,13 +42,10 @@ def _build_timeout_error(
"""
src_exc = exc_list[-1] if exc_list else None
if is_timeout:
return (
exceptions.RetryError(
"Timeout of {:.1f}s exceeded".format(timeout_val),
src_exc,
),
return exceptions.RetryError(
"Timeout of {:.1f}s exceeded".format(timeout_val),
src_exc,
)
), src_exc
else:
return exc_list[-1], None

Expand Down Expand Up @@ -132,9 +127,7 @@ def __init__(
sleep_generator: Iterable[float],
timeout: Optional[float] = None,
on_error: Optional[Callable[[Exception], None]] = None,
exception_factory: Optional[
Callable[[list[Exception], bool, float], tuple[Exception, Exception | None]]
] = None,
exception_factory: Optional[Callable[[List[Exception], bool, float], Tuple[Exception, Optional[Exception]]]] = None,
check_timeout_on_yield: bool = False,
):
"""
Expand Down Expand Up @@ -170,15 +163,10 @@ def __init__(
self.predicate = predicate
self.sleep_generator = iter(sleep_generator)
self.on_error = on_error
if timeout is not None:
self.deadline = time.monotonic() + timeout
else:
self.deadline = None
self.deadline: Optional[float] = time.monotonic() + timeout if timeout else None
self._check_timeout_on_yield = check_timeout_on_yield
self.error_list: list[Exception] = []
self._exc_factory = partial(
exception_factory or _build_timeout_error, timeout_val=timeout
)
self.error_list : List[Exception] = []
self._exc_factory = partial(exception_factory or _build_timeout_error, timeout_val=timeout)

def __iter__(self) -> Generator[T, Any, None]:
"""
Expand All @@ -194,9 +182,7 @@ def _handle_exception(self, exc) -> None:
"""
self.error_list.append(exc)
if not self.predicate(exc):
final_exc, src_exc = self._exc_factory(
exc_list=self.error_list, is_timeout=False
)
final_exc, src_exc = self._exc_factory(exc_list=self.error_list, is_timeout=False)
raise final_exc from src_exc
else:
# run on_error callback if provided
Expand Down Expand Up @@ -227,7 +213,10 @@ def _check_timeout(self, current_time: float) -> None:
Raises:
- Exception from exception_factory if the timeout has been exceeded
"""
if self.deadline is not None and self.deadline < current_time:
if (
self.deadline is not None
and self.deadline < current_time
):
exc, src_exc = self._exc_factory(exc_list=self.error_list, is_timeout=True)
raise exc from src_exc

Expand Down
45 changes: 21 additions & 24 deletions google/api_core/retry_streaming_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
Callable,
Optional,
Iterable,
List,
Tuple,
AsyncIterator,
AsyncIterable,
Awaitable,
Expand Down Expand Up @@ -120,7 +122,7 @@ def __init__(
timeout: Optional[float] = None,
on_error: Optional[Callable[[Exception], None]] = None,
exception_factory: Optional[
Callable[[list[Exception], bool, float], tuple[Exception, Exception | None]]
Callable[[List[Exception], bool, float], Tuple[Exception, Optional[Exception]]]
] = None,
check_timeout_on_yield: bool = False,
):
Expand Down Expand Up @@ -158,12 +160,9 @@ def __init__(
self.predicate = predicate
self.sleep_generator = iter(sleep_generator)
self.on_error = on_error
if timeout is not None:
self.deadline = time.monotonic() + timeout
else:
self.deadline = None
self.deadline: Optional[float] = time.monotonic() + timeout if timeout else None
self._check_timeout_on_yield = check_timeout_on_yield
self.error_list: list[Exception] = []
self.error_list: List[Exception] = []
self._exc_factory = partial(
exception_factory or _build_timeout_error, timeout_val=timeout
)
Expand All @@ -184,19 +183,14 @@ def _check_timeout(
exc, src_exc = self._exc_factory(exc_list=self.error_list, is_timeout=True)
raise exc from src_exc

async def _ensure_active_target(self) -> AsyncIterator[T]:
async def _new_target(self) -> AsyncIterator[T]:
"""
Ensure that the active target is populated and ready to be iterated over.
Returns:
- The active_target iterable
Creates and returns a new target iterator from the target function.
"""
if self.active_target is None:
new_iterable = self.target_fn()
if isinstance(new_iterable, Awaitable):
new_iterable = await new_iterable
self.active_target = new_iterable.__aiter__()
return self.active_target
new_iterable = self.target_fn()
if isinstance(new_iterable, Awaitable):
new_iterable = await new_iterable
return new_iterable.__aiter__()

def __aiter__(self) -> AsyncIterator[T]:
"""Implement the async iterator protocol."""
Expand Down Expand Up @@ -231,8 +225,7 @@ async def _handle_exception(self, exc) -> None:
)
# sleep before retrying
await asyncio.sleep(next_sleep)
self.active_target = None
await self._ensure_active_target()
self.active_target = await self._new_target()

async def _iteration_helper(self, iteration_routine: Awaitable) -> T:
"""
Expand Down Expand Up @@ -265,9 +258,10 @@ async def __anext__(self) -> T:
Returns:
- The next value from the active_target iterator.
"""
iterable = await self._ensure_active_target()
if self.active_target is None:
self.active_target = await self._new_target()
return await self._iteration_helper(
iterable.__anext__(),
self.active_target.__anext__(),
)

async def aclose(self) -> None:
Expand All @@ -277,7 +271,8 @@ async def aclose(self) -> None:
Raises:
- AttributeError if the active_target does not have a aclose() method
"""
await self._ensure_active_target()
if self.active_target is None:
self.active_target = await self._new_target()
if getattr(self.active_target, "aclose", None):
casted_target = cast(AsyncGenerator[T, None], self.active_target)
return await casted_target.aclose()
Expand All @@ -302,7 +297,8 @@ async def asend(self, *args, **kwargs) -> T:
Raises:
- AttributeError if the active_target does not have a asend() method
"""
await self._ensure_active_target()
if self.active_target is None:
self.active_target = await self._new_target()
if getattr(self.active_target, "asend", None):
casted_target = cast(AsyncGenerator[T, None], self.active_target)
return await self._iteration_helper(casted_target.asend(*args, **kwargs))
Expand All @@ -325,7 +321,8 @@ async def athrow(self, *args, **kwargs) -> T:
Raises:
- AttributeError if the active_target does not have a athrow() method
"""
await self._ensure_active_target()
if self.active_target is None:
self.active_target = await self._new_target()
if getattr(self.active_target, "athrow", None):
casted_target = cast(AsyncGenerator[T, None], self.active_target)
try:
Expand Down

0 comments on commit 30fccb9

Please sign in to comment.