diff --git a/redis/retry.py b/redis/retry.py index 606443053e..03fd973c4c 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -1,17 +1,27 @@ import socket from time import sleep +from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar from redis.exceptions import ConnectionError, TimeoutError +T = TypeVar("T") + +if TYPE_CHECKING: + from redis.backoff import AbstractBackoff + class Retry: """Retry a specific number of times after a failure""" def __init__( self, - backoff, - retries, - supported_errors=(ConnectionError, TimeoutError, socket.timeout), + backoff: "AbstractBackoff", + retries: int, + supported_errors: Tuple[Type[Exception], ...] = ( + ConnectionError, + TimeoutError, + socket.timeout, + ), ): """ Initialize a `Retry` object with a `Backoff` object @@ -24,7 +34,9 @@ def __init__( self._retries = retries self._supported_errors = supported_errors - def update_supported_errors(self, specified_errors: list): + def update_supported_errors( + self, specified_errors: Iterable[Type[Exception]] + ) -> None: """ Updates the supported errors with the specified error types """ @@ -32,7 +44,11 @@ def update_supported_errors(self, specified_errors: list): set(self._supported_errors + tuple(specified_errors)) ) - def call_with_retry(self, do, fail): + def call_with_retry( + self, + do: Callable[[], T], + fail: Callable[[Exception], Any], + ) -> T: """ Execute an operation that might fail and returns its result, or raise the exception that was thrown depending on the `Backoff` object. diff --git a/tests/test_retry.py b/tests/test_retry.py index e9d3015897..183807386d 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from redis.backoff import ExponentialBackoff, NoBackoff +from redis.backoff import AbstractBackoff, ExponentialBackoff, NoBackoff from redis.client import Redis from redis.connection import Connection, UnixDomainSocketConnection from redis.exceptions import ( @@ -15,7 +15,7 @@ from .conftest import _get_client -class BackoffMock: +class BackoffMock(AbstractBackoff): def __init__(self): self.reset_calls = 0 self.calls = 0