Skip to content

Commit

Permalink
Made it possible to get answers from litqa evaluations (#760)
Browse files Browse the repository at this point in the history
We now expose answer and ideal on LitQAEvaluations, enabling tracking of individual agent responses
  • Loading branch information
whitead authored Dec 11, 2024
1 parent 0f5c494 commit 5eaa9ee
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 10 deletions.
55 changes: 46 additions & 9 deletions paperqa/litqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,28 @@ def make_mc_options(


class LitQAEvaluation(StrEnum):
"""Possible evaluation results for a LitQA question."""
"""Possible evaluation results for a LitQA question and methods for working with answers."""

CORRECT = "correct"
INCORRECT = "incorrect"
UNSURE = "unsure"

@property
def answer(self) -> str | None:
return getattr(self, "_answer", None)

@answer.setter
def answer(self, value: str | None) -> None:
self._answer = value

@property
def ideal(self) -> str | None:
return getattr(self, "_ideal", None)

@ideal.setter
def ideal(self, value: str) -> None:
self._ideal = value

def make_discounted_returns(
self,
num_steps: int,
Expand Down Expand Up @@ -144,15 +160,21 @@ def extract_answer(answer: str) -> str:
and ord(result[0]) - _CAPITAL_A_INDEX + 1 > total_options
):
# The result extracted was not in the options
return cls.INCORRECT
evaluation = cls.INCORRECT
evaluation.answer = result
# From here, if we don't match either the ideal or the unsure multiple choice
# options then we declare the answer as incorrect.
evaluation_result = cls.INCORRECT
if unsure_mc_answer and result[0].lower() == unsure_mc_answer[0].lower():
evaluation_result = cls.UNSURE
if result[0].lower() == ideal_mc_answer[0].lower():
evaluation_result = cls.CORRECT
return evaluation_result
elif unsure_mc_answer and result[0].lower() == unsure_mc_answer[0].lower():
evaluation = cls.UNSURE
evaluation.answer = unsure_mc_answer
elif result[0].lower() == ideal_mc_answer[0].lower():
evaluation = cls.CORRECT
evaluation.answer = ideal_mc_answer
else:
evaluation = cls.INCORRECT
evaluation.answer = result
evaluation.ideal = ideal_mc_answer
return evaluation

@classmethod
def from_question(
Expand Down Expand Up @@ -215,12 +237,27 @@ async def llm_from_answer(answer: PQASession | str) -> LitQAEvaluation:
raise NotImplementedError(
f"Expected evaluation chunk to be a string, not {eval_chunk.text}."
)
return cls.from_answer(
evaluation = cls.from_answer(
text=eval_chunk.text,
ideal_mc_answer=ideal_answer,
unsure_mc_answer=unsure_answer,
total_options=len(distractor_answers) + (2 if use_unsure else 1),
)
# convert MC answers back to full text option so that it
# is meaningful
evaluation.ideal = ideal
if evaluation == cls.CORRECT:
evaluation.answer = ideal
elif evaluation == cls.UNSURE:
evaluation.answer = UNSURE_OPTION
else:
try:
evaluation.answer = distractors[
distractor_answers.index(evaluation.answer or "")
]
except ValueError:
evaluation.answer = None
return evaluation

return qa_prompt, llm_from_answer

Expand Down
22 changes: 21 additions & 1 deletion tests/test_litqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

import pytest

from paperqa.litqa import SEED_USING_QUESTION, LitQAEvaluation, read_litqa_v2_from_hub
from paperqa.litqa import (
SEED_USING_QUESTION,
UNSURE_OPTION,
LitQAEvaluation,
read_litqa_v2_from_hub,
)
from tests.conftest import VCR_DEFAULT_MATCH_ON


Expand Down Expand Up @@ -47,69 +52,79 @@ def _assert_prompt_is_valid(
"answer",
"expected_eval",
"expected_dreturns",
"extracted_answer",
),
[
pytest.param(
*ZIP_CODE_QUESTION_IDEAL_DISTRACTORS,
"the answer is 94107",
LitQAEvaluation.CORRECT,
[0.25, 0.5, 1.0],
"94107",
id="matched-correct-option",
),
pytest.param(
*ZIP_CODE_QUESTION_IDEAL_DISTRACTORS,
"the answer is 14004",
LitQAEvaluation.INCORRECT,
[-0.25, -0.5, -1.0],
None,
id="didnt-match-and-no-llm-innate-knowledge",
),
pytest.param(
*ZIP_CODE_QUESTION_IDEAL_DISTRACTORS,
"the answer is 94106",
LitQAEvaluation.INCORRECT,
[-0.25, -0.5, -1.0],
"94106",
id="matched-incorrect-option",
),
pytest.param(
*ZIP_CODE_QUESTION_IDEAL_DISTRACTORS,
"Insufficient information",
LitQAEvaluation.UNSURE,
[0.025, 0.05, 0.1],
UNSURE_OPTION,
id="matched-unsure-option",
),
pytest.param(
*ZIP_CODE_QUESTION_IDEAL_DISTRACTORS,
"the answer is 94106 or 94107",
LitQAEvaluation.INCORRECT,
[-0.25, -0.5, -1.0],
None,
id="matched-several-options",
),
pytest.param(
*ZIP_CODE_QUESTION_IDEAL_DISTRACTORS,
"",
LitQAEvaluation.INCORRECT,
[-0.25, -0.5, -1.0],
None,
id="empty-answer1",
),
pytest.param(
*MEANING_OF_LIFE_QUESTION_IDEAL_DISTRACTORS,
"14",
LitQAEvaluation.INCORRECT,
[-0.25, -0.5, -1.0],
None,
id="didnt-match-and-llm-has-innate-knowledge",
),
pytest.param(
*MEANING_OF_LIFE_QUESTION_IDEAL_DISTRACTORS,
"",
LitQAEvaluation.INCORRECT,
[-0.25, -0.5, -1.0],
None,
id="empty-answer2",
),
pytest.param(
*LITQA2_QUESTION_IDEAL_DISTRACTORS,
"",
LitQAEvaluation.INCORRECT,
[-0.25, -0.5, -1.0],
None,
id="empty-answer3",
),
],
Expand All @@ -122,6 +137,7 @@ async def test_from_question(
answer: str,
expected_eval: LitQAEvaluation,
expected_dreturns: list[float],
extracted_answer: str,
) -> None:
"""Tests that we can create a LitQA question and evaluate answers."""
qa_prompt, eval_fn = LitQAEvaluation.from_question(
Expand All @@ -134,6 +150,10 @@ async def test_from_question(

evaluation = await eval_fn(answer)
assert evaluation == expected_eval
if evaluation == LitQAEvaluation.CORRECT:
assert evaluation.answer == ideal
assert evaluation.answer == extracted_answer
assert evaluation.ideal == ideal
assert evaluation.make_discounted_returns(3, discount=0.5) == expected_dreturns

def test_consistent_mc_options(self) -> None:
Expand Down

0 comments on commit 5eaa9ee

Please sign in to comment.