Skip to content

Commit

Permalink
Fix retries in async mode (#2180)
Browse files Browse the repository at this point in the history
* Avoid mutating a global retry_on_error list

* Make retries config consistent in sync and async

* Fix async retries

* Add new TestConnectionConstructorWithRetry tests
  • Loading branch information
elemoine authored Jun 19, 2022
1 parent 3370298 commit bea7299
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 9 deletions.
17 changes: 14 additions & 3 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(
encoding_errors: str = "strict",
decode_responses: bool = False,
retry_on_timeout: bool = False,
retry_on_error: Optional[list] = None,
ssl: bool = False,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
Expand All @@ -176,8 +177,10 @@ def __init__(
):
"""
Initialize a new Redis client.
To specify a retry policy, first set `retry_on_timeout` to `True`
then set `retry` to a valid `Retry` object
To specify a retry policy for specific errors, first set
`retry_on_error` to a list of the error/s to retry on, then set
`retry` to a valid `Retry` object.
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
"""
kwargs: Dict[str, Any]
# auto_close_connection_pool only has an effect if connection_pool is
Expand All @@ -188,6 +191,10 @@ def __init__(
auto_close_connection_pool if connection_pool is None else False
)
if not connection_pool:
if not retry_on_error:
retry_on_error = []
if retry_on_timeout is True:
retry_on_error.append(TimeoutError)
kwargs = {
"db": db,
"username": username,
Expand All @@ -197,6 +204,7 @@ def __init__(
"encoding_errors": encoding_errors,
"decode_responses": decode_responses,
"retry_on_timeout": retry_on_timeout,
"retry_on_error": retry_on_error,
"retry": copy.deepcopy(retry),
"max_connections": max_connections,
"health_check_interval": health_check_interval,
Expand Down Expand Up @@ -461,7 +469,10 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
is not a TimeoutError
"""
await conn.disconnect()
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
if (
conn.retry_on_error is None
or isinstance(error, tuple(conn.retry_on_error)) is False
):
raise error

# COMMAND EXECUTION AND PROTOCOL PARSING
Expand Down
17 changes: 17 additions & 0 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ class Connection:
"socket_type",
"redis_connect_func",
"retry_on_timeout",
"retry_on_error",
"health_check_interval",
"next_health_check",
"last_active_at",
Expand Down Expand Up @@ -606,6 +607,7 @@ def __init__(
socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
socket_type: int = 0,
retry_on_timeout: bool = False,
retry_on_error: Union[list, _Sentinel] = SENTINEL,
encoding: str = "utf-8",
encoding_errors: str = "strict",
decode_responses: bool = False,
Expand All @@ -631,12 +633,19 @@ def __init__(
self.socket_keepalive_options = socket_keepalive_options or {}
self.socket_type = socket_type
self.retry_on_timeout = retry_on_timeout
if retry_on_error is SENTINEL:
retry_on_error = []
if retry_on_timeout:
retry_on_error.append(TimeoutError)
self.retry_on_error = retry_on_error
if retry_on_error:
if not retry:
self.retry = Retry(NoBackoff(), 1)
else:
# deep-copy the Retry object as it is mutable
self.retry = copy.deepcopy(retry)
# Update the retry's supported errors with the specified errors
self.retry.update_supported_errors(retry_on_error)
else:
self.retry = Retry(NoBackoff(), 0)
self.health_check_interval = health_check_interval
Expand Down Expand Up @@ -1169,6 +1178,7 @@ def __init__(
encoding_errors: str = "strict",
decode_responses: bool = False,
retry_on_timeout: bool = False,
retry_on_error: Union[list, _Sentinel] = SENTINEL,
parser_class: Type[BaseParser] = DefaultParser,
socket_read_size: int = 65536,
health_check_interval: float = 0.0,
Expand All @@ -1190,12 +1200,19 @@ def __init__(
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None
self.retry_on_timeout = retry_on_timeout
if retry_on_error is SENTINEL:
retry_on_error = []
if retry_on_timeout:
retry_on_error.append(TimeoutError)
self.retry_on_error = retry_on_error
if retry_on_error:
if retry is None:
self.retry = Retry(NoBackoff(), 1)
else:
# deep-copy the Retry object as it is mutable
self.retry = copy.deepcopy(retry)
# Update the retry's supported errors with the specified errors
self.retry.update_supported_errors(retry_on_error)
else:
self.retry = Retry(NoBackoff(), 0)
self.health_check_interval = health_check_interval
Expand Down
8 changes: 8 additions & 0 deletions redis/asyncio/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ def __init__(
self._retries = retries
self._supported_errors = supported_errors

def update_supported_errors(self, specified_errors: list):
"""
Updates the supported errors with the specified error types
"""
self._supported_errors = tuple(
set(self._supported_errors + tuple(specified_errors))
)

async def call_with_retry(
self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any]
) -> T:
Expand Down
4 changes: 3 additions & 1 deletion redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def __init__(
errors=None,
decode_responses=False,
retry_on_timeout=False,
retry_on_error=[],
retry_on_error=None,
ssl=False,
ssl_keyfile=None,
ssl_certfile=None,
Expand Down Expand Up @@ -958,6 +958,8 @@ def __init__(
)
)
encoding_errors = errors
if not retry_on_error:
retry_on_error = []
if retry_on_timeout is True:
retry_on_error.append(TimeoutError)
kwargs = {
Expand Down
8 changes: 6 additions & 2 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def __init__(
socket_keepalive_options=None,
socket_type=0,
retry_on_timeout=False,
retry_on_error=[],
retry_on_error=SENTINEL,
encoding="utf-8",
encoding_errors="strict",
decode_responses=False,
Expand Down Expand Up @@ -547,6 +547,8 @@ def __init__(
self.socket_keepalive_options = socket_keepalive_options or {}
self.socket_type = socket_type
self.retry_on_timeout = retry_on_timeout
if retry_on_error is SENTINEL:
retry_on_error = []
if retry_on_timeout:
# Add TimeoutError to the errors list to retry on
retry_on_error.append(TimeoutError)
Expand Down Expand Up @@ -1065,7 +1067,7 @@ def __init__(
encoding_errors="strict",
decode_responses=False,
retry_on_timeout=False,
retry_on_error=[],
retry_on_error=SENTINEL,
parser_class=DefaultParser,
socket_read_size=65536,
health_check_interval=0,
Expand All @@ -1088,6 +1090,8 @@ def __init__(
self.password = password
self.socket_timeout = socket_timeout
self.retry_on_timeout = retry_on_timeout
if retry_on_error is SENTINEL:
retry_on_error = []
if retry_on_timeout:
# Add TimeoutError to the errors list to retry on
retry_on_error.append(TimeoutError)
Expand Down
38 changes: 35 additions & 3 deletions tests/test_asyncio/test_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from redis.asyncio.connection import Connection, UnixDomainSocketConnection
from redis.asyncio.retry import Retry
from redis.backoff import AbstractBackoff, NoBackoff
from redis.exceptions import ConnectionError
from redis.exceptions import ConnectionError, TimeoutError


class BackoffMock(AbstractBackoff):
Expand All @@ -22,23 +22,55 @@ def compute(self, failures):
class TestConnectionConstructorWithRetry:
"Test that the Connection constructors properly handles Retry objects"

@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
def test_retry_on_error_set(self, Class):
class CustomError(Exception):
pass

retry_on_error = [ConnectionError, TimeoutError, CustomError]
c = Class(retry_on_error=retry_on_error)
assert c.retry_on_error == retry_on_error
assert isinstance(c.retry, Retry)
assert c.retry._retries == 1
assert set(c.retry._supported_errors) == set(retry_on_error)

@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
def test_retry_on_error_not_set(self, Class):
c = Class()
assert c.retry_on_error == []
assert isinstance(c.retry, Retry)
assert c.retry._retries == 0

@pytest.mark.parametrize("retry_on_timeout", [False, True])
@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
def test_retry_on_timeout_boolean(self, Class, retry_on_timeout):
def test_retry_on_timeout(self, Class, retry_on_timeout):
c = Class(retry_on_timeout=retry_on_timeout)
assert c.retry_on_timeout == retry_on_timeout
assert isinstance(c.retry, Retry)
assert c.retry._retries == (1 if retry_on_timeout else 0)

@pytest.mark.parametrize("retries", range(10))
@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
def test_retry_on_timeout_retry(self, Class, retries: int):
def test_retry_with_retry_on_timeout(self, Class, retries: int):
retry_on_timeout = retries > 0
c = Class(retry_on_timeout=retry_on_timeout, retry=Retry(NoBackoff(), retries))
assert c.retry_on_timeout == retry_on_timeout
assert isinstance(c.retry, Retry)
assert c.retry._retries == retries

@pytest.mark.parametrize("retries", range(10))
@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
def test_retry_with_retry_on_error(self, Class, retries: int):
class CustomError(Exception):
pass

retry_on_error = [ConnectionError, TimeoutError, CustomError]
c = Class(retry_on_error=retry_on_error, retry=Retry(NoBackoff(), retries))
assert c.retry_on_error == retry_on_error
assert isinstance(c.retry, Retry)
assert c.retry._retries == retries
assert set(c.retry._supported_errors) == set(retry_on_error)


class TestRetry:
"Test that Retry calls backoff and retries the expected number of times"
Expand Down

0 comments on commit bea7299

Please sign in to comment.