Skip to content

Commit

Permalink
feat: Store the extracted text on documents (#41)
Browse files Browse the repository at this point in the history
* feat: Store the extracted text on documents

This closes #22.
  • Loading branch information
bjchambers authored Jan 30, 2024
1 parent 5efe21d commit c4af67b
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 19 deletions.
6 changes: 5 additions & 1 deletion dewy/common/collection_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,14 @@ def encode_chunk(c: str) -> str:
await conn.execute(
"""
UPDATE document
SET ingest_state = 'ingested', ingest_error = NULL
SET
ingest_state = 'ingested',
ingest_error = NULL,
extracted_text = $2
WHERE id = $1
""",
document_id,
extracted.text,
)

async def _chunk_sentences(self, text: str) -> List[str]:
Expand Down
8 changes: 8 additions & 0 deletions dewy/document/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ class Document(BaseModel):
id: Optional[int] = None
collection_id: int

extracted_text: Optional[str] = None
"""The text that was extracted for this document.
This is only returned when getting a specific document, not listing documents.
Will not be set until after the document is ingested.
"""

url: str

ingest_state: Optional[IngestState] = None
Expand Down
2 changes: 1 addition & 1 deletion dewy/document/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def get_document(conn: PgConnectionDep, id: PathDocumentId) -> Document:
# TODO: Test / return not found?
result = await conn.fetchrow(
"""
SELECT id, collection_id, url, ingest_state, ingest_error
SELECT id, collection_id, url, ingest_state, ingest_error, extracted_text
FROM document WHERE id = $1
""",
id,
Expand Down
19 changes: 12 additions & 7 deletions dewy/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import contextlib
from typing import AsyncIterator, TypedDict
from pathlib import Path
import os
from pathlib import Path
from typing import AsyncIterator, TypedDict

import uvicorn
import asyncpg
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
Expand All @@ -19,10 +19,12 @@
class State(TypedDict):
pg_pool: asyncpg.Pool


# Resolve paths, independent of PWD
current_file_path = Path(__file__).resolve()
react_build_path = current_file_path.parent.parent / 'frontend' / 'dist'
migrations_path = current_file_path.parent.parent / 'migrations'
react_build_path = current_file_path.parent.parent / "frontend" / "dist"
migrations_path = current_file_path.parent.parent / "migrations"


@contextlib.asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncIterator[State]:
Expand Down Expand Up @@ -66,8 +68,11 @@ async def healthcheck() -> dict[str, str]:
if settings.SERVE_ADMIN_UI and os.path.isdir(react_build_path):
logger.info("Running admin UI at http://localhost/admin")
# Serve static files from the React app build directory
app.mount("/admin", StaticFiles(directory=str(react_build_path), html=True), name="static")
app.mount(
"/admin", StaticFiles(directory=str(react_build_path), html=True), name="static"
)


# Function for running Dewy as a script
def run(*args):
uvicorn.run("dewy.main:app", host="0.0.0.0", port=80)
uvicorn.run("dewy.main:app", host="0.0.0.0", port=80)
2 changes: 2 additions & 0 deletions migrations/0002_document_text.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE document
ADD COLUMN extracted_text TEXT;
35 changes: 25 additions & 10 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from pydantic import TypeAdapter

from dewy.chunks.models import Chunk, RetrieveRequest, RetrieveResponse
from dewy.documents.models import AddDocumentRequest
from dewy.chunk.models import Chunk, RetrieveRequest, RetrieveResponse
from dewy.document.models import AddDocumentRequest, Document

SKELETON_OF_THOUGHT_PDF = "https://arxiv.org/pdf/2307.15337.pdf"

Expand Down Expand Up @@ -47,6 +47,13 @@ async def list_chunks(client, collection: int, document: int):
return ta.validate_json(response.content)


async def get_document(client, document_id: int) -> Document:
response = await client.get(f"/api/documents/{document_id}")
assert response.status_code == 200
assert response
return Document.model_validate_json(response.content)


async def retrieve(client, collection: int, query: str) -> RetrieveResponse:
request = RetrieveRequest(
collection_id=collection, query=query, include_image_chunks=False
Expand All @@ -61,33 +68,41 @@ async def retrieve(client, collection: int, query: str) -> RetrieveResponse:

async def test_e2e_openai_ada002(client):
collection = await create_collection(client, "openai:text-embedding-ada-002")
document = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF)
chunks = await list_chunks(client, collection, document)
document_id = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF)

document = await get_document(client, document_id)
assert document.extracted_text.startswith("Skeleton-of-Thought")

chunks = await list_chunks(client, collection, document_id)
assert len(chunks) > 0
assert chunks[0].document_id == document
assert chunks[0].document_id == document_id

results = await retrieve(
client, collection, "outline the steps to using skeleton-of-thought prompting"
)
assert len(results.text_results) > 0
print(results.text_results)

assert results.text_results[0].document_id == document
assert results.text_results[0].document_id == document_id
assert "skeleton" in results.text_results[0].text.lower()


async def test_e2e_hf_bge_small(client):
collection = await create_collection(client, "hf:BAAI/bge-small-en")
document = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF)
chunks = await list_chunks(client, collection, document)
document_id = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF)

document = await get_document(client, document_id)
assert document.extracted_text.startswith("Skeleton-of-Thought")

chunks = await list_chunks(client, collection, document_id)
assert len(chunks) > 0
assert chunks[0].document_id == document
assert chunks[0].document_id == document_id

results = await retrieve(
client, collection, "outline the steps to using skeleton-of-thought prompting"
)
assert len(results.text_results) > 0
print(results.text_results)

assert results.text_results[0].document_id == document
assert results.text_results[0].document_id == document_id
assert "skeleton" in results.text_results[0].text.lower()

0 comments on commit c4af67b

Please sign in to comment.