Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EvalAnswerMode to HotPotQAEnv #102

Merged
merged 3 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions packages/hotpotqa/src/aviary/envs/hotpotqa/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,17 @@
from pydantic import BaseModel, ConfigDict, Field
from tenacity import retry, stop_after_attempt, wait_exponential_jitter

from aviary.env import Environment, Frame, TaskDataset
from aviary.message import Message
from aviary.tools import Tool, ToolRequestMessage, ToolResponseMessage
from aviary.core import (
Environment,
EvalAnswerMode,
Frame,
Message,
TaskDataset,
Tool,
ToolRequestMessage,
ToolResponseMessage,
eval_answer,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,6 +86,8 @@ class HotPotQAEnvState(BaseModel):
)
page: str | None = Field(default=None, description="The current Wikipedia page.")

evaluation_mode: EvalAnswerMode = EvalAnswerMode.CONTAINS


def create_tool(function: Callable, name: str) -> Tool:
"""Create a Tool object from a function and set its name.
Expand Down Expand Up @@ -176,6 +186,7 @@ def __init__(
correct_reward: float = 1.0,
incorrect_reward: float = 0.0,
tool_failure_reward: float = 0.0,
evaluation_mode: EvalAnswerMode = EvalAnswerMode.CONTAINS,
proxy: str | None = None,
):
super().__init__()
Expand All @@ -186,6 +197,14 @@ def __init__(
self.incorrect_reward = incorrect_reward
self.tool_failure_reward = tool_failure_reward
self.proxy = proxy
self.evaluation_mode = evaluation_mode

if evaluation_mode == EvalAnswerMode.LLM_SCORE:
raise NotImplementedError(
f'{HotPotQAEnv.__name__} does not support "{evaluation_mode}"'
" since the environment was built around binary evaluation of the"
" answer. Further development is needed for this mode."
)

# Title case tool names to match third party demonstration data
self.tools = [
Expand All @@ -198,7 +217,7 @@ def __init__(
def from_task(cls, task: str) -> "HotPotQAEnv":
return cls(question=task, correct_answer=0.0)

def calculate_reward(self, answer: str | None) -> float:
async def calculate_answer_reward(self, answer: str | None) -> float:
"""Calculate the reward based on the agent's answer.

Returns:
Expand All @@ -207,8 +226,17 @@ def calculate_reward(self, answer: str | None) -> float:
"""
if answer is None:
return self.incorrect_reward
pred, gt = normalize_answer(answer), self.normalized_correct_answer
return self.correct_reward if pred == gt else self.incorrect_reward
return (
self.correct_reward
if (
await eval_answer(
normalize_answer(answer),
self.normalized_correct_answer,
self.evaluation_mode,
)
)
else self.incorrect_reward
)

async def reset(self) -> tuple[list[Message], list[Tool]]:
"""Reset the HotPotQA environment to an initial state.
Expand Down Expand Up @@ -331,8 +359,8 @@ def export_frame(self) -> Frame:
}
)

def finish(self, answer: str) -> str:
"""Finish the episode.
async def finish(self, answer: str) -> str:
"""Finish the task by submitting an answer to the question.

Args:
answer: The answer to the question.
Expand All @@ -342,7 +370,7 @@ def finish(self, answer: str) -> str:
return "Finish failed. No answer provided."

self.state.answer = answer
self.state.reward += self.calculate_reward(answer)
self.state.reward += await self.calculate_answer_reward(answer)

self.state.last_action_is_lookup = False
return "Finished."
Expand Down
19 changes: 19 additions & 0 deletions packages/hotpotqa/tests/test_hotpotqa_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from aviary.core import Environment, TaskDataset
from aviary.envs.hotpotqa import HotPotQAEnv
from aviary.tools.utils import EvalAnswerMode


def test_env_construction() -> None:
Expand Down Expand Up @@ -65,3 +66,21 @@ async def test_tool_results() -> None:

# Ensure that the observations are different
assert obs1 != obs2 != obs3 != obs4 != obs5


@pytest.mark.parametrize(
"evaluation_mode",
[EvalAnswerMode.EXACT, EvalAnswerMode.CONTAINS, EvalAnswerMode.LLM],
)
@pytest.mark.asyncio
async def test_answer_evaluation_mode(evaluation_mode: EvalAnswerMode) -> None:
correct_answer = "Golden Gate Bridge"
incorrect_answer = "Bay Bridge"
env = HotPotQAEnv(
question="What is the reddest bridge in San Francisco?",
correct_answer=correct_answer,
evaluation_mode=evaluation_mode,
)

assert (await env.calculate_answer_reward(correct_answer)) == 1
assert (await env.calculate_answer_reward(incorrect_answer)) == 0
2 changes: 2 additions & 0 deletions src/aviary/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from aviary.render import Renderer
from aviary.tools import (
INVALID_TOOL_NAME,
EvalAnswerMode,
FunctionInfo,
Messages,
MessagesAdapter,
Expand Down Expand Up @@ -41,6 +42,7 @@
"EnvStateMessage",
"Environment",
"EnvironmentClient",
"EvalAnswerMode",
"Frame",
"FunctionInfo",
"MalformedMessageError",
Expand Down
3 changes: 2 additions & 1 deletion src/aviary/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
ToolsAdapter,
wraps_doc_only,
)
from .utils import ToolSelector, ToolSelectorLedger, eval_answer
from .utils import EvalAnswerMode, ToolSelector, ToolSelectorLedger, eval_answer

__all__ = [
"INVALID_TOOL_NAME",
"EvalAnswerMode",
"FunctionInfo",
"Messages",
"MessagesAdapter",
Expand Down