Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]: Handle non-string outputs gracefully in auto_contains_json evaluator #1987

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
d3f7315
refactor (backend): improve error handling for auto_contains_json eva…
aybruhm Aug 13, 2024
08b9e87
feat (tests): add tests for dictionary-based output handling in conta…
aybruhm Aug 13, 2024
f0cc8c6
Merge branch 'feature/age-491-poc-1e-expose-running-evaluators-via-ap…
aybruhm Aug 20, 2024
d1fe5aa
chore (backend): remove redundant error message
aybruhm Aug 20, 2024
05ae4b5
Merge branch 'feature/age-491-poc-1e-expose-running-evaluators-via-ap…
aybruhm Aug 20, 2024
ac1ac7e
Merge branch 'feature/age-491-poc-1e-expose-running-evaluators-via-ap…
aybruhm Aug 21, 2024
9309f43
Merge branch 'feature/age-491-poc-1e-expose-running-evaluators-via-ap…
aybruhm Aug 21, 2024
23be8b6
refactor (backend): centralize validation of string and json output a…
aybruhm Aug 21, 2024
b6db4f1
feat (tests): update parameters for BaseResponse compatibility and re…
aybruhm Aug 21, 2024
80f3eff
minor refactor (backend): update 'validate_json_output' function retu…
aybruhm Aug 21, 2024
892a351
chore (style): format evaluators_service with black@23.12.0
aybruhm Aug 21, 2024
2e76a1c
Merge branch 'main' of github.com:Agenta-AI/agenta
jp-agenta Aug 23, 2024
3cad5db
Enforce in Union[str, Dict[str, Any]] in BaseResponse in SDK
jp-agenta Aug 23, 2024
35e6fec
fix(frontend): Migrate Inter font to use @next/font
bekossy Aug 23, 2024
1dec2c6
Merge pull request #2016 from Agenta-AI/AGE-654/-migrate-Inter-Font-t…
jp-agenta Aug 23, 2024
6238cd8
Merge branch 'main' of github.com:Agenta-AI/agenta
jp-agenta Aug 23, 2024
f3546ef
Merge branch 'main' into feature/age-573-evaluators-fail-gracefully-w…
jp-agenta Aug 23, 2024
2402f94
fix exception message and bump SDK out of pre-release
jp-agenta Aug 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions agenta-backend/agenta_backend/routers/evaluators_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ async def evaluator_run(
)
return result
except Exception as e:
logger.error(f"Error while running evaluator: {str(e)}")
logger.error(f"Error while running {evaluator_key} evaluator: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"message": "Error while running evaluator",
"message": f"Error while running {evaluator_key} evaluator",
"stacktrace": traceback.format_exc(),
},
)
Expand Down
130 changes: 83 additions & 47 deletions agenta-backend/agenta_backend/services/evaluators_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,71 @@
logger.setLevel(logging.DEBUG)


def validate_string_output(
evaluator_key: str, output: Union[str, Dict[str, Any]]
) -> str:
"""Checks and validate the output to be of type string.

Args:
evaluator_key (str): the key of the evaluator
output (Union[str, Dict[str, Any]]): the llm response

Raises:
Exception: requires output to be a string

Returns:
str: output
"""

output = output.get("data", "") if isinstance(output, dict) else output
if not isinstance(output, str):
raise Exception(
f"Evaluator {evaluator_key} requires the output to be a string, but received {type(output).__name__} instead. "
)
return output


def validate_json_output(
evaluator_key: str, output: Union[str, Dict[str, Any]]
) -> Union[str, dict]:
"""Checks and validate the output to be of type JSON string or dictionary.

Args:
evaluator_key (str): the key of the evaluator
output (Union[str, Dict[str, Any]]): the llm response

Raises:
Exception: requires output to be a JSON string

Returns:
str, dict: output
"""

output = output.get("data", "") if isinstance(output, dict) else output
if isinstance(output, dict):
output = json.dumps(output)
elif isinstance(output, str):
try:
json.loads(output)
except json.JSONDecodeError:
raise Exception(
f"Evaluator {evaluator_key} requires the output to be a JSON string or object."
)

