Skip to content

Commit

Permalink
Retry decorator (#864)
Browse files Browse the repository at this point in the history
* New backoff numbers

* Retry decorator + unit tests

* logging
  • Loading branch information
bkorycki authored Feb 17, 2025
1 parent 068870f commit 5e4f5b1
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 1 deletion.
3 changes: 3 additions & 0 deletions plugins/google/modelgauge/suts/google_genai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from typing import Dict, List, Optional

import google.generativeai as genai # type: ignore
from google.api_core.exceptions import InternalServerError, ResourceExhausted, RetryError, TooManyRequests
from google.generativeai.types import HarmCategory, HarmBlockThreshold # type: ignore
from pydantic import BaseModel

from modelgauge.general import APIException
from modelgauge.prompt import TextPrompt
from modelgauge.retry_decorator import retry
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
from modelgauge.sut import REFUSAL_RESPONSE, PromptResponseSUT, SUTCompletion, SUTResponse
from modelgauge.sut_capabilities import AcceptsTextPrompt
Expand Down Expand Up @@ -108,6 +110,7 @@ def translate_text_prompt(self, prompt: TextPrompt) -> GoogleGenAiRequest:
contents=prompt.text, generation_config=generation_config, safety_settings=self.safety_settings
)

@retry(transient_exceptions=[InternalServerError, ResourceExhausted, RetryError, TooManyRequests])
def evaluate(self, request: GoogleGenAiRequest) -> GoogleGenAiResponse:
if self.model is None:
# Handle lazy init.
Expand Down
5 changes: 4 additions & 1 deletion plugins/mistral/modelgauge/suts/mistral_sut.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import warnings
from typing import Optional

from mistralai.models import ChatCompletionResponse, ClassificationResponse
from mistralai.models import ChatCompletionResponse, ClassificationResponse, SDKError
from modelgauge.prompt import TextPrompt
from modelgauge.retry_decorator import retry
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse
from modelgauge.sut_capabilities import AcceptsTextPrompt
Expand Down Expand Up @@ -60,6 +61,7 @@ def translate_text_prompt(self, prompt: TextPrompt) -> MistralAIRequest:
args["max_tokens"] = prompt.options.max_tokens
return MistralAIRequest(**args)

@retry(transient_exceptions=[SDKError])
def evaluate(self, request: MistralAIRequest) -> ChatCompletionResponse:
response = self.client.request(request.model_dump(exclude_none=True)) # type: ignore
return response
Expand Down Expand Up @@ -130,6 +132,7 @@ def translate_text_prompt(self, prompt: TextPrompt) -> MistralAIRequest:
args["max_tokens"] = prompt.options.max_tokens
return MistralAIRequest(temperature=self.temperature, n=self.num_generations, **args)

