Skip to content

Commit

Permalink
Quick & dirty get endpoints for chunks
Browse files Browse the repository at this point in the history
This makes some changes to the chunk data model to make it easier to
return as a resource - I expect this will evolve further as we start
actually populating the DB with chunks.
  • Loading branch information
kerinin committed Jan 26, 2024
1 parent 52b685b commit 6de4e32
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 28 deletions.
82 changes: 62 additions & 20 deletions dewy/chunks/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,32 @@
from typing import Literal, Optional, Sequence, Union
from typing import Literal, Optional, Sequence, Union, Annotated

from pydantic import BaseModel, Field

class TextChunk(BaseModel):
id: int
document_id: int
kind: Literal["text"] = "text"

raw: bool
start_char_idx: Optional[int] = Field(
default=None, description="Start char index of the chunk."
)
end_char_idx: Optional[int] = Field(
default=None, description="End char index of the chunk."
)


class ImageChunk(BaseModel):
id: int
document_id: int
kind: Literal["image"] = "image"

image: Optional[str] = Field(..., description="Image of the node.")
image_mimetype: Optional[str] = Field(..., description="Mimetype of the image.")
image_path: Optional[str] = Field(..., description="Path of the image.")
image_url: Optional[str] = Field(..., description="URL of the image.")

Chunk = Annotated[Union[TextChunk, ImageChunk], Field(discriminator='kind')]

