Skip to content

Commit

Permalink
Add retry for custom exception types.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 578988167
  • Loading branch information
dustinvtran authored and edward-bot committed Nov 2, 2023
1 parent d0b772c commit 2be8ed2
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
18 changes: 15 additions & 3 deletions edward2/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def robust_map(
max_retries: int | None = None,
max_workers: int | None = None,
raise_error: bool = False,
retry_exception_types: list[Exception] | None = None,
) -> Sequence[U | V]:
"""Maps a function to inputs using a threadpool.
Expand All @@ -87,23 +88,34 @@ def robust_map(
os.cpu_count() + 4)`.
raise_error: Whether to raise an error if an input exceeds `max_retries`.
Will override any setting of `error_output`.
retry_exception_types: Exception types to retry on. Defaults to retrying
only on grpc's RPC exceptions.
Returns:
A list of items each of type U. They are the outputs of `fn` applied to
the elements of `inputs`.
"""
if index_to_output is None:
index_to_output = {}
if retry_exception_types is None:
retry_exception_types = []
retry_exception_types = retry_exception_types + [
grpc.RpcError,
]
retry_exception_types = list(set(retry_exception_types))
retry = tenacity.retry_if_exception_type(retry_exception_types[0])
for retry_exception_type in retry_exception_types[1:]:
retry = retry | tenacity.retry_if_exception_type(retry_exception_type)
if max_retries is None:
fn_with_backoff = tenacity.retry(
retry=tenacity.retry_if_exception_type(grpc.RpcError),
retry=retry,
wait=tenacity.wait_random_exponential(min=1, max=30),
)(fn)
else:
fn_with_backoff = tenacity.retry(
retry=tenacity.retry_if_exception_type(grpc.RpcError),
retry=retry,
wait=tenacity.wait_random_exponential(min=1, max=30),
stop=tenacity.stop_after_attempt(max_retries),
stop=tenacity.stop_after_attempt(max_retries + 1),
)(fn)
num_existing = len(index_to_output)
num_inputs = len(inputs)
Expand Down
25 changes: 25 additions & 0 deletions edward2/maps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,31 @@ def fn(x):
with self.assertRaises(ValueError):
maps.robust_map(fn, x)

def test_robust_map_retry_exception_types(self):

def make_fn():
busy = True
def fn(x):
nonlocal busy
if busy:
busy = False
raise RuntimeError("Sorry, can't process request right now.")
else:
busy = True
return x + 1
return fn

fn = make_fn()
x = [0, 1, 2]
y = maps.robust_map(
fn,
x,
max_retries=1,
max_workers=1,
retry_exception_types=[RuntimeError],
)
self.assertEqual(y, [1, 2, 3])


if __name__ == '__main__':
tf.test.main()

0 comments on commit 2be8ed2

Please sign in to comment.