if not isinstance(
output,
(
str,
dict,
),
):
raise Exception(
f"Evaluator {evaluator_key} requires the output to be either a JSON string or object, but received {type(output).__name__} instead."
)

return output


async def map(
mapping_input: EvaluatorMappingInputInterface,
) -> EvaluatorMappingOutputInterface:
Expand Down Expand Up @@ -94,9 +159,9 @@ async def auto_exact_match(
Returns:
Result: A Result object containing the evaluation result.
"""
if not isinstance(output, str):
output = output.get("data", "")

try:
output = validate_string_output("exact_match", output)
correct_answer = get_correct_answer(data_point, settings_values)
inputs = {"ground_truth": correct_answer, "prediction": output}
response = exact_match(input=EvaluatorInputInterface(**{"inputs": inputs}))
Expand Down Expand Up @@ -136,9 +201,8 @@ async def auto_regex_test(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("regex_test", output)
inputs = {"ground_truth": data_point, "prediction": output}
response = await regex_test(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -174,9 +238,8 @@ async def auto_field_match_test(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("field_match_test", output)
correct_answer = get_correct_answer(data_point, settings_values)
inputs = {"ground_truth": correct_answer, "prediction": output}
response = await field_match_test(
Expand Down Expand Up @@ -210,9 +273,8 @@ async def auto_webhook_test(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("webhook_test", output)
correct_answer = get_correct_answer(data_point, settings_values)
inputs = {"prediction": output, "ground_truth": correct_answer}
response = await webhook_test(
Expand Down Expand Up @@ -272,9 +334,8 @@ async def auto_custom_code_run(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("custom_code_run", output)
correct_answer = get_correct_answer(data_point, settings_values)
inputs = {
"app_config": app_params,
Expand Down Expand Up @@ -332,9 +393,9 @@ async def auto_ai_critique(
Returns:
Result: Evaluation result.
"""
if not isinstance(output, str):
output = output.get("data", "")

try:
output = validate_string_output("ai_critique", output)
correct_answer = get_correct_answer(data_point, settings_values)
inputs = {
"prompt_user": app_params.get("prompt_user", ""),
Expand Down Expand Up @@ -391,9 +452,8 @@ async def auto_starts_with(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("starts_with", output)
inputs = {"prediction": output}
response = await starts_with(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -433,9 +493,8 @@ async def auto_ends_with(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("ends_with", output)
inputs = {"prediction": output}
response = await ends_with(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -476,9 +535,8 @@ async def auto_contains(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("contains", output)
inputs = {"prediction": output}
response = await contains(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -519,9 +577,8 @@ async def auto_contains_any(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("contains_any", output)
inputs = {"prediction": output}
response = await contains_any(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -564,9 +621,8 @@ async def auto_contains_all(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("contains_all", output)
response = await contains_all(
input=EvaluatorInputInterface(
**{"inputs": {"prediction": output}, "settings": settings_values}
Expand Down Expand Up @@ -607,9 +663,8 @@ async def auto_contains_json(
settings_values: Dict[str, Any], # pylint: disable=unused-argument
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_json_output("contains_json", output)
response = await contains_json(
input=EvaluatorInputInterface(**{"inputs": {"prediction": output}})
)
Expand Down Expand Up @@ -750,22 +805,7 @@ async def auto_json_diff(
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
try:
output = output.get("data", "") if isinstance(output, dict) else output

if isinstance(output, dict):
output = json.dumps(output)
elif isinstance(output, str):
try:
json.loads(output)
except:
raise Exception(
f"Evaluator 'auto_json_diff' requires string outputs to be JSON strings."
)
else:
raise Exception(
f"Evaluator 'auto_json_diff' requires the output to be either a JSON string or a JSON object, but received {type(output).__name__} instead."
)

output = validate_json_output("json_diff", output)
correct_answer = get_correct_answer(data_point, settings_values)
response = await json_diff(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -1035,9 +1075,8 @@ async def auto_levenshtein_distance(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any], # pylint: disable=unused-argument
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("levenshtein_distance", output)
correct_answer = get_correct_answer(data_point, settings_values)
response = await levenshtein_distance(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -1078,9 +1117,8 @@ async def auto_similarity_match(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any],
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")
try:
output = validate_string_output("similarity_match", output)
correct_answer = get_correct_answer(data_point, settings_values)
response = await similarity_match(
input=EvaluatorInputInterface(
Expand Down Expand Up @@ -1160,10 +1198,8 @@ async def auto_semantic_similarity(
settings_values: Dict[str, Any],
lm_providers_keys: Dict[str, Any],
) -> Result:
if not isinstance(output, str):
output = output.get("data", "")

try:
output = validate_string_output("semantic_similarity", output)
correct_answer = get_correct_answer(data_point, settings_values)
inputs = {"prediction": output, "ground_truth": correct_answer}
response = await semantic_similarity(
Expand Down
46 changes: 41 additions & 5 deletions agenta-backend/agenta_backend/tests/unit/test_evaluators.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os
import pytest

from test_traces import simple_rag_trace

from agenta_backend.tests.unit.test_traces import simple_rag_trace
from agenta_backend.services.evaluators_service import (
auto_levenshtein_distance,
auto_starts_with,
Expand Down Expand Up @@ -175,10 +174,13 @@ async def test_auto_contains_all(output, substrings, case_sensitive, expected):
@pytest.mark.parametrize(
"output, expected",
[
('Some random text {"key": "value"} more text', True),
("No JSON here!", False),
("{Malformed JSON, nope!}", False),
('Some random text {"key": "value"} more text', None),
("No JSON here!", None),
("{Malformed JSON, nope!}", None),
('{"valid": "json", "number": 123}', True),
({"data": {"message": "The capital of Azerbaijan is Baku."}}, True),
({"data": '{"message": "The capital of Azerbaijan is Baku."}'}, True),
({"data": "The capital of Azerbaijan is Baku."}, None),
],
)
@pytest.mark.asyncio
Expand Down Expand Up @@ -232,6 +234,40 @@ async def test_auto_contains_json(output, expected):
0.0,
1.0,
),
(
{
"correct_answer": '{"user": {"name": "John", "details": {"age": 30, "location": "New York"}}}'
},
{
"data": '{"USER": {"NAME": "John", "DETAILS": {"AGE": 30, "LOCATION": "New York"}}}'
},
{
"predict_keys": True,
"compare_schema_only": False,
"case_insensitive_keys": True,
"correct_answer_key": "correct_answer",
},
0.0,
1.0,
),
(
{
"correct_answer": '{"user": {"name": "John", "details": {"age": 30, "location": "New York"}}}'
},
{
"data": {
"output": '{"USER": {"NAME": "John", "DETAILS": {"AGE": 30, "LOCATION": "New York"}}}'
}
},
{
"predict_keys": True,
"compare_schema_only": False,
"case_insensitive_keys": True,
"correct_answer_key": "correct_answer",
},
0.0,
1.0,
),
],
)
@pytest.mark.asyncio
Expand Down
21 changes: 13 additions & 8 deletions agenta-cli/agenta/sdk/decorators/llm_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,11 @@ async def wrapper(*args, **kwargs) -> Any:
{
"func": func.__name__,
"endpoint": route,
"params": {**config_params, **func_signature.parameters}
if not config
else func_signature.parameters,
"params": (
{**config_params, **func_signature.parameters}
if not config
else func_signature.parameters
),
"config": config,
}
)
Expand All @@ -229,9 +231,11 @@ async def wrapper(*args, **kwargs) -> Any:
{
"func": func.__name__,
"endpoint": route,
"params": {**config_params, **func_signature.parameters}
if not config
else func_signature.parameters,
"params": (
{**config_params, **func_signature.parameters}
if not config
else func_signature.parameters
),
"config": config,
}
)
Expand Down Expand Up @@ -402,15 +406,16 @@ async def execute_function(

# PATCH : if result is not a dict, make it a dict
if not isinstance(result, dict):
data = result
data = str(result)
else:
# PATCH : if result is a legacy dict, clean it up
if (
"message" in result.keys()
and "cost" in result.keys()
and "usage" in result.keys()
):
data = result["message"]
data = str(result["message"])

# END OF PATH

if data is None:
Expand Down
Loading
Loading