diff --git a/dewy/chunks/models.py b/dewy/chunks/models.py index cf02c90..3222db8 100644 --- a/dewy/chunks/models.py +++ b/dewy/chunks/models.py @@ -7,6 +7,7 @@ class TextChunk(BaseModel): id: int document_id: int kind: Literal["text"] = "text" + text: str raw: bool text: str diff --git a/dewy/chunks/router.py b/dewy/chunks/router.py index 0963ee4..279e2ec 100644 --- a/dewy/chunks/router.py +++ b/dewy/chunks/router.py @@ -82,5 +82,4 @@ async def retrieve_chunks( summary=None, text_results=text_results if request.include_text_chunks else [], image_results=[], - # image_results=image_results if request.include_image_chunks else [], ) diff --git a/dewy/collections/router.py b/dewy/collections/router.py index aa3e1ae..02eb954 100644 --- a/dewy/collections/router.py +++ b/dewy/collections/router.py @@ -1,11 +1,11 @@ from typing import Annotated, List from fastapi import APIRouter, Path +from loguru import logger from dewy.collections.models import Collection, CollectionCreate from dewy.common.collection_embeddings import get_dimensions from dewy.common.db import PgConnectionDep -from loguru import logger router = APIRouter(prefix="/collections") diff --git a/dewy/common/collection_embeddings.py b/dewy/common/collection_embeddings.py index 754a223..12044d0 100644 --- a/dewy/common/collection_embeddings.py +++ b/dewy/common/collection_embeddings.py @@ -1,4 +1,3 @@ -from io import text_encoding from typing import List, Self, Tuple import asyncpg @@ -172,10 +171,7 @@ async def retrieve_text_chunks(self, query: str, n: int = 10) -> List[TextResult async with self._pg_pool.acquire() as conn: logger.info("Executing SQL query for chunks from {}", self.collection_id) embeddings = await conn.fetch( - self._retrieve_chunks, - self.collection_id, - embedded_query, - n + self._retrieve_chunks, self.collection_id, embedded_query, n ) embeddings = [ TextResult( diff --git a/dewy/common/db_migration.py b/dewy/common/db_migration.py index 0fc75bb..4fc99c3 100644 --- a/dewy/common/db_migration.py +++ b/dewy/common/db_migration.py @@ -108,7 +108,7 @@ async def _apply_migration( return False elif applied_sha256 is not None: raise ValueError( - f"Migration '{migration_path}' already applied with different SHA. Recreate DB." + f"'{migration_path}' applied with different SHA. Recreate DB." ) else: logger.info("Applying migration '{}'", migration_path) diff --git a/dewy/documents/models.py b/dewy/documents/models.py index 1c81ac3..6f62dca 100644 --- a/dewy/documents/models.py +++ b/dewy/documents/models.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -class CreateRequest(BaseModel): +class AddDocumentRequest(BaseModel): collection_id: Optional[int] = None """The id of the collection the document should be added to. Either `collection` or `collection_id` must be provided""" diff --git a/dewy/documents/router.py b/dewy/documents/router.py index 8e01e45..e194fbc 100644 --- a/dewy/documents/router.py +++ b/dewy/documents/router.py @@ -7,7 +7,7 @@ from dewy.common.db import PgConnectionDep, PgPoolDep from dewy.documents.models import Document -from .models import CreateRequest +from .models import AddDocumentRequest router = APIRouter(prefix="/documents") @@ -21,7 +21,7 @@ async def ingest_document(document_id: int, pg_pool: asyncpg.Pool) -> None: async def add_document( pg_pool: PgPoolDep, background: BackgroundTasks, - req: CreateRequest, + req: AddDocumentRequest, ) -> Document: """Add a document.""" diff --git a/tests/test_e2e.py b/tests/test_e2e.py new file mode 100644 index 0000000..7c39c34 --- /dev/null +++ b/tests/test_e2e.py @@ -0,0 +1,90 @@ +import random +import string +import time +from typing import List + +from pydantic import TypeAdapter + +from dewy.chunks.models import Chunk, RetrieveRequest, RetrieveResponse +from dewy.documents.models import AddDocumentRequest + +SKELETON_OF_THOUGHT_PDF = "https://arxiv.org/pdf/2307.15337.pdf" + + +async def create_collection(client, text_embedding_model: str) -> int: + name = "".join(random.choices(string.ascii_lowercase, k=5)) + create_response = await client.put("/api/collections/", json={"name": name}) + assert create_response.status_code == 200 + + return create_response.json()["id"] + + +async def ingest(client, collection: int, url: str) -> int: + add_request = AddDocumentRequest(collection_id=collection, url=url) + add_response = await client.put( + "/api/documents/", data=add_request.model_dump_json() + ) + assert add_response.status_code == 200 + + document_id = add_response.json()["id"] + + # TODO(https://github.com/DewyKB/dewy/issues/34): Move waiting to the server + # and eliminate need to poll. + status = await client.get(f"/api/documents/{document_id}") + while status.json()["ingest_state"] != "ingested": + time.sleep(1) + status = await client.get(f"/api/documents/{document_id}") + + return document_id + +async def list_chunks(client, collection: int, document: int): + response = await client.get("/api/chunks/", params = { + 'collection_id': collection, + 'document_id': document + }) + assert response.status_code == 200 + ta = TypeAdapter(List[Chunk]) + return ta.validate_json(response.content) + +async def retrieve(client, collection: int, query: str) -> RetrieveResponse: + request = RetrieveRequest( + collection_id=collection, query=query, include_image_chunks=False + ) + + response = await client.post("/api/chunks/retrieve", data=request.model_dump_json()) + assert response.status_code == 200 + return RetrieveResponse.model_validate_json(response.content) + + +async def test_e2e_openai_ada002(client): + collection = await create_collection(client, "openai:text-embedding-ada-002") + document = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF) + chunks = await list_chunks(client, collection, document) + assert len(chunks) > 0 + assert chunks[0].document_id == document + + results = await retrieve( + client, collection, "outline the steps to using skeleton-of-thought prompting" + ) + assert len(results.text_results) > 0 + print(results.text_results) + + assert results.text_results[0].document_id == document + assert "skeleton" in results.text_results[0].text.lower() + + +async def test_e2e_hf_bge_small(client): + collection = await create_collection(client, "hf:BAAI/bge-small-en") + document = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF) + chunks = await list_chunks(client, collection, document) + assert len(chunks) > 0 + assert chunks[0].document_id == document + + results = await retrieve( + client, collection, "outline the steps to using skeleton-of-thought prompting" + ) + assert len(results.text_results) > 0 + print(results.text_results) + + assert results.text_results[0].document_id == document + assert "skeleton" in results.text_results[0].text.lower()