Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Store the extracted text on documents #41

Merged
merged 3 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion dewy/common/collection_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,14 @@ async def ingest(self, document_id: int, url: str) -> None:
await conn.execute(
"""
UPDATE document
SET ingest_state = 'ingested', ingest_error = NULL
SET
ingest_state = 'ingested',
ingest_error = NULL,
extracted_text = $2
WHERE id = $1
""",
document_id,
extracted.text
)

async def _chunk_sentences(self, text: str) -> List[str]:
Expand Down
8 changes: 8 additions & 0 deletions dewy/documents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ class Document(BaseModel):
id: Optional[int] = None
collection_id: int

extracted_text: Optional[str] = None
"""The text that was extracted for this document.

This is only returned when getting a specific document, not listing documents.

Will not be set until after the document is ingested.
"""

url: str

ingest_state: Optional[IngestState] = None
Expand Down
2 changes: 1 addition & 1 deletion dewy/documents/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def get_document(conn: PgConnectionDep, id: PathDocumentId) -> Document:
# TODO: Test / return not found?
result = await conn.fetchrow(
"""
SELECT id, collection_id, url, ingest_state, ingest_error
SELECT id, collection_id, url, ingest_state, ingest_error, extracted_text
FROM document WHERE id = $1
""",
id,
Expand Down
2 changes: 2 additions & 0 deletions migrations/0002_document_text.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE document
ADD COLUMN extracted_text TEXT;
32 changes: 23 additions & 9 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import TypeAdapter

from dewy.chunks.models import Chunk, RetrieveRequest, RetrieveResponse
from dewy.documents.models import AddDocumentRequest
from dewy.documents.models import AddDocumentRequest, Document

SKELETON_OF_THOUGHT_PDF = "https://arxiv.org/pdf/2307.15337.pdf"

Expand Down Expand Up @@ -46,6 +46,12 @@ async def list_chunks(client, collection: int, document: int):
ta = TypeAdapter(List[Chunk])
return ta.validate_json(response.content)

async def get_document(client, document_id: int) -> Document:
response = await client.get(f"/api/documents/{document_id}")
assert response.status_code == 200
assert response
return Document.model_validate_json(response.content)

async def retrieve(client, collection: int, query: str) -> RetrieveResponse:
request = RetrieveRequest(
collection_id=collection, query=query, include_image_chunks=False
Expand All @@ -58,33 +64,41 @@ async def retrieve(client, collection: int, query: str) -> RetrieveResponse:

async def test_e2e_openai_ada002(client):
collection = await create_collection(client, "openai:text-embedding-ada-002")
document = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF)
chunks = await list_chunks(client, collection, document)
document_id = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF)

document = await get_document(client, document_id)
assert document.extracted_text.startswith("Skeleton-of-Thought")

chunks = await list_chunks(client, collection, document_id)
assert len(chunks) > 0
assert chunks[0].document_id == document
assert chunks[0].document_id == document_id

results = await retrieve(
client, collection, "outline the steps to using skeleton-of-thought prompting"
)
assert len(results.text_results) > 0
print(results.text_results)

assert results.text_results[0].document_id == document
assert results.text_results[0].document_id == document_id
assert "skeleton" in results.text_results[0].text.lower()


async def test_e2e_hf_bge_small(client):
collection = await create_collection(client, "hf:BAAI/bge-small-en")
document = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF)
chunks = await list_chunks(client, collection, document)
document_id = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF)

document = await get_document(client, document_id)
assert document.extracted_text.startswith("Skeleton-of-Thought")

chunks = await list_chunks(client, collection, document_id)
assert len(chunks) > 0
assert chunks[0].document_id == document
assert chunks[0].document_id == document_id

results = await retrieve(
client, collection, "outline the steps to using skeleton-of-thought prompting"
)
assert len(results.text_results) > 0
print(results.text_results)

assert results.text_results[0].document_id == document
assert results.text_results[0].document_id == document_id
assert "skeleton" in results.text_results[0].text.lower()