diff --git a/dewy/common/collection_embeddings.py b/dewy/common/collection_embeddings.py index 00ce940..9984a4b 100644 --- a/dewy/common/collection_embeddings.py +++ b/dewy/common/collection_embeddings.py @@ -283,10 +283,14 @@ def encode_chunk(c: str) -> str: await conn.execute( """ UPDATE document - SET ingest_state = 'ingested', ingest_error = NULL + SET + ingest_state = 'ingested', + ingest_error = NULL, + extracted_text = $2 WHERE id = $1 """, document_id, + extracted.text, ) async def _chunk_sentences(self, text: str) -> List[str]: diff --git a/dewy/document/models.py b/dewy/document/models.py index fb59ea7..603cf67 100644 --- a/dewy/document/models.py +++ b/dewy/document/models.py @@ -29,6 +29,14 @@ class Document(BaseModel): id: Optional[int] = None collection_id: int + extracted_text: Optional[str] = None + """The text that was extracted for this document. + + This is only returned when getting a specific document, not listing documents. + + Will not be set until after the document is ingested. + """ + url: str ingest_state: Optional[IngestState] = None diff --git a/dewy/document/router.py b/dewy/document/router.py index f34172f..a29cbc2 100644 --- a/dewy/document/router.py +++ b/dewy/document/router.py @@ -78,7 +78,7 @@ async def get_document(conn: PgConnectionDep, id: PathDocumentId) -> Document: # TODO: Test / return not found? result = await conn.fetchrow( """ - SELECT id, collection_id, url, ingest_state, ingest_error + SELECT id, collection_id, url, ingest_state, ingest_error, extracted_text FROM document WHERE id = $1 """, id, diff --git a/dewy/main.py b/dewy/main.py index 49cce4d..e896edf 100644 --- a/dewy/main.py +++ b/dewy/main.py @@ -1,10 +1,10 @@ import contextlib -from typing import AsyncIterator, TypedDict -from pathlib import Path import os +from pathlib import Path +from typing import AsyncIterator, TypedDict -import uvicorn import asyncpg +import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles @@ -19,10 +19,12 @@ class State(TypedDict): pg_pool: asyncpg.Pool + # Resolve paths, independent of PWD current_file_path = Path(__file__).resolve() -react_build_path = current_file_path.parent.parent / 'frontend' / 'dist' -migrations_path = current_file_path.parent.parent / 'migrations' +react_build_path = current_file_path.parent.parent / "frontend" / "dist" +migrations_path = current_file_path.parent.parent / "migrations" + @contextlib.asynccontextmanager async def lifespan(_app: FastAPI) -> AsyncIterator[State]: @@ -66,8 +68,11 @@ async def healthcheck() -> dict[str, str]: if settings.SERVE_ADMIN_UI and os.path.isdir(react_build_path): logger.info("Running admin UI at http://localhost/admin") # Serve static files from the React app build directory - app.mount("/admin", StaticFiles(directory=str(react_build_path), html=True), name="static") + app.mount( + "/admin", StaticFiles(directory=str(react_build_path), html=True), name="static" + ) + # Function for running Dewy as a script def run(*args): - uvicorn.run("dewy.main:app", host="0.0.0.0", port=80) \ No newline at end of file + uvicorn.run("dewy.main:app", host="0.0.0.0", port=80) diff --git a/migrations/0002_document_text.sql b/migrations/0002_document_text.sql new file mode 100644 index 0000000..47f95f5 --- /dev/null +++ b/migrations/0002_document_text.sql @@ -0,0 +1,2 @@ +ALTER TABLE document +ADD COLUMN extracted_text TEXT; \ No newline at end of file diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 15c8c5e..1743dde 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -5,8 +5,8 @@ from pydantic import TypeAdapter -from dewy.chunks.models import Chunk, RetrieveRequest, RetrieveResponse -from dewy.documents.models import AddDocumentRequest +from dewy.chunk.models import Chunk, RetrieveRequest, RetrieveResponse +from dewy.document.models import AddDocumentRequest, Document SKELETON_OF_THOUGHT_PDF = "https://arxiv.org/pdf/2307.15337.pdf" @@ -47,6 +47,13 @@ async def list_chunks(client, collection: int, document: int): return ta.validate_json(response.content) +async def get_document(client, document_id: int) -> Document: + response = await client.get(f"/api/documents/{document_id}") + assert response.status_code == 200 + assert response + return Document.model_validate_json(response.content) + + async def retrieve(client, collection: int, query: str) -> RetrieveResponse: request = RetrieveRequest( collection_id=collection, query=query, include_image_chunks=False @@ -61,10 +68,14 @@ async def retrieve(client, collection: int, query: str) -> RetrieveResponse: async def test_e2e_openai_ada002(client): collection = await create_collection(client, "openai:text-embedding-ada-002") - document = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF) - chunks = await list_chunks(client, collection, document) + document_id = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF) + + document = await get_document(client, document_id) + assert document.extracted_text.startswith("Skeleton-of-Thought") + + chunks = await list_chunks(client, collection, document_id) assert len(chunks) > 0 - assert chunks[0].document_id == document + assert chunks[0].document_id == document_id results = await retrieve( client, collection, "outline the steps to using skeleton-of-thought prompting" @@ -72,16 +83,20 @@ async def test_e2e_openai_ada002(client): assert len(results.text_results) > 0 print(results.text_results) - assert results.text_results[0].document_id == document + assert results.text_results[0].document_id == document_id assert "skeleton" in results.text_results[0].text.lower() async def test_e2e_hf_bge_small(client): collection = await create_collection(client, "hf:BAAI/bge-small-en") - document = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF) - chunks = await list_chunks(client, collection, document) + document_id = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF) + + document = await get_document(client, document_id) + assert document.extracted_text.startswith("Skeleton-of-Thought") + + chunks = await list_chunks(client, collection, document_id) assert len(chunks) > 0 - assert chunks[0].document_id == document + assert chunks[0].document_id == document_id results = await retrieve( client, collection, "outline the steps to using skeleton-of-thought prompting" @@ -89,5 +104,5 @@ async def test_e2e_hf_bge_small(client): assert len(results.text_results) > 0 print(results.text_results) - assert results.text_results[0].document_id == document + assert results.text_results[0].document_id == document_id assert "skeleton" in results.text_results[0].text.lower()