diff --git a/app/chunks/models.py b/app/chunks/models.py index c236ccc..3482d69 100644 --- a/app/chunks/models.py +++ b/app/chunks/models.py @@ -2,6 +2,31 @@ from pydantic import BaseModel, Field +class TextChunk(BaseModel): + id: int + document_id: int + kind: Literal["text"] = "text" + + raw: bool + start_char_idx: Optional[int] = Field( + default=None, description="Start char index of the chunk." + ) + end_char_idx: Optional[int] = Field( + default=None, description="End char index of the chunk." + ) + + +class ImageChunk(BaseModel): + id: int + document_id: int + kind: Literal["image"] = "image" + + image: Optional[str] = Field(..., description="Image of the node.") + image_mimetype: Optional[str] = Field(..., description="Mimetype of the image.") + image_path: Optional[str] = Field(..., description="Path of the image.") + image_url: Optional[str] = Field(..., description="URL of the image.") + +Chunk = Annotated[Union[TextChunk, ImageChunk], Field(discriminator='kind')] class RetrieveRequest(BaseModel): """A request for retrieving unstructured (document) results.""" @@ -16,24 +41,52 @@ class RetrieveRequest(BaseModel): # For instance -- if we summarize the text statements, maybe it only includes # images and tables in the response. But for now, this is a big switch to # exclude statements entirely. - include_statements: bool = True - """Whether to include statements in the result. + include_text_chunks: bool = True + """Whether to include text chunks in the result. + + If this is false, no text chunks will be included in the result, although + the summary (if enbaled) may include information from the chunks. + """ + + include_image_chunks: bool = True + """Whether to include image chunks in the result. - If this is false, no statements will be included in the result, although - the summary (if enbaled) may include information from the statements. + If this is false, no image chunks will be included in the result, although + the summary (if enbaled) may include information from the chunks. """ include_summary: bool = False """Whether to include a generated summary.""" -class TextChunk(BaseModel): - id: Optional[int] = None + + +class RetrieveResult(BaseModel): + chunk_id: int + document_id: int - kind: Literal["text"] = "text" - raw: bool - text: str = Field(default="", description="Text content of the chunk.") + score: float + """The similarity score of this chunk.""" + + text: Optional[str] = Field(..., description="Textual description of the chunk.") + + metadata: Union[TextChunk, ImageChunk] = Field(..., discriminator='kind') + +class TextResult(BaseModel): + chunk_id: int + """The ID of the chunk associated with this result""" + + document_id: int + """The ID of the document associated with this result""" + + score: float + """The similarity score of this result.""" + + text: str + "Textual description of the chunk." + + raw: bool start_char_idx: Optional[int] = Field( default=None, description="Start char index of the chunk." ) @@ -41,32 +94,29 @@ class TextChunk(BaseModel): default=None, description="End char index of the chunk." ) - -class ImageChunk(BaseModel): - id: Optional[int] = None +class ImageResult(BaseModel): + chunk_id: int + """The ID of the chunk associated with this result""" + document_id: int - kind: Literal["image"] = "image" + """The ID of the document associated with this result""" + + score: float + """The similarity score of this result.""" - text: Optional[str] = Field(..., description="Textual description of the image.") image: Optional[str] = Field(..., description="Image of the node.") image_mimetype: Optional[str] = Field(..., description="Mimetype of the image.") image_path: Optional[str] = Field(..., description="Path of the image.") image_url: Optional[str] = Field(..., description="URL of the image.") -Chunk = Annotated[Union[TextChunk, ImageChunk], Field(discriminator='kind')] - -class RetrieveResult(BaseModel): - score: Optional[float] = None - """The similarity score of this chunk.""" - - chunk: Chunk - """Retrieved chunks.""" - class RetrieveResponse(BaseModel): - """The response from a chunk retrieval request.""" + """The response from a retrieval request.""" summary: Optional[str] """Summary of the retrieved chunks.""" - results: Sequence[RetrieveResult] - """Retrieved results.""" + text_results: Sequence[TextResult] + """Retrieved text chunks.""" + + image_results: Sequence[ImageResult] + """Retrieved image chunks.""" \ No newline at end of file diff --git a/app/chunks/router.py b/app/chunks/router.py index 163fe5a..ba961eb 100644 --- a/app/chunks/router.py +++ b/app/chunks/router.py @@ -7,7 +7,7 @@ from app.common.db import PgConnectionDep, PgPoolDep from app.ingest.store import StoreDep -from .models import Chunk, ImageChunk, RetrieveRequest, RetrieveResponse, TextChunk +from .models import Chunk, ImageChunk, RetrieveRequest, RetrieveResponse, RetrieveResult, TextResult, ImageResult router = APIRouter(prefix="/chunks") @@ -47,7 +47,7 @@ async def get_chunk( @router.post("/retrieve") async def retrieve_chunks( store: StoreDep, request: RetrieveRequest -) -> RetrieveResponse: +) -> List[RetrieveResult]: """Retrieve chunks based on a given query.""" from llama_index.response_synthesizers import ResponseMode @@ -61,35 +61,44 @@ async def retrieve_chunks( # TODO: metadata filters / ACLs ).query(request.query) - statements = [node_to_statement(node) for node in results.source_nodes] + from llama_index.schema import ImageNode, TextNode + text_results = [node_to_text_result(node) for node in results.source_nodes if isinstance(node.node, TextNode)] + image_results = [node_to_image_result(node) for node in results.source_nodes if isinstance(node.node, ImageNode)] return RetrieveResponse( summary=results.response, - chunks=statements if request.include_statements else [], + text_results=text_results if request.include_text_chunks else [], + image_results=image_results if request.include_image_chunks else [], ) -def node_to_statement(node: NodeWithScore) -> Union[TextChunk, ImageChunk]: - from llama_index.schema import ImageNode, TextNode +def node_to_text_result(node: NodeWithScore) -> TextResult: + return RetrieveResult( + # TODO: Populate for real + id=0, + kind='text', + document_id=0, + score=node.score, + + raw=True, + text=node.node.text, + start_char_idx=node.node.start_char_idx, + end_char_idx=node.node.end_char_idx, + ) + +def node_to_image_result(node: NodeWithScore) -> ImageResult: + return RetrieveResult( + score=node.score, + chunk=ImageChunk( + # TODO: Populate for real + id=0, + kind='image', + document_id=0, - if isinstance(node.node, TextNode): - return TextChunk( - raw=True, - score=node.score, - text=node.node.text, - start_char_idx=node.node.start_char_idx, - end_char_idx=node.node.end_char_idx, - ) - elif isinstance(node.node, ImageNode): - return ImageChunk( - score=node.score, text=node.node.text if node.node.text else None, image=node.node.image, image_mimetype=node.node.image_mimetype, image_path=node.node.image_path, image_url=node.node.image_url, ) - else: - raise NotImplementedError( - f"Unsupported node type ({node.node.class_name()}): {node!r}" - ) + )