Skip to content

Commit

Permalink
Exposed seeding of LitQA2 read and shuffling (#758)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Dec 10, 2024
1 parent e3623ed commit 58dbfc0
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 14 deletions.
12 changes: 8 additions & 4 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __init__(
base_query: QueryRequest | dict | None = None,
base_docs: Docs | dict | None = None,
rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING,
question_kwargs: Mapping[str, Any] | None = None,
eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME,
**env_kwargs,
):
Expand All @@ -210,23 +211,23 @@ def __init__(
base_docs = Docs(**base_docs)
self._base_docs = base_docs
self._rewards = rewards
self._env_kwargs = env_kwargs
self._question_kwargs = question_kwargs
self._eval_model = eval_model
self._env_kwargs = env_kwargs

def _make_gradable_environment(
self,
ideal: str,
distractors: str | list[str],
question: str,
use_unsure: bool = True,
sources: str | list[str] | None = None,
) -> GradablePaperQAEnvironment:
qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question(
ideal=ideal,
distractors=distractors,
question=question,
use_unsure=use_unsure,
eval_model=self._eval_model,
**(self._question_kwargs or {}),
)
query = self._base_query.model_copy()
query.query = qa_prompt
Expand Down Expand Up @@ -305,11 +306,14 @@ def __init__(
self,
*args,
labbench_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
read_data_kwargs: Mapping[str, Any] | None = None,
split: str | LitQAv2TaskSplit = LitQAv2TaskSplit.EVAL,
**kwargs,
):
super().__init__(*args, **kwargs)
train_df, eval_df = read_litqa_v2_from_hub(labbench_dataset)
train_df, eval_df = read_litqa_v2_from_hub(
labbench_dataset, **(read_data_kwargs or {})
)
split = LitQAv2TaskSplit(split)
if split == LitQAv2TaskSplit.TRAIN:
self.data = train_df
Expand Down
9 changes: 7 additions & 2 deletions paperqa/litqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ast import literal_eval
from collections.abc import Awaitable, Callable, Mapping, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, Self
from typing import TYPE_CHECKING, Literal, Self

try:
from ldp.utils import discounted_returns
Expand Down Expand Up @@ -92,6 +92,7 @@ def make_mc_options(

DEFAULT_EVAL_MODEL_NAME = "gpt-4-turbo-2024-04-09"
DEFAULT_REWARD_MAPPING = {"correct": 1.0, "unsure": 0.1, "incorrect": -1.0}
SEED_USING_QUESTION: Literal["SEED_USING_QUESTION"] = "SEED_USING_QUESTION" # Sentinel


class LitQAEvaluation(StrEnum):
Expand Down Expand Up @@ -161,7 +162,7 @@ def from_question(
question: str,
use_unsure: bool = True,
eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME,
seed: int | None = None,
seed: int | Literal["SEED_USING_QUESTION"] | None = None,
) -> tuple[str, Callable[[PQASession | str], Awaitable[LitQAEvaluation]]]:
"""
Create a LitQA question and an answer-to-evaluation function.
Expand All @@ -174,11 +175,15 @@ def from_question(
eval_model: Evaluation model to use for multiple choice letter extraction
from a text answer.
seed: Optional seed to use in randomization of multiple choice letters.
Optionally pass in the string literal "SEED_USING_QUESTION" to hash the
input question for the seed.
Returns:
Two-tuple of created LitQA question, function (that can be thought of as
stateless) to use to extract an evaluation result from an answer.
"""
if seed == SEED_USING_QUESTION:
seed = hash(question)
text, ideal_answer, unsure_answer, distractor_answers = make_mc_options(
ideal=ideal,
distractors=distractors,
Expand Down
34 changes: 28 additions & 6 deletions tests/test_litqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from paperqa.litqa import LitQAEvaluation, read_litqa_v2_from_hub
from paperqa.litqa import SEED_USING_QUESTION, LitQAEvaluation, read_litqa_v2_from_hub
from tests.conftest import VCR_DEFAULT_MATCH_ON


Expand Down Expand Up @@ -140,16 +140,38 @@ def test_consistent_mc_options(self) -> None:
"""Tests that creating multiple evaluations with the same seed results in the same prompt."""
question, ideal, distractors = self.MEANING_OF_LIFE_QUESTION_IDEAL_DISTRACTORS

qa_prompt_1, _ = LitQAEvaluation.from_question(
qa_prompt_1a, _ = LitQAEvaluation.from_question(
ideal=ideal, distractors=distractors, question=question, seed=0
)
self._assert_prompt_is_valid(qa_prompt_1, question, ideal, distractors)
self._assert_prompt_is_valid(qa_prompt_1a, question, ideal, distractors)

qa_prompt_2, _ = LitQAEvaluation.from_question(
qa_prompt_1b, _ = LitQAEvaluation.from_question(
ideal=ideal, distractors=distractors, question=question, seed=0
)
self._assert_prompt_is_valid(qa_prompt_1, question, ideal, distractors)
assert qa_prompt_1 == qa_prompt_2
self._assert_prompt_is_valid(qa_prompt_1b, question, ideal, distractors)
assert qa_prompt_1a == qa_prompt_1b, "Same seeding should lead to same prompts"

qa_prompt_2a, _ = LitQAEvaluation.from_question(
ideal=ideal,
distractors=distractors,
question=question,
seed=SEED_USING_QUESTION,
)
self._assert_prompt_is_valid(qa_prompt_2a, question, ideal, distractors)

qa_prompt_2b, _ = LitQAEvaluation.from_question(
ideal=ideal,
distractors=distractors,
question=question,
seed=SEED_USING_QUESTION,
)
self._assert_prompt_is_valid(qa_prompt_2b, question, ideal, distractors)
assert (
qa_prompt_2a == qa_prompt_2b
), "Same seeding strategy should lead to same prompts"
assert (
qa_prompt_2a != qa_prompt_1a
), "Different seeding strategies should lead to different prompts"

def test_creating_litqa_questions(self) -> None:
"""Test making LitQA eval questions after downloading from Hugging Face Hub."""
Expand Down
20 changes: 18 additions & 2 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
LitQAv2TaskSplit,
)
from paperqa.agents.tools import GenerateAnswer
from paperqa.litqa import DEFAULT_REWARD_MAPPING, LitQAEvaluation
from paperqa.litqa import DEFAULT_REWARD_MAPPING, SEED_USING_QUESTION, LitQAEvaluation


@pytest.fixture(name="base_query_request")
Expand Down Expand Up @@ -103,12 +103,27 @@ async def test___len__(
expected_length: int,
base_query_request: QueryRequest,
) -> None:
task_dataset = LitQAv2TaskDataset(base_query=base_query_request, split=split)
task_dataset = LitQAv2TaskDataset(
base_query=base_query_request,
question_kwargs={"seed": 42},
read_data_kwargs={"seed": 42},
split=split,
)
assert len(task_dataset) == expected_length

# Now let's check we could use the sources in a validation
for i in range(len(task_dataset)):
env = task_dataset.get_new_env_by_idx(i)
if i == 0 and split == LitQAv2TaskSplit.TRAIN:
# Yes this assertion is somewhat brittle, but it reliably
# checks the seeding's behavior so we keep it
obs, _ = await env.reset()
assert (
"Q: SLC14A1 been identified as a specific marker for endothelial"
" cells in which organ?\n\nOptions:\nA) heart\nB) eye\nC)"
" prostate\nD) Insufficient information to answer this question\nE)"
" liver" in (obs[0].content or "")
)
assert env.sources, "Sources need to be accessible"
assert isinstance(
env.sources, Iterable
Expand Down Expand Up @@ -144,6 +159,7 @@ async def test_evaluation(
"deleted_dockeys",
}
),
"question_kwargs": {"seed": SEED_USING_QUESTION},
},
)
# NOTE: set base_query after construction of the TaskConfig. because in
Expand Down

0 comments on commit 58dbfc0

Please sign in to comment.