Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

relax litellm provider constraint #820

Merged
61 changes: 23 additions & 38 deletions garak/generators/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@
reads API keys from the respective environment variables.
(e.g. OPENAI_API_KEY for OpenAI models)

API key can also be directly set in the supplied generator json config.
This also enables support for any custom provider that follows the OAI format.

e.g Supply a JSON like this for Ollama's OAI api:
```json
{
"litellm": {
"LiteLLMGenerator" : {
"api_base" : "http://localhost:11434/v1",
"provider" : "openai",
"api_key" : "test"
"provider" : "openai"
}
}
}
Expand All @@ -41,6 +37,7 @@
import litellm

from garak import _config
from garak.exception import BadGeneratorException
from garak.generators.base import Generator

# Fix issue with Ollama which does not support `presence_penalty`
Expand Down Expand Up @@ -79,7 +76,6 @@
class LiteLLMGenerator(Generator):
"""Generator wrapper using LiteLLM to allow access to different providers using the OpenAI API format."""

ENV_VAR = "OPENAI_API_KEY"
DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"temperature": 0.7,
"top_p": 1.0,
Expand Down Expand Up @@ -110,9 +106,7 @@ class LiteLLMGenerator(Generator):
def __init__(self, name: str = "", generations: int = 10, config_root=_config):
self.name = name
self.api_base = None
self.api_key = None
self.provider = None
self.key_env_var = self.ENV_VAR
self.generations = generations
self._load_config(config_root)
self.fullname = f"LiteLLM {self.name}"
Expand All @@ -125,22 +119,7 @@ def __init__(self, name: str = "", generations: int = 10, config_root=_config):
self.name, generations=self.generations, config_root=config_root
)

if self.provider is None:
raise ValueError(
"litellm generator needs to have a provider value configured - see docs"
)
elif (
self.api_key is None
): # TODO: special case where api_key is not always required
if self.provider == "openai":
self.api_key = getenv(self.key_env_var, None)
if self.api_key is None:
raise APIKeyMissingError(
f"Please supply an OpenAI API key in the {self.key_env_var} environment variable"
" or in the configuration file"
)

@backoff.on_exception(backoff.fibo, Exception, max_value=70)
@backoff.on_exception(backoff.fibo, litellm.exceptions.APIError, max_value=70)
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
Expand All @@ -157,20 +136,26 @@ def _call_model(
print(msg)
return []

response = litellm.completion(
model=self.name,
messages=prompt,
temperature=self.temperature,
top_p=self.top_p,
n=generations_this_call,
stop=self.stop,
max_tokens=self.max_tokens,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
api_base=self.api_base,
custom_llm_provider=self.provider,
api_key=self.api_key,
)
try:
response = litellm.completion(
model=self.name,
messages=prompt,
temperature=self.temperature,
top_p=self.top_p,
n=generations_this_call,
stop=self.stop,
max_tokens=self.max_tokens,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
api_base=self.api_base,
custom_llm_provider=self.provider,
)
except (
litellm.exceptions.AuthenticationError, # authentication failed for detected or passed `provider`
litellm.exceptions.BadRequestError,
) as e:

raise BadGeneratorException("Unrecoverable error during litellm completion see log for details") from e

if self.supports_multiple_generations:
return [c.message.content for c in response.choices]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ dependencies = [
"ecoji>=0.1.1",
"deepl==1.17.0",
"fschat>=0.2.36",
"litellm>=1.33.8",
"litellm>=1.41.21",
"jsonpath-ng>=1.6.1",
"huggingface_hub>=0.21.0",
'python-magic-bin>=0.4.14; sys_platform == "win32"',
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ zalgolib>=0.2.2
ecoji>=0.1.1
deepl==1.17.0
fschat>=0.2.36
litellm>=1.33.8
litellm>=1.41.21
jsonpath-ng>=1.6.1
huggingface_hub>=0.21.0
python-magic-bin>=0.4.14; sys_platform == "win32"
Expand Down
18 changes: 18 additions & 0 deletions tests/generators/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from os import getenv

from garak.exception import BadGeneratorException
from garak.generators.litellm import LiteLLMGenerator

DEFAULT_GENERATIONS_QTY = 10
Expand Down Expand Up @@ -43,3 +44,20 @@ def test_litellm_openrouter():
for item in output:
assert isinstance(item, str)
print("test passed!")


def test_litellm_model_detection():
custom_config = {
"generators": {
"litellm": {
"api_base": "https://garak.example.com/v1",
}
}
}
model_name = "non-existent-model"
generator = LiteLLMGenerator(name=model_name, config_root=custom_config)
with pytest.raises(BadGeneratorException):
generator.generate("This should raise an exception")
generator = LiteLLMGenerator(name="openai/invalid-model", config_root=custom_config)
with pytest.raises(BadGeneratorException):
generator.generate("This should raise an exception")
Loading