diff --git a/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py b/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py index f35ee956..a81d89a5 100644 --- a/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py +++ b/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py @@ -58,12 +58,15 @@ class HotPotQAEnvState(BaseModel): reward: float = Field( default=0.0, description="Current reward value, reset each environment step." ) - answer: str | None = Field( default=None, description="The answer to the question, or None if not yet answered.", ) - + last_action_is_lookup: bool = Field( + default=False, + description="Whether the last action was a lookup action." + "Default is False, as after reset the agent has not yet taken any action.", + ) last_lookup: str | None = Field( default=None, description="The last lookup keyword." ) @@ -340,6 +343,8 @@ def finish(self, answer: str) -> str: self.state.answer = answer self.state.reward += self.calculate_reward(answer) + + self.state.last_action_is_lookup = False return "Finished." async def search(self, entity: str) -> str: @@ -404,6 +409,7 @@ async def search(self, entity: str) -> str: for s in p.split(". ") if s.strip() ] + self.state.last_action_is_lookup = False return " ".join(obs_list[:5]) def construct_lookup_list(self, keyword: str) -> str: @@ -441,11 +447,11 @@ def construct_lookup_list(self, keyword: str) -> str: if not self.state.page: return "Lookup failed. You have not specified a Wikipedia page yet." - if self.state.last_lookup != keyword: + if not self.state.last_action_is_lookup or self.state.last_lookup != keyword: self.state.last_lookup = keyword self.state.lookup_results = [ - s.strip() + "." - for s in self.state.page.split(". ") + s.strip() + for s in self.state.page.split("\n") if s.strip() and keyword.lower() in s.lower() ] self.state.lookup_index = 0 @@ -458,6 +464,7 @@ def construct_lookup_list(self, keyword: str) -> str: f" {self.state.lookup_results[self.state.lookup_index]}" ) self.state.lookup_index += 1 + self.state.last_action_is_lookup = True return obs diff --git a/packages/hotpotqa/tests/test_hotpotqa_env.py b/packages/hotpotqa/tests/test_hotpotqa_env.py index 76afcbc1..33cae63c 100644 --- a/packages/hotpotqa/tests/test_hotpotqa_env.py +++ b/packages/hotpotqa/tests/test_hotpotqa_env.py @@ -1,3 +1,5 @@ +import re + import pytest from aviary.core import Environment, TaskDataset @@ -27,3 +29,39 @@ def test_dataset_from_name() -> None: with pytest.raises(ValueError, match="answer"): TaskDataset.from_name("hotpotqa", split="test") + + +@pytest.mark.asyncio +async def test_tool_results() -> None: + hotpotqa_env: HotPotQAEnv = Environment.from_name( + "hotpotqa", + question=("Which country has a larger population: China or France?"), + correct_answer="China", + ) + lookup_pattern = r"^\(Result \d+ / \d+\)\s*(.*)" + + _, _ = await hotpotqa_env.reset() + obs1 = await hotpotqa_env.search("China") + obs2 = hotpotqa_env.construct_lookup_list("population") + + # Check lookup return format + match = re.match(lookup_pattern, obs2) + assert match # Starts with the right pattern + assert ( + match.group(1) + "\n" in hotpotqa_env.state.page + ) # Everything after the pattern should be a paragraph in current page + + obs3 = await hotpotqa_env.search("France") + obs4 = hotpotqa_env.construct_lookup_list("population") + + # Check lookup return format + match = re.match(lookup_pattern, obs4) + assert match, "Expected lookup" + assert match.group(1) + "\n" in hotpotqa_env.state.page, ( + "Expected text after the match to be a paragraph" + ) + + obs5 = hotpotqa_env.finish("China") + + # Ensure that the observations are different + assert obs1 != obs2 != obs3 != obs4 != obs5