class RetrieveRequest(BaseModel):
"""A request for retrieving chunks from a collection."""
Expand All @@ -19,50 +44,67 @@ class RetrieveRequest(BaseModel):
# For instance -- if we summarize the text statements, maybe it only includes
# images and tables in the response. But for now, this is a big switch to
# exclude statements entirely.
include_statements: bool = True
"""Whether to include statements in the result.
include_text_chunks: bool = True
"""Whether to include text chunks in the result.
If this is false, no text chunks will be included in the result, although
the summary (if enbaled) may include information from the chunks.
"""

include_image_chunks: bool = True
"""Whether to include image chunks in the result.
If this is false, no statements will be included in the result, although
the summary (if enbaled) may include information from the statements.
If this is false, no image chunks will be included in the result, although
the summary (if enbaled) may include information from the chunks.
"""

include_summary: bool = False
"""Whether to include a generated summary."""

class TextResult(BaseModel):
chunk_id: int
"""The ID of the chunk associated with this result"""

document_id: int
"""The ID of the document associated with this result"""

class BaseChunk(BaseModel):
kind: Literal["text", "raw_text", "image"]
score: float
"""The similarity score of this result."""

score: Optional[float] = None
"""The similarity score of this chunk."""
text: str
"Textual description of the chunk."


class TextChunk(BaseChunk):
kind: Literal["text"] = "text"
raw: bool
text: str = Field(default="", description="Text content of the chunk.")
start_char_idx: Optional[int] = Field(
default=None, description="Start char index of the chunk."
)
end_char_idx: Optional[int] = Field(
default=None, description="End char index of the chunk."
)

class ImageResult(BaseModel):
chunk_id: int
"""The ID of the chunk associated with this result"""

document_id: int
"""The ID of the document associated with this result"""

score: float
"""The similarity score of this result."""

class ImageChunk(BaseChunk):
kind: Literal["image"] = "image"
text: Optional[str] = Field(..., description="Textual description of the image.")
image: Optional[str] = Field(..., description="Image of the node.")
image_mimetype: Optional[str] = Field(..., description="Mimetype of the image.")
image_path: Optional[str] = Field(..., description="Path of the image.")
image_url: Optional[str] = Field(..., description="URL of the image.")


class RetrieveResponse(BaseModel):
"""The response from a chunk retrieval request."""
"""The response from a retrieval request."""

summary: Optional[str]
"""Summary of the retrieved chunks."""

chunks: Sequence[Union[TextChunk, ImageChunk]]
"""Retrieved chunks."""
text_results: Sequence[TextResult]
"""Retrieved text chunks."""

image_results: Sequence[ImageResult]
"""Retrieved image chunks."""
52 changes: 48 additions & 4 deletions dewy/chunks/router.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,52 @@
from fastapi import APIRouter
from typing import Union, Annotated, List

from fastapi import APIRouter, Query, Path

from dewy.common.collection_embeddings import CollectionEmbeddings
from dewy.common.db import PgPoolDep

from .models import RetrieveRequest, RetrieveResponse
from .models import Chunk, RetrieveRequest, RetrieveResponse

router = APIRouter(prefix="/chunks")

@router.get("/")
async def list_chunks(
pg_pool: PgPoolDep,
collection_id: Annotated[int | None, Query(description="Limit to chunks associated with this collection")] = None,
document_id: Annotated[int | None, Query(description="Limit to chunks associated with this document")] = None,
) -> List[Chunk]:
"""List chunks."""

# TODO: handle collection & document ID
results = await pg_pool.fetch(
"""
SELECT chunk.id, chunk.document_id, chunk.kind, chunk.text
FROM chunk
WHERE document.collection_id = coalesce($1, document.collection_id)
AND chunk.document_id = coalesce($2, chunk.document_id)
JOIN document ON document.id = chunk.document_id
""",
collection_id,
document_id,
)
return [Chunk.model_validate(dict(result)) for result in results]

PathChunkId = Annotated[int, Path(..., description="The chunk ID.")]

@router.get("/{id}")
async def get_chunk(
pg_pool: PgPoolDep,
id: PathChunkId,
) -> Chunk:
# TODO: Test / return not found?
result = await pg_pool.fetchrow(
"""
SELECT id, document_id, kind, text
FROM chunk WHERE id = $1
""",
id,
)
return Chunk.model_validate(dict(result))

@router.post("/retrieve")
async def retrieve_chunks(
Expand All @@ -19,6 +59,10 @@ async def retrieve_chunks(
collection = await CollectionEmbeddings.for_collection_id(
pg_pool, request.collection_id
)
chunks = await collection.retrieve_text_chunks(query=request.query, n=request.n)
text_results = await collection.retrieve_text_chunks(query=request.query, n=request.n)

return RetrieveResponse(summary=None, chunks=chunks)
return RetrieveResponse(
summary=None,
text_results=text_results if request.include_text_chunks else [],
# image_results=image_results if request.include_image_chunks else [],
)
17 changes: 13 additions & 4 deletions dewy/common/collection_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from llama_index.schema import TextNode
from loguru import logger

from dewy.chunks.models import TextChunk
from dewy.chunks.models import TextResult
from dewy.collections.models import DistanceMetric
from dewy.collections.router import get_dimensions

Expand Down Expand Up @@ -65,7 +65,8 @@ def __init__(
SELECT
relevant_embeddings.chunk_id AS chunk_id,
chunk.text AS text,
relevant_embeddings.score AS score
relevant_embeddings.score AS score,
chunk.document_id AS document_id
FROM relevant_embeddings
JOIN chunk
ON chunk.id = relevant_embeddings.chunk_id
Expand Down Expand Up @@ -149,7 +150,7 @@ async def retrieve_text_embeddings(
embeddings = [e["chunk_id"] for e in embeddings]
return embeddings

async def retrieve_text_chunks(self, query: str, n: int = 10) -> List[TextChunk]:
async def retrieve_text_chunks(self, query: str, n: int = 10) -> List[TextResult]:
"""Retrieve embeddings related to the given query.
Parameters:
Expand All @@ -168,7 +169,15 @@ async def retrieve_text_chunks(self, query: str, n: int = 10) -> List[TextChunk]
embedded_query,
n)
embeddings = [
TextChunk(raw=True, score=e["score"], text=e["text"])
TextResult(
chunk_id=e["chunk_id"],
document_id=e["document_id"],
score=e["score"],
text=e["text"],
raw=True,
start_char_idx=None,
end_char_idx=None,
)
for e in embeddings
]
return embeddings
Expand Down

0 comments on commit 6de4e32

Please sign in to comment.