Skip to content

Commit

Permalink
feat: Automated e2e ingest/retrieve tests (#36)
Browse files Browse the repository at this point in the history
* feat: Automated e2e ingest/retrieve tests
  • Loading branch information
bjchambers authored Jan 29, 2024
1 parent 4993b5a commit cfde7a7
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 11 deletions.
1 change: 1 addition & 0 deletions dewy/chunks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class TextChunk(BaseModel):
id: int
document_id: int
kind: Literal["text"] = "text"
text: str

raw: bool
text: str
Expand Down
1 change: 0 additions & 1 deletion dewy/chunks/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,4 @@ async def retrieve_chunks(
summary=None,
text_results=text_results if request.include_text_chunks else [],
image_results=[],
# image_results=image_results if request.include_image_chunks else [],
)
2 changes: 1 addition & 1 deletion dewy/collections/router.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Annotated, List

from fastapi import APIRouter, Path
from loguru import logger

from dewy.collections.models import Collection, CollectionCreate
from dewy.common.collection_embeddings import get_dimensions
from dewy.common.db import PgConnectionDep
from loguru import logger

router = APIRouter(prefix="/collections")

Expand Down
6 changes: 1 addition & 5 deletions dewy/common/collection_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from io import text_encoding
from typing import List, Self, Tuple

import asyncpg
Expand Down Expand Up @@ -172,10 +171,7 @@ async def retrieve_text_chunks(self, query: str, n: int = 10) -> List[TextResult
async with self._pg_pool.acquire() as conn:
logger.info("Executing SQL query for chunks from {}", self.collection_id)
embeddings = await conn.fetch(
self._retrieve_chunks,
self.collection_id,
embedded_query,
n
self._retrieve_chunks, self.collection_id, embedded_query, n
)
embeddings = [
TextResult(
Expand Down
2 changes: 1 addition & 1 deletion dewy/common/db_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def _apply_migration(
return False
elif applied_sha256 is not None:
raise ValueError(
f"Migration '{migration_path}' already applied with different SHA. Recreate DB."
f"'{migration_path}' applied with different SHA. Recreate DB."
)
else:
logger.info("Applying migration '{}'", migration_path)
Expand Down
2 changes: 1 addition & 1 deletion dewy/documents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel


class CreateRequest(BaseModel):
class AddDocumentRequest(BaseModel):
collection_id: Optional[int] = None
"""The id of the collection the document should be added to. Either `collection` or `collection_id` must be provided"""

Expand Down
4 changes: 2 additions & 2 deletions dewy/documents/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dewy.common.db import PgConnectionDep, PgPoolDep
from dewy.documents.models import Document

from .models import CreateRequest
from .models import AddDocumentRequest

router = APIRouter(prefix="/documents")

Expand All @@ -21,7 +21,7 @@ async def ingest_document(document_id: int, pg_pool: asyncpg.Pool) -> None:
async def add_document(
pg_pool: PgPoolDep,
background: BackgroundTasks,
req: CreateRequest,
req: AddDocumentRequest,
) -> Document:
"""Add a document."""

Expand Down
90 changes: 90 additions & 0 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import random
import string
import time
from typing import List

from pydantic import TypeAdapter

from dewy.chunks.models import Chunk, RetrieveRequest, RetrieveResponse
from dewy.documents.models import AddDocumentRequest

SKELETON_OF_THOUGHT_PDF = "https://arxiv.org/pdf/2307.15337.pdf"


async def create_collection(client, text_embedding_model: str) -> int:
name = "".join(random.choices(string.ascii_lowercase, k=5))
create_response = await client.put("/api/collections/", json={"name": name})
assert create_response.status_code == 200

return create_response.json()["id"]


async def ingest(client, collection: int, url: str) -> int:
add_request = AddDocumentRequest(collection_id=collection, url=url)
add_response = await client.put(
"/api/documents/", data=add_request.model_dump_json()
)
assert add_response.status_code == 200

document_id = add_response.json()["id"]

# TODO(https://github.com/DewyKB/dewy/issues/34): Move waiting to the server
# and eliminate need to poll.
status = await client.get(f"/api/documents/{document_id}")
while status.json()["ingest_state"] != "ingested":
time.sleep(1)
status = await client.get(f"/api/documents/{document_id}")

return document_id

async def list_chunks(client, collection: int, document: int):
response = await client.get("/api/chunks/", params = {
'collection_id': collection,
'document_id': document
})
assert response.status_code == 200
ta = TypeAdapter(List[Chunk])
return ta.validate_json(response.content)

async def retrieve(client, collection: int, query: str) -> RetrieveResponse:
request = RetrieveRequest(
collection_id=collection, query=query, include_image_chunks=False
)

response = await client.post("/api/chunks/retrieve", data=request.model_dump_json())
assert response.status_code == 200
return RetrieveResponse.model_validate_json(response.content)


async def test_e2e_openai_ada002(client):
collection = await create_collection(client, "openai:text-embedding-ada-002")
document = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF)
chunks = await list_chunks(client, collection, document)
assert len(chunks) > 0
assert chunks[0].document_id == document

results = await retrieve(
client, collection, "outline the steps to using skeleton-of-thought prompting"
)
assert len(results.text_results) > 0
print(results.text_results)

assert results.text_results[0].document_id == document
assert "skeleton" in results.text_results[0].text.lower()


async def test_e2e_hf_bge_small(client):
collection = await create_collection(client, "hf:BAAI/bge-small-en")
document = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF)
chunks = await list_chunks(client, collection, document)
assert len(chunks) > 0
assert chunks[0].document_id == document

results = await retrieve(
client, collection, "outline the steps to using skeleton-of-thought prompting"
)
assert len(results.text_results) > 0
print(results.text_results)

assert results.text_results[0].document_id == document
assert "skeleton" in results.text_results[0].text.lower()

0 comments on commit cfde7a7

Please sign in to comment.