-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* New backoff numbers * Retry decorator + unit tests * logging
- Loading branch information
Showing
6 changed files
with
127 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import functools | ||
import logging | ||
import time | ||
|
||
BASE_RETRY_COUNT = 3 | ||
MAX_RETRY_DURATION = 86400 # 1 day in seconds | ||
MAX_BACKOFF = 60 # 1 minute in seconds | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def retry( | ||
transient_exceptions=None, | ||
base_retry_count=BASE_RETRY_COUNT, | ||
max_retry_duration=MAX_RETRY_DURATION, | ||
max_backoff=MAX_BACKOFF, | ||
): | ||
""" | ||
A decorator that retries a function at least base_retry_count times. | ||
If transient_exceptions are specified, it will retry for up to 1 day if any of those exceptions occur. | ||
""" | ||
transient_exceptions = tuple(transient_exceptions) if transient_exceptions else () | ||
|
||
def decorator(func): | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
attempt = 0 | ||
start_time = time.time() | ||
|
||
while True: | ||
try: | ||
return func(*args, **kwargs) | ||
except transient_exceptions as e: | ||
# Keep retrying transient exceptions for 1 day. | ||
elapsed_time = time.time() - start_time | ||
if elapsed_time >= max_retry_duration: | ||
raise | ||
logger.warning(f"Transient exception occurred: {e}. Retrying...") | ||
except Exception as e: | ||
# Retry all other exceptions BASE_RETRY_COUNT times. | ||
attempt += 1 | ||
if attempt >= base_retry_count: | ||
raise | ||
logger.warning(f"Exception occurred after {attempt}/{base_retry_count} attempts: {e}. Retrying...") | ||
sleep_time = min(2**attempt, max_backoff) # Exponential backoff with cap | ||
time.sleep(sleep_time) | ||
|
||
return wrapper | ||
|
||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import pytest | ||
import time | ||
|
||
from modelgauge.retry_decorator import retry, BASE_RETRY_COUNT | ||
|
||
|
||
def test_retry_success(): | ||
attempt_counter = 0 | ||
|
||
@retry() | ||
def always_succeed(): | ||
nonlocal attempt_counter | ||
attempt_counter += 1 | ||
return "success" | ||
|
||
assert always_succeed() == "success" | ||
assert attempt_counter == 1 | ||
|
||
|
||
@pytest.mark.parametrize("exceptions", [None, [ValueError]]) | ||
def test_retry_fails_after_base_retries(exceptions): | ||
attempt_counter = 0 | ||
|
||
@retry(transient_exceptions=exceptions) | ||
def always_fail(): | ||
nonlocal attempt_counter | ||
attempt_counter += 1 | ||
raise KeyError("Intentional failure") | ||
|
||
with pytest.raises(KeyError): | ||
always_fail() | ||
|
||
assert attempt_counter == BASE_RETRY_COUNT | ||
|
||
|
||
def test_retry_eventually_succeeds(): | ||
attempt_counter = 0 | ||
|
||
@retry(transient_exceptions=[ValueError]) | ||
def succeed_before_base_retry_total(): | ||
nonlocal attempt_counter | ||
attempt_counter += 1 | ||
if attempt_counter < BASE_RETRY_COUNT: | ||
raise ValueError("Intentional failure") | ||
return "success" | ||
|
||
assert succeed_before_base_retry_total() == "success" | ||
assert attempt_counter == BASE_RETRY_COUNT | ||
|
||
|
||
def test_retry_transient_eventually_succeeds(): | ||
attempt_counter = 0 | ||
start_time = time.time() | ||
|
||
@retry(transient_exceptions=[ValueError], max_retry_duration=3, base_retry_count=1) | ||
def succeed_eventually(): | ||
nonlocal attempt_counter | ||
attempt_counter += 1 | ||
elapsed_time = time.time() - start_time | ||
if elapsed_time < 1: | ||
raise ValueError("Intentional failure") | ||
return "success" | ||
|
||
assert succeed_eventually() == "success" |