Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

relax litellm provider constraint #820

Merged
40 changes: 20 additions & 20 deletions garak/generators/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import litellm

from garak import _config
from garak.exception import BadGeneratorException
from garak.generators.base import Generator

# Fix issue with Ollama which does not support `presence_penalty`
Expand Down Expand Up @@ -125,11 +126,7 @@ def __init__(self, name: str = "", generations: int = 10, config_root=_config):
self.name, generations=self.generations, config_root=config_root
)

if self.provider is None:
raise ValueError(
"litellm generator needs to have a provider value configured - see docs"
)
elif (
if (
arjun-krishna1 marked this conversation as resolved.
Show resolved Hide resolved
self.api_key is None
): # TODO: special case where api_key is not always required
if self.provider == "openai":
Expand All @@ -140,7 +137,7 @@ def __init__(self, name: str = "", generations: int = 10, config_root=_config):
" or in the configuration file"
)

@backoff.on_exception(backoff.fibo, Exception, max_value=70)
@backoff.on_exception(backoff.fibo, litellm.exceptions.APIError, max_value=70)
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
Expand All @@ -157,20 +154,23 @@ def _call_model(
print(msg)
return []

response = litellm.completion(
model=self.name,
messages=prompt,
temperature=self.temperature,
top_p=self.top_p,
n=generations_this_call,
stop=self.stop,
max_tokens=self.max_tokens,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
api_base=self.api_base,
custom_llm_provider=self.provider,
api_key=self.api_key,
)
try:
response = litellm.completion(
model=self.name,
messages=prompt,
temperature=self.temperature,
top_p=self.top_p,
n=generations_this_call,
stop=self.stop,
max_tokens=self.max_tokens,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
api_base=self.api_base,
custom_llm_provider=self.provider,
api_key=self.api_key,
)
except litellm.exceptions.BadRequestError as e:
arjun-krishna1 marked this conversation as resolved.
Show resolved Hide resolved
raise BadGeneratorException() from e
jmartin-tech marked this conversation as resolved.
Show resolved Hide resolved

if self.supports_multiple_generations:
return [c.message.content for c in response.choices]
Expand Down
9 changes: 9 additions & 0 deletions tests/generators/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from os import getenv

from garak.exception import BadGeneratorException
from garak.generators.litellm import LiteLLMGenerator

DEFAULT_GENERATIONS_QTY = 10
Expand Down Expand Up @@ -43,3 +44,11 @@ def test_litellm_openrouter():
for item in output:
assert isinstance(item, str)
print("test passed!")


def test_litellm_model_non_existence():
model_name = "non-existent-model"
generator = LiteLLMGenerator(name=model_name)
with pytest.raises(BadGeneratorException):
output = generator.generate("This should raise an exception")
assert "Exceptions on model non-existence raised by litellm should be bubbled up"
arjun-krishna1 marked this conversation as resolved.
Show resolved Hide resolved
Loading