From 30fccb906fe374f1cd1fbd5a6e4133e19066c064 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 15 Aug 2023 09:52:00 -0700 Subject: [PATCH] simplified ensure_tareget; fixed mypy issues --- google/api_core/retry_streaming.py | 39 ++++++++------------ google/api_core/retry_streaming_async.py | 45 +++++++++++------------- 2 files changed, 35 insertions(+), 49 deletions(-) diff --git a/google/api_core/retry_streaming.py b/google/api_core/retry_streaming.py index 265a1525..830d87c3 100644 --- a/google/api_core/retry_streaming.py +++ b/google/api_core/retry_streaming.py @@ -14,7 +14,7 @@ """Helpers for retries for streaming APIs.""" -from typing import Callable, Optional, Iterable, Iterator, Generator, TypeVar, Any, cast +from typing import Callable, Optional, List, Tuple, Iterable, Iterator, Generator, TypeVar, Any, cast import logging import time @@ -27,9 +27,7 @@ T = TypeVar("T") -def _build_timeout_error( - exc_list: list[Exception], is_timeout: bool, timeout_val: float -) -> tuple[Exception, Exception | None]: +def _build_timeout_error(exc_list:List[Exception], is_timeout:bool, timeout_val:float) -> Tuple[Exception, Optional[Exception]]: """ Default exception_factory implementation. Builds an exception after the retry fails @@ -44,13 +42,10 @@ def _build_timeout_error( """ src_exc = exc_list[-1] if exc_list else None if is_timeout: - return ( - exceptions.RetryError( - "Timeout of {:.1f}s exceeded".format(timeout_val), - src_exc, - ), + return exceptions.RetryError( + "Timeout of {:.1f}s exceeded".format(timeout_val), src_exc, - ) + ), src_exc else: return exc_list[-1], None @@ -132,9 +127,7 @@ def __init__( sleep_generator: Iterable[float], timeout: Optional[float] = None, on_error: Optional[Callable[[Exception], None]] = None, - exception_factory: Optional[ - Callable[[list[Exception], bool, float], tuple[Exception, Exception | None]] - ] = None, + exception_factory: Optional[Callable[[List[Exception], bool, float], Tuple[Exception, Optional[Exception]]]] = None, check_timeout_on_yield: bool = False, ): """ @@ -170,15 +163,10 @@ def __init__( self.predicate = predicate self.sleep_generator = iter(sleep_generator) self.on_error = on_error - if timeout is not None: - self.deadline = time.monotonic() + timeout - else: - self.deadline = None + self.deadline: Optional[float] = time.monotonic() + timeout if timeout else None self._check_timeout_on_yield = check_timeout_on_yield - self.error_list: list[Exception] = [] - self._exc_factory = partial( - exception_factory or _build_timeout_error, timeout_val=timeout - ) + self.error_list : List[Exception] = [] + self._exc_factory = partial(exception_factory or _build_timeout_error, timeout_val=timeout) def __iter__(self) -> Generator[T, Any, None]: """ @@ -194,9 +182,7 @@ def _handle_exception(self, exc) -> None: """ self.error_list.append(exc) if not self.predicate(exc): - final_exc, src_exc = self._exc_factory( - exc_list=self.error_list, is_timeout=False - ) + final_exc, src_exc = self._exc_factory(exc_list=self.error_list, is_timeout=False) raise final_exc from src_exc else: # run on_error callback if provided @@ -227,7 +213,10 @@ def _check_timeout(self, current_time: float) -> None: Raises: - Exception from exception_factory if the timeout has been exceeded """ - if self.deadline is not None and self.deadline < current_time: + if ( + self.deadline is not None + and self.deadline < current_time + ): exc, src_exc = self._exc_factory(exc_list=self.error_list, is_timeout=True) raise exc from src_exc diff --git a/google/api_core/retry_streaming_async.py b/google/api_core/retry_streaming_async.py index 5f00fa5d..62973d99 100644 --- a/google/api_core/retry_streaming_async.py +++ b/google/api_core/retry_streaming_async.py @@ -19,6 +19,8 @@ Callable, Optional, Iterable, + List, + Tuple, AsyncIterator, AsyncIterable, Awaitable, @@ -120,7 +122,7 @@ def __init__( timeout: Optional[float] = None, on_error: Optional[Callable[[Exception], None]] = None, exception_factory: Optional[ - Callable[[list[Exception], bool, float], tuple[Exception, Exception | None]] + Callable[[List[Exception], bool, float], Tuple[Exception, Optional[Exception]]] ] = None, check_timeout_on_yield: bool = False, ): @@ -158,12 +160,9 @@ def __init__( self.predicate = predicate self.sleep_generator = iter(sleep_generator) self.on_error = on_error - if timeout is not None: - self.deadline = time.monotonic() + timeout - else: - self.deadline = None + self.deadline: Optional[float] = time.monotonic() + timeout if timeout else None self._check_timeout_on_yield = check_timeout_on_yield - self.error_list: list[Exception] = [] + self.error_list: List[Exception] = [] self._exc_factory = partial( exception_factory or _build_timeout_error, timeout_val=timeout ) @@ -184,19 +183,14 @@ def _check_timeout( exc, src_exc = self._exc_factory(exc_list=self.error_list, is_timeout=True) raise exc from src_exc - async def _ensure_active_target(self) -> AsyncIterator[T]: + async def _new_target(self) -> AsyncIterator[T]: """ - Ensure that the active target is populated and ready to be iterated over. - - Returns: - - The active_target iterable + Creates and returns a new target iterator from the target function. """ - if self.active_target is None: - new_iterable = self.target_fn() - if isinstance(new_iterable, Awaitable): - new_iterable = await new_iterable - self.active_target = new_iterable.__aiter__() - return self.active_target + new_iterable = self.target_fn() + if isinstance(new_iterable, Awaitable): + new_iterable = await new_iterable + return new_iterable.__aiter__() def __aiter__(self) -> AsyncIterator[T]: """Implement the async iterator protocol.""" @@ -231,8 +225,7 @@ async def _handle_exception(self, exc) -> None: ) # sleep before retrying await asyncio.sleep(next_sleep) - self.active_target = None - await self._ensure_active_target() + self.active_target = await self._new_target() async def _iteration_helper(self, iteration_routine: Awaitable) -> T: """ @@ -265,9 +258,10 @@ async def __anext__(self) -> T: Returns: - The next value from the active_target iterator. """ - iterable = await self._ensure_active_target() + if self.active_target is None: + self.active_target = await self._new_target() return await self._iteration_helper( - iterable.__anext__(), + self.active_target.__anext__(), ) async def aclose(self) -> None: @@ -277,7 +271,8 @@ async def aclose(self) -> None: Raises: - AttributeError if the active_target does not have a aclose() method """ - await self._ensure_active_target() + if self.active_target is None: + self.active_target = await self._new_target() if getattr(self.active_target, "aclose", None): casted_target = cast(AsyncGenerator[T, None], self.active_target) return await casted_target.aclose() @@ -302,7 +297,8 @@ async def asend(self, *args, **kwargs) -> T: Raises: - AttributeError if the active_target does not have a asend() method """ - await self._ensure_active_target() + if self.active_target is None: + self.active_target = await self._new_target() if getattr(self.active_target, "asend", None): casted_target = cast(AsyncGenerator[T, None], self.active_target) return await self._iteration_helper(casted_target.asend(*args, **kwargs)) @@ -325,7 +321,8 @@ async def athrow(self, *args, **kwargs) -> T: Raises: - AttributeError if the active_target does not have a athrow() method """ - await self._ensure_active_target() + if self.active_target is None: + self.active_target = await self._new_target() if getattr(self.active_target, "athrow", None): casted_target = cast(AsyncGenerator[T, None], self.active_target) try: