Skip to content

Commit

Permalink
Moved eval_answer from tools to core (#142)
Browse files Browse the repository at this point in the history
Co-authored-by: Mayk Caldas <mayk@futurehouse.org>
  • Loading branch information
maykcaldas and maykcaldas authored Dec 6, 2024
1 parent c51c64b commit 02dc2e2
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 106 deletions.
2 changes: 1 addition & 1 deletion packages/hotpotqa/tests/test_hotpotqa_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aviary.core import Environment, TaskDataset
from aviary.envs.hotpotqa import HotPotQAEnv
from aviary.envs.hotpotqa.env import HotPotQADataset
from aviary.tools.utils import EvalAnswerMode
from aviary.utils import EvalAnswerMode


def test_env_construction() -> None:
Expand Down
11 changes: 8 additions & 3 deletions src/aviary/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from aviary.render import Renderer
from aviary.tools import (
INVALID_TOOL_NAME,
EvalAnswerMode,
FunctionInfo,
Messages,
MessagesAdapter,
Expand All @@ -35,10 +34,15 @@
ToolSelector,
ToolSelectorLedger,
argref_by_name,
eval_answer,
wraps_doc_only,
)
from aviary.utils import encode_image_to_base64, is_coroutine_callable, partial_format
from aviary.utils import (
EvalAnswerMode,
encode_image_to_base64,
eval_answer,
is_coroutine_callable,
partial_format,
)

__all__ = [
"INVALID_TOOL_NAME",
Expand Down Expand Up @@ -77,6 +81,7 @@
"argref_by_name",
"encode_image_to_base64",
"eval_answer",
"eval_answer",
"fenv",
"is_coroutine_callable",
"join",
Expand Down
4 changes: 1 addition & 3 deletions src/aviary/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
ToolsAdapter,
wraps_doc_only,
)
from .utils import EvalAnswerMode, ToolSelector, ToolSelectorLedger, eval_answer
from .utils import ToolSelector, ToolSelectorLedger

__all__ = [
"INVALID_TOOL_NAME",
"EvalAnswerMode",
"FunctionInfo",
"Messages",
"MessagesAdapter",
Expand All @@ -33,6 +32,5 @@
"Tools",
"ToolsAdapter",
"argref_by_name",
"eval_answer",
"wraps_doc_only",
]
98 changes: 0 additions & 98 deletions src/aviary/tools/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Callable
from enum import StrEnum
from functools import partial
from typing import TYPE_CHECKING, Any, ClassVar, cast

Expand All @@ -21,103 +20,6 @@
from litellm import ModelResponse


class EvalAnswerMode(StrEnum):
EXACT = "exact" # strings must match exactly
CONTAINS = "contains" # the correct answer is contained in the supplied answer
LLM = "llm" # Ask an LLM to evaluate
LLM_SCORE = "llm-score" # Ask an LLM to evaluate and return the score (normalized)


LLM_EVAL_CONFIG = {
"prompt": (
"Here is a question, the correct answer to the question, and a proposed answer"
" to the question. Please tell me if the proposed answer is correct, given the"
" correct answer. ONLY SAY 'YES' OR 'NO'. No other output is permitted."
"\n\nQuestion: {question}"
"\n\nCorrect answer: {correct_answer}"
"\n\nProposed answer: {proposed_answer}"
),
"model": "gpt-4o-mini",
"temperature": 0,
}

LLM_SCORE_EVAL_CONFIG = {
"prompt": (
"Here is a question, the correct answer to the question, and a rubric for"
" evaluating the question. Judge the proposed answer based on the given rubric."
" Give a score from 0 to 10. No other output is permitted."
"\n\nQuestion: {question}"
"\n\nRubric: {correct_answer}"
"\n\nProposed answer: {proposed_answer}"
),
"model": "gpt-4o-mini",
"temperature": 0,
"max_score": 10,
}


async def eval_answer(
proposed: str,
correct: str,
question: str | None = None,
eval_mode: EvalAnswerMode = EvalAnswerMode.CONTAINS,
llm_eval_config: dict | None = None,
) -> float:
"""Evaluate a proposed answer against a correct answer.
Will return 0 or 1, except for llm-score which should be between 0 and 1
"""
if eval_mode in {EvalAnswerMode.LLM, EvalAnswerMode.LLM_SCORE}:
try:
from litellm import acompletion
except ImportError as e:
raise ImportError(
"eval_answer requires the 'llm' extra for 'litellm'. Please:"
" `pip install aviary[llm]`."
) from e
if question is None:
raise ValueError("Question must be provided for LLM evaluation mode.")
default_config = (
LLM_EVAL_CONFIG
if eval_mode == EvalAnswerMode.LLM
else LLM_SCORE_EVAL_CONFIG
)
config = llm_eval_config or default_config
prompt = cast(str, config.get("prompt", default_config["prompt"])).format(
question=question,
correct_answer=correct,
proposed_answer=proposed,
)
response = await acompletion(
model=config.get("model", default_config["model"]),
temperature=config.get("temperature", default_config["temperature"]),
messages=[{"content": prompt, "role": "user"}],
)
if eval_mode == EvalAnswerMode.LLM:
return await eval_answer(
response.choices[0].message.content.strip().casefold(),
"yes",
eval_mode=EvalAnswerMode.EXACT,
)
try:
return float(response.choices[0].content.strip()) / float(
config.get("max_score", default_config["max_score"]) # type: ignore[arg-type]
)
except ValueError:
return 0

gt = correct.strip().casefold()
pred = proposed.strip().casefold()

if eval_mode == EvalAnswerMode.EXACT:
return float(pred == gt)

if eval_mode == EvalAnswerMode.CONTAINS:
return float(gt in pred)

raise RuntimeError(f"Invalid evaluation mode: {eval_mode}")


class ToolSelector:
"""Simple entity to select a tool based on messages."""

Expand Down
100 changes: 99 additions & 1 deletion src/aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,48 @@
import contextlib
import inspect
import io
from typing import TYPE_CHECKING, Any
from enum import StrEnum
from typing import TYPE_CHECKING, Any, cast

if TYPE_CHECKING:
import numpy as np


LLM_EVAL_CONFIG = {
"prompt": (
"Here is a question, the correct answer to the question, and a proposed answer"
" to the question. Please tell me if the proposed answer is correct, given the"
" correct answer. ONLY SAY 'YES' OR 'NO'. No other output is permitted."
"\n\nQuestion: {question}"
"\n\nCorrect answer: {correct_answer}"
"\n\nProposed answer: {proposed_answer}"
),
"model": "gpt-4o-mini",
"temperature": 0,
}

LLM_SCORE_EVAL_CONFIG = {
"prompt": (
"Here is a question, the correct answer to the question, and a rubric for"
" evaluating the question. Judge the proposed answer based on the given rubric."
" Give a score from 0 to 10. No other output is permitted."
"\n\nQuestion: {question}"
"\n\nRubric: {correct_answer}"
"\n\nProposed answer: {proposed_answer}"
),
"model": "gpt-4o-mini",
"temperature": 0,
"max_score": 10,
}


class EvalAnswerMode(StrEnum):
EXACT = "exact" # strings must match exactly
CONTAINS = "contains" # the correct answer is contained in the supplied answer
LLM = "llm" # Ask an LLM to evaluate
LLM_SCORE = "llm-score" # Ask an LLM to evaluate and return the score (normalized)


def partial_format(value: str, **formats: dict[str, Any]) -> str:
"""Partially format a string given a variable amount of formats."""
for template_key, template_value in formats.items():
Expand Down Expand Up @@ -41,3 +77,65 @@ def is_coroutine_callable(obj) -> bool:
if callable(obj):
return inspect.iscoroutinefunction(obj.__call__)
return False


async def eval_answer(
proposed: str,
correct: str,
question: str | None = None,
eval_mode: EvalAnswerMode = EvalAnswerMode.CONTAINS,
llm_eval_config: dict | None = None,
) -> float:
"""Evaluate a proposed answer against a correct answer.
Will return 0 or 1, except for llm-score which should be between 0 and 1
"""
if eval_mode in {EvalAnswerMode.LLM, EvalAnswerMode.LLM_SCORE}:
try:
from litellm import acompletion
except ImportError as e:
raise ImportError(
"eval_answer requires the 'llm' extra for 'litellm'. Please:"
" `pip install aviary[llm]`."
) from e
if question is None:
raise ValueError("Question must be provided for LLM evaluation mode.")
default_config = (
LLM_EVAL_CONFIG
if eval_mode == EvalAnswerMode.LLM
else LLM_SCORE_EVAL_CONFIG
)
config = llm_eval_config or default_config
prompt = cast(str, config.get("prompt", default_config["prompt"])).format(
question=question,
correct_answer=correct,
proposed_answer=proposed,
)
response = await acompletion(
model=config.get("model", default_config["model"]),
temperature=config.get("temperature", default_config["temperature"]),
messages=[{"content": prompt, "role": "user"}],
)
if eval_mode == EvalAnswerMode.LLM:
return await eval_answer(
response.choices[0].message.content.strip().casefold(),
"yes",
eval_mode=EvalAnswerMode.EXACT,
)
try:
return float(response.choices[0].content.strip()) / float(
config.get("max_score", default_config["max_score"]) # type: ignore[arg-type]
)
except ValueError:
return 0

gt = correct.strip().casefold()
pred = proposed.strip().casefold()

if eval_mode == EvalAnswerMode.EXACT:
return float(pred == gt)

if eval_mode == EvalAnswerMode.CONTAINS:
return float(gt in pred)

raise RuntimeError(f"Invalid evaluation mode: {eval_mode}")

0 comments on commit 02dc2e2

Please sign in to comment.