@retry(transient_exceptions=[SDKError])
def evaluate(self, request: MistralAIRequest) -> MistralAIResponseWithModerations:
response = self.client.request(request.model_dump(exclude_none=True)) # type: ignore
assert (
Expand Down
3 changes: 3 additions & 0 deletions plugins/nvidia/modelgauge/suts/nvidia_nim_api_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, Dict, List, Optional, Union

from openai import OpenAI
from openai import APITimeoutError, ConflictError, InternalServerError, RateLimitError
from openai.types.chat import ChatCompletion
from pydantic import BaseModel

from modelgauge.prompt import ChatPrompt, ChatRole, SUTOptions, TextPrompt
from modelgauge.retry_decorator import retry
from modelgauge.secret_values import (
InjectSecret,
RequiredSecret,
Expand Down Expand Up @@ -115,6 +117,7 @@ def _translate_request(self, messages: List[OpenAIChatMessage], options: SUTOpti
**optional_kwargs,
)

@retry(transient_exceptions=[APITimeoutError, ConflictError, InternalServerError, RateLimitError])
def evaluate(self, request: OpenAIChatRequest) -> ChatCompletion:
if self.client is None:
# Handle lazy init.
Expand Down
3 changes: 3 additions & 0 deletions plugins/openai/modelgauge/suts/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, Dict, List, Optional, Union

from openai import OpenAI
from openai import APITimeoutError, ConflictError, InternalServerError, RateLimitError
from openai.types.chat import ChatCompletion
from pydantic import BaseModel

from modelgauge.prompt import ChatPrompt, ChatRole, SUTOptions, TextPrompt
from modelgauge.retry_decorator import retry
from modelgauge.secret_values import (
InjectSecret,
OptionalSecret,
Expand Down Expand Up @@ -139,6 +141,7 @@ def _translate_request(self, messages: List[OpenAIChatMessage], options: SUTOpti
**optional_kwargs,
)

@retry(transient_exceptions=[APITimeoutError, ConflictError, InternalServerError, RateLimitError])
def evaluate(self, request: OpenAIChatRequest) -> ChatCompletion:
if self.client is None:
# Handle lazy init.
Expand Down
50 changes: 50 additions & 0 deletions src/modelgauge/retry_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import functools
import logging
import time

BASE_RETRY_COUNT = 3
MAX_RETRY_DURATION = 86400 # 1 day in seconds
MAX_BACKOFF = 60 # 1 minute in seconds

logger = logging.getLogger(__name__)


def retry(
transient_exceptions=None,
base_retry_count=BASE_RETRY_COUNT,
max_retry_duration=MAX_RETRY_DURATION,
max_backoff=MAX_BACKOFF,
):
"""
A decorator that retries a function at least base_retry_count times.
If transient_exceptions are specified, it will retry for up to 1 day if any of those exceptions occur.
"""
transient_exceptions = tuple(transient_exceptions) if transient_exceptions else ()

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
attempt = 0
start_time = time.time()

while True:
try:
return func(*args, **kwargs)
except transient_exceptions as e:
# Keep retrying transient exceptions for 1 day.
elapsed_time = time.time() - start_time
if elapsed_time >= max_retry_duration:
raise
logger.warning(f"Transient exception occurred: {e}. Retrying...")
except Exception as e:
# Retry all other exceptions BASE_RETRY_COUNT times.
attempt += 1
if attempt >= base_retry_count:
raise
logger.warning(f"Exception occurred after {attempt}/{base_retry_count} attempts: {e}. Retrying...")
sleep_time = min(2**attempt, max_backoff) # Exponential backoff with cap
time.sleep(sleep_time)

return wrapper

return decorator
64 changes: 64 additions & 0 deletions tests/modelgauge_tests/test_retry_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
import time

from modelgauge.retry_decorator import retry, BASE_RETRY_COUNT


def test_retry_success():
attempt_counter = 0

@retry()
def always_succeed():
nonlocal attempt_counter
attempt_counter += 1
return "success"

assert always_succeed() == "success"
assert attempt_counter == 1


@pytest.mark.parametrize("exceptions", [None, [ValueError]])
def test_retry_fails_after_base_retries(exceptions):
attempt_counter = 0

@retry(transient_exceptions=exceptions)
def always_fail():
nonlocal attempt_counter
attempt_counter += 1
raise KeyError("Intentional failure")

with pytest.raises(KeyError):
always_fail()

assert attempt_counter == BASE_RETRY_COUNT


def test_retry_eventually_succeeds():
attempt_counter = 0

@retry(transient_exceptions=[ValueError])
def succeed_before_base_retry_total():
nonlocal attempt_counter
attempt_counter += 1
if attempt_counter < BASE_RETRY_COUNT:
raise ValueError("Intentional failure")
return "success"

assert succeed_before_base_retry_total() == "success"
assert attempt_counter == BASE_RETRY_COUNT


def test_retry_transient_eventually_succeeds():
attempt_counter = 0
start_time = time.time()

@retry(transient_exceptions=[ValueError], max_retry_duration=3, base_retry_count=1)
def succeed_eventually():
nonlocal attempt_counter
attempt_counter += 1
elapsed_time = time.time() - start_time
if elapsed_time < 1:
raise ValueError("Intentional failure")
return "success"

assert succeed_eventually() == "success"

0 comments on commit 5e4f5b1

Please sign in to comment.