Skip to content

Commit

Permalink
Merge pull request #1989 from Agenta-AI/feature/age-532-poc-1e-add-ll…
Browse files Browse the repository at this point in the history
…m-api-key-checks-in-llm-based-evaluators

[Enhancement] Add LLM API key checks to LLM-based evaluators
  • Loading branch information
aybruhm authored Aug 29, 2024
2 parents 532a4bb + cc33a66 commit 9212c7b
Show file tree
Hide file tree
Showing 10 changed files with 308 additions and 62 deletions.
20 changes: 20 additions & 0 deletions agenta-backend/agenta_backend/resources/evaluators/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
"name": "Semantic Similarity Match",
"key": "auto_semantic_similarity",
"direct_use": False,
"requires_llm_api_keys": True,
"description": "Semantic Similarity Match evaluator measures the similarity between two pieces of text by analyzing their meaning and context. It compares the semantic content, providing a score that reflects how closely the texts match in terms of meaning, rather than just exact word matches.",
"settings_template": {
"correct_answer_key": {
Expand Down Expand Up @@ -181,6 +182,7 @@
"name": "LLM-as-a-judge",
"key": "auto_ai_critique",
"direct_use": False,
"requires_llm_api_keys": True,
"settings_template": {
"prompt_template": {
"label": "Prompt Template",
Expand All @@ -206,6 +208,14 @@
"key": "auto_custom_code_run",
"direct_use": False,
"settings_template": {
"requires_llm_api_keys": {
"label": "Requires LLM API Key(s)",
"type": "boolean",
"required": True,
"default": False,
"advanced": True,
"description": "Indicates whether the evaluation requires LLM API key(s) to function.",
},
"code": {
"label": "Evaluation Code",
"type": "code",
Expand All @@ -230,6 +240,14 @@
"key": "auto_webhook_test",
"direct_use": False,
"settings_template": {
"requires_llm_api_keys": {
"label": "Requires LLM API Key(s)",
"type": "boolean",
"required": True,
"default": False,
"advanced": True,
"description": "Indicates whether the evaluation requires LLM API key(s) to function.",
},
"webhook_url": {
"label": "Webhook URL",
"type": "string",
Expand Down Expand Up @@ -380,13 +398,15 @@
"name": "RAG Faithfulness",
"key": "rag_faithfulness",
"direct_use": False,
"requires_llm_api_keys": True,
"settings_template": rag_evaluator_settings_template,
"description": "RAG Faithfulness evaluator assesses the accuracy and reliability of responses generated by Retrieval-Augmented Generation (RAG) models. It evaluates how faithfully the responses adhere to the retrieved documents or sources, ensuring that the generated text accurately reflects the information from the original sources.",
},
{
"name": "RAG Context Relevancy",
"key": "rag_context_relevancy",
"direct_use": False,
"requires_llm_api_keys": True,
"settings_template": rag_evaluator_settings_template,
"description": "RAG Context Relevancy evaluator measures how relevant the retrieved documents or contexts are to the given question or prompt. It ensures that the selected documents provide the necessary information for generating accurate and meaningful responses, improving the overall quality of the RAG model's output.",
},
Expand Down
13 changes: 6 additions & 7 deletions agenta-backend/agenta_backend/routers/evaluation_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastapi.responses import JSONResponse
from fastapi import HTTPException, Request, status, Response, Query

from agenta_backend.services import helpers
from agenta_backend.models import converters
from agenta_backend.tasks.evaluations import evaluate
from agenta_backend.utils.common import APIRouter, isCloudEE
Expand All @@ -15,9 +16,6 @@
NewEvaluation,
DeleteEvaluation,
)
from agenta_backend.services.evaluator_manager import (
check_ai_critique_inputs,
)

if isCloudEE():
from agenta_backend.commons.models.shared_models import Permission
Expand Down Expand Up @@ -112,8 +110,9 @@ async def create_evaluation(
status_code=403,
)

success, response = await check_ai_critique_inputs(
payload.evaluators_configs, payload.lm_providers_keys
llm_provider_keys = helpers.format_llm_provider_keys(payload.lm_providers_keys)
success, response = await helpers.ensure_required_llm_keys_exist(
payload.evaluators_configs, llm_provider_keys
)
if not success:
return response
Expand All @@ -134,8 +133,8 @@ async def create_evaluation(
evaluators_config_ids=payload.evaluators_configs,
testset_id=payload.testset_id,
evaluation_id=evaluation.id,
rate_limit_config=payload.rate_limit.dict(),
lm_providers_keys=payload.lm_providers_keys,
rate_limit_config=payload.rate_limit.model_dump(),
lm_providers_keys=llm_provider_keys,
)
evaluations.append(evaluation)

Expand Down
25 changes: 16 additions & 9 deletions agenta-backend/agenta_backend/services/db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3000,13 +3000,17 @@ async def fetch_evaluator_config(evaluator_config_id: str):
return evaluator_config


async def check_if_ai_critique_exists_in_list_of_evaluators_configs(
evaluators_configs_ids: List[str],
async def check_if_evaluators_exist_in_list_of_evaluators_configs(
evaluators_configs_ids: List[str], evaluators_keys: List[str]
) -> bool:
"""Fetch evaluator configurations from the database.
"""Check if the provided evaluators exist in the database within the given evaluator configurations.
Arguments:
evaluators_configs_ids (List[str]): List of evaluator configuration IDs to search within.
evaluators_keys (List[str]): List of evaluator keys to check for existence.
Returns:
EvaluatorConfigDB: the evaluator configuration object.
bool: True if all evaluators exist, False otherwise.
"""

async with db_engine.get_session() as session:
Expand All @@ -3015,15 +3019,18 @@ async def check_if_ai_critique_exists_in_list_of_evaluators_configs(
for evaluator_config_id in evaluators_configs_ids
]

query = select(EvaluatorConfigDB).where(
query = select(EvaluatorConfigDB.id, EvaluatorConfigDB.evaluator_key).where(
EvaluatorConfigDB.id.in_(evaluator_config_uuids),
EvaluatorConfigDB.evaluator_key == "auto_ai_critique",
EvaluatorConfigDB.evaluator_key.in_(evaluators_keys),
)

result = await session.execute(query)
evaluators_configs = result.scalars().all()

return bool(evaluators_configs)
# NOTE: result.all() returns the records as a list of tuples
# 0 is the evaluator_id and 1 is evaluator_key
fetched_evaluators_keys = {config[1] for config in result.all()}

# Ensure the passed evaluators are found in the fetched evaluator keys
return any(key in fetched_evaluators_keys for key in evaluators_keys)


async def fetch_evaluator_config_by_appId(
Expand Down
25 changes: 0 additions & 25 deletions agenta-backend/agenta_backend/services/evaluator_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,28 +166,3 @@ async def create_ready_to_use_evaluators(app: AppDB):
evaluator_key=evaluator.key,
settings_values=settings_values,
)


async def check_ai_critique_inputs(
evaluators_configs: List[str], lm_providers_keys: Optional[Dict[str, Any]]
) -> Tuple[bool, Optional[JSONResponse]]:
"""
Checks if AI critique exists in evaluators configs and validates lm_providers_keys.
Args:
evaluators_configs (List[str]): List of evaluator configurations.
lm_providers_keys (Optional[Dict[str, Any]]): Language model provider keys.
Returns:
Tuple[bool, Optional[JSONResponse]]: Returns a tuple containing a boolean indicating success,
and a JSONResponse in case of error.
"""
if await db_manager.check_if_ai_critique_exists_in_list_of_evaluators_configs(
evaluators_configs
):
if not lm_providers_keys:
return False, JSONResponse(
{"detail": "Missing LM provider Key"},
status_code=400,
)
return True, None
34 changes: 24 additions & 10 deletions agenta-backend/agenta_backend/services/evaluators_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,15 +398,18 @@ async def auto_ai_critique(
output = validate_string_output("ai_critique", output)
correct_answer = get_correct_answer(data_point, settings_values)
inputs = {
"prompt_user": app_params.get("prompt_user", ""),
"prompt_user": app_params.get("prompt_user", "").format(**data_point),
"prediction": output,
"ground_truth": correct_answer,
}
settings = {
"prompt_template": settings_values.get("prompt_template", ""),
}
response = await ai_critique(
input=EvaluatorInputInterface(
**{
"inputs": inputs,
"settings": settings_values,
"settings": settings,
"credentials": lm_providers_keys,
}
)
Expand All @@ -424,7 +427,12 @@ async def auto_ai_critique(


async def ai_critique(input: EvaluatorInputInterface) -> EvaluatorOutputInterface:
openai_api_key = input.credentials["OPENAI_API_KEY"]
openai_api_key = input.credentials.get("OPENAI_API_KEY", None)

if not openai_api_key:
raise Exception(
"No OpenAI key was found. AI Critique evaluator requires a valid OpenAI API key to function. Please configure your OpenAI API and try again."
)

chain_run_args = {
"llm_app_prompt_template": input.inputs.get("prompt_user", ""),
Expand All @@ -434,18 +442,20 @@ async def ai_critique(input: EvaluatorInputInterface) -> EvaluatorOutputInterfac
for key, value in input.inputs.items():
chain_run_args[key] = value

prompt_template = input.settings["prompt_template"]
prompt_template = input.settings.get("prompt_template", "")
messages = [
{"role": "system", "content": prompt_template},
{"role": "user", "content": str(chain_run_args)},
]

print(input)

client = AsyncOpenAI(api_key=openai_api_key)
response = await client.chat.completions.create(
model="gpt-3.5-turbo", messages=messages, temperature=0.8
)
evaluation_output = response.choices[0].message.content.strip()
return {"outputs": {"score": float(evaluation_output)}}
return {"outputs": {"score": evaluation_output}}


async def auto_starts_with(
Expand Down Expand Up @@ -846,7 +856,7 @@ async def measure_rag_consistency(
openai_api_key = input.credentials.get("OPENAI_API_KEY", None)
if not openai_api_key:
raise Exception(
"No LLM keys OpenAI key found. Please configure your OpenAI keys and try again."
"No OpenAI key was found. RAG evaluator requires a valid OpenAI API key to function. Please configure your OpenAI API and try again."
)

# Initialize RAG evaluator to calculate faithfulness score
Expand Down Expand Up @@ -945,10 +955,9 @@ async def measure_context_coherence(
input: EvaluatorInputInterface,
) -> EvaluatorOutputInterface:
openai_api_key = input.credentials.get("OPENAI_API_KEY", None)

if not openai_api_key:
raise Exception(
"No LLM keys OpenAI key found. Please configure your OpenAI keys and try again."
"No OpenAI key was found. RAG evaluator requires a valid OpenAI API key to function. Please configure your OpenAI API and try again."
)

# Initialize RAG evaluator to calculate context relevancy score
Expand Down Expand Up @@ -1176,8 +1185,13 @@ async def semantic_similarity(
float: the semantic similarity score
"""

api_key = input.credentials["OPENAI_API_KEY"]
openai = AsyncOpenAI(api_key=api_key)
openai_api_key = input.credentials.get("OPENAI_API_KEY", None)
if not openai_api_key:
raise Exception(
"No OpenAI key was found. Semantic evaluator requires a valid OpenAI API key to function. Please configure your OpenAI API and try again."
)

openai = AsyncOpenAI(api_key=openai_api_key)

async def encode(text: str):
response = await openai.embeddings.create(
Expand Down
70 changes: 68 additions & 2 deletions agenta-backend/agenta_backend/services/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import json
from typing import List, Dict, Any, Tuple, Union
from datetime import datetime, timedelta, timezone
from datetime import datetime, timezone
from typing import List, Dict, Any, Union, Tuple

from agenta_backend.services import db_manager
from agenta_backend.models.api.evaluation_model import LMProvidersEnum
from agenta_backend.resources.evaluators.evaluators import get_all_evaluators


def format_inputs(list_of_dictionaries: List[Dict[str, Any]]) -> Dict:
Expand Down Expand Up @@ -76,3 +80,65 @@ def convert_to_utc_datetime(dt: Union[datetime, str, None]) -> datetime:
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt


def format_llm_provider_keys(
llm_provider_keys: Dict[LMProvidersEnum, str]
) -> Dict[str, str]:
"""Formats a dictionary of LLM provider keys into a dictionary of strings.
Args:
llm_provider_keys (Dict[LMProvidersEnum, str]): LLM provider keys
Returns:
Dict[str, str]: formatted llm provided keys
Example:
Input: {<LMProvidersEnum.mistralai: 'MISTRAL_API_KEY'>: '...', ...}
Output: {'MISTRAL_API_KEY': '...', ...}
"""

llm_provider_keys = {key.value: value for key, value in llm_provider_keys.items()}
return llm_provider_keys


async def ensure_required_llm_keys_exist(
evaluator_configs: List[str], llm_provider_keys: Dict[str, str]
) -> Tuple[bool, None]:
"""
Validates if necessary LLM API keys are present when required evaluators are used.
Args:
evaluator_configs (List[str]): List of evaluator configurations to check.
llm_provider_keys (Dict[str, str]): Dictionary of LLM provider keys (e.g., {"OPENAI_API_KEY": "your-key"}).
Returns:
Tuple[bool, None]: Returns (True, None) if validation passes.
Raises:
ValueError: If an evaluator requiring LLM keys is configured but no LLM API key is provided.
"""

evaluators_requiring_llm_keys = [
evaluator["key"]
for evaluator in get_all_evaluators()
if evaluator.get("requires_llm_api_keys", False)
or (
evaluator.get("settings_template", {})
.get("requires_llm_api_keys", {})
.get("default", False)
)
]
evaluators_found = (
await db_manager.check_if_evaluators_exist_in_list_of_evaluators_configs(
evaluator_configs, evaluators_requiring_llm_keys
)
)

if evaluators_found and "OPENAI_API_KEY" not in llm_provider_keys:
raise ValueError(
"OpenAI API key is required to run one or more of the specified evaluators."
)

return True, None
Loading

0 comments on commit 9212c7b

Please sign in to comment.