Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make nested retry blocks work for RPC calls #589

Merged
merged 4 commits into from
Jan 21, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion google/cloud/ndb/_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from google.api_core import retry as core_retry
from google.api_core import exceptions as core_exceptions
from google.cloud.ndb import exceptions
from google.cloud.ndb import tasklets

_DEFAULT_INITIAL_DELAY = 1.0 # seconds
Expand Down Expand Up @@ -59,24 +60,47 @@ def retry_async(callback, retries=_DEFAULT_RETRIES):
@tasklets.tasklet
@wraps_safely(callback)
def retry_wrapper(*args, **kwargs):
from google.cloud.ndb import context as context_module

sleep_generator = core_retry.exponential_sleep_generator(
_DEFAULT_INITIAL_DELAY,
_DEFAULT_MAXIMUM_DELAY,
_DEFAULT_DELAY_MULTIPLIER,
)

for sleep_time in itertools.islice(sleep_generator, retries + 1):
context = context_module.get_context()
if not context.in_retry():
# We need to be able to identify if we are inside a nested
# retry. Here, we set the retry state in the context. This is
# used for deciding if an exception should be raised
# immediately or passed up to the outer retry block.
context.set_retry_state(repr(callback))
try:
result = callback(*args, **kwargs)
if isinstance(result, tasklets.Future):
result = yield result
except exceptions.NestedRetryException as e:
error = e
except Exception as e:
# `e` is removed from locals at end of block
error = e # See: https://goo.gl/5J8BMK
if not is_transient_error(error):
raise error
# If we are in an inner retry block, use special nested
# retry exception to bubble up to outer retry. Else, raise
# actual exception.
if context.get_retry_state() != repr(callback):
message = getattr(error, "message", str(error))
raise exceptions.NestedRetryException(message)
else:
raise error
else:
raise tasklets.Return(result)
finally:
# No matter what, if we are exiting the top level retry,
# clear the retry state in the context.
if context.get_retry_state() == repr(callback): # pragma: NO BRANCH
context.clear_retry_state()

yield tasklets.sleep(sleep_time)

Expand Down
1 change: 1 addition & 0 deletions google/cloud/ndb/_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def _transaction_async(context, callback, read_only=False):
# new event loop is of the same type as the current one, to propagate
# the event loop class used for testing.
eventloop=type(context.eventloop)(),
retry=context.get_retry_state(),
)

# The outer loop is dependent on the inner loop
Expand Down
24 changes: 23 additions & 1 deletion google/cloud/ndb/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def __new__(
datastore_policy=None,
on_commit_callbacks=None,
legacy_data=True,
retry=None,
rpc_time=None,
wait_time=None,
):
Expand Down Expand Up @@ -286,6 +287,7 @@ def __new__(
context.set_global_cache_policy(global_cache_policy)
context.set_global_cache_timeout_policy(global_cache_timeout_policy)
context.set_datastore_policy(datastore_policy)
context.set_retry_state(retry)

return context

Expand All @@ -296,7 +298,9 @@ def new(self, **kwargs):
will be substituted.
"""
fields = self._fields + tuple(self.__dict__.keys())
state = {name: getattr(self, name) for name in fields}
state = {
name: getattr(self, name) for name in fields if not name.startswith("_")
}
state.update(kwargs)
return type(self)(**state)

Expand Down Expand Up @@ -544,6 +548,15 @@ def policy(key):

set_memcache_timeout_policy = set_global_cache_timeout_policy

def get_retry_state(self):
return self._retry

def set_retry_state(self, state):
self._retry = state

def clear_retry_state(self):
self._retry = None

def call_on_commit(self, callback):
"""Call a callback upon successful commit of a transaction.

Expand Down Expand Up @@ -578,6 +591,15 @@ def in_transaction(self):
"""
return self.transaction is not None

def in_retry(self):
"""Get whether we are already in a retry block.

Returns:
bool: :data:`True` if currently in a retry block, otherwise
:data:`False`.
"""
return self._retry is not None

def memcache_add(self, *args, **kwargs):
"""Direct pass-through to memcache client."""
raise exceptions.NoLongerImplementedError()
Expand Down
8 changes: 8 additions & 0 deletions google/cloud/ndb/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,11 @@ class Cancelled(Error):
a call to ``Future.cancel`` (possibly on a future that depends on this
future).
"""


class NestedRetryException(Error):
"""A nested retry block raised an exception.

Raised when a nested retry block cannot complete due to an exception. This
allows the outer retry to get back control and retry the whole operation.
"""
31 changes: 31 additions & 0 deletions tests/unit/test__retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,37 @@ def callback():
retry = _retry.retry_async(callback)
assert retry().result() == "foo"

@staticmethod
@pytest.mark.usefixtures("in_context")
def test_nested_retry():
def callback():
def nested_callback():
return "bar"

nested = _retry.retry_async(nested_callback)
assert nested().result() == "bar"

return "foo"

retry = _retry.retry_async(callback)
assert retry().result() == "foo"

@staticmethod
@pytest.mark.usefixtures("in_context")
def test_nested_retry_with_exception():
error = Exception("Fail")

def callback():
def nested_callback():
raise error

nested = _retry.retry_async(nested_callback, retries=1)
return nested()

with pytest.raises(core_exceptions.RetryError):
retry = _retry.retry_async(callback, retries=1)
retry().result()

@staticmethod
@pytest.mark.usefixtures("in_context")
def test_success_callback_is_tasklet():
Expand Down