From 2be8ed2da2518cfdcd1bfc5effbc1a1075cd9d30 Mon Sep 17 00:00:00 2001 From: Dustin Tran Date: Thu, 2 Nov 2023 14:57:39 -0700 Subject: [PATCH] Add retry for custom exception types. PiperOrigin-RevId: 578988167 --- edward2/maps.py | 18 +++++++++++++++--- edward2/maps_test.py | 25 +++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/edward2/maps.py b/edward2/maps.py index df5a689f..ee5861e0 100644 --- a/edward2/maps.py +++ b/edward2/maps.py @@ -66,6 +66,7 @@ def robust_map( max_retries: int | None = None, max_workers: int | None = None, raise_error: bool = False, + retry_exception_types: list[Exception] | None = None, ) -> Sequence[U | V]: """Maps a function to inputs using a threadpool. @@ -87,6 +88,8 @@ def robust_map( os.cpu_count() + 4)`. raise_error: Whether to raise an error if an input exceeds `max_retries`. Will override any setting of `error_output`. + retry_exception_types: Exception types to retry on. Defaults to retrying + only on grpc's RPC exceptions. Returns: A list of items each of type U. They are the outputs of `fn` applied to @@ -94,16 +97,25 @@ def robust_map( """ if index_to_output is None: index_to_output = {} + if retry_exception_types is None: + retry_exception_types = [] + retry_exception_types = retry_exception_types + [ + grpc.RpcError, + ] + retry_exception_types = list(set(retry_exception_types)) + retry = tenacity.retry_if_exception_type(retry_exception_types[0]) + for retry_exception_type in retry_exception_types[1:]: + retry = retry | tenacity.retry_if_exception_type(retry_exception_type) if max_retries is None: fn_with_backoff = tenacity.retry( - retry=tenacity.retry_if_exception_type(grpc.RpcError), + retry=retry, wait=tenacity.wait_random_exponential(min=1, max=30), )(fn) else: fn_with_backoff = tenacity.retry( - retry=tenacity.retry_if_exception_type(grpc.RpcError), + retry=retry, wait=tenacity.wait_random_exponential(min=1, max=30), - stop=tenacity.stop_after_attempt(max_retries), + stop=tenacity.stop_after_attempt(max_retries + 1), )(fn) num_existing = len(index_to_output) num_inputs = len(inputs) diff --git a/edward2/maps_test.py b/edward2/maps_test.py index 849099e5..c2b79f11 100644 --- a/edward2/maps_test.py +++ b/edward2/maps_test.py @@ -109,6 +109,31 @@ def fn(x): with self.assertRaises(ValueError): maps.robust_map(fn, x) + def test_robust_map_retry_exception_types(self): + + def make_fn(): + busy = True + def fn(x): + nonlocal busy + if busy: + busy = False + raise RuntimeError("Sorry, can't process request right now.") + else: + busy = True + return x + 1 + return fn + + fn = make_fn() + x = [0, 1, 2] + y = maps.robust_map( + fn, + x, + max_retries=1, + max_workers=1, + retry_exception_types=[RuntimeError], + ) + self.assertEqual(y, [1, 2, 3]) + if __name__ == '__main__': tf.test.main()