Skip to content

Commit

Permalink
Added status info on used contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Jan 29, 2024
1 parent d8a2387 commit c1eea51
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 16 deletions.
2 changes: 2 additions & 0 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NumpyVectorStore,
LangchainVectorStore,
SentenceTransformerEmbeddingModel,
LLMResult,
)

__all__ = [
Expand All @@ -32,4 +33,5 @@
"NumpyVectorStore",
"LangchainVectorStore",
"print_callback",
"LLMResult",
]
9 changes: 8 additions & 1 deletion paperqa/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Callable
from uuid import UUID, uuid4

from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator

from .prompts import (
citation_prompt,
Expand All @@ -12,6 +12,7 @@
summary_json_system_prompt,
summary_prompt,
)
from .utils import get_citenames

# Just for clarity
DocKey = Any
Expand Down Expand Up @@ -174,6 +175,12 @@ def __str__(self) -> str:
"""Return the answer as a string."""
return self.formatted_answer

@computed_field # type: ignore
@property
def used_contexts(self) -> set[str]:
"""Return the used contexts."""
return get_citenames(self.formatted_answer)

def get_citation(self, name: str) -> str:
"""Return the formatted citation for the gien docname."""
try:
Expand Down
19 changes: 16 additions & 3 deletions paperqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,24 @@ def strip_citations(text: str) -> str:
return text


def iter_citations(text: str) -> list[str]:
def get_citenames(text: str) -> set[str]:
# Combined regex for identifying citations (see unit tests for examples)
citation_regex = r"\b[\w\-]+\set\sal\.\s\([0-9]{4}\)|\((?:[^\)]*?[a-zA-Z][^\)]*?[0-9]{4}[^\)]*?)\)"
result = re.findall(citation_regex, text, flags=re.MULTILINE)
return result
results = re.findall(citation_regex, text, flags=re.MULTILINE)
# now find None patterns
none_citation_regex = r"(\(None[a-f]{0,1} pages [0-9]{1,10}-[0-9]{1,10}\))"
none_results = re.findall(none_citation_regex, text, flags=re.MULTILINE)
results.extend(none_results)
values = []
for citation in results:
citation = citation.strip("() ")
for c in re.split(",|;", citation):
if c == "Extra background information":
continue
# remove leading/trailing spaces
c = c.strip()
values.append(c)
return set(values)


def extract_doi(reference: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "4.0.0-pre.4"
__version__ = "4.0.0-pre.5"
52 changes: 41 additions & 11 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from paperqa.readers import read_doc
from paperqa.utils import (
iter_citations,
get_citenames,
maybe_is_html,
maybe_is_text,
name_in_text,
Expand All @@ -57,7 +57,7 @@ def test_guess_model_type():
assert guess_model_type("davinci-002") == "completion"


def test_iter_citations():
def test_get_citations():
text = (
"Yes, COVID-19 vaccines are effective. Various studies have documented the "
"effectiveness of COVID-19 vaccines in preventing severe disease, "
Expand All @@ -79,15 +79,18 @@ def test_iter_citations():
"(Chemaitelly2022WaningEO, Foo2019Bar). Despite this, vaccines still provide "
"significant protection against severe outcomes (Bar2000Foo pg 1-3; Far2000 pg 2-5)."
)
ref = [
"(Dorabawila2022EffectivenessOT)",
"(Bernal2021EffectivenessOC pg. 1-3)",
"(Thompson2021EffectivenessOC pg. 3-5, Goo2031Foo pg. 3-4)",
"(Marfé2021EffectivenessOC)",
"(Chemaitelly2022WaningEO, Foo2019Bar)",
"(Bar2000Foo pg 1-3; Far2000 pg 2-5)",
]
assert list(iter_citations(text)) == ref
ref = {
"Dorabawila2022EffectivenessOT",
"Bernal2021EffectivenessOC pg. 1-3",
"Thompson2021EffectivenessOC pg. 3-5",
"Goo2031Foo pg. 3-4",
"Marfé2021EffectivenessOC",
"Chemaitelly2022WaningEO",
"Foo2019Bar",
"Bar2000Foo pg 1-3",
"Far2000 pg 2-5",
}
assert get_citenames(text) == ref


def test_single_author():
Expand Down Expand Up @@ -529,6 +532,27 @@ def test_query():
docs.query("What is Frederick Bates's greatest accomplishment?")


def test_answer_attributes():
docs = Docs()
docs.add_url(
"https://en.wikipedia.org/wiki/Frederick_Bates_(politician)",
citation="WikiMedia Foundation, 2023, Accessed now",
dockey="test",
)
answer = docs.query("What is Frederick Bates's greatest accomplishment?")
used_citations = answer.used_contexts
assert len(used_citations) > 0
assert len(used_citations) < len(answer.contexts)
assert (
answer.get_citation(list(used_citations)[0])
== "WikiMedia Foundation, 2023, Accessed now"
)

# make sure it is serialized correctly
js = answer.model_dump_json()
assert "used_contexts" in js


def test_llmresult_callbacks():
my_results = []

Expand Down Expand Up @@ -906,6 +930,12 @@ def test_docs_pickle():
# make sure we can query
docs.query("What date is bring your dog to work in the US?")

# make sure we can embed documents
docs2.add_url(
"https://en.wikipedia.org/wiki/Frederick_Bates_(politician)",
citation="WikiMedia Foundation, 2023, Accessed now",
)


def test_bad_context():
doc_path = "example.html"
Expand Down

0 comments on commit c1eea51

Please sign in to comment.