diff --git a/dewy/chunks/models.py b/dewy/chunks/models.py index e922634..db2a40d 100644 --- a/dewy/chunks/models.py +++ b/dewy/chunks/models.py @@ -1,7 +1,32 @@ -from typing import Literal, Optional, Sequence, Union +from typing import Literal, Optional, Sequence, Union, Annotated from pydantic import BaseModel, Field +class TextChunk(BaseModel): + id: int + document_id: int + kind: Literal["text"] = "text" + + raw: bool + start_char_idx: Optional[int] = Field( + default=None, description="Start char index of the chunk." + ) + end_char_idx: Optional[int] = Field( + default=None, description="End char index of the chunk." + ) + + +class ImageChunk(BaseModel): + id: int + document_id: int + kind: Literal["image"] = "image" + + image: Optional[str] = Field(..., description="Image of the node.") + image_mimetype: Optional[str] = Field(..., description="Mimetype of the image.") + image_path: Optional[str] = Field(..., description="Path of the image.") + image_url: Optional[str] = Field(..., description="URL of the image.") + +Chunk = Annotated[Union[TextChunk, ImageChunk], Field(discriminator='kind')] class RetrieveRequest(BaseModel): """A request for retrieving chunks from a collection.""" @@ -19,28 +44,37 @@ class RetrieveRequest(BaseModel): # For instance -- if we summarize the text statements, maybe it only includes # images and tables in the response. But for now, this is a big switch to # exclude statements entirely. - include_statements: bool = True - """Whether to include statements in the result. + include_text_chunks: bool = True + """Whether to include text chunks in the result. + + If this is false, no text chunks will be included in the result, although + the summary (if enbaled) may include information from the chunks. + """ + + include_image_chunks: bool = True + """Whether to include image chunks in the result. - If this is false, no statements will be included in the result, although - the summary (if enbaled) may include information from the statements. + If this is false, no image chunks will be included in the result, although + the summary (if enbaled) may include information from the chunks. """ include_summary: bool = False """Whether to include a generated summary.""" +class TextResult(BaseModel): + chunk_id: int + """The ID of the chunk associated with this result""" + + document_id: int + """The ID of the document associated with this result""" -class BaseChunk(BaseModel): - kind: Literal["text", "raw_text", "image"] + score: float + """The similarity score of this result.""" - score: Optional[float] = None - """The similarity score of this chunk.""" + text: str + "Textual description of the chunk." - -class TextChunk(BaseChunk): - kind: Literal["text"] = "text" raw: bool - text: str = Field(default="", description="Text content of the chunk.") start_char_idx: Optional[int] = Field( default=None, description="Start char index of the chunk." ) @@ -48,21 +82,29 @@ class TextChunk(BaseChunk): default=None, description="End char index of the chunk." ) +class ImageResult(BaseModel): + chunk_id: int + """The ID of the chunk associated with this result""" + + document_id: int + """The ID of the document associated with this result""" + + score: float + """The similarity score of this result.""" -class ImageChunk(BaseChunk): - kind: Literal["image"] = "image" - text: Optional[str] = Field(..., description="Textual description of the image.") image: Optional[str] = Field(..., description="Image of the node.") image_mimetype: Optional[str] = Field(..., description="Mimetype of the image.") image_path: Optional[str] = Field(..., description="Path of the image.") image_url: Optional[str] = Field(..., description="URL of the image.") - class RetrieveResponse(BaseModel): - """The response from a chunk retrieval request.""" + """The response from a retrieval request.""" summary: Optional[str] """Summary of the retrieved chunks.""" - chunks: Sequence[Union[TextChunk, ImageChunk]] - """Retrieved chunks.""" + text_results: Sequence[TextResult] + """Retrieved text chunks.""" + + image_results: Sequence[ImageResult] + """Retrieved image chunks.""" \ No newline at end of file diff --git a/dewy/chunks/router.py b/dewy/chunks/router.py index e7f0fa8..16de872 100644 --- a/dewy/chunks/router.py +++ b/dewy/chunks/router.py @@ -1,12 +1,52 @@ -from fastapi import APIRouter +from typing import Union, Annotated, List + +from fastapi import APIRouter, Query, Path from dewy.common.collection_embeddings import CollectionEmbeddings from dewy.common.db import PgPoolDep -from .models import RetrieveRequest, RetrieveResponse +from .models import Chunk, RetrieveRequest, RetrieveResponse router = APIRouter(prefix="/chunks") +@router.get("/") +async def list_chunks( + pg_pool: PgPoolDep, + collection_id: Annotated[int | None, Query(description="Limit to chunks associated with this collection")] = None, + document_id: Annotated[int | None, Query(description="Limit to chunks associated with this document")] = None, +) -> List[Chunk]: + """List chunks.""" + + # TODO: handle collection & document ID + results = await pg_pool.fetch( + """ + SELECT chunk.id, chunk.document_id, chunk.kind, chunk.text + FROM chunk + WHERE document.collection_id = coalesce($1, document.collection_id) + AND chunk.document_id = coalesce($2, chunk.document_id) + JOIN document ON document.id = chunk.document_id + """, + collection_id, + document_id, + ) + return [Chunk.model_validate(dict(result)) for result in results] + +PathChunkId = Annotated[int, Path(..., description="The chunk ID.")] + +@router.get("/{id}") +async def get_chunk( + pg_pool: PgPoolDep, + id: PathChunkId, +) -> Chunk: + # TODO: Test / return not found? + result = await pg_pool.fetchrow( + """ + SELECT id, document_id, kind, text + FROM chunk WHERE id = $1 + """, + id, + ) + return Chunk.model_validate(dict(result)) @router.post("/retrieve") async def retrieve_chunks( @@ -19,6 +59,10 @@ async def retrieve_chunks( collection = await CollectionEmbeddings.for_collection_id( pg_pool, request.collection_id ) - chunks = await collection.retrieve_text_chunks(query=request.query, n=request.n) + text_results = await collection.retrieve_text_chunks(query=request.query, n=request.n) - return RetrieveResponse(summary=None, chunks=chunks) + return RetrieveResponse( + summary=None, + text_results=text_results if request.include_text_chunks else [], + # image_results=image_results if request.include_image_chunks else [], + ) diff --git a/dewy/common/collection_embeddings.py b/dewy/common/collection_embeddings.py index 5ab934e..e252187 100644 --- a/dewy/common/collection_embeddings.py +++ b/dewy/common/collection_embeddings.py @@ -6,7 +6,7 @@ from llama_index.schema import TextNode from loguru import logger -from dewy.chunks.models import TextChunk +from dewy.chunks.models import TextResult from dewy.collections.models import DistanceMetric from dewy.collections.router import get_dimensions @@ -65,7 +65,8 @@ def __init__( SELECT relevant_embeddings.chunk_id AS chunk_id, chunk.text AS text, - relevant_embeddings.score AS score + relevant_embeddings.score AS score, + chunk.document_id AS document_id FROM relevant_embeddings JOIN chunk ON chunk.id = relevant_embeddings.chunk_id @@ -149,7 +150,7 @@ async def retrieve_text_embeddings( embeddings = [e["chunk_id"] for e in embeddings] return embeddings - async def retrieve_text_chunks(self, query: str, n: int = 10) -> List[TextChunk]: + async def retrieve_text_chunks(self, query: str, n: int = 10) -> List[TextResult]: """Retrieve embeddings related to the given query. Parameters: @@ -168,7 +169,15 @@ async def retrieve_text_chunks(self, query: str, n: int = 10) -> List[TextChunk] embedded_query, n) embeddings = [ - TextChunk(raw=True, score=e["score"], text=e["text"]) + TextResult( + chunk_id=e["chunk_id"], + document_id=e["document_id"], + score=e["score"], + text=e["text"], + raw=True, + start_char_idx=None, + end_char_idx=None, + ) for e in embeddings ] return embeddings