Skip to content

Commit

Permalink
feat: Store the extracted text on documents
Browse files Browse the repository at this point in the history
This closes #22.
  • Loading branch information
bjchambers committed Jan 29, 2024
1 parent cfde7a7 commit 3abf52e
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 11 deletions.
6 changes: 5 additions & 1 deletion dewy/common/collection_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,14 @@ async def ingest(self, document_id: int, url: str) -> None:
await conn.execute(
"""
UPDATE document
SET ingest_state = 'ingested', ingest_error = NULL
SET
ingest_state = 'ingested',
ingest_error = NULL,
extracted_text = $2
WHERE id = $1
""",
document_id,
extracted.text
)

async def _chunk_sentences(self, text: str) -> List[str]:
Expand Down
8 changes: 8 additions & 0 deletions dewy/documents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ class Document(BaseModel):
id: Optional[int] = None
collection_id: int

extracted_text: Optional[str] = None
"""The text that was extracted for this document.
This is only returned when getting a specific document, not listing documents.
Will not be set until after the document is ingested.
"""

url: str

ingest_state: Optional[IngestState] = None
Expand Down
2 changes: 1 addition & 1 deletion dewy/documents/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def get_document(conn: PgConnectionDep, id: PathDocumentId) -> Document:
# TODO: Test / return not found?
result = await conn.fetchrow(
"""
SELECT id, collection_id, url, ingest_state, ingest_error
SELECT id, collection_id, url, ingest_state, ingest_error, extracted_text
FROM document WHERE id = $1
""",
id,
Expand Down
2 changes: 2 additions & 0 deletions migrations/0002_document_text.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE document
ADD COLUMN extracted_text TEXT;
32 changes: 23 additions & 9 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import TypeAdapter

from dewy.chunks.models import Chunk, RetrieveRequest, RetrieveResponse
from dewy.documents.models import AddDocumentRequest
from dewy.documents.models import AddDocumentRequest, Document

SKELETON_OF_THOUGHT_PDF = "https://arxiv.org/pdf/2307.15337.pdf"

Expand Down Expand Up @@ -46,6 +46,12 @@ async def list_chunks(client, collection: int, document: int):
ta = TypeAdapter(List[Chunk])
return ta.validate_json(response.content)

async def get_document(client, document_id: int) -> Document:
response = await client.get(f"/api/documents/{document_id}")
assert response.status_code == 200
assert response
return Document.model_validate_json(response.content)

async def retrieve(client, collection: int, query: str) -> RetrieveResponse:
request = RetrieveRequest(
collection_id=collection, query=query, include_image_chunks=False
Expand All @@ -58,33 +64,41 @@ async def retrieve(client, collection: int, query: str) -> RetrieveResponse:

async def test_e2e_openai_ada002(client):
collection = await create_collection(client, "openai:text-embedding-ada-002")
document = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF)
chunks = await list_chunks(client, collection, document)
document_id = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF)

document = await get_document(client, document_id)
assert document.extracted_text.startswith("Skeleton-of-Thought")

chunks = await list_chunks(client, collection, document_id)
assert len(chunks) > 0
assert chunks[0].document_id == document
assert chunks[0].document_id == document_id

results = await retrieve(
client, collection, "outline the steps to using skeleton-of-thought prompting"
)
assert len(results.text_results) > 0
print(results.text_results)

assert results.text_results[0].document_id == document
assert results.text_results[0].document_id == document_id
assert "skeleton" in results.text_results[0].text.lower()


async def test_e2e_hf_bge_small(client):
collection = await create_collection(client, "hf:BAAI/bge-small-en")
document = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF)
chunks = await list_chunks(client, collection, document)
document_id = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF)

document = await get_document(client, document_id)
assert document.extracted_text.startswith("Skeleton-of-Thought")

chunks = await list_chunks(client, collection, document_id)
assert len(chunks) > 0
assert chunks[0].document_id == document
assert chunks[0].document_id == document_id

results = await retrieve(
client, collection, "outline the steps to using skeleton-of-thought prompting"
)
assert len(results.text_results) > 0
print(results.text_results)

assert results.text_results[0].document_id == document
assert results.text_results[0].document_id == document_id
assert "skeleton" in results.text_results[0].text.lower()

0 comments on commit 3abf52e

Please sign in to comment.