Skip to content

Commit

Permalink
Refactor retrieve API to make it shallower and more specific to retri…
Browse files Browse the repository at this point in the history
…eval
  • Loading branch information
kerinin committed Jan 25, 2024
1 parent 2e7071f commit e7bca9c
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 47 deletions.
102 changes: 76 additions & 26 deletions dewy/chunks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,31 @@

from pydantic import BaseModel, Field

class TextChunk(BaseModel):
id: int
document_id: int
kind: Literal["text"] = "text"

raw: bool
start_char_idx: Optional[int] = Field(
default=None, description="Start char index of the chunk."
)
end_char_idx: Optional[int] = Field(
default=None, description="End char index of the chunk."
)


class ImageChunk(BaseModel):
id: int
document_id: int
kind: Literal["image"] = "image"

image: Optional[str] = Field(..., description="Image of the node.")
image_mimetype: Optional[str] = Field(..., description="Mimetype of the image.")
image_path: Optional[str] = Field(..., description="Path of the image.")
image_url: Optional[str] = Field(..., description="URL of the image.")

Chunk = Annotated[Union[TextChunk, ImageChunk], Field(discriminator='kind')]

class RetrieveRequest(BaseModel):
"""A request for retrieving unstructured (document) results."""
Expand All @@ -16,57 +41,82 @@ class RetrieveRequest(BaseModel):
# For instance -- if we summarize the text statements, maybe it only includes
# images and tables in the response. But for now, this is a big switch to
# exclude statements entirely.
include_statements: bool = True
"""Whether to include statements in the result.
include_text_chunks: bool = True
"""Whether to include text chunks in the result.
If this is false, no text chunks will be included in the result, although
the summary (if enbaled) may include information from the chunks.
"""

include_image_chunks: bool = True
"""Whether to include image chunks in the result.
If this is false, no statements will be included in the result, although
the summary (if enbaled) may include information from the statements.
If this is false, no image chunks will be included in the result, although
the summary (if enbaled) may include information from the chunks.
"""

include_summary: bool = False
"""Whether to include a generated summary."""


class TextChunk(BaseModel):
id: Optional[int] = None


class RetrieveResult(BaseModel):
chunk_id: int

document_id: int
kind: Literal["text"] = "text"
raw: bool

text: str = Field(default="", description="Text content of the chunk.")
score: float
"""The similarity score of this chunk."""

text: Optional[str] = Field(..., description="Textual description of the chunk.")

metadata: Union[TextChunk, ImageChunk] = Field(..., discriminator='kind')

class TextResult(BaseModel):
chunk_id: int
"""The ID of the chunk associated with this result"""

document_id: int
"""The ID of the document associated with this result"""

score: float
"""The similarity score of this result."""

text: str
"Textual description of the chunk."

raw: bool
start_char_idx: Optional[int] = Field(
default=None, description="Start char index of the chunk."
)
end_char_idx: Optional[int] = Field(
default=None, description="End char index of the chunk."
)


class ImageChunk(BaseModel):
id: Optional[int] = None
class ImageResult(BaseModel):
chunk_id: int
"""The ID of the chunk associated with this result"""

document_id: int
kind: Literal["image"] = "image"
"""The ID of the document associated with this result"""

score: float
"""The similarity score of this result."""

text: Optional[str] = Field(..., description="Textual description of the image.")
image: Optional[str] = Field(..., description="Image of the node.")
image_mimetype: Optional[str] = Field(..., description="Mimetype of the image.")
image_path: Optional[str] = Field(..., description="Path of the image.")
image_url: Optional[str] = Field(..., description="URL of the image.")

Chunk = Annotated[Union[TextChunk, ImageChunk], Field(discriminator='kind')]

class RetrieveResult(BaseModel):
score: Optional[float] = None
"""The similarity score of this chunk."""

chunk: Chunk
"""Retrieved chunks."""

class RetrieveResponse(BaseModel):
"""The response from a chunk retrieval request."""
"""The response from a retrieval request."""

summary: Optional[str]
"""Summary of the retrieved chunks."""

results: Sequence[RetrieveResult]
"""Retrieved results."""
text_results: Sequence[TextResult]
"""Retrieved text chunks."""

image_results: Sequence[ImageResult]
"""Retrieved image chunks."""
51 changes: 30 additions & 21 deletions dewy/chunks/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dewy.common.db import PgConnectionDep, PgPoolDep
from dewy.ingest.store import StoreDep

from .models import Chunk, ImageChunk, RetrieveRequest, RetrieveResponse, TextChunk
from .models import Chunk, ImageChunk, RetrieveRequest, RetrieveResponse, RetrieveResult, TextResult, ImageResult

router = APIRouter(prefix="/chunks")

Expand Down Expand Up @@ -47,7 +47,7 @@ async def get_chunk(
@router.post("/retrieve")
async def retrieve_chunks(
store: StoreDep, request: RetrieveRequest
) -> RetrieveResponse:
) -> List[RetrieveResult]:
"""Retrieve chunks based on a given query."""

from llama_index.response_synthesizers import ResponseMode
Expand All @@ -61,35 +61,44 @@ async def retrieve_chunks(
# TODO: metadata filters / ACLs
).query(request.query)

statements = [node_to_statement(node) for node in results.source_nodes]
from llama_index.schema import ImageNode, TextNode
text_results = [node_to_text_result(node) for node in results.source_nodes if isinstance(node.node, TextNode)]
image_results = [node_to_image_result(node) for node in results.source_nodes if isinstance(node.node, ImageNode)]

return RetrieveResponse(
summary=results.response,
chunks=statements if request.include_statements else [],
text_results=text_results if request.include_text_chunks else [],
image_results=image_results if request.include_image_chunks else [],
)


def node_to_statement(node: NodeWithScore) -> Union[TextChunk, ImageChunk]:
from llama_index.schema import ImageNode, TextNode
def node_to_text_result(node: NodeWithScore) -> TextResult:
return RetrieveResult(
# TODO: Populate for real
id=0,
kind='text',
document_id=0,
score=node.score,

raw=True,
text=node.node.text,
start_char_idx=node.node.start_char_idx,
end_char_idx=node.node.end_char_idx,
)

def node_to_image_result(node: NodeWithScore) -> ImageResult:
return RetrieveResult(
score=node.score,
chunk=ImageChunk(
# TODO: Populate for real
id=0,
kind='image',
document_id=0,

if isinstance(node.node, TextNode):
return TextChunk(
raw=True,
score=node.score,
text=node.node.text,
start_char_idx=node.node.start_char_idx,
end_char_idx=node.node.end_char_idx,
)
elif isinstance(node.node, ImageNode):
return ImageChunk(
score=node.score,
text=node.node.text if node.node.text else None,
image=node.node.image,
image_mimetype=node.node.image_mimetype,
image_path=node.node.image_path,
image_url=node.node.image_url,
)
else:
raise NotImplementedError(
f"Unsupported node type ({node.node.class_name()}): {node!r}"
)
)

0 comments on commit e7bca9c

Please sign in to comment.