From 356db3e74df4067dd4dde818cf606005bc2c2e03 Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Sun, 27 Oct 2024 14:51:15 -0700 Subject: [PATCH 1/7] increase hotpotqa lookup range to a full paragraph --- packages/hotpotqa/src/aviary/hotpotqa/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/hotpotqa/src/aviary/hotpotqa/env.py b/packages/hotpotqa/src/aviary/hotpotqa/env.py index 8e58275c..23c625df 100644 --- a/packages/hotpotqa/src/aviary/hotpotqa/env.py +++ b/packages/hotpotqa/src/aviary/hotpotqa/env.py @@ -445,7 +445,7 @@ def construct_lookup_list(self, keyword: str) -> str: self.state.last_lookup = keyword self.state.lookup_results = [ s.strip() + "." - for s in self.state.page.split(". ") + for s in self.state.page.split("\n") if s.strip() and keyword.lower() in s.lower() ] self.state.lookup_index = 0 From 5b997a4f0149ed0ed154f52cbff591be12eb48b3 Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Mon, 28 Oct 2024 12:13:00 -0700 Subject: [PATCH 2/7] lookup bug fix plus test --- .../hotpotqa/src/aviary/envs/hotpotqa/env.py | 11 ++++-- packages/hotpotqa/tests/test_hotpotqa_env.py | 38 +++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py b/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py index 5689d4c2..e2b935ec 100644 --- a/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py +++ b/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py @@ -63,7 +63,9 @@ class HotPotQAEnvState(BaseModel): default=None, description="The answer to the question, or None if not yet answered.", ) - + last_action: str | None = Field( + default=None, description="The last action taken by the agent." + ) last_lookup: str | None = Field( default=None, description="The last lookup keyword." ) @@ -340,6 +342,7 @@ def finish(self, answer: str) -> str: self.state.answer = answer self.state.reward += self.calculate_reward(answer) + self.state.last_action = "Finish" return "Finished." async def search(self, entity: str) -> str: @@ -404,6 +407,7 @@ async def search(self, entity: str) -> str: for s in p.split(". ") if s.strip() ] + self.state.last_action = "Search" return " ".join(obs_list[:5]) def construct_lookup_list(self, keyword: str) -> str: @@ -441,10 +445,10 @@ def construct_lookup_list(self, keyword: str) -> str: if not self.state.page: return "Lookup failed. You have not specified a Wikipedia page yet." - if self.state.last_lookup != keyword: + if self.state.last_action != "Lookup" or self.state.last_lookup != keyword: self.state.last_lookup = keyword self.state.lookup_results = [ - s.strip() + "." + s.strip() for s in self.state.page.split("\n") if s.strip() and keyword.lower() in s.lower() ] @@ -458,6 +462,7 @@ def construct_lookup_list(self, keyword: str) -> str: f" {self.state.lookup_results[self.state.lookup_index]}" ) self.state.lookup_index += 1 + self.state.last_action = "Lookup" return obs diff --git a/packages/hotpotqa/tests/test_hotpotqa_env.py b/packages/hotpotqa/tests/test_hotpotqa_env.py index 76afcbc1..e431e847 100644 --- a/packages/hotpotqa/tests/test_hotpotqa_env.py +++ b/packages/hotpotqa/tests/test_hotpotqa_env.py @@ -1,3 +1,5 @@ +import re + import pytest from aviary.core import Environment, TaskDataset @@ -27,3 +29,39 @@ def test_dataset_from_name() -> None: with pytest.raises(ValueError, match="answer"): TaskDataset.from_name("hotpotqa", split="test") + + +@pytest.mark.asyncio +async def test_tool_results() -> None: + hotpotqa_env: HotPotQAEnv = Environment.from_name( + "hotpotqa", + question=("Which country has a larger population: China or France?"), + correct_answer="China", + ) + lookup_pattern = r"^\(Result \d+ / \d+\)\s*(.*)" + + _, _ = await hotpotqa_env.reset() + obs1 = await hotpotqa_env.search("China") + obs2 = hotpotqa_env.construct_lookup_list("population") + + # Check lookup return format + match = re.match(lookup_pattern, obs2) + assert match # Starts with the right pattern + assert ( + match.group(1) + "\n" in hotpotqa_env.state.page + ) # Everything after the pattern should be a paragraph in current page + + obs3 = await hotpotqa_env.search("France") + obs4 = hotpotqa_env.construct_lookup_list("population") + + # Check lookup return format + match = re.match(lookup_pattern, obs4) + assert match # Starts with the right pattern + assert ( + match.group(1) + "\n" in hotpotqa_env.state.page + ) # Everything after the pattern should be a paragraph in current page + + obs5 = hotpotqa_env.finish("China") + + # Ensure that the observations are different + assert obs1 != obs2 != obs3 != obs4 != obs5 From 4fb7bcb9aadf60713380ee7360e785ca0619c088 Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Mon, 28 Oct 2024 12:22:53 -0700 Subject: [PATCH 3/7] format --- packages/hotpotqa/src/aviary/envs/hotpotqa/env.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py b/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py index e2b935ec..85786ff7 100644 --- a/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py +++ b/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py @@ -58,7 +58,6 @@ class HotPotQAEnvState(BaseModel): reward: float = Field( default=0.0, description="Current reward value, reset each environment step." ) - answer: str | None = Field( default=None, description="The answer to the question, or None if not yet answered.", From cff751339c1321e04ea3e24ebd7c3597b43cfc16 Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Mon, 28 Oct 2024 13:43:52 -0700 Subject: [PATCH 4/7] fixes --- .../hotpotqa/src/aviary/envs/hotpotqa/env.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py b/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py index 85786ff7..be7e2916 100644 --- a/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py +++ b/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py @@ -63,7 +63,9 @@ class HotPotQAEnvState(BaseModel): description="The answer to the question, or None if not yet answered.", ) last_action: str | None = Field( - default=None, description="The last action taken by the agent." + default=None, + description="The last action taken by the agent." + "Default is None, as after reset the agent has not yet taken any action.", ) last_lookup: str | None = Field( default=None, description="The last lookup keyword." @@ -341,7 +343,8 @@ def finish(self, answer: str) -> str: self.state.answer = answer self.state.reward += self.calculate_reward(answer) - self.state.last_action = "Finish" + + self.state.last_action = self.tools[2].info.name return "Finished." async def search(self, entity: str) -> str: @@ -406,7 +409,7 @@ async def search(self, entity: str) -> str: for s in p.split(". ") if s.strip() ] - self.state.last_action = "Search" + self.state.last_action = self.tools[0].info.name return " ".join(obs_list[:5]) def construct_lookup_list(self, keyword: str) -> str: @@ -444,7 +447,10 @@ def construct_lookup_list(self, keyword: str) -> str: if not self.state.page: return "Lookup failed. You have not specified a Wikipedia page yet." - if self.state.last_action != "Lookup" or self.state.last_lookup != keyword: + if ( + self.state.last_action != self.tools[1].info.name + or self.state.last_lookup != keyword + ): self.state.last_lookup = keyword self.state.lookup_results = [ s.strip() @@ -461,7 +467,7 @@ def construct_lookup_list(self, keyword: str) -> str: f" {self.state.lookup_results[self.state.lookup_index]}" ) self.state.lookup_index += 1 - self.state.last_action = "Lookup" + self.state.last_action = self.tools[1].info.name return obs From 43fb054f27bdae583208de9d1a955bfb455a2f7a Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Mon, 28 Oct 2024 15:42:49 -0700 Subject: [PATCH 5/7] Update packages/hotpotqa/tests/test_hotpotqa_env.py Co-authored-by: James Braza --- packages/hotpotqa/tests/test_hotpotqa_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/hotpotqa/tests/test_hotpotqa_env.py b/packages/hotpotqa/tests/test_hotpotqa_env.py index e431e847..be5f7033 100644 --- a/packages/hotpotqa/tests/test_hotpotqa_env.py +++ b/packages/hotpotqa/tests/test_hotpotqa_env.py @@ -56,10 +56,10 @@ async def test_tool_results() -> None: # Check lookup return format match = re.match(lookup_pattern, obs4) - assert match # Starts with the right pattern + assert match, "Expected lookup" assert ( match.group(1) + "\n" in hotpotqa_env.state.page - ) # Everything after the pattern should be a paragraph in current page + ), "Expected text after the match to be a paragraph" obs5 = hotpotqa_env.finish("China") From 04a31c0820145fc3483a93fadc2d12a3d35c8dfe Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Mon, 28 Oct 2024 15:52:45 -0700 Subject: [PATCH 6/7] fixes --- .../hotpotqa/src/aviary/envs/hotpotqa/env.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py b/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py index be7e2916..a81d89a5 100644 --- a/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py +++ b/packages/hotpotqa/src/aviary/envs/hotpotqa/env.py @@ -62,10 +62,10 @@ class HotPotQAEnvState(BaseModel): default=None, description="The answer to the question, or None if not yet answered.", ) - last_action: str | None = Field( - default=None, - description="The last action taken by the agent." - "Default is None, as after reset the agent has not yet taken any action.", + last_action_is_lookup: bool = Field( + default=False, + description="Whether the last action was a lookup action." + "Default is False, as after reset the agent has not yet taken any action.", ) last_lookup: str | None = Field( default=None, description="The last lookup keyword." @@ -344,7 +344,7 @@ def finish(self, answer: str) -> str: self.state.answer = answer self.state.reward += self.calculate_reward(answer) - self.state.last_action = self.tools[2].info.name + self.state.last_action_is_lookup = False return "Finished." async def search(self, entity: str) -> str: @@ -409,7 +409,7 @@ async def search(self, entity: str) -> str: for s in p.split(". ") if s.strip() ] - self.state.last_action = self.tools[0].info.name + self.state.last_action_is_lookup = False return " ".join(obs_list[:5]) def construct_lookup_list(self, keyword: str) -> str: @@ -447,10 +447,7 @@ def construct_lookup_list(self, keyword: str) -> str: if not self.state.page: return "Lookup failed. You have not specified a Wikipedia page yet." - if ( - self.state.last_action != self.tools[1].info.name - or self.state.last_lookup != keyword - ): + if not self.state.last_action_is_lookup or self.state.last_lookup != keyword: self.state.last_lookup = keyword self.state.lookup_results = [ s.strip() @@ -467,7 +464,7 @@ def construct_lookup_list(self, keyword: str) -> str: f" {self.state.lookup_results[self.state.lookup_index]}" ) self.state.lookup_index += 1 - self.state.last_action = self.tools[1].info.name + self.state.last_action_is_lookup = True return obs From 3274129d77908739b8a4502cf5f60cf23bb826b6 Mon Sep 17 00:00:00 2001 From: albertbou92 Date: Mon, 28 Oct 2024 15:53:47 -0700 Subject: [PATCH 7/7] format --- packages/hotpotqa/tests/test_hotpotqa_env.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/hotpotqa/tests/test_hotpotqa_env.py b/packages/hotpotqa/tests/test_hotpotqa_env.py index be5f7033..33cae63c 100644 --- a/packages/hotpotqa/tests/test_hotpotqa_env.py +++ b/packages/hotpotqa/tests/test_hotpotqa_env.py @@ -57,9 +57,9 @@ async def test_tool_results() -> None: # Check lookup return format match = re.match(lookup_pattern, obs4) assert match, "Expected lookup" - assert ( - match.group(1) + "\n" in hotpotqa_env.state.page - ), "Expected text after the match to be a paragraph" + assert match.group(1) + "\n" in hotpotqa_env.state.page, ( + "Expected text after the match to be a paragraph" + ) obs5 = hotpotqa_env.finish("China")