Skip to content

Commit

Permalink
Added configuration for both Huggingface Hub import and options shuff…
Browse files Browse the repository at this point in the history
…ling, with testing
  • Loading branch information
jamesbraza committed Dec 10, 2024
1 parent 53b0cf2 commit 7cbfb6b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
12 changes: 8 additions & 4 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __init__(
base_query: QueryRequest | dict | None = None,
base_docs: Docs | dict | None = None,
rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING,
question_kwargs: Mapping[str, Any] | None = None,
eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME,
**env_kwargs,
):
Expand All @@ -210,23 +211,23 @@ def __init__(
base_docs = Docs(**base_docs)
self._base_docs = base_docs
self._rewards = rewards
self._env_kwargs = env_kwargs
self._question_kwargs = question_kwargs
self._eval_model = eval_model
self._env_kwargs = env_kwargs

def _make_gradable_environment(
self,
ideal: str,
distractors: str | list[str],
question: str,
use_unsure: bool = True,
sources: str | list[str] | None = None,
) -> GradablePaperQAEnvironment:
qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question(
ideal=ideal,
distractors=distractors,
question=question,
use_unsure=use_unsure,
eval_model=self._eval_model,
**(self._question_kwargs or {}),
)
query = self._base_query.model_copy()
query.query = qa_prompt
Expand Down Expand Up @@ -305,11 +306,14 @@ def __init__(
self,
*args,
labbench_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
read_data_kwargs: Mapping[str, Any] | None = None,
split: str | LitQAv2TaskSplit = LitQAv2TaskSplit.EVAL,
**kwargs,
):
super().__init__(*args, **kwargs)
train_df, eval_df = read_litqa_v2_from_hub(labbench_dataset)
train_df, eval_df = read_litqa_v2_from_hub(
labbench_dataset, **(read_data_kwargs or {})
)
split = LitQAv2TaskSplit(split)
if split == LitQAv2TaskSplit.TRAIN:
self.data = train_df
Expand Down
20 changes: 18 additions & 2 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
LitQAv2TaskSplit,
)
from paperqa.agents.tools import GenerateAnswer
from paperqa.litqa import DEFAULT_REWARD_MAPPING, LitQAEvaluation
from paperqa.litqa import DEFAULT_REWARD_MAPPING, SEED_USING_QUESTION, LitQAEvaluation


@pytest.fixture(name="base_query_request")
Expand Down Expand Up @@ -103,12 +103,27 @@ async def test___len__(
expected_length: int,
base_query_request: QueryRequest,
) -> None:
task_dataset = LitQAv2TaskDataset(base_query=base_query_request, split=split)
task_dataset = LitQAv2TaskDataset(
base_query=base_query_request,
question_kwargs={"seed": 42},
read_data_kwargs={"seed": 42},
split=split,
)
assert len(task_dataset) == expected_length

# Now let's check we could use the sources in a validation
for i in range(len(task_dataset)):
env = task_dataset.get_new_env_by_idx(i)
if i == 0 and split == LitQAv2TaskSplit.TRAIN:
# Yes this assertion is somewhat brittle, but it reliably
# checks the seeding's behavior so we keep it
obs, _ = await env.reset()
assert (
"Q: SLC14A1 been identified as a specific marker for endothelial"
" cells in which organ?\n\nOptions:\nA) heart\nB) eye\nC)"
" prostate\nD) Insufficient information to answer this question\nE)"
" liver" in (obs[0].content or "")
)
assert env.sources, "Sources need to be accessible"
assert isinstance(
env.sources, Iterable
Expand Down Expand Up @@ -144,6 +159,7 @@ async def test_evaluation(
"deleted_dockeys",
}
),
"question_kwargs": {"seed": SEED_USING_QUESTION},
},
)
# NOTE: set base_query after construction of the TaskConfig. because in
Expand Down

0 comments on commit 7cbfb6b

Please sign in to comment.