Skip to content

Commit

Permalink
Merge pull request #2251 from Agenta-AI/feature/age-1285-inline-trace…
Browse files Browse the repository at this point in the history
…s-update-sdk-to-return-baseresponse-version30

[Enhancement] Update SDK to return BaseResponse version 3.0 and new trace tree format
  • Loading branch information
aybruhm authored Nov 19, 2024
2 parents 0f6dc16 + 438b87f commit df10790
Show file tree
Hide file tree
Showing 58 changed files with 2,355 additions and 292 deletions.
20 changes: 18 additions & 2 deletions agenta-backend/agenta_backend/services/evaluators_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,25 @@ async def map(
"""

mapping_outputs = {}
trace = process_distributed_trace_into_trace_tree(mapping_input.inputs["trace"])
mapping_inputs = mapping_input.inputs
response_version = mapping_input.inputs.get("version")

trace = process_distributed_trace_into_trace_tree(
trace=(
mapping_inputs["tree"]
if response_version == "3.0"
else mapping_inputs["trace"]
if response_version == "2.0"
else {}
),
version=mapping_input.inputs.get("version"),
)
for to_key, from_key in mapping_input.mapping.items():
mapping_outputs[to_key] = get_field_value_from_trace_tree(trace, from_key)
mapping_outputs[to_key] = get_field_value_from_trace_tree(
trace,
from_key,
version=mapping_input.inputs.get("version"),
)
return {"outputs": mapping_outputs}


Expand Down
92 changes: 79 additions & 13 deletions agenta-backend/agenta_backend/services/llm_apps_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,92 @@
logger.setLevel(logging.DEBUG)


def extract_result_from_response(response):
def extract_result_from_response(response: dict):
def get_nested_value(d: dict, keys: list, default=None):
"""
Helper function to safely retrieve nested values.
"""
try:
for key in keys:
if isinstance(d, dict):
d = d.get(key, default)
else:
return default
return d
except Exception as e:
print(f"Error accessing nested value: {e}")
return default

# Initialize default values
value = None
latency = None
cost = None

if response.get("version", None) == "2.0":
value = response
try:
# Validate input
if not isinstance(response, dict):
raise ValueError("The response must be a dictionary.")

# Handle version 3.0 response
if response.get("version") == "3.0":
value = response
# Ensure 'data' is a dictionary or convert it to a string
if not isinstance(value.get("data"), dict):
value["data"] = str(value.get("data"))

if "tree" in response:
trace_tree = (
response["tree"][0]
if isinstance(response.get("tree"), list)
else {}
)
latency = (
get_nested_value(trace_tree, ["time", "span"]) * 1_000_000
if trace_tree
else None
)
cost = get_nested_value(
trace_tree, ["metrics", "acc", "costs", "total"]
)

if not isinstance(value["data"], dict):
value["data"] = str(value["data"])
# Handle version 2.0 response
elif response.get("version") == "2.0":
value = response
if not isinstance(value.get("data"), dict):
value["data"] = str(value.get("data"))

if "trace" in response:
latency = response["trace"].get("latency", None)
cost = response["trace"].get("cost", None)
else:
value = {"data": str(response["message"])}
latency = response.get("latency", None)
cost = response.get("cost", None)
if "trace" in response:
latency = response["trace"].get("latency")
cost = response["trace"].get("cost")

kind = "text" if isinstance(value, str) else "object"
# Handle generic response (neither 2.0 nor 3.0)
else:
value = {"data": str(response.get("message", ""))}
latency = response.get("latency")
cost = response.get("cost")

# Determine the type of 'value' (either 'text' or 'object')
kind = "text" if isinstance(value, str) else "object"

except ValueError as ve:
print(f"Input validation error: {ve}")
value = {"error": str(ve)}
kind = "error"

except KeyError as ke:
print(f"Missing key: {ke}")
value = {"error": f"Missing key: {ke}"}
kind = "error"

except TypeError as te:
print(f"Type error: {te}")
value = {"error": f"Type error: {te}"}
kind = "error"

except Exception as e:
print(f"Unexpected error: {e}")
value = {"error": f"Unexpected error: {e}"}
kind = "error"

return value, kind, cost, latency

Expand Down
101 changes: 100 additions & 1 deletion agenta-backend/agenta_backend/tests/unit/test_evaluators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import pytest

from agenta_backend.tests.unit.test_traces import simple_rag_trace
from agenta_backend.tests.unit.test_traces import (
simple_rag_trace,
simple_rag_trace_for_baseresponse_v3,
)
from agenta_backend.services.evaluators_service import (
auto_levenshtein_distance,
auto_ai_critique,
Expand Down Expand Up @@ -535,3 +538,99 @@ async def test_rag_context_relevancy_evaluator(
# - raised by evaluator (agenta) -> TypeError
assert not isinstance(result.value, float) or not isinstance(result.value, int)
assert result.error.message == "Error during RAG Context Relevancy evaluation"


@pytest.mark.parametrize(
"settings_values, expected_min, openai_api_key, expected_max",
[
(
{
"question_key": "rag.retriever.internals.prompt",
"answer_key": "rag.reporter.outputs.report",
"contexts_key": "rag.retriever.outputs.movies",
},
os.environ.get("OPENAI_API_KEY"),
0.0,
1.0,
),
(
{
"question_key": "rag.retriever.internals.prompt",
"answer_key": "rag.reporter.outputs.report",
"contexts_key": "rag.retriever.outputs.movies",
},
None,
None,
None,
),
# add more use cases
],
)
@pytest.mark.asyncio
async def test_rag_faithfulness_evaluator_for_baseresponse_v3(
settings_values, expected_min, openai_api_key, expected_max
):
result = await rag_faithfulness(
{},
simple_rag_trace_for_baseresponse_v3,
{},
{},
settings_values,
{"OPENAI_API_KEY": openai_api_key},
)

try:
assert expected_min <= round(result.value, 1) <= expected_max
except TypeError as error:
# exceptions
# - raised by evaluator (agenta) -> TypeError
assert not isinstance(result.value, float) or not isinstance(result.value, int)


@pytest.mark.parametrize(
"settings_values, expected_min, openai_api_key, expected_max",
[
(
{
"question_key": "rag.retriever.internals.prompt",
"answer_key": "rag.reporter.outputs.report",
"contexts_key": "rag.retriever.outputs.movies",
},
os.environ.get("OPENAI_API_KEY"),
0.0,
1.0,
),
(
{
"question_key": "rag.retriever.internals.prompt",
"answer_key": "rag.reporter.outputs.report",
"contexts_key": "rag.retriever.outputs.movies",
},
None,
None,
None,
),
# add more use cases
],
)
@pytest.mark.asyncio
async def test_rag_context_relevancy_evaluator_for_baseresponse_v3(
settings_values, expected_min, openai_api_key, expected_max
):
result = await rag_context_relevancy(
{},
simple_rag_trace_for_baseresponse_v3,
{},
{},
settings_values,
{"OPENAI_API_KEY": openai_api_key},
)

try:
assert expected_min <= round(result.value, 1) <= expected_max
except TypeError as error:
# exceptions
# - raised by autoevals -> ValueError (caught already and then passed as a stacktrace to the result)
# - raised by evaluator (agenta) -> TypeError
assert not isinstance(result.value, float) or not isinstance(result.value, int)
assert result.error.message == "Error during RAG Context Relevancy evaluation"
Loading

0 comments on commit df10790

Please sign in to comment.