From c1eea514faaf4ac5777f72bed67e36bbef5762b5 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Mon, 29 Jan 2024 11:58:14 -0800 Subject: [PATCH] Added status info on used contexts --- paperqa/__init__.py | 2 ++ paperqa/types.py | 9 +++++++- paperqa/utils.py | 19 +++++++++++++--- paperqa/version.py | 2 +- tests/test_paperqa.py | 52 ++++++++++++++++++++++++++++++++++--------- 5 files changed, 68 insertions(+), 16 deletions(-) diff --git a/paperqa/__init__.py b/paperqa/__init__.py index 285f4bbfd..f6a4a268a 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -11,6 +11,7 @@ NumpyVectorStore, LangchainVectorStore, SentenceTransformerEmbeddingModel, + LLMResult, ) __all__ = [ @@ -32,4 +33,5 @@ "NumpyVectorStore", "LangchainVectorStore", "print_callback", + "LLMResult", ] diff --git a/paperqa/types.py b/paperqa/types.py index 6275e184f..57864d2a7 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -1,7 +1,7 @@ from typing import Any, Callable from uuid import UUID, uuid4 -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator from .prompts import ( citation_prompt, @@ -12,6 +12,7 @@ summary_json_system_prompt, summary_prompt, ) +from .utils import get_citenames # Just for clarity DocKey = Any @@ -174,6 +175,12 @@ def __str__(self) -> str: """Return the answer as a string.""" return self.formatted_answer + @computed_field # type: ignore + @property + def used_contexts(self) -> set[str]: + """Return the used contexts.""" + return get_citenames(self.formatted_answer) + def get_citation(self, name: str) -> str: """Return the formatted citation for the gien docname.""" try: diff --git a/paperqa/utils.py b/paperqa/utils.py index 5152a79b7..de6162bd8 100644 --- a/paperqa/utils.py +++ b/paperqa/utils.py @@ -107,11 +107,24 @@ def strip_citations(text: str) -> str: return text -def iter_citations(text: str) -> list[str]: +def get_citenames(text: str) -> set[str]: # Combined regex for identifying citations (see unit tests for examples) citation_regex = r"\b[\w\-]+\set\sal\.\s\([0-9]{4}\)|\((?:[^\)]*?[a-zA-Z][^\)]*?[0-9]{4}[^\)]*?)\)" - result = re.findall(citation_regex, text, flags=re.MULTILINE) - return result + results = re.findall(citation_regex, text, flags=re.MULTILINE) + # now find None patterns + none_citation_regex = r"(\(None[a-f]{0,1} pages [0-9]{1,10}-[0-9]{1,10}\))" + none_results = re.findall(none_citation_regex, text, flags=re.MULTILINE) + results.extend(none_results) + values = [] + for citation in results: + citation = citation.strip("() ") + for c in re.split(",|;", citation): + if c == "Extra background information": + continue + # remove leading/trailing spaces + c = c.strip() + values.append(c) + return set(values) def extract_doi(reference: str) -> str: diff --git a/paperqa/version.py b/paperqa/version.py index 7ce614769..f26b9a90c 100644 --- a/paperqa/version.py +++ b/paperqa/version.py @@ -1 +1 @@ -__version__ = "4.0.0-pre.4" +__version__ = "4.0.0-pre.5" diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index c34437274..1aecb8030 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -30,7 +30,7 @@ ) from paperqa.readers import read_doc from paperqa.utils import ( - iter_citations, + get_citenames, maybe_is_html, maybe_is_text, name_in_text, @@ -57,7 +57,7 @@ def test_guess_model_type(): assert guess_model_type("davinci-002") == "completion" -def test_iter_citations(): +def test_get_citations(): text = ( "Yes, COVID-19 vaccines are effective. Various studies have documented the " "effectiveness of COVID-19 vaccines in preventing severe disease, " @@ -79,15 +79,18 @@ def test_iter_citations(): "(Chemaitelly2022WaningEO, Foo2019Bar). Despite this, vaccines still provide " "significant protection against severe outcomes (Bar2000Foo pg 1-3; Far2000 pg 2-5)." ) - ref = [ - "(Dorabawila2022EffectivenessOT)", - "(Bernal2021EffectivenessOC pg. 1-3)", - "(Thompson2021EffectivenessOC pg. 3-5, Goo2031Foo pg. 3-4)", - "(Marfé2021EffectivenessOC)", - "(Chemaitelly2022WaningEO, Foo2019Bar)", - "(Bar2000Foo pg 1-3; Far2000 pg 2-5)", - ] - assert list(iter_citations(text)) == ref + ref = { + "Dorabawila2022EffectivenessOT", + "Bernal2021EffectivenessOC pg. 1-3", + "Thompson2021EffectivenessOC pg. 3-5", + "Goo2031Foo pg. 3-4", + "Marfé2021EffectivenessOC", + "Chemaitelly2022WaningEO", + "Foo2019Bar", + "Bar2000Foo pg 1-3", + "Far2000 pg 2-5", + } + assert get_citenames(text) == ref def test_single_author(): @@ -529,6 +532,27 @@ def test_query(): docs.query("What is Frederick Bates's greatest accomplishment?") +def test_answer_attributes(): + docs = Docs() + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + answer = docs.query("What is Frederick Bates's greatest accomplishment?") + used_citations = answer.used_contexts + assert len(used_citations) > 0 + assert len(used_citations) < len(answer.contexts) + assert ( + answer.get_citation(list(used_citations)[0]) + == "WikiMedia Foundation, 2023, Accessed now" + ) + + # make sure it is serialized correctly + js = answer.model_dump_json() + assert "used_contexts" in js + + def test_llmresult_callbacks(): my_results = [] @@ -906,6 +930,12 @@ def test_docs_pickle(): # make sure we can query docs.query("What date is bring your dog to work in the US?") + # make sure we can embed documents + docs2.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + ) + def test_bad_context(): doc_path = "example.html"