From 3abf52eec59bf93433fca56b50c8f320a852487e Mon Sep 17 00:00:00 2001 From: Ben Chambers <35960+bjchambers@users.noreply.github.com> Date: Mon, 29 Jan 2024 15:41:45 -0800 Subject: [PATCH] feat: Store the extracted text on documents This closes #22. --- dewy/common/collection_embeddings.py | 6 +++++- dewy/documents/models.py | 8 +++++++ dewy/documents/router.py | 2 +- migrations/0002_document_text.sql | 2 ++ tests/test_e2e.py | 32 ++++++++++++++++++++-------- 5 files changed, 39 insertions(+), 11 deletions(-) create mode 100644 migrations/0002_document_text.sql diff --git a/dewy/common/collection_embeddings.py b/dewy/common/collection_embeddings.py index 12044d0..51dc7a0 100644 --- a/dewy/common/collection_embeddings.py +++ b/dewy/common/collection_embeddings.py @@ -270,10 +270,14 @@ async def ingest(self, document_id: int, url: str) -> None: await conn.execute( """ UPDATE document - SET ingest_state = 'ingested', ingest_error = NULL + SET + ingest_state = 'ingested', + ingest_error = NULL, + extracted_text = $2 WHERE id = $1 """, document_id, + extracted.text ) async def _chunk_sentences(self, text: str) -> List[str]: diff --git a/dewy/documents/models.py b/dewy/documents/models.py index 6f62dca..1648e1e 100644 --- a/dewy/documents/models.py +++ b/dewy/documents/models.py @@ -29,6 +29,14 @@ class Document(BaseModel): id: Optional[int] = None collection_id: int + extracted_text: Optional[str] = None + """The text that was extracted for this document. + + This is only returned when getting a specific document, not listing documents. + + Will not be set until after the document is ingested. + """ + url: str ingest_state: Optional[IngestState] = None diff --git a/dewy/documents/router.py b/dewy/documents/router.py index e194fbc..be04aae 100644 --- a/dewy/documents/router.py +++ b/dewy/documents/router.py @@ -78,7 +78,7 @@ async def get_document(conn: PgConnectionDep, id: PathDocumentId) -> Document: # TODO: Test / return not found? result = await conn.fetchrow( """ - SELECT id, collection_id, url, ingest_state, ingest_error + SELECT id, collection_id, url, ingest_state, ingest_error, extracted_text FROM document WHERE id = $1 """, id, diff --git a/migrations/0002_document_text.sql b/migrations/0002_document_text.sql new file mode 100644 index 0000000..47f95f5 --- /dev/null +++ b/migrations/0002_document_text.sql @@ -0,0 +1,2 @@ +ALTER TABLE document +ADD COLUMN extracted_text TEXT; \ No newline at end of file diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 7c39c34..559bcba 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -6,7 +6,7 @@ from pydantic import TypeAdapter from dewy.chunks.models import Chunk, RetrieveRequest, RetrieveResponse -from dewy.documents.models import AddDocumentRequest +from dewy.documents.models import AddDocumentRequest, Document SKELETON_OF_THOUGHT_PDF = "https://arxiv.org/pdf/2307.15337.pdf" @@ -46,6 +46,12 @@ async def list_chunks(client, collection: int, document: int): ta = TypeAdapter(List[Chunk]) return ta.validate_json(response.content) +async def get_document(client, document_id: int) -> Document: + response = await client.get(f"/api/documents/{document_id}") + assert response.status_code == 200 + assert response + return Document.model_validate_json(response.content) + async def retrieve(client, collection: int, query: str) -> RetrieveResponse: request = RetrieveRequest( collection_id=collection, query=query, include_image_chunks=False @@ -58,10 +64,14 @@ async def retrieve(client, collection: int, query: str) -> RetrieveResponse: async def test_e2e_openai_ada002(client): collection = await create_collection(client, "openai:text-embedding-ada-002") - document = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF) - chunks = await list_chunks(client, collection, document) + document_id = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF) + + document = await get_document(client, document_id) + assert document.extracted_text.startswith("Skeleton-of-Thought") + + chunks = await list_chunks(client, collection, document_id) assert len(chunks) > 0 - assert chunks[0].document_id == document + assert chunks[0].document_id == document_id results = await retrieve( client, collection, "outline the steps to using skeleton-of-thought prompting" @@ -69,16 +79,20 @@ async def test_e2e_openai_ada002(client): assert len(results.text_results) > 0 print(results.text_results) - assert results.text_results[0].document_id == document + assert results.text_results[0].document_id == document_id assert "skeleton" in results.text_results[0].text.lower() async def test_e2e_hf_bge_small(client): collection = await create_collection(client, "hf:BAAI/bge-small-en") - document = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF) - chunks = await list_chunks(client, collection, document) + document_id = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF) + + document = await get_document(client, document_id) + assert document.extracted_text.startswith("Skeleton-of-Thought") + + chunks = await list_chunks(client, collection, document_id) assert len(chunks) > 0 - assert chunks[0].document_id == document + assert chunks[0].document_id == document_id results = await retrieve( client, collection, "outline the steps to using skeleton-of-thought prompting" @@ -86,5 +100,5 @@ async def test_e2e_hf_bge_small(client): assert len(results.text_results) > 0 print(results.text_results) - assert results.text_results[0].document_id == document + assert results.text_results[0].document_id == document_id assert "skeleton" in results.text_results[0].text.lower()